Merge branch 'mesos'
This commit is contained in:
commit
faa4d9e31f
|
@ -98,6 +98,11 @@
|
|||
<artifactId>scalacheck_${scala.version}</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.easymock</groupId>
|
||||
<artifactId>easymock</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.novocode</groupId>
|
||||
<artifactId>junit-interface</artifactId>
|
||||
|
|
|
@ -61,17 +61,3 @@ class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int)
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Represents a dependency between the PartitionPruningRDD and its parent. In this
|
||||
* case, the child RDD contains a subset of partitions of the parents'.
|
||||
*/
|
||||
class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean)
|
||||
extends NarrowDependency[T](rdd) {
|
||||
|
||||
@transient
|
||||
val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index))
|
||||
|
||||
override def getParents(partitionId: Int) = List(partitions(partitionId).index)
|
||||
}
|
||||
|
|
|
@ -170,7 +170,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
|
|||
}
|
||||
}
|
||||
|
||||
def cleanup(cleanupTime: Long) {
|
||||
private def cleanup(cleanupTime: Long) {
|
||||
mapStatuses.clearOldValues(cleanupTime)
|
||||
cachedSerializedStatuses.clearOldValues(cleanupTime)
|
||||
}
|
||||
|
|
|
@ -465,7 +465,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
val res = self.context.runJob(self, process _, Array(index), false)
|
||||
res(0)
|
||||
case None =>
|
||||
self.filter(_._1 == key).map(_._2).collect
|
||||
self.filter(_._1 == key).map(_._2).collect()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -590,7 +590,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
|
||||
var count = 0
|
||||
while(iter.hasNext) {
|
||||
val record = iter.next
|
||||
val record = iter.next()
|
||||
count += 1
|
||||
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
|
||||
}
|
||||
|
|
|
@ -385,20 +385,22 @@ abstract class RDD[T: ClassManifest](
|
|||
val reducePartition: Iterator[T] => Option[T] = iter => {
|
||||
if (iter.hasNext) {
|
||||
Some(iter.reduceLeft(cleanF))
|
||||
}else {
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
val options = sc.runJob(this, reducePartition)
|
||||
val results = new ArrayBuffer[T]
|
||||
for (opt <- options; elem <- opt) {
|
||||
results += elem
|
||||
}
|
||||
if (results.size == 0) {
|
||||
throw new UnsupportedOperationException("empty collection")
|
||||
} else {
|
||||
return results.reduceLeft(cleanF)
|
||||
var jobResult: Option[T] = None
|
||||
val mergeResult = (index: Int, taskResult: Option[T]) => {
|
||||
if (taskResult != None) {
|
||||
jobResult = jobResult match {
|
||||
case Some(value) => Some(f(value, taskResult.get))
|
||||
case None => taskResult
|
||||
}
|
||||
}
|
||||
}
|
||||
sc.runJob(this, reducePartition, mergeResult)
|
||||
// Get the final result out of our Option, or throw an exception if the RDD was empty
|
||||
jobResult.getOrElse(throw new UnsupportedOperationException("empty collection"))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -408,9 +410,13 @@ abstract class RDD[T: ClassManifest](
|
|||
* modify t2.
|
||||
*/
|
||||
def fold(zeroValue: T)(op: (T, T) => T): T = {
|
||||
// Clone the zero value since we will also be serializing it as part of tasks
|
||||
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
|
||||
val cleanOp = sc.clean(op)
|
||||
val results = sc.runJob(this, (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp))
|
||||
return results.fold(zeroValue)(cleanOp)
|
||||
val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp)
|
||||
val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult)
|
||||
sc.runJob(this, foldPartition, mergeResult)
|
||||
jobResult
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -422,11 +428,14 @@ abstract class RDD[T: ClassManifest](
|
|||
* allocation.
|
||||
*/
|
||||
def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
|
||||
// Clone the zero value since we will also be serializing it as part of tasks
|
||||
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
|
||||
val cleanSeqOp = sc.clean(seqOp)
|
||||
val cleanCombOp = sc.clean(combOp)
|
||||
val results = sc.runJob(this,
|
||||
(iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp))
|
||||
return results.fold(zeroValue)(cleanCombOp)
|
||||
val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
|
||||
val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
|
||||
sc.runJob(this, aggregatePartition, mergeResult)
|
||||
jobResult
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -437,7 +446,7 @@ abstract class RDD[T: ClassManifest](
|
|||
var result = 0L
|
||||
while (iter.hasNext) {
|
||||
result += 1L
|
||||
iter.next
|
||||
iter.next()
|
||||
}
|
||||
result
|
||||
}).sum
|
||||
|
@ -452,7 +461,7 @@ abstract class RDD[T: ClassManifest](
|
|||
var result = 0L
|
||||
while (iter.hasNext) {
|
||||
result += 1L
|
||||
iter.next
|
||||
iter.next()
|
||||
}
|
||||
result
|
||||
}
|
||||
|
|
|
@ -46,6 +46,7 @@ import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, C
|
|||
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
|
||||
import storage.BlockManagerUI
|
||||
import util.{MetadataCleaner, TimeStampedHashMap}
|
||||
import storage.{StorageStatus, StorageUtils, RDDInfo}
|
||||
|
||||
/**
|
||||
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
|
||||
|
@ -107,8 +108,9 @@ class SparkContext(
|
|||
|
||||
// Environment variables to pass to our executors
|
||||
private[spark] val executorEnvs = HashMap[String, String]()
|
||||
// Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner
|
||||
for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS",
|
||||
"SPARK_TESTING")) {
|
||||
"SPARK_TESTING")) {
|
||||
val value = System.getenv(key)
|
||||
if (value != null) {
|
||||
executorEnvs(key) = value
|
||||
|
@ -187,6 +189,7 @@ class SparkContext(
|
|||
taskScheduler.start()
|
||||
|
||||
private var dagScheduler = new DAGScheduler(taskScheduler)
|
||||
dagScheduler.start()
|
||||
|
||||
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
|
||||
val hadoopConfiguration = {
|
||||
|
@ -467,12 +470,27 @@ class SparkContext(
|
|||
* Return a map from the slave to the max memory available for caching and the remaining
|
||||
* memory available for caching.
|
||||
*/
|
||||
def getSlavesMemoryStatus: Map[String, (Long, Long)] = {
|
||||
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
|
||||
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
|
||||
(blockManagerId.ip + ":" + blockManagerId.port, mem)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return information about what RDDs are cached, if they are in mem or on disk, how much space
|
||||
* they take, etc.
|
||||
*/
|
||||
def getRDDStorageInfo : Array[RDDInfo] = {
|
||||
StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
|
||||
}
|
||||
|
||||
/**
|
||||
* Return information about blocks stored in all of the slaves
|
||||
*/
|
||||
def getExecutorStorageStatus : Array[StorageStatus] = {
|
||||
env.blockManager.master.getStorageStatus
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear the job's list of files added by `addFile` so that they do not get downloaded to
|
||||
* any new nodes.
|
||||
|
@ -543,10 +561,30 @@ class SparkContext(
|
|||
}
|
||||
|
||||
/**
|
||||
* Run a function on a given set of partitions in an RDD and return the results. This is the main
|
||||
* entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies
|
||||
* whether the scheduler can run the computation on the driver rather than shipping it out to the
|
||||
* cluster, for short actions like first().
|
||||
* Run a function on a given set of partitions in an RDD and pass the results to the given
|
||||
* handler function. This is the main entry point for all actions in Spark. The allowLocal
|
||||
* flag specifies whether the scheduler can run the computation on the driver rather than
|
||||
* shipping it out to the cluster, for short actions like first().
|
||||
*/
|
||||
def runJob[T, U: ClassManifest](
|
||||
rdd: RDD[T],
|
||||
func: (TaskContext, Iterator[T]) => U,
|
||||
partitions: Seq[Int],
|
||||
allowLocal: Boolean,
|
||||
resultHandler: (Int, U) => Unit) {
|
||||
val callSite = Utils.getSparkCallSite
|
||||
logInfo("Starting job: " + callSite)
|
||||
val start = System.nanoTime
|
||||
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler)
|
||||
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
|
||||
rdd.doCheckpoint()
|
||||
result
|
||||
}
|
||||
|
||||
/**
|
||||
* Run a function on a given set of partitions in an RDD and return the results as an array. The
|
||||
* allowLocal flag specifies whether the scheduler can run the computation on the driver rather
|
||||
* than shipping it out to the cluster, for short actions like first().
|
||||
*/
|
||||
def runJob[T, U: ClassManifest](
|
||||
rdd: RDD[T],
|
||||
|
@ -554,13 +592,9 @@ class SparkContext(
|
|||
partitions: Seq[Int],
|
||||
allowLocal: Boolean
|
||||
): Array[U] = {
|
||||
val callSite = Utils.getSparkCallSite
|
||||
logInfo("Starting job: " + callSite)
|
||||
val start = System.nanoTime
|
||||
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal)
|
||||
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
|
||||
rdd.doCheckpoint()
|
||||
result
|
||||
val results = new Array[U](partitions.size)
|
||||
runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res)
|
||||
results
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -590,6 +624,29 @@ class SparkContext(
|
|||
runJob(rdd, func, 0 until rdd.splits.size, false)
|
||||
}
|
||||
|
||||
/**
|
||||
* Run a job on all partitions in an RDD and pass the results to a handler function.
|
||||
*/
|
||||
def runJob[T, U: ClassManifest](
|
||||
rdd: RDD[T],
|
||||
processPartition: (TaskContext, Iterator[T]) => U,
|
||||
resultHandler: (Int, U) => Unit)
|
||||
{
|
||||
runJob[T, U](rdd, processPartition, 0 until rdd.splits.size, false, resultHandler)
|
||||
}
|
||||
|
||||
/**
|
||||
* Run a job on all partitions in an RDD and pass the results to a handler function.
|
||||
*/
|
||||
def runJob[T, U: ClassManifest](
|
||||
rdd: RDD[T],
|
||||
processPartition: Iterator[T] => U,
|
||||
resultHandler: (Int, U) => Unit)
|
||||
{
|
||||
val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter)
|
||||
runJob[T, U](rdd, processFunc, 0 until rdd.splits.size, false, resultHandler)
|
||||
}
|
||||
|
||||
/**
|
||||
* Run a job that can return approximate results.
|
||||
*/
|
||||
|
|
|
@ -12,6 +12,7 @@ import scala.io.Source
|
|||
import com.google.common.io.Files
|
||||
import com.google.common.util.concurrent.ThreadFactoryBuilder
|
||||
import scala.Some
|
||||
import spark.serializer.SerializerInstance
|
||||
|
||||
/**
|
||||
* Various utility methods used by Spark.
|
||||
|
@ -446,4 +447,11 @@ private object Utils extends Logging {
|
|||
socket.close()
|
||||
portBound
|
||||
}
|
||||
|
||||
/**
|
||||
* Clone an object using a Spark serializer.
|
||||
*/
|
||||
def clone[T](value: T, serializer: SerializerInstance): T = {
|
||||
serializer.deserialize[T](serializer.serialize(value))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,35 +18,23 @@ import scala.collection.mutable.ArrayBuffer
|
|||
private[spark]
|
||||
class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
|
||||
|
||||
val localIpAddress = Utils.localIpAddress
|
||||
private val localIpAddress = Utils.localIpAddress
|
||||
private val masterActorSystems = ArrayBuffer[ActorSystem]()
|
||||
private val workerActorSystems = ArrayBuffer[ActorSystem]()
|
||||
|
||||
var masterActor : ActorRef = _
|
||||
var masterActorSystem : ActorSystem = _
|
||||
var masterPort : Int = _
|
||||
var masterUrl : String = _
|
||||
|
||||
val workerActorSystems = ArrayBuffer[ActorSystem]()
|
||||
val workerActors = ArrayBuffer[ActorRef]()
|
||||
|
||||
def start() : String = {
|
||||
def start(): String = {
|
||||
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
|
||||
|
||||
/* Start the Master */
|
||||
val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0)
|
||||
masterActorSystem = actorSystem
|
||||
masterUrl = "spark://" + localIpAddress + ":" + masterPort
|
||||
masterActor = masterActorSystem.actorOf(
|
||||
Props(new Master(localIpAddress, masterPort, 0)), name = "Master")
|
||||
val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0)
|
||||
masterActorSystems += masterSystem
|
||||
val masterUrl = "spark://" + localIpAddress + ":" + masterPort
|
||||
|
||||
/* Start the Slaves */
|
||||
/* Start the Workers */
|
||||
for (workerNum <- 1 to numWorkers) {
|
||||
val (actorSystem, boundPort) =
|
||||
AkkaUtils.createActorSystem("sparkWorker" + workerNum, localIpAddress, 0)
|
||||
workerActorSystems += actorSystem
|
||||
val actor = actorSystem.actorOf(
|
||||
Props(new Worker(localIpAddress, boundPort, 0, coresPerWorker, memoryPerWorker, masterUrl)),
|
||||
name = "Worker")
|
||||
workerActors += actor
|
||||
val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker,
|
||||
memoryPerWorker, masterUrl, null, Some(workerNum))
|
||||
workerActorSystems += workerSystem
|
||||
}
|
||||
|
||||
return masterUrl
|
||||
|
@ -57,7 +45,7 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
|
|||
// Stop the workers before the master so they don't get upset that it disconnected
|
||||
workerActorSystems.foreach(_.shutdown())
|
||||
workerActorSystems.foreach(_.awaitTermination())
|
||||
masterActorSystem.shutdown()
|
||||
masterActorSystem.awaitTermination()
|
||||
masterActorSystems.foreach(_.shutdown())
|
||||
masterActorSystems.foreach(_.awaitTermination())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import spark.{SparkException, Logging}
|
|||
import akka.remote.RemoteClientLifeCycleEvent
|
||||
import akka.remote.RemoteClientShutdown
|
||||
import spark.deploy.RegisterJob
|
||||
import spark.deploy.master.Master
|
||||
import akka.remote.RemoteClientDisconnected
|
||||
import akka.actor.Terminated
|
||||
import akka.dispatch.Await
|
||||
|
@ -24,26 +25,18 @@ private[spark] class Client(
|
|||
listener: ClientListener)
|
||||
extends Logging {
|
||||
|
||||
val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r
|
||||
|
||||
var actor: ActorRef = null
|
||||
var jobId: String = null
|
||||
|
||||
if (MASTER_REGEX.unapplySeq(masterUrl) == None) {
|
||||
throw new SparkException("Invalid master URL: " + masterUrl)
|
||||
}
|
||||
|
||||
class ClientActor extends Actor with Logging {
|
||||
var master: ActorRef = null
|
||||
var masterAddress: Address = null
|
||||
var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
|
||||
|
||||
override def preStart() {
|
||||
val Seq(masterHost, masterPort) = MASTER_REGEX.unapplySeq(masterUrl).get
|
||||
logInfo("Connecting to master spark://" + masterHost + ":" + masterPort)
|
||||
val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
|
||||
logInfo("Connecting to master " + masterUrl)
|
||||
try {
|
||||
master = context.actorFor(akkaUrl)
|
||||
master = context.actorFor(Master.toAkkaUrl(masterUrl))
|
||||
masterAddress = master.path.address
|
||||
master ! RegisterJob(jobDescription)
|
||||
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
|
||||
|
|
|
@ -262,11 +262,29 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
|
|||
}
|
||||
|
||||
private[spark] object Master {
|
||||
private val systemName = "sparkMaster"
|
||||
private val actorName = "Master"
|
||||
private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r
|
||||
|
||||
def main(argStrings: Array[String]) {
|
||||
val args = new MasterArguments(argStrings)
|
||||
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
|
||||
val actor = actorSystem.actorOf(
|
||||
Props(new Master(args.ip, boundPort, args.webUiPort)), name = "Master")
|
||||
val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort)
|
||||
actorSystem.awaitTermination()
|
||||
}
|
||||
|
||||
/** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */
|
||||
def toAkkaUrl(sparkUrl: String): String = {
|
||||
sparkUrl match {
|
||||
case sparkUrlRegex(host, port) =>
|
||||
"akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName)
|
||||
case _ =>
|
||||
throw new SparkException("Invalid master URL: " + sparkUrl)
|
||||
}
|
||||
}
|
||||
|
||||
def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int) = {
|
||||
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
|
||||
val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName)
|
||||
(actorSystem, boundPort)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,13 +45,9 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
|
|||
case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) =>
|
||||
val future = master ? RequestMasterState
|
||||
val jobInfo = for (masterState <- future.mapTo[MasterState]) yield {
|
||||
masterState.activeJobs.find(_.id == jobId) match {
|
||||
case Some(job) => job
|
||||
case _ => masterState.completedJobs.find(_.id == jobId) match {
|
||||
case Some(job) => job
|
||||
case _ => null
|
||||
}
|
||||
}
|
||||
masterState.activeJobs.find(_.id == jobId).getOrElse({
|
||||
masterState.completedJobs.find(_.id == jobId).getOrElse(null)
|
||||
})
|
||||
}
|
||||
respondWithMediaType(MediaTypes.`application/json`) { ctx =>
|
||||
ctx.complete(jobInfo.mapTo[JobInfo])
|
||||
|
@ -61,14 +57,10 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
|
|||
val future = master ? RequestMasterState
|
||||
future.map { state =>
|
||||
val masterState = state.asInstanceOf[MasterState]
|
||||
|
||||
masterState.activeJobs.find(_.id == jobId) match {
|
||||
case Some(job) => spark.deploy.master.html.job_details.render(job)
|
||||
case _ => masterState.completedJobs.find(_.id == jobId) match {
|
||||
case Some(job) => spark.deploy.master.html.job_details.render(job)
|
||||
case _ => null
|
||||
}
|
||||
}
|
||||
val job = masterState.activeJobs.find(_.id == jobId).getOrElse({
|
||||
masterState.completedJobs.find(_.id == jobId).getOrElse(null)
|
||||
})
|
||||
spark.deploy.master.html.job_details.render(job)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -113,8 +113,7 @@ private[spark] class ExecutorRunner(
|
|||
for ((key, value) <- jobDesc.command.environment) {
|
||||
env.put(key, value)
|
||||
}
|
||||
env.put("SPARK_CORES", cores.toString)
|
||||
env.put("SPARK_MEMORY", memory.toString)
|
||||
env.put("SPARK_MEM", memory.toString + "m")
|
||||
// In case we are running this from within the Spark Shell, avoid creating a "scala"
|
||||
// parent process for the executor command
|
||||
env.put("SPARK_LAUNCH_WITH_SCALA", "0")
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package spark.deploy.worker
|
||||
|
||||
import scala.collection.mutable.{ArrayBuffer, HashMap}
|
||||
import akka.actor.{ActorRef, Props, Actor}
|
||||
import akka.actor.{ActorRef, Props, Actor, ActorSystem}
|
||||
import spark.{Logging, Utils}
|
||||
import spark.util.AkkaUtils
|
||||
import spark.deploy._
|
||||
|
@ -13,6 +13,7 @@ import akka.remote.RemoteClientDisconnected
|
|||
import spark.deploy.RegisterWorker
|
||||
import spark.deploy.LaunchExecutor
|
||||
import spark.deploy.RegisterWorkerFailed
|
||||
import spark.deploy.master.Master
|
||||
import akka.actor.Terminated
|
||||
import java.io.File
|
||||
|
||||
|
@ -27,7 +28,6 @@ private[spark] class Worker(
|
|||
extends Actor with Logging {
|
||||
|
||||
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
|
||||
val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r
|
||||
|
||||
var master: ActorRef = null
|
||||
var masterWebUiUrl : String = ""
|
||||
|
@ -48,11 +48,7 @@ private[spark] class Worker(
|
|||
def memoryFree: Int = memory - memoryUsed
|
||||
|
||||
def createWorkDir() {
|
||||
workDir = if (workDirPath != null) {
|
||||
new File(workDirPath)
|
||||
} else {
|
||||
new File(sparkHome, "work")
|
||||
}
|
||||
workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work"))
|
||||
try {
|
||||
if (!workDir.exists() && !workDir.mkdirs()) {
|
||||
logError("Failed to create work directory " + workDir)
|
||||
|
@ -68,8 +64,7 @@ private[spark] class Worker(
|
|||
override def preStart() {
|
||||
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
|
||||
ip, port, cores, Utils.memoryMegabytesToString(memory)))
|
||||
val envVar = System.getenv("SPARK_HOME")
|
||||
sparkHome = new File(if (envVar == null) "." else envVar)
|
||||
sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
|
||||
logInfo("Spark home: " + sparkHome)
|
||||
createWorkDir()
|
||||
connectToMaster()
|
||||
|
@ -77,24 +72,15 @@ private[spark] class Worker(
|
|||
}
|
||||
|
||||
def connectToMaster() {
|
||||
masterUrl match {
|
||||
case MASTER_REGEX(masterHost, masterPort) => {
|
||||
logInfo("Connecting to master spark://" + masterHost + ":" + masterPort)
|
||||
val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
|
||||
try {
|
||||
master = context.actorFor(akkaUrl)
|
||||
master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
|
||||
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
|
||||
context.watch(master) // Doesn't work with remote actors, but useful for testing
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError("Failed to connect to master", e)
|
||||
System.exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
case _ =>
|
||||
logError("Invalid master URL: " + masterUrl)
|
||||
logInfo("Connecting to master " + masterUrl)
|
||||
try {
|
||||
master = context.actorFor(Master.toAkkaUrl(masterUrl))
|
||||
master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
|
||||
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
|
||||
context.watch(master) // Doesn't work with remote actors, but useful for testing
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError("Failed to connect to master", e)
|
||||
System.exit(1)
|
||||
}
|
||||
}
|
||||
|
@ -183,11 +169,19 @@ private[spark] class Worker(
|
|||
private[spark] object Worker {
|
||||
def main(argStrings: Array[String]) {
|
||||
val args = new WorkerArguments(argStrings)
|
||||
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
|
||||
val actor = actorSystem.actorOf(
|
||||
Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory,
|
||||
args.master, args.workDir)),
|
||||
name = "Worker")
|
||||
val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores,
|
||||
args.memory, args.master, args.workDir)
|
||||
actorSystem.awaitTermination()
|
||||
}
|
||||
|
||||
def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int,
|
||||
masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
|
||||
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
|
||||
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
|
||||
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
|
||||
val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory,
|
||||
masterUrl, workDir)), name = "Worker")
|
||||
(actorSystem, boundPort)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ private[spark] class ApproximateActionListener[T, U, R](
|
|||
if (finishedTasks == totalTasks) {
|
||||
// If we had already returned a PartialResult, set its final value
|
||||
resultObject.foreach(r => r.setFinalValue(evaluator.currentResult()))
|
||||
// Notify any waiting thread that may have called getResult
|
||||
// Notify any waiting thread that may have called awaitResult
|
||||
this.notifyAll()
|
||||
}
|
||||
}
|
||||
|
@ -49,7 +49,7 @@ private[spark] class ApproximateActionListener[T, U, R](
|
|||
* Waits for up to timeout milliseconds since the listener was created and then returns a
|
||||
* PartialResult with the result so far. This may be complete if the whole job is done.
|
||||
*/
|
||||
def getResult(): PartialResult[R] = synchronized {
|
||||
def awaitResult(): PartialResult[R] = synchronized {
|
||||
val finishTime = startTime + timeout
|
||||
while (true) {
|
||||
val time = System.currentTimeMillis()
|
||||
|
|
|
@ -1,24 +1,42 @@
|
|||
package spark.rdd
|
||||
|
||||
import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext}
|
||||
import spark.{NarrowDependency, RDD, SparkEnv, Split, TaskContext}
|
||||
|
||||
|
||||
class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split {
|
||||
override val index = idx
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Represents a dependency between the PartitionPruningRDD and its parent. In this
|
||||
* case, the child RDD contains a subset of partitions of the parents'.
|
||||
*/
|
||||
class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean)
|
||||
extends NarrowDependency[T](rdd) {
|
||||
|
||||
@transient
|
||||
val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index))
|
||||
.zipWithIndex.map { case(split, idx) => new PartitionPruningRDDSplit(idx, split) : Split }
|
||||
|
||||
override def getParents(partitionId: Int) = List(partitions(partitionId).index)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A RDD used to prune RDD partitions/splits so we can avoid launching tasks on
|
||||
* all partitions. An example use case: If we know the RDD is partitioned by range,
|
||||
* and the execution DAG has a filter on the key, we can avoid launching tasks
|
||||
* on partitions that don't have the range covering the key.
|
||||
*
|
||||
* TODO: This currently doesn't give partition IDs properly!
|
||||
*/
|
||||
class PartitionPruningRDD[T: ClassManifest](
|
||||
@transient prev: RDD[T],
|
||||
@transient partitionFilterFunc: Int => Boolean)
|
||||
extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) {
|
||||
|
||||
override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context)
|
||||
override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(
|
||||
split.asInstanceOf[PartitionPruningRDDSplit].parentSplit, context)
|
||||
|
||||
override protected def getSplits =
|
||||
getDependencies.head.asInstanceOf[PruneDependency[T]].partitions
|
||||
|
||||
override val partitioner = firstParent[T].partitioner
|
||||
}
|
||||
|
|
|
@ -23,7 +23,16 @@ import util.{MetadataCleaner, TimeStampedHashMap}
|
|||
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
|
||||
*/
|
||||
private[spark]
|
||||
class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging {
|
||||
class DAGScheduler(
|
||||
taskSched: TaskScheduler,
|
||||
mapOutputTracker: MapOutputTracker,
|
||||
blockManagerMaster: BlockManagerMaster,
|
||||
env: SparkEnv)
|
||||
extends TaskSchedulerListener with Logging {
|
||||
|
||||
def this(taskSched: TaskScheduler) {
|
||||
this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
|
||||
}
|
||||
taskSched.setListener(this)
|
||||
|
||||
// Called by TaskScheduler to report task completions or failures.
|
||||
|
@ -66,10 +75,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
|
||||
var cacheLocs = new HashMap[Int, Array[List[String]]]
|
||||
|
||||
val env = SparkEnv.get
|
||||
val mapOutputTracker = env.mapOutputTracker
|
||||
val blockManagerMaster = env.blockManager.master
|
||||
|
||||
// For tracking failed nodes, we use the MapOutputTracker's generation number, which is
|
||||
// sent with every task. When we detect a node failing, we note the current generation number
|
||||
// and failed executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask
|
||||
|
@ -90,14 +95,16 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
|
||||
|
||||
// Start a thread to run the DAGScheduler event loop
|
||||
new Thread("DAGScheduler") {
|
||||
setDaemon(true)
|
||||
override def run() {
|
||||
DAGScheduler.this.run()
|
||||
}
|
||||
}.start()
|
||||
def start() {
|
||||
new Thread("DAGScheduler") {
|
||||
setDaemon(true)
|
||||
override def run() {
|
||||
DAGScheduler.this.run()
|
||||
}
|
||||
}.start()
|
||||
}
|
||||
|
||||
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
|
||||
private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
|
||||
if (!cacheLocs.contains(rdd.id)) {
|
||||
val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
|
||||
cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map {
|
||||
|
@ -107,7 +114,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
cacheLocs(rdd.id)
|
||||
}
|
||||
|
||||
def clearCacheLocs() {
|
||||
private def clearCacheLocs() {
|
||||
cacheLocs.clear()
|
||||
}
|
||||
|
||||
|
@ -116,7 +123,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
* The priority value passed in will be used if the stage doesn't already exist with
|
||||
* a lower priority (we assume that priorities always increase across jobs for now).
|
||||
*/
|
||||
def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = {
|
||||
private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = {
|
||||
shuffleToMapStage.get(shuffleDep.shuffleId) match {
|
||||
case Some(stage) => stage
|
||||
case None =>
|
||||
|
@ -131,11 +138,11 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
* as a result stage for the final RDD used directly in an action. The stage will also be given
|
||||
* the provided priority.
|
||||
*/
|
||||
def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
|
||||
// Kind of ugly: need to register RDDs with the cache and map output tracker here
|
||||
// since we can't do it in the RDD constructor because # of splits is unknown
|
||||
logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
|
||||
private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
|
||||
if (shuffleDep != None) {
|
||||
// Kind of ugly: need to register RDDs with the cache and map output tracker here
|
||||
// since we can't do it in the RDD constructor because # of splits is unknown
|
||||
logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
|
||||
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
|
||||
}
|
||||
val id = nextStageId.getAndIncrement()
|
||||
|
@ -148,7 +155,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
* Get or create the list of parent stages for a given RDD. The stages will be assigned the
|
||||
* provided priority if they haven't already been created with a lower priority.
|
||||
*/
|
||||
def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
|
||||
private def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
|
||||
val parents = new HashSet[Stage]
|
||||
val visited = new HashSet[RDD[_]]
|
||||
def visit(r: RDD[_]) {
|
||||
|
@ -170,25 +177,22 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
parents.toList
|
||||
}
|
||||
|
||||
def getMissingParentStages(stage: Stage): List[Stage] = {
|
||||
private def getMissingParentStages(stage: Stage): List[Stage] = {
|
||||
val missing = new HashSet[Stage]
|
||||
val visited = new HashSet[RDD[_]]
|
||||
def visit(rdd: RDD[_]) {
|
||||
if (!visited(rdd)) {
|
||||
visited += rdd
|
||||
val locs = getCacheLocs(rdd)
|
||||
for (p <- 0 until rdd.splits.size) {
|
||||
if (locs(p) == Nil) {
|
||||
for (dep <- rdd.dependencies) {
|
||||
dep match {
|
||||
case shufDep: ShuffleDependency[_,_] =>
|
||||
val mapStage = getShuffleMapStage(shufDep, stage.priority)
|
||||
if (!mapStage.isAvailable) {
|
||||
missing += mapStage
|
||||
}
|
||||
case narrowDep: NarrowDependency[_] =>
|
||||
visit(narrowDep.rdd)
|
||||
}
|
||||
if (getCacheLocs(rdd).contains(Nil)) {
|
||||
for (dep <- rdd.dependencies) {
|
||||
dep match {
|
||||
case shufDep: ShuffleDependency[_,_] =>
|
||||
val mapStage = getShuffleMapStage(shufDep, stage.priority)
|
||||
if (!mapStage.isAvailable) {
|
||||
missing += mapStage
|
||||
}
|
||||
case narrowDep: NarrowDependency[_] =>
|
||||
visit(narrowDep.rdd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -198,23 +202,45 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
missing.toList
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
|
||||
* JobWaiter whose getResult() method will return the result of the job when it is complete.
|
||||
*
|
||||
* The job is assumed to have at least one partition; zero partition jobs should be handled
|
||||
* without a JobSubmitted event.
|
||||
*/
|
||||
private[scheduler] def prepareJob[T, U: ClassManifest](
|
||||
finalRdd: RDD[T],
|
||||
func: (TaskContext, Iterator[T]) => U,
|
||||
partitions: Seq[Int],
|
||||
callSite: String,
|
||||
allowLocal: Boolean,
|
||||
resultHandler: (Int, U) => Unit)
|
||||
: (JobSubmitted, JobWaiter[U]) =
|
||||
{
|
||||
assert(partitions.size > 0)
|
||||
val waiter = new JobWaiter(partitions.size, resultHandler)
|
||||
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
|
||||
val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)
|
||||
return (toSubmit, waiter)
|
||||
}
|
||||
|
||||
def runJob[T, U: ClassManifest](
|
||||
finalRdd: RDD[T],
|
||||
func: (TaskContext, Iterator[T]) => U,
|
||||
partitions: Seq[Int],
|
||||
callSite: String,
|
||||
allowLocal: Boolean)
|
||||
: Array[U] =
|
||||
allowLocal: Boolean,
|
||||
resultHandler: (Int, U) => Unit)
|
||||
{
|
||||
if (partitions.size == 0) {
|
||||
return new Array[U](0)
|
||||
return
|
||||
}
|
||||
val waiter = new JobWaiter(partitions.size)
|
||||
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
|
||||
eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter))
|
||||
waiter.getResult() match {
|
||||
case JobSucceeded(results: Seq[_]) =>
|
||||
return results.asInstanceOf[Seq[U]].toArray
|
||||
val (toSubmit, waiter) = prepareJob(
|
||||
finalRdd, func, partitions, callSite, allowLocal, resultHandler)
|
||||
eventQueue.put(toSubmit)
|
||||
waiter.awaitResult() match {
|
||||
case JobSucceeded => {}
|
||||
case JobFailed(exception: Exception) =>
|
||||
logInfo("Failed to run " + callSite)
|
||||
throw exception
|
||||
|
@ -233,90 +259,117 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
|
||||
val partitions = (0 until rdd.splits.size).toArray
|
||||
eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener))
|
||||
return listener.getResult() // Will throw an exception if the job fails
|
||||
return listener.awaitResult() // Will throw an exception if the job fails
|
||||
}
|
||||
|
||||
/**
|
||||
* Process one event retrieved from the event queue.
|
||||
* Returns true if we should stop the event loop.
|
||||
*/
|
||||
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
|
||||
event match {
|
||||
case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
|
||||
val runId = nextRunId.getAndIncrement()
|
||||
val finalStage = newStage(finalRDD, None, runId)
|
||||
val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
|
||||
clearCacheLocs()
|
||||
logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
|
||||
" output partitions (allowLocal=" + allowLocal + ")")
|
||||
logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
|
||||
logInfo("Parents of final stage: " + finalStage.parents)
|
||||
logInfo("Missing parents: " + getMissingParentStages(finalStage))
|
||||
if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
|
||||
// Compute very short actions like first() or take() with no parent stages locally.
|
||||
runLocally(job)
|
||||
} else {
|
||||
activeJobs += job
|
||||
resultStageToJob(finalStage) = job
|
||||
submitStage(finalStage)
|
||||
}
|
||||
|
||||
case ExecutorLost(execId) =>
|
||||
handleExecutorLost(execId)
|
||||
|
||||
case completion: CompletionEvent =>
|
||||
handleTaskCompletion(completion)
|
||||
|
||||
case TaskSetFailed(taskSet, reason) =>
|
||||
abortStage(idToStage(taskSet.stageId), reason)
|
||||
|
||||
case StopDAGScheduler =>
|
||||
// Cancel any active jobs
|
||||
for (job <- activeJobs) {
|
||||
val error = new SparkException("Job cancelled because SparkContext was shut down")
|
||||
job.listener.jobFailed(error)
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* Resubmit any failed stages. Ordinarily called after a small amount of time has passed since
|
||||
* the last fetch failure.
|
||||
*/
|
||||
private[scheduler] def resubmitFailedStages() {
|
||||
logInfo("Resubmitting failed stages")
|
||||
clearCacheLocs()
|
||||
val failed2 = failed.toArray
|
||||
failed.clear()
|
||||
for (stage <- failed2.sortBy(_.priority)) {
|
||||
submitStage(stage)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check for waiting or failed stages which are now eligible for resubmission.
|
||||
* Ordinarily run on every iteration of the event loop.
|
||||
*/
|
||||
private[scheduler] def submitWaitingStages() {
|
||||
// TODO: We might want to run this less often, when we are sure that something has become
|
||||
// runnable that wasn't before.
|
||||
logTrace("Checking for newly runnable parent stages")
|
||||
logTrace("running: " + running)
|
||||
logTrace("waiting: " + waiting)
|
||||
logTrace("failed: " + failed)
|
||||
val waiting2 = waiting.toArray
|
||||
waiting.clear()
|
||||
for (stage <- waiting2.sortBy(_.priority)) {
|
||||
submitStage(stage)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
|
||||
* events and responds by launching tasks. This runs in a dedicated thread and receives events
|
||||
* via the eventQueue.
|
||||
*/
|
||||
def run() {
|
||||
private def run() {
|
||||
SparkEnv.set(env)
|
||||
|
||||
while (true) {
|
||||
val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
|
||||
val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
|
||||
if (event != null) {
|
||||
logDebug("Got event of type " + event.getClass.getName)
|
||||
}
|
||||
|
||||
event match {
|
||||
case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
|
||||
val runId = nextRunId.getAndIncrement()
|
||||
val finalStage = newStage(finalRDD, None, runId)
|
||||
val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
|
||||
clearCacheLocs()
|
||||
logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
|
||||
" output partitions")
|
||||
logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
|
||||
logInfo("Parents of final stage: " + finalStage.parents)
|
||||
logInfo("Missing parents: " + getMissingParentStages(finalStage))
|
||||
if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
|
||||
// Compute very short actions like first() or take() with no parent stages locally.
|
||||
runLocally(job)
|
||||
} else {
|
||||
activeJobs += job
|
||||
resultStageToJob(finalStage) = job
|
||||
submitStage(finalStage)
|
||||
}
|
||||
|
||||
case ExecutorLost(execId) =>
|
||||
handleExecutorLost(execId)
|
||||
|
||||
case completion: CompletionEvent =>
|
||||
handleTaskCompletion(completion)
|
||||
|
||||
case TaskSetFailed(taskSet, reason) =>
|
||||
abortStage(idToStage(taskSet.stageId), reason)
|
||||
|
||||
case StopDAGScheduler =>
|
||||
// Cancel any active jobs
|
||||
for (job <- activeJobs) {
|
||||
val error = new SparkException("Job cancelled because SparkContext was shut down")
|
||||
job.listener.jobFailed(error)
|
||||
}
|
||||
if (event != null) {
|
||||
if (processEvent(event)) {
|
||||
return
|
||||
|
||||
case null =>
|
||||
// queue.poll() timed out, ignore it
|
||||
}
|
||||
}
|
||||
|
||||
val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
|
||||
// Periodically resubmit failed stages if some map output fetches have failed and we have
|
||||
// waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
|
||||
// tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
|
||||
// the same time, so we want to make sure we've identified all the reduce tasks that depend
|
||||
// on the failed node.
|
||||
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
|
||||
logInfo("Resubmitting failed stages")
|
||||
clearCacheLocs()
|
||||
val failed2 = failed.toArray
|
||||
failed.clear()
|
||||
for (stage <- failed2.sortBy(_.priority)) {
|
||||
submitStage(stage)
|
||||
}
|
||||
resubmitFailedStages()
|
||||
} else {
|
||||
// TODO: We might want to run this less often, when we are sure that something has become
|
||||
// runnable that wasn't before.
|
||||
logTrace("Checking for newly runnable parent stages")
|
||||
logTrace("running: " + running)
|
||||
logTrace("waiting: " + waiting)
|
||||
logTrace("failed: " + failed)
|
||||
val waiting2 = waiting.toArray
|
||||
waiting.clear()
|
||||
for (stage <- waiting2.sortBy(_.priority)) {
|
||||
submitStage(stage)
|
||||
}
|
||||
submitWaitingStages()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -326,7 +379,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
* We run the operation in a separate thread just in case it takes a bunch of time, so that we
|
||||
* don't block the DAGScheduler event loop or other concurrent jobs.
|
||||
*/
|
||||
def runLocally(job: ActiveJob) {
|
||||
private def runLocally(job: ActiveJob) {
|
||||
logInfo("Computing the requested partition locally")
|
||||
new Thread("Local computation of job " + job.runId) {
|
||||
override def run() {
|
||||
|
@ -349,13 +402,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
}.start()
|
||||
}
|
||||
|
||||
def submitStage(stage: Stage) {
|
||||
/** Submits stage, but first recursively submits any missing parents. */
|
||||
private def submitStage(stage: Stage) {
|
||||
logDebug("submitStage(" + stage + ")")
|
||||
if (!waiting(stage) && !running(stage) && !failed(stage)) {
|
||||
val missing = getMissingParentStages(stage).sortBy(_.id)
|
||||
logDebug("missing: " + missing)
|
||||
if (missing == Nil) {
|
||||
logInfo("Submitting " + stage + " (" + stage.origin + "), which has no missing parents")
|
||||
logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
|
||||
submitMissingTasks(stage)
|
||||
running += stage
|
||||
} else {
|
||||
|
@ -367,7 +421,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
}
|
||||
}
|
||||
|
||||
def submitMissingTasks(stage: Stage) {
|
||||
/** Called when stage's parents are available and we can now do its task. */
|
||||
private def submitMissingTasks(stage: Stage) {
|
||||
logDebug("submitMissingTasks(" + stage + ")")
|
||||
// Get our pending tasks and remember them in our pendingTasks entry
|
||||
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
|
||||
|
@ -388,7 +443,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
}
|
||||
}
|
||||
if (tasks.size > 0) {
|
||||
logInfo("Submitting " + tasks.size + " missing tasks from " + stage)
|
||||
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
|
||||
myPending ++= tasks
|
||||
logDebug("New pending tasks: " + myPending)
|
||||
taskSched.submitTasks(
|
||||
|
@ -407,7 +462,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
* Responds to a task finishing. This is called inside the event loop so it assumes that it can
|
||||
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
|
||||
*/
|
||||
def handleTaskCompletion(event: CompletionEvent) {
|
||||
private def handleTaskCompletion(event: CompletionEvent) {
|
||||
val task = event.task
|
||||
val stage = idToStage(task.stageId)
|
||||
|
||||
|
@ -492,7 +547,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
waiting --= newlyRunnable
|
||||
running ++= newlyRunnable
|
||||
for (stage <- newlyRunnable.sortBy(_.id)) {
|
||||
logInfo("Submitting " + stage + " (" + stage.origin + "), which is now runnable")
|
||||
logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable")
|
||||
submitMissingTasks(stage)
|
||||
}
|
||||
}
|
||||
|
@ -541,12 +596,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
* Optionally the generation during which the failure was caught can be passed to avoid allowing
|
||||
* stray fetch failures from possibly retriggering the detection of a node as lost.
|
||||
*/
|
||||
def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) {
|
||||
private def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) {
|
||||
val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration)
|
||||
if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) {
|
||||
failedGeneration(execId) = currentGeneration
|
||||
logInfo("Executor lost: %s (generation %d)".format(execId, currentGeneration))
|
||||
env.blockManager.master.removeExecutor(execId)
|
||||
blockManagerMaster.removeExecutor(execId)
|
||||
// TODO: This will be really slow if we keep accumulating shuffle map stages
|
||||
for ((shuffleId, stage) <- shuffleToMapStage) {
|
||||
stage.removeOutputsOnExecutor(execId)
|
||||
|
@ -567,7 +622,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
|
||||
* being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
|
||||
*/
|
||||
def abortStage(failedStage: Stage, reason: String) {
|
||||
private def abortStage(failedStage: Stage, reason: String) {
|
||||
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
|
||||
for (resultStage <- dependentStages) {
|
||||
val job = resultStageToJob(resultStage)
|
||||
|
@ -583,7 +638,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
/**
|
||||
* Return true if one of stage's ancestors is target.
|
||||
*/
|
||||
def stageDependsOn(stage: Stage, target: Stage): Boolean = {
|
||||
private def stageDependsOn(stage: Stage, target: Stage): Boolean = {
|
||||
if (stage == target) {
|
||||
return true
|
||||
}
|
||||
|
@ -610,7 +665,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
visitedRdds.contains(target.rdd)
|
||||
}
|
||||
|
||||
def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
|
||||
private def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
|
||||
// If the partition is cached, return the cache locations
|
||||
val cached = getCacheLocs(rdd)(partition)
|
||||
if (cached != Nil) {
|
||||
|
@ -636,7 +691,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
|
|||
return Nil
|
||||
}
|
||||
|
||||
def cleanup(cleanupTime: Long) {
|
||||
private def cleanup(cleanupTime: Long) {
|
||||
var sizeBefore = idToStage.size
|
||||
idToStage.clearOldValues(cleanupTime)
|
||||
logInfo("idToStage " + sizeBefore + " --> " + idToStage.size)
|
||||
|
|
|
@ -5,5 +5,5 @@ package spark.scheduler
|
|||
*/
|
||||
private[spark] sealed trait JobResult
|
||||
|
||||
private[spark] case class JobSucceeded(results: Seq[_]) extends JobResult
|
||||
private[spark] case object JobSucceeded extends JobResult
|
||||
private[spark] case class JobFailed(exception: Exception) extends JobResult
|
||||
|
|
|
@ -3,10 +3,12 @@ package spark.scheduler
|
|||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
/**
|
||||
* An object that waits for a DAGScheduler job to complete.
|
||||
* An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
|
||||
* results to the given handler function.
|
||||
*/
|
||||
private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
|
||||
private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null)
|
||||
private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit)
|
||||
extends JobListener {
|
||||
|
||||
private var finishedTasks = 0
|
||||
|
||||
private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
|
||||
|
@ -17,11 +19,11 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
|
|||
if (jobFinished) {
|
||||
throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
|
||||
}
|
||||
taskResults(index) = result
|
||||
resultHandler(index, result.asInstanceOf[T])
|
||||
finishedTasks += 1
|
||||
if (finishedTasks == totalTasks) {
|
||||
jobFinished = true
|
||||
jobResult = JobSucceeded(taskResults)
|
||||
jobResult = JobSucceeded
|
||||
this.notifyAll()
|
||||
}
|
||||
}
|
||||
|
@ -38,7 +40,7 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
|
|||
}
|
||||
}
|
||||
|
||||
def getResult(): JobResult = synchronized {
|
||||
def awaitResult(): JobResult = synchronized {
|
||||
while (!jobFinished) {
|
||||
this.wait()
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ private[spark] object ShuffleMapTask {
|
|||
return old
|
||||
} else {
|
||||
val out = new ByteArrayOutputStream
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||
val objOut = ser.serializeStream(new GZIPOutputStream(out))
|
||||
objOut.writeObject(rdd)
|
||||
objOut.writeObject(dep)
|
||||
|
@ -48,7 +48,7 @@ private[spark] object ShuffleMapTask {
|
|||
synchronized {
|
||||
val loader = Thread.currentThread.getContextClassLoader
|
||||
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||
val objIn = ser.deserializeStream(in)
|
||||
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
|
||||
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
|
||||
|
@ -127,7 +127,6 @@ private[spark] class ShuffleMapTask(
|
|||
val bucketId = dep.partitioner.getPartition(pair._1)
|
||||
buckets(bucketId) += pair
|
||||
}
|
||||
val bucketIterators = buckets.map(_.iterator)
|
||||
|
||||
val compressedSizes = new Array[Byte](numOutputSplits)
|
||||
|
||||
|
@ -135,7 +134,7 @@ private[spark] class ShuffleMapTask(
|
|||
for (i <- 0 until numOutputSplits) {
|
||||
val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
|
||||
// Get a Scala iterator from Java map
|
||||
val iter: Iterator[(Any, Any)] = bucketIterators(i)
|
||||
val iter: Iterator[(Any, Any)] = buckets(i).iterator
|
||||
val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
|
||||
compressedSizes(i) = MapOutputTracker.compressSize(size)
|
||||
}
|
||||
|
|
|
@ -86,7 +86,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
|
|||
}
|
||||
}
|
||||
|
||||
def submitTasks(taskSet: TaskSet) {
|
||||
override def submitTasks(taskSet: TaskSet) {
|
||||
val tasks = taskSet.tasks
|
||||
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
|
||||
this.synchronized {
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
package spark.scheduler.cluster
|
||||
|
||||
import spark.Utils
|
||||
|
||||
/**
|
||||
* A backend interface for cluster scheduling systems that allows plugging in different ones under
|
||||
* ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as
|
||||
|
@ -11,5 +13,15 @@ private[spark] trait SchedulerBackend {
|
|||
def reviveOffers(): Unit
|
||||
def defaultParallelism(): Int
|
||||
|
||||
// Memory used by each executor (in megabytes)
|
||||
protected val executorMemory = {
|
||||
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
|
||||
Option(System.getProperty("spark.executor.memory"))
|
||||
.orElse(Option(System.getenv("SPARK_MEM")))
|
||||
.map(Utils.memoryStringToMb)
|
||||
.getOrElse(512)
|
||||
}
|
||||
|
||||
|
||||
// TODO: Probably want to add a killTask too
|
||||
}
|
||||
|
|
|
@ -20,16 +20,6 @@ private[spark] class SparkDeploySchedulerBackend(
|
|||
|
||||
val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt
|
||||
|
||||
// Memory used by each executor (in megabytes)
|
||||
val executorMemory = {
|
||||
if (System.getenv("SPARK_MEM") != null) {
|
||||
Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
|
||||
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
|
||||
} else {
|
||||
512
|
||||
}
|
||||
}
|
||||
|
||||
override def start() {
|
||||
super.start()
|
||||
|
||||
|
|
|
@ -17,10 +17,7 @@ import java.nio.ByteBuffer
|
|||
/**
|
||||
* Schedules the tasks within a single TaskSet in the ClusterScheduler.
|
||||
*/
|
||||
private[spark] class TaskSetManager(
|
||||
sched: ClusterScheduler,
|
||||
val taskSet: TaskSet)
|
||||
extends Logging {
|
||||
private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) extends Logging {
|
||||
|
||||
// Maximum time to wait to run a task in a preferred location (in ms)
|
||||
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
|
||||
|
@ -100,7 +97,7 @@ private[spark] class TaskSetManager(
|
|||
}
|
||||
|
||||
// Add a task to all the pending-task lists that it should be on.
|
||||
def addPendingTask(index: Int) {
|
||||
private def addPendingTask(index: Int) {
|
||||
val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
|
||||
if (locations.size == 0) {
|
||||
pendingTasksWithNoPrefs += index
|
||||
|
@ -115,7 +112,7 @@ private[spark] class TaskSetManager(
|
|||
|
||||
// Return the pending tasks list for a given host, or an empty list if
|
||||
// there is no map entry for that host
|
||||
def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
|
||||
private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
|
||||
pendingTasksForHost.getOrElse(host, ArrayBuffer())
|
||||
}
|
||||
|
||||
|
@ -123,7 +120,7 @@ private[spark] class TaskSetManager(
|
|||
// Return None if the list is empty.
|
||||
// This method also cleans up any tasks in the list that have already
|
||||
// been launched, since we want that to happen lazily.
|
||||
def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
|
||||
private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
|
||||
while (!list.isEmpty) {
|
||||
val index = list.last
|
||||
list.trimEnd(1)
|
||||
|
@ -137,7 +134,7 @@ private[spark] class TaskSetManager(
|
|||
// Return a speculative task for a given host if any are available. The task should not have an
|
||||
// attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
|
||||
// task must have a preference for this host (or no preferred locations at all).
|
||||
def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
|
||||
private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
|
||||
val hostsAlive = sched.hostsAlive
|
||||
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
|
||||
val localTask = speculatableTasks.find {
|
||||
|
@ -162,7 +159,7 @@ private[spark] class TaskSetManager(
|
|||
|
||||
// Dequeue a pending task for a given node and return its index.
|
||||
// If localOnly is set to false, allow non-local tasks as well.
|
||||
def findTask(host: String, localOnly: Boolean): Option[Int] = {
|
||||
private def findTask(host: String, localOnly: Boolean): Option[Int] = {
|
||||
val localTask = findTaskFromList(getPendingTasksForHost(host))
|
||||
if (localTask != None) {
|
||||
return localTask
|
||||
|
@ -184,7 +181,7 @@ private[spark] class TaskSetManager(
|
|||
// Does a host count as a preferred location for a task? This is true if
|
||||
// either the task has preferred locations and this host is one, or it has
|
||||
// no preferred locations (in which we still count the launch as preferred).
|
||||
def isPreferredLocation(task: Task[_], host: String): Boolean = {
|
||||
private def isPreferredLocation(task: Task[_], host: String): Boolean = {
|
||||
val locs = task.preferredLocations
|
||||
return (locs.contains(host) || locs.isEmpty)
|
||||
}
|
||||
|
@ -335,7 +332,7 @@ private[spark] class TaskSetManager(
|
|||
if (numFailures(index) > MAX_TASK_FAILURES) {
|
||||
logError("Task %s:%d failed more than %d times; aborting job".format(
|
||||
taskSet.id, index, MAX_TASK_FAILURES))
|
||||
abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES))
|
||||
abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -53,7 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
|
|||
}
|
||||
|
||||
def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
|
||||
logInfo("Running task " + idInJob)
|
||||
logInfo("Running " + task)
|
||||
// Set the Spark execution environment for the worker thread
|
||||
SparkEnv.set(env)
|
||||
try {
|
||||
|
@ -80,7 +80,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
|
|||
val resultToReturn = ser.deserialize[Any](ser.serialize(result))
|
||||
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
|
||||
ser.serialize(Accumulators.values))
|
||||
logInfo("Finished task " + idInJob)
|
||||
logInfo("Finished " + task)
|
||||
|
||||
// If the threadpool has not already been shutdown, notify DAGScheduler
|
||||
if (!Thread.currentThread().isInterrupted)
|
||||
|
|
|
@ -35,16 +35,6 @@ private[spark] class CoarseMesosSchedulerBackend(
|
|||
|
||||
val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures
|
||||
|
||||
// Memory used by each executor (in megabytes)
|
||||
val executorMemory = {
|
||||
if (System.getenv("SPARK_MEM") != null) {
|
||||
Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
|
||||
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
|
||||
} else {
|
||||
512
|
||||
}
|
||||
}
|
||||
|
||||
// Lock used to wait for scheduler to be registered
|
||||
var isRegistered = false
|
||||
val registeredLock = new Object()
|
||||
|
|
|
@ -29,16 +29,6 @@ private[spark] class MesosSchedulerBackend(
|
|||
with MScheduler
|
||||
with Logging {
|
||||
|
||||
// Memory used by each executor (in megabytes)
|
||||
val EXECUTOR_MEMORY = {
|
||||
if (System.getenv("SPARK_MEM") != null) {
|
||||
Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
|
||||
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
|
||||
} else {
|
||||
512
|
||||
}
|
||||
}
|
||||
|
||||
// Lock used to wait for scheduler to be registered
|
||||
var isRegistered = false
|
||||
val registeredLock = new Object()
|
||||
|
@ -89,7 +79,7 @@ private[spark] class MesosSchedulerBackend(
|
|||
val memory = Resource.newBuilder()
|
||||
.setName("mem")
|
||||
.setType(Value.Type.SCALAR)
|
||||
.setScalar(Value.Scalar.newBuilder().setValue(EXECUTOR_MEMORY).build())
|
||||
.setScalar(Value.Scalar.newBuilder().setValue(executorMemory).build())
|
||||
.build()
|
||||
val command = CommandInfo.newBuilder()
|
||||
.setValue(execScript)
|
||||
|
@ -161,7 +151,7 @@ private[spark] class MesosSchedulerBackend(
|
|||
def enoughMemory(o: Offer) = {
|
||||
val mem = getResource(o.getResourcesList, "mem")
|
||||
val slaveId = o.getSlaveId.getValue
|
||||
mem >= EXECUTOR_MEMORY || slaveIdsWithExecutors.contains(slaveId)
|
||||
mem >= executorMemory || slaveIdsWithExecutors.contains(slaveId)
|
||||
}
|
||||
|
||||
for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) {
|
||||
|
|
|
@ -950,6 +950,7 @@ class BlockManager(
|
|||
blockInfo.clear()
|
||||
memoryStore.clear()
|
||||
diskStore.clear()
|
||||
metadataCleaner.cancel()
|
||||
logInfo("BlockManager stopped")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,8 +27,6 @@ private[spark] class BlockManagerMaster(
|
|||
val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt
|
||||
|
||||
val DRIVER_AKKA_ACTOR_NAME = "BlockMasterManager"
|
||||
val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager"
|
||||
val DEFAULT_MANAGER_IP: String = Utils.localHostName()
|
||||
|
||||
val timeout = 10.seconds
|
||||
var driverActor: ActorRef = {
|
||||
|
@ -117,6 +115,10 @@ private[spark] class BlockManagerMaster(
|
|||
askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
|
||||
}
|
||||
|
||||
def getStorageStatus: Array[StorageStatus] = {
|
||||
askDriverWithReply[ArrayBuffer[StorageStatus]](GetStorageStatus).toArray
|
||||
}
|
||||
|
||||
/** Stop the driver actor, called only on the Spark driver node */
|
||||
def stop() {
|
||||
if (driverActor != null) {
|
||||
|
|
|
@ -1,13 +1,10 @@
|
|||
package spark.storage
|
||||
|
||||
import akka.actor.{ActorRef, ActorSystem}
|
||||
import akka.pattern.ask
|
||||
import akka.util.Timeout
|
||||
import akka.util.duration._
|
||||
import cc.spray.directives._
|
||||
import cc.spray.typeconversion.TwirlSupport._
|
||||
import cc.spray.Directives
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import spark.{Logging, SparkContext}
|
||||
import spark.util.AkkaUtils
|
||||
import spark.Utils
|
||||
|
@ -48,32 +45,26 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef,
|
|||
path("") {
|
||||
completeWith {
|
||||
// Request the current storage status from the Master
|
||||
val future = blockManagerMaster ? GetStorageStatus
|
||||
future.map { status =>
|
||||
// Calculate macro-level statistics
|
||||
val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray
|
||||
val maxMem = storageStatusList.map(_.maxMem).reduce(_+_)
|
||||
val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_)
|
||||
val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize))
|
||||
.reduceOption(_+_).getOrElse(0L)
|
||||
val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc)
|
||||
spark.storage.html.index.
|
||||
render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList)
|
||||
}
|
||||
val storageStatusList = sc.getExecutorStorageStatus
|
||||
// Calculate macro-level statistics
|
||||
val maxMem = storageStatusList.map(_.maxMem).reduce(_+_)
|
||||
val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_)
|
||||
val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize))
|
||||
.reduceOption(_+_).getOrElse(0L)
|
||||
val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc)
|
||||
spark.storage.html.index.
|
||||
render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList)
|
||||
}
|
||||
} ~
|
||||
path("rdd") {
|
||||
parameter("id") { id =>
|
||||
completeWith {
|
||||
val future = blockManagerMaster ? GetStorageStatus
|
||||
future.map { status =>
|
||||
val prefix = "rdd_" + id.toString
|
||||
val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray
|
||||
val filteredStorageStatusList = StorageUtils.
|
||||
filterStorageStatusByPrefix(storageStatusList, prefix)
|
||||
val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head
|
||||
spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList)
|
||||
}
|
||||
val prefix = "rdd_" + id.toString
|
||||
val storageStatusList = sc.getExecutorStorageStatus
|
||||
val filteredStorageStatusList = StorageUtils.
|
||||
filterStorageStatusByPrefix(storageStatusList, prefix)
|
||||
val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head
|
||||
spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList)
|
||||
}
|
||||
}
|
||||
} ~
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package spark.storage
|
||||
|
||||
import spark.SparkContext
|
||||
import spark.{Utils, SparkContext}
|
||||
import BlockManagerMasterActor.BlockStatus
|
||||
|
||||
private[spark]
|
||||
|
@ -22,8 +22,13 @@ case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
|
|||
}
|
||||
|
||||
case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel,
|
||||
numPartitions: Int, memSize: Long, diskSize: Long)
|
||||
|
||||
numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) {
|
||||
override def toString = {
|
||||
import Utils.memoryBytesToString
|
||||
"RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; DiskSize: %s".format(name, id,
|
||||
storageLevel.toString, numCachedPartitions, numPartitions, memoryBytesToString(memSize), memoryBytesToString(diskSize))
|
||||
}
|
||||
}
|
||||
|
||||
/* Helper methods for storage-related objects */
|
||||
private[spark]
|
||||
|
@ -38,8 +43,6 @@ object StorageUtils {
|
|||
/* Given a list of BlockStatus objets, returns information for each RDD */
|
||||
def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
|
||||
sc: SparkContext) : Array[RDDInfo] = {
|
||||
// Find all RDD Blocks (ignore broadcast variables)
|
||||
val rddBlocks = infos.filterKeys(_.startsWith("rdd"))
|
||||
|
||||
// Group by rddId, ignore the partition name
|
||||
val groupedRddBlocks = infos.groupBy { case(k, v) =>
|
||||
|
@ -56,10 +59,11 @@ object StorageUtils {
|
|||
// Find the id of the RDD, e.g. rdd_1 => 1
|
||||
val rddId = rddKey.split("_").last.toInt
|
||||
// Get the friendly name for the rdd, if available.
|
||||
val rddName = Option(sc.persistentRdds(rddId).name).getOrElse(rddKey)
|
||||
val rddStorageLevel = sc.persistentRdds(rddId).getStorageLevel
|
||||
|
||||
RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize)
|
||||
val rdd = sc.persistentRdds(rddId)
|
||||
val rddName = Option(rdd.name).getOrElse(rddKey)
|
||||
val rddStorageLevel = rdd.getStorageLevel
|
||||
|
||||
RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.splits.size, memSize, diskSize)
|
||||
}.toArray
|
||||
}
|
||||
|
||||
|
@ -75,4 +79,4 @@ object StorageUtils {
|
|||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,9 +18,13 @@ import java.util.concurrent.TimeoutException
|
|||
* Various utility classes for working with Akka.
|
||||
*/
|
||||
private[spark] object AkkaUtils {
|
||||
|
||||
/**
|
||||
* Creates an ActorSystem ready for remoting, with various Spark features. Returns both the
|
||||
* ActorSystem itself and its port (which is hard to get from Akka).
|
||||
*
|
||||
* Note: the `name` parameter is important, as even if a client sends a message to right
|
||||
* host + port, if the system name is incorrect, Akka will drop the message.
|
||||
*/
|
||||
def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = {
|
||||
val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt
|
||||
|
@ -30,6 +34,7 @@ private[spark] object AkkaUtils {
|
|||
val akkaConf = ConfigFactory.parseString("""
|
||||
akka.daemonic = on
|
||||
akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
|
||||
akka.stdout-loglevel = "ERROR"
|
||||
akka.actor.provider = "akka.remote.RemoteActorRefProvider"
|
||||
akka.remote.transport = "akka.remote.netty.NettyRemoteTransport"
|
||||
akka.remote.log-remote-lifecycle-events = on
|
||||
|
@ -41,7 +46,7 @@ private[spark] object AkkaUtils {
|
|||
akka.actor.default-dispatcher.throughput = %d
|
||||
""".format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize))
|
||||
|
||||
val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader)
|
||||
val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader)
|
||||
|
||||
// Figure out the port number we bound to, in case port was passed as 0. This is a bit of a
|
||||
// hack because Akka doesn't let you figure out the port through the public API yet.
|
||||
|
|
|
@ -9,12 +9,12 @@ import spark.Logging
|
|||
* Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries)
|
||||
*/
|
||||
class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging {
|
||||
val delaySeconds = MetadataCleaner.getDelaySeconds
|
||||
val periodSeconds = math.max(10, delaySeconds / 10)
|
||||
val timer = new Timer(name + " cleanup timer", true)
|
||||
private val delaySeconds = MetadataCleaner.getDelaySeconds
|
||||
private val periodSeconds = math.max(10, delaySeconds / 10)
|
||||
private val timer = new Timer(name + " cleanup timer", true)
|
||||
|
||||
val task = new TimerTask {
|
||||
def run() {
|
||||
private val task = new TimerTask {
|
||||
override def run() {
|
||||
try {
|
||||
cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000))
|
||||
logInfo("Ran metadata cleaner for " + name)
|
||||
|
|
|
@ -11,7 +11,11 @@
|
|||
<strong>Storage Level:</strong>
|
||||
@(rddInfo.storageLevel.description)
|
||||
<li>
|
||||
<strong>Partitions:</strong>
|
||||
<strong>Cached Partitions:</strong>
|
||||
@(rddInfo.numCachedPartitions)
|
||||
</li>
|
||||
<li>
|
||||
<strong>Total Partitions:</strong>
|
||||
@(rddInfo.numPartitions)
|
||||
</li>
|
||||
<li>
|
||||
|
|
|
@ -6,7 +6,8 @@
|
|||
<tr>
|
||||
<th>RDD Name</th>
|
||||
<th>Storage Level</th>
|
||||
<th>Partitions</th>
|
||||
<th>Cached Partitions</th>
|
||||
<th>Fraction Partitions Cached</th>
|
||||
<th>Size in Memory</th>
|
||||
<th>Size on Disk</th>
|
||||
</tr>
|
||||
|
@ -21,7 +22,8 @@
|
|||
</td>
|
||||
<td>@(rdd.storageLevel.description)
|
||||
</td>
|
||||
<td>@rdd.numPartitions</td>
|
||||
<td>@rdd.numCachedPartitions</td>
|
||||
<td>@(rdd.numCachedPartitions / rdd.numPartitions.toDouble)</td>
|
||||
<td>@{Utils.memoryBytesToString(rdd.memSize)}</td>
|
||||
<td>@{Utils.memoryBytesToString(rdd.diskSize)}</td>
|
||||
</tr>
|
||||
|
|
|
@ -12,9 +12,10 @@ class RDDSuite extends FunSuite with LocalSparkContext {
|
|||
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
|
||||
assert(nums.collect().toList === List(1, 2, 3, 4))
|
||||
val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
|
||||
assert(dups.distinct.count === 4)
|
||||
assert(dups.distinct().collect === dups.distinct.collect)
|
||||
assert(dups.distinct(2).collect === dups.distinct.collect)
|
||||
assert(dups.distinct().count() === 4)
|
||||
assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses?
|
||||
assert(dups.distinct().collect === dups.distinct().collect)
|
||||
assert(dups.distinct(2).collect === dups.distinct().collect)
|
||||
assert(nums.reduce(_ + _) === 10)
|
||||
assert(nums.fold(0)(_ + _) === 10)
|
||||
assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4"))
|
||||
|
@ -31,6 +32,10 @@ class RDDSuite extends FunSuite with LocalSparkContext {
|
|||
case(split, iter) => Iterator((split, iter.reduceLeft(_ + _)))
|
||||
}
|
||||
assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7)))
|
||||
|
||||
intercept[UnsupportedOperationException] {
|
||||
nums.filter(_ > 5).reduce(_ + _)
|
||||
}
|
||||
}
|
||||
|
||||
test("SparkContext.union") {
|
||||
|
@ -164,7 +169,7 @@ class RDDSuite extends FunSuite with LocalSparkContext {
|
|||
// Note that split number starts from 0, so > 8 means only 10th partition left.
|
||||
val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8)
|
||||
assert(prunedRdd.splits.size === 1)
|
||||
val prunedData = prunedRdd.collect
|
||||
val prunedData = prunedRdd.collect()
|
||||
assert(prunedData.size === 1)
|
||||
assert(prunedData(0) === 10)
|
||||
}
|
||||
|
|
663
core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
Normal file
663
core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
Normal file
|
@ -0,0 +1,663 @@
|
|||
package spark.scheduler
|
||||
|
||||
import scala.collection.mutable.{Map, HashMap}
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.BeforeAndAfter
|
||||
import org.scalatest.concurrent.TimeLimitedTests
|
||||
import org.scalatest.mock.EasyMockSugar
|
||||
import org.scalatest.time.{Span, Seconds}
|
||||
|
||||
import org.easymock.EasyMock._
|
||||
import org.easymock.Capture
|
||||
import org.easymock.EasyMock
|
||||
import org.easymock.{IAnswer, IArgumentMatcher}
|
||||
|
||||
import akka.actor.ActorSystem
|
||||
|
||||
import spark.storage.BlockManager
|
||||
import spark.storage.BlockManagerId
|
||||
import spark.storage.BlockManagerMaster
|
||||
import spark.{Dependency, ShuffleDependency, OneToOneDependency}
|
||||
import spark.FetchFailedException
|
||||
import spark.MapOutputTracker
|
||||
import spark.RDD
|
||||
import spark.SparkContext
|
||||
import spark.SparkException
|
||||
import spark.Split
|
||||
import spark.TaskContext
|
||||
import spark.TaskEndReason
|
||||
|
||||
import spark.{FetchFailed, Success}
|
||||
|
||||
/**
|
||||
* Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
|
||||
* rather than spawning an event loop thread as happens in the real code. They use EasyMock
|
||||
* to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are
|
||||
* submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead
|
||||
* host notifications are sent). In addition, tests may check for side effects on a non-mocked
|
||||
* MapOutputTracker instance.
|
||||
*
|
||||
* Tests primarily consist of running DAGScheduler#processEvent and
|
||||
* DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet)
|
||||
* and capturing the resulting TaskSets from the mock TaskScheduler.
|
||||
*/
|
||||
class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar with TimeLimitedTests {
|
||||
|
||||
// impose a time limit on this test in case we don't let the job finish, in which case
|
||||
// JobWaiter#getResult will hang.
|
||||
override val timeLimit = Span(5, Seconds)
|
||||
|
||||
val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite")
|
||||
var scheduler: DAGScheduler = null
|
||||
val taskScheduler = mock[TaskScheduler]
|
||||
val blockManagerMaster = mock[BlockManagerMaster]
|
||||
var mapOutputTracker: MapOutputTracker = null
|
||||
var schedulerThread: Thread = null
|
||||
var schedulerException: Throwable = null
|
||||
|
||||
/**
|
||||
* Set of EasyMock argument matchers that match a TaskSet for a given RDD.
|
||||
* We cache these so we do not create duplicate matchers for the same RDD.
|
||||
* This allows us to easily setup a sequence of expectations for task sets for
|
||||
* that RDD.
|
||||
*/
|
||||
val taskSetMatchers = new HashMap[MyRDD, IArgumentMatcher]
|
||||
|
||||
/**
|
||||
* Set of cache locations to return from our mock BlockManagerMaster.
|
||||
* Keys are (rdd ID, partition ID). Anything not present will return an empty
|
||||
* list of cache locations silently.
|
||||
*/
|
||||
val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
|
||||
|
||||
/**
|
||||
* JobWaiter for the last JobSubmitted event we pushed. To keep tests (most of which
|
||||
* will only submit one job) from needing to explicitly track it.
|
||||
*/
|
||||
var lastJobWaiter: JobWaiter[Int] = null
|
||||
|
||||
/**
|
||||
* Array into which we are accumulating the results from the last job asynchronously.
|
||||
*/
|
||||
var lastJobResult: Array[Int] = null
|
||||
|
||||
/**
|
||||
* Tell EasyMockSugar what mock objects we want to be configured by expecting {...}
|
||||
* and whenExecuting {...} */
|
||||
implicit val mocks = MockObjects(taskScheduler, blockManagerMaster)
|
||||
|
||||
/**
|
||||
* Utility function to reset mocks and set expectations on them. EasyMock wants mock objects
|
||||
* to be reset after each time their expectations are set, and we tend to check mock object
|
||||
* calls over a single call to DAGScheduler.
|
||||
*
|
||||
* We also set a default expectation here that blockManagerMaster.getLocations can be called
|
||||
* and will return values from cacheLocations.
|
||||
*/
|
||||
def resetExpecting(f: => Unit) {
|
||||
reset(taskScheduler)
|
||||
reset(blockManagerMaster)
|
||||
expecting {
|
||||
expectGetLocations()
|
||||
f
|
||||
}
|
||||
}
|
||||
|
||||
before {
|
||||
taskSetMatchers.clear()
|
||||
cacheLocations.clear()
|
||||
val actorSystem = ActorSystem("test")
|
||||
mapOutputTracker = new MapOutputTracker(actorSystem, true)
|
||||
resetExpecting {
|
||||
taskScheduler.setListener(anyObject())
|
||||
}
|
||||
whenExecuting {
|
||||
scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null)
|
||||
}
|
||||
}
|
||||
|
||||
after {
|
||||
assert(scheduler.processEvent(StopDAGScheduler))
|
||||
resetExpecting {
|
||||
taskScheduler.stop()
|
||||
}
|
||||
whenExecuting {
|
||||
scheduler.stop()
|
||||
}
|
||||
sc.stop()
|
||||
System.clearProperty("spark.master.port")
|
||||
}
|
||||
|
||||
def makeBlockManagerId(host: String): BlockManagerId =
|
||||
BlockManagerId("exec-" + host, host, 12345)
|
||||
|
||||
/**
|
||||
* Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
|
||||
* This is a pair RDD type so it can always be used in ShuffleDependencies.
|
||||
*/
|
||||
type MyRDD = RDD[(Int, Int)]
|
||||
|
||||
/**
|
||||
* Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and
|
||||
* preferredLocations (if any) that are passed to them. They are deliberately not executable
|
||||
* so we can test that DAGScheduler does not try to execute RDDs locally.
|
||||
*/
|
||||
def makeRdd(
|
||||
numSplits: Int,
|
||||
dependencies: List[Dependency[_]],
|
||||
locations: Seq[Seq[String]] = Nil
|
||||
): MyRDD = {
|
||||
val maxSplit = numSplits - 1
|
||||
return new MyRDD(sc, dependencies) {
|
||||
override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] =
|
||||
throw new RuntimeException("should not be reached")
|
||||
override def getSplits() = (0 to maxSplit).map(i => new Split {
|
||||
override def index = i
|
||||
}).toArray
|
||||
override def getPreferredLocations(split: Split): Seq[String] =
|
||||
if (locations.isDefinedAt(split.index))
|
||||
locations(split.index)
|
||||
else
|
||||
Nil
|
||||
override def toString: String = "DAGSchedulerSuiteRDD " + id
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* EasyMock matcher method. For use as an argument matcher for a TaskSet whose first task
|
||||
* is from a particular RDD.
|
||||
*/
|
||||
def taskSetForRdd(rdd: MyRDD): TaskSet = {
|
||||
val matcher = taskSetMatchers.getOrElseUpdate(rdd,
|
||||
new IArgumentMatcher {
|
||||
override def matches(actual: Any): Boolean = {
|
||||
val taskSet = actual.asInstanceOf[TaskSet]
|
||||
taskSet.tasks(0) match {
|
||||
case rt: ResultTask[_, _] => rt.rdd.id == rdd.id
|
||||
case smt: ShuffleMapTask => smt.rdd.id == rdd.id
|
||||
case _ => false
|
||||
}
|
||||
}
|
||||
override def appendTo(buf: StringBuffer) {
|
||||
buf.append("taskSetForRdd(" + rdd + ")")
|
||||
}
|
||||
})
|
||||
EasyMock.reportMatcher(matcher)
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Setup an EasyMock expectation to repsond to blockManagerMaster.getLocations() called from
|
||||
* cacheLocations.
|
||||
*/
|
||||
def expectGetLocations(): Unit = {
|
||||
EasyMock.expect(blockManagerMaster.getLocations(anyObject().asInstanceOf[Array[String]])).
|
||||
andAnswer(new IAnswer[Seq[Seq[BlockManagerId]]] {
|
||||
override def answer(): Seq[Seq[BlockManagerId]] = {
|
||||
val blocks = getCurrentArguments()(0).asInstanceOf[Array[String]]
|
||||
return blocks.map { name =>
|
||||
val pieces = name.split("_")
|
||||
if (pieces(0) == "rdd") {
|
||||
val key = pieces(1).toInt -> pieces(2).toInt
|
||||
if (cacheLocations.contains(key)) {
|
||||
cacheLocations(key)
|
||||
} else {
|
||||
Seq[BlockManagerId]()
|
||||
}
|
||||
} else {
|
||||
Seq[BlockManagerId]()
|
||||
}
|
||||
}.toSeq
|
||||
}
|
||||
}).anyTimes()
|
||||
}
|
||||
|
||||
/**
|
||||
* Process the supplied event as if it were the top of the DAGScheduler event queue, expecting
|
||||
* the scheduler not to exit.
|
||||
*
|
||||
* After processing the event, submit waiting stages as is done on most iterations of the
|
||||
* DAGScheduler event loop.
|
||||
*/
|
||||
def runEvent(event: DAGSchedulerEvent) {
|
||||
assert(!scheduler.processEvent(event))
|
||||
scheduler.submitWaitingStages()
|
||||
}
|
||||
|
||||
/**
|
||||
* Expect a TaskSet for the specified RDD to be submitted to the TaskScheduler. Should be
|
||||
* called from a resetExpecting { ... } block.
|
||||
*
|
||||
* Returns a easymock Capture that will contain the task set after the stage is submitted.
|
||||
* Most tests should use interceptStage() instead of this directly.
|
||||
*/
|
||||
def expectStage(rdd: MyRDD): Capture[TaskSet] = {
|
||||
val taskSetCapture = new Capture[TaskSet]
|
||||
taskScheduler.submitTasks(and(capture(taskSetCapture), taskSetForRdd(rdd)))
|
||||
return taskSetCapture
|
||||
}
|
||||
|
||||
/**
|
||||
* Expect the supplied code snippet to submit a stage for the specified RDD.
|
||||
* Return the resulting TaskSet. First marks all the tasks are belonging to the
|
||||
* current MapOutputTracker generation.
|
||||
*/
|
||||
def interceptStage(rdd: MyRDD)(f: => Unit): TaskSet = {
|
||||
var capture: Capture[TaskSet] = null
|
||||
resetExpecting {
|
||||
capture = expectStage(rdd)
|
||||
}
|
||||
whenExecuting {
|
||||
f
|
||||
}
|
||||
val taskSet = capture.getValue
|
||||
for (task <- taskSet.tasks) {
|
||||
task.generation = mapOutputTracker.getGeneration
|
||||
}
|
||||
return taskSet
|
||||
}
|
||||
|
||||
/**
|
||||
* Send the given CompletionEvent messages for the tasks in the TaskSet.
|
||||
*/
|
||||
def respondToTaskSet(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
|
||||
assert(taskSet.tasks.size >= results.size)
|
||||
for ((result, i) <- results.zipWithIndex) {
|
||||
if (i < taskSet.tasks.size) {
|
||||
runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any]()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Assert that the supplied TaskSet has exactly the given preferredLocations.
|
||||
*/
|
||||
def expectTaskSetLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
|
||||
assert(locations.size === taskSet.tasks.size)
|
||||
for ((expectLocs, taskLocs) <-
|
||||
taskSet.tasks.map(_.preferredLocations).zip(locations)) {
|
||||
assert(expectLocs === taskLocs)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* When we submit dummy Jobs, this is the compute function we supply. Except in a local test
|
||||
* below, we do not expect this function to ever be executed; instead, we will return results
|
||||
* directly through CompletionEvents.
|
||||
*/
|
||||
def jobComputeFunc(context: TaskContext, it: Iterator[(Int, Int)]): Int =
|
||||
it.next._1.asInstanceOf[Int]
|
||||
|
||||
|
||||
/**
|
||||
* Start a job to compute the given RDD. Returns the JobWaiter that will
|
||||
* collect the result of the job via callbacks from DAGScheduler.
|
||||
*/
|
||||
def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): (JobWaiter[Int], Array[Int]) = {
|
||||
val resultArray = new Array[Int](rdd.splits.size)
|
||||
val (toSubmit, waiter) = scheduler.prepareJob[(Int, Int), Int](
|
||||
rdd,
|
||||
jobComputeFunc,
|
||||
(0 to (rdd.splits.size - 1)),
|
||||
"test-site",
|
||||
allowLocal,
|
||||
(i: Int, value: Int) => resultArray(i) = value
|
||||
)
|
||||
lastJobWaiter = waiter
|
||||
lastJobResult = resultArray
|
||||
runEvent(toSubmit)
|
||||
return (waiter, resultArray)
|
||||
}
|
||||
|
||||
/**
|
||||
* Assert that a job we started has failed.
|
||||
*/
|
||||
def expectJobException(waiter: JobWaiter[Int] = lastJobWaiter) {
|
||||
waiter.awaitResult() match {
|
||||
case JobSucceeded => fail()
|
||||
case JobFailed(_) => return
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Assert that a job we started has succeeded and has the given result.
|
||||
*/
|
||||
def expectJobResult(expected: Array[Int], waiter: JobWaiter[Int] = lastJobWaiter,
|
||||
result: Array[Int] = lastJobResult) {
|
||||
waiter.awaitResult match {
|
||||
case JobSucceeded =>
|
||||
assert(expected === result)
|
||||
case JobFailed(_) =>
|
||||
fail()
|
||||
}
|
||||
}
|
||||
|
||||
def makeMapStatus(host: String, reduces: Int): MapStatus =
|
||||
new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
|
||||
|
||||
test("zero split job") {
|
||||
val rdd = makeRdd(0, Nil)
|
||||
var numResults = 0
|
||||
def accumulateResult(partition: Int, value: Int) {
|
||||
numResults += 1
|
||||
}
|
||||
scheduler.runJob(rdd, jobComputeFunc, Seq(), "test-site", false, accumulateResult)
|
||||
assert(numResults === 0)
|
||||
}
|
||||
|
||||
test("run trivial job") {
|
||||
val rdd = makeRdd(1, Nil)
|
||||
val taskSet = interceptStage(rdd) { submitRdd(rdd) }
|
||||
respondToTaskSet(taskSet, List( (Success, 42) ))
|
||||
expectJobResult(Array(42))
|
||||
}
|
||||
|
||||
test("local job") {
|
||||
val rdd = new MyRDD(sc, Nil) {
|
||||
override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] =
|
||||
Array(42 -> 0).iterator
|
||||
override def getSplits() = Array( new Split { override def index = 0 } )
|
||||
override def getPreferredLocations(split: Split) = Nil
|
||||
override def toString = "DAGSchedulerSuite Local RDD"
|
||||
}
|
||||
submitRdd(rdd, true)
|
||||
expectJobResult(Array(42))
|
||||
}
|
||||
|
||||
test("run trivial job w/ dependency") {
|
||||
val baseRdd = makeRdd(1, Nil)
|
||||
val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
|
||||
val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
|
||||
respondToTaskSet(taskSet, List( (Success, 42) ))
|
||||
expectJobResult(Array(42))
|
||||
}
|
||||
|
||||
test("cache location preferences w/ dependency") {
|
||||
val baseRdd = makeRdd(1, Nil)
|
||||
val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
|
||||
cacheLocations(baseRdd.id -> 0) =
|
||||
Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
|
||||
val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
|
||||
expectTaskSetLocations(taskSet, List(Seq("hostA", "hostB")))
|
||||
respondToTaskSet(taskSet, List( (Success, 42) ))
|
||||
expectJobResult(Array(42))
|
||||
}
|
||||
|
||||
test("trivial job failure") {
|
||||
val rdd = makeRdd(1, Nil)
|
||||
val taskSet = interceptStage(rdd) { submitRdd(rdd) }
|
||||
runEvent(TaskSetFailed(taskSet, "test failure"))
|
||||
expectJobException()
|
||||
}
|
||||
|
||||
test("run trivial shuffle") {
|
||||
val shuffleMapRdd = makeRdd(2, Nil)
|
||||
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
|
||||
val shuffleId = shuffleDep.shuffleId
|
||||
val reduceRdd = makeRdd(1, List(shuffleDep))
|
||||
|
||||
val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
|
||||
val secondStage = interceptStage(reduceRdd) {
|
||||
respondToTaskSet(firstStage, List(
|
||||
(Success, makeMapStatus("hostA", 1)),
|
||||
(Success, makeMapStatus("hostB", 1))
|
||||
))
|
||||
}
|
||||
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
|
||||
Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
|
||||
respondToTaskSet(secondStage, List( (Success, 42) ))
|
||||
expectJobResult(Array(42))
|
||||
}
|
||||
|
||||
test("run trivial shuffle with fetch failure") {
|
||||
val shuffleMapRdd = makeRdd(2, Nil)
|
||||
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
|
||||
val shuffleId = shuffleDep.shuffleId
|
||||
val reduceRdd = makeRdd(2, List(shuffleDep))
|
||||
|
||||
val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
|
||||
val secondStage = interceptStage(reduceRdd) {
|
||||
respondToTaskSet(firstStage, List(
|
||||
(Success, makeMapStatus("hostA", 1)),
|
||||
(Success, makeMapStatus("hostB", 1))
|
||||
))
|
||||
}
|
||||
resetExpecting {
|
||||
blockManagerMaster.removeExecutor("exec-hostA")
|
||||
}
|
||||
whenExecuting {
|
||||
respondToTaskSet(secondStage, List(
|
||||
(Success, 42),
|
||||
(FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)
|
||||
))
|
||||
}
|
||||
val thirdStage = interceptStage(shuffleMapRdd) {
|
||||
scheduler.resubmitFailedStages()
|
||||
}
|
||||
val fourthStage = interceptStage(reduceRdd) {
|
||||
respondToTaskSet(thirdStage, List( (Success, makeMapStatus("hostA", 1)) ))
|
||||
}
|
||||
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
|
||||
Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
|
||||
respondToTaskSet(fourthStage, List( (Success, 43) ))
|
||||
expectJobResult(Array(42, 43))
|
||||
}
|
||||
|
||||
test("ignore late map task completions") {
|
||||
val shuffleMapRdd = makeRdd(2, Nil)
|
||||
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
|
||||
val shuffleId = shuffleDep.shuffleId
|
||||
val reduceRdd = makeRdd(2, List(shuffleDep))
|
||||
|
||||
val taskSet = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
|
||||
val oldGeneration = mapOutputTracker.getGeneration
|
||||
resetExpecting {
|
||||
blockManagerMaster.removeExecutor("exec-hostA")
|
||||
}
|
||||
whenExecuting {
|
||||
runEvent(ExecutorLost("exec-hostA"))
|
||||
}
|
||||
val newGeneration = mapOutputTracker.getGeneration
|
||||
assert(newGeneration > oldGeneration)
|
||||
val noAccum = Map[Long, Any]()
|
||||
// We rely on the event queue being ordered and increasing the generation number by 1
|
||||
// should be ignored for being too old
|
||||
runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
|
||||
// should work because it's a non-failed host
|
||||
runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum))
|
||||
// should be ignored for being too old
|
||||
runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
|
||||
taskSet.tasks(1).generation = newGeneration
|
||||
val secondStage = interceptStage(reduceRdd) {
|
||||
runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum))
|
||||
}
|
||||
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
|
||||
Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
|
||||
respondToTaskSet(secondStage, List( (Success, 42), (Success, 43) ))
|
||||
expectJobResult(Array(42, 43))
|
||||
}
|
||||
|
||||
test("run trivial shuffle with out-of-band failure and retry") {
|
||||
val shuffleMapRdd = makeRdd(2, Nil)
|
||||
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
|
||||
val shuffleId = shuffleDep.shuffleId
|
||||
val reduceRdd = makeRdd(1, List(shuffleDep))
|
||||
|
||||
val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
|
||||
resetExpecting {
|
||||
blockManagerMaster.removeExecutor("exec-hostA")
|
||||
}
|
||||
whenExecuting {
|
||||
runEvent(ExecutorLost("exec-hostA"))
|
||||
}
|
||||
// DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
|
||||
// rather than marking it is as failed and waiting.
|
||||
val secondStage = interceptStage(shuffleMapRdd) {
|
||||
respondToTaskSet(firstStage, List(
|
||||
(Success, makeMapStatus("hostA", 1)),
|
||||
(Success, makeMapStatus("hostB", 1))
|
||||
))
|
||||
}
|
||||
val thirdStage = interceptStage(reduceRdd) {
|
||||
respondToTaskSet(secondStage, List(
|
||||
(Success, makeMapStatus("hostC", 1))
|
||||
))
|
||||
}
|
||||
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
|
||||
Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
|
||||
respondToTaskSet(thirdStage, List( (Success, 42) ))
|
||||
expectJobResult(Array(42))
|
||||
}
|
||||
|
||||
test("recursive shuffle failures") {
|
||||
val shuffleOneRdd = makeRdd(2, Nil)
|
||||
val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
|
||||
val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
|
||||
val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
|
||||
val finalRdd = makeRdd(1, List(shuffleDepTwo))
|
||||
|
||||
val firstStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
|
||||
val secondStage = interceptStage(shuffleTwoRdd) {
|
||||
respondToTaskSet(firstStage, List(
|
||||
(Success, makeMapStatus("hostA", 2)),
|
||||
(Success, makeMapStatus("hostB", 2))
|
||||
))
|
||||
}
|
||||
val thirdStage = interceptStage(finalRdd) {
|
||||
respondToTaskSet(secondStage, List(
|
||||
(Success, makeMapStatus("hostA", 1)),
|
||||
(Success, makeMapStatus("hostC", 1))
|
||||
))
|
||||
}
|
||||
resetExpecting {
|
||||
blockManagerMaster.removeExecutor("exec-hostA")
|
||||
}
|
||||
whenExecuting {
|
||||
respondToTaskSet(thirdStage, List(
|
||||
(FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
|
||||
))
|
||||
}
|
||||
val recomputeOne = interceptStage(shuffleOneRdd) {
|
||||
scheduler.resubmitFailedStages()
|
||||
}
|
||||
val recomputeTwo = interceptStage(shuffleTwoRdd) {
|
||||
respondToTaskSet(recomputeOne, List(
|
||||
(Success, makeMapStatus("hostA", 2))
|
||||
))
|
||||
}
|
||||
val finalStage = interceptStage(finalRdd) {
|
||||
respondToTaskSet(recomputeTwo, List(
|
||||
(Success, makeMapStatus("hostA", 1))
|
||||
))
|
||||
}
|
||||
respondToTaskSet(finalStage, List( (Success, 42) ))
|
||||
expectJobResult(Array(42))
|
||||
}
|
||||
|
||||
test("cached post-shuffle") {
|
||||
val shuffleOneRdd = makeRdd(2, Nil)
|
||||
val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
|
||||
val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
|
||||
val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
|
||||
val finalRdd = makeRdd(1, List(shuffleDepTwo))
|
||||
|
||||
val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
|
||||
cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
|
||||
cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
|
||||
val secondShuffleStage = interceptStage(shuffleTwoRdd) {
|
||||
respondToTaskSet(firstShuffleStage, List(
|
||||
(Success, makeMapStatus("hostA", 2)),
|
||||
(Success, makeMapStatus("hostB", 2))
|
||||
))
|
||||
}
|
||||
val reduceStage = interceptStage(finalRdd) {
|
||||
respondToTaskSet(secondShuffleStage, List(
|
||||
(Success, makeMapStatus("hostA", 1)),
|
||||
(Success, makeMapStatus("hostB", 1))
|
||||
))
|
||||
}
|
||||
resetExpecting {
|
||||
blockManagerMaster.removeExecutor("exec-hostA")
|
||||
}
|
||||
whenExecuting {
|
||||
respondToTaskSet(reduceStage, List(
|
||||
(FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
|
||||
))
|
||||
}
|
||||
// DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
|
||||
val recomputeTwo = interceptStage(shuffleTwoRdd) {
|
||||
scheduler.resubmitFailedStages()
|
||||
}
|
||||
expectTaskSetLocations(recomputeTwo, Seq(Seq("hostD")))
|
||||
val finalRetry = interceptStage(finalRdd) {
|
||||
respondToTaskSet(recomputeTwo, List(
|
||||
(Success, makeMapStatus("hostD", 1))
|
||||
))
|
||||
}
|
||||
respondToTaskSet(finalRetry, List( (Success, 42) ))
|
||||
expectJobResult(Array(42))
|
||||
}
|
||||
|
||||
test("cached post-shuffle but fails") {
|
||||
val shuffleOneRdd = makeRdd(2, Nil)
|
||||
val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
|
||||
val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
|
||||
val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
|
||||
val finalRdd = makeRdd(1, List(shuffleDepTwo))
|
||||
|
||||
val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
|
||||
cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
|
||||
cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
|
||||
val secondShuffleStage = interceptStage(shuffleTwoRdd) {
|
||||
respondToTaskSet(firstShuffleStage, List(
|
||||
(Success, makeMapStatus("hostA", 2)),
|
||||
(Success, makeMapStatus("hostB", 2))
|
||||
))
|
||||
}
|
||||
val reduceStage = interceptStage(finalRdd) {
|
||||
respondToTaskSet(secondShuffleStage, List(
|
||||
(Success, makeMapStatus("hostA", 1)),
|
||||
(Success, makeMapStatus("hostB", 1))
|
||||
))
|
||||
}
|
||||
resetExpecting {
|
||||
blockManagerMaster.removeExecutor("exec-hostA")
|
||||
}
|
||||
whenExecuting {
|
||||
respondToTaskSet(reduceStage, List(
|
||||
(FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
|
||||
))
|
||||
}
|
||||
val recomputeTwoCached = interceptStage(shuffleTwoRdd) {
|
||||
scheduler.resubmitFailedStages()
|
||||
}
|
||||
expectTaskSetLocations(recomputeTwoCached, Seq(Seq("hostD")))
|
||||
intercept[FetchFailedException]{
|
||||
mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0)
|
||||
}
|
||||
|
||||
// Simulate the shuffle input data failing to be cached.
|
||||
cacheLocations.remove(shuffleTwoRdd.id -> 0)
|
||||
respondToTaskSet(recomputeTwoCached, List(
|
||||
(FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null)
|
||||
))
|
||||
|
||||
// After the fetch failure, DAGScheduler should recheck the cache and decide to resubmit
|
||||
// everything.
|
||||
val recomputeOne = interceptStage(shuffleOneRdd) {
|
||||
scheduler.resubmitFailedStages()
|
||||
}
|
||||
// We use hostA here to make sure DAGScheduler doesn't think it's still dead.
|
||||
val recomputeTwoUncached = interceptStage(shuffleTwoRdd) {
|
||||
respondToTaskSet(recomputeOne, List( (Success, makeMapStatus("hostA", 1)) ))
|
||||
}
|
||||
expectTaskSetLocations(recomputeTwoUncached, Seq(Seq[String]()))
|
||||
val finalRetry = interceptStage(finalRdd) {
|
||||
respondToTaskSet(recomputeTwoUncached, List( (Success, makeMapStatus("hostA", 1)) ))
|
||||
|
||||
}
|
||||
respondToTaskSet(finalRetry, List( (Success, 42) ))
|
||||
expectJobResult(Array(42))
|
||||
}
|
||||
}
|
6
pom.xml
6
pom.xml
|
@ -273,6 +273,12 @@
|
|||
<version>1.8</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.easymock</groupId>
|
||||
<artifactId>easymock</artifactId>
|
||||
<version>3.1</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scalacheck</groupId>
|
||||
<artifactId>scalacheck_${scala.version}</artifactId>
|
||||
|
|
|
@ -92,7 +92,8 @@ object SparkBuild extends Build {
|
|||
"org.eclipse.jetty" % "jetty-server" % "7.5.3.v20111011",
|
||||
"org.scalatest" %% "scalatest" % "1.8" % "test",
|
||||
"org.scalacheck" %% "scalacheck" % "1.9" % "test",
|
||||
"com.novocode" % "junit-interface" % "0.8" % "test"
|
||||
"com.novocode" % "junit-interface" % "0.8" % "test",
|
||||
"org.easymock" % "easymock" % "3.1" % "test"
|
||||
),
|
||||
parallelExecution := false,
|
||||
/* Workaround for issue #206 (fixed after SBT 0.11.0) */
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
import os
|
||||
import atexit
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from threading import Lock
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
|
@ -24,11 +22,10 @@ class SparkContext(object):
|
|||
broadcast variables on that cluster.
|
||||
"""
|
||||
|
||||
gateway = launch_gateway()
|
||||
jvm = gateway.jvm
|
||||
_readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
|
||||
_writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
|
||||
_takePartition = jvm.PythonRDD.takePartition
|
||||
_gateway = None
|
||||
_jvm = None
|
||||
_writeIteratorToPickleFile = None
|
||||
_takePartition = None
|
||||
_next_accum_id = 0
|
||||
_active_spark_context = None
|
||||
_lock = Lock()
|
||||
|
@ -56,6 +53,13 @@ class SparkContext(object):
|
|||
raise ValueError("Cannot run multiple SparkContexts at once")
|
||||
else:
|
||||
SparkContext._active_spark_context = self
|
||||
if not SparkContext._gateway:
|
||||
SparkContext._gateway = launch_gateway()
|
||||
SparkContext._jvm = SparkContext._gateway.jvm
|
||||
SparkContext._writeIteratorToPickleFile = \
|
||||
SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
|
||||
SparkContext._takePartition = \
|
||||
SparkContext._jvm.PythonRDD.takePartition
|
||||
self.master = master
|
||||
self.jobName = jobName
|
||||
self.sparkHome = sparkHome or None # None becomes null in Py4J
|
||||
|
@ -63,8 +67,8 @@ class SparkContext(object):
|
|||
self.batchSize = batchSize # -1 represents a unlimited batch size
|
||||
|
||||
# Create the Java SparkContext through Py4J
|
||||
empty_string_array = self.gateway.new_array(self.jvm.String, 0)
|
||||
self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
|
||||
empty_string_array = self._gateway.new_array(self._jvm.String, 0)
|
||||
self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome,
|
||||
empty_string_array)
|
||||
|
||||
# Create a single Accumulator in Java that we'll send all our updates through;
|
||||
|
@ -72,8 +76,8 @@ class SparkContext(object):
|
|||
self._accumulatorServer = accumulators._start_update_server()
|
||||
(host, port) = self._accumulatorServer.server_address
|
||||
self._javaAccumulator = self._jsc.accumulator(
|
||||
self.jvm.java.util.ArrayList(),
|
||||
self.jvm.PythonAccumulatorParam(host, port))
|
||||
self._jvm.java.util.ArrayList(),
|
||||
self._jvm.PythonAccumulatorParam(host, port))
|
||||
|
||||
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
|
||||
# Broadcast's __reduce__ method stores Broadcast instances here.
|
||||
|
@ -88,6 +92,11 @@ class SparkContext(object):
|
|||
SparkFiles._sc = self
|
||||
sys.path.append(SparkFiles.getRootDirectory())
|
||||
|
||||
# Create a temporary directory inside spark.local.dir:
|
||||
local_dir = self._jvm.spark.Utils.getLocalDir()
|
||||
self._temp_dir = \
|
||||
self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath()
|
||||
|
||||
@property
|
||||
def defaultParallelism(self):
|
||||
"""
|
||||
|
@ -120,14 +129,14 @@ class SparkContext(object):
|
|||
# Calling the Java parallelize() method with an ArrayList is too slow,
|
||||
# because it sends O(n) Py4J commands. As an alternative, serialized
|
||||
# objects are written to a file and loaded through textFile().
|
||||
tempFile = NamedTemporaryFile(delete=False)
|
||||
atexit.register(lambda: os.unlink(tempFile.name))
|
||||
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
|
||||
if self.batchSize != 1:
|
||||
c = batched(c, self.batchSize)
|
||||
for x in c:
|
||||
write_with_length(dump_pickle(x), tempFile)
|
||||
tempFile.close()
|
||||
jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
|
||||
readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
|
||||
jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
|
||||
return RDD(jrdd, self)
|
||||
|
||||
def textFile(self, name, minSplits=None):
|
||||
|
@ -240,7 +249,9 @@ class SparkContext(object):
|
|||
|
||||
|
||||
def _test():
|
||||
import atexit
|
||||
import doctest
|
||||
import tempfile
|
||||
globs = globals().copy()
|
||||
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
|
||||
globs['tempdir'] = tempfile.mkdtemp()
|
||||
|
|
|
@ -35,4 +35,4 @@ class SparkFiles(object):
|
|||
return cls._root_directory
|
||||
else:
|
||||
# This will have to change if we support multiple SparkContexts:
|
||||
return cls._sc.jvm.spark.SparkFiles.getRootDirectory()
|
||||
return cls._sc._jvm.spark.SparkFiles.getRootDirectory()
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import atexit
|
||||
from base64 import standard_b64encode as b64enc
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
|
@ -264,12 +263,8 @@ class RDD(object):
|
|||
# Transferring lots of data through Py4J can be slow because
|
||||
# socket.readline() is inefficient. Instead, we'll dump the data to a
|
||||
# file and read it back.
|
||||
tempFile = NamedTemporaryFile(delete=False)
|
||||
tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
|
||||
tempFile.close()
|
||||
def clean_up_file():
|
||||
try: os.unlink(tempFile.name)
|
||||
except: pass
|
||||
atexit.register(clean_up_file)
|
||||
self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
|
||||
# Read the data into Python and deserialize it:
|
||||
with open(tempFile.name, 'rb') as tempFile:
|
||||
|
@ -407,7 +402,7 @@ class RDD(object):
|
|||
return (str(x).encode("utf-8") for x in iterator)
|
||||
keyed = PipelinedRDD(self, func)
|
||||
keyed._bypass_serializer = True
|
||||
keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path)
|
||||
keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path)
|
||||
|
||||
# Pair functions
|
||||
|
||||
|
@ -550,8 +545,8 @@ class RDD(object):
|
|||
yield dump_pickle(Batch(items))
|
||||
keyed = PipelinedRDD(self, add_shuffle_key)
|
||||
keyed._bypass_serializer = True
|
||||
pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
|
||||
partitioner = self.ctx.jvm.PythonPartitioner(numSplits,
|
||||
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
|
||||
partitioner = self.ctx._jvm.PythonPartitioner(numSplits,
|
||||
id(partitionFunc))
|
||||
jrdd = pairRDD.partitionBy(partitioner).values()
|
||||
rdd = RDD(jrdd, self.ctx)
|
||||
|
@ -730,13 +725,13 @@ class PipelinedRDD(RDD):
|
|||
pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
|
||||
broadcast_vars = ListConverter().convert(
|
||||
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
|
||||
self.ctx.gateway._gateway_client)
|
||||
self.ctx._gateway._gateway_client)
|
||||
self.ctx._pickled_broadcast_vars.clear()
|
||||
class_manifest = self._prev_jrdd.classManifest()
|
||||
env = copy.copy(self.ctx.environment)
|
||||
env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "")
|
||||
env = MapConverter().convert(env, self.ctx.gateway._gateway_client)
|
||||
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
|
||||
env = MapConverter().convert(env, self.ctx._gateway._gateway_client)
|
||||
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
|
||||
pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
|
||||
broadcast_vars, self.ctx._javaAccumulator, class_manifest)
|
||||
self._jrdd_val = python_rdd.asJavaRDD()
|
||||
|
|
|
@ -26,7 +26,7 @@ class PySparkTestCase(unittest.TestCase):
|
|||
sys.path = self._old_sys_path
|
||||
# To avoid Akka rebinding to the same port, since it doesn't unbind
|
||||
# immediately on shutdown
|
||||
self.sc.jvm.System.clearProperty("spark.driver.port")
|
||||
self.sc._jvm.System.clearProperty("spark.driver.port")
|
||||
|
||||
|
||||
class TestCheckpoint(PySparkTestCase):
|
||||
|
@ -108,5 +108,14 @@ class TestAddFile(PySparkTestCase):
|
|||
self.assertEqual("Hello World!", UserClass().hello())
|
||||
|
||||
|
||||
class TestIO(PySparkTestCase):
|
||||
|
||||
def test_stdout_redirection(self):
|
||||
import subprocess
|
||||
def func(x):
|
||||
subprocess.check_call('ls', shell=True)
|
||||
self.sc.parallelize([1]).foreach(func)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Worker that receives input from Piped RDD.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from base64 import standard_b64decode
|
||||
|
@ -15,8 +16,8 @@ from pyspark.serializers import write_with_length, read_with_length, write_int,
|
|||
|
||||
|
||||
# Redirect stdout to stderr so that users must return values from functions.
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = sys.stderr
|
||||
old_stdout = os.fdopen(os.dup(1), 'w')
|
||||
os.dup2(2, 1)
|
||||
|
||||
|
||||
def load_obj():
|
||||
|
|
Loading…
Reference in a new issue