[SPARK-19820][CORE] Add interface to kill tasks w/ a reason

This commit adds a killTaskAttempt method to SparkContext, to allow users to
kill tasks so that they can be re-scheduled elsewhere.

This also refactors the task kill path to allow specifying a reason for the task kill. The reason is propagated opaquely through events, and will show up in the UI automatically as `(N killed: $reason)` and `TaskKilled: $reason`. Without this change, there is no way to provide the user feedback through the UI.

Currently used reasons are "stage cancelled", "another attempt succeeded", and "killed via SparkContext.killTask". The user can also specify a custom reason through `SparkContext.killTask`.

cc rxin

In the stage overview UI the reasons are summarized:
![1](https://cloud.githubusercontent.com/assets/14922/23929209/a83b2862-08e1-11e7-8b3e-ae1967bbe2e5.png)

Within the stage UI you can see individual task kill reasons:
![2](https://cloud.githubusercontent.com/assets/14922/23929200/9a798692-08e1-11e7-8697-72b27ad8a287.png)

Existing tests, tried killing some stages in the UI and verified the messages are as expected.

Author: Eric Liang <ekl@databricks.com>
Author: Eric Liang <ekl@google.com>

Closes #17166 from ericl/kill-reason.
This commit is contained in:
Eric Liang 2017-03-23 23:30:40 -07:00 committed by Kay Ousterhout
parent 19596c28b6
commit 8e558041aa
43 changed files with 289 additions and 115 deletions

View file

@ -23,7 +23,6 @@ import java.util.LinkedList;
import org.apache.avro.reflect.Nullable;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskKilledException;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform;
@ -291,8 +290,8 @@ public final class UnsafeInMemorySorter {
// to avoid performance overhead. This check is added here in `loadNext()` instead of in
// `hasNext()` because it's technically possible for the caller to be relying on
// `getNumRecords()` instead of `hasNext()` to know when to stop.
if (taskContext != null && taskContext.isInterrupted()) {
throw new TaskKilledException();
if (taskContext != null) {
taskContext.killTaskIfInterrupted();
}
// This pointer points to a 4-byte record length, followed by the record's bytes
final long recordPointer = array.get(offset + position);

View file

@ -24,7 +24,6 @@ import com.google.common.io.Closeables;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskKilledException;
import org.apache.spark.io.NioBufferedFileInputStream;
import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockId;
@ -102,8 +101,8 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
// to avoid performance overhead. This check is added here in `loadNext()` instead of in
// `hasNext()` because it's technically possible for the caller to be relying on
// `getNumRecords()` instead of `hasNext()` to know when to stop.
if (taskContext != null && taskContext.isInterrupted()) {
throw new TaskKilledException();
if (taskContext != null) {
taskContext.killTaskIfInterrupted();
}
recordLength = din.readInt();
keyPrefix = din.readLong();

View file

@ -33,12 +33,9 @@ class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator
// is allowed. The assumption is that Thread.interrupted does not have a memory fence in read
// (just a volatile field in C), while context.interrupted is a volatile in the JVM, which
// introduces an expensive read fence.
if (context.isInterrupted) {
throw new TaskKilledException
} else {
context.killTaskIfInterrupted()
delegate.hasNext
}
}
def next(): T = delegate.next()
}

View file

@ -2249,6 +2249,24 @@ class SparkContext(config: SparkConf) extends Logging {
dagScheduler.cancelStage(stageId, None)
}
/**
* Kill and reschedule the given task attempt. Task ids can be obtained from the Spark UI
* or through SparkListener.onTaskStart.
*
* @param taskId the task ID to kill. This id uniquely identifies the task attempt.
* @param interruptThread whether to interrupt the thread running the task.
* @param reason the reason for killing the task, which should be a short string. If a task
* is killed multiple times with different reasons, only one reason will be reported.
*
* @return Whether the task was successfully killed.
*/
def killTaskAttempt(
taskId: Long,
interruptThread: Boolean = true,
reason: String = "killed via SparkContext.killTaskAttempt"): Boolean = {
dagScheduler.killTaskAttempt(taskId, interruptThread, reason)
}
/**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)

View file

@ -184,6 +184,16 @@ abstract class TaskContext extends Serializable {
@DeveloperApi
def getMetricsSources(sourceName: String): Seq[Source]
/**
* If the task is interrupted, throws TaskKilledException with the reason for the interrupt.
*/
private[spark] def killTaskIfInterrupted(): Unit
/**
* If the task is interrupted, the reason this task was killed, otherwise None.
*/
private[spark] def getKillReason(): Option[String]
/**
* Returns the manager for this task's managed memory.
*/

View file

@ -59,8 +59,8 @@ private[spark] class TaskContextImpl(
/** List of callback functions to execute when the task fails. */
@transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener]
// Whether the corresponding task has been killed.
@volatile private var interrupted: Boolean = false
// If defined, the corresponding task has been killed and this option contains the reason.
@volatile private var reasonIfKilled: Option[String] = None
// Whether the task has completed.
private var completed: Boolean = false
@ -140,8 +140,19 @@ private[spark] class TaskContextImpl(
}
/** Marks the task for interruption, i.e. cancellation. */
private[spark] def markInterrupted(): Unit = {
interrupted = true
private[spark] def markInterrupted(reason: String): Unit = {
reasonIfKilled = Some(reason)
}
private[spark] override def killTaskIfInterrupted(): Unit = {
val reason = reasonIfKilled
if (reason.isDefined) {
throw new TaskKilledException(reason.get)
}
}
private[spark] override def getKillReason(): Option[String] = {
reasonIfKilled
}
@GuardedBy("this")
@ -149,7 +160,7 @@ private[spark] class TaskContextImpl(
override def isRunningLocally(): Boolean = false
override def isInterrupted(): Boolean = interrupted
override def isInterrupted(): Boolean = reasonIfKilled.isDefined
override def getLocalProperty(key: String): String = localProperties.getProperty(key)

View file

@ -212,8 +212,8 @@ case object TaskResultLost extends TaskFailedReason {
* Task was killed intentionally and needs to be rescheduled.
*/
@DeveloperApi
case object TaskKilled extends TaskFailedReason {
override def toErrorString: String = "TaskKilled (killed intentionally)"
case class TaskKilled(reason: String) extends TaskFailedReason {
override def toErrorString: String = s"TaskKilled ($reason)"
override def countTowardsTaskFailures: Boolean = false
}

View file

@ -24,4 +24,6 @@ import org.apache.spark.annotation.DeveloperApi
* Exception thrown when a task is explicitly killed (i.e., task failure is expected).
*/
@DeveloperApi
class TaskKilledException extends RuntimeException
class TaskKilledException(val reason: String) extends RuntimeException {
def this() = this("unknown reason")
}

View file

@ -215,7 +215,7 @@ private[spark] class PythonRunner(
case e: Exception if context.isInterrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException
throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason"))
case e: Exception if env.isStopped =>
logDebug("Exception thrown after context is stopped", e)

View file

@ -97,11 +97,11 @@ private[spark] class CoarseGrainedExecutorBackend(
executor.launchTask(this, taskDesc)
}
case KillTask(taskId, _, interruptThread) =>
case KillTask(taskId, _, interruptThread, reason) =>
if (executor == null) {
exitExecutor(1, "Received KillTask command but executor was null")
} else {
executor.killTask(taskId, interruptThread)
executor.killTask(taskId, interruptThread, reason)
}
case StopExecutor =>

View file

@ -158,7 +158,7 @@ private[spark] class Executor(
threadPool.execute(tr)
}
def killTask(taskId: Long, interruptThread: Boolean): Unit = {
def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = {
val taskRunner = runningTasks.get(taskId)
if (taskRunner != null) {
if (taskReaperEnabled) {
@ -168,7 +168,8 @@ private[spark] class Executor(
case Some(existingReaper) => interruptThread && !existingReaper.interruptThread
}
if (shouldCreateReaper) {
val taskReaper = new TaskReaper(taskRunner, interruptThread = interruptThread)
val taskReaper = new TaskReaper(
taskRunner, interruptThread = interruptThread, reason = reason)
taskReaperForTask(taskId) = taskReaper
Some(taskReaper)
} else {
@ -178,7 +179,7 @@ private[spark] class Executor(
// Execute the TaskReaper from outside of the synchronized block.
maybeNewTaskReaper.foreach(taskReaperPool.execute)
} else {
taskRunner.kill(interruptThread = interruptThread)
taskRunner.kill(interruptThread = interruptThread, reason = reason)
}
}
}
@ -189,8 +190,9 @@ private[spark] class Executor(
* tasks instead of taking the JVM down.
* @param interruptThread whether to interrupt the task thread
*/
def killAllTasks(interruptThread: Boolean) : Unit = {
runningTasks.keys().asScala.foreach(t => killTask(t, interruptThread = interruptThread))
def killAllTasks(interruptThread: Boolean, reason: String) : Unit = {
runningTasks.keys().asScala.foreach(t =>
killTask(t, interruptThread = interruptThread, reason = reason))
}
def stop(): Unit = {
@ -217,8 +219,8 @@ private[spark] class Executor(
val threadName = s"Executor task launch worker for task $taskId"
private val taskName = taskDescription.name
/** Whether this task has been killed. */
@volatile private var killed = false
/** If specified, this task has been killed and this option contains the reason. */
@volatile private var reasonIfKilled: Option[String] = None
@volatile private var threadId: Long = -1
@ -239,13 +241,13 @@ private[spark] class Executor(
*/
@volatile var task: Task[Any] = _
def kill(interruptThread: Boolean): Unit = {
logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
killed = true
def kill(interruptThread: Boolean, reason: String): Unit = {
logInfo(s"Executor is trying to kill $taskName (TID $taskId), reason: $reason")
reasonIfKilled = Some(reason)
if (task != null) {
synchronized {
if (!finished) {
task.kill(interruptThread)
task.kill(interruptThread, reason)
}
}
}
@ -296,12 +298,13 @@ private[spark] class Executor(
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
if (killed) {
val killReason = reasonIfKilled
if (killReason.isDefined) {
// Throw an exception rather than returning, because returning within a try{} block
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
// for the task.
throw new TaskKilledException
throw new TaskKilledException(killReason.get)
}
logDebug("Task " + taskId + "'s epoch is " + task.epoch)
@ -358,9 +361,7 @@ private[spark] class Executor(
} else 0L
// If the task has been killed, let's fail it.
if (task.killed) {
throw new TaskKilledException
}
task.context.killTaskIfInterrupted()
val resultSer = env.serializer.newInstance()
val beforeSerialization = System.currentTimeMillis()
@ -426,15 +427,17 @@ private[spark] class Executor(
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
case _: TaskKilledException =>
logInfo(s"Executor killed $taskName (TID $taskId)")
case t: TaskKilledException =>
logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))
case _: InterruptedException if task.killed =>
logInfo(s"Executor interrupted and killed $taskName (TID $taskId)")
case _: InterruptedException if task.reasonIfKilled.isDefined =>
val killReason = task.reasonIfKilled.getOrElse("unknown reason")
logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
execBackend.statusUpdate(
taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))
case CausedBy(cDE: CommitDeniedException) =>
val reason = cDE.toTaskFailedReason
@ -512,7 +515,8 @@ private[spark] class Executor(
*/
private class TaskReaper(
taskRunner: TaskRunner,
val interruptThread: Boolean)
val interruptThread: Boolean,
val reason: String)
extends Runnable {
private[this] val taskId: Long = taskRunner.taskId
@ -533,7 +537,7 @@ private[spark] class Executor(
// Only attempt to kill the task once. If interruptThread = false then a second kill
// attempt would be a no-op and if interruptThread = true then it may not be safe or
// effective to interrupt multiple times:
taskRunner.kill(interruptThread = interruptThread)
taskRunner.kill(interruptThread = interruptThread, reason = reason)
// Monitor the killed task until it exits. The synchronization logic here is complicated
// because we don't want to synchronize on the taskRunner while possibly taking a thread
// dump, but we also need to be careful to avoid races between checking whether the task

View file

@ -738,6 +738,15 @@ class DAGScheduler(
eventProcessLoop.post(StageCancelled(stageId, reason))
}
/**
* Kill a given task. It will be retried.
*
* @return Whether the task was successfully killed.
*/
def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = {
taskScheduler.killTaskAttempt(taskId, interruptThread, reason)
}
/**
* Resubmit any failed stages. Ordinarily called after a small amount of time has passed since
* the last fetch failure.
@ -1353,7 +1362,7 @@ class DAGScheduler(
case TaskResultLost =>
// Do nothing here; the TaskScheduler handles these failures and resubmits the task.
case _: ExecutorLostFailure | TaskKilled | UnknownReason =>
case _: ExecutorLostFailure | _: TaskKilled | UnknownReason =>
// Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler
// will abort the job.
}

View file

@ -30,8 +30,21 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit
def defaultParallelism(): Int
def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit =
/**
* Requests that an executor kills a running task.
*
* @param taskId Id of the task.
* @param executorId Id of the executor the task is running on.
* @param interruptThread Whether the executor should interrupt the task thread.
* @param reason The reason for the task kill.
*/
def killTask(
taskId: Long,
executorId: String,
interruptThread: Boolean,
reason: String): Unit =
throw new UnsupportedOperationException
def isReady(): Boolean = true
/**

View file

@ -89,8 +89,8 @@ private[spark] abstract class Task[T](
TaskContext.setTaskContext(context)
taskThread = Thread.currentThread()
if (_killed) {
kill(interruptThread = false)
if (_reasonIfKilled != null) {
kill(interruptThread = false, _reasonIfKilled)
}
new CallerContext(
@ -158,17 +158,17 @@ private[spark] abstract class Task[T](
// The actual Thread on which the task is running, if any. Initialized in run().
@volatile @transient private var taskThread: Thread = _
// A flag to indicate whether the task is killed. This is used in case context is not yet
// initialized when kill() is invoked.
@volatile @transient private var _killed = false
// If non-null, this task has been killed and the reason is as specified. This is used in case
// context is not yet initialized when kill() is invoked.
@volatile @transient private var _reasonIfKilled: String = null
protected var _executorDeserializeTime: Long = 0
protected var _executorDeserializeCpuTime: Long = 0
/**
* Whether the task has been killed.
* If defined, this task has been killed and this option contains the reason.
*/
def killed: Boolean = _killed
def reasonIfKilled: Option[String] = Option(_reasonIfKilled)
/**
* Returns the amount of time spent deserializing the RDD and function to be run.
@ -201,10 +201,11 @@ private[spark] abstract class Task[T](
* be called multiple times.
* If interruptThread is true, we will also call Thread.interrupt() on the Task's executor thread.
*/
def kill(interruptThread: Boolean) {
_killed = true
def kill(interruptThread: Boolean, reason: String) {
require(reason != null)
_reasonIfKilled = reason
if (context != null) {
context.markInterrupted()
context.markInterrupted(reason)
}
if (interruptThread && taskThread != null) {
taskThread.interrupt()

View file

@ -54,6 +54,13 @@ private[spark] trait TaskScheduler {
// Cancel a stage.
def cancelTasks(stageId: Int, interruptThread: Boolean): Unit
/**
* Kills a task attempt.
*
* @return Whether the task was successfully killed.
*/
def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean
// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
def setDAGScheduler(dagScheduler: DAGScheduler): Unit

View file

@ -241,7 +241,7 @@ private[spark] class TaskSchedulerImpl private[scheduler](
// simply abort the stage.
tsm.runningTasksSet.foreach { tid =>
val execId = taskIdToExecutorId(tid)
backend.killTask(tid, execId, interruptThread)
backend.killTask(tid, execId, interruptThread, reason = "stage cancelled")
}
tsm.abort("Stage %s cancelled".format(stageId))
logInfo("Stage %d was cancelled".format(stageId))
@ -249,6 +249,18 @@ private[spark] class TaskSchedulerImpl private[scheduler](
}
}
override def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = {
logInfo(s"Killing task $taskId: $reason")
val execId = taskIdToExecutorId.get(taskId)
if (execId.isDefined) {
backend.killTask(taskId, execId.get, interruptThread, reason)
true
} else {
logWarning(s"Could not kill task $taskId because no task with that ID was found.")
false
}
}
/**
* Called to indicate that all task attempts (including speculated tasks) associated with the
* given TaskSetManager have completed, so state associated with the TaskSetManager should be
@ -469,7 +481,7 @@ private[spark] class TaskSchedulerImpl private[scheduler](
taskState: TaskState,
reason: TaskFailedReason): Unit = synchronized {
taskSetManager.handleFailedTask(tid, taskState, reason)
if (!taskSetManager.isZombie && taskState != TaskState.KILLED) {
if (!taskSetManager.isZombie && !taskSetManager.someAttemptSucceeded(tid)) {
// Need to revive offers again now that the task set manager state has been updated to
// reflect failed tasks that need to be re-run.
backend.reviveOffers()

View file

@ -101,6 +101,10 @@ private[spark] class TaskSetManager(
override def runningTasks: Int = runningTasksSet.size
def someAttemptSucceeded(tid: Long): Boolean = {
successful(taskInfos(tid).index)
}
// True once no more tasks should be launched for this task set manager. TaskSetManagers enter
// the zombie state once at least one attempt of each task has completed successfully, or if the
// task set is aborted (for example, because it was killed). TaskSetManagers remain in the zombie
@ -722,7 +726,11 @@ private[spark] class TaskSetManager(
logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " +
s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " +
s"as the attempt ${info.attemptNumber} succeeded on ${info.host}")
sched.backend.killTask(attemptInfo.taskId, attemptInfo.executorId, true)
sched.backend.killTask(
attemptInfo.taskId,
attemptInfo.executorId,
interruptThread = true,
reason = "another attempt succeeded")
}
if (!successful(index)) {
tasksSuccessful += 1

View file

@ -40,7 +40,7 @@ private[spark] object CoarseGrainedClusterMessages {
// Driver to executors
case class LaunchTask(data: SerializableBuffer) extends CoarseGrainedClusterMessage
case class KillTask(taskId: Long, executor: String, interruptThread: Boolean)
case class KillTask(taskId: Long, executor: String, interruptThread: Boolean, reason: String)
extends CoarseGrainedClusterMessage
case class KillExecutorsOnHost(host: String)

View file

@ -132,10 +132,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
case ReviveOffers =>
makeOffers()
case KillTask(taskId, executorId, interruptThread) =>
case KillTask(taskId, executorId, interruptThread, reason) =>
executorDataMap.get(executorId) match {
case Some(executorInfo) =>
executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread))
executorInfo.executorEndpoint.send(
KillTask(taskId, executorId, interruptThread, reason))
case None =>
// Ignoring the task kill since the executor is not registered.
logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.")
@ -428,8 +429,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
driverEndpoint.send(ReviveOffers)
}
override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) {
driverEndpoint.send(KillTask(taskId, executorId, interruptThread))
override def killTask(
taskId: Long, executorId: String, interruptThread: Boolean, reason: String) {
driverEndpoint.send(KillTask(taskId, executorId, interruptThread, reason))
}
override def defaultParallelism(): Int = {

View file

@ -34,7 +34,7 @@ private case class ReviveOffers()
private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
private case class KillTask(taskId: Long, interruptThread: Boolean)
private case class KillTask(taskId: Long, interruptThread: Boolean, reason: String)
private case class StopExecutor()
@ -70,8 +70,8 @@ private[spark] class LocalEndpoint(
reviveOffers()
}
case KillTask(taskId, interruptThread) =>
executor.killTask(taskId, interruptThread)
case KillTask(taskId, interruptThread, reason) =>
executor.killTask(taskId, interruptThread, reason)
}
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
@ -143,8 +143,9 @@ private[spark] class LocalSchedulerBackend(
override def defaultParallelism(): Int =
scheduler.conf.getInt("spark.default.parallelism", totalCores)
override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) {
localEndpoint.send(KillTask(taskId, interruptThread))
override def killTask(
taskId: Long, executorId: String, interruptThread: Boolean, reason: String) {
localEndpoint.send(KillTask(taskId, interruptThread, reason))
}
override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {

View file

@ -342,7 +342,7 @@ private[spark] object UIUtils extends Logging {
completed: Int,
failed: Int,
skipped: Int,
killed: Int,
reasonToNumKilled: Map[String, Int],
total: Int): Seq[Node] = {
val completeWidth = "width: %s%%".format((completed.toDouble/total)*100)
// started + completed can be > total when there are speculative tasks
@ -354,7 +354,10 @@ private[spark] object UIUtils extends Logging {
{completed}/{total}
{ if (failed > 0) s"($failed failed)" }
{ if (skipped > 0) s"($skipped skipped)" }
{ if (killed > 0) s"($killed killed)" }
{ reasonToNumKilled.toSeq.sortBy(-_._2).map {
case (reason, count) => s"($count killed: $reason)"
}
}
</span>
<div class="bar bar-completed" style={completeWidth}></div>
<div class="bar bar-running" style={startWidth}></div>

View file

@ -630,8 +630,8 @@ private[ui] class JobPagedTable(
</td>
<td class="progress-cell">
{UIUtils.makeProgressBar(started = job.numActiveTasks, completed = job.numCompletedTasks,
failed = job.numFailedTasks, skipped = job.numSkippedTasks, killed = job.numKilledTasks,
total = job.numTasks - job.numSkippedTasks)}
failed = job.numFailedTasks, skipped = job.numSkippedTasks,
reasonToNumKilled = job.reasonToNumKilled, total = job.numTasks - job.numSkippedTasks)}
</td>
</tr>
}

View file

@ -133,9 +133,9 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage
</td>
<td>{executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")}</td>
<td sorttable_customkey={v.taskTime.toString}>{UIUtils.formatDuration(v.taskTime)}</td>
<td>{v.failedTasks + v.succeededTasks + v.killedTasks}</td>
<td>{v.failedTasks + v.succeededTasks + v.reasonToNumKilled.map(_._2).sum}</td>
<td>{v.failedTasks}</td>
<td>{v.killedTasks}</td>
<td>{v.reasonToNumKilled.map(_._2).sum}</td>
<td>{v.succeededTasks}</td>
{if (stageData.hasInput) {
<td sorttable_customkey={v.inputBytes.toString}>

View file

@ -371,8 +371,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
taskEnd.reason match {
case Success =>
execSummary.succeededTasks += 1
case TaskKilled =>
execSummary.killedTasks += 1
case kill: TaskKilled =>
execSummary.reasonToNumKilled = execSummary.reasonToNumKilled.updated(
kill.reason, execSummary.reasonToNumKilled.getOrElse(kill.reason, 0) + 1)
case _ =>
execSummary.failedTasks += 1
}
@ -385,9 +386,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
stageData.completedIndices.add(info.index)
stageData.numCompleteTasks += 1
None
case TaskKilled =>
stageData.numKilledTasks += 1
Some(TaskKilled.toErrorString)
case kill: TaskKilled =>
stageData.reasonToNumKilled = stageData.reasonToNumKilled.updated(
kill.reason, stageData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1)
Some(kill.toErrorString)
case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates
stageData.numFailedTasks += 1
Some(e.toErrorString)
@ -422,8 +424,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
taskEnd.reason match {
case Success =>
jobData.numCompletedTasks += 1
case TaskKilled =>
jobData.numKilledTasks += 1
case kill: TaskKilled =>
jobData.reasonToNumKilled = jobData.reasonToNumKilled.updated(
kill.reason, jobData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1)
case _ =>
jobData.numFailedTasks += 1
}

View file

@ -300,7 +300,7 @@ private[ui] class StagePagedTable(
<td class="progress-cell">
{UIUtils.makeProgressBar(started = stageData.numActiveTasks,
completed = stageData.completedIndices.size, failed = stageData.numFailedTasks,
skipped = 0, killed = stageData.numKilledTasks, total = info.numTasks)}
skipped = 0, reasonToNumKilled = stageData.reasonToNumKilled, total = info.numTasks)}
</td>
<td>{data.inputReadWithUnit}</td>
<td>{data.outputWriteWithUnit}</td>

View file

@ -32,7 +32,7 @@ private[spark] object UIData {
var taskTime : Long = 0
var failedTasks : Int = 0
var succeededTasks : Int = 0
var killedTasks : Int = 0
var reasonToNumKilled : Map[String, Int] = Map.empty
var inputBytes : Long = 0
var inputRecords : Long = 0
var outputBytes : Long = 0
@ -64,7 +64,7 @@ private[spark] object UIData {
var numCompletedTasks: Int = 0,
var numSkippedTasks: Int = 0,
var numFailedTasks: Int = 0,
var numKilledTasks: Int = 0,
var reasonToNumKilled: Map[String, Int] = Map.empty,
/* Stages */
var numActiveStages: Int = 0,
// This needs to be a set instead of a simple count to prevent double-counting of rerun stages:
@ -78,7 +78,7 @@ private[spark] object UIData {
var numCompleteTasks: Int = _
var completedIndices = new OpenHashSet[Int]()
var numFailedTasks: Int = _
var numKilledTasks: Int = _
var reasonToNumKilled: Map[String, Int] = Map.empty
var executorRunTime: Long = _
var executorCpuTime: Long = _

View file

@ -390,6 +390,8 @@ private[spark] object JsonProtocol {
("Executor ID" -> executorId) ~
("Exit Caused By App" -> exitCausedByApp) ~
("Loss Reason" -> reason.map(_.toString))
case taskKilled: TaskKilled =>
("Kill Reason" -> taskKilled.reason)
case _ => Utils.emptyJson
}
("Reason" -> reason) ~ json
@ -877,7 +879,10 @@ private[spark] object JsonProtocol {
}))
ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates)
case `taskResultLost` => TaskResultLost
case `taskKilled` => TaskKilled
case `taskKilled` =>
val killReason = Utils.jsonOption(json \ "Kill Reason")
.map(_.extract[String]).getOrElse("unknown reason")
TaskKilled(killReason)
case `taskCommitDenied` =>
// Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON
// de/serialization logic was not added until 1.5.1. To provide backward compatibility

View file

@ -34,7 +34,7 @@ import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFor
import org.scalatest.concurrent.Eventually
import org.scalatest.Matchers._
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart}
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskEnd, SparkListenerTaskStart}
import org.apache.spark.util.Utils
@ -540,6 +540,48 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
}
}
// Launches one task that will run forever. Once the SparkListener detects the task has
// started, kill and re-schedule it. The second run of the task will complete immediately.
// If this test times out, then the first version of the task wasn't killed successfully.
test("Killing tasks") {
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
SparkContextSuite.isTaskStarted = false
SparkContextSuite.taskKilled = false
SparkContextSuite.taskSucceeded = false
val listener = new SparkListener {
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
eventually(timeout(10.seconds)) {
assert(SparkContextSuite.isTaskStarted)
}
if (!SparkContextSuite.taskKilled) {
SparkContextSuite.taskKilled = true
sc.killTaskAttempt(taskStart.taskInfo.taskId, true, "first attempt will hang")
}
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
if (taskEnd.taskInfo.attemptNumber == 1 && taskEnd.reason == Success) {
SparkContextSuite.taskSucceeded = true
}
}
}
sc.addSparkListener(listener)
eventually(timeout(20.seconds)) {
sc.parallelize(1 to 1).foreach { x =>
// first attempt will hang
if (!SparkContextSuite.isTaskStarted) {
SparkContextSuite.isTaskStarted = true
Thread.sleep(9999999)
}
// second attempt succeeds immediately
}
}
eventually(timeout(10.seconds)) {
assert(SparkContextSuite.taskSucceeded)
}
}
test("SPARK-19446: DebugFilesystem.assertNoOpenStreams should report " +
"open streams to help debugging") {
val fs = new DebugFilesystem()
@ -555,11 +597,12 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
assert(exc.getCause() != null)
stream.close()
}
}
object SparkContextSuite {
@volatile var cancelJob = false
@volatile var cancelStage = false
@volatile var isTaskStarted = false
@volatile var taskKilled = false
@volatile var taskSucceeded = false
}

View file

@ -110,14 +110,14 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
}
// we know the task will be started, but not yet deserialized, because of the latches we
// use in mockExecutorBackend.
executor.killAllTasks(true)
executor.killAllTasks(true, "test")
executorSuiteHelper.latch2.countDown()
if (!executorSuiteHelper.latch3.await(5, TimeUnit.SECONDS)) {
fail("executor did not send second status update in time")
}
// `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED`
assert(executorSuiteHelper.testFailedReason === TaskKilled)
assert(executorSuiteHelper.testFailedReason === TaskKilled("test"))
assert(executorSuiteHelper.taskState === TaskState.KILLED)
}
finally {

View file

@ -126,6 +126,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
override def cancelTasks(stageId: Int, interruptThread: Boolean) {
cancelledStages += stageId
}
override def killTaskAttempt(
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
override def defaultParallelism() = 2
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
@ -552,6 +554,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
override def cancelTasks(stageId: Int, interruptThread: Boolean) {
throw new UnsupportedOperationException
}
override def killTaskAttempt(
taskId: Long, interruptThread: Boolean, reason: String): Boolean = {
throw new UnsupportedOperationException
}
override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}
override def defaultParallelism(): Int = 2
override def executorHeartbeatReceived(

View file

@ -79,6 +79,8 @@ private class DummyTaskScheduler extends TaskScheduler {
override def stop(): Unit = {}
override def submitTasks(taskSet: TaskSet): Unit = {}
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {}
override def killTaskAttempt(
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}
override def defaultParallelism(): Int = 2
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}

View file

@ -176,13 +176,13 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter {
assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter))
// The non-authorized committer fails
outputCommitCoordinator.taskCompleted(
stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled)
stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test"))
// New tasks should still not be able to commit because the authorized committer has not failed
assert(
!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1))
// The authorized committer now fails, clearing the lock
outputCommitCoordinator.taskCompleted(
stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled)
stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled("test"))
// A new task should now be allowed to become the authorized committer
assert(
outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2))

View file

@ -410,7 +410,8 @@ private[spark] abstract class MockBackend(
}
}
override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = {
override def killTask(
taskId: Long, executorId: String, interruptThread: Boolean, reason: String): Unit = {
// We have to implement this b/c of SPARK-15385.
// Its OK for this to be a no-op, because even if a backend does implement killTask,
// it really can only be "best-effort" in any case, and the scheduler should be robust to that.

View file

@ -677,7 +677,11 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"))
sched.initialize(new FakeSchedulerBackend() {
override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = {}
override def killTask(
taskId: Long,
executorId: String,
interruptThread: Boolean,
reason: String): Unit = {}
})
// Keep track of the number of tasks that are resubmitted,
@ -935,7 +939,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
// Complete the speculative attempt for the running task
manager.handleSuccessfulTask(4, createTaskResult(3, accumUpdatesByTask(3)))
// Verify that it kills other running attempt
verify(sched.backend).killTask(3, "exec2", true)
verify(sched.backend).killTask(3, "exec2", true, "another attempt succeeded")
// Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be
// killed, so the FakeTaskScheduler is only told about the successful completion
// of the speculated task.
@ -1023,14 +1027,14 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
manager.handleSuccessfulTask(speculativeTask.taskId, createTaskResult(3, accumUpdatesByTask(3)))
// Verify that it kills other running attempt
val origTask = originalTasks(speculativeTask.index)
verify(sched.backend).killTask(origTask.taskId, "exec2", true)
verify(sched.backend).killTask(origTask.taskId, "exec2", true, "another attempt succeeded")
// Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be
// killed, so the FakeTaskScheduler is only told about the successful completion
// of the speculated task.
assert(sched.endedTasks(3) === Success)
// also because the scheduler is a mock, our manager isn't notified about the task killed event,
// so we do that manually
manager.handleFailedTask(origTask.taskId, TaskState.KILLED, TaskKilled)
manager.handleFailedTask(origTask.taskId, TaskState.KILLED, TaskKilled("test"))
// this task has "failed" 4 times, but one of them doesn't count, so keep running the stage
assert(manager.tasksSuccessful === 4)
assert(!manager.isZombie)
@ -1047,7 +1051,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
createTaskResult(3, accumUpdatesByTask(3)))
// Verify that it kills other running attempt
val origTask2 = originalTasks(speculativeTask2.index)
verify(sched.backend).killTask(origTask2.taskId, "exec2", true)
verify(sched.backend).killTask(origTask2.taskId, "exec2", true, "another attempt succeeded")
assert(manager.tasksSuccessful === 5)
assert(manager.isZombie)
}
@ -1102,8 +1106,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
ExecutorLostFailure(taskDescs(1).executorId, exitCausedByApp = false, reason = None))
tsmSpy.handleFailedTask(taskDescs(2).taskId, TaskState.FAILED,
TaskCommitDenied(0, 2, 0))
tsmSpy.handleFailedTask(taskDescs(3).taskId, TaskState.KILLED,
TaskKilled)
tsmSpy.handleFailedTask(taskDescs(3).taskId, TaskState.KILLED, TaskKilled("test"))
// Make sure that the blacklist ignored all of the task failures above, since they aren't
// the fault of the executor where the task was running.

View file

@ -110,7 +110,7 @@ class UIUtilsSuite extends SparkFunSuite {
}
test("SPARK-11906: Progress bar should not overflow because of speculative tasks") {
val generated = makeProgressBar(2, 3, 0, 0, 0, 4).head.child.filter(_.label == "div")
val generated = makeProgressBar(2, 3, 0, 0, Map.empty, 4).head.child.filter(_.label == "div")
val expected = Seq(
<div class="bar bar-completed" style="width: 75.0%"></div>,
<div class="bar bar-running" style="width: 25.0%"></div>

View file

@ -274,8 +274,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
// Make sure killed tasks are accounted for correctly.
listener.onTaskEnd(
SparkListenerTaskEnd(task.stageId, 0, taskType, TaskKilled, taskInfo, metrics))
assert(listener.stageIdToData((task.stageId, 0)).numKilledTasks === 1)
SparkListenerTaskEnd(
task.stageId, 0, taskType, TaskKilled("test"), taskInfo, metrics))
assert(listener.stageIdToData((task.stageId, 0)).reasonToNumKilled === Map("test" -> 1))
// Make sure we count success as success.
listener.onTaskEnd(

View file

@ -164,7 +164,7 @@ class JsonProtocolSuite extends SparkFunSuite {
testTaskEndReason(fetchMetadataFailed)
testTaskEndReason(exceptionFailure)
testTaskEndReason(TaskResultLost)
testTaskEndReason(TaskKilled)
testTaskEndReason(TaskKilled("test"))
testTaskEndReason(TaskCommitDenied(2, 3, 4))
testTaskEndReason(ExecutorLostFailure("100", true, Some("Induced failure")))
testTaskEndReason(UnknownReason)
@ -676,7 +676,8 @@ private[spark] object JsonProtocolSuite extends Assertions {
assert(r1.fullStackTrace === r2.fullStackTrace)
assertSeqEquals[AccumulableInfo](r1.accumUpdates, r2.accumUpdates, (a, b) => a.equals(b))
case (TaskResultLost, TaskResultLost) =>
case (TaskKilled, TaskKilled) =>
case (r1: TaskKilled, r2: TaskKilled) =>
assert(r1.reason == r2.reason)
case (TaskCommitDenied(jobId1, partitionId1, attemptNumber1),
TaskCommitDenied(jobId2, partitionId2, attemptNumber2)) =>
assert(jobId1 === jobId2)

View file

@ -66,6 +66,19 @@ object MimaExcludes {
// [SPARK-17161] Removing Python-friendly constructors not needed
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this"),
// [SPARK-19820] Allow reason to be specified to task kill
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.TaskKilled$"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productElement"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productArity"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.canEqual"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productIterator"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.countTowardsTaskFailures"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productPrefix"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.toErrorString"),
ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.TaskKilled.toString"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.killTaskIfInterrupted"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getKillReason"),
// [SPARK-19876] Add one time trigger, and improve Trigger APIs
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.sql.streaming.Trigger"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.streaming.ProcessingTime")

View file

@ -104,7 +104,8 @@ private[spark] class MesosExecutorBackend
logError("Received KillTask but executor was null")
} else {
// TODO: Determine the 'interruptOnCancel' property set for the given job.
executor.killTask(t.getValue.toLong, interruptThread = false)
executor.killTask(
t.getValue.toLong, interruptThread = false, reason = "killed by mesos")
}
}

View file

@ -428,7 +428,8 @@ private[spark] class MesosFineGrainedSchedulerBackend(
recordSlaveLost(d, slaveId, ExecutorExited(status, exitCausedByApp = true))
}
override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = {
override def killTask(
taskId: Long, executorId: String, interruptThread: Boolean, reason: String): Unit = {
schedulerDriver.killTask(
TaskID.newBuilder()
.setValue(taskId.toString).build()

View file

@ -101,9 +101,7 @@ class FileScanRDD(
// Kill the task in case it has been marked as killed. This logic is from
// InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
// to avoid performance overhead.
if (context.isInterrupted()) {
throw new TaskKilledException
}
context.killTaskIfInterrupted()
(currentIterator != null && currentIterator.hasNext) || nextIterator()
}
def next(): Object = {

View file

@ -97,7 +97,7 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long)
completed = batch.numCompletedOutputOp,
failed = batch.numFailedOutputOp,
skipped = 0,
killed = 0,
reasonToNumKilled = Map.empty,
total = batch.outputOperations.size)
}
</td>

View file

@ -146,7 +146,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") {
completed = sparkJob.numCompletedTasks,
failed = sparkJob.numFailedTasks,
skipped = sparkJob.numSkippedTasks,
killed = sparkJob.numKilledTasks,
reasonToNumKilled = sparkJob.reasonToNumKilled,
total = sparkJob.numTasks - sparkJob.numSkippedTasks)
}
</td>