Use stubs instead of mocks for DAGSchedulerSuite.

This commit is contained in:
Stephen Haberman 2013-02-09 10:58:47 -06:00
parent 9cfa068379
commit 921be76533
8 changed files with 258 additions and 532 deletions

View file

@ -38,9 +38,10 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac
} }
} }
private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolean) extends Logging { private[spark] class MapOutputTracker extends Logging {
val timeout = 10.seconds // Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _
var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
@ -53,24 +54,13 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
var cacheGeneration = generation var cacheGeneration = generation
val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
val actorName: String = "MapOutputTracker"
var trackerActor: ActorRef = if (isDriver) {
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
logInfo("Registered MapOutputTrackerActor actor")
actor
} else {
val ip = System.getProperty("spark.driver.host", "localhost")
val port = System.getProperty("spark.driver.port", "7077").toInt
val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
actorSystem.actorFor(url)
}
val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup) val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
// Send a message to the trackerActor and get its result within a default timeout, or // Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails. // throw a SparkException if this fails.
def askTracker(message: Any): Any = { def askTracker(message: Any): Any = {
try { try {
val timeout = 10.seconds
val future = trackerActor.ask(message)(timeout) val future = trackerActor.ask(message)(timeout)
return Await.result(future, timeout) return Await.result(future, timeout)
} catch { } catch {

View file

@ -1,7 +1,6 @@
package spark package spark
import akka.actor.ActorSystem import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
import akka.actor.ActorSystemImpl
import akka.remote.RemoteActorRefProvider import akka.remote.RemoteActorRefProvider
import serializer.Serializer import serializer.Serializer
@ -83,11 +82,23 @@ object SparkEnv extends Logging {
} }
val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer") val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
if (isDriver) {
logInfo("Registering " + name)
actorSystem.actorOf(Props(newActor), name = name)
} else {
val driverIp: String = System.getProperty("spark.driver.host", "localhost")
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, name)
logInfo("Connecting to " + name + ": " + url)
actorSystem.actorFor(url)
}
}
val driverIp: String = System.getProperty("spark.driver.host", "localhost") val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt "BlockManagerMaster",
val blockManagerMaster = new BlockManagerMaster( new spark.storage.BlockManagerMasterActor(isLocal)))
actorSystem, isDriver, isLocal, driverIp, driverPort)
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer) val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer)
val connectionManager = blockManager.connectionManager val connectionManager = blockManager.connectionManager
@ -99,7 +110,12 @@ object SparkEnv extends Logging {
val cacheManager = new CacheManager(blockManager) val cacheManager = new CacheManager(blockManager)
val mapOutputTracker = new MapOutputTracker(actorSystem, isDriver) // Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
val mapOutputTracker = new MapOutputTracker()
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
new MapOutputTrackerActor(mapOutputTracker))
val shuffleFetcher = instantiateClass[ShuffleFetcher]( val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") "spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
@ -137,4 +153,5 @@ object SparkEnv extends Logging {
httpFileServer, httpFileServer,
sparkFilesDir) sparkFilesDir)
} }
} }

View file

@ -88,7 +88,7 @@ class BlockManager(
val host = System.getProperty("spark.hostname", Utils.localHostName()) val host = System.getProperty("spark.hostname", Utils.localHostName())
val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
// Pending reregistration action being executed asynchronously or null if none // Pending reregistration action being executed asynchronously or null if none
@ -946,7 +946,7 @@ class BlockManager(
heartBeatTask.cancel() heartBeatTask.cancel()
} }
connectionManager.stop() connectionManager.stop()
master.actorSystem.stop(slaveActor) actorSystem.stop(slaveActor)
blockInfo.clear() blockInfo.clear()
memoryStore.clear() memoryStore.clear()
diskStore.clear() diskStore.clear()

View file

@ -15,32 +15,12 @@ import akka.util.duration._
import spark.{Logging, SparkException, Utils} import spark.{Logging, SparkException, Utils}
private[spark] class BlockManagerMaster( private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging {
val actorSystem: ActorSystem,
isDriver: Boolean,
isLocal: Boolean,
driverIp: String,
driverPort: Int)
extends Logging {
val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt
val DRIVER_AKKA_ACTOR_NAME = "BlockMasterManager"
val timeout = 10.seconds val timeout = 10.seconds
var driverActor: ActorRef = {
if (isDriver) {
val driverActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)),
name = DRIVER_AKKA_ACTOR_NAME)
logInfo("Registered BlockManagerMaster Actor")
driverActor
} else {
val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, DRIVER_AKKA_ACTOR_NAME)
logInfo("Connecting to BlockManagerMaster: " + url)
actorSystem.actorFor(url)
}
}
/** Remove a dead executor from the driver actor. This is only called on the driver side. */ /** Remove a dead executor from the driver actor. This is only called on the driver side. */
def removeExecutor(execId: String) { def removeExecutor(execId: String) {
@ -59,7 +39,7 @@ private[spark] class BlockManagerMaster(
/** Register the BlockManager's id with the driver. */ /** Register the BlockManager's id with the driver. */
def registerBlockManager( def registerBlockManager(
blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
logInfo("Trying to register BlockManager") logInfo("Trying to register BlockManager")
tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor)) tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor))
logInfo("Registered BlockManager") logInfo("Registered BlockManager")

View file

@ -75,9 +75,8 @@ private[spark] object ThreadingTest {
System.setProperty("spark.kryoserializer.buffer.mb", "1") System.setProperty("spark.kryoserializer.buffer.mb", "1")
val actorSystem = ActorSystem("test") val actorSystem = ActorSystem("test")
val serializer = new KryoSerializer val serializer = new KryoSerializer
val driverIp: String = System.getProperty("spark.driver.host", "localhost") val blockManagerMaster = new BlockManagerMaster(
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))
val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, driverIp, driverPort)
val blockManager = new BlockManager( val blockManager = new BlockManager(
"<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024) "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024)
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))

View file

@ -31,13 +31,15 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master start and stop") { test("master start and stop") {
val actorSystem = ActorSystem("test") val actorSystem = ActorSystem("test")
val tracker = new MapOutputTracker(actorSystem, true) val tracker = new MapOutputTracker()
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
tracker.stop() tracker.stop()
} }
test("master register and fetch") { test("master register and fetch") {
val actorSystem = ActorSystem("test") val actorSystem = ActorSystem("test")
val tracker = new MapOutputTracker(actorSystem, true) val tracker = new MapOutputTracker()
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
tracker.registerShuffle(10, 2) tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L)
@ -55,7 +57,8 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master register and unregister and fetch") { test("master register and unregister and fetch") {
val actorSystem = ActorSystem("test") val actorSystem = ActorSystem("test")
val tracker = new MapOutputTracker(actorSystem, true) val tracker = new MapOutputTracker()
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
tracker.registerShuffle(10, 2) tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L)
@ -77,35 +80,34 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
} }
test("remote fetch") { test("remote fetch") {
try { val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0)
System.clearProperty("spark.driver.host") // In case some previous test had set it val masterTracker = new MapOutputTracker()
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0) masterTracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(masterTracker)))
System.setProperty("spark.driver.port", boundPort.toString)
val masterTracker = new MapOutputTracker(actorSystem, true) val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", "localhost", 0)
val slaveTracker = new MapOutputTracker(actorSystem, false) val slaveTracker = new MapOutputTracker()
masterTracker.registerShuffle(10, 1) slaveTracker.trackerActor = slaveSystem.actorFor("akka://spark@localhost:" + boundPort)
masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration) masterTracker.registerShuffle(10, 1)
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration)
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
masterTracker.registerMapOutput(10, 0, new MapStatus( masterTracker.registerMapOutput(10, 0, new MapStatus(
BlockManagerId("a", "hostA", 1000), Array(compressedSize1000))) BlockManagerId("a", "hostA", 1000), Array(compressedSize1000)))
masterTracker.incrementGeneration() masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration) slaveTracker.updateGeneration(masterTracker.getGeneration)
assert(slaveTracker.getServerStatuses(10, 0).toSeq === assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
Seq((BlockManagerId("a", "hostA", 1000), size1000))) Seq((BlockManagerId("a", "hostA", 1000), size1000)))
masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
masterTracker.incrementGeneration() masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration) slaveTracker.updateGeneration(masterTracker.getGeneration)
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
// failure should be cached // failure should be cached
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
} finally {
System.clearProperty("spark.driver.port")
}
} }
} }

View file

@ -4,16 +4,6 @@ import scala.collection.mutable.{Map, HashMap}
import org.scalatest.FunSuite import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.TimeLimitedTests
import org.scalatest.mock.EasyMockSugar
import org.scalatest.time.{Span, Seconds}
import org.easymock.EasyMock._
import org.easymock.Capture
import org.easymock.EasyMock
import org.easymock.{IAnswer, IArgumentMatcher}
import akka.actor.ActorSystem
import spark.storage.BlockManager import spark.storage.BlockManager
import spark.storage.BlockManagerId import spark.storage.BlockManagerId
@ -42,27 +32,26 @@ import spark.{FetchFailed, Success}
* DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet) * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet)
* and capturing the resulting TaskSets from the mock TaskScheduler. * and capturing the resulting TaskSets from the mock TaskScheduler.
*/ */
class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar with TimeLimitedTests { class DAGSchedulerSuite extends FunSuite with BeforeAndAfter {
// impose a time limit on this test in case we don't let the job finish, in which case
// JobWaiter#getResult will hang.
override val timeLimit = Span(5, Seconds)
val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite") val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite")
var scheduler: DAGScheduler = null
val taskScheduler = mock[TaskScheduler]
val blockManagerMaster = mock[BlockManagerMaster]
var mapOutputTracker: MapOutputTracker = null
var schedulerThread: Thread = null
var schedulerException: Throwable = null
/** /** Set of TaskSets the DAGScheduler has requested executed. */
* Set of EasyMock argument matchers that match a TaskSet for a given RDD. val taskSets = scala.collection.mutable.Buffer[TaskSet]()
* We cache these so we do not create duplicate matchers for the same RDD. val taskScheduler = new TaskScheduler() {
* This allows us to easily setup a sequence of expectations for task sets for override def start() = {}
* that RDD. override def stop() = {}
*/ override def submitTasks(taskSet: TaskSet) = {
val taskSetMatchers = new HashMap[MyRDD, IArgumentMatcher] // normally done by TaskSetManager
taskSet.tasks.foreach(_.generation = mapOutputTracker.getGeneration)
taskSets += taskSet
}
override def setListener(listener: TaskSchedulerListener) = {}
override def defaultParallelism() = 2
}
var mapOutputTracker: MapOutputTracker = null
var scheduler: DAGScheduler = null
/** /**
* Set of cache locations to return from our mock BlockManagerMaster. * Set of cache locations to return from our mock BlockManagerMaster.
@ -70,68 +59,46 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
* list of cache locations silently. * list of cache locations silently.
*/ */
val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
// stub out BlockManagerMaster.getLocations to use our cacheLocations
/** val blockManagerMaster = new BlockManagerMaster(null) {
* JobWaiter for the last JobSubmitted event we pushed. To keep tests (most of which override def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
* will only submit one job) from needing to explicitly track it. blockIds.map { name =>
*/ val pieces = name.split("_")
var lastJobWaiter: JobWaiter[Int] = null if (pieces(0) == "rdd") {
val key = pieces(1).toInt -> pieces(2).toInt
/** cacheLocations.getOrElse(key, Seq())
* Array into which we are accumulating the results from the last job asynchronously. } else {
*/ Seq()
var lastJobResult: Array[Int] = null }
}.toSeq
/** }
* Tell EasyMockSugar what mock objects we want to be configured by expecting {...} override def removeExecutor(execId: String) {
* and whenExecuting {...} */ // don't need to propagate to the driver, which we don't have
implicit val mocks = MockObjects(taskScheduler, blockManagerMaster) }
/**
* Utility function to reset mocks and set expectations on them. EasyMock wants mock objects
* to be reset after each time their expectations are set, and we tend to check mock object
* calls over a single call to DAGScheduler.
*
* We also set a default expectation here that blockManagerMaster.getLocations can be called
* and will return values from cacheLocations.
*/
def resetExpecting(f: => Unit) {
reset(taskScheduler)
reset(blockManagerMaster)
expecting {
expectGetLocations()
f
} }
/** The list of results that DAGScheduler has collected. */
val results = new HashMap[Int, Any]()
var failure: Exception = _
val listener = new JobListener() {
override def taskSucceeded(index: Int, result: Any) = results.put(index, result)
override def jobFailed(exception: Exception) = { failure = exception }
} }
before { before {
taskSetMatchers.clear() taskSets.clear()
cacheLocations.clear() cacheLocations.clear()
val actorSystem = ActorSystem("test") results.clear()
mapOutputTracker = new MapOutputTracker(actorSystem, true) mapOutputTracker = new MapOutputTracker()
resetExpecting { scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null)
taskScheduler.setListener(anyObject())
}
whenExecuting {
scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null)
}
} }
after { after {
assert(scheduler.processEvent(StopDAGScheduler)) scheduler.stop()
resetExpecting {
taskScheduler.stop()
}
whenExecuting {
scheduler.stop()
}
sc.stop() sc.stop()
System.clearProperty("spark.master.port") System.clearProperty("spark.master.port")
} }
def makeBlockManagerId(host: String): BlockManagerId =
BlockManagerId("exec-" + host, host, 12345)
/** /**
* Type of RDD we use for testing. Note that we should never call the real RDD compute methods. * Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
* This is a pair RDD type so it can always be used in ShuffleDependencies. * This is a pair RDD type so it can always be used in ShuffleDependencies.
@ -143,7 +110,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
* preferredLocations (if any) that are passed to them. They are deliberately not executable * preferredLocations (if any) that are passed to them. They are deliberately not executable
* so we can test that DAGScheduler does not try to execute RDDs locally. * so we can test that DAGScheduler does not try to execute RDDs locally.
*/ */
def makeRdd( private def makeRdd(
numSplits: Int, numSplits: Int,
dependencies: List[Dependency[_]], dependencies: List[Dependency[_]],
locations: Seq[Seq[String]] = Nil locations: Seq[Seq[String]] = Nil
@ -164,55 +131,6 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
} }
} }
/**
* EasyMock matcher method. For use as an argument matcher for a TaskSet whose first task
* is from a particular RDD.
*/
def taskSetForRdd(rdd: MyRDD): TaskSet = {
val matcher = taskSetMatchers.getOrElseUpdate(rdd,
new IArgumentMatcher {
override def matches(actual: Any): Boolean = {
val taskSet = actual.asInstanceOf[TaskSet]
taskSet.tasks(0) match {
case rt: ResultTask[_, _] => rt.rdd.id == rdd.id
case smt: ShuffleMapTask => smt.rdd.id == rdd.id
case _ => false
}
}
override def appendTo(buf: StringBuffer) {
buf.append("taskSetForRdd(" + rdd + ")")
}
})
EasyMock.reportMatcher(matcher)
return null
}
/**
* Setup an EasyMock expectation to repsond to blockManagerMaster.getLocations() called from
* cacheLocations.
*/
def expectGetLocations(): Unit = {
EasyMock.expect(blockManagerMaster.getLocations(anyObject().asInstanceOf[Array[String]])).
andAnswer(new IAnswer[Seq[Seq[BlockManagerId]]] {
override def answer(): Seq[Seq[BlockManagerId]] = {
val blocks = getCurrentArguments()(0).asInstanceOf[Array[String]]
return blocks.map { name =>
val pieces = name.split("_")
if (pieces(0) == "rdd") {
val key = pieces(1).toInt -> pieces(2).toInt
if (cacheLocations.contains(key)) {
cacheLocations(key)
} else {
Seq[BlockManagerId]()
}
} else {
Seq[BlockManagerId]()
}
}.toSeq
}
}).anyTimes()
}
/** /**
* Process the supplied event as if it were the top of the DAGScheduler event queue, expecting * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting
* the scheduler not to exit. * the scheduler not to exit.
@ -220,157 +138,81 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
* After processing the event, submit waiting stages as is done on most iterations of the * After processing the event, submit waiting stages as is done on most iterations of the
* DAGScheduler event loop. * DAGScheduler event loop.
*/ */
def runEvent(event: DAGSchedulerEvent) { private def runEvent(event: DAGSchedulerEvent) {
assert(!scheduler.processEvent(event)) assert(!scheduler.processEvent(event))
scheduler.submitWaitingStages() scheduler.submitWaitingStages()
} }
/**
* Expect a TaskSet for the specified RDD to be submitted to the TaskScheduler. Should be
* called from a resetExpecting { ... } block.
*
* Returns a easymock Capture that will contain the task set after the stage is submitted.
* Most tests should use interceptStage() instead of this directly.
*/
def expectStage(rdd: MyRDD): Capture[TaskSet] = {
val taskSetCapture = new Capture[TaskSet]
taskScheduler.submitTasks(and(capture(taskSetCapture), taskSetForRdd(rdd)))
return taskSetCapture
}
/**
* Expect the supplied code snippet to submit a stage for the specified RDD.
* Return the resulting TaskSet. First marks all the tasks are belonging to the
* current MapOutputTracker generation.
*/
def interceptStage(rdd: MyRDD)(f: => Unit): TaskSet = {
var capture: Capture[TaskSet] = null
resetExpecting {
capture = expectStage(rdd)
}
whenExecuting {
f
}
val taskSet = capture.getValue
for (task <- taskSet.tasks) {
task.generation = mapOutputTracker.getGeneration
}
return taskSet
}
/**
* Send the given CompletionEvent messages for the tasks in the TaskSet.
*/
def respondToTaskSet(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
assert(taskSet.tasks.size >= results.size)
for ((result, i) <- results.zipWithIndex) {
if (i < taskSet.tasks.size) {
runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any]()))
}
}
}
/**
* Assert that the supplied TaskSet has exactly the given preferredLocations.
*/
def expectTaskSetLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
assert(locations.size === taskSet.tasks.size)
for ((expectLocs, taskLocs) <-
taskSet.tasks.map(_.preferredLocations).zip(locations)) {
assert(expectLocs === taskLocs)
}
}
/** /**
* When we submit dummy Jobs, this is the compute function we supply. Except in a local test * When we submit dummy Jobs, this is the compute function we supply. Except in a local test
* below, we do not expect this function to ever be executed; instead, we will return results * below, we do not expect this function to ever be executed; instead, we will return results
* directly through CompletionEvents. * directly through CompletionEvents.
*/ */
def jobComputeFunc(context: TaskContext, it: Iterator[(Int, Int)]): Int = private val jobComputeFunc = (context: TaskContext, it: Iterator[(_)]) =>
it.next._1.asInstanceOf[Int] it.next.asInstanceOf[Tuple2[_, _]]._1
/** Send the given CompletionEvent messages for the tasks in the TaskSet. */
/** private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
* Start a job to compute the given RDD. Returns the JobWaiter that will assert(taskSet.tasks.size >= results.size)
* collect the result of the job via callbacks from DAGScheduler. for ((result, i) <- results.zipWithIndex) {
*/ if (i < taskSet.tasks.size) {
def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): (JobWaiter[Int], Array[Int]) = { runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any]()))
val resultArray = new Array[Int](rdd.splits.size) }
val (toSubmit, waiter) = scheduler.prepareJob[(Int, Int), Int](
rdd,
jobComputeFunc,
(0 to (rdd.splits.size - 1)),
"test-site",
allowLocal,
(i: Int, value: Int) => resultArray(i) = value
)
lastJobWaiter = waiter
lastJobResult = resultArray
runEvent(toSubmit)
return (waiter, resultArray)
}
/**
* Assert that a job we started has failed.
*/
def expectJobException(waiter: JobWaiter[Int] = lastJobWaiter) {
waiter.awaitResult() match {
case JobSucceeded => fail()
case JobFailed(_) => return
} }
} }
/** /** Sends the rdd to the scheduler for scheduling. */
* Assert that a job we started has succeeded and has the given result. private def submit(
*/ rdd: RDD[_],
def expectJobResult(expected: Array[Int], waiter: JobWaiter[Int] = lastJobWaiter, partitions: Array[Int],
result: Array[Int] = lastJobResult) { func: (TaskContext, Iterator[_]) => _ = jobComputeFunc,
waiter.awaitResult match { allowLocal: Boolean = false,
case JobSucceeded => listener: JobListener = listener) {
assert(expected === result) runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener))
case JobFailed(_) => }
fail()
} /** Sends TaskSetFailed to the scheduler. */
private def failed(taskSet: TaskSet, message: String) {
runEvent(TaskSetFailed(taskSet, message))
} }
def makeMapStatus(host: String, reduces: Int): MapStatus =
new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
test("zero split job") { test("zero split job") {
val rdd = makeRdd(0, Nil) val rdd = makeRdd(0, Nil)
var numResults = 0 var numResults = 0
def accumulateResult(partition: Int, value: Int) { val fakeListener = new JobListener() {
numResults += 1 override def taskSucceeded(partition: Int, value: Any) = numResults += 1
override def jobFailed(exception: Exception) = throw exception
} }
scheduler.runJob(rdd, jobComputeFunc, Seq(), "test-site", false, accumulateResult) submit(rdd, Array(), listener = fakeListener)
assert(numResults === 0) assert(numResults === 0)
} }
test("run trivial job") { test("run trivial job") {
val rdd = makeRdd(1, Nil) val rdd = makeRdd(1, Nil)
val taskSet = interceptStage(rdd) { submitRdd(rdd) } submit(rdd, Array(0))
respondToTaskSet(taskSet, List( (Success, 42) )) complete(taskSets(0), List((Success, 42)))
expectJobResult(Array(42)) assert(results === Map(0 -> 42))
} }
test("local job") { test("local job") {
val rdd = new MyRDD(sc, Nil) { val rdd = new MyRDD(sc, Nil) {
override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] = override def compute(split: Split, context: TaskContext) = Array(42 -> 0).iterator
Array(42 -> 0).iterator override def getSplits() = Array(new Split { override def index = 0 })
override def getSplits() = Array( new Split { override def index = 0 } )
override def getPreferredLocations(split: Split) = Nil override def getPreferredLocations(split: Split) = Nil
override def toString = "DAGSchedulerSuite Local RDD" override def toString = "DAGSchedulerSuite Local RDD"
} }
submitRdd(rdd, true) runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener))
expectJobResult(Array(42)) // this shouldn't be needed, but i haven't stubbed out runLocally yet
Thread.sleep(500)
assert(results === Map(0 -> 42))
} }
test("run trivial job w/ dependency") { test("run trivial job w/ dependency") {
val baseRdd = makeRdd(1, Nil) val baseRdd = makeRdd(1, Nil)
val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) } submit(finalRdd, Array(0))
respondToTaskSet(taskSet, List( (Success, 42) )) complete(taskSets(0), Seq((Success, 42)))
expectJobResult(Array(42)) assert(results === Map(0 -> 42))
} }
test("cache location preferences w/ dependency") { test("cache location preferences w/ dependency") {
@ -378,17 +220,17 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
cacheLocations(baseRdd.id -> 0) = cacheLocations(baseRdd.id -> 0) =
Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) } submit(finalRdd, Array(0))
expectTaskSetLocations(taskSet, List(Seq("hostA", "hostB"))) val taskSet = taskSets(0)
respondToTaskSet(taskSet, List( (Success, 42) )) assertLocations(taskSet, Seq(Seq("hostA", "hostB")))
expectJobResult(Array(42)) complete(taskSet, Seq((Success, 42)))
assert(results === Map(0 -> 42))
} }
test("trivial job failure") { test("trivial job failure") {
val rdd = makeRdd(1, Nil) submit(makeRdd(1, Nil), Array(0))
val taskSet = interceptStage(rdd) { submitRdd(rdd) } failed(taskSets(0), "some failure")
runEvent(TaskSetFailed(taskSet, "test failure")) assert(failure.getMessage === "Job failed: some failure")
expectJobException()
} }
test("run trivial shuffle") { test("run trivial shuffle") {
@ -396,52 +238,39 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
val shuffleId = shuffleDep.shuffleId val shuffleId = shuffleDep.shuffleId
val reduceRdd = makeRdd(1, List(shuffleDep)) val reduceRdd = makeRdd(1, List(shuffleDep))
submit(reduceRdd, Array(0))
val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) } complete(taskSets(0), Seq(
val secondStage = interceptStage(reduceRdd) {
respondToTaskSet(firstStage, List(
(Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1)) (Success, makeMapStatus("hostB", 1))))
))
}
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
respondToTaskSet(secondStage, List( (Success, 42) )) complete(taskSets(1), Seq((Success, 42)))
expectJobResult(Array(42)) assert(results === Map(0 -> 42))
} }
test("run trivial shuffle with fetch failure") { test("run trivial shuffle with fetch failure") {
val shuffleMapRdd = makeRdd(2, Nil) val shuffleMapRdd = makeRdd(2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
val shuffleId = shuffleDep.shuffleId val shuffleId = shuffleDep.shuffleId
val reduceRdd = makeRdd(2, List(shuffleDep)) val reduceRdd = makeRdd(2, List(shuffleDep))
submit(reduceRdd, Array(0, 1))
val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) } complete(taskSets(0), Seq(
val secondStage = interceptStage(reduceRdd) {
respondToTaskSet(firstStage, List(
(Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1)) (Success, makeMapStatus("hostB", 1))))
)) // the 2nd ResultTask failed
} complete(taskSets(1), Seq(
resetExpecting {
blockManagerMaster.removeExecutor("exec-hostA")
}
whenExecuting {
respondToTaskSet(secondStage, List(
(Success, 42), (Success, 42),
(FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null) (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)))
)) // this will get called
} // blockManagerMaster.removeExecutor("exec-hostA")
val thirdStage = interceptStage(shuffleMapRdd) { // ask the scheduler to try it again
scheduler.resubmitFailedStages() scheduler.resubmitFailedStages()
} // have the 2nd attempt pass
val fourthStage = interceptStage(reduceRdd) { complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
respondToTaskSet(thirdStage, List( (Success, makeMapStatus("hostA", 1)) )) // we can see both result blocks now
} assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.ip) === Array("hostA", "hostB"))
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === complete(taskSets(3), Seq((Success, 43)))
Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) assert(results === Map(0 -> 42, 1 -> 43))
respondToTaskSet(fourthStage, List( (Success, 43) ))
expectJobResult(Array(42, 43))
} }
test("ignore late map task completions") { test("ignore late map task completions") {
@ -449,33 +278,27 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
val shuffleId = shuffleDep.shuffleId val shuffleId = shuffleDep.shuffleId
val reduceRdd = makeRdd(2, List(shuffleDep)) val reduceRdd = makeRdd(2, List(shuffleDep))
submit(reduceRdd, Array(0, 1))
val taskSet = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) } // pretend we were told hostA went away
val oldGeneration = mapOutputTracker.getGeneration val oldGeneration = mapOutputTracker.getGeneration
resetExpecting { runEvent(ExecutorLost("exec-hostA"))
blockManagerMaster.removeExecutor("exec-hostA")
}
whenExecuting {
runEvent(ExecutorLost("exec-hostA"))
}
val newGeneration = mapOutputTracker.getGeneration val newGeneration = mapOutputTracker.getGeneration
assert(newGeneration > oldGeneration) assert(newGeneration > oldGeneration)
val noAccum = Map[Long, Any]() val noAccum = Map[Long, Any]()
// We rely on the event queue being ordered and increasing the generation number by 1 val taskSet = taskSets(0)
// should be ignored for being too old // should be ignored for being too old
runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum)) runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
// should work because it's a non-failed host // should work because it's a non-failed host
runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum)) runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum))
// should be ignored for being too old // should be ignored for being too old
runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum)) runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
// should work because it's a new generation
taskSet.tasks(1).generation = newGeneration taskSet.tasks(1).generation = newGeneration
val secondStage = interceptStage(reduceRdd) { runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum))
runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum))
}
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
respondToTaskSet(secondStage, List( (Success, 42), (Success, 43) )) complete(taskSets(1), Seq((Success, 42), (Success, 43)))
expectJobResult(Array(42, 43)) assert(results === Map(0 -> 42, 1 -> 43))
} }
test("run trivial shuffle with out-of-band failure and retry") { test("run trivial shuffle with out-of-band failure and retry") {
@ -483,76 +306,49 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
val shuffleId = shuffleDep.shuffleId val shuffleId = shuffleDep.shuffleId
val reduceRdd = makeRdd(1, List(shuffleDep)) val reduceRdd = makeRdd(1, List(shuffleDep))
submit(reduceRdd, Array(0))
val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) } // blockManagerMaster.removeExecutor("exec-hostA")
resetExpecting { // pretend we were told hostA went away
blockManagerMaster.removeExecutor("exec-hostA") runEvent(ExecutorLost("exec-hostA"))
}
whenExecuting {
runEvent(ExecutorLost("exec-hostA"))
}
// DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
// rather than marking it is as failed and waiting. // rather than marking it is as failed and waiting.
val secondStage = interceptStage(shuffleMapRdd) { complete(taskSets(0), Seq(
respondToTaskSet(firstStage, List(
(Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1)) (Success, makeMapStatus("hostB", 1))))
)) // have hostC complete the resubmitted task
} complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
val thirdStage = interceptStage(reduceRdd) { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
respondToTaskSet(secondStage, List( Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
(Success, makeMapStatus("hostC", 1)) complete(taskSets(2), Seq((Success, 42)))
)) assert(results === Map(0 -> 42))
} }
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
respondToTaskSet(thirdStage, List( (Success, 42) ))
expectJobResult(Array(42))
}
test("recursive shuffle failures") { test("recursive shuffle failures") {
val shuffleOneRdd = makeRdd(2, Nil) val shuffleOneRdd = makeRdd(2, Nil)
val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
val finalRdd = makeRdd(1, List(shuffleDepTwo)) val finalRdd = makeRdd(1, List(shuffleDepTwo))
submit(finalRdd, Array(0))
val firstStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) } // have the first stage complete normally
val secondStage = interceptStage(shuffleTwoRdd) { complete(taskSets(0), Seq(
respondToTaskSet(firstStage, List(
(Success, makeMapStatus("hostA", 2)), (Success, makeMapStatus("hostA", 2)),
(Success, makeMapStatus("hostB", 2)) (Success, makeMapStatus("hostB", 2))))
)) // have the second stage complete normally
} complete(taskSets(1), Seq(
val thirdStage = interceptStage(finalRdd) {
respondToTaskSet(secondStage, List(
(Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostC", 1)) (Success, makeMapStatus("hostC", 1))))
)) // fail the third stage because hostA went down
} complete(taskSets(2), Seq(
resetExpecting { (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
blockManagerMaster.removeExecutor("exec-hostA") // TODO assert this:
} // blockManagerMaster.removeExecutor("exec-hostA")
whenExecuting { // have DAGScheduler try again
respondToTaskSet(thirdStage, List( scheduler.resubmitFailedStages()
(FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null) complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 2))))
)) complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1))))
} complete(taskSets(5), Seq((Success, 42)))
val recomputeOne = interceptStage(shuffleOneRdd) { assert(results === Map(0 -> 42))
scheduler.resubmitFailedStages()
}
val recomputeTwo = interceptStage(shuffleTwoRdd) {
respondToTaskSet(recomputeOne, List(
(Success, makeMapStatus("hostA", 2))
))
}
val finalStage = interceptStage(finalRdd) {
respondToTaskSet(recomputeTwo, List(
(Success, makeMapStatus("hostA", 1))
))
}
respondToTaskSet(finalStage, List( (Success, 42) ))
expectJobResult(Array(42))
} }
test("cached post-shuffle") { test("cached post-shuffle") {
@ -561,103 +357,44 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
val finalRdd = makeRdd(1, List(shuffleDepTwo)) val finalRdd = makeRdd(1, List(shuffleDepTwo))
submit(finalRdd, Array(0))
val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
val secondShuffleStage = interceptStage(shuffleTwoRdd) { // complete stage 2
respondToTaskSet(firstShuffleStage, List( complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", 2)), (Success, makeMapStatus("hostA", 2)),
(Success, makeMapStatus("hostB", 2)) (Success, makeMapStatus("hostB", 2))))
)) // complete stage 1
} complete(taskSets(1), Seq(
val reduceStage = interceptStage(finalRdd) {
respondToTaskSet(secondShuffleStage, List(
(Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1)) (Success, makeMapStatus("hostB", 1))))
)) // pretend stage 0 failed because hostA went down
} complete(taskSets(2), Seq(
resetExpecting { (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
blockManagerMaster.removeExecutor("exec-hostA") // TODO assert this:
} // blockManagerMaster.removeExecutor("exec-hostA")
whenExecuting {
respondToTaskSet(reduceStage, List(
(FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
))
}
// DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
val recomputeTwo = interceptStage(shuffleTwoRdd) { scheduler.resubmitFailedStages()
scheduler.resubmitFailedStages() assertLocations(taskSets(3), Seq(Seq("hostD")))
} // allow hostD to recover
expectTaskSetLocations(recomputeTwo, Seq(Seq("hostD"))) complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1))))
val finalRetry = interceptStage(finalRdd) { complete(taskSets(4), Seq((Success, 42)))
respondToTaskSet(recomputeTwo, List( assert(results === Map(0 -> 42))
(Success, makeMapStatus("hostD", 1))
))
}
respondToTaskSet(finalRetry, List( (Success, 42) ))
expectJobResult(Array(42))
} }
test("cached post-shuffle but fails") { /** Assert that the supplied TaskSet has exactly the given preferredLocations. */
val shuffleOneRdd = makeRdd(2, Nil) private def assertLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) assert(locations.size === taskSet.tasks.size)
val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) for ((expectLocs, taskLocs) <-
val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) taskSet.tasks.map(_.preferredLocations).zip(locations)) {
val finalRdd = makeRdd(1, List(shuffleDepTwo)) assert(expectLocs === taskLocs)
val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
val secondShuffleStage = interceptStage(shuffleTwoRdd) {
respondToTaskSet(firstShuffleStage, List(
(Success, makeMapStatus("hostA", 2)),
(Success, makeMapStatus("hostB", 2))
))
} }
val reduceStage = interceptStage(finalRdd) {
respondToTaskSet(secondShuffleStage, List(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1))
))
}
resetExpecting {
blockManagerMaster.removeExecutor("exec-hostA")
}
whenExecuting {
respondToTaskSet(reduceStage, List(
(FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
))
}
val recomputeTwoCached = interceptStage(shuffleTwoRdd) {
scheduler.resubmitFailedStages()
}
expectTaskSetLocations(recomputeTwoCached, Seq(Seq("hostD")))
intercept[FetchFailedException]{
mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0)
}
// Simulate the shuffle input data failing to be cached.
cacheLocations.remove(shuffleTwoRdd.id -> 0)
respondToTaskSet(recomputeTwoCached, List(
(FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null)
))
// After the fetch failure, DAGScheduler should recheck the cache and decide to resubmit
// everything.
val recomputeOne = interceptStage(shuffleOneRdd) {
scheduler.resubmitFailedStages()
}
// We use hostA here to make sure DAGScheduler doesn't think it's still dead.
val recomputeTwoUncached = interceptStage(shuffleTwoRdd) {
respondToTaskSet(recomputeOne, List( (Success, makeMapStatus("hostA", 1)) ))
}
expectTaskSetLocations(recomputeTwoUncached, Seq(Seq[String]()))
val finalRetry = interceptStage(finalRdd) {
respondToTaskSet(recomputeTwoUncached, List( (Success, makeMapStatus("hostA", 1)) ))
}
respondToTaskSet(finalRetry, List( (Success, 42) ))
expectJobResult(Array(42))
} }
}
private def makeMapStatus(host: String, reduces: Int): MapStatus =
new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
private def makeBlockManagerId(host: String): BlockManagerId =
BlockManagerId("exec-" + host, host, 12345)
}

View file

@ -31,7 +31,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
before { before {
actorSystem = ActorSystem("test") actorSystem = ActorSystem("test")
master = new BlockManagerMaster(actorSystem, true, true, "localhost", 7077) master = new BlockManagerMaster(
actorSystem.actorOf(Props(new spark.storage.BlockManagerMasterActor(true))))
// Set the arch to 64-bit and compressedOops to true to get a deterministic test-case // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
oldArch = System.setProperty("os.arch", "amd64") oldArch = System.setProperty("os.arch", "amd64")