Fixed a failure recovery bug and added some tests for fault recovery.
This commit is contained in:
parent
eb05154b7a
commit
fd5581a0d3
|
@ -238,7 +238,7 @@ private trait DAGScheduler extends Scheduler with Logging {
|
|||
case smt: ShuffleMapTask =>
|
||||
val stage = idToStage(smt.stageId)
|
||||
stage.addOutputLoc(smt.partition, evt.result.asInstanceOf[String])
|
||||
if (pendingTasks(stage).isEmpty) {
|
||||
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
|
||||
logInfo(stage + " finished; looking for newly runnable stages")
|
||||
running -= stage
|
||||
if (stage.shuffleDep != None) {
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
package spark
|
||||
|
||||
import java.util.concurrent._
|
||||
import java.util.concurrent.Executors
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
|
||||
/**
|
||||
* A simple Scheduler implementation that runs tasks locally in a thread pool.
|
||||
* Optionally the scheduler also allows each task to fail up to maxFailures times,
|
||||
* which is useful for testing fault recovery.
|
||||
*/
|
||||
private class LocalScheduler(threads: Int) extends DAGScheduler with Logging {
|
||||
var attemptId = 0
|
||||
var threadPool: ExecutorService =
|
||||
Executors.newFixedThreadPool(threads, DaemonThreadFactory)
|
||||
private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGScheduler with Logging {
|
||||
var attemptId = new AtomicInteger(0)
|
||||
var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
|
||||
|
||||
val env = SparkEnv.get
|
||||
|
||||
|
@ -17,37 +19,52 @@ private class LocalScheduler(threads: Int) extends DAGScheduler with Logging {
|
|||
override def waitForRegister() {}
|
||||
|
||||
override def submitTasks(tasks: Seq[Task[_]]) {
|
||||
tasks.zipWithIndex.foreach { case (task, i) =>
|
||||
val myAttemptId = attemptId
|
||||
attemptId = attemptId + 1
|
||||
val failCount = new Array[Int](tasks.size)
|
||||
|
||||
def submitTask(task: Task[_], idInJob: Int) {
|
||||
val myAttemptId = attemptId.getAndIncrement()
|
||||
threadPool.submit(new Runnable {
|
||||
def run() {
|
||||
logInfo("Running task " + i)
|
||||
// Set the Spark execution environment for the worker thread
|
||||
SparkEnv.set(env)
|
||||
try {
|
||||
// Serialize and deserialize the task so that accumulators are
|
||||
// changed to thread-local ones; this adds a bit of unnecessary
|
||||
// overhead but matches how the Mesos Executor works
|
||||
Accumulators.clear
|
||||
val bytes = Utils.serialize(tasks(i))
|
||||
logInfo("Size of task " + i + " is " + bytes.size + " bytes")
|
||||
val deserializedTask = Utils.deserialize[Task[_]](
|
||||
bytes, Thread.currentThread.getContextClassLoader)
|
||||
val result: Any = deserializedTask.run(myAttemptId)
|
||||
val accumUpdates = Accumulators.values
|
||||
logInfo("Finished task " + i)
|
||||
taskEnded(tasks(i), Success, result, accumUpdates)
|
||||
} catch {
|
||||
case t: Throwable => {
|
||||
// TODO: Do something nicer here
|
||||
logError("Exception in task " + i, t)
|
||||
runTask(task, idInJob, myAttemptId)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
|
||||
logInfo("Running task " + idInJob)
|
||||
// Set the Spark execution environment for the worker thread
|
||||
SparkEnv.set(env)
|
||||
try {
|
||||
// Serialize and deserialize the task so that accumulators are
|
||||
// changed to thread-local ones; this adds a bit of unnecessary
|
||||
// overhead but matches how the Mesos Executor works
|
||||
Accumulators.clear
|
||||
val bytes = Utils.serialize(task)
|
||||
logInfo("Size of task " + idInJob + " is " + bytes.size + " bytes")
|
||||
val deserializedTask = Utils.deserialize[Task[_]](
|
||||
bytes, Thread.currentThread.getContextClassLoader)
|
||||
val result: Any = deserializedTask.run(attemptId)
|
||||
val accumUpdates = Accumulators.values
|
||||
logInfo("Finished task " + idInJob)
|
||||
taskEnded(task, Success, result, accumUpdates)
|
||||
} catch {
|
||||
case t: Throwable => {
|
||||
logError("Exception in task " + idInJob, t)
|
||||
failCount.synchronized {
|
||||
failCount(idInJob) += 1
|
||||
if (failCount(idInJob) <= maxFailures) {
|
||||
submitTask(task, idInJob)
|
||||
} else {
|
||||
// TODO: Do something nicer here to return all the way to the user
|
||||
System.exit(1)
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for ((task, i) <- tasks.zipWithIndex) {
|
||||
submitTask(task, i)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -58,11 +58,15 @@ extends Logging {
|
|||
private var scheduler: Scheduler = {
|
||||
// Regular expression used for local[N] master format
|
||||
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
|
||||
// Regular expression for local[N, maxRetries], used in tests with failing tasks
|
||||
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+),([0-9]+)\]""".r
|
||||
master match {
|
||||
case "local" =>
|
||||
new LocalScheduler(1)
|
||||
new LocalScheduler(1, 0)
|
||||
case LOCAL_N_REGEX(threads) =>
|
||||
new LocalScheduler(threads.toInt)
|
||||
new LocalScheduler(threads.toInt, 0)
|
||||
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
|
||||
new LocalScheduler(threads.toInt, maxFailures.toInt)
|
||||
case _ =>
|
||||
System.loadLibrary("mesos")
|
||||
new MesosScheduler(this, master, frameworkName)
|
||||
|
|
Loading…
Reference in a new issue