diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index f219c5605b..c14c12664f 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -23,7 +23,6 @@ import java.util.LinkedList; import org.apache.avro.reflect.Nullable; import org.apache.spark.TaskContext; -import org.apache.spark.TaskKilledException; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; @@ -291,8 +290,8 @@ public final class UnsafeInMemorySorter { // to avoid performance overhead. This check is added here in `loadNext()` instead of in // `hasNext()` because it's technically possible for the caller to be relying on // `getNumRecords()` instead of `hasNext()` to know when to stop. - if (taskContext != null && taskContext.isInterrupted()) { - throw new TaskKilledException(); + if (taskContext != null) { + taskContext.killTaskIfInterrupted(); } // This pointer points to a 4-byte record length, followed by the record's bytes final long recordPointer = array.get(offset + position); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index b6323c624b..9521ab86a1 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -24,7 +24,6 @@ import com.google.common.io.Closeables; import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; -import org.apache.spark.TaskKilledException; import org.apache.spark.io.NioBufferedFileInputStream; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; @@ -102,8 +101,8 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen // to avoid performance overhead. This check is added here in `loadNext()` instead of in // `hasNext()` because it's technically possible for the caller to be relying on // `getNumRecords()` instead of `hasNext()` to know when to stop. - if (taskContext != null && taskContext.isInterrupted()) { - throw new TaskKilledException(); + if (taskContext != null) { + taskContext.killTaskIfInterrupted(); } recordLength = din.readInt(); keyPrefix = din.readLong(); diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala index 5c262bcbdd..7f2c006817 100644 --- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala +++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala @@ -33,11 +33,8 @@ class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator // is allowed. The assumption is that Thread.interrupted does not have a memory fence in read // (just a volatile field in C), while context.interrupted is a volatile in the JVM, which // introduces an expensive read fence. - if (context.isInterrupted) { - throw new TaskKilledException - } else { - delegate.hasNext - } + context.killTaskIfInterrupted() + delegate.hasNext } def next(): T = delegate.next() diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0e36a30c93..0225fd6056 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2249,6 +2249,24 @@ class SparkContext(config: SparkConf) extends Logging { dagScheduler.cancelStage(stageId, None) } + /** + * Kill and reschedule the given task attempt. Task ids can be obtained from the Spark UI + * or through SparkListener.onTaskStart. + * + * @param taskId the task ID to kill. This id uniquely identifies the task attempt. + * @param interruptThread whether to interrupt the thread running the task. + * @param reason the reason for killing the task, which should be a short string. If a task + * is killed multiple times with different reasons, only one reason will be reported. + * + * @return Whether the task was successfully killed. + */ + def killTaskAttempt( + taskId: Long, + interruptThread: Boolean = true, + reason: String = "killed via SparkContext.killTaskAttempt"): Boolean = { + dagScheduler.killTaskAttempt(taskId, interruptThread, reason) + } + /** * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 5acfce1759..0b87cd503d 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -184,6 +184,16 @@ abstract class TaskContext extends Serializable { @DeveloperApi def getMetricsSources(sourceName: String): Seq[Source] + /** + * If the task is interrupted, throws TaskKilledException with the reason for the interrupt. + */ + private[spark] def killTaskIfInterrupted(): Unit + + /** + * If the task is interrupted, the reason this task was killed, otherwise None. + */ + private[spark] def getKillReason(): Option[String] + /** * Returns the manager for this task's managed memory. */ diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index f346cf8d65..8cd1d1c96a 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -59,8 +59,8 @@ private[spark] class TaskContextImpl( /** List of callback functions to execute when the task fails. */ @transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener] - // Whether the corresponding task has been killed. - @volatile private var interrupted: Boolean = false + // If defined, the corresponding task has been killed and this option contains the reason. + @volatile private var reasonIfKilled: Option[String] = None // Whether the task has completed. private var completed: Boolean = false @@ -140,8 +140,19 @@ private[spark] class TaskContextImpl( } /** Marks the task for interruption, i.e. cancellation. */ - private[spark] def markInterrupted(): Unit = { - interrupted = true + private[spark] def markInterrupted(reason: String): Unit = { + reasonIfKilled = Some(reason) + } + + private[spark] override def killTaskIfInterrupted(): Unit = { + val reason = reasonIfKilled + if (reason.isDefined) { + throw new TaskKilledException(reason.get) + } + } + + private[spark] override def getKillReason(): Option[String] = { + reasonIfKilled } @GuardedBy("this") @@ -149,7 +160,7 @@ private[spark] class TaskContextImpl( override def isRunningLocally(): Boolean = false - override def isInterrupted(): Boolean = interrupted + override def isInterrupted(): Boolean = reasonIfKilled.isDefined override def getLocalProperty(key: String): String = localProperties.getProperty(key) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 8c1b5f7bf0..a76283e33f 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -212,8 +212,8 @@ case object TaskResultLost extends TaskFailedReason { * Task was killed intentionally and needs to be rescheduled. */ @DeveloperApi -case object TaskKilled extends TaskFailedReason { - override def toErrorString: String = "TaskKilled (killed intentionally)" +case class TaskKilled(reason: String) extends TaskFailedReason { + override def toErrorString: String = s"TaskKilled ($reason)" override def countTowardsTaskFailures: Boolean = false } diff --git a/core/src/main/scala/org/apache/spark/TaskKilledException.scala b/core/src/main/scala/org/apache/spark/TaskKilledException.scala index ad487c4efb..9dbf0d493b 100644 --- a/core/src/main/scala/org/apache/spark/TaskKilledException.scala +++ b/core/src/main/scala/org/apache/spark/TaskKilledException.scala @@ -24,4 +24,6 @@ import org.apache.spark.annotation.DeveloperApi * Exception thrown when a task is explicitly killed (i.e., task failure is expected). */ @DeveloperApi -class TaskKilledException extends RuntimeException +class TaskKilledException(val reason: String) extends RuntimeException { + def this() = this("unknown reason") +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 04ae97ed3c..b0dd2fc187 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -215,7 +215,7 @@ private[spark] class PythonRunner( case e: Exception if context.isInterrupted => logDebug("Exception thrown after task interruption", e) - throw new TaskKilledException + throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) case e: Exception if env.isStopped => logDebug("Exception thrown after context is stopped", e) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b376ecd301..ba0096d874 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -97,11 +97,11 @@ private[spark] class CoarseGrainedExecutorBackend( executor.launchTask(this, taskDesc) } - case KillTask(taskId, _, interruptThread) => + case KillTask(taskId, _, interruptThread, reason) => if (executor == null) { exitExecutor(1, "Received KillTask command but executor was null") } else { - executor.killTask(taskId, interruptThread) + executor.killTask(taskId, interruptThread, reason) } case StopExecutor => diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 790c1ae942..99b1608010 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -158,7 +158,7 @@ private[spark] class Executor( threadPool.execute(tr) } - def killTask(taskId: Long, interruptThread: Boolean): Unit = { + def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = { val taskRunner = runningTasks.get(taskId) if (taskRunner != null) { if (taskReaperEnabled) { @@ -168,7 +168,8 @@ private[spark] class Executor( case Some(existingReaper) => interruptThread && !existingReaper.interruptThread } if (shouldCreateReaper) { - val taskReaper = new TaskReaper(taskRunner, interruptThread = interruptThread) + val taskReaper = new TaskReaper( + taskRunner, interruptThread = interruptThread, reason = reason) taskReaperForTask(taskId) = taskReaper Some(taskReaper) } else { @@ -178,7 +179,7 @@ private[spark] class Executor( // Execute the TaskReaper from outside of the synchronized block. maybeNewTaskReaper.foreach(taskReaperPool.execute) } else { - taskRunner.kill(interruptThread = interruptThread) + taskRunner.kill(interruptThread = interruptThread, reason = reason) } } } @@ -189,8 +190,9 @@ private[spark] class Executor( * tasks instead of taking the JVM down. * @param interruptThread whether to interrupt the task thread */ - def killAllTasks(interruptThread: Boolean) : Unit = { - runningTasks.keys().asScala.foreach(t => killTask(t, interruptThread = interruptThread)) + def killAllTasks(interruptThread: Boolean, reason: String) : Unit = { + runningTasks.keys().asScala.foreach(t => + killTask(t, interruptThread = interruptThread, reason = reason)) } def stop(): Unit = { @@ -217,8 +219,8 @@ private[spark] class Executor( val threadName = s"Executor task launch worker for task $taskId" private val taskName = taskDescription.name - /** Whether this task has been killed. */ - @volatile private var killed = false + /** If specified, this task has been killed and this option contains the reason. */ + @volatile private var reasonIfKilled: Option[String] = None @volatile private var threadId: Long = -1 @@ -239,13 +241,13 @@ private[spark] class Executor( */ @volatile var task: Task[Any] = _ - def kill(interruptThread: Boolean): Unit = { - logInfo(s"Executor is trying to kill $taskName (TID $taskId)") - killed = true + def kill(interruptThread: Boolean, reason: String): Unit = { + logInfo(s"Executor is trying to kill $taskName (TID $taskId), reason: $reason") + reasonIfKilled = Some(reason) if (task != null) { synchronized { if (!finished) { - task.kill(interruptThread) + task.kill(interruptThread, reason) } } } @@ -296,12 +298,13 @@ private[spark] class Executor( // If this task has been killed before we deserialized it, let's quit now. Otherwise, // continue executing the task. - if (killed) { + val killReason = reasonIfKilled + if (killReason.isDefined) { // Throw an exception rather than returning, because returning within a try{} block // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl // exception will be caught by the catch block, leading to an incorrect ExceptionFailure // for the task. - throw new TaskKilledException + throw new TaskKilledException(killReason.get) } logDebug("Task " + taskId + "'s epoch is " + task.epoch) @@ -358,9 +361,7 @@ private[spark] class Executor( } else 0L // If the task has been killed, let's fail it. - if (task.killed) { - throw new TaskKilledException - } + task.context.killTaskIfInterrupted() val resultSer = env.serializer.newInstance() val beforeSerialization = System.currentTimeMillis() @@ -426,15 +427,17 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - case _: TaskKilledException => - logInfo(s"Executor killed $taskName (TID $taskId)") + case t: TaskKilledException => + logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) - case _: InterruptedException if task.killed => - logInfo(s"Executor interrupted and killed $taskName (TID $taskId)") + case _: InterruptedException if task.reasonIfKilled.isDefined => + val killReason = task.reasonIfKilled.getOrElse("unknown reason") + logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + execBackend.statusUpdate( + taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) case CausedBy(cDE: CommitDeniedException) => val reason = cDE.toTaskFailedReason @@ -512,7 +515,8 @@ private[spark] class Executor( */ private class TaskReaper( taskRunner: TaskRunner, - val interruptThread: Boolean) + val interruptThread: Boolean, + val reason: String) extends Runnable { private[this] val taskId: Long = taskRunner.taskId @@ -533,7 +537,7 @@ private[spark] class Executor( // Only attempt to kill the task once. If interruptThread = false then a second kill // attempt would be a no-op and if interruptThread = true then it may not be safe or // effective to interrupt multiple times: - taskRunner.kill(interruptThread = interruptThread) + taskRunner.kill(interruptThread = interruptThread, reason = reason) // Monitor the killed task until it exits. The synchronization logic here is complicated // because we don't want to synchronize on the taskRunner while possibly taking a thread // dump, but we also need to be careful to avoid races between checking whether the task diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index d944f26875..0971731683 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -738,6 +738,15 @@ class DAGScheduler( eventProcessLoop.post(StageCancelled(stageId, reason)) } + /** + * Kill a given task. It will be retried. + * + * @return Whether the task was successfully killed. + */ + def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = { + taskScheduler.killTaskAttempt(taskId, interruptThread, reason) + } + /** * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since * the last fetch failure. @@ -1353,7 +1362,7 @@ class DAGScheduler( case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. - case _: ExecutorLostFailure | TaskKilled | UnknownReason => + case _: ExecutorLostFailure | _: TaskKilled | UnknownReason => // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler // will abort the job. } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 8801a761af..22db3350ab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -30,8 +30,21 @@ private[spark] trait SchedulerBackend { def reviveOffers(): Unit def defaultParallelism(): Int - def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = + /** + * Requests that an executor kills a running task. + * + * @param taskId Id of the task. + * @param executorId Id of the executor the task is running on. + * @param interruptThread Whether the executor should interrupt the task thread. + * @param reason The reason for the task kill. + */ + def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = throw new UnsupportedOperationException + def isReady(): Boolean = true /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 70213722aa..46ef23f316 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -89,8 +89,8 @@ private[spark] abstract class Task[T]( TaskContext.setTaskContext(context) taskThread = Thread.currentThread() - if (_killed) { - kill(interruptThread = false) + if (_reasonIfKilled != null) { + kill(interruptThread = false, _reasonIfKilled) } new CallerContext( @@ -158,17 +158,17 @@ private[spark] abstract class Task[T]( // The actual Thread on which the task is running, if any. Initialized in run(). @volatile @transient private var taskThread: Thread = _ - // A flag to indicate whether the task is killed. This is used in case context is not yet - // initialized when kill() is invoked. - @volatile @transient private var _killed = false + // If non-null, this task has been killed and the reason is as specified. This is used in case + // context is not yet initialized when kill() is invoked. + @volatile @transient private var _reasonIfKilled: String = null protected var _executorDeserializeTime: Long = 0 protected var _executorDeserializeCpuTime: Long = 0 /** - * Whether the task has been killed. + * If defined, this task has been killed and this option contains the reason. */ - def killed: Boolean = _killed + def reasonIfKilled: Option[String] = Option(_reasonIfKilled) /** * Returns the amount of time spent deserializing the RDD and function to be run. @@ -201,10 +201,11 @@ private[spark] abstract class Task[T]( * be called multiple times. * If interruptThread is true, we will also call Thread.interrupt() on the Task's executor thread. */ - def kill(interruptThread: Boolean) { - _killed = true + def kill(interruptThread: Boolean, reason: String) { + require(reason != null) + _reasonIfKilled = reason if (context != null) { - context.markInterrupted() + context.markInterrupted(reason) } if (interruptThread && taskThread != null) { taskThread.interrupt() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index cd13eebe74..3de7d1f7de 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -54,6 +54,13 @@ private[spark] trait TaskScheduler { // Cancel a stage. def cancelTasks(stageId: Int, interruptThread: Boolean): Unit + /** + * Kills a task attempt. + * + * @return Whether the task was successfully killed. + */ + def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean + // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. def setDAGScheduler(dagScheduler: DAGScheduler): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index d6225a0873..07aea773fa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -241,7 +241,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( // simply abort the stage. tsm.runningTasksSet.foreach { tid => val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId, interruptThread) + backend.killTask(tid, execId, interruptThread, reason = "stage cancelled") } tsm.abort("Stage %s cancelled".format(stageId)) logInfo("Stage %d was cancelled".format(stageId)) @@ -249,6 +249,18 @@ private[spark] class TaskSchedulerImpl private[scheduler]( } } + override def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = { + logInfo(s"Killing task $taskId: $reason") + val execId = taskIdToExecutorId.get(taskId) + if (execId.isDefined) { + backend.killTask(taskId, execId.get, interruptThread, reason) + true + } else { + logWarning(s"Could not kill task $taskId because no task with that ID was found.") + false + } + } + /** * Called to indicate that all task attempts (including speculated tasks) associated with the * given TaskSetManager have completed, so state associated with the TaskSetManager should be @@ -469,7 +481,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( taskState: TaskState, reason: TaskFailedReason): Unit = synchronized { taskSetManager.handleFailedTask(tid, taskState, reason) - if (!taskSetManager.isZombie && taskState != TaskState.KILLED) { + if (!taskSetManager.isZombie && !taskSetManager.someAttemptSucceeded(tid)) { // Need to revive offers again now that the task set manager state has been updated to // reflect failed tasks that need to be re-run. backend.reviveOffers() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index f4a21bca79..a177aab5f9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -101,6 +101,10 @@ private[spark] class TaskSetManager( override def runningTasks: Int = runningTasksSet.size + def someAttemptSucceeded(tid: Long): Boolean = { + successful(taskInfos(tid).index) + } + // True once no more tasks should be launched for this task set manager. TaskSetManagers enter // the zombie state once at least one attempt of each task has completed successfully, or if the // task set is aborted (for example, because it was killed). TaskSetManagers remain in the zombie @@ -722,7 +726,11 @@ private[spark] class TaskSetManager( logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " + s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " + s"as the attempt ${info.attemptNumber} succeeded on ${info.host}") - sched.backend.killTask(attemptInfo.taskId, attemptInfo.executorId, true) + sched.backend.killTask( + attemptInfo.taskId, + attemptInfo.executorId, + interruptThread = true, + reason = "another attempt succeeded") } if (!successful(index)) { tasksSuccessful += 1 diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 2898cd7d17..6b49bd699a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -40,7 +40,7 @@ private[spark] object CoarseGrainedClusterMessages { // Driver to executors case class LaunchTask(data: SerializableBuffer) extends CoarseGrainedClusterMessage - case class KillTask(taskId: Long, executor: String, interruptThread: Boolean) + case class KillTask(taskId: Long, executor: String, interruptThread: Boolean, reason: String) extends CoarseGrainedClusterMessage case class KillExecutorsOnHost(host: String) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7e2cfaccfc..4eedaaea61 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -132,10 +132,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case ReviveOffers => makeOffers() - case KillTask(taskId, executorId, interruptThread) => + case KillTask(taskId, executorId, interruptThread, reason) => executorDataMap.get(executorId) match { case Some(executorInfo) => - executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread)) + executorInfo.executorEndpoint.send( + KillTask(taskId, executorId, interruptThread, reason)) case None => // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") @@ -428,8 +429,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp driverEndpoint.send(ReviveOffers) } - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - driverEndpoint.send(KillTask(taskId, executorId, interruptThread)) + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String) { + driverEndpoint.send(KillTask(taskId, executorId, interruptThread, reason)) } override def defaultParallelism(): Int = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index 625f998cd4..35509bc2f8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -34,7 +34,7 @@ private case class ReviveOffers() private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) -private case class KillTask(taskId: Long, interruptThread: Boolean) +private case class KillTask(taskId: Long, interruptThread: Boolean, reason: String) private case class StopExecutor() @@ -70,8 +70,8 @@ private[spark] class LocalEndpoint( reviveOffers() } - case KillTask(taskId, interruptThread) => - executor.killTask(taskId, interruptThread) + case KillTask(taskId, interruptThread, reason) => + executor.killTask(taskId, interruptThread, reason) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -143,8 +143,9 @@ private[spark] class LocalSchedulerBackend( override def defaultParallelism(): Int = scheduler.conf.getInt("spark.default.parallelism", totalCores) - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - localEndpoint.send(KillTask(taskId, interruptThread)) + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String) { + localEndpoint.send(KillTask(taskId, interruptThread, reason)) } override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index d161843dd2..e53d6907bc 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -342,7 +342,7 @@ private[spark] object UIUtils extends Logging { completed: Int, failed: Int, skipped: Int, - killed: Int, + reasonToNumKilled: Map[String, Int], total: Int): Seq[Node] = { val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) // started + completed can be > total when there are speculative tasks @@ -354,7 +354,10 @@ private[spark] object UIUtils extends Logging { {completed}/{total} { if (failed > 0) s"($failed failed)" } { if (skipped > 0) s"($skipped skipped)" } - { if (killed > 0) s"($killed killed)" } + { reasonToNumKilled.toSeq.sortBy(-_._2).map { + case (reason, count) => s"($count killed: $reason)" + } + }
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index d217f55804..18be087074 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -630,8 +630,8 @@ private[ui] class JobPagedTable( {UIUtils.makeProgressBar(started = job.numActiveTasks, completed = job.numCompletedTasks, - failed = job.numFailedTasks, skipped = job.numSkippedTasks, killed = job.numKilledTasks, - total = job.numTasks - job.numSkippedTasks)} + failed = job.numFailedTasks, skipped = job.numSkippedTasks, + reasonToNumKilled = job.reasonToNumKilled, total = job.numTasks - job.numSkippedTasks)} } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index cd1b02addc..52f41298a1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -133,9 +133,9 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage {executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")} {UIUtils.formatDuration(v.taskTime)} - {v.failedTasks + v.succeededTasks + v.killedTasks} + {v.failedTasks + v.succeededTasks + v.reasonToNumKilled.map(_._2).sum} {v.failedTasks} - {v.killedTasks} + {v.reasonToNumKilled.map(_._2).sum} {v.succeededTasks} {if (stageData.hasInput) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index e87caff426..1cf03e1541 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -371,8 +371,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { taskEnd.reason match { case Success => execSummary.succeededTasks += 1 - case TaskKilled => - execSummary.killedTasks += 1 + case kill: TaskKilled => + execSummary.reasonToNumKilled = execSummary.reasonToNumKilled.updated( + kill.reason, execSummary.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) case _ => execSummary.failedTasks += 1 } @@ -385,9 +386,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.completedIndices.add(info.index) stageData.numCompleteTasks += 1 None - case TaskKilled => - stageData.numKilledTasks += 1 - Some(TaskKilled.toErrorString) + case kill: TaskKilled => + stageData.reasonToNumKilled = stageData.reasonToNumKilled.updated( + kill.reason, stageData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) + Some(kill.toErrorString) case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates stageData.numFailedTasks += 1 Some(e.toErrorString) @@ -422,8 +424,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { taskEnd.reason match { case Success => jobData.numCompletedTasks += 1 - case TaskKilled => - jobData.numKilledTasks += 1 + case kill: TaskKilled => + jobData.reasonToNumKilled = jobData.reasonToNumKilled.updated( + kill.reason, jobData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) case _ => jobData.numFailedTasks += 1 } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index e1fa9043b6..f4caad0f58 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -300,7 +300,7 @@ private[ui] class StagePagedTable( {UIUtils.makeProgressBar(started = stageData.numActiveTasks, completed = stageData.completedIndices.size, failed = stageData.numFailedTasks, - skipped = 0, killed = stageData.numKilledTasks, total = info.numTasks)} + skipped = 0, reasonToNumKilled = stageData.reasonToNumKilled, total = info.numTasks)} {data.inputReadWithUnit} {data.outputWriteWithUnit} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 073f7edfc2..ac1a74ad80 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -32,7 +32,7 @@ private[spark] object UIData { var taskTime : Long = 0 var failedTasks : Int = 0 var succeededTasks : Int = 0 - var killedTasks : Int = 0 + var reasonToNumKilled : Map[String, Int] = Map.empty var inputBytes : Long = 0 var inputRecords : Long = 0 var outputBytes : Long = 0 @@ -64,7 +64,7 @@ private[spark] object UIData { var numCompletedTasks: Int = 0, var numSkippedTasks: Int = 0, var numFailedTasks: Int = 0, - var numKilledTasks: Int = 0, + var reasonToNumKilled: Map[String, Int] = Map.empty, /* Stages */ var numActiveStages: Int = 0, // This needs to be a set instead of a simple count to prevent double-counting of rerun stages: @@ -78,7 +78,7 @@ private[spark] object UIData { var numCompleteTasks: Int = _ var completedIndices = new OpenHashSet[Int]() var numFailedTasks: Int = _ - var numKilledTasks: Int = _ + var reasonToNumKilled: Map[String, Int] = Map.empty var executorRunTime: Long = _ var executorCpuTime: Long = _ diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 4b4d2d10cb..2cb88919c8 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -390,6 +390,8 @@ private[spark] object JsonProtocol { ("Executor ID" -> executorId) ~ ("Exit Caused By App" -> exitCausedByApp) ~ ("Loss Reason" -> reason.map(_.toString)) + case taskKilled: TaskKilled => + ("Kill Reason" -> taskKilled.reason) case _ => Utils.emptyJson } ("Reason" -> reason) ~ json @@ -877,7 +879,10 @@ private[spark] object JsonProtocol { })) ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates) case `taskResultLost` => TaskResultLost - case `taskKilled` => TaskKilled + case `taskKilled` => + val killReason = Utils.jsonOption(json \ "Kill Reason") + .map(_.extract[String]).getOrElse("unknown reason") + TaskKilled(killReason) case `taskCommitDenied` => // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON // de/serialization logic was not added until 1.5.1. To provide backward compatibility diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index d08a162fed..2c947556df 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFor import org.scalatest.concurrent.Eventually import org.scalatest.Matchers._ -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskEnd, SparkListenerTaskStart} import org.apache.spark.util.Utils @@ -540,6 +540,48 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } + // Launches one task that will run forever. Once the SparkListener detects the task has + // started, kill and re-schedule it. The second run of the task will complete immediately. + // If this test times out, then the first version of the task wasn't killed successfully. + test("Killing tasks") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + + SparkContextSuite.isTaskStarted = false + SparkContextSuite.taskKilled = false + SparkContextSuite.taskSucceeded = false + + val listener = new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + eventually(timeout(10.seconds)) { + assert(SparkContextSuite.isTaskStarted) + } + if (!SparkContextSuite.taskKilled) { + SparkContextSuite.taskKilled = true + sc.killTaskAttempt(taskStart.taskInfo.taskId, true, "first attempt will hang") + } + } + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + if (taskEnd.taskInfo.attemptNumber == 1 && taskEnd.reason == Success) { + SparkContextSuite.taskSucceeded = true + } + } + } + sc.addSparkListener(listener) + eventually(timeout(20.seconds)) { + sc.parallelize(1 to 1).foreach { x => + // first attempt will hang + if (!SparkContextSuite.isTaskStarted) { + SparkContextSuite.isTaskStarted = true + Thread.sleep(9999999) + } + // second attempt succeeds immediately + } + } + eventually(timeout(10.seconds)) { + assert(SparkContextSuite.taskSucceeded) + } + } + test("SPARK-19446: DebugFilesystem.assertNoOpenStreams should report " + "open streams to help debugging") { val fs = new DebugFilesystem() @@ -555,11 +597,12 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(exc.getCause() != null) stream.close() } - } object SparkContextSuite { @volatile var cancelJob = false @volatile var cancelStage = false @volatile var isTaskStarted = false + @volatile var taskKilled = false + @volatile var taskSucceeded = false } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 8150fff2d0..f47e574b4f 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -110,14 +110,14 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } // we know the task will be started, but not yet deserialized, because of the latches we // use in mockExecutorBackend. - executor.killAllTasks(true) + executor.killAllTasks(true, "test") executorSuiteHelper.latch2.countDown() if (!executorSuiteHelper.latch3.await(5, TimeUnit.SECONDS)) { fail("executor did not send second status update in time") } // `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED` - assert(executorSuiteHelper.testFailedReason === TaskKilled) + assert(executorSuiteHelper.testFailedReason === TaskKilled("test")) assert(executorSuiteHelper.taskState === TaskState.KILLED) } finally { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index a9389003d5..a10941b579 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -126,6 +126,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou override def cancelTasks(stageId: Int, interruptThread: Boolean) { cancelledStages += stageId } + override def killTaskAttempt( + taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} @@ -552,6 +554,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou override def cancelTasks(stageId: Int, interruptThread: Boolean) { throw new UnsupportedOperationException } + override def killTaskAttempt( + taskId: Long, interruptThread: Boolean, reason: String): Boolean = { + throw new UnsupportedOperationException + } override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorHeartbeatReceived( diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index 37c124a726..ba56af8215 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -79,6 +79,8 @@ private class DummyTaskScheduler extends TaskScheduler { override def stop(): Unit = {} override def submitTasks(taskSet: TaskSet): Unit = {} override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {} + override def killTaskAttempt( + taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 38b9d40329..e51e6a0d3f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -176,13 +176,13 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter)) // The non-authorized committer fails outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test")) // New tasks should still not be able to commit because the authorized committer has not failed assert( !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1)) // The authorized committer now fails, clearing the lock outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled("test")) // A new task should now be allowed to become the authorized committer assert( outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 398ac3d620..8103983c43 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -410,7 +410,8 @@ private[spark] abstract class MockBackend( } } - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String): Unit = { // We have to implement this b/c of SPARK-15385. // Its OK for this to be a no-op, because even if a backend does implement killTask, // it really can only be "best-effort" in any case, and the scheduler should be robust to that. diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 064af381a7..132caef097 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -677,7 +677,11 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) sched.initialize(new FakeSchedulerBackend() { - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = {} + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = {} }) // Keep track of the number of tasks that are resubmitted, @@ -935,7 +939,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Complete the speculative attempt for the running task manager.handleSuccessfulTask(4, createTaskResult(3, accumUpdatesByTask(3))) // Verify that it kills other running attempt - verify(sched.backend).killTask(3, "exec2", true) + verify(sched.backend).killTask(3, "exec2", true, "another attempt succeeded") // Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be // killed, so the FakeTaskScheduler is only told about the successful completion // of the speculated task. @@ -1023,14 +1027,14 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg manager.handleSuccessfulTask(speculativeTask.taskId, createTaskResult(3, accumUpdatesByTask(3))) // Verify that it kills other running attempt val origTask = originalTasks(speculativeTask.index) - verify(sched.backend).killTask(origTask.taskId, "exec2", true) + verify(sched.backend).killTask(origTask.taskId, "exec2", true, "another attempt succeeded") // Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be // killed, so the FakeTaskScheduler is only told about the successful completion // of the speculated task. assert(sched.endedTasks(3) === Success) // also because the scheduler is a mock, our manager isn't notified about the task killed event, // so we do that manually - manager.handleFailedTask(origTask.taskId, TaskState.KILLED, TaskKilled) + manager.handleFailedTask(origTask.taskId, TaskState.KILLED, TaskKilled("test")) // this task has "failed" 4 times, but one of them doesn't count, so keep running the stage assert(manager.tasksSuccessful === 4) assert(!manager.isZombie) @@ -1047,7 +1051,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg createTaskResult(3, accumUpdatesByTask(3))) // Verify that it kills other running attempt val origTask2 = originalTasks(speculativeTask2.index) - verify(sched.backend).killTask(origTask2.taskId, "exec2", true) + verify(sched.backend).killTask(origTask2.taskId, "exec2", true, "another attempt succeeded") assert(manager.tasksSuccessful === 5) assert(manager.isZombie) } @@ -1102,8 +1106,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg ExecutorLostFailure(taskDescs(1).executorId, exitCausedByApp = false, reason = None)) tsmSpy.handleFailedTask(taskDescs(2).taskId, TaskState.FAILED, TaskCommitDenied(0, 2, 0)) - tsmSpy.handleFailedTask(taskDescs(3).taskId, TaskState.KILLED, - TaskKilled) + tsmSpy.handleFailedTask(taskDescs(3).taskId, TaskState.KILLED, TaskKilled("test")) // Make sure that the blacklist ignored all of the task failures above, since they aren't // the fault of the executor where the task was running. diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala index 6335d905c0..c770fd5da7 100644 --- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala @@ -110,7 +110,7 @@ class UIUtilsSuite extends SparkFunSuite { } test("SPARK-11906: Progress bar should not overflow because of speculative tasks") { - val generated = makeProgressBar(2, 3, 0, 0, 0, 4).head.child.filter(_.label == "div") + val generated = makeProgressBar(2, 3, 0, 0, Map.empty, 4).head.child.filter(_.label == "div") val expected = Seq(
,
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index e3127da9a6..93964a2d56 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -274,8 +274,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with // Make sure killed tasks are accounted for correctly. listener.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 0, taskType, TaskKilled, taskInfo, metrics)) - assert(listener.stageIdToData((task.stageId, 0)).numKilledTasks === 1) + SparkListenerTaskEnd( + task.stageId, 0, taskType, TaskKilled("test"), taskInfo, metrics)) + assert(listener.stageIdToData((task.stageId, 0)).reasonToNumKilled === Map("test" -> 1)) // Make sure we count success as success. listener.onTaskEnd( diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 9f76c74bce..a64dbeae47 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -164,7 +164,7 @@ class JsonProtocolSuite extends SparkFunSuite { testTaskEndReason(fetchMetadataFailed) testTaskEndReason(exceptionFailure) testTaskEndReason(TaskResultLost) - testTaskEndReason(TaskKilled) + testTaskEndReason(TaskKilled("test")) testTaskEndReason(TaskCommitDenied(2, 3, 4)) testTaskEndReason(ExecutorLostFailure("100", true, Some("Induced failure"))) testTaskEndReason(UnknownReason) @@ -676,7 +676,8 @@ private[spark] object JsonProtocolSuite extends Assertions { assert(r1.fullStackTrace === r2.fullStackTrace) assertSeqEquals[AccumulableInfo](r1.accumUpdates, r2.accumUpdates, (a, b) => a.equals(b)) case (TaskResultLost, TaskResultLost) => - case (TaskKilled, TaskKilled) => + case (r1: TaskKilled, r2: TaskKilled) => + assert(r1.reason == r2.reason) case (TaskCommitDenied(jobId1, partitionId1, attemptNumber1), TaskCommitDenied(jobId2, partitionId2, attemptNumber2)) => assert(jobId1 === jobId2) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9925a8ba72..8ce9367c9b 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -66,6 +66,19 @@ object MimaExcludes { // [SPARK-17161] Removing Python-friendly constructors not needed ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this"), + // [SPARK-19820] Allow reason to be specified to task kill + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.TaskKilled$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productElement"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productArity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.canEqual"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productIterator"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.countTowardsTaskFailures"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productPrefix"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.toErrorString"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.TaskKilled.toString"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.killTaskIfInterrupted"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getKillReason"), + // [SPARK-19876] Add one time trigger, and improve Trigger APIs ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.sql.streaming.Trigger"), ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.streaming.ProcessingTime") diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index b252539782..a086ec7ea2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -104,7 +104,8 @@ private[spark] class MesosExecutorBackend logError("Received KillTask but executor was null") } else { // TODO: Determine the 'interruptOnCancel' property set for the given job. - executor.killTask(t.getValue.toLong, interruptThread = false) + executor.killTask( + t.getValue.toLong, interruptThread = false, reason = "killed by mesos") } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index f198f8893b..735c879c63 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -428,7 +428,8 @@ private[spark] class MesosFineGrainedSchedulerBackend( recordSlaveLost(d, slaveId, ExecutorExited(status, exitCausedByApp = true)) } - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String): Unit = { schedulerDriver.killTask( TaskID.newBuilder() .setValue(taskId.toString).build() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index a89d172a91..9df20731c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -101,9 +101,7 @@ class FileScanRDD( // Kill the task in case it has been marked as killed. This logic is from // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order // to avoid performance overhead. - if (context.isInterrupted()) { - throw new TaskKilledException - } + context.killTaskIfInterrupted() (currentIterator != null && currentIterator.hasNext) || nextIterator() } def next(): Object = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala index 1352ca1c4c..70b4bb466c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -97,7 +97,7 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) completed = batch.numCompletedOutputOp, failed = batch.numFailedOutputOp, skipped = 0, - killed = 0, + reasonToNumKilled = Map.empty, total = batch.outputOperations.size) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 1a87fc790f..f55af6a5cc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -146,7 +146,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { completed = sparkJob.numCompletedTasks, failed = sparkJob.numFailedTasks, skipped = sparkJob.numSkippedTasks, - killed = sparkJob.numKilledTasks, + reasonToNumKilled = sparkJob.reasonToNumKilled, total = sparkJob.numTasks - sparkJob.numSkippedTasks) }