Fixed a failure recovery bug and added some tests for fault recovery.

This commit is contained in:
Matei Zaharia 2012-01-13 19:17:27 -08:00
parent eb05154b7a
commit fd5581a0d3
3 changed files with 54 additions and 33 deletions

View file

@ -238,7 +238,7 @@ private trait DAGScheduler extends Scheduler with Logging {
case smt: ShuffleMapTask =>
val stage = idToStage(smt.stageId)
stage.addOutputLoc(smt.partition, evt.result.asInstanceOf[String])
if (pendingTasks(stage).isEmpty) {
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
logInfo(stage + " finished; looking for newly runnable stages")
running -= stage
if (stage.shuffleDep != None) {

View file

@ -1,14 +1,16 @@
package spark
import java.util.concurrent._
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicInteger
/**
* A simple Scheduler implementation that runs tasks locally in a thread pool.
* Optionally the scheduler also allows each task to fail up to maxFailures times,
* which is useful for testing fault recovery.
*/
private class LocalScheduler(threads: Int) extends DAGScheduler with Logging {
var attemptId = 0
var threadPool: ExecutorService =
Executors.newFixedThreadPool(threads, DaemonThreadFactory)
private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGScheduler with Logging {
var attemptId = new AtomicInteger(0)
var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
val env = SparkEnv.get
@ -17,37 +19,52 @@ private class LocalScheduler(threads: Int) extends DAGScheduler with Logging {
override def waitForRegister() {}
override def submitTasks(tasks: Seq[Task[_]]) {
tasks.zipWithIndex.foreach { case (task, i) =>
val myAttemptId = attemptId
attemptId = attemptId + 1
val failCount = new Array[Int](tasks.size)
def submitTask(task: Task[_], idInJob: Int) {
val myAttemptId = attemptId.getAndIncrement()
threadPool.submit(new Runnable {
def run() {
logInfo("Running task " + i)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
try {
// Serialize and deserialize the task so that accumulators are
// changed to thread-local ones; this adds a bit of unnecessary
// overhead but matches how the Mesos Executor works
Accumulators.clear
val bytes = Utils.serialize(tasks(i))
logInfo("Size of task " + i + " is " + bytes.size + " bytes")
val deserializedTask = Utils.deserialize[Task[_]](
bytes, Thread.currentThread.getContextClassLoader)
val result: Any = deserializedTask.run(myAttemptId)
val accumUpdates = Accumulators.values
logInfo("Finished task " + i)
taskEnded(tasks(i), Success, result, accumUpdates)
} catch {
case t: Throwable => {
// TODO: Do something nicer here
logError("Exception in task " + i, t)
runTask(task, idInJob, myAttemptId)
}
})
}
def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
logInfo("Running task " + idInJob)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
try {
// Serialize and deserialize the task so that accumulators are
// changed to thread-local ones; this adds a bit of unnecessary
// overhead but matches how the Mesos Executor works
Accumulators.clear
val bytes = Utils.serialize(task)
logInfo("Size of task " + idInJob + " is " + bytes.size + " bytes")
val deserializedTask = Utils.deserialize[Task[_]](
bytes, Thread.currentThread.getContextClassLoader)
val result: Any = deserializedTask.run(attemptId)
val accumUpdates = Accumulators.values
logInfo("Finished task " + idInJob)
taskEnded(task, Success, result, accumUpdates)
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
failCount.synchronized {
failCount(idInJob) += 1
if (failCount(idInJob) <= maxFailures) {
submitTask(task, idInJob)
} else {
// TODO: Do something nicer here to return all the way to the user
System.exit(1)
null
}
}
}
})
}
}
for ((task, i) <- tasks.zipWithIndex) {
submitTask(task, i)
}
}

View file

@ -58,11 +58,15 @@ extends Logging {
private var scheduler: Scheduler = {
// Regular expression used for local[N] master format
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+),([0-9]+)\]""".r
master match {
case "local" =>
new LocalScheduler(1)
new LocalScheduler(1, 0)
case LOCAL_N_REGEX(threads) =>
new LocalScheduler(threads.toInt)
new LocalScheduler(threads.toInt, 0)
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
new LocalScheduler(threads.toInt, maxFailures.toInt)
case _ =>
System.loadLibrary("mesos")
new MesosScheduler(this, master, frameworkName)