Addressed Matei's code review comments

This commit is contained in:
Kay Ousterhout 2013-09-30 10:11:59 -07:00
parent c75eb14fe5
commit 58b764b7c6
6 changed files with 45 additions and 29 deletions

View file

@ -26,10 +26,7 @@ import java.nio.ByteBuffer
import org.apache.spark.util.Utils
// Task result. Also contains updates to accumulator variables.
// TODO: Use of distributed cache to return result is a hack to get around
// what seems to be a bug with messages over 60KB in libprocess; fix it
private[spark]
sealed abstract class TaskResult[T]
private[spark] sealed trait TaskResult[T]
/** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
private[spark]

View file

@ -100,7 +100,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
System.getProperty("spark.scheduler.mode", "FIFO"))
// This is a var so that we can reset it for testing purposes.
private[spark] var taskResultResolver = new TaskResultResolver(sc.env, this)
private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)
override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
@ -267,10 +267,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
activeTaskSets.get(taskSetId).foreach { taskSet =>
if (state == TaskState.FINISHED) {
taskSet.removeRunningTask(tid)
taskResultResolver.enqueueSuccessfulTask(taskSet, tid, serializedData)
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
taskSet.removeRunningTask(tid)
taskResultResolver.enqueueFailedTask(taskSet, tid, state, serializedData)
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
}
}
case None =>
@ -338,8 +338,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (jarServer != null) {
jarServer.stop()
}
if (taskResultResolver != null) {
taskResultResolver.stop()
if (taskResultGetter != null) {
taskResultGetter.stop()
}
// sleeping for an arbitrary 5 seconds : to ensure that messages are sent out.

View file

@ -25,7 +25,6 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.math.max
import scala.math.min
import scala.Some
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
@ -458,8 +457,6 @@ private[spark] class ClusterTaskSetManager(
removeRunningTask(tid)
val index = info.index
info.markFailed()
// Count failed attempts only on FAILED and LOST state (not on KILLED)
var countFailedTaskAttempt = (state == TaskState.FAILED || state == TaskState.LOST)
if (!successful(index)) {
logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
copiesRunning(index) -= 1
@ -505,7 +502,6 @@ private[spark] class ClusterTaskSetManager(
case TaskResultLost =>
logInfo("Lost result for TID %s on host %s".format(tid, info.host))
countFailedTaskAttempt = true
sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
case _ => {}
@ -513,7 +509,7 @@ private[spark] class ClusterTaskSetManager(
}
// On non-fetch failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index)
if (countFailedTaskAttempt) {
if (state != TaskState.KILLED) {
numFailures(index) += 1
if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %s:%d failed more than %d times; aborting job".format(

View file

@ -26,17 +26,16 @@ import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskRes
import org.apache.spark.serializer.SerializerInstance
/**
* Runs a thread pool that deserializes and remotely fetches (if neceessary) task results.
* Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
*/
private[spark] class TaskResultResolver(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
extends Logging {
private val MIN_THREADS = 20
private val MAX_THREADS = 60
private val KEEP_ALIVE_SECONDS = 60
private val MIN_THREADS = System.getProperty("spark.resultGetter.minThreads", "4").toInt
private val MAX_THREADS = System.getProperty("spark.resultGetter.maxThreads", "4").toInt
private val getTaskResultExecutor = new ThreadPoolExecutor(
MIN_THREADS,
MAX_THREADS,
KEEP_ALIVE_SECONDS,
0L,
TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable],
new ResultResolverThreadFactory)

View file

@ -253,6 +253,23 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None)
}
test("task result lost") {
sc = new SparkContext("local", "test")
val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
val taskSet = createTaskSet(1)
val clock = new FakeClock
val manager = new ClusterTaskSetManager(sched, taskSet, clock)
assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
// Tell it the task has finished but the result was lost.
manager.handleFailedTask(0, TaskState.FINISHED, Some(TaskResultLost))
assert(sched.endedTasks(0) === TaskResultLost)
// Re-offer the host -- now we should get task 0 again.
assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
}
/**
* Utility method to create a TaskSet, potentially setting a particular sequence of preferred
* locations for each task (given as varargs) if this sequence is not empty.

View file

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.scheduler
package org.apache.spark.scheduler.cluster
import java.nio.ByteBuffer
@ -23,16 +23,16 @@ import org.scalatest.BeforeAndAfter
import org.scalatest.FunSuite
import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv}
import org.apache.spark.scheduler.cluster.{ClusterScheduler, ClusterTaskSetManager, TaskResultResolver}
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
/**
* Removes the TaskResult from the BlockManager before delegating to a normal TaskResultResolver.
* Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter.
*
* Used to test the case where a BlockManager evicts the task result (or dies) before the
* TaskResult is retrieved.
*/
class ResultDeletingTaskResultResolver(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
extends TaskResultResolver(sparkEnv, scheduler) {
class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
extends TaskResultGetter(sparkEnv, scheduler) {
var removedResult = false
override def enqueueSuccessfulTask(
@ -44,7 +44,7 @@ class ResultDeletingTaskResultResolver(sparkEnv: SparkEnv, scheduler: ClusterSch
case IndirectTaskResult(blockId) =>
sparkEnv.blockManager.master.removeBlock(blockId)
case directResult: DirectTaskResult[_] =>
taskSetManager.abort("Expect only indirect results")
taskSetManager.abort("Internal error: expect only indirect results")
}
serializedData.rewind()
removedResult = true
@ -56,9 +56,11 @@ class ResultDeletingTaskResultResolver(sparkEnv: SparkEnv, scheduler: ClusterSch
/**
* Tests related to handling task results (both direct and indirect).
*/
class TaskResultResolverSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
override def beforeAll() {
super.beforeAll()
before {
// Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small
// as we can make it) so the tests don't take too long.
System.setProperty("spark.akka.frameSize", "1")
@ -67,6 +69,11 @@ class TaskResultResolverSuite extends FunSuite with BeforeAndAfter with LocalSpa
sc = new SparkContext("local-cluster[1,1,512]", "test")
}
override def afterAll() {
super.afterAll()
System.clearProperty("spark.akka.frameSize")
}
test("handling results smaller than Akka frame size") {
val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
assert(result === 2)
@ -93,7 +100,7 @@ class TaskResultResolverSuite extends FunSuite with BeforeAndAfter with LocalSpa
assert(false, "Expect local cluster to use ClusterScheduler")
throw new ClassCastException
}
scheduler.taskResultResolver = new ResultDeletingTaskResultResolver(sc.env, scheduler)
scheduler.taskResultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler)
val akkaFrameSize =
sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)