[SPARK-8103][core] DAGScheduler should not submit multiple concurrent attempts for a stage
https://issues.apache.org/jira/browse/SPARK-8103 cc kayousterhout (thanks for the extra test case) Author: Imran Rashid <irashid@cloudera.com> Author: Kay Ousterhout <kayousterhout@gmail.com> Author: Imran Rashid <squito@users.noreply.github.com> Closes #6750 from squito/SPARK-8103 and squashes the following commits: fb3acfc [Imran Rashid] fix log msg e01b7aa [Imran Rashid] fix some comments, style 584acd4 [Imran Rashid] simplify going from taskId to taskSetMgr e43ac25 [Imran Rashid] Merge branch 'master' into SPARK-8103 6bc23af [Imran Rashid] update log msg 4470fa1 [Imran Rashid] rename c04707e [Imran Rashid] style 88b61cc [Imran Rashid] add tests to make sure that TaskSchedulerImpl schedules correctly with zombie attempts d7f1ef2 [Imran Rashid] get rid of activeTaskSets a21c8b5 [Imran Rashid] Merge branch 'master' into SPARK-8103 906d626 [Imran Rashid] fix merge 109900e [Imran Rashid] Merge branch 'master' into SPARK-8103 c0d4d90 [Imran Rashid] Revert "Index active task sets by stage Id rather than by task set id" f025154 [Imran Rashid] Merge pull request #2 from kayousterhout/imran_SPARK-8103 baf46e1 [Kay Ousterhout] Index active task sets by stage Id rather than by task set id 19685bb [Imran Rashid] switch to using latestInfo.attemptId, and add comments a5f7c8c [Imran Rashid] remove comment for reviewers 227b40d [Imran Rashid] style 517b6e5 [Imran Rashid] get rid of SparkIllegalStateException b2faef5 [Imran Rashid] faster check for conflicting task sets 6542b42 [Imran Rashid] remove extra stageAttemptId ada7726 [Imran Rashid] reviewer feedback d8eb202 [Imran Rashid] Merge branch 'master' into SPARK-8103 46bc26a [Imran Rashid] more cleanup of debug garbage cb245da [Imran Rashid] finally found the issue ... clean up debug stuff 8c29707 [Imran Rashid] Merge branch 'master' into SPARK-8103 89a59b6 [Imran Rashid] more printlns ... 9601b47 [Imran Rashid] more debug printlns ecb4e7d [Imran Rashid] debugging printlns b6bc248 [Imran Rashid] style 55f4a94 [Imran Rashid] get rid of more random test case since kays tests are clearer 7021d28 [Imran Rashid] update test since listenerBus.waitUntilEmpty now throws an exception instead of returning a boolean 883fe49 [Kay Ousterhout] Unit tests for concurrent stages issue 6e14683 [Imran Rashid] unit test just to make sure we fail fast on concurrent attempts 06a0af6 [Imran Rashid] ignore for jenkins c443def [Imran Rashid] better fix and simpler test case 28d70aa [Imran Rashid] wip on getting a better test case ... a9bf31f [Imran Rashid] wip
This commit is contained in:
parent
c6fe9b4a17
commit
80e2568b25
|
@ -857,7 +857,6 @@ class DAGScheduler(
|
|||
// Get our pending tasks and remember them in our pendingTasks entry
|
||||
stage.pendingTasks.clear()
|
||||
|
||||
|
||||
// First figure out the indexes of partition ids to compute.
|
||||
val partitionsToCompute: Seq[Int] = {
|
||||
stage match {
|
||||
|
@ -918,7 +917,7 @@ class DAGScheduler(
|
|||
partitionsToCompute.map { id =>
|
||||
val locs = getPreferredLocs(stage.rdd, id)
|
||||
val part = stage.rdd.partitions(id)
|
||||
new ShuffleMapTask(stage.id, taskBinary, part, locs)
|
||||
new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs)
|
||||
}
|
||||
|
||||
case stage: ResultStage =>
|
||||
|
@ -927,7 +926,7 @@ class DAGScheduler(
|
|||
val p: Int = job.partitions(id)
|
||||
val part = stage.rdd.partitions(p)
|
||||
val locs = getPreferredLocs(stage.rdd, p)
|
||||
new ResultTask(stage.id, taskBinary, part, locs, id)
|
||||
new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
|
@ -1069,10 +1068,11 @@ class DAGScheduler(
|
|||
val execId = status.location.executorId
|
||||
logDebug("ShuffleMapTask finished on " + execId)
|
||||
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
|
||||
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
|
||||
logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
|
||||
} else {
|
||||
shuffleStage.addOutputLoc(smt.partitionId, status)
|
||||
}
|
||||
|
||||
if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) {
|
||||
markStageAsFinished(shuffleStage)
|
||||
logInfo("looking for newly runnable stages")
|
||||
|
@ -1132,38 +1132,48 @@ class DAGScheduler(
|
|||
val failedStage = stageIdToStage(task.stageId)
|
||||
val mapStage = shuffleToMapStage(shuffleId)
|
||||
|
||||
// It is likely that we receive multiple FetchFailed for a single stage (because we have
|
||||
// multiple tasks running concurrently on different executors). In that case, it is possible
|
||||
// the fetch failure has already been handled by the scheduler.
|
||||
if (runningStages.contains(failedStage)) {
|
||||
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
|
||||
s"due to a fetch failure from $mapStage (${mapStage.name})")
|
||||
markStageAsFinished(failedStage, Some(failureMessage))
|
||||
}
|
||||
if (failedStage.latestInfo.attemptId != task.stageAttemptId) {
|
||||
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
|
||||
s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
|
||||
s"(attempt ID ${failedStage.latestInfo.attemptId}) running")
|
||||
} else {
|
||||
|
||||
if (disallowStageRetryForTest) {
|
||||
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
|
||||
} else if (failedStages.isEmpty) {
|
||||
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
|
||||
// in that case the event will already have been scheduled.
|
||||
// TODO: Cancel running tasks in the stage
|
||||
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
|
||||
s"$failedStage (${failedStage.name}) due to fetch failure")
|
||||
messageScheduler.schedule(new Runnable {
|
||||
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
|
||||
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
|
||||
}
|
||||
failedStages += failedStage
|
||||
failedStages += mapStage
|
||||
// Mark the map whose fetch failed as broken in the map stage
|
||||
if (mapId != -1) {
|
||||
mapStage.removeOutputLoc(mapId, bmAddress)
|
||||
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
|
||||
}
|
||||
// It is likely that we receive multiple FetchFailed for a single stage (because we have
|
||||
// multiple tasks running concurrently on different executors). In that case, it is
|
||||
// possible the fetch failure has already been handled by the scheduler.
|
||||
if (runningStages.contains(failedStage)) {
|
||||
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
|
||||
s"due to a fetch failure from $mapStage (${mapStage.name})")
|
||||
markStageAsFinished(failedStage, Some(failureMessage))
|
||||
} else {
|
||||
logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " +
|
||||
s"longer running")
|
||||
}
|
||||
|
||||
// TODO: mark the executor as failed only if there were lots of fetch failures on it
|
||||
if (bmAddress != null) {
|
||||
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
|
||||
if (disallowStageRetryForTest) {
|
||||
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
|
||||
} else if (failedStages.isEmpty) {
|
||||
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
|
||||
// in that case the event will already have been scheduled.
|
||||
// TODO: Cancel running tasks in the stage
|
||||
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
|
||||
s"$failedStage (${failedStage.name}) due to fetch failure")
|
||||
messageScheduler.schedule(new Runnable {
|
||||
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
|
||||
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
|
||||
}
|
||||
failedStages += failedStage
|
||||
failedStages += mapStage
|
||||
// Mark the map whose fetch failed as broken in the map stage
|
||||
if (mapId != -1) {
|
||||
mapStage.removeOutputLoc(mapId, bmAddress)
|
||||
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
|
||||
}
|
||||
|
||||
// TODO: mark the executor as failed only if there were lots of fetch failures on it
|
||||
if (bmAddress != null) {
|
||||
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
|
||||
}
|
||||
}
|
||||
|
||||
case commitDenied: TaskCommitDenied =>
|
||||
|
|
|
@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD
|
|||
*/
|
||||
private[spark] class ResultTask[T, U](
|
||||
stageId: Int,
|
||||
stageAttemptId: Int,
|
||||
taskBinary: Broadcast[Array[Byte]],
|
||||
partition: Partition,
|
||||
@transient locs: Seq[TaskLocation],
|
||||
val outputId: Int)
|
||||
extends Task[U](stageId, partition.index) with Serializable {
|
||||
extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {
|
||||
|
||||
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
|
||||
if (locs == null) Nil else locs.toSet.toSeq
|
||||
|
|
|
@ -40,14 +40,15 @@ import org.apache.spark.shuffle.ShuffleWriter
|
|||
*/
|
||||
private[spark] class ShuffleMapTask(
|
||||
stageId: Int,
|
||||
stageAttemptId: Int,
|
||||
taskBinary: Broadcast[Array[Byte]],
|
||||
partition: Partition,
|
||||
@transient private var locs: Seq[TaskLocation])
|
||||
extends Task[MapStatus](stageId, partition.index) with Logging {
|
||||
extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {
|
||||
|
||||
/** A constructor used only in test suites. This does not require passing in an RDD. */
|
||||
def this(partitionId: Int) {
|
||||
this(0, null, new Partition { override def index: Int = 0 }, null)
|
||||
this(0, 0, null, new Partition { override def index: Int = 0 }, null)
|
||||
}
|
||||
|
||||
@transient private val preferredLocs: Seq[TaskLocation] = {
|
||||
|
|
|
@ -43,7 +43,10 @@ import org.apache.spark.util.Utils
|
|||
* @param stageId id of the stage this task belongs to
|
||||
* @param partitionId index of the number in the RDD
|
||||
*/
|
||||
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
|
||||
private[spark] abstract class Task[T](
|
||||
val stageId: Int,
|
||||
val stageAttemptId: Int,
|
||||
var partitionId: Int) extends Serializable {
|
||||
|
||||
/**
|
||||
* The key of the Map is the accumulator id and the value of the Map is the latest accumulator
|
||||
|
|
|
@ -75,9 +75,9 @@ private[spark] class TaskSchedulerImpl(
|
|||
|
||||
// TaskSetManagers are not thread safe, so any access to one should be synchronized
|
||||
// on this class.
|
||||
val activeTaskSets = new HashMap[String, TaskSetManager]
|
||||
private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]
|
||||
|
||||
val taskIdToTaskSetId = new HashMap[Long, String]
|
||||
private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager]
|
||||
val taskIdToExecutorId = new HashMap[Long, String]
|
||||
|
||||
@volatile private var hasReceivedTask = false
|
||||
|
@ -162,7 +162,17 @@ private[spark] class TaskSchedulerImpl(
|
|||
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
|
||||
this.synchronized {
|
||||
val manager = createTaskSetManager(taskSet, maxTaskFailures)
|
||||
activeTaskSets(taskSet.id) = manager
|
||||
val stage = taskSet.stageId
|
||||
val stageTaskSets =
|
||||
taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
|
||||
stageTaskSets(taskSet.stageAttemptId) = manager
|
||||
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
|
||||
ts.taskSet != taskSet && !ts.isZombie
|
||||
}
|
||||
if (conflictingTaskSet) {
|
||||
throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
|
||||
s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
|
||||
}
|
||||
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
|
||||
|
||||
if (!isLocal && !hasReceivedTask) {
|
||||
|
@ -192,19 +202,21 @@ private[spark] class TaskSchedulerImpl(
|
|||
|
||||
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
|
||||
logInfo("Cancelling stage " + stageId)
|
||||
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
|
||||
// There are two possible cases here:
|
||||
// 1. The task set manager has been created and some tasks have been scheduled.
|
||||
// In this case, send a kill signal to the executors to kill the task and then abort
|
||||
// the stage.
|
||||
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
|
||||
// simply abort the stage.
|
||||
tsm.runningTasksSet.foreach { tid =>
|
||||
val execId = taskIdToExecutorId(tid)
|
||||
backend.killTask(tid, execId, interruptThread)
|
||||
taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts =>
|
||||
attempts.foreach { case (_, tsm) =>
|
||||
// There are two possible cases here:
|
||||
// 1. The task set manager has been created and some tasks have been scheduled.
|
||||
// In this case, send a kill signal to the executors to kill the task and then abort
|
||||
// the stage.
|
||||
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
|
||||
// simply abort the stage.
|
||||
tsm.runningTasksSet.foreach { tid =>
|
||||
val execId = taskIdToExecutorId(tid)
|
||||
backend.killTask(tid, execId, interruptThread)
|
||||
}
|
||||
tsm.abort("Stage %s cancelled".format(stageId))
|
||||
logInfo("Stage %d was cancelled".format(stageId))
|
||||
}
|
||||
tsm.abort("Stage %s cancelled".format(stageId))
|
||||
logInfo("Stage %d was cancelled".format(stageId))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -214,7 +226,12 @@ private[spark] class TaskSchedulerImpl(
|
|||
* cleaned up.
|
||||
*/
|
||||
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
|
||||
activeTaskSets -= manager.taskSet.id
|
||||
taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
|
||||
taskSetsForStage -= manager.taskSet.stageAttemptId
|
||||
if (taskSetsForStage.isEmpty) {
|
||||
taskSetsByStageIdAndAttempt -= manager.taskSet.stageId
|
||||
}
|
||||
}
|
||||
manager.parent.removeSchedulable(manager)
|
||||
logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
|
||||
.format(manager.taskSet.id, manager.parent.name))
|
||||
|
@ -235,7 +252,7 @@ private[spark] class TaskSchedulerImpl(
|
|||
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
|
||||
tasks(i) += task
|
||||
val tid = task.taskId
|
||||
taskIdToTaskSetId(tid) = taskSet.taskSet.id
|
||||
taskIdToTaskSetManager(tid) = taskSet
|
||||
taskIdToExecutorId(tid) = execId
|
||||
executorsByHost(host) += execId
|
||||
availableCpus(i) -= CPUS_PER_TASK
|
||||
|
@ -319,26 +336,24 @@ private[spark] class TaskSchedulerImpl(
|
|||
failedExecutor = Some(execId)
|
||||
}
|
||||
}
|
||||
taskIdToTaskSetId.get(tid) match {
|
||||
case Some(taskSetId) =>
|
||||
taskIdToTaskSetManager.get(tid) match {
|
||||
case Some(taskSet) =>
|
||||
if (TaskState.isFinished(state)) {
|
||||
taskIdToTaskSetId.remove(tid)
|
||||
taskIdToTaskSetManager.remove(tid)
|
||||
taskIdToExecutorId.remove(tid)
|
||||
}
|
||||
activeTaskSets.get(taskSetId).foreach { taskSet =>
|
||||
if (state == TaskState.FINISHED) {
|
||||
taskSet.removeRunningTask(tid)
|
||||
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
|
||||
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
|
||||
taskSet.removeRunningTask(tid)
|
||||
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
|
||||
}
|
||||
if (state == TaskState.FINISHED) {
|
||||
taskSet.removeRunningTask(tid)
|
||||
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
|
||||
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
|
||||
taskSet.removeRunningTask(tid)
|
||||
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
|
||||
}
|
||||
case None =>
|
||||
logError(
|
||||
("Ignoring update with state %s for TID %s because its task set is gone (this is " +
|
||||
"likely the result of receiving duplicate task finished status updates)")
|
||||
.format(state, tid))
|
||||
"likely the result of receiving duplicate task finished status updates)")
|
||||
.format(state, tid))
|
||||
}
|
||||
} catch {
|
||||
case e: Exception => logError("Exception in statusUpdate", e)
|
||||
|
@ -363,9 +378,9 @@ private[spark] class TaskSchedulerImpl(
|
|||
|
||||
val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
|
||||
taskMetrics.flatMap { case (id, metrics) =>
|
||||
taskIdToTaskSetId.get(id)
|
||||
.flatMap(activeTaskSets.get)
|
||||
.map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
|
||||
taskIdToTaskSetManager.get(id).map { taskSetMgr =>
|
||||
(id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
|
||||
}
|
||||
}
|
||||
}
|
||||
dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
|
||||
|
@ -397,9 +412,12 @@ private[spark] class TaskSchedulerImpl(
|
|||
|
||||
def error(message: String) {
|
||||
synchronized {
|
||||
if (activeTaskSets.nonEmpty) {
|
||||
if (taskSetsByStageIdAndAttempt.nonEmpty) {
|
||||
// Have each task set throw a SparkException with the error
|
||||
for ((taskSetId, manager) <- activeTaskSets) {
|
||||
for {
|
||||
attempts <- taskSetsByStageIdAndAttempt.values
|
||||
manager <- attempts.values
|
||||
} {
|
||||
try {
|
||||
manager.abort(message)
|
||||
} catch {
|
||||
|
@ -520,6 +538,17 @@ private[spark] class TaskSchedulerImpl(
|
|||
|
||||
override def applicationAttemptId(): Option[String] = backend.applicationAttemptId()
|
||||
|
||||
private[scheduler] def taskSetManagerForAttempt(
|
||||
stageId: Int,
|
||||
stageAttemptId: Int): Option[TaskSetManager] = {
|
||||
for {
|
||||
attempts <- taskSetsByStageIdAndAttempt.get(stageId)
|
||||
manager <- attempts.get(stageAttemptId)
|
||||
} yield {
|
||||
manager
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -26,10 +26,10 @@ import java.util.Properties
|
|||
private[spark] class TaskSet(
|
||||
val tasks: Array[Task[_]],
|
||||
val stageId: Int,
|
||||
val attempt: Int,
|
||||
val stageAttemptId: Int,
|
||||
val priority: Int,
|
||||
val properties: Properties) {
|
||||
val id: String = stageId + "." + attempt
|
||||
val id: String = stageId + "." + stageAttemptId
|
||||
|
||||
override def toString: String = "TaskSet " + id
|
||||
}
|
||||
|
|
|
@ -191,15 +191,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
|
|||
for (task <- tasks.flatten) {
|
||||
val serializedTask = ser.serialize(task)
|
||||
if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
|
||||
val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
|
||||
scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
|
||||
scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
|
||||
try {
|
||||
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
|
||||
"spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
|
||||
"spark.akka.frameSize or using broadcast variables for large values."
|
||||
msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize,
|
||||
AkkaUtils.reservedSizeBytes)
|
||||
taskSet.abort(msg)
|
||||
taskSetMgr.abort(msg)
|
||||
} catch {
|
||||
case e: Exception => logError("Exception in error callback", e)
|
||||
}
|
||||
|
|
|
@ -101,9 +101,15 @@ class DAGSchedulerSuite
|
|||
/** Length of time to wait while draining listener events. */
|
||||
val WAIT_TIMEOUT_MILLIS = 10000
|
||||
val sparkListener = new SparkListener() {
|
||||
val submittedStageInfos = new HashSet[StageInfo]
|
||||
val successfulStages = new HashSet[Int]
|
||||
val failedStages = new ArrayBuffer[Int]
|
||||
val stageByOrderOfExecution = new ArrayBuffer[Int]
|
||||
|
||||
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
|
||||
submittedStageInfos += stageSubmitted.stageInfo
|
||||
}
|
||||
|
||||
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
|
||||
val stageInfo = stageCompleted.stageInfo
|
||||
stageByOrderOfExecution += stageInfo.stageId
|
||||
|
@ -150,6 +156,7 @@ class DAGSchedulerSuite
|
|||
// Enable local execution for this test
|
||||
val conf = new SparkConf().set("spark.localExecution.enabled", "true")
|
||||
sc = new SparkContext("local", "DAGSchedulerSuite", conf)
|
||||
sparkListener.submittedStageInfos.clear()
|
||||
sparkListener.successfulStages.clear()
|
||||
sparkListener.failedStages.clear()
|
||||
failure = null
|
||||
|
@ -547,6 +554,140 @@ class DAGSchedulerSuite
|
|||
assert(sparkListener.failedStages.size == 1)
|
||||
}
|
||||
|
||||
/**
|
||||
* This tests the case where another FetchFailed comes in while the map stage is getting
|
||||
* re-run.
|
||||
*/
|
||||
test("late fetch failures don't cause multiple concurrent attempts for the same map stage") {
|
||||
val shuffleMapRdd = new MyRDD(sc, 2, Nil)
|
||||
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
|
||||
val shuffleId = shuffleDep.shuffleId
|
||||
val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
|
||||
submit(reduceRdd, Array(0, 1))
|
||||
|
||||
val mapStageId = 0
|
||||
def countSubmittedMapStageAttempts(): Int = {
|
||||
sparkListener.submittedStageInfos.count(_.stageId == mapStageId)
|
||||
}
|
||||
|
||||
// The map stage should have been submitted.
|
||||
assert(countSubmittedMapStageAttempts() === 1)
|
||||
|
||||
complete(taskSets(0), Seq(
|
||||
(Success, makeMapStatus("hostA", 2)),
|
||||
(Success, makeMapStatus("hostB", 2))))
|
||||
// The MapOutputTracker should know about both map output locations.
|
||||
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) ===
|
||||
Array("hostA", "hostB"))
|
||||
assert(mapOutputTracker.getServerStatuses(shuffleId, 1).map(_._1.host) ===
|
||||
Array("hostA", "hostB"))
|
||||
|
||||
// The first result task fails, with a fetch failure for the output from the first mapper.
|
||||
runEvent(CompletionEvent(
|
||||
taskSets(1).tasks(0),
|
||||
FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
|
||||
null,
|
||||
Map[Long, Any](),
|
||||
createFakeTaskInfo(),
|
||||
null))
|
||||
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
|
||||
assert(sparkListener.failedStages.contains(1))
|
||||
|
||||
// Trigger resubmission of the failed map stage.
|
||||
runEvent(ResubmitFailedStages)
|
||||
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
|
||||
|
||||
// Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
|
||||
assert(countSubmittedMapStageAttempts() === 2)
|
||||
|
||||
// The second ResultTask fails, with a fetch failure for the output from the second mapper.
|
||||
runEvent(CompletionEvent(
|
||||
taskSets(1).tasks(1),
|
||||
FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"),
|
||||
null,
|
||||
Map[Long, Any](),
|
||||
createFakeTaskInfo(),
|
||||
null))
|
||||
|
||||
// Another ResubmitFailedStages event should not result in another attempt for the map
|
||||
// stage being run concurrently.
|
||||
// NOTE: the actual ResubmitFailedStages may get called at any time during this, but it
|
||||
// shouldn't effect anything -- our calling it just makes *SURE* it gets called between the
|
||||
// desired event and our check.
|
||||
runEvent(ResubmitFailedStages)
|
||||
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
|
||||
assert(countSubmittedMapStageAttempts() === 2)
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* This tests the case where a late FetchFailed comes in after the map stage has finished getting
|
||||
* retried and a new reduce stage starts running.
|
||||
*/
|
||||
test("extremely late fetch failures don't cause multiple concurrent attempts for " +
|
||||
"the same stage") {
|
||||
val shuffleMapRdd = new MyRDD(sc, 2, Nil)
|
||||
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
|
||||
val shuffleId = shuffleDep.shuffleId
|
||||
val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
|
||||
submit(reduceRdd, Array(0, 1))
|
||||
|
||||
def countSubmittedReduceStageAttempts(): Int = {
|
||||
sparkListener.submittedStageInfos.count(_.stageId == 1)
|
||||
}
|
||||
def countSubmittedMapStageAttempts(): Int = {
|
||||
sparkListener.submittedStageInfos.count(_.stageId == 0)
|
||||
}
|
||||
|
||||
// The map stage should have been submitted.
|
||||
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
|
||||
assert(countSubmittedMapStageAttempts() === 1)
|
||||
|
||||
// Complete the map stage.
|
||||
complete(taskSets(0), Seq(
|
||||
(Success, makeMapStatus("hostA", 2)),
|
||||
(Success, makeMapStatus("hostB", 2))))
|
||||
|
||||
// The reduce stage should have been submitted.
|
||||
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
|
||||
assert(countSubmittedReduceStageAttempts() === 1)
|
||||
|
||||
// The first result task fails, with a fetch failure for the output from the first mapper.
|
||||
runEvent(CompletionEvent(
|
||||
taskSets(1).tasks(0),
|
||||
FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
|
||||
null,
|
||||
Map[Long, Any](),
|
||||
createFakeTaskInfo(),
|
||||
null))
|
||||
|
||||
// Trigger resubmission of the failed map stage and finish the re-started map task.
|
||||
runEvent(ResubmitFailedStages)
|
||||
complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
|
||||
|
||||
// Because the map stage finished, another attempt for the reduce stage should have been
|
||||
// submitted, resulting in 2 total attempts for each the map and the reduce stage.
|
||||
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
|
||||
assert(countSubmittedMapStageAttempts() === 2)
|
||||
assert(countSubmittedReduceStageAttempts() === 2)
|
||||
|
||||
// A late FetchFailed arrives from the second task in the original reduce stage.
|
||||
runEvent(CompletionEvent(
|
||||
taskSets(1).tasks(1),
|
||||
FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"),
|
||||
null,
|
||||
Map[Long, Any](),
|
||||
createFakeTaskInfo(),
|
||||
null))
|
||||
|
||||
// Running ResubmitFailedStages shouldn't result in any more attempts for the map stage, because
|
||||
// the FetchFailed should have been ignored
|
||||
runEvent(ResubmitFailedStages)
|
||||
|
||||
// The FetchFailed from the original reduce stage should be ignored.
|
||||
assert(countSubmittedMapStageAttempts() === 2)
|
||||
}
|
||||
|
||||
test("ignore late map task completions") {
|
||||
val shuffleMapRdd = new MyRDD(sc, 2, Nil)
|
||||
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.apache.spark.scheduler
|
|||
|
||||
import org.apache.spark.TaskContext
|
||||
|
||||
class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
|
||||
class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) {
|
||||
override def runTask(context: TaskContext): Int = 0
|
||||
|
||||
override def preferredLocations: Seq[TaskLocation] = prefLocs
|
||||
|
@ -31,12 +31,16 @@ object FakeTask {
|
|||
* locations for each task (given as varargs) if this sequence is not empty.
|
||||
*/
|
||||
def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
|
||||
createTaskSet(numTasks, 0, prefLocs: _*)
|
||||
}
|
||||
|
||||
def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
|
||||
if (prefLocs.size != 0 && prefLocs.size != numTasks) {
|
||||
throw new IllegalArgumentException("Wrong number of task locations")
|
||||
}
|
||||
val tasks = Array.tabulate[Task[_]](numTasks) { i =>
|
||||
new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil)
|
||||
}
|
||||
new TaskSet(tasks, 0, 0, 0, null)
|
||||
new TaskSet(tasks, 0, stageAttemptId, 0, null)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ import org.apache.spark.TaskContext
|
|||
* A Task implementation that fails to serialize.
|
||||
*/
|
||||
private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
|
||||
extends Task[Array[Byte]](stageId, 0) {
|
||||
extends Task[Array[Byte]](stageId, 0, 0) {
|
||||
|
||||
override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
|
||||
override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()
|
||||
|
|
|
@ -41,8 +41,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
|
|||
}
|
||||
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
|
||||
val func = (c: TaskContext, i: Iterator[String]) => i.next()
|
||||
val task = new ResultTask[String, String](
|
||||
0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
|
||||
val task = new ResultTask[String, String](0, 0,
|
||||
sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
|
||||
intercept[RuntimeException] {
|
||||
task.run(0, 0)
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
|
|||
val taskScheduler = new TaskSchedulerImpl(sc)
|
||||
taskScheduler.initialize(new FakeSchedulerBackend)
|
||||
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
|
||||
val dagScheduler = new DAGScheduler(sc, taskScheduler) {
|
||||
new DAGScheduler(sc, taskScheduler) {
|
||||
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
|
||||
override def executorAdded(execId: String, host: String) {}
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
|
|||
val taskScheduler = new TaskSchedulerImpl(sc)
|
||||
taskScheduler.initialize(new FakeSchedulerBackend)
|
||||
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
|
||||
val dagScheduler = new DAGScheduler(sc, taskScheduler) {
|
||||
new DAGScheduler(sc, taskScheduler) {
|
||||
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
|
||||
override def executorAdded(execId: String, host: String) {}
|
||||
}
|
||||
|
@ -128,4 +128,113 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
|
|||
assert(taskDescriptions.map(_.executorId) === Seq("executor0"))
|
||||
}
|
||||
|
||||
test("refuse to schedule concurrent attempts for the same stage (SPARK-8103)") {
|
||||
sc = new SparkContext("local", "TaskSchedulerImplSuite")
|
||||
val taskScheduler = new TaskSchedulerImpl(sc)
|
||||
taskScheduler.initialize(new FakeSchedulerBackend)
|
||||
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
|
||||
val dagScheduler = new DAGScheduler(sc, taskScheduler) {
|
||||
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
|
||||
override def executorAdded(execId: String, host: String) {}
|
||||
}
|
||||
taskScheduler.setDAGScheduler(dagScheduler)
|
||||
val attempt1 = FakeTask.createTaskSet(1, 0)
|
||||
val attempt2 = FakeTask.createTaskSet(1, 1)
|
||||
taskScheduler.submitTasks(attempt1)
|
||||
intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) }
|
||||
|
||||
// OK to submit multiple if previous attempts are all zombie
|
||||
taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId)
|
||||
.get.isZombie = true
|
||||
taskScheduler.submitTasks(attempt2)
|
||||
val attempt3 = FakeTask.createTaskSet(1, 2)
|
||||
intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) }
|
||||
taskScheduler.taskSetManagerForAttempt(attempt2.stageId, attempt2.stageAttemptId)
|
||||
.get.isZombie = true
|
||||
taskScheduler.submitTasks(attempt3)
|
||||
}
|
||||
|
||||
test("don't schedule more tasks after a taskset is zombie") {
|
||||
sc = new SparkContext("local", "TaskSchedulerImplSuite")
|
||||
val taskScheduler = new TaskSchedulerImpl(sc)
|
||||
taskScheduler.initialize(new FakeSchedulerBackend)
|
||||
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
|
||||
new DAGScheduler(sc, taskScheduler) {
|
||||
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
|
||||
override def executorAdded(execId: String, host: String) {}
|
||||
}
|
||||
|
||||
val numFreeCores = 1
|
||||
val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores))
|
||||
val attempt1 = FakeTask.createTaskSet(10)
|
||||
|
||||
// submit attempt 1, offer some resources, some tasks get scheduled
|
||||
taskScheduler.submitTasks(attempt1)
|
||||
val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
|
||||
assert(1 === taskDescriptions.length)
|
||||
|
||||
// now mark attempt 1 as a zombie
|
||||
taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId)
|
||||
.get.isZombie = true
|
||||
|
||||
// don't schedule anything on another resource offer
|
||||
val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten
|
||||
assert(0 === taskDescriptions2.length)
|
||||
|
||||
// if we schedule another attempt for the same stage, it should get scheduled
|
||||
val attempt2 = FakeTask.createTaskSet(10, 1)
|
||||
|
||||
// submit attempt 2, offer some resources, some tasks get scheduled
|
||||
taskScheduler.submitTasks(attempt2)
|
||||
val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten
|
||||
assert(1 === taskDescriptions3.length)
|
||||
val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId).get
|
||||
assert(mgr.taskSet.stageAttemptId === 1)
|
||||
}
|
||||
|
||||
test("if a zombie attempt finishes, continue scheduling tasks for non-zombie attempts") {
|
||||
sc = new SparkContext("local", "TaskSchedulerImplSuite")
|
||||
val taskScheduler = new TaskSchedulerImpl(sc)
|
||||
taskScheduler.initialize(new FakeSchedulerBackend)
|
||||
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
|
||||
new DAGScheduler(sc, taskScheduler) {
|
||||
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
|
||||
override def executorAdded(execId: String, host: String) {}
|
||||
}
|
||||
|
||||
val numFreeCores = 10
|
||||
val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores))
|
||||
val attempt1 = FakeTask.createTaskSet(10)
|
||||
|
||||
// submit attempt 1, offer some resources, some tasks get scheduled
|
||||
taskScheduler.submitTasks(attempt1)
|
||||
val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
|
||||
assert(10 === taskDescriptions.length)
|
||||
|
||||
// now mark attempt 1 as a zombie
|
||||
val mgr1 = taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId).get
|
||||
mgr1.isZombie = true
|
||||
|
||||
// don't schedule anything on another resource offer
|
||||
val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten
|
||||
assert(0 === taskDescriptions2.length)
|
||||
|
||||
// submit attempt 2
|
||||
val attempt2 = FakeTask.createTaskSet(10, 1)
|
||||
taskScheduler.submitTasks(attempt2)
|
||||
|
||||
// attempt 1 finished (this can happen even if it was marked zombie earlier -- all tasks were
|
||||
// already submitted, and then they finish)
|
||||
taskScheduler.taskSetFinished(mgr1)
|
||||
|
||||
// now with another resource offer, we should still schedule all the tasks in attempt2
|
||||
val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten
|
||||
assert(10 === taskDescriptions3.length)
|
||||
|
||||
taskDescriptions3.foreach { task =>
|
||||
val mgr = taskScheduler.taskIdToTaskSetManager.get(task.taskId).get
|
||||
assert(mgr.taskSet.stageAttemptId === 1)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -136,7 +136,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
|
|||
/**
|
||||
* A Task implementation that results in a large serialized task.
|
||||
*/
|
||||
class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0) {
|
||||
class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) {
|
||||
val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024)
|
||||
val random = new Random(0)
|
||||
random.nextBytes(randomBuffer)
|
||||
|
|
Loading…
Reference in a new issue