From a61cf40ab97e3be467d2e6e50ffb78c9fa8a503b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 11 Oct 2013 15:58:14 -0700 Subject: [PATCH] Job cancellation: addressed code review feedback from Kay. --- .../scala/org/apache/spark/FutureAction.scala | 4 +- .../scala/org/apache/spark/SparkContext.scala | 4 +- .../org/apache/spark/executor/Executor.scala | 5 +- .../apache/spark/rdd/AsyncRDDActions.scala | 6 +-- .../apache/spark/scheduler/DAGScheduler.scala | 20 +++++-- .../spark/scheduler/DAGSchedulerEvent.scala | 2 + .../org/apache/spark/scheduler/Pool.scala | 3 ++ .../spark/scheduler/SchedulableBuilder.scala | 22 -------- .../scheduler/cluster/ClusterScheduler.scala | 10 ++-- .../cluster/ClusterTaskSetManager.scala | 5 +- .../scheduler/local/LocalScheduler.scala | 54 ++++++++++--------- .../apache/spark/JobCancellationSuite.scala | 29 ++++++---- 12 files changed, 85 insertions(+), 79 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index eab2957632..9185b9529c 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -46,7 +46,7 @@ trait FutureAction[T] extends Future[T] { override def ready(atMost: Duration)(implicit permit: CanAwait): FutureAction.this.type /** - * Await and return the result (of type T) of this action. + * Awaits and returns the result (of type T) of this action. * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf * for unbounded waiting, or a finite positive duration * @throws Exception exception during action execution @@ -76,7 +76,7 @@ trait FutureAction[T] extends Future[T] { override def value: Option[Try[T]] /** - * Block and return the result of this job. + * Blocks and returns the result of this job. */ @throws(classOf[Exception]) def get(): T = Await.result(this, Duration.Inf) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 17247be78a..52fc4dd869 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -881,9 +881,7 @@ class SparkContext( * Cancel all jobs that have been scheduled or are running. */ def cancelAllJobs() { - dagScheduler.activeJobs.foreach { job => - dagScheduler.cancelJob(job.jobId) - } + dagScheduler.cancelAllJobs() } /** diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 16258f3521..dea265615b 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -114,8 +114,8 @@ private[spark] class Executor( } } - // Akka's message frame size. This is only used to warn the user when the task result is greater - // than this value, in which case Akka will silently drop the task result message. + // Akka's message frame size. If task result is bigger than this, we use the block manager + // to send the result back. private val akkaFrameSize = { env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size") } @@ -198,6 +198,7 @@ private[spark] class Executor( if (killed) { logInfo("Executor killed task " + taskId) execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + return } attemptedTask = Some(task) diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 32af795d4c..1f24ee8cd3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -31,7 +31,7 @@ import org.apache.spark.{Logging, CancellablePromise, FutureAction} class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with Logging { /** - * Return a future for counting the number of elements in the RDD. + * Returns a future for counting the number of elements in the RDD. */ def countAsync(): FutureAction[Long] = { val totalCount = new AtomicLong @@ -51,7 +51,7 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with } /** - * Return a future for retrieving all elements of this RDD. + * Returns a future for retrieving all elements of this RDD. */ def collectAsync(): FutureAction[Seq[T]] = { val results = new Array[Array[T]](self.partitions.size) @@ -60,7 +60,7 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with } /** - * The async version of take that returns a FutureAction. + * Returns a future for retrieving the first num elements of the RDD. */ def takeAsync(num: Int): FutureAction[Seq[T]] = { val promise = new CancellablePromise[Seq[T]] diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 7278237a41..c5b28b8286 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -88,7 +88,8 @@ class DAGScheduler( eventQueue.put(ExecutorGained(execId, host)) } - // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures. + // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or + // cancellation of the job itself. override def taskSetFailed(taskSet: TaskSet, reason: String) { eventQueue.put(TaskSetFailed(taskSet, reason)) } @@ -336,11 +337,18 @@ class DAGScheduler( /** * Cancel a job that is running or waiting in the queue. */ - def cancelJob(jobId: Int): Unit = this.synchronized { + def cancelJob(jobId: Int) { logInfo("Asked to cancel job " + jobId) eventQueue.put(JobCancelled(jobId)) } + /** + * Cancel all jobs that are running or waiting in the queue. + */ + def cancelAllJobs() { + eventQueue.put(AllJobsCancelled) + } + /** * Process one event retrieved from the event queue. * Returns true if we should stop the event loop. @@ -373,6 +381,12 @@ class DAGScheduler( taskSched.cancelTasks(stage.id) } + case AllJobsCancelled => + // Cancel all running jobs. + running.foreach { stage => + taskSched.cancelTasks(stage.id) + } + case ExecutorGained(execId, host) => handleExecutorGained(execId, host) @@ -777,7 +791,7 @@ class DAGScheduler( failedStage.completionTime = Some(System.currentTimeMillis()) for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) - val error = new SparkException("Job failed: " + reason) + val error = new SparkException("Job aborted: " + reason) job.listener.jobFailed(error) listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage)))) idToActiveJob -= resultStage.jobId diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 0d4d4edc55..ee89bfb38d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -46,6 +46,8 @@ private[scheduler] case class JobSubmitted( private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent +private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent + private[scheduler] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 8b33319d02..596f9adde9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -43,7 +43,10 @@ private[spark] class Pool( var runningTasks = 0 var priority = 0 + + // A pool's stage id is used to break the tie in scheduling. var stageId = -1 + var name = poolName var parent: Pool = null diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index 873801e867..356fe56bf3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -35,28 +35,6 @@ private[spark] trait SchedulableBuilder { def buildPools() def addTaskSetManager(manager: Schedulable, properties: Properties) - - /** - * Find the TaskSetManager for the given stage. In fair scheduler, this function examines - * all the pools to find the TaskSetManager. - */ - def getTaskSetManagers(stageId: Int): Option[TaskSetManager] = { - def getTsm(pool: Pool): Option[TaskSetManager] = { - pool.schedulableQueue.foreach { - case tsm: TaskSetManager => - if (tsm.stageId == stageId) { - return Some(tsm) - } - case pool: Pool => - val found = getTsm(pool) - if (found.isDefined) { - return getTsm(pool) - } - } - None - } - getTsm(rootPool) - } } private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index 6c12ff7370..7a72ff0474 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -166,21 +166,21 @@ private[spark] class ClusterScheduler(val sc: SparkContext) override def cancelTasks(stageId: Int): Unit = synchronized { logInfo("Cancelling stage " + stageId) - schedulableBuilder.getTaskSetManagers(stageId).foreach { tsm => + activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => // There are two possible cases here: // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task. + // In this case, send a kill signal to the executors to kill the task and then abort + // the stage. // 2. The task set manager has been created but no tasks has been scheduled. In this case, - // simply abort the task set. + // simply abort the stage. val taskIds = taskSetTaskIds(tsm.taskSet.id) if (taskIds.size > 0) { taskIds.foreach { tid => val execId = taskIdToExecutorId(tid) backend.killTask(tid, execId) } - } else { - tsm.error("Stage %d was cancelled before any tasks was launched".format(stageId)) } + tsm.error("Stage %d was cancelled".format(stageId)) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala index 0a9544bd6d..1198bac6dd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -17,8 +17,7 @@ package org.apache.spark.scheduler.cluster -import java.nio.ByteBuffer -import java.util.{Arrays, NoSuchElementException} +import java.util.Arrays import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -27,7 +26,7 @@ import scala.math.max import scala.math.min import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv, - SparkException, Success, TaskEndReason, TaskResultLost, TaskState, TaskKilled} + Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler._ import org.apache.spark.util.{SystemClock, Clock} diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala index dc6509d195..b445260d1b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala @@ -57,8 +57,10 @@ class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int) launchTask(localScheduler.resourceOffer(freeCores)) case LocalStatusUpdate(taskId, state, serializeData) => - freeCores += 1 - launchTask(localScheduler.resourceOffer(freeCores)) + if (TaskState.isFinished(state)) { + freeCores += 1 + launchTask(localScheduler.resourceOffer(freeCores)) + } case KillTask(taskId) => executor.killTask(taskId) @@ -128,20 +130,21 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: override def cancelTasks(stageId: Int): Unit = synchronized { logInfo("Cancelling stage " + stageId) - schedulableBuilder.getTaskSetManagers(stageId).foreach { tsm => + logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId)) + activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => // There are two possible cases here: // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task. + // In this case, send a kill signal to the executors to kill the task and then abort + // the stage. // 2. The task set manager has been created but no tasks has been scheduled. In this case, - // simply abort the task set. + // simply abort the stage. val taskIds = taskSetTaskIds(tsm.taskSet.id) if (taskIds.size > 0) { taskIds.foreach { tid => localActor ! KillTask(tid) } - } else { - tsm.error("Stage %d was cancelled before any tasks was launched".format(stageId)) } + tsm.error("Stage %d was cancelled".format(stageId)) } } @@ -185,26 +188,27 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { - if (TaskState.isFinished(state)) synchronized { - taskIdToTaskSetId.get(taskId) match { - case Some(taskSetId) => - val taskSetManager = activeTaskSets(taskSetId) - taskSetTaskIds(taskSetId) -= taskId + if (TaskState.isFinished(state)) { + synchronized { + taskIdToTaskSetId.get(taskId) match { + case Some(taskSetId) => + val taskSetManager = activeTaskSets(taskSetId) + taskSetTaskIds(taskSetId) -= taskId - state match { - case TaskState.FINISHED => - taskSetManager.taskEnded(taskId, state, serializedData) - case TaskState.FAILED => - taskSetManager.taskFailed(taskId, state, serializedData) - case TaskState.KILLED => - taskSetManager.error("Task %d was killed".format(taskId)) - case _ => {} - } - - localActor ! LocalStatusUpdate(taskId, state, serializedData) - case None => - logInfo("Ignoring update from TID " + taskId + " because its task set is gone") + state match { + case TaskState.FINISHED => + taskSetManager.taskEnded(taskId, state, serializedData) + case TaskState.FAILED => + taskSetManager.taskFailed(taskId, state, serializedData) + case TaskState.KILLED => + taskSetManager.error("Task %d was killed".format(taskId)) + case _ => {} + } + case None => + logInfo("Ignoring update from TID " + taskId + " because its task set is gone") + } } + localActor ! LocalStatusUpdate(taskId, state, serializedData) } } diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index eca309cf29..53c225391c 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -31,9 +31,8 @@ import org.apache.spark.scheduler.{SparkListenerTaskStart, SparkListener} /** * Test suite for cancelling running jobs. We run the cancellation tasks for single job action - * (e.g. count) as well as multi-job action (e.g. take). We test in the combination of: - * - FIFO vs fair scheduler - * - local vs local cluster + * (e.g. count) as well as multi-job action (e.g. take). We test the local and cluster schedulers + * in both FIFO and fair scheduling modes. */ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAfter with LocalSparkContext { @@ -48,14 +47,8 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf sc = new SparkContext("local[2]", "test") testCount() testTake() - resetSparkContext() - } - - test("cluster mode, FIFO scheduler") { - System.setProperty("spark.scheduler.mode", "FIFO") - sc = new SparkContext("local-cluster[2,1,512]", "test") - testCount() - testTake() + // Make sure we can still launch tasks. + assert(sc.parallelize(1 to 10, 2).count === 10) resetSparkContext() } @@ -66,6 +59,18 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf sc = new SparkContext("local[2]", "test") testCount() testTake() + // Make sure we can still launch tasks. + assert(sc.parallelize(1 to 10, 2).count === 10) + resetSparkContext() + } + + test("cluster mode, FIFO scheduler") { + System.setProperty("spark.scheduler.mode", "FIFO") + sc = new SparkContext("local-cluster[2,1,512]", "test") + testCount() + testTake() + // Make sure we can still launch tasks. + assert(sc.parallelize(1 to 10, 2).count === 10) resetSparkContext() } @@ -76,6 +81,8 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf sc = new SparkContext("local-cluster[2,1,512]", "test") testCount() testTake() + // Make sure we can still launch tasks. + assert(sc.parallelize(1 to 10, 2).count === 10) resetSparkContext() }