Addressed Matei's code review comments

This commit is contained in:
Kay Ousterhout 2013-09-30 10:11:59 -07:00
parent c75eb14fe5
commit 58b764b7c6
6 changed files with 45 additions and 29 deletions

View file

@ -26,10 +26,7 @@ import java.nio.ByteBuffer
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
// Task result. Also contains updates to accumulator variables. // Task result. Also contains updates to accumulator variables.
// TODO: Use of distributed cache to return result is a hack to get around private[spark] sealed trait TaskResult[T]
// what seems to be a bug with messages over 60KB in libprocess; fix it
private[spark]
sealed abstract class TaskResult[T]
/** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */ /** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
private[spark] private[spark]

View file

@ -100,7 +100,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
System.getProperty("spark.scheduler.mode", "FIFO")) System.getProperty("spark.scheduler.mode", "FIFO"))
// This is a var so that we can reset it for testing purposes. // This is a var so that we can reset it for testing purposes.
private[spark] var taskResultResolver = new TaskResultResolver(sc.env, this) private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)
override def setListener(listener: TaskSchedulerListener) { override def setListener(listener: TaskSchedulerListener) {
this.listener = listener this.listener = listener
@ -267,10 +267,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
activeTaskSets.get(taskSetId).foreach { taskSet => activeTaskSets.get(taskSetId).foreach { taskSet =>
if (state == TaskState.FINISHED) { if (state == TaskState.FINISHED) {
taskSet.removeRunningTask(tid) taskSet.removeRunningTask(tid)
taskResultResolver.enqueueSuccessfulTask(taskSet, tid, serializedData) taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
taskSet.removeRunningTask(tid) taskSet.removeRunningTask(tid)
taskResultResolver.enqueueFailedTask(taskSet, tid, state, serializedData) taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
} }
} }
case None => case None =>
@ -338,8 +338,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (jarServer != null) { if (jarServer != null) {
jarServer.stop() jarServer.stop()
} }
if (taskResultResolver != null) { if (taskResultGetter != null) {
taskResultResolver.stop() taskResultGetter.stop()
} }
// sleeping for an arbitrary 5 seconds : to ensure that messages are sent out. // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out.

View file

@ -25,7 +25,6 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet import scala.collection.mutable.HashSet
import scala.math.max import scala.math.max
import scala.math.min import scala.math.min
import scala.Some
import org.apache.spark._ import org.apache.spark._
import org.apache.spark.TaskState.TaskState import org.apache.spark.TaskState.TaskState
@ -458,8 +457,6 @@ private[spark] class ClusterTaskSetManager(
removeRunningTask(tid) removeRunningTask(tid)
val index = info.index val index = info.index
info.markFailed() info.markFailed()
// Count failed attempts only on FAILED and LOST state (not on KILLED)
var countFailedTaskAttempt = (state == TaskState.FAILED || state == TaskState.LOST)
if (!successful(index)) { if (!successful(index)) {
logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
copiesRunning(index) -= 1 copiesRunning(index) -= 1
@ -505,7 +502,6 @@ private[spark] class ClusterTaskSetManager(
case TaskResultLost => case TaskResultLost =>
logInfo("Lost result for TID %s on host %s".format(tid, info.host)) logInfo("Lost result for TID %s on host %s".format(tid, info.host))
countFailedTaskAttempt = true
sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null) sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
case _ => {} case _ => {}
@ -513,7 +509,7 @@ private[spark] class ClusterTaskSetManager(
} }
// On non-fetch failures, re-enqueue the task as pending for a max number of retries // On non-fetch failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index) addPendingTask(index)
if (countFailedTaskAttempt) { if (state != TaskState.KILLED) {
numFailures(index) += 1 numFailures(index) += 1
if (numFailures(index) > MAX_TASK_FAILURES) { if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %s:%d failed more than %d times; aborting job".format( logError("Task %s:%d failed more than %d times; aborting job".format(

View file

@ -26,17 +26,16 @@ import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskRes
import org.apache.spark.serializer.SerializerInstance import org.apache.spark.serializer.SerializerInstance
/** /**
* Runs a thread pool that deserializes and remotely fetches (if neceessary) task results. * Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
*/ */
private[spark] class TaskResultResolver(sparkEnv: SparkEnv, scheduler: ClusterScheduler) private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
extends Logging { extends Logging {
private val MIN_THREADS = 20 private val MIN_THREADS = System.getProperty("spark.resultGetter.minThreads", "4").toInt
private val MAX_THREADS = 60 private val MAX_THREADS = System.getProperty("spark.resultGetter.maxThreads", "4").toInt
private val KEEP_ALIVE_SECONDS = 60
private val getTaskResultExecutor = new ThreadPoolExecutor( private val getTaskResultExecutor = new ThreadPoolExecutor(
MIN_THREADS, MIN_THREADS,
MAX_THREADS, MAX_THREADS,
KEEP_ALIVE_SECONDS, 0L,
TimeUnit.SECONDS, TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable], new LinkedBlockingDeque[Runnable],
new ResultResolverThreadFactory) new ResultResolverThreadFactory)

View file

@ -253,6 +253,23 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None) assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None)
} }
test("task result lost") {
sc = new SparkContext("local", "test")
val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
val taskSet = createTaskSet(1)
val clock = new FakeClock
val manager = new ClusterTaskSetManager(sched, taskSet, clock)
assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
// Tell it the task has finished but the result was lost.
manager.handleFailedTask(0, TaskState.FINISHED, Some(TaskResultLost))
assert(sched.endedTasks(0) === TaskResultLost)
// Re-offer the host -- now we should get task 0 again.
assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
}
/** /**
* Utility method to create a TaskSet, potentially setting a particular sequence of preferred * Utility method to create a TaskSet, potentially setting a particular sequence of preferred
* locations for each task (given as varargs) if this sequence is not empty. * locations for each task (given as varargs) if this sequence is not empty.

View file

@ -15,7 +15,7 @@
* limitations under the License. * limitations under the License.
*/ */
package org.apache.spark.scheduler package org.apache.spark.scheduler.cluster
import java.nio.ByteBuffer import java.nio.ByteBuffer
@ -23,16 +23,16 @@ import org.scalatest.BeforeAndAfter
import org.scalatest.FunSuite import org.scalatest.FunSuite
import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv} import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv}
import org.apache.spark.scheduler.cluster.{ClusterScheduler, ClusterTaskSetManager, TaskResultResolver} import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
/** /**
* Removes the TaskResult from the BlockManager before delegating to a normal TaskResultResolver. * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter.
* *
* Used to test the case where a BlockManager evicts the task result (or dies) before the * Used to test the case where a BlockManager evicts the task result (or dies) before the
* TaskResult is retrieved. * TaskResult is retrieved.
*/ */
class ResultDeletingTaskResultResolver(sparkEnv: SparkEnv, scheduler: ClusterScheduler) class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
extends TaskResultResolver(sparkEnv, scheduler) { extends TaskResultGetter(sparkEnv, scheduler) {
var removedResult = false var removedResult = false
override def enqueueSuccessfulTask( override def enqueueSuccessfulTask(
@ -44,7 +44,7 @@ class ResultDeletingTaskResultResolver(sparkEnv: SparkEnv, scheduler: ClusterSch
case IndirectTaskResult(blockId) => case IndirectTaskResult(blockId) =>
sparkEnv.blockManager.master.removeBlock(blockId) sparkEnv.blockManager.master.removeBlock(blockId)
case directResult: DirectTaskResult[_] => case directResult: DirectTaskResult[_] =>
taskSetManager.abort("Expect only indirect results") taskSetManager.abort("Internal error: expect only indirect results")
} }
serializedData.rewind() serializedData.rewind()
removedResult = true removedResult = true
@ -56,9 +56,11 @@ class ResultDeletingTaskResultResolver(sparkEnv: SparkEnv, scheduler: ClusterSch
/** /**
* Tests related to handling task results (both direct and indirect). * Tests related to handling task results (both direct and indirect).
*/ */
class TaskResultResolverSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
override def beforeAll() {
super.beforeAll()
before {
// Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small // Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small
// as we can make it) so the tests don't take too long. // as we can make it) so the tests don't take too long.
System.setProperty("spark.akka.frameSize", "1") System.setProperty("spark.akka.frameSize", "1")
@ -67,6 +69,11 @@ class TaskResultResolverSuite extends FunSuite with BeforeAndAfter with LocalSpa
sc = new SparkContext("local-cluster[1,1,512]", "test") sc = new SparkContext("local-cluster[1,1,512]", "test")
} }
override def afterAll() {
super.afterAll()
System.clearProperty("spark.akka.frameSize")
}
test("handling results smaller than Akka frame size") { test("handling results smaller than Akka frame size") {
val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x) val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
assert(result === 2) assert(result === 2)
@ -93,7 +100,7 @@ class TaskResultResolverSuite extends FunSuite with BeforeAndAfter with LocalSpa
assert(false, "Expect local cluster to use ClusterScheduler") assert(false, "Expect local cluster to use ClusterScheduler")
throw new ClassCastException throw new ClassCastException
} }
scheduler.taskResultResolver = new ResultDeletingTaskResultResolver(sc.env, scheduler) scheduler.taskResultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler)
val akkaFrameSize = val akkaFrameSize =
sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)