Merge branch 'mesos'

This commit is contained in:
haitao.yao 2013-02-04 11:40:15 +08:00
commit faa4d9e31f
43 changed files with 1195 additions and 393 deletions

View file

@ -98,6 +98,11 @@
<artifactId>scalacheck_${scala.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.easymock</groupId>
<artifactId>easymock</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.novocode</groupId>
<artifactId>junit-interface</artifactId>

View file

@ -61,17 +61,3 @@ class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int)
}
}
}
/**
* Represents a dependency between the PartitionPruningRDD and its parent. In this
* case, the child RDD contains a subset of partitions of the parents'.
*/
class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean)
extends NarrowDependency[T](rdd) {
@transient
val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index))
override def getParents(partitionId: Int) = List(partitions(partitionId).index)
}

View file

@ -170,7 +170,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
}
}
def cleanup(cleanupTime: Long) {
private def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}

View file

@ -465,7 +465,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
val res = self.context.runJob(self, process _, Array(index), false)
res(0)
case None =>
self.filter(_._1 == key).map(_._2).collect
self.filter(_._1 == key).map(_._2).collect()
}
}
@ -590,7 +590,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
var count = 0
while(iter.hasNext) {
val record = iter.next
val record = iter.next()
count += 1
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
}

View file

@ -385,20 +385,22 @@ abstract class RDD[T: ClassManifest](
val reducePartition: Iterator[T] => Option[T] = iter => {
if (iter.hasNext) {
Some(iter.reduceLeft(cleanF))
}else {
} else {
None
}
}
val options = sc.runJob(this, reducePartition)
val results = new ArrayBuffer[T]
for (opt <- options; elem <- opt) {
results += elem
}
if (results.size == 0) {
throw new UnsupportedOperationException("empty collection")
} else {
return results.reduceLeft(cleanF)
var jobResult: Option[T] = None
val mergeResult = (index: Int, taskResult: Option[T]) => {
if (taskResult != None) {
jobResult = jobResult match {
case Some(value) => Some(f(value, taskResult.get))
case None => taskResult
}
}
}
sc.runJob(this, reducePartition, mergeResult)
// Get the final result out of our Option, or throw an exception if the RDD was empty
jobResult.getOrElse(throw new UnsupportedOperationException("empty collection"))
}
/**
@ -408,9 +410,13 @@ abstract class RDD[T: ClassManifest](
* modify t2.
*/
def fold(zeroValue: T)(op: (T, T) => T): T = {
// Clone the zero value since we will also be serializing it as part of tasks
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
val cleanOp = sc.clean(op)
val results = sc.runJob(this, (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp))
return results.fold(zeroValue)(cleanOp)
val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp)
val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult)
sc.runJob(this, foldPartition, mergeResult)
jobResult
}
/**
@ -422,11 +428,14 @@ abstract class RDD[T: ClassManifest](
* allocation.
*/
def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
// Clone the zero value since we will also be serializing it as part of tasks
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
val cleanSeqOp = sc.clean(seqOp)
val cleanCombOp = sc.clean(combOp)
val results = sc.runJob(this,
(iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp))
return results.fold(zeroValue)(cleanCombOp)
val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
sc.runJob(this, aggregatePartition, mergeResult)
jobResult
}
/**
@ -437,7 +446,7 @@ abstract class RDD[T: ClassManifest](
var result = 0L
while (iter.hasNext) {
result += 1L
iter.next
iter.next()
}
result
}).sum
@ -452,7 +461,7 @@ abstract class RDD[T: ClassManifest](
var result = 0L
while (iter.hasNext) {
result += 1L
iter.next
iter.next()
}
result
}

View file

@ -46,6 +46,7 @@ import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, C
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import storage.BlockManagerUI
import util.{MetadataCleaner, TimeStampedHashMap}
import storage.{StorageStatus, StorageUtils, RDDInfo}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@ -107,8 +108,9 @@ class SparkContext(
// Environment variables to pass to our executors
private[spark] val executorEnvs = HashMap[String, String]()
// Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner
for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS",
"SPARK_TESTING")) {
"SPARK_TESTING")) {
val value = System.getenv(key)
if (value != null) {
executorEnvs(key) = value
@ -187,6 +189,7 @@ class SparkContext(
taskScheduler.start()
private var dagScheduler = new DAGScheduler(taskScheduler)
dagScheduler.start()
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = {
@ -467,12 +470,27 @@ class SparkContext(
* Return a map from the slave to the max memory available for caching and the remaining
* memory available for caching.
*/
def getSlavesMemoryStatus: Map[String, (Long, Long)] = {
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
(blockManagerId.ip + ":" + blockManagerId.port, mem)
}
}
/**
* Return information about what RDDs are cached, if they are in mem or on disk, how much space
* they take, etc.
*/
def getRDDStorageInfo : Array[RDDInfo] = {
StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
}
/**
* Return information about blocks stored in all of the slaves
*/
def getExecutorStorageStatus : Array[StorageStatus] = {
env.blockManager.master.getStorageStatus
}
/**
* Clear the job's list of files added by `addFile` so that they do not get downloaded to
* any new nodes.
@ -543,10 +561,30 @@ class SparkContext(
}
/**
* Run a function on a given set of partitions in an RDD and return the results. This is the main
* entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies
* whether the scheduler can run the computation on the driver rather than shipping it out to the
* cluster, for short actions like first().
* Run a function on a given set of partitions in an RDD and pass the results to the given
* handler function. This is the main entry point for all actions in Spark. The allowLocal
* flag specifies whether the scheduler can run the computation on the driver rather than
* shipping it out to the cluster, for short actions like first().
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
val callSite = Utils.getSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
result
}
/**
* Run a function on a given set of partitions in an RDD and return the results as an array. The
* allowLocal flag specifies whether the scheduler can run the computation on the driver rather
* than shipping it out to the cluster, for short actions like first().
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
@ -554,13 +592,9 @@ class SparkContext(
partitions: Seq[Int],
allowLocal: Boolean
): Array[U] = {
val callSite = Utils.getSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
result
val results = new Array[U](partitions.size)
runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res)
results
}
/**
@ -590,6 +624,29 @@ class SparkContext(
runJob(rdd, func, 0 until rdd.splits.size, false)
}
/**
* Run a job on all partitions in an RDD and pass the results to a handler function.
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
processPartition: (TaskContext, Iterator[T]) => U,
resultHandler: (Int, U) => Unit)
{
runJob[T, U](rdd, processPartition, 0 until rdd.splits.size, false, resultHandler)
}
/**
* Run a job on all partitions in an RDD and pass the results to a handler function.
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
processPartition: Iterator[T] => U,
resultHandler: (Int, U) => Unit)
{
val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter)
runJob[T, U](rdd, processFunc, 0 until rdd.splits.size, false, resultHandler)
}
/**
* Run a job that can return approximate results.
*/

View file

@ -12,6 +12,7 @@ import scala.io.Source
import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
import scala.Some
import spark.serializer.SerializerInstance
/**
* Various utility methods used by Spark.
@ -446,4 +447,11 @@ private object Utils extends Logging {
socket.close()
portBound
}
/**
* Clone an object using a Spark serializer.
*/
def clone[T](value: T, serializer: SerializerInstance): T = {
serializer.deserialize[T](serializer.serialize(value))
}
}

View file

@ -18,35 +18,23 @@ import scala.collection.mutable.ArrayBuffer
private[spark]
class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
val localIpAddress = Utils.localIpAddress
private val localIpAddress = Utils.localIpAddress
private val masterActorSystems = ArrayBuffer[ActorSystem]()
private val workerActorSystems = ArrayBuffer[ActorSystem]()
var masterActor : ActorRef = _
var masterActorSystem : ActorSystem = _
var masterPort : Int = _
var masterUrl : String = _
val workerActorSystems = ArrayBuffer[ActorSystem]()
val workerActors = ArrayBuffer[ActorRef]()
def start() : String = {
def start(): String = {
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
/* Start the Master */
val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0)
masterActorSystem = actorSystem
masterUrl = "spark://" + localIpAddress + ":" + masterPort
masterActor = masterActorSystem.actorOf(
Props(new Master(localIpAddress, masterPort, 0)), name = "Master")
val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0)
masterActorSystems += masterSystem
val masterUrl = "spark://" + localIpAddress + ":" + masterPort
/* Start the Slaves */
/* Start the Workers */
for (workerNum <- 1 to numWorkers) {
val (actorSystem, boundPort) =
AkkaUtils.createActorSystem("sparkWorker" + workerNum, localIpAddress, 0)
workerActorSystems += actorSystem
val actor = actorSystem.actorOf(
Props(new Worker(localIpAddress, boundPort, 0, coresPerWorker, memoryPerWorker, masterUrl)),
name = "Worker")
workerActors += actor
val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker,
memoryPerWorker, masterUrl, null, Some(workerNum))
workerActorSystems += workerSystem
}
return masterUrl
@ -57,7 +45,7 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
// Stop the workers before the master so they don't get upset that it disconnected
workerActorSystems.foreach(_.shutdown())
workerActorSystems.foreach(_.awaitTermination())
masterActorSystem.shutdown()
masterActorSystem.awaitTermination()
masterActorSystems.foreach(_.shutdown())
masterActorSystems.foreach(_.awaitTermination())
}
}

View file

@ -9,6 +9,7 @@ import spark.{SparkException, Logging}
import akka.remote.RemoteClientLifeCycleEvent
import akka.remote.RemoteClientShutdown
import spark.deploy.RegisterJob
import spark.deploy.master.Master
import akka.remote.RemoteClientDisconnected
import akka.actor.Terminated
import akka.dispatch.Await
@ -24,26 +25,18 @@ private[spark] class Client(
listener: ClientListener)
extends Logging {
val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r
var actor: ActorRef = null
var jobId: String = null
if (MASTER_REGEX.unapplySeq(masterUrl) == None) {
throw new SparkException("Invalid master URL: " + masterUrl)
}
class ClientActor extends Actor with Logging {
var master: ActorRef = null
var masterAddress: Address = null
var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
override def preStart() {
val Seq(masterHost, masterPort) = MASTER_REGEX.unapplySeq(masterUrl).get
logInfo("Connecting to master spark://" + masterHost + ":" + masterPort)
val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
logInfo("Connecting to master " + masterUrl)
try {
master = context.actorFor(akkaUrl)
master = context.actorFor(Master.toAkkaUrl(masterUrl))
masterAddress = master.path.address
master ! RegisterJob(jobDescription)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])

View file

@ -262,11 +262,29 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
}
private[spark] object Master {
private val systemName = "sparkMaster"
private val actorName = "Master"
private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r
def main(argStrings: Array[String]) {
val args = new MasterArguments(argStrings)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
val actor = actorSystem.actorOf(
Props(new Master(args.ip, boundPort, args.webUiPort)), name = "Master")
val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort)
actorSystem.awaitTermination()
}
/** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */
def toAkkaUrl(sparkUrl: String): String = {
sparkUrl match {
case sparkUrlRegex(host, port) =>
"akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName)
case _ =>
throw new SparkException("Invalid master URL: " + sparkUrl)
}
}
def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int) = {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName)
(actorSystem, boundPort)
}
}

View file

@ -45,13 +45,9 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) =>
val future = master ? RequestMasterState
val jobInfo = for (masterState <- future.mapTo[MasterState]) yield {
masterState.activeJobs.find(_.id == jobId) match {
case Some(job) => job
case _ => masterState.completedJobs.find(_.id == jobId) match {
case Some(job) => job
case _ => null
}
}
masterState.activeJobs.find(_.id == jobId).getOrElse({
masterState.completedJobs.find(_.id == jobId).getOrElse(null)
})
}
respondWithMediaType(MediaTypes.`application/json`) { ctx =>
ctx.complete(jobInfo.mapTo[JobInfo])
@ -61,14 +57,10 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
val future = master ? RequestMasterState
future.map { state =>
val masterState = state.asInstanceOf[MasterState]
masterState.activeJobs.find(_.id == jobId) match {
case Some(job) => spark.deploy.master.html.job_details.render(job)
case _ => masterState.completedJobs.find(_.id == jobId) match {
case Some(job) => spark.deploy.master.html.job_details.render(job)
case _ => null
}
}
val job = masterState.activeJobs.find(_.id == jobId).getOrElse({
masterState.completedJobs.find(_.id == jobId).getOrElse(null)
})
spark.deploy.master.html.job_details.render(job)
}
}
}

View file

@ -113,8 +113,7 @@ private[spark] class ExecutorRunner(
for ((key, value) <- jobDesc.command.environment) {
env.put(key, value)
}
env.put("SPARK_CORES", cores.toString)
env.put("SPARK_MEMORY", memory.toString)
env.put("SPARK_MEM", memory.toString + "m")
// In case we are running this from within the Spark Shell, avoid creating a "scala"
// parent process for the executor command
env.put("SPARK_LAUNCH_WITH_SCALA", "0")

View file

@ -1,7 +1,7 @@
package spark.deploy.worker
import scala.collection.mutable.{ArrayBuffer, HashMap}
import akka.actor.{ActorRef, Props, Actor}
import akka.actor.{ActorRef, Props, Actor, ActorSystem}
import spark.{Logging, Utils}
import spark.util.AkkaUtils
import spark.deploy._
@ -13,6 +13,7 @@ import akka.remote.RemoteClientDisconnected
import spark.deploy.RegisterWorker
import spark.deploy.LaunchExecutor
import spark.deploy.RegisterWorkerFailed
import spark.deploy.master.Master
import akka.actor.Terminated
import java.io.File
@ -27,7 +28,6 @@ private[spark] class Worker(
extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r
var master: ActorRef = null
var masterWebUiUrl : String = ""
@ -48,11 +48,7 @@ private[spark] class Worker(
def memoryFree: Int = memory - memoryUsed
def createWorkDir() {
workDir = if (workDirPath != null) {
new File(workDirPath)
} else {
new File(sparkHome, "work")
}
workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work"))
try {
if (!workDir.exists() && !workDir.mkdirs()) {
logError("Failed to create work directory " + workDir)
@ -68,8 +64,7 @@ private[spark] class Worker(
override def preStart() {
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
ip, port, cores, Utils.memoryMegabytesToString(memory)))
val envVar = System.getenv("SPARK_HOME")
sparkHome = new File(if (envVar == null) "." else envVar)
sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
logInfo("Spark home: " + sparkHome)
createWorkDir()
connectToMaster()
@ -77,24 +72,15 @@ private[spark] class Worker(
}
def connectToMaster() {
masterUrl match {
case MASTER_REGEX(masterHost, masterPort) => {
logInfo("Connecting to master spark://" + masterHost + ":" + masterPort)
val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
try {
master = context.actorFor(akkaUrl)
master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(master) // Doesn't work with remote actors, but useful for testing
} catch {
case e: Exception =>
logError("Failed to connect to master", e)
System.exit(1)
}
}
case _ =>
logError("Invalid master URL: " + masterUrl)
logInfo("Connecting to master " + masterUrl)
try {
master = context.actorFor(Master.toAkkaUrl(masterUrl))
master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(master) // Doesn't work with remote actors, but useful for testing
} catch {
case e: Exception =>
logError("Failed to connect to master", e)
System.exit(1)
}
}
@ -183,11 +169,19 @@ private[spark] class Worker(
private[spark] object Worker {
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
val actor = actorSystem.actorOf(
Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory,
args.master, args.workDir)),
name = "Worker")
val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores,
args.memory, args.master, args.workDir)
actorSystem.awaitTermination()
}
def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int,
masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory,
masterUrl, workDir)), name = "Worker")
(actorSystem, boundPort)
}
}

View file

@ -32,7 +32,7 @@ private[spark] class ApproximateActionListener[T, U, R](
if (finishedTasks == totalTasks) {
// If we had already returned a PartialResult, set its final value
resultObject.foreach(r => r.setFinalValue(evaluator.currentResult()))
// Notify any waiting thread that may have called getResult
// Notify any waiting thread that may have called awaitResult
this.notifyAll()
}
}
@ -49,7 +49,7 @@ private[spark] class ApproximateActionListener[T, U, R](
* Waits for up to timeout milliseconds since the listener was created and then returns a
* PartialResult with the result so far. This may be complete if the whole job is done.
*/
def getResult(): PartialResult[R] = synchronized {
def awaitResult(): PartialResult[R] = synchronized {
val finishTime = startTime + timeout
while (true) {
val time = System.currentTimeMillis()

View file

@ -1,24 +1,42 @@
package spark.rdd
import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext}
import spark.{NarrowDependency, RDD, SparkEnv, Split, TaskContext}
class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split {
override val index = idx
}
/**
* Represents a dependency between the PartitionPruningRDD and its parent. In this
* case, the child RDD contains a subset of partitions of the parents'.
*/
class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean)
extends NarrowDependency[T](rdd) {
@transient
val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index))
.zipWithIndex.map { case(split, idx) => new PartitionPruningRDDSplit(idx, split) : Split }
override def getParents(partitionId: Int) = List(partitions(partitionId).index)
}
/**
* A RDD used to prune RDD partitions/splits so we can avoid launching tasks on
* all partitions. An example use case: If we know the RDD is partitioned by range,
* and the execution DAG has a filter on the key, we can avoid launching tasks
* on partitions that don't have the range covering the key.
*
* TODO: This currently doesn't give partition IDs properly!
*/
class PartitionPruningRDD[T: ClassManifest](
@transient prev: RDD[T],
@transient partitionFilterFunc: Int => Boolean)
extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) {
override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context)
override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(
split.asInstanceOf[PartitionPruningRDDSplit].parentSplit, context)
override protected def getSplits =
getDependencies.head.asInstanceOf[PruneDependency[T]].partitions
override val partitioner = firstParent[T].partitioner
}

View file

@ -23,7 +23,16 @@ import util.{MetadataCleaner, TimeStampedHashMap}
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
*/
private[spark]
class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging {
class DAGScheduler(
taskSched: TaskScheduler,
mapOutputTracker: MapOutputTracker,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv)
extends TaskSchedulerListener with Logging {
def this(taskSched: TaskScheduler) {
this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
}
taskSched.setListener(this)
// Called by TaskScheduler to report task completions or failures.
@ -66,10 +75,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
var cacheLocs = new HashMap[Int, Array[List[String]]]
val env = SparkEnv.get
val mapOutputTracker = env.mapOutputTracker
val blockManagerMaster = env.blockManager.master
// For tracking failed nodes, we use the MapOutputTracker's generation number, which is
// sent with every task. When we detect a node failing, we note the current generation number
// and failed executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask
@ -90,14 +95,16 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
// Start a thread to run the DAGScheduler event loop
new Thread("DAGScheduler") {
setDaemon(true)
override def run() {
DAGScheduler.this.run()
}
}.start()
def start() {
new Thread("DAGScheduler") {
setDaemon(true)
override def run() {
DAGScheduler.this.run()
}
}.start()
}
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
if (!cacheLocs.contains(rdd.id)) {
val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map {
@ -107,7 +114,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
cacheLocs(rdd.id)
}
def clearCacheLocs() {
private def clearCacheLocs() {
cacheLocs.clear()
}
@ -116,7 +123,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* The priority value passed in will be used if the stage doesn't already exist with
* a lower priority (we assume that priorities always increase across jobs for now).
*/
def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = {
private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
@ -131,11 +138,11 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* as a result stage for the final RDD used directly in an action. The stage will also be given
* the provided priority.
*/
def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of splits is unknown
logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
if (shuffleDep != None) {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of splits is unknown
logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
}
val id = nextStageId.getAndIncrement()
@ -148,7 +155,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Get or create the list of parent stages for a given RDD. The stages will be assigned the
* provided priority if they haven't already been created with a lower priority.
*/
def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
private def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
val parents = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
def visit(r: RDD[_]) {
@ -170,25 +177,22 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
parents.toList
}
def getMissingParentStages(stage: Stage): List[Stage] = {
private def getMissingParentStages(stage: Stage): List[Stage] = {
val missing = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
def visit(rdd: RDD[_]) {
if (!visited(rdd)) {
visited += rdd
val locs = getCacheLocs(rdd)
for (p <- 0 until rdd.splits.size) {
if (locs(p) == Nil) {
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
val mapStage = getShuffleMapStage(shufDep, stage.priority)
if (!mapStage.isAvailable) {
missing += mapStage
}
case narrowDep: NarrowDependency[_] =>
visit(narrowDep.rdd)
}
if (getCacheLocs(rdd).contains(Nil)) {
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
val mapStage = getShuffleMapStage(shufDep, stage.priority)
if (!mapStage.isAvailable) {
missing += mapStage
}
case narrowDep: NarrowDependency[_] =>
visit(narrowDep.rdd)
}
}
}
@ -198,23 +202,45 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
missing.toList
}
/**
* Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
* JobWaiter whose getResult() method will return the result of the job when it is complete.
*
* The job is assumed to have at least one partition; zero partition jobs should be handled
* without a JobSubmitted event.
*/
private[scheduler] def prepareJob[T, U: ClassManifest](
finalRdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean,
resultHandler: (Int, U) => Unit)
: (JobSubmitted, JobWaiter[U]) =
{
assert(partitions.size > 0)
val waiter = new JobWaiter(partitions.size, resultHandler)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)
return (toSubmit, waiter)
}
def runJob[T, U: ClassManifest](
finalRdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean)
: Array[U] =
allowLocal: Boolean,
resultHandler: (Int, U) => Unit)
{
if (partitions.size == 0) {
return new Array[U](0)
return
}
val waiter = new JobWaiter(partitions.size)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter))
waiter.getResult() match {
case JobSucceeded(results: Seq[_]) =>
return results.asInstanceOf[Seq[U]].toArray
val (toSubmit, waiter) = prepareJob(
finalRdd, func, partitions, callSite, allowLocal, resultHandler)
eventQueue.put(toSubmit)
waiter.awaitResult() match {
case JobSucceeded => {}
case JobFailed(exception: Exception) =>
logInfo("Failed to run " + callSite)
throw exception
@ -233,90 +259,117 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.splits.size).toArray
eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener))
return listener.getResult() // Will throw an exception if the job fails
return listener.awaitResult() // Will throw an exception if the job fails
}
/**
* Process one event retrieved from the event queue.
* Returns true if we should stop the event loop.
*/
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
val runId = nextRunId.getAndIncrement()
val finalStage = newStage(finalRDD, None, runId)
val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
clearCacheLocs()
logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
" output partitions (allowLocal=" + allowLocal + ")")
logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))
if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
// Compute very short actions like first() or take() with no parent stages locally.
runLocally(job)
} else {
activeJobs += job
resultStageToJob(finalStage) = job
submitStage(finalStage)
}
case ExecutorLost(execId) =>
handleExecutorLost(execId)
case completion: CompletionEvent =>
handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
abortStage(idToStage(taskSet.stageId), reason)
case StopDAGScheduler =>
// Cancel any active jobs
for (job <- activeJobs) {
val error = new SparkException("Job cancelled because SparkContext was shut down")
job.listener.jobFailed(error)
}
return true
}
return false
}
/**
* Resubmit any failed stages. Ordinarily called after a small amount of time has passed since
* the last fetch failure.
*/
private[scheduler] def resubmitFailedStages() {
logInfo("Resubmitting failed stages")
clearCacheLocs()
val failed2 = failed.toArray
failed.clear()
for (stage <- failed2.sortBy(_.priority)) {
submitStage(stage)
}
}
/**
* Check for waiting or failed stages which are now eligible for resubmission.
* Ordinarily run on every iteration of the event loop.
*/
private[scheduler] def submitWaitingStages() {
// TODO: We might want to run this less often, when we are sure that something has become
// runnable that wasn't before.
logTrace("Checking for newly runnable parent stages")
logTrace("running: " + running)
logTrace("waiting: " + waiting)
logTrace("failed: " + failed)
val waiting2 = waiting.toArray
waiting.clear()
for (stage <- waiting2.sortBy(_.priority)) {
submitStage(stage)
}
}
/**
* The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
* events and responds by launching tasks. This runs in a dedicated thread and receives events
* via the eventQueue.
*/
def run() {
private def run() {
SparkEnv.set(env)
while (true) {
val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
if (event != null) {
logDebug("Got event of type " + event.getClass.getName)
}
event match {
case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
val runId = nextRunId.getAndIncrement()
val finalStage = newStage(finalRDD, None, runId)
val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
clearCacheLocs()
logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
" output partitions")
logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))
if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
// Compute very short actions like first() or take() with no parent stages locally.
runLocally(job)
} else {
activeJobs += job
resultStageToJob(finalStage) = job
submitStage(finalStage)
}
case ExecutorLost(execId) =>
handleExecutorLost(execId)
case completion: CompletionEvent =>
handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
abortStage(idToStage(taskSet.stageId), reason)
case StopDAGScheduler =>
// Cancel any active jobs
for (job <- activeJobs) {
val error = new SparkException("Job cancelled because SparkContext was shut down")
job.listener.jobFailed(error)
}
if (event != null) {
if (processEvent(event)) {
return
case null =>
// queue.poll() timed out, ignore it
}
}
val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
// Periodically resubmit failed stages if some map output fetches have failed and we have
// waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
// tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
// the same time, so we want to make sure we've identified all the reduce tasks that depend
// on the failed node.
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
logInfo("Resubmitting failed stages")
clearCacheLocs()
val failed2 = failed.toArray
failed.clear()
for (stage <- failed2.sortBy(_.priority)) {
submitStage(stage)
}
resubmitFailedStages()
} else {
// TODO: We might want to run this less often, when we are sure that something has become
// runnable that wasn't before.
logTrace("Checking for newly runnable parent stages")
logTrace("running: " + running)
logTrace("waiting: " + waiting)
logTrace("failed: " + failed)
val waiting2 = waiting.toArray
waiting.clear()
for (stage <- waiting2.sortBy(_.priority)) {
submitStage(stage)
}
submitWaitingStages()
}
}
}
@ -326,7 +379,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* We run the operation in a separate thread just in case it takes a bunch of time, so that we
* don't block the DAGScheduler event loop or other concurrent jobs.
*/
def runLocally(job: ActiveJob) {
private def runLocally(job: ActiveJob) {
logInfo("Computing the requested partition locally")
new Thread("Local computation of job " + job.runId) {
override def run() {
@ -349,13 +402,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}.start()
}
def submitStage(stage: Stage) {
/** Submits stage, but first recursively submits any missing parents. */
private def submitStage(stage: Stage) {
logDebug("submitStage(" + stage + ")")
if (!waiting(stage) && !running(stage) && !failed(stage)) {
val missing = getMissingParentStages(stage).sortBy(_.id)
logDebug("missing: " + missing)
if (missing == Nil) {
logInfo("Submitting " + stage + " (" + stage.origin + "), which has no missing parents")
logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
submitMissingTasks(stage)
running += stage
} else {
@ -367,7 +421,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
}
def submitMissingTasks(stage: Stage) {
/** Called when stage's parents are available and we can now do its task. */
private def submitMissingTasks(stage: Stage) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
@ -388,7 +443,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
}
if (tasks.size > 0) {
logInfo("Submitting " + tasks.size + " missing tasks from " + stage)
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
myPending ++= tasks
logDebug("New pending tasks: " + myPending)
taskSched.submitTasks(
@ -407,7 +462,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Responds to a task finishing. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
*/
def handleTaskCompletion(event: CompletionEvent) {
private def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
val stage = idToStage(task.stageId)
@ -492,7 +547,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
waiting --= newlyRunnable
running ++= newlyRunnable
for (stage <- newlyRunnable.sortBy(_.id)) {
logInfo("Submitting " + stage + " (" + stage.origin + "), which is now runnable")
logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable")
submitMissingTasks(stage)
}
}
@ -541,12 +596,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Optionally the generation during which the failure was caught can be passed to avoid allowing
* stray fetch failures from possibly retriggering the detection of a node as lost.
*/
def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) {
private def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) {
val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration)
if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) {
failedGeneration(execId) = currentGeneration
logInfo("Executor lost: %s (generation %d)".format(execId, currentGeneration))
env.blockManager.master.removeExecutor(execId)
blockManagerMaster.removeExecutor(execId)
// TODO: This will be really slow if we keep accumulating shuffle map stages
for ((shuffleId, stage) <- shuffleToMapStage) {
stage.removeOutputsOnExecutor(execId)
@ -567,7 +622,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
def abortStage(failedStage: Stage, reason: String) {
private def abortStage(failedStage: Stage, reason: String) {
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
@ -583,7 +638,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
/**
* Return true if one of stage's ancestors is target.
*/
def stageDependsOn(stage: Stage, target: Stage): Boolean = {
private def stageDependsOn(stage: Stage, target: Stage): Boolean = {
if (stage == target) {
return true
}
@ -610,7 +665,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
visitedRdds.contains(target.rdd)
}
def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
private def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
// If the partition is cached, return the cache locations
val cached = getCacheLocs(rdd)(partition)
if (cached != Nil) {
@ -636,7 +691,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
return Nil
}
def cleanup(cleanupTime: Long) {
private def cleanup(cleanupTime: Long) {
var sizeBefore = idToStage.size
idToStage.clearOldValues(cleanupTime)
logInfo("idToStage " + sizeBefore + " --> " + idToStage.size)

View file

@ -5,5 +5,5 @@ package spark.scheduler
*/
private[spark] sealed trait JobResult
private[spark] case class JobSucceeded(results: Seq[_]) extends JobResult
private[spark] case object JobSucceeded extends JobResult
private[spark] case class JobFailed(exception: Exception) extends JobResult

View file

@ -3,10 +3,12 @@ package spark.scheduler
import scala.collection.mutable.ArrayBuffer
/**
* An object that waits for a DAGScheduler job to complete.
* An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
* results to the given handler function.
*/
private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null)
private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit)
extends JobListener {
private var finishedTasks = 0
private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
@ -17,11 +19,11 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
if (jobFinished) {
throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
}
taskResults(index) = result
resultHandler(index, result.asInstanceOf[T])
finishedTasks += 1
if (finishedTasks == totalTasks) {
jobFinished = true
jobResult = JobSucceeded(taskResults)
jobResult = JobSucceeded
this.notifyAll()
}
}
@ -38,7 +40,7 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
}
}
def getResult(): JobResult = synchronized {
def awaitResult(): JobResult = synchronized {
while (!jobFinished) {
this.wait()
}

View file

@ -32,7 +32,7 @@ private[spark] object ShuffleMapTask {
return old
} else {
val out = new ByteArrayOutputStream
val ser = SparkEnv.get.closureSerializer.newInstance
val ser = SparkEnv.get.closureSerializer.newInstance()
val objOut = ser.serializeStream(new GZIPOutputStream(out))
objOut.writeObject(rdd)
objOut.writeObject(dep)
@ -48,7 +48,7 @@ private[spark] object ShuffleMapTask {
synchronized {
val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val ser = SparkEnv.get.closureSerializer.newInstance
val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
@ -127,7 +127,6 @@ private[spark] class ShuffleMapTask(
val bucketId = dep.partitioner.getPartition(pair._1)
buckets(bucketId) += pair
}
val bucketIterators = buckets.map(_.iterator)
val compressedSizes = new Array[Byte](numOutputSplits)
@ -135,7 +134,7 @@ private[spark] class ShuffleMapTask(
for (i <- 0 until numOutputSplits) {
val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
// Get a Scala iterator from Java map
val iter: Iterator[(Any, Any)] = bucketIterators(i)
val iter: Iterator[(Any, Any)] = buckets(i).iterator
val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
compressedSizes(i) = MapOutputTracker.compressSize(size)
}

View file

@ -86,7 +86,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
def submitTasks(taskSet: TaskSet) {
override def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {

View file

@ -1,5 +1,7 @@
package spark.scheduler.cluster
import spark.Utils
/**
* A backend interface for cluster scheduling systems that allows plugging in different ones under
* ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as
@ -11,5 +13,15 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit
def defaultParallelism(): Int
// Memory used by each executor (in megabytes)
protected val executorMemory = {
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
Option(System.getProperty("spark.executor.memory"))
.orElse(Option(System.getenv("SPARK_MEM")))
.map(Utils.memoryStringToMb)
.getOrElse(512)
}
// TODO: Probably want to add a killTask too
}

View file

@ -20,16 +20,6 @@ private[spark] class SparkDeploySchedulerBackend(
val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt
// Memory used by each executor (in megabytes)
val executorMemory = {
if (System.getenv("SPARK_MEM") != null) {
Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
} else {
512
}
}
override def start() {
super.start()

View file

@ -17,10 +17,7 @@ import java.nio.ByteBuffer
/**
* Schedules the tasks within a single TaskSet in the ClusterScheduler.
*/
private[spark] class TaskSetManager(
sched: ClusterScheduler,
val taskSet: TaskSet)
extends Logging {
private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) extends Logging {
// Maximum time to wait to run a task in a preferred location (in ms)
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
@ -100,7 +97,7 @@ private[spark] class TaskSetManager(
}
// Add a task to all the pending-task lists that it should be on.
def addPendingTask(index: Int) {
private def addPendingTask(index: Int) {
val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
if (locations.size == 0) {
pendingTasksWithNoPrefs += index
@ -115,7 +112,7 @@ private[spark] class TaskSetManager(
// Return the pending tasks list for a given host, or an empty list if
// there is no map entry for that host
def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
pendingTasksForHost.getOrElse(host, ArrayBuffer())
}
@ -123,7 +120,7 @@ private[spark] class TaskSetManager(
// Return None if the list is empty.
// This method also cleans up any tasks in the list that have already
// been launched, since we want that to happen lazily.
def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
while (!list.isEmpty) {
val index = list.last
list.trimEnd(1)
@ -137,7 +134,7 @@ private[spark] class TaskSetManager(
// Return a speculative task for a given host if any are available. The task should not have an
// attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
// task must have a preference for this host (or no preferred locations at all).
def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
val hostsAlive = sched.hostsAlive
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
val localTask = speculatableTasks.find {
@ -162,7 +159,7 @@ private[spark] class TaskSetManager(
// Dequeue a pending task for a given node and return its index.
// If localOnly is set to false, allow non-local tasks as well.
def findTask(host: String, localOnly: Boolean): Option[Int] = {
private def findTask(host: String, localOnly: Boolean): Option[Int] = {
val localTask = findTaskFromList(getPendingTasksForHost(host))
if (localTask != None) {
return localTask
@ -184,7 +181,7 @@ private[spark] class TaskSetManager(
// Does a host count as a preferred location for a task? This is true if
// either the task has preferred locations and this host is one, or it has
// no preferred locations (in which we still count the launch as preferred).
def isPreferredLocation(task: Task[_], host: String): Boolean = {
private def isPreferredLocation(task: Task[_], host: String): Boolean = {
val locs = task.preferredLocations
return (locs.contains(host) || locs.isEmpty)
}
@ -335,7 +332,7 @@ private[spark] class TaskSetManager(
if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %s:%d failed more than %d times; aborting job".format(
taskSet.id, index, MAX_TASK_FAILURES))
abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES))
abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
}
}
} else {

View file

@ -53,7 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
}
def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
logInfo("Running task " + idInJob)
logInfo("Running " + task)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
try {
@ -80,7 +80,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
val resultToReturn = ser.deserialize[Any](ser.serialize(result))
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
logInfo("Finished task " + idInJob)
logInfo("Finished " + task)
// If the threadpool has not already been shutdown, notify DAGScheduler
if (!Thread.currentThread().isInterrupted)

View file

@ -35,16 +35,6 @@ private[spark] class CoarseMesosSchedulerBackend(
val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures
// Memory used by each executor (in megabytes)
val executorMemory = {
if (System.getenv("SPARK_MEM") != null) {
Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
} else {
512
}
}
// Lock used to wait for scheduler to be registered
var isRegistered = false
val registeredLock = new Object()

View file

@ -29,16 +29,6 @@ private[spark] class MesosSchedulerBackend(
with MScheduler
with Logging {
// Memory used by each executor (in megabytes)
val EXECUTOR_MEMORY = {
if (System.getenv("SPARK_MEM") != null) {
Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
} else {
512
}
}
// Lock used to wait for scheduler to be registered
var isRegistered = false
val registeredLock = new Object()
@ -89,7 +79,7 @@ private[spark] class MesosSchedulerBackend(
val memory = Resource.newBuilder()
.setName("mem")
.setType(Value.Type.SCALAR)
.setScalar(Value.Scalar.newBuilder().setValue(EXECUTOR_MEMORY).build())
.setScalar(Value.Scalar.newBuilder().setValue(executorMemory).build())
.build()
val command = CommandInfo.newBuilder()
.setValue(execScript)
@ -161,7 +151,7 @@ private[spark] class MesosSchedulerBackend(
def enoughMemory(o: Offer) = {
val mem = getResource(o.getResourcesList, "mem")
val slaveId = o.getSlaveId.getValue
mem >= EXECUTOR_MEMORY || slaveIdsWithExecutors.contains(slaveId)
mem >= executorMemory || slaveIdsWithExecutors.contains(slaveId)
}
for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) {

View file

@ -950,6 +950,7 @@ class BlockManager(
blockInfo.clear()
memoryStore.clear()
diskStore.clear()
metadataCleaner.cancel()
logInfo("BlockManager stopped")
}
}

View file

@ -27,8 +27,6 @@ private[spark] class BlockManagerMaster(
val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt
val DRIVER_AKKA_ACTOR_NAME = "BlockMasterManager"
val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager"
val DEFAULT_MANAGER_IP: String = Utils.localHostName()
val timeout = 10.seconds
var driverActor: ActorRef = {
@ -117,6 +115,10 @@ private[spark] class BlockManagerMaster(
askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
}
def getStorageStatus: Array[StorageStatus] = {
askDriverWithReply[ArrayBuffer[StorageStatus]](GetStorageStatus).toArray
}
/** Stop the driver actor, called only on the Spark driver node */
def stop() {
if (driverActor != null) {

View file

@ -1,13 +1,10 @@
package spark.storage
import akka.actor.{ActorRef, ActorSystem}
import akka.pattern.ask
import akka.util.Timeout
import akka.util.duration._
import cc.spray.directives._
import cc.spray.typeconversion.TwirlSupport._
import cc.spray.Directives
import scala.collection.mutable.ArrayBuffer
import spark.{Logging, SparkContext}
import spark.util.AkkaUtils
import spark.Utils
@ -48,32 +45,26 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef,
path("") {
completeWith {
// Request the current storage status from the Master
val future = blockManagerMaster ? GetStorageStatus
future.map { status =>
// Calculate macro-level statistics
val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray
val maxMem = storageStatusList.map(_.maxMem).reduce(_+_)
val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_)
val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize))
.reduceOption(_+_).getOrElse(0L)
val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc)
spark.storage.html.index.
render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList)
}
val storageStatusList = sc.getExecutorStorageStatus
// Calculate macro-level statistics
val maxMem = storageStatusList.map(_.maxMem).reduce(_+_)
val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_)
val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize))
.reduceOption(_+_).getOrElse(0L)
val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc)
spark.storage.html.index.
render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList)
}
} ~
path("rdd") {
parameter("id") { id =>
completeWith {
val future = blockManagerMaster ? GetStorageStatus
future.map { status =>
val prefix = "rdd_" + id.toString
val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray
val filteredStorageStatusList = StorageUtils.
filterStorageStatusByPrefix(storageStatusList, prefix)
val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head
spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList)
}
val prefix = "rdd_" + id.toString
val storageStatusList = sc.getExecutorStorageStatus
val filteredStorageStatusList = StorageUtils.
filterStorageStatusByPrefix(storageStatusList, prefix)
val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head
spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList)
}
}
} ~

View file

@ -1,6 +1,6 @@
package spark.storage
import spark.SparkContext
import spark.{Utils, SparkContext}
import BlockManagerMasterActor.BlockStatus
private[spark]
@ -22,8 +22,13 @@ case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
}
case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel,
numPartitions: Int, memSize: Long, diskSize: Long)
numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) {
override def toString = {
import Utils.memoryBytesToString
"RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; DiskSize: %s".format(name, id,
storageLevel.toString, numCachedPartitions, numPartitions, memoryBytesToString(memSize), memoryBytesToString(diskSize))
}
}
/* Helper methods for storage-related objects */
private[spark]
@ -38,8 +43,6 @@ object StorageUtils {
/* Given a list of BlockStatus objets, returns information for each RDD */
def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
sc: SparkContext) : Array[RDDInfo] = {
// Find all RDD Blocks (ignore broadcast variables)
val rddBlocks = infos.filterKeys(_.startsWith("rdd"))
// Group by rddId, ignore the partition name
val groupedRddBlocks = infos.groupBy { case(k, v) =>
@ -56,10 +59,11 @@ object StorageUtils {
// Find the id of the RDD, e.g. rdd_1 => 1
val rddId = rddKey.split("_").last.toInt
// Get the friendly name for the rdd, if available.
val rddName = Option(sc.persistentRdds(rddId).name).getOrElse(rddKey)
val rddStorageLevel = sc.persistentRdds(rddId).getStorageLevel
RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize)
val rdd = sc.persistentRdds(rddId)
val rddName = Option(rdd.name).getOrElse(rddKey)
val rddStorageLevel = rdd.getStorageLevel
RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.splits.size, memSize, diskSize)
}.toArray
}
@ -75,4 +79,4 @@ object StorageUtils {
}
}
}

View file

@ -18,9 +18,13 @@ import java.util.concurrent.TimeoutException
* Various utility classes for working with Akka.
*/
private[spark] object AkkaUtils {
/**
* Creates an ActorSystem ready for remoting, with various Spark features. Returns both the
* ActorSystem itself and its port (which is hard to get from Akka).
*
* Note: the `name` parameter is important, as even if a client sends a message to right
* host + port, if the system name is incorrect, Akka will drop the message.
*/
def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = {
val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt
@ -30,6 +34,7 @@ private[spark] object AkkaUtils {
val akkaConf = ConfigFactory.parseString("""
akka.daemonic = on
akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
akka.stdout-loglevel = "ERROR"
akka.actor.provider = "akka.remote.RemoteActorRefProvider"
akka.remote.transport = "akka.remote.netty.NettyRemoteTransport"
akka.remote.log-remote-lifecycle-events = on
@ -41,7 +46,7 @@ private[spark] object AkkaUtils {
akka.actor.default-dispatcher.throughput = %d
""".format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize))
val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader)
val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader)
// Figure out the port number we bound to, in case port was passed as 0. This is a bit of a
// hack because Akka doesn't let you figure out the port through the public API yet.

View file

@ -9,12 +9,12 @@ import spark.Logging
* Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries)
*/
class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging {
val delaySeconds = MetadataCleaner.getDelaySeconds
val periodSeconds = math.max(10, delaySeconds / 10)
val timer = new Timer(name + " cleanup timer", true)
private val delaySeconds = MetadataCleaner.getDelaySeconds
private val periodSeconds = math.max(10, delaySeconds / 10)
private val timer = new Timer(name + " cleanup timer", true)
val task = new TimerTask {
def run() {
private val task = new TimerTask {
override def run() {
try {
cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000))
logInfo("Ran metadata cleaner for " + name)

View file

@ -11,7 +11,11 @@
<strong>Storage Level:</strong>
@(rddInfo.storageLevel.description)
<li>
<strong>Partitions:</strong>
<strong>Cached Partitions:</strong>
@(rddInfo.numCachedPartitions)
</li>
<li>
<strong>Total Partitions:</strong>
@(rddInfo.numPartitions)
</li>
<li>

View file

@ -6,7 +6,8 @@
<tr>
<th>RDD Name</th>
<th>Storage Level</th>
<th>Partitions</th>
<th>Cached Partitions</th>
<th>Fraction Partitions Cached</th>
<th>Size in Memory</th>
<th>Size on Disk</th>
</tr>
@ -21,7 +22,8 @@
</td>
<td>@(rdd.storageLevel.description)
</td>
<td>@rdd.numPartitions</td>
<td>@rdd.numCachedPartitions</td>
<td>@(rdd.numCachedPartitions / rdd.numPartitions.toDouble)</td>
<td>@{Utils.memoryBytesToString(rdd.memSize)}</td>
<td>@{Utils.memoryBytesToString(rdd.diskSize)}</td>
</tr>

View file

@ -12,9 +12,10 @@ class RDDSuite extends FunSuite with LocalSparkContext {
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
assert(dups.distinct.count === 4)
assert(dups.distinct().collect === dups.distinct.collect)
assert(dups.distinct(2).collect === dups.distinct.collect)
assert(dups.distinct().count() === 4)
assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses?
assert(dups.distinct().collect === dups.distinct().collect)
assert(dups.distinct(2).collect === dups.distinct().collect)
assert(nums.reduce(_ + _) === 10)
assert(nums.fold(0)(_ + _) === 10)
assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4"))
@ -31,6 +32,10 @@ class RDDSuite extends FunSuite with LocalSparkContext {
case(split, iter) => Iterator((split, iter.reduceLeft(_ + _)))
}
assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7)))
intercept[UnsupportedOperationException] {
nums.filter(_ > 5).reduce(_ + _)
}
}
test("SparkContext.union") {
@ -164,7 +169,7 @@ class RDDSuite extends FunSuite with LocalSparkContext {
// Note that split number starts from 0, so > 8 means only 10th partition left.
val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8)
assert(prunedRdd.splits.size === 1)
val prunedData = prunedRdd.collect
val prunedData = prunedRdd.collect()
assert(prunedData.size === 1)
assert(prunedData(0) === 10)
}

View file

@ -0,0 +1,663 @@
package spark.scheduler
import scala.collection.mutable.{Map, HashMap}
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.TimeLimitedTests
import org.scalatest.mock.EasyMockSugar
import org.scalatest.time.{Span, Seconds}
import org.easymock.EasyMock._
import org.easymock.Capture
import org.easymock.EasyMock
import org.easymock.{IAnswer, IArgumentMatcher}
import akka.actor.ActorSystem
import spark.storage.BlockManager
import spark.storage.BlockManagerId
import spark.storage.BlockManagerMaster
import spark.{Dependency, ShuffleDependency, OneToOneDependency}
import spark.FetchFailedException
import spark.MapOutputTracker
import spark.RDD
import spark.SparkContext
import spark.SparkException
import spark.Split
import spark.TaskContext
import spark.TaskEndReason
import spark.{FetchFailed, Success}
/**
* Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
* rather than spawning an event loop thread as happens in the real code. They use EasyMock
* to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are
* submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead
* host notifications are sent). In addition, tests may check for side effects on a non-mocked
* MapOutputTracker instance.
*
* Tests primarily consist of running DAGScheduler#processEvent and
* DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet)
* and capturing the resulting TaskSets from the mock TaskScheduler.
*/
class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar with TimeLimitedTests {
// impose a time limit on this test in case we don't let the job finish, in which case
// JobWaiter#getResult will hang.
override val timeLimit = Span(5, Seconds)
val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite")
var scheduler: DAGScheduler = null
val taskScheduler = mock[TaskScheduler]
val blockManagerMaster = mock[BlockManagerMaster]
var mapOutputTracker: MapOutputTracker = null
var schedulerThread: Thread = null
var schedulerException: Throwable = null
/**
* Set of EasyMock argument matchers that match a TaskSet for a given RDD.
* We cache these so we do not create duplicate matchers for the same RDD.
* This allows us to easily setup a sequence of expectations for task sets for
* that RDD.
*/
val taskSetMatchers = new HashMap[MyRDD, IArgumentMatcher]
/**
* Set of cache locations to return from our mock BlockManagerMaster.
* Keys are (rdd ID, partition ID). Anything not present will return an empty
* list of cache locations silently.
*/
val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
/**
* JobWaiter for the last JobSubmitted event we pushed. To keep tests (most of which
* will only submit one job) from needing to explicitly track it.
*/
var lastJobWaiter: JobWaiter[Int] = null
/**
* Array into which we are accumulating the results from the last job asynchronously.
*/
var lastJobResult: Array[Int] = null
/**
* Tell EasyMockSugar what mock objects we want to be configured by expecting {...}
* and whenExecuting {...} */
implicit val mocks = MockObjects(taskScheduler, blockManagerMaster)
/**
* Utility function to reset mocks and set expectations on them. EasyMock wants mock objects
* to be reset after each time their expectations are set, and we tend to check mock object
* calls over a single call to DAGScheduler.
*
* We also set a default expectation here that blockManagerMaster.getLocations can be called
* and will return values from cacheLocations.
*/
def resetExpecting(f: => Unit) {
reset(taskScheduler)
reset(blockManagerMaster)
expecting {
expectGetLocations()
f
}
}
before {
taskSetMatchers.clear()
cacheLocations.clear()
val actorSystem = ActorSystem("test")
mapOutputTracker = new MapOutputTracker(actorSystem, true)
resetExpecting {
taskScheduler.setListener(anyObject())
}
whenExecuting {
scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null)
}
}
after {
assert(scheduler.processEvent(StopDAGScheduler))
resetExpecting {
taskScheduler.stop()
}
whenExecuting {
scheduler.stop()
}
sc.stop()
System.clearProperty("spark.master.port")
}
def makeBlockManagerId(host: String): BlockManagerId =
BlockManagerId("exec-" + host, host, 12345)
/**
* Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
* This is a pair RDD type so it can always be used in ShuffleDependencies.
*/
type MyRDD = RDD[(Int, Int)]
/**
* Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and
* preferredLocations (if any) that are passed to them. They are deliberately not executable
* so we can test that DAGScheduler does not try to execute RDDs locally.
*/
def makeRdd(
numSplits: Int,
dependencies: List[Dependency[_]],
locations: Seq[Seq[String]] = Nil
): MyRDD = {
val maxSplit = numSplits - 1
return new MyRDD(sc, dependencies) {
override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] =
throw new RuntimeException("should not be reached")
override def getSplits() = (0 to maxSplit).map(i => new Split {
override def index = i
}).toArray
override def getPreferredLocations(split: Split): Seq[String] =
if (locations.isDefinedAt(split.index))
locations(split.index)
else
Nil
override def toString: String = "DAGSchedulerSuiteRDD " + id
}
}
/**
* EasyMock matcher method. For use as an argument matcher for a TaskSet whose first task
* is from a particular RDD.
*/
def taskSetForRdd(rdd: MyRDD): TaskSet = {
val matcher = taskSetMatchers.getOrElseUpdate(rdd,
new IArgumentMatcher {
override def matches(actual: Any): Boolean = {
val taskSet = actual.asInstanceOf[TaskSet]
taskSet.tasks(0) match {
case rt: ResultTask[_, _] => rt.rdd.id == rdd.id
case smt: ShuffleMapTask => smt.rdd.id == rdd.id
case _ => false
}
}
override def appendTo(buf: StringBuffer) {
buf.append("taskSetForRdd(" + rdd + ")")
}
})
EasyMock.reportMatcher(matcher)
return null
}
/**
* Setup an EasyMock expectation to repsond to blockManagerMaster.getLocations() called from
* cacheLocations.
*/
def expectGetLocations(): Unit = {
EasyMock.expect(blockManagerMaster.getLocations(anyObject().asInstanceOf[Array[String]])).
andAnswer(new IAnswer[Seq[Seq[BlockManagerId]]] {
override def answer(): Seq[Seq[BlockManagerId]] = {
val blocks = getCurrentArguments()(0).asInstanceOf[Array[String]]
return blocks.map { name =>
val pieces = name.split("_")
if (pieces(0) == "rdd") {
val key = pieces(1).toInt -> pieces(2).toInt
if (cacheLocations.contains(key)) {
cacheLocations(key)
} else {
Seq[BlockManagerId]()
}
} else {
Seq[BlockManagerId]()
}
}.toSeq
}
}).anyTimes()
}
/**
* Process the supplied event as if it were the top of the DAGScheduler event queue, expecting
* the scheduler not to exit.
*
* After processing the event, submit waiting stages as is done on most iterations of the
* DAGScheduler event loop.
*/
def runEvent(event: DAGSchedulerEvent) {
assert(!scheduler.processEvent(event))
scheduler.submitWaitingStages()
}
/**
* Expect a TaskSet for the specified RDD to be submitted to the TaskScheduler. Should be
* called from a resetExpecting { ... } block.
*
* Returns a easymock Capture that will contain the task set after the stage is submitted.
* Most tests should use interceptStage() instead of this directly.
*/
def expectStage(rdd: MyRDD): Capture[TaskSet] = {
val taskSetCapture = new Capture[TaskSet]
taskScheduler.submitTasks(and(capture(taskSetCapture), taskSetForRdd(rdd)))
return taskSetCapture
}
/**
* Expect the supplied code snippet to submit a stage for the specified RDD.
* Return the resulting TaskSet. First marks all the tasks are belonging to the
* current MapOutputTracker generation.
*/
def interceptStage(rdd: MyRDD)(f: => Unit): TaskSet = {
var capture: Capture[TaskSet] = null
resetExpecting {
capture = expectStage(rdd)
}
whenExecuting {
f
}
val taskSet = capture.getValue
for (task <- taskSet.tasks) {
task.generation = mapOutputTracker.getGeneration
}
return taskSet
}
/**
* Send the given CompletionEvent messages for the tasks in the TaskSet.
*/
def respondToTaskSet(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
assert(taskSet.tasks.size >= results.size)
for ((result, i) <- results.zipWithIndex) {
if (i < taskSet.tasks.size) {
runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any]()))
}
}
}
/**
* Assert that the supplied TaskSet has exactly the given preferredLocations.
*/
def expectTaskSetLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
assert(locations.size === taskSet.tasks.size)
for ((expectLocs, taskLocs) <-
taskSet.tasks.map(_.preferredLocations).zip(locations)) {
assert(expectLocs === taskLocs)
}
}
/**
* When we submit dummy Jobs, this is the compute function we supply. Except in a local test
* below, we do not expect this function to ever be executed; instead, we will return results
* directly through CompletionEvents.
*/
def jobComputeFunc(context: TaskContext, it: Iterator[(Int, Int)]): Int =
it.next._1.asInstanceOf[Int]
/**
* Start a job to compute the given RDD. Returns the JobWaiter that will
* collect the result of the job via callbacks from DAGScheduler.
*/
def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): (JobWaiter[Int], Array[Int]) = {
val resultArray = new Array[Int](rdd.splits.size)
val (toSubmit, waiter) = scheduler.prepareJob[(Int, Int), Int](
rdd,
jobComputeFunc,
(0 to (rdd.splits.size - 1)),
"test-site",
allowLocal,
(i: Int, value: Int) => resultArray(i) = value
)
lastJobWaiter = waiter
lastJobResult = resultArray
runEvent(toSubmit)
return (waiter, resultArray)
}
/**
* Assert that a job we started has failed.
*/
def expectJobException(waiter: JobWaiter[Int] = lastJobWaiter) {
waiter.awaitResult() match {
case JobSucceeded => fail()
case JobFailed(_) => return
}
}
/**
* Assert that a job we started has succeeded and has the given result.
*/
def expectJobResult(expected: Array[Int], waiter: JobWaiter[Int] = lastJobWaiter,
result: Array[Int] = lastJobResult) {
waiter.awaitResult match {
case JobSucceeded =>
assert(expected === result)
case JobFailed(_) =>
fail()
}
}
def makeMapStatus(host: String, reduces: Int): MapStatus =
new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
test("zero split job") {
val rdd = makeRdd(0, Nil)
var numResults = 0
def accumulateResult(partition: Int, value: Int) {
numResults += 1
}
scheduler.runJob(rdd, jobComputeFunc, Seq(), "test-site", false, accumulateResult)
assert(numResults === 0)
}
test("run trivial job") {
val rdd = makeRdd(1, Nil)
val taskSet = interceptStage(rdd) { submitRdd(rdd) }
respondToTaskSet(taskSet, List( (Success, 42) ))
expectJobResult(Array(42))
}
test("local job") {
val rdd = new MyRDD(sc, Nil) {
override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] =
Array(42 -> 0).iterator
override def getSplits() = Array( new Split { override def index = 0 } )
override def getPreferredLocations(split: Split) = Nil
override def toString = "DAGSchedulerSuite Local RDD"
}
submitRdd(rdd, true)
expectJobResult(Array(42))
}
test("run trivial job w/ dependency") {
val baseRdd = makeRdd(1, Nil)
val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
respondToTaskSet(taskSet, List( (Success, 42) ))
expectJobResult(Array(42))
}
test("cache location preferences w/ dependency") {
val baseRdd = makeRdd(1, Nil)
val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
cacheLocations(baseRdd.id -> 0) =
Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
expectTaskSetLocations(taskSet, List(Seq("hostA", "hostB")))
respondToTaskSet(taskSet, List( (Success, 42) ))
expectJobResult(Array(42))
}
test("trivial job failure") {
val rdd = makeRdd(1, Nil)
val taskSet = interceptStage(rdd) { submitRdd(rdd) }
runEvent(TaskSetFailed(taskSet, "test failure"))
expectJobException()
}
test("run trivial shuffle") {
val shuffleMapRdd = makeRdd(2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
val shuffleId = shuffleDep.shuffleId
val reduceRdd = makeRdd(1, List(shuffleDep))
val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
val secondStage = interceptStage(reduceRdd) {
respondToTaskSet(firstStage, List(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1))
))
}
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
respondToTaskSet(secondStage, List( (Success, 42) ))
expectJobResult(Array(42))
}
test("run trivial shuffle with fetch failure") {
val shuffleMapRdd = makeRdd(2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
val shuffleId = shuffleDep.shuffleId
val reduceRdd = makeRdd(2, List(shuffleDep))
val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
val secondStage = interceptStage(reduceRdd) {
respondToTaskSet(firstStage, List(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1))
))
}
resetExpecting {
blockManagerMaster.removeExecutor("exec-hostA")
}
whenExecuting {
respondToTaskSet(secondStage, List(
(Success, 42),
(FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)
))
}
val thirdStage = interceptStage(shuffleMapRdd) {
scheduler.resubmitFailedStages()
}
val fourthStage = interceptStage(reduceRdd) {
respondToTaskSet(thirdStage, List( (Success, makeMapStatus("hostA", 1)) ))
}
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
respondToTaskSet(fourthStage, List( (Success, 43) ))
expectJobResult(Array(42, 43))
}
test("ignore late map task completions") {
val shuffleMapRdd = makeRdd(2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
val shuffleId = shuffleDep.shuffleId
val reduceRdd = makeRdd(2, List(shuffleDep))
val taskSet = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
val oldGeneration = mapOutputTracker.getGeneration
resetExpecting {
blockManagerMaster.removeExecutor("exec-hostA")
}
whenExecuting {
runEvent(ExecutorLost("exec-hostA"))
}
val newGeneration = mapOutputTracker.getGeneration
assert(newGeneration > oldGeneration)
val noAccum = Map[Long, Any]()
// We rely on the event queue being ordered and increasing the generation number by 1
// should be ignored for being too old
runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
// should work because it's a non-failed host
runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum))
// should be ignored for being too old
runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
taskSet.tasks(1).generation = newGeneration
val secondStage = interceptStage(reduceRdd) {
runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum))
}
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
respondToTaskSet(secondStage, List( (Success, 42), (Success, 43) ))
expectJobResult(Array(42, 43))
}
test("run trivial shuffle with out-of-band failure and retry") {
val shuffleMapRdd = makeRdd(2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
val shuffleId = shuffleDep.shuffleId
val reduceRdd = makeRdd(1, List(shuffleDep))
val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
resetExpecting {
blockManagerMaster.removeExecutor("exec-hostA")
}
whenExecuting {
runEvent(ExecutorLost("exec-hostA"))
}
// DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
// rather than marking it is as failed and waiting.
val secondStage = interceptStage(shuffleMapRdd) {
respondToTaskSet(firstStage, List(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1))
))
}
val thirdStage = interceptStage(reduceRdd) {
respondToTaskSet(secondStage, List(
(Success, makeMapStatus("hostC", 1))
))
}
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
respondToTaskSet(thirdStage, List( (Success, 42) ))
expectJobResult(Array(42))
}
test("recursive shuffle failures") {
val shuffleOneRdd = makeRdd(2, Nil)
val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
val finalRdd = makeRdd(1, List(shuffleDepTwo))
val firstStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
val secondStage = interceptStage(shuffleTwoRdd) {
respondToTaskSet(firstStage, List(
(Success, makeMapStatus("hostA", 2)),
(Success, makeMapStatus("hostB", 2))
))
}
val thirdStage = interceptStage(finalRdd) {
respondToTaskSet(secondStage, List(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostC", 1))
))
}
resetExpecting {
blockManagerMaster.removeExecutor("exec-hostA")
}
whenExecuting {
respondToTaskSet(thirdStage, List(
(FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
))
}
val recomputeOne = interceptStage(shuffleOneRdd) {
scheduler.resubmitFailedStages()
}
val recomputeTwo = interceptStage(shuffleTwoRdd) {
respondToTaskSet(recomputeOne, List(
(Success, makeMapStatus("hostA", 2))
))
}
val finalStage = interceptStage(finalRdd) {
respondToTaskSet(recomputeTwo, List(
(Success, makeMapStatus("hostA", 1))
))
}
respondToTaskSet(finalStage, List( (Success, 42) ))
expectJobResult(Array(42))
}
test("cached post-shuffle") {
val shuffleOneRdd = makeRdd(2, Nil)
val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
val finalRdd = makeRdd(1, List(shuffleDepTwo))
val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
val secondShuffleStage = interceptStage(shuffleTwoRdd) {
respondToTaskSet(firstShuffleStage, List(
(Success, makeMapStatus("hostA", 2)),
(Success, makeMapStatus("hostB", 2))
))
}
val reduceStage = interceptStage(finalRdd) {
respondToTaskSet(secondShuffleStage, List(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1))
))
}
resetExpecting {
blockManagerMaster.removeExecutor("exec-hostA")
}
whenExecuting {
respondToTaskSet(reduceStage, List(
(FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
))
}
// DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
val recomputeTwo = interceptStage(shuffleTwoRdd) {
scheduler.resubmitFailedStages()
}
expectTaskSetLocations(recomputeTwo, Seq(Seq("hostD")))
val finalRetry = interceptStage(finalRdd) {
respondToTaskSet(recomputeTwo, List(
(Success, makeMapStatus("hostD", 1))
))
}
respondToTaskSet(finalRetry, List( (Success, 42) ))
expectJobResult(Array(42))
}
test("cached post-shuffle but fails") {
val shuffleOneRdd = makeRdd(2, Nil)
val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
val finalRdd = makeRdd(1, List(shuffleDepTwo))
val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
val secondShuffleStage = interceptStage(shuffleTwoRdd) {
respondToTaskSet(firstShuffleStage, List(
(Success, makeMapStatus("hostA", 2)),
(Success, makeMapStatus("hostB", 2))
))
}
val reduceStage = interceptStage(finalRdd) {
respondToTaskSet(secondShuffleStage, List(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1))
))
}
resetExpecting {
blockManagerMaster.removeExecutor("exec-hostA")
}
whenExecuting {
respondToTaskSet(reduceStage, List(
(FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
))
}
val recomputeTwoCached = interceptStage(shuffleTwoRdd) {
scheduler.resubmitFailedStages()
}
expectTaskSetLocations(recomputeTwoCached, Seq(Seq("hostD")))
intercept[FetchFailedException]{
mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0)
}
// Simulate the shuffle input data failing to be cached.
cacheLocations.remove(shuffleTwoRdd.id -> 0)
respondToTaskSet(recomputeTwoCached, List(
(FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null)
))
// After the fetch failure, DAGScheduler should recheck the cache and decide to resubmit
// everything.
val recomputeOne = interceptStage(shuffleOneRdd) {
scheduler.resubmitFailedStages()
}
// We use hostA here to make sure DAGScheduler doesn't think it's still dead.
val recomputeTwoUncached = interceptStage(shuffleTwoRdd) {
respondToTaskSet(recomputeOne, List( (Success, makeMapStatus("hostA", 1)) ))
}
expectTaskSetLocations(recomputeTwoUncached, Seq(Seq[String]()))
val finalRetry = interceptStage(finalRdd) {
respondToTaskSet(recomputeTwoUncached, List( (Success, makeMapStatus("hostA", 1)) ))
}
respondToTaskSet(finalRetry, List( (Success, 42) ))
expectJobResult(Array(42))
}
}

View file

@ -273,6 +273,12 @@
<version>1.8</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.easymock</groupId>
<artifactId>easymock</artifactId>
<version>3.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.version}</artifactId>

View file

@ -92,7 +92,8 @@ object SparkBuild extends Build {
"org.eclipse.jetty" % "jetty-server" % "7.5.3.v20111011",
"org.scalatest" %% "scalatest" % "1.8" % "test",
"org.scalacheck" %% "scalacheck" % "1.9" % "test",
"com.novocode" % "junit-interface" % "0.8" % "test"
"com.novocode" % "junit-interface" % "0.8" % "test",
"org.easymock" % "easymock" % "3.1" % "test"
),
parallelExecution := false,
/* Workaround for issue #206 (fixed after SBT 0.11.0) */

View file

@ -1,8 +1,6 @@
import os
import atexit
import shutil
import sys
import tempfile
from threading import Lock
from tempfile import NamedTemporaryFile
@ -24,11 +22,10 @@ class SparkContext(object):
broadcast variables on that cluster.
"""
gateway = launch_gateway()
jvm = gateway.jvm
_readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
_writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
_takePartition = jvm.PythonRDD.takePartition
_gateway = None
_jvm = None
_writeIteratorToPickleFile = None
_takePartition = None
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
@ -56,6 +53,13 @@ class SparkContext(object):
raise ValueError("Cannot run multiple SparkContexts at once")
else:
SparkContext._active_spark_context = self
if not SparkContext._gateway:
SparkContext._gateway = launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
SparkContext._writeIteratorToPickleFile = \
SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
SparkContext._takePartition = \
SparkContext._jvm.PythonRDD.takePartition
self.master = master
self.jobName = jobName
self.sparkHome = sparkHome or None # None becomes null in Py4J
@ -63,8 +67,8 @@ class SparkContext(object):
self.batchSize = batchSize # -1 represents a unlimited batch size
# Create the Java SparkContext through Py4J
empty_string_array = self.gateway.new_array(self.jvm.String, 0)
self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
empty_string_array = self._gateway.new_array(self._jvm.String, 0)
self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome,
empty_string_array)
# Create a single Accumulator in Java that we'll send all our updates through;
@ -72,8 +76,8 @@ class SparkContext(object):
self._accumulatorServer = accumulators._start_update_server()
(host, port) = self._accumulatorServer.server_address
self._javaAccumulator = self._jsc.accumulator(
self.jvm.java.util.ArrayList(),
self.jvm.PythonAccumulatorParam(host, port))
self._jvm.java.util.ArrayList(),
self._jvm.PythonAccumulatorParam(host, port))
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
# Broadcast's __reduce__ method stores Broadcast instances here.
@ -88,6 +92,11 @@ class SparkContext(object):
SparkFiles._sc = self
sys.path.append(SparkFiles.getRootDirectory())
# Create a temporary directory inside spark.local.dir:
local_dir = self._jvm.spark.Utils.getLocalDir()
self._temp_dir = \
self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath()
@property
def defaultParallelism(self):
"""
@ -120,14 +129,14 @@ class SparkContext(object):
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False)
atexit.register(lambda: os.unlink(tempFile.name))
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
if self.batchSize != 1:
c = batched(c, self.batchSize)
for x in c:
write_with_length(dump_pickle(x), tempFile)
tempFile.close()
jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
return RDD(jrdd, self)
def textFile(self, name, minSplits=None):
@ -240,7 +249,9 @@ class SparkContext(object):
def _test():
import atexit
import doctest
import tempfile
globs = globals().copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
globs['tempdir'] = tempfile.mkdtemp()

View file

@ -35,4 +35,4 @@ class SparkFiles(object):
return cls._root_directory
else:
# This will have to change if we support multiple SparkContexts:
return cls._sc.jvm.spark.SparkFiles.getRootDirectory()
return cls._sc._jvm.spark.SparkFiles.getRootDirectory()

View file

@ -1,4 +1,3 @@
import atexit
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
@ -264,12 +263,8 @@ class RDD(object):
# Transferring lots of data through Py4J can be slow because
# socket.readline() is inefficient. Instead, we'll dump the data to a
# file and read it back.
tempFile = NamedTemporaryFile(delete=False)
tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
tempFile.close()
def clean_up_file():
try: os.unlink(tempFile.name)
except: pass
atexit.register(clean_up_file)
self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
@ -407,7 +402,7 @@ class RDD(object):
return (str(x).encode("utf-8") for x in iterator)
keyed = PipelinedRDD(self, func)
keyed._bypass_serializer = True
keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path)
keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path)
# Pair functions
@ -550,8 +545,8 @@ class RDD(object):
yield dump_pickle(Batch(items))
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx.jvm.PythonPartitioner(numSplits,
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx._jvm.PythonPartitioner(numSplits,
id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
rdd = RDD(jrdd, self.ctx)
@ -730,13 +725,13 @@ class PipelinedRDD(RDD):
pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx.gateway._gateway_client)
self.ctx._gateway._gateway_client)
self.ctx._pickled_broadcast_vars.clear()
class_manifest = self._prev_jrdd.classManifest()
env = copy.copy(self.ctx.environment)
env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "")
env = MapConverter().convert(env, self.ctx.gateway._gateway_client)
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
env = MapConverter().convert(env, self.ctx._gateway._gateway_client)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
broadcast_vars, self.ctx._javaAccumulator, class_manifest)
self._jrdd_val = python_rdd.asJavaRDD()

View file

@ -26,7 +26,7 @@ class PySparkTestCase(unittest.TestCase):
sys.path = self._old_sys_path
# To avoid Akka rebinding to the same port, since it doesn't unbind
# immediately on shutdown
self.sc.jvm.System.clearProperty("spark.driver.port")
self.sc._jvm.System.clearProperty("spark.driver.port")
class TestCheckpoint(PySparkTestCase):
@ -108,5 +108,14 @@ class TestAddFile(PySparkTestCase):
self.assertEqual("Hello World!", UserClass().hello())
class TestIO(PySparkTestCase):
def test_stdout_redirection(self):
import subprocess
def func(x):
subprocess.check_call('ls', shell=True)
self.sc.parallelize([1]).foreach(func)
if __name__ == "__main__":
unittest.main()

View file

@ -1,6 +1,7 @@
"""
Worker that receives input from Piped RDD.
"""
import os
import sys
import traceback
from base64 import standard_b64decode
@ -15,8 +16,8 @@ from pyspark.serializers import write_with_length, read_with_length, write_int,
# Redirect stdout to stderr so that users must return values from functions.
old_stdout = sys.stdout
sys.stdout = sys.stderr
old_stdout = os.fdopen(os.dup(1), 'w')
os.dup2(2, 1)
def load_obj():