[SPARK-8103][core] DAGScheduler should not submit multiple concurrent attempts for a stage

https://issues.apache.org/jira/browse/SPARK-8103

cc kayousterhout (thanks for the extra test case)

Author: Imran Rashid <irashid@cloudera.com>
Author: Kay Ousterhout <kayousterhout@gmail.com>
Author: Imran Rashid <squito@users.noreply.github.com>

Closes #6750 from squito/SPARK-8103 and squashes the following commits:

fb3acfc [Imran Rashid] fix log msg
e01b7aa [Imran Rashid] fix some comments, style
584acd4 [Imran Rashid] simplify going from taskId to taskSetMgr
e43ac25 [Imran Rashid] Merge branch 'master' into SPARK-8103
6bc23af [Imran Rashid] update log msg
4470fa1 [Imran Rashid] rename
c04707e [Imran Rashid] style
88b61cc [Imran Rashid] add tests to make sure that TaskSchedulerImpl schedules correctly with zombie attempts
d7f1ef2 [Imran Rashid] get rid of activeTaskSets
a21c8b5 [Imran Rashid] Merge branch 'master' into SPARK-8103
906d626 [Imran Rashid] fix merge
109900e [Imran Rashid] Merge branch 'master' into SPARK-8103
c0d4d90 [Imran Rashid] Revert "Index active task sets by stage Id rather than by task set id"
f025154 [Imran Rashid] Merge pull request #2 from kayousterhout/imran_SPARK-8103
baf46e1 [Kay Ousterhout] Index active task sets by stage Id rather than by task set id
19685bb [Imran Rashid] switch to using latestInfo.attemptId, and add comments
a5f7c8c [Imran Rashid] remove comment for reviewers
227b40d [Imran Rashid] style
517b6e5 [Imran Rashid] get rid of SparkIllegalStateException
b2faef5 [Imran Rashid] faster check for conflicting task sets
6542b42 [Imran Rashid] remove extra stageAttemptId
ada7726 [Imran Rashid] reviewer feedback
d8eb202 [Imran Rashid] Merge branch 'master' into SPARK-8103
46bc26a [Imran Rashid] more cleanup of debug garbage
cb245da [Imran Rashid] finally found the issue ... clean up debug stuff
8c29707 [Imran Rashid] Merge branch 'master' into SPARK-8103
89a59b6 [Imran Rashid] more printlns ...
9601b47 [Imran Rashid] more debug printlns
ecb4e7d [Imran Rashid] debugging printlns
b6bc248 [Imran Rashid] style
55f4a94 [Imran Rashid] get rid of more random test case since kays tests are clearer
7021d28 [Imran Rashid] update test since listenerBus.waitUntilEmpty now throws an exception instead of returning a boolean
883fe49 [Kay Ousterhout] Unit tests for concurrent stages issue
6e14683 [Imran Rashid] unit test just to make sure we fail fast on concurrent attempts
06a0af6 [Imran Rashid] ignore for jenkins
c443def [Imran Rashid] better fix and simpler test case
28d70aa [Imran Rashid] wip on getting a better test case ...
a9bf31f [Imran Rashid] wip
This commit is contained in:
Imran Rashid 2015-07-20 10:28:32 -07:00 committed by Kay Ousterhout
parent c6fe9b4a17
commit 80e2568b25
13 changed files with 383 additions and 86 deletions

View file

@ -857,7 +857,6 @@ class DAGScheduler(
// Get our pending tasks and remember them in our pendingTasks entry
stage.pendingTasks.clear()
// First figure out the indexes of partition ids to compute.
val partitionsToCompute: Seq[Int] = {
stage match {
@ -918,7 +917,7 @@ class DAGScheduler(
partitionsToCompute.map { id =>
val locs = getPreferredLocs(stage.rdd, id)
val part = stage.rdd.partitions(id)
new ShuffleMapTask(stage.id, taskBinary, part, locs)
new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs)
}
case stage: ResultStage =>
@ -927,7 +926,7 @@ class DAGScheduler(
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = getPreferredLocs(stage.rdd, p)
new ResultTask(stage.id, taskBinary, part, locs, id)
new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id)
}
}
} catch {
@ -1069,10 +1068,11 @@ class DAGScheduler(
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
} else {
shuffleStage.addOutputLoc(smt.partitionId, status)
}
if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) {
markStageAsFinished(shuffleStage)
logInfo("looking for newly runnable stages")
@ -1132,38 +1132,48 @@ class DAGScheduler(
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleToMapStage(shuffleId)
// It is likely that we receive multiple FetchFailed for a single stage (because we have
// multiple tasks running concurrently on different executors). In that case, it is possible
// the fetch failure has already been handled by the scheduler.
if (runningStages.contains(failedStage)) {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
markStageAsFinished(failedStage, Some(failureMessage))
}
if (failedStage.latestInfo.attemptId != task.stageAttemptId) {
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
s"(attempt ID ${failedStage.latestInfo.attemptId}) running")
} else {
if (disallowStageRetryForTest) {
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
} else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled.
// TODO: Cancel running tasks in the stage
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
messageScheduler.schedule(new Runnable {
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}
failedStages += failedStage
failedStages += mapStage
// Mark the map whose fetch failed as broken in the map stage
if (mapId != -1) {
mapStage.removeOutputLoc(mapId, bmAddress)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
}
// It is likely that we receive multiple FetchFailed for a single stage (because we have
// multiple tasks running concurrently on different executors). In that case, it is
// possible the fetch failure has already been handled by the scheduler.
if (runningStages.contains(failedStage)) {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
markStageAsFinished(failedStage, Some(failureMessage))
} else {
logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " +
s"longer running")
}
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
if (disallowStageRetryForTest) {
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
} else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled.
// TODO: Cancel running tasks in the stage
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
messageScheduler.schedule(new Runnable {
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}
failedStages += failedStage
failedStages += mapStage
// Mark the map whose fetch failed as broken in the map stage
if (mapId != -1) {
mapStage.removeOutputLoc(mapId, bmAddress)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
}
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
}
}
case commitDenied: TaskCommitDenied =>

View file

@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD
*/
private[spark] class ResultTask[T, U](
stageId: Int,
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient locs: Seq[TaskLocation],
val outputId: Int)
extends Task[U](stageId, partition.index) with Serializable {
extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq

View file

@ -40,14 +40,15 @@ import org.apache.spark.shuffle.ShuffleWriter
*/
private[spark] class ShuffleMapTask(
stageId: Int,
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient private var locs: Seq[TaskLocation])
extends Task[MapStatus](stageId, partition.index) with Logging {
extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {
/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
this(0, null, new Partition { override def index: Int = 0 }, null)
this(0, 0, null, new Partition { override def index: Int = 0 }, null)
}
@transient private val preferredLocs: Seq[TaskLocation] = {

View file

@ -43,7 +43,10 @@ import org.apache.spark.util.Utils
* @param stageId id of the stage this task belongs to
* @param partitionId index of the number in the RDD
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
private[spark] abstract class Task[T](
val stageId: Int,
val stageAttemptId: Int,
var partitionId: Int) extends Serializable {
/**
* The key of the Map is the accumulator id and the value of the Map is the latest accumulator

View file

@ -75,9 +75,9 @@ private[spark] class TaskSchedulerImpl(
// TaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
val activeTaskSets = new HashMap[String, TaskSetManager]
private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]
val taskIdToTaskSetId = new HashMap[Long, String]
private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager]
val taskIdToExecutorId = new HashMap[Long, String]
@volatile private var hasReceivedTask = false
@ -162,7 +162,17 @@ private[spark] class TaskSchedulerImpl(
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = createTaskSetManager(taskSet, maxTaskFailures)
activeTaskSets(taskSet.id) = manager
val stage = taskSet.stageId
val stageTaskSets =
taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
stageTaskSets(taskSet.stageAttemptId) = manager
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
ts.taskSet != taskSet && !ts.isZombie
}
if (conflictingTaskSet) {
throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
}
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
if (!isLocal && !hasReceivedTask) {
@ -192,19 +202,21 @@ private[spark] class TaskSchedulerImpl(
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
// There are two possible cases here:
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task and then abort
// the stage.
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
// simply abort the stage.
tsm.runningTasksSet.foreach { tid =>
val execId = taskIdToExecutorId(tid)
backend.killTask(tid, execId, interruptThread)
taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts =>
attempts.foreach { case (_, tsm) =>
// There are two possible cases here:
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task and then abort
// the stage.
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
// simply abort the stage.
tsm.runningTasksSet.foreach { tid =>
val execId = taskIdToExecutorId(tid)
backend.killTask(tid, execId, interruptThread)
}
tsm.abort("Stage %s cancelled".format(stageId))
logInfo("Stage %d was cancelled".format(stageId))
}
tsm.abort("Stage %s cancelled".format(stageId))
logInfo("Stage %d was cancelled".format(stageId))
}
}
@ -214,7 +226,12 @@ private[spark] class TaskSchedulerImpl(
* cleaned up.
*/
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
activeTaskSets -= manager.taskSet.id
taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
taskSetsForStage -= manager.taskSet.stageAttemptId
if (taskSetsForStage.isEmpty) {
taskSetsByStageIdAndAttempt -= manager.taskSet.stageId
}
}
manager.parent.removeSchedulable(manager)
logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
.format(manager.taskSet.id, manager.parent.name))
@ -235,7 +252,7 @@ private[spark] class TaskSchedulerImpl(
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
tasks(i) += task
val tid = task.taskId
taskIdToTaskSetId(tid) = taskSet.taskSet.id
taskIdToTaskSetManager(tid) = taskSet
taskIdToExecutorId(tid) = execId
executorsByHost(host) += execId
availableCpus(i) -= CPUS_PER_TASK
@ -319,26 +336,24 @@ private[spark] class TaskSchedulerImpl(
failedExecutor = Some(execId)
}
}
taskIdToTaskSetId.get(tid) match {
case Some(taskSetId) =>
taskIdToTaskSetManager.get(tid) match {
case Some(taskSet) =>
if (TaskState.isFinished(state)) {
taskIdToTaskSetId.remove(tid)
taskIdToTaskSetManager.remove(tid)
taskIdToExecutorId.remove(tid)
}
activeTaskSets.get(taskSetId).foreach { taskSet =>
if (state == TaskState.FINISHED) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
}
if (state == TaskState.FINISHED) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
}
case None =>
logError(
("Ignoring update with state %s for TID %s because its task set is gone (this is " +
"likely the result of receiving duplicate task finished status updates)")
.format(state, tid))
"likely the result of receiving duplicate task finished status updates)")
.format(state, tid))
}
} catch {
case e: Exception => logError("Exception in statusUpdate", e)
@ -363,9 +378,9 @@ private[spark] class TaskSchedulerImpl(
val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
taskMetrics.flatMap { case (id, metrics) =>
taskIdToTaskSetId.get(id)
.flatMap(activeTaskSets.get)
.map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
taskIdToTaskSetManager.get(id).map { taskSetMgr =>
(id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
}
}
}
dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
@ -397,9 +412,12 @@ private[spark] class TaskSchedulerImpl(
def error(message: String) {
synchronized {
if (activeTaskSets.nonEmpty) {
if (taskSetsByStageIdAndAttempt.nonEmpty) {
// Have each task set throw a SparkException with the error
for ((taskSetId, manager) <- activeTaskSets) {
for {
attempts <- taskSetsByStageIdAndAttempt.values
manager <- attempts.values
} {
try {
manager.abort(message)
} catch {
@ -520,6 +538,17 @@ private[spark] class TaskSchedulerImpl(
override def applicationAttemptId(): Option[String] = backend.applicationAttemptId()
private[scheduler] def taskSetManagerForAttempt(
stageId: Int,
stageAttemptId: Int): Option[TaskSetManager] = {
for {
attempts <- taskSetsByStageIdAndAttempt.get(stageId)
manager <- attempts.get(stageAttemptId)
} yield {
manager
}
}
}

View file

@ -26,10 +26,10 @@ import java.util.Properties
private[spark] class TaskSet(
val tasks: Array[Task[_]],
val stageId: Int,
val attempt: Int,
val stageAttemptId: Int,
val priority: Int,
val properties: Properties) {
val id: String = stageId + "." + attempt
val id: String = stageId + "." + stageAttemptId
override def toString: String = "TaskSet " + id
}

View file

@ -191,15 +191,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
for (task <- tasks.flatten) {
val serializedTask = ser.serialize(task)
if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
try {
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
"spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
"spark.akka.frameSize or using broadcast variables for large values."
msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize,
AkkaUtils.reservedSizeBytes)
taskSet.abort(msg)
taskSetMgr.abort(msg)
} catch {
case e: Exception => logError("Exception in error callback", e)
}

View file

@ -101,9 +101,15 @@ class DAGSchedulerSuite
/** Length of time to wait while draining listener events. */
val WAIT_TIMEOUT_MILLIS = 10000
val sparkListener = new SparkListener() {
val submittedStageInfos = new HashSet[StageInfo]
val successfulStages = new HashSet[Int]
val failedStages = new ArrayBuffer[Int]
val stageByOrderOfExecution = new ArrayBuffer[Int]
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
submittedStageInfos += stageSubmitted.stageInfo
}
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
val stageInfo = stageCompleted.stageInfo
stageByOrderOfExecution += stageInfo.stageId
@ -150,6 +156,7 @@ class DAGSchedulerSuite
// Enable local execution for this test
val conf = new SparkConf().set("spark.localExecution.enabled", "true")
sc = new SparkContext("local", "DAGSchedulerSuite", conf)
sparkListener.submittedStageInfos.clear()
sparkListener.successfulStages.clear()
sparkListener.failedStages.clear()
failure = null
@ -547,6 +554,140 @@ class DAGSchedulerSuite
assert(sparkListener.failedStages.size == 1)
}
/**
* This tests the case where another FetchFailed comes in while the map stage is getting
* re-run.
*/
test("late fetch failures don't cause multiple concurrent attempts for the same map stage") {
val shuffleMapRdd = new MyRDD(sc, 2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
val shuffleId = shuffleDep.shuffleId
val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
submit(reduceRdd, Array(0, 1))
val mapStageId = 0
def countSubmittedMapStageAttempts(): Int = {
sparkListener.submittedStageInfos.count(_.stageId == mapStageId)
}
// The map stage should have been submitted.
assert(countSubmittedMapStageAttempts() === 1)
complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", 2)),
(Success, makeMapStatus("hostB", 2))))
// The MapOutputTracker should know about both map output locations.
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) ===
Array("hostA", "hostB"))
assert(mapOutputTracker.getServerStatuses(shuffleId, 1).map(_._1.host) ===
Array("hostA", "hostB"))
// The first result task fails, with a fetch failure for the output from the first mapper.
runEvent(CompletionEvent(
taskSets(1).tasks(0),
FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
null,
Map[Long, Any](),
createFakeTaskInfo(),
null))
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(sparkListener.failedStages.contains(1))
// Trigger resubmission of the failed map stage.
runEvent(ResubmitFailedStages)
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
// Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
assert(countSubmittedMapStageAttempts() === 2)
// The second ResultTask fails, with a fetch failure for the output from the second mapper.
runEvent(CompletionEvent(
taskSets(1).tasks(1),
FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"),
null,
Map[Long, Any](),
createFakeTaskInfo(),
null))
// Another ResubmitFailedStages event should not result in another attempt for the map
// stage being run concurrently.
// NOTE: the actual ResubmitFailedStages may get called at any time during this, but it
// shouldn't effect anything -- our calling it just makes *SURE* it gets called between the
// desired event and our check.
runEvent(ResubmitFailedStages)
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 2)
}
/**
* This tests the case where a late FetchFailed comes in after the map stage has finished getting
* retried and a new reduce stage starts running.
*/
test("extremely late fetch failures don't cause multiple concurrent attempts for " +
"the same stage") {
val shuffleMapRdd = new MyRDD(sc, 2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
val shuffleId = shuffleDep.shuffleId
val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
submit(reduceRdd, Array(0, 1))
def countSubmittedReduceStageAttempts(): Int = {
sparkListener.submittedStageInfos.count(_.stageId == 1)
}
def countSubmittedMapStageAttempts(): Int = {
sparkListener.submittedStageInfos.count(_.stageId == 0)
}
// The map stage should have been submitted.
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 1)
// Complete the map stage.
complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", 2)),
(Success, makeMapStatus("hostB", 2))))
// The reduce stage should have been submitted.
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedReduceStageAttempts() === 1)
// The first result task fails, with a fetch failure for the output from the first mapper.
runEvent(CompletionEvent(
taskSets(1).tasks(0),
FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
null,
Map[Long, Any](),
createFakeTaskInfo(),
null))
// Trigger resubmission of the failed map stage and finish the re-started map task.
runEvent(ResubmitFailedStages)
complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
// Because the map stage finished, another attempt for the reduce stage should have been
// submitted, resulting in 2 total attempts for each the map and the reduce stage.
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 2)
assert(countSubmittedReduceStageAttempts() === 2)
// A late FetchFailed arrives from the second task in the original reduce stage.
runEvent(CompletionEvent(
taskSets(1).tasks(1),
FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"),
null,
Map[Long, Any](),
createFakeTaskInfo(),
null))
// Running ResubmitFailedStages shouldn't result in any more attempts for the map stage, because
// the FetchFailed should have been ignored
runEvent(ResubmitFailedStages)
// The FetchFailed from the original reduce stage should be ignored.
assert(countSubmittedMapStageAttempts() === 2)
}
test("ignore late map task completions") {
val shuffleMapRdd = new MyRDD(sc, 2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)

View file

@ -19,7 +19,7 @@ package org.apache.spark.scheduler
import org.apache.spark.TaskContext
class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) {
override def runTask(context: TaskContext): Int = 0
override def preferredLocations: Seq[TaskLocation] = prefLocs
@ -31,12 +31,16 @@ object FakeTask {
* locations for each task (given as varargs) if this sequence is not empty.
*/
def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
createTaskSet(numTasks, 0, prefLocs: _*)
}
def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
if (prefLocs.size != 0 && prefLocs.size != numTasks) {
throw new IllegalArgumentException("Wrong number of task locations")
}
val tasks = Array.tabulate[Task[_]](numTasks) { i =>
new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil)
}
new TaskSet(tasks, 0, 0, 0, null)
new TaskSet(tasks, 0, stageAttemptId, 0, null)
}
}

View file

@ -25,7 +25,7 @@ import org.apache.spark.TaskContext
* A Task implementation that fails to serialize.
*/
private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
extends Task[Array[Byte]](stageId, 0) {
extends Task[Array[Byte]](stageId, 0, 0) {
override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()

View file

@ -41,8 +41,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
}
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val task = new ResultTask[String, String](
0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
val task = new ResultTask[String, String](0, 0,
sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
intercept[RuntimeException] {
task.run(0, 0)
}

View file

@ -33,7 +33,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
val taskScheduler = new TaskSchedulerImpl(sc)
taskScheduler.initialize(new FakeSchedulerBackend)
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
val dagScheduler = new DAGScheduler(sc, taskScheduler) {
new DAGScheduler(sc, taskScheduler) {
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
override def executorAdded(execId: String, host: String) {}
}
@ -67,7 +67,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
val taskScheduler = new TaskSchedulerImpl(sc)
taskScheduler.initialize(new FakeSchedulerBackend)
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
val dagScheduler = new DAGScheduler(sc, taskScheduler) {
new DAGScheduler(sc, taskScheduler) {
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
override def executorAdded(execId: String, host: String) {}
}
@ -128,4 +128,113 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
assert(taskDescriptions.map(_.executorId) === Seq("executor0"))
}
test("refuse to schedule concurrent attempts for the same stage (SPARK-8103)") {
sc = new SparkContext("local", "TaskSchedulerImplSuite")
val taskScheduler = new TaskSchedulerImpl(sc)
taskScheduler.initialize(new FakeSchedulerBackend)
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
val dagScheduler = new DAGScheduler(sc, taskScheduler) {
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
override def executorAdded(execId: String, host: String) {}
}
taskScheduler.setDAGScheduler(dagScheduler)
val attempt1 = FakeTask.createTaskSet(1, 0)
val attempt2 = FakeTask.createTaskSet(1, 1)
taskScheduler.submitTasks(attempt1)
intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) }
// OK to submit multiple if previous attempts are all zombie
taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId)
.get.isZombie = true
taskScheduler.submitTasks(attempt2)
val attempt3 = FakeTask.createTaskSet(1, 2)
intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) }
taskScheduler.taskSetManagerForAttempt(attempt2.stageId, attempt2.stageAttemptId)
.get.isZombie = true
taskScheduler.submitTasks(attempt3)
}
test("don't schedule more tasks after a taskset is zombie") {
sc = new SparkContext("local", "TaskSchedulerImplSuite")
val taskScheduler = new TaskSchedulerImpl(sc)
taskScheduler.initialize(new FakeSchedulerBackend)
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
new DAGScheduler(sc, taskScheduler) {
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
override def executorAdded(execId: String, host: String) {}
}
val numFreeCores = 1
val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores))
val attempt1 = FakeTask.createTaskSet(10)
// submit attempt 1, offer some resources, some tasks get scheduled
taskScheduler.submitTasks(attempt1)
val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
assert(1 === taskDescriptions.length)
// now mark attempt 1 as a zombie
taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId)
.get.isZombie = true
// don't schedule anything on another resource offer
val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten
assert(0 === taskDescriptions2.length)
// if we schedule another attempt for the same stage, it should get scheduled
val attempt2 = FakeTask.createTaskSet(10, 1)
// submit attempt 2, offer some resources, some tasks get scheduled
taskScheduler.submitTasks(attempt2)
val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten
assert(1 === taskDescriptions3.length)
val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId).get
assert(mgr.taskSet.stageAttemptId === 1)
}
test("if a zombie attempt finishes, continue scheduling tasks for non-zombie attempts") {
sc = new SparkContext("local", "TaskSchedulerImplSuite")
val taskScheduler = new TaskSchedulerImpl(sc)
taskScheduler.initialize(new FakeSchedulerBackend)
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
new DAGScheduler(sc, taskScheduler) {
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
override def executorAdded(execId: String, host: String) {}
}
val numFreeCores = 10
val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores))
val attempt1 = FakeTask.createTaskSet(10)
// submit attempt 1, offer some resources, some tasks get scheduled
taskScheduler.submitTasks(attempt1)
val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
assert(10 === taskDescriptions.length)
// now mark attempt 1 as a zombie
val mgr1 = taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId).get
mgr1.isZombie = true
// don't schedule anything on another resource offer
val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten
assert(0 === taskDescriptions2.length)
// submit attempt 2
val attempt2 = FakeTask.createTaskSet(10, 1)
taskScheduler.submitTasks(attempt2)
// attempt 1 finished (this can happen even if it was marked zombie earlier -- all tasks were
// already submitted, and then they finish)
taskScheduler.taskSetFinished(mgr1)
// now with another resource offer, we should still schedule all the tasks in attempt2
val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten
assert(10 === taskDescriptions3.length)
taskDescriptions3.foreach { task =>
val mgr = taskScheduler.taskIdToTaskSetManager.get(task.taskId).get
assert(mgr.taskSet.stageAttemptId === 1)
}
}
}

View file

@ -136,7 +136,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
/**
* A Task implementation that results in a large serialized task.
*/
class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0) {
class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) {
val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024)
val random = new Random(0)
random.nextBytes(randomBuffer)