Merge pull request #67 from kayousterhout/remove_tsl

Removed TaskSchedulerListener interface.

The interface was used only by the DAG scheduler (so it wasn't necessary
to define the additional interface), and the naming makes it very
confusing when reading the code (because "listener" was used
to describe the DAG scheduler, rather than SparkListeners, which
implement a nearly-identical interface but serve a different
function).

@mateiz - is there a reason for this interface that I'm missing?
This commit is contained in:
Matei Zaharia 2013-10-17 11:12:28 -07:00
commit cf64f63f8a
9 changed files with 63 additions and 105 deletions

View file

@ -55,20 +55,20 @@ class DAGScheduler(
mapOutputTracker: MapOutputTracker, mapOutputTracker: MapOutputTracker,
blockManagerMaster: BlockManagerMaster, blockManagerMaster: BlockManagerMaster,
env: SparkEnv) env: SparkEnv)
extends TaskSchedulerListener with Logging { extends Logging {
def this(taskSched: TaskScheduler) { def this(taskSched: TaskScheduler) {
this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get) this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
} }
taskSched.setListener(this) taskSched.setDAGScheduler(this)
// Called by TaskScheduler to report task's starting. // Called by TaskScheduler to report task's starting.
override def taskStarted(task: Task[_], taskInfo: TaskInfo) { def taskStarted(task: Task[_], taskInfo: TaskInfo) {
eventQueue.put(BeginEvent(task, taskInfo)) eventQueue.put(BeginEvent(task, taskInfo))
} }
// Called by TaskScheduler to report task completions or failures. // Called by TaskScheduler to report task completions or failures.
override def taskEnded( def taskEnded(
task: Task[_], task: Task[_],
reason: TaskEndReason, reason: TaskEndReason,
result: Any, result: Any,
@ -79,18 +79,18 @@ class DAGScheduler(
} }
// Called by TaskScheduler when an executor fails. // Called by TaskScheduler when an executor fails.
override def executorLost(execId: String) { def executorLost(execId: String) {
eventQueue.put(ExecutorLost(execId)) eventQueue.put(ExecutorLost(execId))
} }
// Called by TaskScheduler when a host is added // Called by TaskScheduler when a host is added
override def executorGained(execId: String, host: String) { def executorGained(execId: String, host: String) {
eventQueue.put(ExecutorGained(execId, host)) eventQueue.put(ExecutorGained(execId, host))
} }
// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
// cancellation of the job itself. // cancellation of the job itself.
override def taskSetFailed(taskSet: TaskSet, reason: String) { def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason)) eventQueue.put(TaskSetFailed(taskSet, reason))
} }

View file

@ -24,8 +24,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
* Each TaskScheduler schedulers task for a single SparkContext. * Each TaskScheduler schedulers task for a single SparkContext.
* These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage, * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
* and are responsible for sending the tasks to the cluster, running them, retrying if there * and are responsible for sending the tasks to the cluster, running them, retrying if there
* are failures, and mitigating stragglers. They return events to the DAGScheduler through * are failures, and mitigating stragglers. They return events to the DAGScheduler.
* the TaskSchedulerListener interface.
*/ */
private[spark] trait TaskScheduler { private[spark] trait TaskScheduler {
@ -48,8 +47,8 @@ private[spark] trait TaskScheduler {
// Cancel a stage. // Cancel a stage.
def cancelTasks(stageId: Int) def cancelTasks(stageId: Int)
// Set a listener for upcalls. This is guaranteed to be set before submitTasks is called. // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
def setListener(listener: TaskSchedulerListener): Unit def setDAGScheduler(dagScheduler: DAGScheduler): Unit
// Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
def defaultParallelism(): Int def defaultParallelism(): Int

View file

@ -1,44 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.scheduler
import scala.collection.mutable.Map
import org.apache.spark.TaskEndReason
import org.apache.spark.executor.TaskMetrics
/**
* Interface for getting events back from the TaskScheduler.
*/
private[spark] trait TaskSchedulerListener {
// A task has started.
def taskStarted(task: Task[_], taskInfo: TaskInfo)
// A task has finished or failed.
def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any],
taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit
// A node was added to the cluster.
def executorGained(execId: String, host: String): Unit
// A node was lost from the cluster.
def executorLost(execId: String): Unit
// The TaskScheduler wants to abort an entire task set.
def taskSetFailed(taskSet: TaskSet, reason: String): Unit
}

View file

@ -79,7 +79,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
private val executorIdToHost = new HashMap[String, String] private val executorIdToHost = new HashMap[String, String]
// Listener object to pass upcalls into // Listener object to pass upcalls into
var listener: TaskSchedulerListener = null var dagScheduler: DAGScheduler = null
var backend: SchedulerBackend = null var backend: SchedulerBackend = null
@ -94,8 +94,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// This is a var so that we can reset it for testing purposes. // This is a var so that we can reset it for testing purposes.
private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)
override def setListener(listener: TaskSchedulerListener) { override def setDAGScheduler(dagScheduler: DAGScheduler) {
this.listener = listener this.dagScheduler = dagScheduler
} }
def initialize(context: SchedulerBackend) { def initialize(context: SchedulerBackend) {
@ -297,7 +297,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
} }
// Update the DAGScheduler without holding a lock on this, since that can deadlock // Update the DAGScheduler without holding a lock on this, since that can deadlock
if (failedExecutor != None) { if (failedExecutor != None) {
listener.executorLost(failedExecutor.get) dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers() backend.reviveOffers()
} }
if (taskFailed) { if (taskFailed) {
@ -397,9 +397,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
logError("Lost an executor " + executorId + " (already removed): " + reason) logError("Lost an executor " + executorId + " (already removed): " + reason)
} }
} }
// Call listener.executorLost without holding the lock on this to prevent deadlock // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock
if (failedExecutor != None) { if (failedExecutor != None) {
listener.executorLost(failedExecutor.get) dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers() backend.reviveOffers()
} }
} }
@ -418,7 +418,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
} }
def executorGained(execId: String, host: String) { def executorGained(execId: String, host: String) {
listener.executorGained(execId, host) dagScheduler.executorGained(execId, host)
} }
def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized { def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized {

View file

@ -415,11 +415,11 @@ private[spark] class ClusterTaskSetManager(
} }
private def taskStarted(task: Task[_], info: TaskInfo) { private def taskStarted(task: Task[_], info: TaskInfo) {
sched.listener.taskStarted(task, info) sched.dagScheduler.taskStarted(task, info)
} }
/** /**
* Marks the task as successful and notifies the listener that a task has ended. * Marks the task as successful and notifies the DAGScheduler that a task has ended.
*/ */
def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = { def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
val info = taskInfos(tid) val info = taskInfos(tid)
@ -429,7 +429,7 @@ private[spark] class ClusterTaskSetManager(
if (!successful(index)) { if (!successful(index)) {
logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format( logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
tid, info.duration, info.host, tasksSuccessful, numTasks)) tid, info.duration, info.host, tasksSuccessful, numTasks))
sched.listener.taskEnded( sched.dagScheduler.taskEnded(
tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
// Mark successful and stop if all the tasks have succeeded. // Mark successful and stop if all the tasks have succeeded.
@ -445,7 +445,8 @@ private[spark] class ClusterTaskSetManager(
} }
/** /**
* Marks the task as failed, re-adds it to the list of pending tasks, and notifies the listener. * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
* DAG Scheduler.
*/ */
def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) { def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
val info = taskInfos(tid) val info = taskInfos(tid)
@ -463,7 +464,7 @@ private[spark] class ClusterTaskSetManager(
reason.foreach { reason.foreach {
case fetchFailed: FetchFailed => case fetchFailed: FetchFailed =>
logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress) logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
successful(index) = true successful(index) = true
tasksSuccessful += 1 tasksSuccessful += 1
sched.taskSetFinished(this) sched.taskSetFinished(this)
@ -472,11 +473,11 @@ private[spark] class ClusterTaskSetManager(
case TaskKilled => case TaskKilled =>
logWarning("Task %d was killed.".format(tid)) logWarning("Task %d was killed.".format(tid))
sched.listener.taskEnded(tasks(index), reason.get, null, null, info, null) sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
return return
case ef: ExceptionFailure => case ef: ExceptionFailure =>
sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
val key = ef.description val key = ef.description
val now = clock.getTime() val now = clock.getTime()
val (printFull, dupCount) = { val (printFull, dupCount) = {
@ -504,7 +505,7 @@ private[spark] class ClusterTaskSetManager(
case TaskResultLost => case TaskResultLost =>
logWarning("Lost result for TID %s on host %s".format(tid, info.host)) logWarning("Lost result for TID %s on host %s".format(tid, info.host))
sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null) sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
case _ => {} case _ => {}
} }
@ -533,7 +534,7 @@ private[spark] class ClusterTaskSetManager(
failed = true failed = true
causeOfFailure = message causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error // TODO: Kill running tasks if we were not terminated due to a Mesos error
sched.listener.taskSetFailed(taskSet, message) sched.dagScheduler.taskSetFailed(taskSet, message)
removeAllRunningTasks() removeAllRunningTasks()
sched.taskSetFinished(this) sched.taskSetFinished(this)
} }
@ -606,7 +607,7 @@ private[spark] class ClusterTaskSetManager(
addPendingTask(index) addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish. // stage finishes when a total of tasks.size tasks finish.
sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null) sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
} }
} }
} }

View file

@ -81,7 +81,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
val env = SparkEnv.get val env = SparkEnv.get
val attemptId = new AtomicInteger val attemptId = new AtomicInteger
var listener: TaskSchedulerListener = null var dagScheduler: DAGScheduler = null
// Application dependencies (added through SparkContext) that we've fetched so far on this node. // Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got. // Each map holds the master's timestamp for the version of that file or JAR we got.
@ -114,8 +114,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test") localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
} }
override def setListener(listener: TaskSchedulerListener) { override def setDAGScheduler(dagScheduler: DAGScheduler) {
this.listener = listener this.dagScheduler = dagScheduler
} }
override def submitTasks(taskSet: TaskSet) { override def submitTasks(taskSet: TaskSet) {

View file

@ -133,7 +133,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
} }
def taskStarted(task: Task[_], info: TaskInfo) { def taskStarted(task: Task[_], info: TaskInfo) {
sched.listener.taskStarted(task, info) sched.dagScheduler.taskStarted(task, info)
} }
def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
@ -148,7 +148,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
} }
} }
result.metrics.resultSize = serializedData.limit() result.metrics.resultSize = serializedData.limit()
sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics) sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info,
result.metrics)
numFinished += 1 numFinished += 1
decreaseRunningTasks(1) decreaseRunningTasks(1)
finished(index) = true finished(index) = true
@ -165,7 +166,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
decreaseRunningTasks(1) decreaseRunningTasks(1)
val reason: ExceptionFailure = ser.deserialize[ExceptionFailure]( val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](
serializedData, getClass.getClassLoader) serializedData, getClass.getClassLoader)
sched.listener.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null)) sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
if (!finished(index)) { if (!finished(index)) {
copiesRunning(index) -= 1 copiesRunning(index) -= 1
numFailures(index) += 1 numFailures(index) += 1
@ -176,7 +177,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format( val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(
taskSet.id, index, 4, reason.description) taskSet.id, index, 4, reason.description)
decreaseRunningTasks(runningTasks) decreaseRunningTasks(runningTasks)
sched.listener.taskSetFailed(taskSet, errorMessage) sched.dagScheduler.taskSetFailed(taskSet, errorMessage)
// need to delete failed Taskset from schedule queue // need to delete failed Taskset from schedule queue
sched.taskSetFinished(this) sched.taskSetFinished(this)
} }
@ -184,7 +185,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
} }
override def error(message: String) { override def error(message: String) {
sched.listener.taskSetFailed(taskSet, message) sched.dagScheduler.taskSetFailed(taskSet, message)
sched.taskSetFinished(this) sched.taskSetFinished(this)
} }
} }

View file

@ -60,7 +60,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
taskSets += taskSet taskSets += taskSet
} }
override def cancelTasks(stageId: Int) {} override def cancelTasks(stageId: Int) {}
override def setListener(listener: TaskSchedulerListener) = {} override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
override def defaultParallelism() = 2 override def defaultParallelism() = 2
} }

View file

@ -28,6 +28,30 @@ import org.apache.spark.executor.TaskMetrics
import java.nio.ByteBuffer import java.nio.ByteBuffer
import org.apache.spark.util.{Utils, FakeClock} import org.apache.spark.util.{Utils, FakeClock}
class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler(taskScheduler) {
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {
taskScheduler.startedTasks += taskInfo.index
}
override def taskEnded(
task: Task[_],
reason: TaskEndReason,
result: Any,
accumUpdates: mutable.Map[Long, Any],
taskInfo: TaskInfo,
taskMetrics: TaskMetrics) {
taskScheduler.endedTasks(taskInfo.index) = reason
}
override def executorGained(execId: String, host: String) {}
override def executorLost(execId: String) {}
override def taskSetFailed(taskSet: TaskSet, reason: String) {
taskScheduler.taskSetsFailed += taskSet.id
}
}
/** /**
* A mock ClusterScheduler implementation that just remembers information about tasks started and * A mock ClusterScheduler implementation that just remembers information about tasks started and
* feedback received from the TaskSetManagers. Note that it's important to initialize this with * feedback received from the TaskSetManagers. Note that it's important to initialize this with
@ -44,30 +68,7 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /*
val executors = new mutable.HashMap[String, String] ++ liveExecutors val executors = new mutable.HashMap[String, String] ++ liveExecutors
listener = new TaskSchedulerListener { dagScheduler = new FakeDAGScheduler(this)
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
startedTasks += taskInfo.index
}
def taskEnded(
task: Task[_],
reason: TaskEndReason,
result: Any,
accumUpdates: mutable.Map[Long, Any],
taskInfo: TaskInfo,
taskMetrics: TaskMetrics)
{
endedTasks(taskInfo.index) = reason
}
def executorGained(execId: String, host: String) {}
def executorLost(execId: String) {}
def taskSetFailed(taskSet: TaskSet, reason: String) {
taskSetsFailed += taskSet.id
}
}
def removeExecutor(execId: String): Unit = executors -= execId def removeExecutor(execId: String): Unit = executors -= execId