Merge branch 'mesos-master' into streaming

This commit is contained in:
Tathagata Das 2013-02-07 13:59:31 -08:00
commit 4cc223b478
177 changed files with 3881 additions and 2225 deletions

View file

@ -45,11 +45,6 @@
<profiles> <profiles>
<profile> <profile>
<id>hadoop1</id> <id>hadoop1</id>
<activation>
<property>
<name>!hadoopVersion</name>
</property>
</activation>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.spark-project</groupId> <groupId>org.spark-project</groupId>
@ -77,12 +72,6 @@
</profile> </profile>
<profile> <profile>
<id>hadoop2</id> <id>hadoop2</id>
<activation>
<property>
<name>hadoopVersion</name>
<value>2</value>
</property>
</activation>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.spark-project</groupId> <groupId>org.spark-project</groupId>

View file

@ -23,7 +23,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter {
sc = null sc = null
} }
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.master.port") System.clearProperty("spark.driver.port")
} }
test("halting by voting") { test("halting by voting") {

View file

@ -26,7 +26,8 @@ fi
# Set SPARK_PUBLIC_DNS so the master report the correct webUI address to the slaves # Set SPARK_PUBLIC_DNS so the master report the correct webUI address to the slaves
if [ "$SPARK_PUBLIC_DNS" = "" ]; then if [ "$SPARK_PUBLIC_DNS" = "" ]; then
# If we appear to be running on EC2, use the public address by default: # If we appear to be running on EC2, use the public address by default:
if [[ `hostname` == *ec2.internal ]]; then # NOTE: ec2-metadata is installed on Amazon Linux AMI. Check based on that and hostname
if command -v ec2-metadata > /dev/null || [[ `hostname` == *ec2.internal ]]; then
export SPARK_PUBLIC_DNS=`wget -q -O - http://instance-data.ec2.internal/latest/meta-data/public-hostname` export SPARK_PUBLIC_DNS=`wget -q -O - http://instance-data.ec2.internal/latest/meta-data/public-hostname`
fi fi
fi fi

View file

@ -98,6 +98,11 @@
<artifactId>scalacheck_${scala.version}</artifactId> <artifactId>scalacheck_${scala.version}</artifactId>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.easymock</groupId>
<artifactId>easymock</artifactId>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>com.novocode</groupId> <groupId>com.novocode</groupId>
<artifactId>junit-interface</artifactId> <artifactId>junit-interface</artifactId>
@ -163,11 +168,6 @@
<profiles> <profiles>
<profile> <profile>
<id>hadoop1</id> <id>hadoop1</id>
<activation>
<property>
<name>!hadoopVersion</name>
</property>
</activation>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.apache.hadoop</groupId> <groupId>org.apache.hadoop</groupId>
@ -220,12 +220,6 @@
</profile> </profile>
<profile> <profile>
<id>hadoop2</id> <id>hadoop2</id>
<activation>
<property>
<name>hadoopVersion</name>
<value>2</value>
</property>
</activation>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.apache.hadoop</groupId> <groupId>org.apache.hadoop</groupId>

View file

@ -25,8 +25,7 @@ class Accumulable[R, T] (
extends Serializable { extends Serializable {
val id = Accumulators.newId val id = Accumulators.newId
@transient @transient private var value_ = initialValue // Current value on master
private var value_ = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers val zero = param.zero(initialValue) // Zero value to be passed to workers
var deserialized = false var deserialized = false

View file

@ -0,0 +1,65 @@
package spark
import scala.collection.mutable.{ArrayBuffer, HashSet}
import spark.storage.{BlockManager, StorageLevel}
/** Spark class responsible for passing RDDs split contents to the BlockManager and making
sure a node doesn't load two copies of an RDD at once.
*/
private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
private val loading = new HashSet[String]
/** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
: Iterator[T] = {
val key = "rdd_%d_%d".format(rdd.id, split.index)
logInfo("Cache key is " + key)
blockManager.get(key) match {
case Some(cachedValues) =>
// Split is in cache, so just return its values
logInfo("Found partition in cache!")
return cachedValues.asInstanceOf[Iterator[T]]
case None =>
// Mark the split as loading (unless someone else marks it first)
loading.synchronized {
if (loading.contains(key)) {
logInfo("Loading contains " + key + ", waiting...")
while (loading.contains(key)) {
try {loading.wait()} catch {case _ =>}
}
logInfo("Loading no longer contains " + key + ", so returning cached result")
// See whether someone else has successfully loaded it. The main way this would fail
// is for the RDD-level cache eviction policy if someone else has loaded the same RDD
// partition but we didn't want to make space for it. However, that case is unlikely
// because it's unlikely that two threads would work on the same RDD partition. One
// downside of the current code is that threads wait serially if this does happen.
blockManager.get(key) match {
case Some(values) =>
return values.asInstanceOf[Iterator[T]]
case None =>
logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
loading.add(key)
}
} else {
loading.add(key)
}
}
try {
// If we got here, we have to load the split
val elements = new ArrayBuffer[Any]
logInfo("Computing partition " + split)
elements ++= rdd.computeOrReadCheckpoint(split, context)
// Try to put this block in the blockManager
blockManager.put(key, elements, storageLevel, true)
return elements.iterator.asInstanceOf[Iterator[T]]
} finally {
loading.synchronized {
loading.remove(key)
loading.notifyAll()
}
}
}
}
}

View file

@ -1,240 +0,0 @@
package spark
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import akka.actor._
import akka.dispatch._
import akka.pattern.ask
import akka.remote._
import akka.util.Duration
import akka.util.Timeout
import akka.util.duration._
import spark.storage.BlockManager
import spark.storage.StorageLevel
import util.{TimeStampedHashSet, MetadataCleaner, TimeStampedHashMap}
private[spark] sealed trait CacheTrackerMessage
private[spark] case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
private[spark] case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
private[spark] case class MemoryCacheLost(host: String) extends CacheTrackerMessage
private[spark] case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
private[spark] case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage
private[spark] case object GetCacheStatus extends CacheTrackerMessage
private[spark] case object GetCacheLocations extends CacheTrackerMessage
private[spark] case object StopCacheTracker extends CacheTrackerMessage
private[spark] class CacheTrackerActor extends Actor with Logging {
// TODO: Should probably store (String, CacheType) tuples
private val locs = new TimeStampedHashMap[Int, Array[List[String]]]
/**
* A map from the slave's host name to its cache size.
*/
private val slaveCapacity = new HashMap[String, Long]
private val slaveUsage = new HashMap[String, Long]
private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.clearOldValues)
private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
def receive = {
case SlaveCacheStarted(host: String, size: Long) =>
slaveCapacity.put(host, size)
slaveUsage.put(host, 0)
sender ! true
case RegisterRDD(rddId: Int, numPartitions: Int) =>
logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
sender ! true
case AddedToCache(rddId, partition, host, size) =>
slaveUsage.put(host, getCacheUsage(host) + size)
locs(rddId)(partition) = host :: locs(rddId)(partition)
sender ! true
case DroppedFromCache(rddId, partition, host, size) =>
slaveUsage.put(host, getCacheUsage(host) - size)
// Do a sanity check to make sure usage is greater than 0.
locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
sender ! true
case MemoryCacheLost(host) =>
logInfo("Memory cache lost on " + host)
for ((id, locations) <- locs) {
for (i <- 0 until locations.length) {
locations(i) = locations(i).filterNot(_ == host)
}
}
sender ! true
case GetCacheLocations =>
logInfo("Asked for current cache locations")
sender ! locs.map{case (rrdId, array) => (rrdId -> array.clone())}
case GetCacheStatus =>
val status = slaveCapacity.map { case (host, capacity) =>
(host, capacity, getCacheUsage(host))
}.toSeq
sender ! status
case StopCacheTracker =>
logInfo("Stopping CacheTrackerActor")
sender ! true
metadataCleaner.cancel()
context.stop(self)
}
}
private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager)
extends Logging {
// Tracker actor on the master, or remote reference to it on workers
val ip: String = System.getProperty("spark.master.host", "localhost")
val port: Int = System.getProperty("spark.master.port", "7077").toInt
val actorName: String = "CacheTracker"
val timeout = 10.seconds
var trackerActor: ActorRef = if (isMaster) {
val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName)
logInfo("Registered CacheTrackerActor actor")
actor
} else {
val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
actorSystem.actorFor(url)
}
// TODO: Consider removing this HashSet completely as locs CacheTrackerActor already
// keeps track of registered RDDs
val registeredRddIds = new TimeStampedHashSet[Int]
// Remembers which splits are currently being loaded (on worker nodes)
val loading = new HashSet[String]
val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.clearOldValues)
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
def askTracker(message: Any): Any = {
try {
val future = trackerActor.ask(message)(timeout)
return Await.result(future, timeout)
} catch {
case e: Exception =>
throw new SparkException("Error communicating with CacheTracker", e)
}
}
// Send a one-way message to the trackerActor, to which we expect it to reply with true.
def communicate(message: Any) {
if (askTracker(message) != true) {
throw new SparkException("Error reply received from CacheTracker")
}
}
// Registers an RDD (on master only)
def registerRDD(rddId: Int, numPartitions: Int) {
registeredRddIds.synchronized {
if (!registeredRddIds.contains(rddId)) {
logInfo("Registering RDD ID " + rddId + " with cache")
registeredRddIds += rddId
communicate(RegisterRDD(rddId, numPartitions))
}
}
}
// For BlockManager.scala only
def cacheLost(host: String) {
communicate(MemoryCacheLost(host))
logInfo("CacheTracker successfully removed entries on " + host)
}
// Get the usage status of slave caches. Each tuple in the returned sequence
// is in the form of (host name, capacity, usage).
def getCacheStatus(): Seq[(String, Long, Long)] = {
askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]]
}
// For BlockManager.scala only
def notifyFromBlockManager(t: AddedToCache) {
communicate(t)
}
// Get a snapshot of the currently known locations
def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]]
}
// Gets or computes an RDD split
def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
: Iterator[T] = {
val key = "rdd_%d_%d".format(rdd.id, split.index)
logInfo("Cache key is " + key)
blockManager.get(key) match {
case Some(cachedValues) =>
// Split is in cache, so just return its values
logInfo("Found partition in cache!")
return cachedValues.asInstanceOf[Iterator[T]]
case None =>
// Mark the split as loading (unless someone else marks it first)
loading.synchronized {
if (loading.contains(key)) {
logInfo("Loading contains " + key + ", waiting...")
while (loading.contains(key)) {
try {loading.wait()} catch {case _ =>}
}
logInfo("Loading no longer contains " + key + ", so returning cached result")
// See whether someone else has successfully loaded it. The main way this would fail
// is for the RDD-level cache eviction policy if someone else has loaded the same RDD
// partition but we didn't want to make space for it. However, that case is unlikely
// because it's unlikely that two threads would work on the same RDD partition. One
// downside of the current code is that threads wait serially if this does happen.
blockManager.get(key) match {
case Some(values) =>
return values.asInstanceOf[Iterator[T]]
case None =>
logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
loading.add(key)
}
} else {
loading.add(key)
}
}
try {
// If we got here, we have to load the split
val elements = new ArrayBuffer[Any]
logInfo("Computing partition " + split)
elements ++= rdd.compute(split, context)
// Try to put this block in the blockManager
blockManager.put(key, elements, storageLevel, true)
return elements.iterator.asInstanceOf[Iterator[T]]
} finally {
loading.synchronized {
loading.remove(key)
loading.notifyAll()
}
}
}
}
// Called by the Cache to report that an entry has been dropped from it
def dropEntry(rddId: Int, partition: Int) {
communicate(DroppedFromCache(rddId, partition, Utils.localHostName()))
}
def stop() {
communicate(StopCacheTracker)
registeredRddIds.clear()
trackerActor = null
}
}

View file

@ -1,18 +0,0 @@
package spark
import java.util.concurrent.ThreadFactory
/**
* A ThreadFactory that creates daemon threads
*/
private object DaemonThreadFactory extends ThreadFactory {
override def newThread(r: Runnable): Thread = new DaemonThread(r)
}
private class DaemonThread(r: Runnable = null) extends Thread {
override def run() {
if (r != null) {
r.run()
}
}
}

View file

@ -5,6 +5,7 @@ package spark
*/ */
abstract class Dependency[T](val rdd: RDD[T]) extends Serializable abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
/** /**
* Base class for dependencies where each partition of the parent RDD is used by at most one * Base class for dependencies where each partition of the parent RDD is used by at most one
* partition of the child RDD. Narrow dependencies allow for pipelined execution. * partition of the child RDD. Narrow dependencies allow for pipelined execution.
@ -12,12 +13,13 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
/** /**
* Get the parent partitions for a child partition. * Get the parent partitions for a child partition.
* @param outputPartition a partition of the child RDD * @param partitionId a partition of the child RDD
* @return the partitions of the parent RDD that the child partition depends upon * @return the partitions of the parent RDD that the child partition depends upon
*/ */
def getParents(outputPartition: Int): Seq[Int] def getParents(partitionId: Int): Seq[Int]
} }
/** /**
* Represents a dependency on the output of a shuffle stage. * Represents a dependency on the output of a shuffle stage.
* @param shuffleId the shuffle id * @param shuffleId the shuffle id
@ -32,6 +34,7 @@ class ShuffleDependency[K, V](
val shuffleId: Int = rdd.context.newShuffleId() val shuffleId: Int = rdd.context.newShuffleId()
} }
/** /**
* Represents a one-to-one dependency between partitions of the parent and child RDDs. * Represents a one-to-one dependency between partitions of the parent and child RDDs.
*/ */
@ -39,6 +42,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
override def getParents(partitionId: Int) = List(partitionId) override def getParents(partitionId: Int) = List(partitionId)
} }
/** /**
* Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs. * Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs.
* @param rdd the parent RDD * @param rdd the parent RDD
@ -48,7 +52,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
*/ */
class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int)
extends NarrowDependency[T](rdd) { extends NarrowDependency[T](rdd) {
override def getParents(partitionId: Int) = { override def getParents(partitionId: Int) = {
if (partitionId >= outStart && partitionId < outStart + length) { if (partitionId >= outStart && partitionId < outStart + length) {
List(partitionId - outStart + inStart) List(partitionId - outStart + inStart)

View file

@ -1,9 +1,7 @@
package spark package spark
import java.io.{File, PrintWriter} import java.io.{File}
import java.net.URL import com.google.common.io.Files
import scala.collection.mutable.HashMap
import org.apache.hadoop.fs.FileUtil
private[spark] class HttpFileServer extends Logging { private[spark] class HttpFileServer extends Logging {
@ -40,7 +38,7 @@ private[spark] class HttpFileServer extends Logging {
} }
def addFileToDir(file: File, dir: File) : String = { def addFileToDir(file: File, dir: File) : String = {
Utils.copyFile(file, new File(dir, file.getName)) Files.copy(file, new File(dir, file.getName))
return dir + "/" + file.getName return dir + "/" + file.getName
} }

View file

@ -4,6 +4,7 @@ import java.io.File
import java.net.InetAddress import java.net.InetAddress
import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.bio.SocketConnector
import org.eclipse.jetty.server.handler.DefaultHandler import org.eclipse.jetty.server.handler.DefaultHandler
import org.eclipse.jetty.server.handler.HandlerList import org.eclipse.jetty.server.handler.HandlerList
import org.eclipse.jetty.server.handler.ResourceHandler import org.eclipse.jetty.server.handler.ResourceHandler
@ -27,7 +28,13 @@ private[spark] class HttpServer(resourceBase: File) extends Logging {
if (server != null) { if (server != null) {
throw new ServerStateException("Server is already started") throw new ServerStateException("Server is already started")
} else { } else {
server = new Server(0) server = new Server()
val connector = new SocketConnector
connector.setMaxIdleTime(60*1000)
connector.setSoLingerTime(-1)
connector.setPort(0)
server.addConnector(connector)
val threadPool = new QueuedThreadPool val threadPool = new QueuedThreadPool
threadPool.setDaemon(true) threadPool.setDaemon(true)
server.setThreadPool(threadPool) server.setThreadPool(threadPool)

View file

@ -206,5 +206,8 @@ class KryoSerializer extends spark.serializer.Serializer with Logging {
kryo kryo
} }
def newInstance(): SerializerInstance = new KryoSerializerInstance(this) def newInstance(): SerializerInstance = {
this.kryo.get().setClassLoader(Thread.currentThread().getContextClassLoader)
new KryoSerializerInstance(this)
}
} }

View file

@ -11,8 +11,7 @@ import org.slf4j.LoggerFactory
trait Logging { trait Logging {
// Make the log field transient so that objects with Logging can // Make the log field transient so that objects with Logging can
// be serialized and used on another machine // be serialized and used on another machine
@transient @transient private var log_ : Logger = null
private var log_ : Logger = null
// Method to get or create the logger for this object // Method to get or create the logger for this object
protected def log: Logger = { protected def log: Logger = {

View file

@ -38,10 +38,7 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac
} }
} }
private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging { private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolean) extends Logging {
val ip: String = System.getProperty("spark.master.host", "localhost")
val port: Int = System.getProperty("spark.master.port", "7077").toInt
val actorName: String = "MapOutputTracker"
val timeout = 10.seconds val timeout = 10.seconds
@ -56,11 +53,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
var cacheGeneration = generation var cacheGeneration = generation
val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
var trackerActor: ActorRef = if (isMaster) { val actorName: String = "MapOutputTracker"
var trackerActor: ActorRef = if (isDriver) {
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName) val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
logInfo("Registered MapOutputTrackerActor actor") logInfo("Registered MapOutputTrackerActor actor")
actor actor
} else { } else {
val ip = System.getProperty("spark.driver.host", "localhost")
val port = System.getProperty("spark.driver.port", "7077").toInt
val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
actorSystem.actorFor(url) actorSystem.actorFor(url)
} }
@ -114,7 +114,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
var array = mapStatuses(shuffleId) var array = mapStatuses(shuffleId)
if (array != null) { if (array != null) {
array.synchronized { array.synchronized {
if (array(mapId) != null && array(mapId).address == bmAddress) { if (array(mapId) != null && array(mapId).location == bmAddress) {
array(mapId) = null array(mapId) = null
} }
} }
@ -170,7 +170,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
} }
} }
def cleanup(cleanupTime: Long) { private def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime) mapStatuses.clearOldValues(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime) cachedSerializedStatuses.clearOldValues(cleanupTime)
} }
@ -277,7 +277,7 @@ private[spark] object MapOutputTracker {
throw new FetchFailedException(null, shuffleId, -1, reduceId, throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing an output location for shuffle " + shuffleId)) new Exception("Missing an output location for shuffle " + shuffleId))
} else { } else {
(status.address, decompressSize(status.compressedSizes(reduceId))) (status.location, decompressSize(status.compressedSizes(reduceId)))
} }
} }
} }

View file

@ -465,7 +465,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
val res = self.context.runJob(self, process _, Array(index), false) val res = self.context.runJob(self, process _, Array(index), false)
res(0) res(0)
case None => case None =>
self.filter(_._1 == key).map(_._2).collect self.filter(_._1 == key).map(_._2).collect()
} }
} }
@ -485,18 +485,6 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
} }
/**
* Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
*/
def saveAsNewAPIHadoopFile(
path: String,
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) {
saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, new Configuration)
}
/** /**
* Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
@ -506,7 +494,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
keyClass: Class[_], keyClass: Class[_],
valueClass: Class[_], valueClass: Class[_],
outputFormatClass: Class[_ <: NewOutputFormat[_, _]], outputFormatClass: Class[_ <: NewOutputFormat[_, _]],
conf: Configuration) { conf: Configuration = self.context.hadoopConfiguration) {
val job = new NewAPIHadoopJob(conf) val job = new NewAPIHadoopJob(conf)
job.setOutputKeyClass(keyClass) job.setOutputKeyClass(keyClass)
job.setOutputValueClass(valueClass) job.setOutputValueClass(valueClass)
@ -557,7 +545,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
keyClass: Class[_], keyClass: Class[_],
valueClass: Class[_], valueClass: Class[_],
outputFormatClass: Class[_ <: OutputFormat[_, _]], outputFormatClass: Class[_ <: OutputFormat[_, _]],
conf: JobConf = new JobConf) { conf: JobConf = new JobConf(self.context.hadoopConfiguration)) {
conf.setOutputKeyClass(keyClass) conf.setOutputKeyClass(keyClass)
conf.setOutputValueClass(valueClass) conf.setOutputValueClass(valueClass)
// conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug
@ -602,7 +590,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
var count = 0 var count = 0
while(iter.hasNext) { while(iter.hasNext) {
val record = iter.next val record = iter.next()
count += 1 count += 1
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
} }
@ -661,9 +649,7 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
} }
private[spark] private[spark]
class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev) {
extends RDD[(K, U)](prev) {
override def getSplits = firstParent[(K, V)].splits override def getSplits = firstParent[(K, V)].splits
override val partitioner = firstParent[(K, V)].partitioner override val partitioner = firstParent[(K, V)].partitioner
override def compute(split: Split, context: TaskContext) = override def compute(split: Split, context: TaskContext) =

View file

@ -23,32 +23,28 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest](
} }
private[spark] class ParallelCollection[T: ClassManifest]( private[spark] class ParallelCollection[T: ClassManifest](
@transient sc : SparkContext, @transient sc: SparkContext,
@transient data: Seq[T], @transient data: Seq[T],
numSlices: Int, numSlices: Int,
locationPrefs : Map[Int,Seq[String]]) locationPrefs: Map[Int,Seq[String]])
extends RDD[T](sc, Nil) { extends RDD[T](sc, Nil) {
// TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
// cached. It might be worthwhile to write the data to a file in the DFS and read it in the split // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
// instead. // instead.
// UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal. // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal.
@transient @transient var splits_ : Array[Split] = {
var splits_ : Array[Split] = {
val slices = ParallelCollection.slice(data, numSlices).toArray val slices = ParallelCollection.slice(data, numSlices).toArray
slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray
} }
override def getSplits = splits_.asInstanceOf[Array[Split]] override def getSplits = splits_
override def compute(s: Split, context: TaskContext) = override def compute(s: Split, context: TaskContext) =
s.asInstanceOf[ParallelCollectionSplit[T]].iterator s.asInstanceOf[ParallelCollectionSplit[T]].iterator
override def getPreferredLocations(s: Split): Seq[String] = { override def getPreferredLocations(s: Split): Seq[String] = {
locationPrefs.get(s.index) match { locationPrefs.getOrElse(s.index, Nil)
case Some(s) => s
case _ => Nil
}
} }
override def clearDependencies() { override def clearDependencies() {
@ -56,7 +52,6 @@ private[spark] class ParallelCollection[T: ClassManifest](
} }
} }
private object ParallelCollection { private object ParallelCollection {
/** /**
* Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range

View file

@ -1,27 +1,17 @@
package spark package spark
import java.io.{ObjectOutputStream, IOException, EOFException, ObjectInputStream}
import java.net.URL import java.net.URL
import java.util.{Date, Random} import java.util.{Date, Random}
import java.util.{HashMap => JHashMap} import java.util.{HashMap => JHashMap}
import java.util.concurrent.atomic.AtomicLong
import scala.collection.Map import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap import scala.collection.mutable.HashMap
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text import org.apache.hadoop.io.Text
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.FileOutputCommitter
import org.apache.hadoop.mapred.HadoopWriter
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.OutputCommitter
import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.mapred.TextOutputFormat import org.apache.hadoop.mapred.TextOutputFormat
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
@ -30,7 +20,6 @@ import spark.partial.BoundedDouble
import spark.partial.CountEvaluator import spark.partial.CountEvaluator
import spark.partial.GroupedCountEvaluator import spark.partial.GroupedCountEvaluator
import spark.partial.PartialResult import spark.partial.PartialResult
import spark.rdd.BlockRDD
import spark.rdd.CartesianRDD import spark.rdd.CartesianRDD
import spark.rdd.FilteredRDD import spark.rdd.FilteredRDD
import spark.rdd.FlatMappedRDD import spark.rdd.FlatMappedRDD
@ -73,11 +62,11 @@ import SparkContext._
* on RDD internals. * on RDD internals.
*/ */
abstract class RDD[T: ClassManifest]( abstract class RDD[T: ClassManifest](
@transient var sc: SparkContext, @transient private var sc: SparkContext,
var dependencies_ : List[Dependency[_]] @transient private var deps: Seq[Dependency[_]]
) extends Serializable with Logging { ) extends Serializable with Logging {
/** Construct an RDD with just a one-to-one dependency on one parent */
def this(@transient oneParent: RDD[_]) = def this(@transient oneParent: RDD[_]) =
this(oneParent.context , List(new OneToOneDependency(oneParent))) this(oneParent.context , List(new OneToOneDependency(oneParent)))
@ -85,14 +74,20 @@ abstract class RDD[T: ClassManifest](
// Methods that should be implemented by subclasses of RDD // Methods that should be implemented by subclasses of RDD
// ======================================================================= // =======================================================================
/** Function for computing a given partition. */ /** Implemented by subclasses to compute a given partition. */
def compute(split: Split, context: TaskContext): Iterator[T] def compute(split: Split, context: TaskContext): Iterator[T]
/** Set of partitions in this RDD. */ /**
protected def getSplits(): Array[Split] * Implemented by subclasses to return the set of partitions in this RDD. This method will only
* be called once, so it is safe to implement a time-consuming computation in it.
*/
protected def getSplits: Array[Split]
/** How this RDD depends on any parent RDDs. */ /**
protected def getDependencies(): List[Dependency[_]] = dependencies_ * Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only
* be called once, so it is safe to implement a time-consuming computation in it.
*/
protected def getDependencies: Seq[Dependency[_]] = deps
/** Optionally overridden by subclasses to specify placement preferences. */ /** Optionally overridden by subclasses to specify placement preferences. */
protected def getPreferredLocations(split: Split): Seq[String] = Nil protected def getPreferredLocations(split: Split): Seq[String] = Nil
@ -100,7 +95,6 @@ abstract class RDD[T: ClassManifest](
/** Optionally overridden by subclasses to specify how they are partitioned. */ /** Optionally overridden by subclasses to specify how they are partitioned. */
val partitioner: Option[Partitioner] = None val partitioner: Option[Partitioner] = None
// ======================================================================= // =======================================================================
// Methods and fields available on all RDDs // Methods and fields available on all RDDs
// ======================================================================= // =======================================================================
@ -108,6 +102,15 @@ abstract class RDD[T: ClassManifest](
/** A unique ID for this RDD (within its SparkContext). */ /** A unique ID for this RDD (within its SparkContext). */
val id = sc.newRddId() val id = sc.newRddId()
/** A friendly name for this RDD */
var name: String = null
/** Assign a name to this RDD */
def setName(_name: String) = {
name = _name
this
}
/** /**
* Set this RDD's storage level to persist its values across operations after the first time * Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD. * it is computed. Can only be called once on each RDD.
@ -119,6 +122,8 @@ abstract class RDD[T: ClassManifest](
"Cannot change storage level of an RDD after it was already assigned a level") "Cannot change storage level of an RDD after it was already assigned a level")
} }
storageLevel = newLevel storageLevel = newLevel
// Register the RDD with the SparkContext
sc.persistentRdds(id) = this
this this
} }
@ -131,15 +136,24 @@ abstract class RDD[T: ClassManifest](
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel def getStorageLevel = storageLevel
// Our dependencies and splits will be gotten by calling subclass's methods below, and will
// be overwritten when we're checkpointed
private var dependencies_ : Seq[Dependency[_]] = null
@transient private var splits_ : Array[Split] = null
/** An Option holding our checkpoint RDD, if we are checkpointed */
private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD)
/** /**
* Get the preferred location of a split, taking into account whether the * Get the list of dependencies of this RDD, taking into account whether the
* RDD is checkpointed or not. * RDD is checkpointed or not.
*/ */
final def preferredLocations(split: Split): Seq[String] = { final def dependencies: Seq[Dependency[_]] = {
if (isCheckpointed) { checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse {
checkpointData.get.getPreferredLocations(split) if (dependencies_ == null) {
} else { dependencies_ = getDependencies
getPreferredLocations(split) }
dependencies_
} }
} }
@ -148,22 +162,21 @@ abstract class RDD[T: ClassManifest](
* RDD is checkpointed or not. * RDD is checkpointed or not.
*/ */
final def splits: Array[Split] = { final def splits: Array[Split] = {
if (isCheckpointed) { checkpointRDD.map(_.splits).getOrElse {
checkpointData.get.getSplits if (splits_ == null) {
} else { splits_ = getSplits
getSplits }
splits_
} }
} }
/** /**
* Get the list of dependencies of this RDD, taking into account whether the * Get the preferred location of a split, taking into account whether the
* RDD is checkpointed or not. * RDD is checkpointed or not.
*/ */
final def dependencies: List[Dependency[_]] = { final def preferredLocations(split: Split): Seq[String] = {
if (isCheckpointed) { checkpointRDD.map(_.getPreferredLocations(split)).getOrElse {
dependencies_ getPreferredLocations(split)
} else {
getDependencies
} }
} }
@ -173,10 +186,19 @@ abstract class RDD[T: ClassManifest](
* subclasses of RDD. * subclasses of RDD.
*/ */
final def iterator(split: Split, context: TaskContext): Iterator[T] = { final def iterator(split: Split, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
} else {
computeOrReadCheckpoint(split, context)
}
}
/**
* Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
*/
private[spark] def computeOrReadCheckpoint(split: Split, context: TaskContext): Iterator[T] = {
if (isCheckpointed) { if (isCheckpointed) {
checkpointData.get.iterator(split, context) firstParent[T].iterator(split, context)
} else if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel)
} else { } else {
compute(split, context) compute(split, context)
} }
@ -363,20 +385,22 @@ abstract class RDD[T: ClassManifest](
val reducePartition: Iterator[T] => Option[T] = iter => { val reducePartition: Iterator[T] => Option[T] = iter => {
if (iter.hasNext) { if (iter.hasNext) {
Some(iter.reduceLeft(cleanF)) Some(iter.reduceLeft(cleanF))
}else { } else {
None None
} }
} }
val options = sc.runJob(this, reducePartition) var jobResult: Option[T] = None
val results = new ArrayBuffer[T] val mergeResult = (index: Int, taskResult: Option[T]) => {
for (opt <- options; elem <- opt) { if (taskResult != None) {
results += elem jobResult = jobResult match {
} case Some(value) => Some(f(value, taskResult.get))
if (results.size == 0) { case None => taskResult
throw new UnsupportedOperationException("empty collection") }
} else { }
return results.reduceLeft(cleanF)
} }
sc.runJob(this, reducePartition, mergeResult)
// Get the final result out of our Option, or throw an exception if the RDD was empty
jobResult.getOrElse(throw new UnsupportedOperationException("empty collection"))
} }
/** /**
@ -386,9 +410,13 @@ abstract class RDD[T: ClassManifest](
* modify t2. * modify t2.
*/ */
def fold(zeroValue: T)(op: (T, T) => T): T = { def fold(zeroValue: T)(op: (T, T) => T): T = {
// Clone the zero value since we will also be serializing it as part of tasks
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
val cleanOp = sc.clean(op) val cleanOp = sc.clean(op)
val results = sc.runJob(this, (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp)) val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp)
return results.fold(zeroValue)(cleanOp) val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult)
sc.runJob(this, foldPartition, mergeResult)
jobResult
} }
/** /**
@ -400,11 +428,14 @@ abstract class RDD[T: ClassManifest](
* allocation. * allocation.
*/ */
def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = { def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
// Clone the zero value since we will also be serializing it as part of tasks
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
val cleanSeqOp = sc.clean(seqOp) val cleanSeqOp = sc.clean(seqOp)
val cleanCombOp = sc.clean(combOp) val cleanCombOp = sc.clean(combOp)
val results = sc.runJob(this, val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
(iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)) val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
return results.fold(zeroValue)(cleanCombOp) sc.runJob(this, aggregatePartition, mergeResult)
jobResult
} }
/** /**
@ -415,7 +446,7 @@ abstract class RDD[T: ClassManifest](
var result = 0L var result = 0L
while (iter.hasNext) { while (iter.hasNext) {
result += 1L result += 1L
iter.next iter.next()
} }
result result
}).sum }).sum
@ -430,7 +461,7 @@ abstract class RDD[T: ClassManifest](
var result = 0L var result = 0L
while (iter.hasNext) { while (iter.hasNext) {
result += 1L result += 1L
iter.next iter.next()
} }
result result
} }
@ -567,15 +598,15 @@ abstract class RDD[T: ClassManifest](
/** /**
* Return whether this RDD has been checkpointed or not * Return whether this RDD has been checkpointed or not
*/ */
def isCheckpointed(): Boolean = { def isCheckpointed: Boolean = {
if (checkpointData.isDefined) checkpointData.get.isCheckpointed() else false checkpointData.map(_.isCheckpointed).getOrElse(false)
} }
/** /**
* Gets the name of the file to which this RDD was checkpointed * Gets the name of the file to which this RDD was checkpointed
*/ */
def getCheckpointFile(): Option[String] = { def getCheckpointFile: Option[String] = {
if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None checkpointData.flatMap(_.getCheckpointFile)
} }
// ======================================================================= // =======================================================================
@ -600,31 +631,52 @@ abstract class RDD[T: ClassManifest](
def context = sc def context = sc
/** /**
* Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler
* after a job using this RDD has completed (therefore the RDD has been materialized and * after a job using this RDD has completed (therefore the RDD has been materialized and
* potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs. * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs.
*/ */
protected[spark] def doCheckpoint() { private[spark] def doCheckpoint() {
if (checkpointData.isDefined) checkpointData.get.doCheckpoint() if (checkpointData.isDefined) {
dependencies.foreach(_.rdd.doCheckpoint()) checkpointData.get.doCheckpoint()
} else {
dependencies.foreach(_.rdd.doCheckpoint())
}
} }
/** /**
* Changes the dependencies of this RDD from its original parents to the new RDD * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`)
* (`newRDD`) created from the checkpoint file. * created from the checkpoint file, and forget its old dependencies and splits.
*/ */
protected[spark] def changeDependencies(newRDD: RDD[_]) { private[spark] def markCheckpointed(checkpointRDD: RDD[_]) {
clearDependencies() clearDependencies()
dependencies_ = List(new OneToOneDependency(newRDD)) dependencies_ = null
splits_ = null
deps = null // Forget the constructor argument for dependencies too
} }
/** /**
* Clears the dependencies of this RDD. This method must ensure that all references * Clears the dependencies of this RDD. This method must ensure that all references
* to the original parent RDDs is removed to enable the parent RDDs to be garbage * to the original parent RDDs is removed to enable the parent RDDs to be garbage
* collected. Subclasses of RDD may override this method for implementing their own cleaning * collected. Subclasses of RDD may override this method for implementing their own cleaning
* logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. * logic. See [[spark.rdd.UnionRDD]] for an example.
*/ */
protected[spark] def clearDependencies() { protected def clearDependencies() {
dependencies_ = null dependencies_ = null
} }
/** A description of this RDD and its recursive dependencies for debugging. */
def toDebugString(): String = {
def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = {
Seq(prefix + rdd + " (" + rdd.splits.size + " splits)") ++
rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " "))
}
debugString(this).mkString("\n")
}
override def toString(): String = "%s%s[%d] at %s".format(
Option(name).map(_ + " ").getOrElse(""),
getClass.getSimpleName,
id,
origin)
} }

View file

@ -20,7 +20,7 @@ private[spark] object CheckpointState extends Enumeration {
* of the checkpointed RDD. * of the checkpointed RDD.
*/ */
private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
extends Logging with Serializable { extends Logging with Serializable {
import CheckpointState._ import CheckpointState._
@ -31,7 +31,7 @@ extends Logging with Serializable {
@transient var cpFile: Option[String] = None @transient var cpFile: Option[String] = None
// The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
@transient var cpRDD: Option[RDD[T]] = None var cpRDD: Option[RDD[T]] = None
// Mark the RDD for checkpointing // Mark the RDD for checkpointing
def markForCheckpoint() { def markForCheckpoint() {
@ -41,12 +41,12 @@ extends Logging with Serializable {
} }
// Is the RDD already checkpointed // Is the RDD already checkpointed
def isCheckpointed(): Boolean = { def isCheckpointed: Boolean = {
RDDCheckpointData.synchronized { cpState == Checkpointed } RDDCheckpointData.synchronized { cpState == Checkpointed }
} }
// Get the file to which this RDD was checkpointed to as an Option // Get the file to which this RDD was checkpointed to as an Option
def getCheckpointFile(): Option[String] = { def getCheckpointFile: Option[String] = {
RDDCheckpointData.synchronized { cpFile } RDDCheckpointData.synchronized { cpFile }
} }
@ -71,7 +71,7 @@ extends Logging with Serializable {
RDDCheckpointData.synchronized { RDDCheckpointData.synchronized {
cpFile = Some(path) cpFile = Some(path)
cpRDD = Some(newRDD) cpRDD = Some(newRDD)
rdd.changeDependencies(newRDD) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and splits
cpState = Checkpointed cpState = Checkpointed
RDDCheckpointData.clearTaskCaches() RDDCheckpointData.clearTaskCaches()
logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
@ -79,7 +79,7 @@ extends Logging with Serializable {
} }
// Get preferred location of a split after checkpointing // Get preferred location of a split after checkpointing
def getPreferredLocations(split: Split) = { def getPreferredLocations(split: Split): Seq[String] = {
RDDCheckpointData.synchronized { RDDCheckpointData.synchronized {
cpRDD.get.preferredLocations(split) cpRDD.get.preferredLocations(split)
} }
@ -91,9 +91,10 @@ extends Logging with Serializable {
} }
} }
// Get iterator. This is called at the worker nodes. def checkpointRDD: Option[RDD[T]] = {
def iterator(split: Split, context: TaskContext): Iterator[T] = { RDDCheckpointData.synchronized {
rdd.firstParent[T].iterator(split, context) cpRDD
}
} }
} }

View file

@ -1,6 +1,7 @@
package spark package spark
import java.io._ import java.io._
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
import java.net.{URI, URLClassLoader} import java.net.{URI, URLClassLoader}
import java.lang.ref.WeakReference import java.lang.ref.WeakReference
@ -8,6 +9,7 @@ import java.lang.ref.WeakReference
import scala.collection.Map import scala.collection.Map
import scala.collection.generic.Growable import scala.collection.generic.Growable
import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.JavaConversions._
import akka.actor.Actor import akka.actor.Actor
import akka.actor.Actor._ import akka.actor.Actor._
@ -42,6 +44,9 @@ import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler}
import spark.scheduler.local.LocalScheduler import spark.scheduler.local.LocalScheduler
import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import storage.BlockManagerUI
import util.{MetadataCleaner, TimeStampedHashMap}
import storage.{StorageStatus, StorageUtils, RDDInfo}
/** /**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@ -57,59 +62,55 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend
class SparkContext( class SparkContext(
val master: String, val master: String,
val jobName: String, val jobName: String,
val sparkHome: String, val sparkHome: String = null,
val jars: Seq[String], val jars: Seq[String] = Nil,
environment: Map[String, String]) environment: Map[String, String] = Map())
extends Logging { extends Logging {
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param jobName A name for your job, to display on the cluster web UI
* @param sparkHome Location where Spark is installed on cluster nodes.
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
*/
def this(master: String, jobName: String, sparkHome: String, jars: Seq[String]) =
this(master, jobName, sparkHome, jars, Map())
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param jobName A name for your job, to display on the cluster web UI
*/
def this(master: String, jobName: String) = this(master, jobName, null, Nil, Map())
// Ensure logging is initialized before we spawn any threads // Ensure logging is initialized before we spawn any threads
initLogging() initLogging()
// Set Spark master host and port system properties // Set Spark driver host and port system properties
if (System.getProperty("spark.master.host") == null) { if (System.getProperty("spark.driver.host") == null) {
System.setProperty("spark.master.host", Utils.localIpAddress) System.setProperty("spark.driver.host", Utils.localIpAddress)
} }
if (System.getProperty("spark.master.port") == null) { if (System.getProperty("spark.driver.port") == null) {
System.setProperty("spark.master.port", "0") System.setProperty("spark.driver.port", "0")
} }
private val isLocal = (master == "local" || master.startsWith("local[")) private val isLocal = (master == "local" || master.startsWith("local["))
// Create the Spark execution environment (cache, map output tracker, etc) // Create the Spark execution environment (cache, map output tracker, etc)
private[spark] val env = SparkEnv.createFromSystemProperties( private[spark] val env = SparkEnv.createFromSystemProperties(
System.getProperty("spark.master.host"), "<driver>",
System.getProperty("spark.master.port").toInt, System.getProperty("spark.driver.host"),
System.getProperty("spark.driver.port").toInt,
true, true,
isLocal) isLocal)
SparkEnv.set(env) SparkEnv.set(env)
// Start the BlockManager UI
private[spark] val ui = new BlockManagerUI(
env.actorSystem, env.blockManager.master.driverActor, this)
ui.start()
// Used to store a URL for each static file/jar together with the file's local timestamp // Used to store a URL for each static file/jar together with the file's local timestamp
private[spark] val addedFiles = HashMap[String, Long]() private[spark] val addedFiles = HashMap[String, Long]()
private[spark] val addedJars = HashMap[String, Long]() private[spark] val addedJars = HashMap[String, Long]()
// Keeps track of all persisted RDDs
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]()
private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup)
// Add each JAR given through the constructor // Add each JAR given through the constructor
jars.foreach { addJar(_) } jars.foreach { addJar(_) }
// Environment variables to pass to our executors // Environment variables to pass to our executors
private[spark] val executorEnvs = HashMap[String, String]() private[spark] val executorEnvs = HashMap[String, String]()
// Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner
for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS",
"SPARK_TESTING")) { "SPARK_TESTING")) {
val value = System.getenv(key) val value = System.getenv(key)
if (value != null) { if (value != null) {
executorEnvs(key) = value executorEnvs(key) = value
@ -127,6 +128,8 @@ class SparkContext(
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters // Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """(spark://.*)""".r val SPARK_REGEX = """(spark://.*)""".r
//Regular expression for connection to Mesos cluster
val MESOS_REGEX = """(mesos://.*)""".r
master match { master match {
case "local" => case "local" =>
@ -167,6 +170,9 @@ class SparkContext(
scheduler scheduler
case _ => case _ =>
if (MESOS_REGEX.findFirstIn(master).isEmpty) {
logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
}
MesosNativeLibrary.load() MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this) val scheduler = new ClusterScheduler(this)
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
@ -183,6 +189,26 @@ class SparkContext(
taskScheduler.start() taskScheduler.start()
private var dagScheduler = new DAGScheduler(taskScheduler) private var dagScheduler = new DAGScheduler(taskScheduler)
dagScheduler.start()
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = {
val conf = new Configuration()
// Explicitly check for S3 environment variables
if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
}
// Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) {
conf.set(key.substring("spark.hadoop.".length), System.getProperty(key))
}
val bufferSize = System.getProperty("spark.buffer.size", "65536")
conf.set("io.file.buffer.size", bufferSize)
conf
}
private[spark] var checkpointDir: Option[String] = None private[spark] var checkpointDir: Option[String] = None
@ -238,10 +264,8 @@ class SparkContext(
valueClass: Class[V], valueClass: Class[V],
minSplits: Int = defaultMinSplits minSplits: Int = defaultMinSplits
) : RDD[(K, V)] = { ) : RDD[(K, V)] = {
val conf = new JobConf() val conf = new JobConf(hadoopConfiguration)
FileInputFormat.setInputPaths(conf, path) FileInputFormat.setInputPaths(conf, path)
val bufferSize = System.getProperty("spark.buffer.size", "65536")
conf.set("io.file.buffer.size", bufferSize)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
} }
@ -282,8 +306,7 @@ class SparkContext(
path, path,
fm.erasure.asInstanceOf[Class[F]], fm.erasure.asInstanceOf[Class[F]],
km.erasure.asInstanceOf[Class[K]], km.erasure.asInstanceOf[Class[K]],
vm.erasure.asInstanceOf[Class[V]], vm.erasure.asInstanceOf[Class[V]])
new Configuration)
} }
/** /**
@ -295,7 +318,7 @@ class SparkContext(
fClass: Class[F], fClass: Class[F],
kClass: Class[K], kClass: Class[K],
vClass: Class[V], vClass: Class[V],
conf: Configuration): RDD[(K, V)] = { conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
val job = new NewHadoopJob(conf) val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path)) NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration val updatedConf = job.getConfiguration
@ -307,7 +330,7 @@ class SparkContext(
* and extra configuration options to pass to the input format. * and extra configuration options to pass to the input format.
*/ */
def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]]( def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]](
conf: Configuration, conf: Configuration = hadoopConfiguration,
fClass: Class[F], fClass: Class[F],
kClass: Class[K], kClass: Class[K],
vClass: Class[V]): RDD[(K, V)] = { vClass: Class[V]): RDD[(K, V)] = {
@ -390,14 +413,14 @@ class SparkContext(
/** /**
* Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values
* to using the `+=` method. Only the master can access the accumulator's `value`. * to using the `+=` method. Only the driver can access the accumulator's `value`.
*/ */
def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
new Accumulator(initialValue, param) new Accumulator(initialValue, param)
/** /**
* Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`. * Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`.
* Only the master can access the accumuable's `value`. * Only the driver can access the accumuable's `value`.
* @tparam T accumulator type * @tparam T accumulator type
* @tparam R type that can be added to the accumulator * @tparam R type that can be added to the accumulator
*/ */
@ -422,9 +445,10 @@ class SparkContext(
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
/** /**
* Add a file to be downloaded into the working directory of this Spark job on every node. * Add a file to be downloaded with this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
* filesystems), or an HTTP, HTTPS or FTP URI. * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
* use `SparkFiles.get(path)` to find its download location.
*/ */
def addFile(path: String) { def addFile(path: String) {
val uri = new URI(path) val uri = new URI(path)
@ -437,7 +461,7 @@ class SparkContext(
// Fetch the file locally in case a job is executed locally. // Fetch the file locally in case a job is executed locally.
// Jobs that run through LocalScheduler will already fetch the required dependencies, // Jobs that run through LocalScheduler will already fetch the required dependencies,
// but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here. // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here.
Utils.fetchFile(path, new File(".")) Utils.fetchFile(path, new File(SparkFiles.getRootDirectory))
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
} }
@ -446,12 +470,27 @@ class SparkContext(
* Return a map from the slave to the max memory available for caching and the remaining * Return a map from the slave to the max memory available for caching and the remaining
* memory available for caching. * memory available for caching.
*/ */
def getSlavesMemoryStatus: Map[String, (Long, Long)] = { def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
(blockManagerId.ip + ":" + blockManagerId.port, mem) (blockManagerId.ip + ":" + blockManagerId.port, mem)
} }
} }
/**
* Return information about what RDDs are cached, if they are in mem or on disk, how much space
* they take, etc.
*/
def getRDDStorageInfo : Array[RDDInfo] = {
StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
}
/**
* Return information about blocks stored in all of the slaves
*/
def getExecutorStorageStatus : Array[StorageStatus] = {
env.blockManager.master.getStorageStatus
}
/** /**
* Clear the job's list of files added by `addFile` so that they do not get downloaded to * Clear the job's list of files added by `addFile` so that they do not get downloaded to
* any new nodes. * any new nodes.
@ -486,6 +525,7 @@ class SparkContext(
/** Shut down the SparkContext. */ /** Shut down the SparkContext. */
def stop() { def stop() {
if (dagScheduler != null) { if (dagScheduler != null) {
metadataCleaner.cancel()
dagScheduler.stop() dagScheduler.stop()
dagScheduler = null dagScheduler = null
taskScheduler = null taskScheduler = null
@ -521,10 +561,30 @@ class SparkContext(
} }
/** /**
* Run a function on a given set of partitions in an RDD and return the results. This is the main * Run a function on a given set of partitions in an RDD and pass the results to the given
* entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies * handler function. This is the main entry point for all actions in Spark. The allowLocal
* whether the scheduler can run the computation on the master rather than shipping it out to the * flag specifies whether the scheduler can run the computation on the driver rather than
* cluster, for short actions like first(). * shipping it out to the cluster, for short actions like first().
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
val callSite = Utils.getSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
result
}
/**
* Run a function on a given set of partitions in an RDD and return the results as an array. The
* allowLocal flag specifies whether the scheduler can run the computation on the driver rather
* than shipping it out to the cluster, for short actions like first().
*/ */
def runJob[T, U: ClassManifest]( def runJob[T, U: ClassManifest](
rdd: RDD[T], rdd: RDD[T],
@ -532,13 +592,9 @@ class SparkContext(
partitions: Seq[Int], partitions: Seq[Int],
allowLocal: Boolean allowLocal: Boolean
): Array[U] = { ): Array[U] = {
val callSite = Utils.getSparkCallSite val results = new Array[U](partitions.size)
logInfo("Starting job: " + callSite) runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res)
val start = System.nanoTime results
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
result
} }
/** /**
@ -568,6 +624,29 @@ class SparkContext(
runJob(rdd, func, 0 until rdd.splits.size, false) runJob(rdd, func, 0 until rdd.splits.size, false)
} }
/**
* Run a job on all partitions in an RDD and pass the results to a handler function.
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
processPartition: (TaskContext, Iterator[T]) => U,
resultHandler: (Int, U) => Unit)
{
runJob[T, U](rdd, processPartition, 0 until rdd.splits.size, false, resultHandler)
}
/**
* Run a job on all partitions in an RDD and pass the results to a handler function.
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
processPartition: Iterator[T] => U,
resultHandler: (Int, U) => Unit)
{
val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter)
runJob[T, U](rdd, processFunc, 0 until rdd.splits.size, false, resultHandler)
}
/** /**
* Run a job that can return approximate results. * Run a job that can return approximate results.
*/ */
@ -628,6 +707,11 @@ class SparkContext(
/** Register a new RDD, returning its RDD ID */ /** Register a new RDD, returning its RDD ID */
private[spark] def newRddId(): Int = nextRddId.getAndIncrement() private[spark] def newRddId(): Int = nextRddId.getAndIncrement()
/** Called by MetadataCleaner to clean up the persistentRdds map periodically */
private[spark] def cleanup(cleanupTime: Long) {
persistentRdds.clearOldValues(cleanupTime)
}
} }
/** /**
@ -646,6 +730,16 @@ object SparkContext {
def zero(initialValue: Int) = 0 def zero(initialValue: Int) = 0
} }
implicit object LongAccumulatorParam extends AccumulatorParam[Long] {
def addInPlace(t1: Long, t2: Long) = t1 + t2
def zero(initialValue: Long) = 0l
}
implicit object FloatAccumulatorParam extends AccumulatorParam[Float] {
def addInPlace(t1: Float, t2: Float) = t1 + t2
def zero(initialValue: Float) = 0f
}
// TODO: Add AccumulatorParams for other types, e.g. lists and strings // TODO: Add AccumulatorParams for other types, e.g. lists and strings
implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) = implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =

View file

@ -19,27 +19,23 @@ import spark.util.AkkaUtils
* SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set. * SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set.
*/ */
class SparkEnv ( class SparkEnv (
val executorId: String,
val actorSystem: ActorSystem, val actorSystem: ActorSystem,
val serializer: Serializer, val serializer: Serializer,
val closureSerializer: Serializer, val closureSerializer: Serializer,
val cacheTracker: CacheTracker, val cacheManager: CacheManager,
val mapOutputTracker: MapOutputTracker, val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher, val shuffleFetcher: ShuffleFetcher,
val broadcastManager: BroadcastManager, val broadcastManager: BroadcastManager,
val blockManager: BlockManager, val blockManager: BlockManager,
val connectionManager: ConnectionManager, val connectionManager: ConnectionManager,
val httpFileServer: HttpFileServer val httpFileServer: HttpFileServer,
val sparkFilesDir: String
) { ) {
/** No-parameter constructor for unit tests. */
def this() = {
this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
}
def stop() { def stop() {
httpFileServer.stop() httpFileServer.stop()
mapOutputTracker.stop() mapOutputTracker.stop()
cacheTracker.stop()
shuffleFetcher.stop() shuffleFetcher.stop()
broadcastManager.stop() broadcastManager.stop()
blockManager.stop() blockManager.stop()
@ -63,17 +59,18 @@ object SparkEnv extends Logging {
} }
def createFromSystemProperties( def createFromSystemProperties(
executorId: String,
hostname: String, hostname: String,
port: Int, port: Int,
isMaster: Boolean, isDriver: Boolean,
isLocal: Boolean isLocal: Boolean): SparkEnv = {
) : SparkEnv = {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port) val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port)
// Bit of a hack: If this is the master and our port was 0 (meaning bind to any free port), // Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port),
// figure out which port number Akka actually bound to and set spark.master.port to it. // figure out which port number Akka actually bound to and set spark.driver.port to it.
if (isMaster && port == 0) { if (isDriver && port == 0) {
System.setProperty("spark.master.port", boundPort.toString) System.setProperty("spark.driver.port", boundPort.toString)
} }
val classLoader = Thread.currentThread.getContextClassLoader val classLoader = Thread.currentThread.getContextClassLoader
@ -87,23 +84,22 @@ object SparkEnv extends Logging {
val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer") val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
val masterIp: String = System.getProperty("spark.master.host", "localhost") val driverIp: String = System.getProperty("spark.driver.host", "localhost")
val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
val blockManagerMaster = new BlockManagerMaster( val blockManagerMaster = new BlockManagerMaster(
actorSystem, isMaster, isLocal, masterIp, masterPort) actorSystem, isDriver, isLocal, driverIp, driverPort)
val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer) val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer)
val connectionManager = blockManager.connectionManager val connectionManager = blockManager.connectionManager
val broadcastManager = new BroadcastManager(isMaster) val broadcastManager = new BroadcastManager(isDriver)
val closureSerializer = instantiateClass[Serializer]( val closureSerializer = instantiateClass[Serializer](
"spark.closure.serializer", "spark.JavaSerializer") "spark.closure.serializer", "spark.JavaSerializer")
val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager) val cacheManager = new CacheManager(blockManager)
blockManager.cacheTracker = cacheTracker
val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster) val mapOutputTracker = new MapOutputTracker(actorSystem, isDriver)
val shuffleFetcher = instantiateClass[ShuffleFetcher]( val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") "spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
@ -112,6 +108,15 @@ object SparkEnv extends Logging {
httpFileServer.initialize() httpFileServer.initialize()
System.setProperty("spark.fileserver.uri", httpFileServer.serverUri) System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
// Set the sparkFiles directory, used when downloading dependencies. In local mode,
// this is a temporary directory; in distributed mode, this is the executor's current working
// directory.
val sparkFilesDir: String = if (isDriver) {
Utils.createTempDir().getAbsolutePath
} else {
"."
}
// Warn about deprecated spark.cache.class property // Warn about deprecated spark.cache.class property
if (System.getProperty("spark.cache.class") != null) { if (System.getProperty("spark.cache.class") != null) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " + logWarning("The spark.cache.class property is no longer being used! Specify storage " +
@ -119,15 +124,17 @@ object SparkEnv extends Logging {
} }
new SparkEnv( new SparkEnv(
executorId,
actorSystem, actorSystem,
serializer, serializer,
closureSerializer, closureSerializer,
cacheTracker, cacheManager,
mapOutputTracker, mapOutputTracker,
shuffleFetcher, shuffleFetcher,
broadcastManager, broadcastManager,
blockManager, blockManager,
connectionManager, connectionManager,
httpFileServer) httpFileServer,
sparkFilesDir)
} }
} }

View file

@ -0,0 +1,25 @@
package spark;
import java.io.File;
/**
* Resolves paths to files added through `SparkContext.addFile()`.
*/
public class SparkFiles {
private SparkFiles() {}
/**
* Get the absolute path of a file added through `SparkContext.addFile()`.
*/
public static String get(String filename) {
return new File(getRootDirectory(), filename).getAbsolutePath();
}
/**
* Get the root directory that contains files added through `SparkContext.addFile()`.
*/
public static String getRootDirectory() {
return SparkEnv.get().sparkFilesDir();
}
}

View file

@ -5,8 +5,7 @@ import scala.collection.mutable.ArrayBuffer
class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable { class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable {
@transient @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]
val onCompleteCallbacks = new ArrayBuffer[() => Unit]
// Add a callback function to be executed on task completion. An example use // Add a callback function to be executed on task completion. An example use
// is for HadoopRDD to register a callback to close the input stream. // is for HadoopRDD to register a callback to close the input stream.

View file

@ -1,7 +1,7 @@
package spark package spark
import java.io._ import java.io._
import java.net.{NetworkInterface, InetAddress, URL, URI} import java.net._
import java.util.{Locale, Random, UUID} import java.util.{Locale, Random, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import org.apache.hadoop.conf.Configuration import org.apache.hadoop.conf.Configuration
@ -10,6 +10,9 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
import scala.io.Source import scala.io.Source
import com.google.common.io.Files import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
import scala.Some
import spark.serializer.SerializerInstance
/** /**
* Various utility methods used by Spark. * Various utility methods used by Spark.
@ -111,20 +114,6 @@ private object Utils extends Logging {
} }
} }
/** Copy a file on the local file system */
def copyFile(source: File, dest: File) {
val in = new FileInputStream(source)
val out = new FileOutputStream(dest)
copyStream(in, out, true)
}
/** Download a file from a given URL to the local filesystem */
def downloadFile(url: URL, localPath: String) {
val in = url.openStream()
val out = new FileOutputStream(localPath)
Utils.copyStream(in, out, true)
}
/** /**
* Download a file requested by the executor. Supports fetching the file in a variety of ways, * Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
@ -201,7 +190,7 @@ private object Utils extends Logging {
Utils.execute(Seq("tar", "-xf", filename), targetDir) Utils.execute(Seq("tar", "-xf", filename), targetDir)
} }
// Make the file executable - That's necessary for scripts // Make the file executable - That's necessary for scripts
FileUtil.chmod(filename, "a+x") FileUtil.chmod(targetFile.getAbsolutePath, "a+x")
} }
/** /**
@ -251,7 +240,8 @@ private object Utils extends Logging {
// Address resolves to something like 127.0.1.1, which happens on Debian; try to find // Address resolves to something like 127.0.1.1, which happens on Debian; try to find
// a better address using the local network interfaces // a better address using the local network interfaces
for (ni <- NetworkInterface.getNetworkInterfaces) { for (ni <- NetworkInterface.getNetworkInterfaces) {
for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && !addr.isLoopbackAddress) { for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress &&
!addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) {
// We've found an address that looks reasonable! // We've found an address that looks reasonable!
logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" +
" a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress + " a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress +
@ -286,29 +276,14 @@ private object Utils extends Logging {
customHostname.getOrElse(InetAddress.getLocalHost.getHostName) customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
} }
/** private[spark] val daemonThreadFactory: ThreadFactory =
* Returns a standard ThreadFactory except all threads are daemons. new ThreadFactoryBuilder().setDaemon(true).build()
*/
private def newDaemonThreadFactory: ThreadFactory = {
new ThreadFactory {
def newThread(r: Runnable): Thread = {
var t = Executors.defaultThreadFactory.newThread (r)
t.setDaemon (true)
return t
}
}
}
/** /**
* Wrapper over newCachedThreadPool. * Wrapper over newCachedThreadPool.
*/ */
def newDaemonCachedThreadPool(): ThreadPoolExecutor = { def newDaemonCachedThreadPool(): ThreadPoolExecutor =
var threadPool = Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor] Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory (newDaemonThreadFactory)
return threadPool
}
/** /**
* Return the string to tell how long has passed in seconds. The passing parameter should be in * Return the string to tell how long has passed in seconds. The passing parameter should be in
@ -321,13 +296,8 @@ private object Utils extends Logging {
/** /**
* Wrapper over newFixedThreadPool. * Wrapper over newFixedThreadPool.
*/ */
def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor =
var threadPool = Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory(newDaemonThreadFactory)
return threadPool
}
/** /**
* Delete a file or directory and its contents recursively. * Delete a file or directory and its contents recursively.
@ -463,4 +433,25 @@ private object Utils extends Logging {
} }
"%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine) "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
} }
/**
* Try to find a free port to bind to on the local host. This should ideally never be needed,
* except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray)
* don't let users bind to port 0 and then figure out which free port they actually bound to.
* We work around this by binding a ServerSocket and immediately unbinding it. This is *not*
* necessarily guaranteed to work, but it's the best we can do.
*/
def findFreePort(): Int = {
val socket = new ServerSocket(0)
val portBound = socket.getLocalPort
socket.close()
portBound
}
/**
* Clone an object using a Spark serializer.
*/
def clone[T](value: T, serializer: SerializerInstance): T = {
serializer.deserialize[T](serializer.serialize(value))
}
} }

View file

@ -12,7 +12,7 @@ import spark.storage.StorageLevel
import com.google.common.base.Optional import com.google.common.base.Optional
trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround[T] {
def wrapRDD(rdd: RDD[T]): This def wrapRDD(rdd: RDD[T]): This
implicit val classManifest: ClassManifest[T] implicit val classManifest: ClassManifest[T]
@ -82,10 +82,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
} }
/** /**
* Return a new RDD by first applying a function to all elements of this * Part of the workaround for SPARK-668; called in PairFlatMapWorkaround.java.
* RDD, and then flattening the results.
*/ */
def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = { private[spark] def doFlatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = {
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala def fn = (x: T) => f.apply(x).asScala
def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
@ -307,23 +306,20 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
JavaPairRDD.fromRDD(rdd.keyBy(f)) JavaPairRDD.fromRDD(rdd.keyBy(f))
} }
/** /**
* Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
* (set using setCheckpointDir()) and all references to its parent RDDs will be removed. * directory set with SparkContext.setCheckpointDir() and all references to its parent
* This is used to truncate very long lineages. In the current implementation, Spark will save * RDDs will be removed. This function must be called before any job has been
* this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. * executed on this RDD. It is strongly recommended that this RDD is persisted in
* Hence, it is strongly recommended to use checkpoint() on RDDs when * memory, otherwise saving it on a file will require recomputation.
* (i) checkpoint() is called before the any job has been executed on this RDD.
* (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will
* require recomputation.
*/ */
def checkpoint() = rdd.checkpoint() def checkpoint() = rdd.checkpoint()
/** /**
* Return whether this RDD has been checkpointed or not * Return whether this RDD has been checkpointed or not
*/ */
def isCheckpointed(): Boolean = rdd.isCheckpointed() def isCheckpointed: Boolean = rdd.isCheckpointed
/** /**
* Gets the name of the file to which this RDD was checkpointed * Gets the name of the file to which this RDD was checkpointed
@ -334,4 +330,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
case _ => Optional.absent() case _ => Optional.absent()
} }
} }
/** A description of this RDD and its recursive dependencies for debugging. */
def toDebugString(): String = {
rdd.toDebugString()
}
} }

View file

@ -323,9 +323,10 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
def getSparkHome(): Option[String] = sc.getSparkHome() def getSparkHome(): Option[String] = sc.getSparkHome()
/** /**
* Add a file to be downloaded into the working directory of this Spark job on every node. * Add a file to be downloaded with this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
* filesystems), or an HTTP, HTTPS or FTP URI. * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
* use `SparkFiles.get(path)` to find its download location.
*/ */
def addFile(path: String) { def addFile(path: String) {
sc.addFile(path) sc.addFile(path)
@ -357,20 +358,28 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
} }
/** /**
* Set the directory under which RDDs are going to be checkpointed. This method will * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse.
* create this directory and will throw an exception of the path already exists (to avoid */
* overwriting existing files may be overwritten). The directory will be deleted on exit def hadoopConfiguration(): Configuration = {
* if indicated. sc.hadoopConfiguration
}
/**
* Set the directory under which RDDs are going to be checkpointed. The directory must
* be a HDFS path if running on a cluster. If the directory does not exist, it will
* be created. If the directory exists and useExisting is set to true, then the
* exisiting directory will be used. Otherwise an exception will be thrown to
* prevent accidental overriding of checkpoint files in the existing directory.
*/ */
def setCheckpointDir(dir: String, useExisting: Boolean) { def setCheckpointDir(dir: String, useExisting: Boolean) {
sc.setCheckpointDir(dir, useExisting) sc.setCheckpointDir(dir, useExisting)
} }
/** /**
* Set the directory under which RDDs are going to be checkpointed. This method will * Set the directory under which RDDs are going to be checkpointed. The directory must
* create this directory and will throw an exception of the path already exists (to avoid * be a HDFS path if running on a cluster. If the directory does not exist, it will
* overwriting existing files may be overwritten). The directory will be deleted on exit * be created. If the directory exists, an exception will be thrown to prevent accidental
* if indicated. * overriding of checkpoint files.
*/ */
def setCheckpointDir(dir: String) { def setCheckpointDir(dir: String) {
sc.setCheckpointDir(dir) sc.setCheckpointDir(dir)

View file

@ -0,0 +1,20 @@
package spark.api.java;
import spark.api.java.JavaPairRDD;
import spark.api.java.JavaRDDLike;
import spark.api.java.function.PairFlatMapFunction;
import java.io.Serializable;
/**
* Workaround for SPARK-668.
*/
class PairFlatMapWorkaround<T> implements Serializable {
/**
* Return a new RDD by first applying a function to all elements of this
* RDD, and then flattening the results.
*/
public <K, V> JavaPairRDD<K, V> flatMap(PairFlatMapFunction<T, K, V> f) {
return ((JavaRDDLike <T, ?>) this).doFlatMap(f);
}
}

View file

@ -17,4 +17,15 @@ public class StorageLevels {
public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2); public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2);
public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1); public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1);
public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2); public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2);
/**
* Create a new StorageLevel object.
* @param useDisk saved to disk, if true
* @param useMemory saved to memory, if true
* @param deserialized saved as deserialized objects, if true
* @param replication replication factor
*/
public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) {
return StorageLevel.apply(useDisk, useMemory, deserialized, replication);
}
} }

View file

@ -6,8 +6,17 @@ import java.util.Arrays
/** /**
* A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
*
* Stores the unique id() of the Python-side partitioning function so that it is incorporated into
* equality comparisons. Correctness requires that the id is a unique identifier for the
* lifetime of the job (i.e. that it is not re-used as the id of a different partitioning
* function). This can be ensured by using the Python id() function and maintaining a reference
* to the Python partitioning function so that its id() is not reused.
*/ */
private[spark] class PythonPartitioner(override val numPartitions: Int) extends Partitioner { private[spark] class PythonPartitioner(
override val numPartitions: Int,
val pyPartitionFunctionId: Long)
extends Partitioner {
override def getPartition(key: Any): Int = { override def getPartition(key: Any): Int = {
if (key == null) { if (key == null) {
@ -32,7 +41,7 @@ private[spark] class PythonPartitioner(override val numPartitions: Int) extends
override def equals(other: Any): Boolean = other match { override def equals(other: Any): Boolean = other match {
case h: PythonPartitioner => case h: PythonPartitioner =>
h.numPartitions == numPartitions h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId
case _ => case _ =>
false false
} }

View file

@ -67,6 +67,8 @@ private[spark] class PythonRDD[T: ClassManifest](
val dOut = new DataOutputStream(proc.getOutputStream) val dOut = new DataOutputStream(proc.getOutputStream)
// Split index // Split index
dOut.writeInt(split.index) dOut.writeInt(split.index)
// sparkFilesDir
PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut)
// Broadcast variables // Broadcast variables
dOut.writeInt(broadcastVars.length) dOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) { for (broadcast <- broadcastVars) {
@ -101,21 +103,27 @@ private[spark] class PythonRDD[T: ClassManifest](
private def read(): Array[Byte] = { private def read(): Array[Byte] = {
try { try {
val length = stream.readInt() stream.readInt() match {
if (length != -1) { case length if length > 0 =>
val obj = new Array[Byte](length) val obj = new Array[Byte](length)
stream.readFully(obj) stream.readFully(obj)
obj obj
} else { case -2 =>
// We've finished the data section of the output, but we can still read some // Signals that an exception has been thrown in python
// accumulator updates; let's do that, breaking when we get EOFException val exLength = stream.readInt()
while (true) { val obj = new Array[Byte](exLength)
val len2 = stream.readInt() stream.readFully(obj)
val update = new Array[Byte](len2) throw new PythonException(new String(obj))
stream.readFully(update) case -1 =>
accumulator += Collections.singletonList(update) // We've finished the data section of the output, but we can still read some
} // accumulator updates; let's do that, breaking when we get EOFException
new Array[Byte](0) while (true) {
val len2 = stream.readInt()
val update = new Array[Byte](len2)
stream.readFully(update)
accumulator += Collections.singletonList(update)
}
new Array[Byte](0)
} }
} catch { } catch {
case eof: EOFException => { case eof: EOFException => {
@ -135,11 +143,12 @@ private[spark] class PythonRDD[T: ClassManifest](
} }
} }
override def checkpoint() { }
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
} }
/** Thrown for exceptions in user Python code. */
private class PythonException(msg: String) extends Exception(msg)
/** /**
* Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
* This is used by PySpark's shuffle operations. * This is used by PySpark's shuffle operations.
@ -152,7 +161,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
case Seq(a, b) => (a, b) case Seq(a, b) => (a, b)
case x => throw new Exception("PairwiseRDD: unexpected value: " + x) case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
} }
override def checkpoint() { }
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
} }
@ -230,6 +238,11 @@ private[spark] object PythonRDD {
} }
def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
import scala.collection.JavaConverters._
writeIteratorToPickleFile(items.asScala, filename)
}
def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename)) val file = new DataOutputStream(new FileOutputStream(filename))
for (item <- items) { for (item <- items) {
writeAsPickle(item, file) writeAsPickle(item, file)
@ -237,8 +250,10 @@ private[spark] object PythonRDD {
file.close() file.close()
} }
def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] = def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = {
rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head implicit val cm : ClassManifest[T] = rdd.elementClassManifest
rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator
}
} }
private object Pickle { private object Pickle {
@ -252,11 +267,6 @@ private object Pickle {
val APPENDS: Byte = 'e' val APPENDS: Byte = 'e'
} }
private class ExtractValue extends spark.api.java.function.Function[(Array[Byte],
Array[Byte]), Array[Byte]] {
override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2
}
private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
} }

View file

@ -31,7 +31,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
@transient var totalBlocks = -1 @transient var totalBlocks = -1
@transient var hasBlocks = new AtomicInteger(0) @transient var hasBlocks = new AtomicInteger(0)
// Used ONLY by Master to track how many unique blocks have been sent out // Used ONLY by driver to track how many unique blocks have been sent out
@transient var sentBlocks = new AtomicInteger(0) @transient var sentBlocks = new AtomicInteger(0)
@transient var listenPortLock = new Object @transient var listenPortLock = new Object
@ -42,7 +42,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
@transient var serveMR: ServeMultipleRequests = null @transient var serveMR: ServeMultipleRequests = null
// Used only in Master // Used only in driver
@transient var guideMR: GuideMultipleRequests = null @transient var guideMR: GuideMultipleRequests = null
// Used only in Workers // Used only in Workers
@ -99,14 +99,14 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
} }
// Must always come AFTER listenPort is created // Must always come AFTER listenPort is created
val masterSource = val driverSource =
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
hasBlocksBitVector.synchronized { hasBlocksBitVector.synchronized {
masterSource.hasBlocksBitVector = hasBlocksBitVector driverSource.hasBlocksBitVector = hasBlocksBitVector
} }
// In the beginning, this is the only known source to Guide // In the beginning, this is the only known source to Guide
listOfSources += masterSource listOfSources += driverSource
// Register with the Tracker // Register with the Tracker
MultiTracker.registerBroadcast(id, MultiTracker.registerBroadcast(id,
@ -122,7 +122,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
case None => case None =>
logInfo("Started reading broadcast variable " + id) logInfo("Started reading broadcast variable " + id)
// Initializing everything because Master will only send null/0 values // Initializing everything because driver will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache // Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables() initializeWorkerVariables()
@ -151,7 +151,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
} }
} }
// Initialize variables in the worker node. Master sends everything as 0/null // Initialize variables in the worker node. Driver sends everything as 0/null
private def initializeWorkerVariables() { private def initializeWorkerVariables() {
arrayOfBlocks = null arrayOfBlocks = null
hasBlocksBitVector = null hasBlocksBitVector = null
@ -248,7 +248,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
// Receive source information from Guide // Receive source information from Guide
var suitableSources = var suitableSources =
oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]] oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]]
logDebug("Received suitableSources from Master " + suitableSources) logDebug("Received suitableSources from Driver " + suitableSources)
addToListOfSources(suitableSources) addToListOfSources(suitableSources)
@ -532,7 +532,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
oosSource.writeObject(blockToAskFor) oosSource.writeObject(blockToAskFor)
oosSource.flush() oosSource.flush()
// CHANGED: Master might send some other block than the one // CHANGED: Driver might send some other block than the one
// requested to ensure fast spreading of all blocks. // requested to ensure fast spreading of all blocks.
val recvStartTime = System.currentTimeMillis val recvStartTime = System.currentTimeMillis
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
@ -982,9 +982,9 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
// Receive which block to send // Receive which block to send
var blockToSend = ois.readObject.asInstanceOf[Int] var blockToSend = ois.readObject.asInstanceOf[Int]
// If it is master AND at least one copy of each block has not been // If it is driver AND at least one copy of each block has not been
// sent out already, MODIFY blockToSend // sent out already, MODIFY blockToSend
if (MultiTracker.isMaster && sentBlocks.get < totalBlocks) { if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) {
blockToSend = sentBlocks.getAndIncrement blockToSend = sentBlocks.getAndIncrement
} }
@ -1031,7 +1031,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
private[spark] class BitTorrentBroadcastFactory private[spark] class BitTorrentBroadcastFactory
extends BroadcastFactory { extends BroadcastFactory {
def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) } def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new BitTorrentBroadcast[T](value_, isLocal, id) new BitTorrentBroadcast[T](value_, isLocal, id)

View file

@ -15,7 +15,7 @@ abstract class Broadcast[T](private[spark] val id: Long) extends Serializable {
} }
private[spark] private[spark]
class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializable { class BroadcastManager(val _isDriver: Boolean) extends Logging with Serializable {
private var initialized = false private var initialized = false
private var broadcastFactory: BroadcastFactory = null private var broadcastFactory: BroadcastFactory = null
@ -33,7 +33,7 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
// Initialize appropriate BroadcastFactory and BroadcastObject // Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isMaster) broadcastFactory.initialize(isDriver)
initialized = true initialized = true
} }
@ -49,5 +49,5 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl
def newBroadcast[T](value_ : T, isLocal: Boolean) = def newBroadcast[T](value_ : T, isLocal: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
def isMaster = isMaster_ def isDriver = _isDriver
} }

View file

@ -7,7 +7,7 @@ package spark.broadcast
* entire Spark job. * entire Spark job.
*/ */
private[spark] trait BroadcastFactory { private[spark] trait BroadcastFactory {
def initialize(isMaster: Boolean): Unit def initialize(isDriver: Boolean): Unit
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit def stop(): Unit
} }

View file

@ -48,7 +48,7 @@ extends Broadcast[T](id) with Logging with Serializable {
} }
private[spark] class HttpBroadcastFactory extends BroadcastFactory { private[spark] class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isMaster: Boolean) { HttpBroadcast.initialize(isMaster) } def initialize(isDriver: Boolean) { HttpBroadcast.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id) new HttpBroadcast[T](value_, isLocal, id)
@ -69,12 +69,12 @@ private object HttpBroadcast extends Logging {
private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup) private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup)
def initialize(isMaster: Boolean) { def initialize(isDriver: Boolean) {
synchronized { synchronized {
if (!initialized) { if (!initialized) {
bufferSize = System.getProperty("spark.buffer.size", "65536").toInt bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
compress = System.getProperty("spark.broadcast.compress", "true").toBoolean compress = System.getProperty("spark.broadcast.compress", "true").toBoolean
if (isMaster) { if (isDriver) {
createServer() createServer()
} }
serverUri = System.getProperty("spark.httpBroadcast.uri") serverUri = System.getProperty("spark.httpBroadcast.uri")

View file

@ -23,25 +23,24 @@ extends Logging {
var ranGen = new Random var ranGen = new Random
private var initialized = false private var initialized = false
private var isMaster_ = false private var _isDriver = false
private var stopBroadcast = false private var stopBroadcast = false
private var trackMV: TrackMultipleValues = null private var trackMV: TrackMultipleValues = null
def initialize(isMaster__ : Boolean) { def initialize(__isDriver: Boolean) {
synchronized { synchronized {
if (!initialized) { if (!initialized) {
_isDriver = __isDriver
isMaster_ = isMaster__ if (isDriver) {
if (isMaster) {
trackMV = new TrackMultipleValues trackMV = new TrackMultipleValues
trackMV.setDaemon(true) trackMV.setDaemon(true)
trackMV.start() trackMV.start()
// Set masterHostAddress to the master's IP address for the slaves to read // Set DriverHostAddress to the driver's IP address for the slaves to read
System.setProperty("spark.MultiTracker.MasterHostAddress", Utils.localIpAddress) System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress)
} }
initialized = true initialized = true
@ -54,10 +53,10 @@ extends Logging {
} }
// Load common parameters // Load common parameters
private var MasterHostAddress_ = System.getProperty( private var DriverHostAddress_ = System.getProperty(
"spark.MultiTracker.MasterHostAddress", "") "spark.MultiTracker.DriverHostAddress", "")
private var MasterTrackerPort_ = System.getProperty( private var DriverTrackerPort_ = System.getProperty(
"spark.broadcast.masterTrackerPort", "11111").toInt "spark.broadcast.driverTrackerPort", "11111").toInt
private var BlockSize_ = System.getProperty( private var BlockSize_ = System.getProperty(
"spark.broadcast.blockSize", "4096").toInt * 1024 "spark.broadcast.blockSize", "4096").toInt * 1024
private var MaxRetryCount_ = System.getProperty( private var MaxRetryCount_ = System.getProperty(
@ -91,11 +90,11 @@ extends Logging {
private var EndGameFraction_ = System.getProperty( private var EndGameFraction_ = System.getProperty(
"spark.broadcast.endGameFraction", "0.95").toDouble "spark.broadcast.endGameFraction", "0.95").toDouble
def isMaster = isMaster_ def isDriver = _isDriver
// Common config params // Common config params
def MasterHostAddress = MasterHostAddress_ def DriverHostAddress = DriverHostAddress_
def MasterTrackerPort = MasterTrackerPort_ def DriverTrackerPort = DriverTrackerPort_
def BlockSize = BlockSize_ def BlockSize = BlockSize_
def MaxRetryCount = MaxRetryCount_ def MaxRetryCount = MaxRetryCount_
@ -123,7 +122,7 @@ extends Logging {
var threadPool = Utils.newDaemonCachedThreadPool() var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(MasterTrackerPort) serverSocket = new ServerSocket(DriverTrackerPort)
logInfo("TrackMultipleValues started at " + serverSocket) logInfo("TrackMultipleValues started at " + serverSocket)
try { try {
@ -235,7 +234,7 @@ extends Logging {
try { try {
// Connect to the tracker to find out GuideInfo // Connect to the tracker to find out GuideInfo
clientSocketToTracker = clientSocketToTracker =
new Socket(MultiTracker.MasterHostAddress, MultiTracker.MasterTrackerPort) new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort)
oosTracker = oosTracker =
new ObjectOutputStream(clientSocketToTracker.getOutputStream) new ObjectOutputStream(clientSocketToTracker.getOutputStream)
oosTracker.flush() oosTracker.flush()
@ -276,7 +275,7 @@ extends Logging {
} }
def registerBroadcast(id: Long, gInfo: SourceInfo) { def registerBroadcast(id: Long, gInfo: SourceInfo) {
val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort) val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream) val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush() oosST.flush()
val oisST = new ObjectInputStream(socket.getInputStream) val oisST = new ObjectInputStream(socket.getInputStream)
@ -303,7 +302,7 @@ extends Logging {
} }
def unregisterBroadcast(id: Long) { def unregisterBroadcast(id: Long) {
val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort) val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream) val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush() oosST.flush()
val oisST = new ObjectInputStream(socket.getInputStream) val oisST = new ObjectInputStream(socket.getInputStream)

View file

@ -98,7 +98,7 @@ extends Broadcast[T](id) with Logging with Serializable {
case None => case None =>
logInfo("Started reading broadcast variable " + id) logInfo("Started reading broadcast variable " + id)
// Initializing everything because Master will only send null/0 values // Initializing everything because Driver will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache // Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables() initializeWorkerVariables()
@ -157,55 +157,55 @@ extends Broadcast[T](id) with Logging with Serializable {
listenPortLock.synchronized { listenPortLock.wait() } listenPortLock.synchronized { listenPortLock.wait() }
} }
var clientSocketToMaster: Socket = null var clientSocketToDriver: Socket = null
var oosMaster: ObjectOutputStream = null var oosDriver: ObjectOutputStream = null
var oisMaster: ObjectInputStream = null var oisDriver: ObjectInputStream = null
// Connect and receive broadcast from the specified source, retrying the // Connect and receive broadcast from the specified source, retrying the
// specified number of times in case of failures // specified number of times in case of failures
var retriesLeft = MultiTracker.MaxRetryCount var retriesLeft = MultiTracker.MaxRetryCount
do { do {
// Connect to Master and send this worker's Information // Connect to Driver and send this worker's Information
clientSocketToMaster = new Socket(MultiTracker.MasterHostAddress, gInfo.listenPort) clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort)
oosMaster = new ObjectOutputStream(clientSocketToMaster.getOutputStream) oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream)
oosMaster.flush() oosDriver.flush()
oisMaster = new ObjectInputStream(clientSocketToMaster.getInputStream) oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream)
logDebug("Connected to Master's guiding object") logDebug("Connected to Driver's guiding object")
// Send local source information // Send local source information
oosMaster.writeObject(SourceInfo(hostAddress, listenPort)) oosDriver.writeObject(SourceInfo(hostAddress, listenPort))
oosMaster.flush() oosDriver.flush()
// Receive source information from Master // Receive source information from Driver
var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo] var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo]
totalBlocks = sourceInfo.totalBlocks totalBlocks = sourceInfo.totalBlocks
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
totalBytes = sourceInfo.totalBytes totalBytes = sourceInfo.totalBytes
logDebug("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort)
val start = System.nanoTime val start = System.nanoTime
val receptionSucceeded = receiveSingleTransmission(sourceInfo) val receptionSucceeded = receiveSingleTransmission(sourceInfo)
val time = (System.nanoTime - start) / 1e9 val time = (System.nanoTime - start) / 1e9
// Updating some statistics in sourceInfo. Master will be using them later // Updating some statistics in sourceInfo. Driver will be using them later
if (!receptionSucceeded) { if (!receptionSucceeded) {
sourceInfo.receptionFailed = true sourceInfo.receptionFailed = true
} }
// Send back statistics to the Master // Send back statistics to the Driver
oosMaster.writeObject(sourceInfo) oosDriver.writeObject(sourceInfo)
if (oisMaster != null) { if (oisDriver != null) {
oisMaster.close() oisDriver.close()
} }
if (oosMaster != null) { if (oosDriver != null) {
oosMaster.close() oosDriver.close()
} }
if (clientSocketToMaster != null) { if (clientSocketToDriver != null) {
clientSocketToMaster.close() clientSocketToDriver.close()
} }
retriesLeft -= 1 retriesLeft -= 1
@ -552,7 +552,7 @@ extends Broadcast[T](id) with Logging with Serializable {
} }
private def sendObject() { private def sendObject() {
// Wait till receiving the SourceInfo from Master // Wait till receiving the SourceInfo from Driver
while (totalBlocks == -1) { while (totalBlocks == -1) {
totalBlocksLock.synchronized { totalBlocksLock.wait() } totalBlocksLock.synchronized { totalBlocksLock.wait() }
} }
@ -576,7 +576,7 @@ extends Broadcast[T](id) with Logging with Serializable {
private[spark] class TreeBroadcastFactory private[spark] class TreeBroadcastFactory
extends BroadcastFactory { extends BroadcastFactory {
def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) } def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TreeBroadcast[T](value_, isLocal, id) new TreeBroadcast[T](value_, isLocal, id)

View file

@ -4,7 +4,6 @@ import spark.deploy.ExecutorState.ExecutorState
import spark.deploy.master.{WorkerInfo, JobInfo} import spark.deploy.master.{WorkerInfo, JobInfo}
import spark.deploy.worker.ExecutorRunner import spark.deploy.worker.ExecutorRunner
import scala.collection.immutable.List import scala.collection.immutable.List
import scala.collection.mutable.HashMap
private[spark] sealed trait DeployMessage extends Serializable private[spark] sealed trait DeployMessage extends Serializable
@ -42,7 +41,8 @@ private[spark] case class LaunchExecutor(
execId: Int, execId: Int,
jobDesc: JobDescription, jobDesc: JobDescription,
cores: Int, cores: Int,
memory: Int) memory: Int,
sparkHome: String)
extends DeployMessage extends DeployMessage

View file

@ -4,7 +4,8 @@ private[spark] class JobDescription(
val name: String, val name: String,
val cores: Int, val cores: Int,
val memoryPerSlave: Int, val memoryPerSlave: Int,
val command: Command) val command: Command,
val sparkHome: String)
extends Serializable { extends Serializable {
val user = System.getProperty("user.name", "<unknown>") val user = System.getProperty("user.name", "<unknown>")

View file

@ -9,43 +9,32 @@ import spark.{Logging, Utils}
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
/**
* Testing class that creates a Spark standalone process in-cluster (that is, running the
* spark.deploy.master.Master and spark.deploy.worker.Workers in the same JVMs). Executors launched
* by the Workers still run in separate JVMs. This can be used to test distributed operation and
* fault recovery without spinning up a lot of processes.
*/
private[spark] private[spark]
class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) extends Logging { class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
val localIpAddress = Utils.localIpAddress private val localIpAddress = Utils.localIpAddress
private val masterActorSystems = ArrayBuffer[ActorSystem]()
private val workerActorSystems = ArrayBuffer[ActorSystem]()
var masterActor : ActorRef = _ def start(): String = {
var masterActorSystem : ActorSystem = _ logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
var masterPort : Int = _
var masterUrl : String = _
val slaveActorSystems = ArrayBuffer[ActorSystem]()
val slaveActors = ArrayBuffer[ActorRef]()
def start() : String = {
logInfo("Starting a local Spark cluster with " + numSlaves + " slaves.")
/* Start the Master */ /* Start the Master */
val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0) val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0)
masterActorSystem = actorSystem masterActorSystems += masterSystem
masterUrl = "spark://" + localIpAddress + ":" + masterPort val masterUrl = "spark://" + localIpAddress + ":" + masterPort
val actor = masterActorSystem.actorOf(
Props(new Master(localIpAddress, masterPort, 0)), name = "Master")
masterActor = actor
/* Start the Slaves */ /* Start the Workers */
for (slaveNum <- 1 to numSlaves) { for (workerNum <- 1 to numWorkers) {
/* We can pretend to test distributed stuff by giving the slaves distinct hostnames. val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker,
All of 127/8 should be a loopback, we use 127.100.*.* in hopes that it is memoryPerWorker, masterUrl, null, Some(workerNum))
sufficiently distinctive. */ workerActorSystems += workerSystem
val slaveIpAddress = "127.100.0." + (slaveNum % 256)
val (actorSystem, boundPort) =
AkkaUtils.createActorSystem("sparkWorker" + slaveNum, slaveIpAddress, 0)
slaveActorSystems += actorSystem
val actor = actorSystem.actorOf(
Props(new Worker(slaveIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)),
name = "Worker")
slaveActors += actor
} }
return masterUrl return masterUrl
@ -53,10 +42,10 @@ class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int)
def stop() { def stop() {
logInfo("Shutting down local Spark cluster.") logInfo("Shutting down local Spark cluster.")
// Stop the slaves before the master so they don't get upset that it disconnected // Stop the workers before the master so they don't get upset that it disconnected
slaveActorSystems.foreach(_.shutdown()) workerActorSystems.foreach(_.shutdown())
slaveActorSystems.foreach(_.awaitTermination()) workerActorSystems.foreach(_.awaitTermination())
masterActorSystem.shutdown() masterActorSystems.foreach(_.shutdown())
masterActorSystem.awaitTermination() masterActorSystems.foreach(_.awaitTermination())
} }
} }

View file

@ -9,6 +9,7 @@ import spark.{SparkException, Logging}
import akka.remote.RemoteClientLifeCycleEvent import akka.remote.RemoteClientLifeCycleEvent
import akka.remote.RemoteClientShutdown import akka.remote.RemoteClientShutdown
import spark.deploy.RegisterJob import spark.deploy.RegisterJob
import spark.deploy.master.Master
import akka.remote.RemoteClientDisconnected import akka.remote.RemoteClientDisconnected
import akka.actor.Terminated import akka.actor.Terminated
import akka.dispatch.Await import akka.dispatch.Await
@ -24,26 +25,18 @@ private[spark] class Client(
listener: ClientListener) listener: ClientListener)
extends Logging { extends Logging {
val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r
var actor: ActorRef = null var actor: ActorRef = null
var jobId: String = null var jobId: String = null
if (MASTER_REGEX.unapplySeq(masterUrl) == None) {
throw new SparkException("Invalid master URL: " + masterUrl)
}
class ClientActor extends Actor with Logging { class ClientActor extends Actor with Logging {
var master: ActorRef = null var master: ActorRef = null
var masterAddress: Address = null var masterAddress: Address = null
var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
override def preStart() { override def preStart() {
val Seq(masterHost, masterPort) = MASTER_REGEX.unapplySeq(masterUrl).get logInfo("Connecting to master " + masterUrl)
logInfo("Connecting to master spark://" + masterHost + ":" + masterPort)
val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
try { try {
master = context.actorFor(akkaUrl) master = context.actorFor(Master.toAkkaUrl(masterUrl))
masterAddress = master.path.address masterAddress = master.path.address
master ! RegisterJob(jobDescription) master ! RegisterJob(jobDescription)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])

View file

@ -12,7 +12,7 @@ private[spark] trait ClientListener {
def disconnected(): Unit def disconnected(): Unit
def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int): Unit def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int): Unit
def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit
} }

View file

@ -25,7 +25,7 @@ private[spark] object TestClient {
val url = args(0) val url = args(0)
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0)
val desc = new JobDescription( val desc = new JobDescription(
"TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map())) "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home")
val listener = new TestListener val listener = new TestListener
val client = new Client(actorSystem, url, desc, listener) val client = new Client(actorSystem, url, desc, listener)
client.start() client.start()

View file

@ -10,7 +10,7 @@ private[spark] class JobInfo(
val id: String, val id: String,
val desc: JobDescription, val desc: JobDescription,
val submitDate: Date, val submitDate: Date,
val actor: ActorRef) val driver: ActorRef)
{ {
var state = JobState.WAITING var state = JobState.WAITING
var executors = new mutable.HashMap[Int, ExecutorInfo] var executors = new mutable.HashMap[Int, ExecutorInfo]

View file

@ -88,7 +88,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
execOption match { execOption match {
case Some(exec) => { case Some(exec) => {
exec.state = state exec.state = state
exec.job.actor ! ExecutorUpdated(execId, state, message, exitStatus) exec.job.driver ! ExecutorUpdated(execId, state, message, exitStatus)
if (ExecutorState.isFinished(state)) { if (ExecutorState.isFinished(state)) {
val jobInfo = idToJob(jobId) val jobInfo = idToJob(jobId)
// Remove this executor from the worker and job // Remove this executor from the worker and job
@ -97,14 +97,12 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
exec.worker.removeExecutor(exec) exec.worker.removeExecutor(exec)
// Only retry certain number of times so we don't go into an infinite loop. // Only retry certain number of times so we don't go into an infinite loop.
if (jobInfo.incrementRetryCount <= JobState.MAX_NUM_RETRY) { if (jobInfo.incrementRetryCount < JobState.MAX_NUM_RETRY) {
schedule() schedule()
} else { } else {
val e = new SparkException("Job %s wth ID %s failed %d times.".format( logError("Job %s with ID %s failed %d times, removing it".format(
jobInfo.desc.name, jobInfo.id, jobInfo.retryCount)) jobInfo.desc.name, jobInfo.id, jobInfo.retryCount))
logError(e.getMessage, e) removeJob(jobInfo)
throw e
//System.exit(1)
} }
} }
} }
@ -173,7 +171,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
for (pos <- 0 until numUsable) { for (pos <- 0 until numUsable) {
if (assigned(pos) > 0) { if (assigned(pos) > 0) {
val exec = job.addExecutor(usableWorkers(pos), assigned(pos)) val exec = job.addExecutor(usableWorkers(pos), assigned(pos))
launchExecutor(usableWorkers(pos), exec) launchExecutor(usableWorkers(pos), exec, job.desc.sparkHome)
job.state = JobState.RUNNING job.state = JobState.RUNNING
} }
} }
@ -186,7 +184,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
val coresToUse = math.min(worker.coresFree, job.coresLeft) val coresToUse = math.min(worker.coresFree, job.coresLeft)
if (coresToUse > 0) { if (coresToUse > 0) {
val exec = job.addExecutor(worker, coresToUse) val exec = job.addExecutor(worker, coresToUse)
launchExecutor(worker, exec) launchExecutor(worker, exec, job.desc.sparkHome)
job.state = JobState.RUNNING job.state = JobState.RUNNING
} }
} }
@ -195,11 +193,11 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
} }
} }
def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo) { def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) {
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec) worker.addExecutor(exec)
worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory) worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome)
exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) exec.job.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
} }
def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int, def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
@ -221,19 +219,19 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
actorToWorker -= worker.actor actorToWorker -= worker.actor
addressToWorker -= worker.actor.path.address addressToWorker -= worker.actor.path.address
for (exec <- worker.executors.values) { for (exec <- worker.executors.values) {
exec.job.actor ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None, None) exec.job.driver ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None, None)
exec.job.executors -= exec.id exec.job.executors -= exec.id
} }
} }
def addJob(desc: JobDescription, actor: ActorRef): JobInfo = { def addJob(desc: JobDescription, driver: ActorRef): JobInfo = {
val now = System.currentTimeMillis() val now = System.currentTimeMillis()
val date = new Date(now) val date = new Date(now)
val job = new JobInfo(now, newJobId(date), desc, date, actor) val job = new JobInfo(now, newJobId(date), desc, date, driver)
jobs += job jobs += job
idToJob(job.id) = job idToJob(job.id) = job
actorToJob(sender) = job actorToJob(driver) = job
addressToJob(sender.path.address) = job addressToJob(driver.path.address) = job
return job return job
} }
@ -242,8 +240,8 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
logInfo("Removing job " + job.id) logInfo("Removing job " + job.id)
jobs -= job jobs -= job
idToJob -= job.id idToJob -= job.id
actorToJob -= job.actor actorToJob -= job.driver
addressToWorker -= job.actor.path.address addressToWorker -= job.driver.path.address
completedJobs += job // Remember it in our history completedJobs += job // Remember it in our history
waitingJobs -= job waitingJobs -= job
for (exec <- job.executors.values) { for (exec <- job.executors.values) {
@ -264,11 +262,29 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
} }
private[spark] object Master { private[spark] object Master {
private val systemName = "sparkMaster"
private val actorName = "Master"
private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r
def main(argStrings: Array[String]) { def main(argStrings: Array[String]) {
val args = new MasterArguments(argStrings) val args = new MasterArguments(argStrings)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port) val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort)
val actor = actorSystem.actorOf(
Props(new Master(args.ip, boundPort, args.webUiPort)), name = "Master")
actorSystem.awaitTermination() actorSystem.awaitTermination()
} }
/** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */
def toAkkaUrl(sparkUrl: String): String = {
sparkUrl match {
case sparkUrlRegex(host, port) =>
"akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName)
case _ =>
throw new SparkException("Invalid master URL: " + sparkUrl)
}
}
def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int) = {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName)
(actorSystem, boundPort)
}
} }

View file

@ -14,12 +14,15 @@ import cc.spray.typeconversion.SprayJsonSupport._
import spark.deploy._ import spark.deploy._
import spark.deploy.JsonProtocol._ import spark.deploy.JsonProtocol._
/**
* Web UI server for the standalone master.
*/
private[spark] private[spark]
class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives { class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives {
val RESOURCE_DIR = "spark/deploy/master/webui" val RESOURCE_DIR = "spark/deploy/master/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static" val STATIC_RESOURCE_DIR = "spark/deploy/static"
implicit val timeout = Timeout(1 seconds) implicit val timeout = Timeout(10 seconds)
val handler = { val handler = {
get { get {
@ -42,13 +45,9 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) => case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) =>
val future = master ? RequestMasterState val future = master ? RequestMasterState
val jobInfo = for (masterState <- future.mapTo[MasterState]) yield { val jobInfo = for (masterState <- future.mapTo[MasterState]) yield {
masterState.activeJobs.find(_.id == jobId) match { masterState.activeJobs.find(_.id == jobId).getOrElse({
case Some(job) => job masterState.completedJobs.find(_.id == jobId).getOrElse(null)
case _ => masterState.completedJobs.find(_.id == jobId) match { })
case Some(job) => job
case _ => null
}
}
} }
respondWithMediaType(MediaTypes.`application/json`) { ctx => respondWithMediaType(MediaTypes.`application/json`) { ctx =>
ctx.complete(jobInfo.mapTo[JobInfo]) ctx.complete(jobInfo.mapTo[JobInfo])
@ -58,14 +57,10 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
val future = master ? RequestMasterState val future = master ? RequestMasterState
future.map { state => future.map { state =>
val masterState = state.asInstanceOf[MasterState] val masterState = state.asInstanceOf[MasterState]
val job = masterState.activeJobs.find(_.id == jobId).getOrElse({
masterState.activeJobs.find(_.id == jobId) match { masterState.completedJobs.find(_.id == jobId).getOrElse(null)
case Some(job) => spark.deploy.master.html.job_details.render(job) })
case _ => masterState.completedJobs.find(_.id == jobId) match { spark.deploy.master.html.job_details.render(job)
case Some(job) => spark.deploy.master.html.job_details.render(job)
case _ => null
}
}
} }
} }
} }
@ -76,5 +71,4 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
getFromResourceDirectory(RESOURCE_DIR) getFromResourceDirectory(RESOURCE_DIR)
} }
} }
} }

View file

@ -65,9 +65,9 @@ private[spark] class ExecutorRunner(
} }
} }
/** Replace variables such as {{SLAVEID}} and {{CORES}} in a command argument passed to us */ /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */
def substituteVariables(argument: String): String = argument match { def substituteVariables(argument: String): String = argument match {
case "{{SLAVEID}}" => workerId case "{{EXECUTOR_ID}}" => execId.toString
case "{{HOSTNAME}}" => hostname case "{{HOSTNAME}}" => hostname
case "{{CORES}}" => cores.toString case "{{CORES}}" => cores.toString
case other => other case other => other
@ -106,11 +106,6 @@ private[spark] class ExecutorRunner(
throw new IOException("Failed to create directory " + executorDir) throw new IOException("Failed to create directory " + executorDir)
} }
// Download the files it depends on into it (disabled for now)
//for (url <- jobDesc.fileUrls) {
// fetchFile(url, executorDir)
//}
// Launch the process // Launch the process
val command = buildCommandSeq() val command = buildCommandSeq()
val builder = new ProcessBuilder(command: _*).directory(executorDir) val builder = new ProcessBuilder(command: _*).directory(executorDir)
@ -118,8 +113,7 @@ private[spark] class ExecutorRunner(
for ((key, value) <- jobDesc.command.environment) { for ((key, value) <- jobDesc.command.environment) {
env.put(key, value) env.put(key, value)
} }
env.put("SPARK_CORES", cores.toString) env.put("SPARK_MEM", memory.toString + "m")
env.put("SPARK_MEMORY", memory.toString)
// In case we are running this from within the Spark Shell, avoid creating a "scala" // In case we are running this from within the Spark Shell, avoid creating a "scala"
// parent process for the executor command // parent process for the executor command
env.put("SPARK_LAUNCH_WITH_SCALA", "0") env.put("SPARK_LAUNCH_WITH_SCALA", "0")

View file

@ -1,19 +1,17 @@
package spark.deploy.worker package spark.deploy.worker
import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.mutable.{ArrayBuffer, HashMap}
import akka.actor.{ActorRef, Props, Actor} import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated}
import spark.{Logging, Utils} import spark.{Logging, Utils}
import spark.util.AkkaUtils import spark.util.AkkaUtils
import spark.deploy._ import spark.deploy._
import akka.remote.RemoteClientLifeCycleEvent import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected}
import java.text.SimpleDateFormat import java.text.SimpleDateFormat
import java.util.Date import java.util.Date
import akka.remote.RemoteClientShutdown
import akka.remote.RemoteClientDisconnected
import spark.deploy.RegisterWorker import spark.deploy.RegisterWorker
import spark.deploy.LaunchExecutor import spark.deploy.LaunchExecutor
import spark.deploy.RegisterWorkerFailed import spark.deploy.RegisterWorkerFailed
import akka.actor.Terminated import spark.deploy.master.Master
import java.io.File import java.io.File
private[spark] class Worker( private[spark] class Worker(
@ -27,7 +25,6 @@ private[spark] class Worker(
extends Actor with Logging { extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r
var master: ActorRef = null var master: ActorRef = null
var masterWebUiUrl : String = "" var masterWebUiUrl : String = ""
@ -48,11 +45,7 @@ private[spark] class Worker(
def memoryFree: Int = memory - memoryUsed def memoryFree: Int = memory - memoryUsed
def createWorkDir() { def createWorkDir() {
workDir = if (workDirPath != null) { workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work"))
new File(workDirPath)
} else {
new File(sparkHome, "work")
}
try { try {
if (!workDir.exists() && !workDir.mkdirs()) { if (!workDir.exists() && !workDir.mkdirs()) {
logError("Failed to create work directory " + workDir) logError("Failed to create work directory " + workDir)
@ -68,8 +61,7 @@ private[spark] class Worker(
override def preStart() { override def preStart() {
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
ip, port, cores, Utils.memoryMegabytesToString(memory))) ip, port, cores, Utils.memoryMegabytesToString(memory)))
val envVar = System.getenv("SPARK_HOME") sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
sparkHome = new File(if (envVar == null) "." else envVar)
logInfo("Spark home: " + sparkHome) logInfo("Spark home: " + sparkHome)
createWorkDir() createWorkDir()
connectToMaster() connectToMaster()
@ -77,24 +69,15 @@ private[spark] class Worker(
} }
def connectToMaster() { def connectToMaster() {
masterUrl match { logInfo("Connecting to master " + masterUrl)
case MASTER_REGEX(masterHost, masterPort) => { try {
logInfo("Connecting to master spark://" + masterHost + ":" + masterPort) master = context.actorFor(Master.toAkkaUrl(masterUrl))
val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort) master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
try { context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
master = context.actorFor(akkaUrl) context.watch(master) // Doesn't work with remote actors, but useful for testing
master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress) } catch {
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) case e: Exception =>
context.watch(master) // Doesn't work with remote actors, but useful for testing logError("Failed to connect to master", e)
} catch {
case e: Exception =>
logError("Failed to connect to master", e)
System.exit(1)
}
}
case _ =>
logError("Invalid master URL: " + masterUrl)
System.exit(1) System.exit(1)
} }
} }
@ -119,10 +102,10 @@ private[spark] class Worker(
logError("Worker registration failed: " + message) logError("Worker registration failed: " + message)
System.exit(1) System.exit(1)
case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_) => case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_, execSparkHome_) =>
logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name)) logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name))
val manager = new ExecutorRunner( val manager = new ExecutorRunner(
jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, sparkHome, workDir) jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir)
executors(jobId + "/" + execId) = manager executors(jobId + "/" + execId) = manager
manager.start() manager.start()
coresUsed += cores_ coresUsed += cores_
@ -134,7 +117,9 @@ private[spark] class Worker(
val fullId = jobId + "/" + execId val fullId = jobId + "/" + execId
if (ExecutorState.isFinished(state)) { if (ExecutorState.isFinished(state)) {
val executor = executors(fullId) val executor = executors(fullId)
logInfo("Executor " + fullId + " finished with state " + state) logInfo("Executor " + fullId + " finished with state " + state +
message.map(" message " + _).getOrElse("") +
exitStatus.map(" exitStatus " + _).getOrElse(""))
finishedExecutors(fullId) = executor finishedExecutors(fullId) = executor
executors -= fullId executors -= fullId
coresUsed -= executor.cores coresUsed -= executor.cores
@ -143,9 +128,13 @@ private[spark] class Worker(
case KillExecutor(jobId, execId) => case KillExecutor(jobId, execId) =>
val fullId = jobId + "/" + execId val fullId = jobId + "/" + execId
val executor = executors(fullId) executors.get(fullId) match {
logInfo("Asked to kill executor " + fullId) case Some(executor) =>
executor.kill() logInfo("Asked to kill executor " + fullId)
executor.kill()
case None =>
logInfo("Asked to kill unknown executor " + fullId)
}
case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
masterDisconnected() masterDisconnected()
@ -177,11 +166,19 @@ private[spark] class Worker(
private[spark] object Worker { private[spark] object Worker {
def main(argStrings: Array[String]) { def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings) val args = new WorkerArguments(argStrings)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port) val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores,
val actor = actorSystem.actorOf( args.memory, args.master, args.workDir)
Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory,
args.master, args.workDir)),
name = "Worker")
actorSystem.awaitTermination() actorSystem.awaitTermination()
} }
def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int,
masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory,
masterUrl, workDir)), name = "Worker")
(actorSystem, boundPort)
}
} }

View file

@ -13,12 +13,15 @@ import cc.spray.typeconversion.SprayJsonSupport._
import spark.deploy.{WorkerState, RequestWorkerState} import spark.deploy.{WorkerState, RequestWorkerState}
import spark.deploy.JsonProtocol._ import spark.deploy.JsonProtocol._
/**
* Web UI server for the standalone worker.
*/
private[spark] private[spark]
class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives { class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives {
val RESOURCE_DIR = "spark/deploy/worker/webui" val RESOURCE_DIR = "spark/deploy/worker/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static" val STATIC_RESOURCE_DIR = "spark/deploy/static"
implicit val timeout = Timeout(1 seconds) implicit val timeout = Timeout(10 seconds)
val handler = { val handler = {
get { get {
@ -50,5 +53,4 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct
getFromResourceDirectory(RESOURCE_DIR) getFromResourceDirectory(RESOURCE_DIR)
} }
} }
} }

View file

@ -30,7 +30,7 @@ private[spark] class Executor extends Logging {
initLogging() initLogging()
def initialize(slaveHostname: String, properties: Seq[(String, String)]) { def initialize(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) {
// Make sure the local hostname we report matches the cluster scheduler's name for this host // Make sure the local hostname we report matches the cluster scheduler's name for this host
Utils.setCustomHostname(slaveHostname) Utils.setCustomHostname(slaveHostname)
@ -64,7 +64,7 @@ private[spark] class Executor extends Logging {
) )
// Initialize Spark environment (using system properties read above) // Initialize Spark environment (using system properties read above)
env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false) env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
SparkEnv.set(env) SparkEnv.set(env)
// Start worker thread pool // Start worker thread pool
@ -159,22 +159,24 @@ private[spark] class Executor extends Logging {
* SparkContext. Also adds any new JARs we fetched to the class loader. * SparkContext. Also adds any new JARs we fetched to the class loader.
*/ */
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
// Fetch missing dependencies synchronized {
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { // Fetch missing dependencies
logInfo("Fetching " + name + " with timestamp " + timestamp) for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
Utils.fetchFile(name, new File(".")) logInfo("Fetching " + name + " with timestamp " + timestamp)
currentFiles(name) = timestamp Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
} currentFiles(name) = timestamp
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { }
logInfo("Fetching " + name + " with timestamp " + timestamp) for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
Utils.fetchFile(name, new File(".")) logInfo("Fetching " + name + " with timestamp " + timestamp)
currentJars(name) = timestamp Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
// Add it to our class loader currentJars(name) = timestamp
val localName = name.split("/").last // Add it to our class loader
val url = new File(".", localName).toURI.toURL val localName = name.split("/").last
if (!urlClassLoader.getURLs.contains(url)) { val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
logInfo("Adding " + url + " to class loader") if (!urlClassLoader.getURLs.contains(url)) {
urlClassLoader.addURL(url) logInfo("Adding " + url + " to class loader")
urlClassLoader.addURL(url)
}
} }
} }
} }

View file

@ -29,9 +29,14 @@ private[spark] class MesosExecutorBackend(executor: Executor)
executorInfo: ExecutorInfo, executorInfo: ExecutorInfo,
frameworkInfo: FrameworkInfo, frameworkInfo: FrameworkInfo,
slaveInfo: SlaveInfo) { slaveInfo: SlaveInfo) {
logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
this.driver = driver this.driver = driver
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
executor.initialize(slaveInfo.getHostname, properties) executor.initialize(
executorInfo.getExecutorId.getValue,
slaveInfo.getHostname,
properties
)
} }
override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) { override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {

View file

@ -4,75 +4,72 @@ import java.nio.ByteBuffer
import spark.Logging import spark.Logging
import spark.TaskState.TaskState import spark.TaskState.TaskState
import spark.util.AkkaUtils import spark.util.AkkaUtils
import akka.actor.{ActorRef, Actor, Props} import akka.actor.{ActorRef, Actor, Props, Terminated}
import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected}
import java.util.concurrent.{TimeUnit, ThreadPoolExecutor, SynchronousQueue} import java.util.concurrent.{TimeUnit, ThreadPoolExecutor, SynchronousQueue}
import akka.remote.RemoteClientLifeCycleEvent
import spark.scheduler.cluster._ import spark.scheduler.cluster._
import spark.scheduler.cluster.RegisteredSlave import spark.scheduler.cluster.RegisteredExecutor
import spark.scheduler.cluster.LaunchTask import spark.scheduler.cluster.LaunchTask
import spark.scheduler.cluster.RegisterSlaveFailed import spark.scheduler.cluster.RegisterExecutorFailed
import spark.scheduler.cluster.RegisterSlave import spark.scheduler.cluster.RegisterExecutor
private[spark] class StandaloneExecutorBackend( private[spark] class StandaloneExecutorBackend(
executor: Executor, executor: Executor,
masterUrl: String, driverUrl: String,
slaveId: String, executorId: String,
hostname: String, hostname: String,
cores: Int) cores: Int)
extends Actor extends Actor
with ExecutorBackend with ExecutorBackend
with Logging { with Logging {
var master: ActorRef = null var driver: ActorRef = null
override def preStart() { override def preStart() {
try { logInfo("Connecting to driver: " + driverUrl)
logInfo("Connecting to master: " + masterUrl) driver = context.actorFor(driverUrl)
master = context.actorFor(masterUrl) driver ! RegisterExecutor(executorId, hostname, cores)
master ! RegisterSlave(slaveId, hostname, cores) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(driver) // Doesn't work with remote actors, but useful for testing
context.watch(master) // Doesn't work with remote actors, but useful for testing
} catch {
case e: Exception =>
logError("Failed to connect to master", e)
System.exit(1)
}
} }
override def receive = { override def receive = {
case RegisteredSlave(sparkProperties) => case RegisteredExecutor(sparkProperties) =>
logInfo("Successfully registered with master") logInfo("Successfully registered with driver")
executor.initialize(hostname, sparkProperties) executor.initialize(executorId, hostname, sparkProperties)
case RegisterSlaveFailed(message) => case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message) logError("Slave registration failed: " + message)
System.exit(1) System.exit(1)
case LaunchTask(taskDesc) => case LaunchTask(taskDesc) =>
logInfo("Got assigned task " + taskDesc.taskId) logInfo("Got assigned task " + taskDesc.taskId)
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask) executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
logError("Driver terminated or disconnected! Shutting down.")
System.exit(1)
} }
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
master ! StatusUpdate(slaveId, taskId, state, data) driver ! StatusUpdate(executorId, taskId, state, data)
} }
} }
private[spark] object StandaloneExecutorBackend { private[spark] object StandaloneExecutorBackend {
def run(masterUrl: String, slaveId: String, hostname: String, cores: Int) { def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc // before getting started with all our system properties, etc
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0) val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0)
val actor = actorSystem.actorOf( val actor = actorSystem.actorOf(
Props(new StandaloneExecutorBackend(new Executor, masterUrl, slaveId, hostname, cores)), Props(new StandaloneExecutorBackend(new Executor, driverUrl, executorId, hostname, cores)),
name = "Executor") name = "Executor")
actorSystem.awaitTermination() actorSystem.awaitTermination()
} }
def main(args: Array[String]) { def main(args: Array[String]) {
if (args.length != 4) { if (args.length != 4) {
System.err.println("Usage: StandaloneExecutorBackend <master> <slaveId> <hostname> <cores>") System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores>")
System.exit(1) System.exit(1)
} }
run(args(0), args(1), args(2), args(3).toInt) run(args(0), args(1), args(2), args(3).toInt)

View file

@ -12,7 +12,14 @@ import java.net._
private[spark] private[spark]
abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging { abstract class Connection(val channel: SocketChannel, val selector: Selector,
val remoteConnectionManagerId: ConnectionManagerId) extends Logging {
def this(channel_ : SocketChannel, selector_ : Selector) = {
this(channel_, selector_,
ConnectionManagerId.fromSocketAddress(
channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
))
}
channel.configureBlocking(false) channel.configureBlocking(false)
channel.socket.setTcpNoDelay(true) channel.socket.setTcpNoDelay(true)
@ -25,7 +32,6 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
var onKeyInterestChangeCallback: (Connection, Int) => Unit = null var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
val remoteAddress = getRemoteAddress() val remoteAddress = getRemoteAddress()
val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress)
def key() = channel.keyFor(selector) def key() = channel.keyFor(selector)
@ -103,8 +109,9 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
} }
private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector) private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
extends Connection(SocketChannel.open, selector_) { remoteId_ : ConnectionManagerId)
extends Connection(SocketChannel.open, selector_, remoteId_) {
class Outbox(fair: Int = 0) { class Outbox(fair: Int = 0) {
val messages = new Queue[Message]() val messages = new Queue[Message]()

View file

@ -52,9 +52,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
val sendMessageRequests = new Queue[(Message, SendingConnection)] val sendMessageRequests = new Queue[(Message, SendingConnection)]
implicit val futureExecContext = ExecutionContext.fromExecutor( implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
Executors.newCachedThreadPool(DaemonThreadFactory))
var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
serverChannel.configureBlocking(false) serverChannel.configureBlocking(false)
@ -300,7 +299,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = { def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, new SendingConnection(inetSocketAddress, selector)) val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId,
new SendingConnection(inetSocketAddress, selector, connectionManagerId))
newConnection newConnection
} }
val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress) val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress)

View file

@ -32,7 +32,7 @@ private[spark] class ApproximateActionListener[T, U, R](
if (finishedTasks == totalTasks) { if (finishedTasks == totalTasks) {
// If we had already returned a PartialResult, set its final value // If we had already returned a PartialResult, set its final value
resultObject.foreach(r => r.setFinalValue(evaluator.currentResult())) resultObject.foreach(r => r.setFinalValue(evaluator.currentResult()))
// Notify any waiting thread that may have called getResult // Notify any waiting thread that may have called awaitResult
this.notifyAll() this.notifyAll()
} }
} }
@ -49,7 +49,7 @@ private[spark] class ApproximateActionListener[T, U, R](
* Waits for up to timeout milliseconds since the listener was created and then returns a * Waits for up to timeout milliseconds since the listener was created and then returns a
* PartialResult with the result so far. This may be complete if the whole job is done. * PartialResult with the result so far. This may be complete if the whole job is done.
*/ */
def getResult(): PartialResult[R] = synchronized { def awaitResult(): PartialResult[R] = synchronized {
val finishTime = startTime + timeout val finishTime = startTime + timeout
while (true) { while (true) {
val time = System.currentTimeMillis() val time = System.currentTimeMillis()

View file

@ -11,13 +11,11 @@ private[spark]
class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
extends RDD[T](sc, Nil) { extends RDD[T](sc, Nil) {
@transient @transient var splits_ : Array[Split] = (0 until blockIds.size).map(i => {
var splits_ : Array[Split] = (0 until blockIds.size).map(i => {
new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split] new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split]
}).toArray }).toArray
@transient @transient lazy val locations_ = {
lazy val locations_ = {
val blockManager = SparkEnv.get.blockManager val blockManager = SparkEnv.get.blockManager
/*val locations = blockIds.map(id => blockManager.getLocations(id))*/ /*val locations = blockIds.map(id => blockManager.getLocations(id))*/
val locations = blockManager.getLocations(blockIds) val locations = blockManager.getLocations(blockIds)

View file

@ -1,7 +1,7 @@
package spark.rdd package spark.rdd
import java.io.{ObjectOutputStream, IOException} import java.io.{ObjectOutputStream, IOException}
import spark.{OneToOneDependency, NarrowDependency, RDD, SparkContext, Split, TaskContext} import spark._
private[spark] private[spark]
@ -35,8 +35,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
val numSplitsInRdd2 = rdd2.splits.size val numSplitsInRdd2 = rdd2.splits.size
@transient override def getSplits: Array[Split] = {
var splits_ = {
// create the cross product split // create the cross product split
val array = new Array[Split](rdd1.splits.size * rdd2.splits.size) val array = new Array[Split](rdd1.splits.size * rdd2.splits.size)
for (s1 <- rdd1.splits; s2 <- rdd2.splits) { for (s1 <- rdd1.splits; s2 <- rdd2.splits) {
@ -46,8 +45,6 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
array array
} }
override def getSplits = splits_
override def getPreferredLocations(split: Split) = { override def getPreferredLocations(split: Split) = {
val currSplit = split.asInstanceOf[CartesianSplit] val currSplit = split.asInstanceOf[CartesianSplit]
rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
@ -59,7 +56,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
} }
var deps_ = List( override def getDependencies: Seq[Dependency[_]] = List(
new NarrowDependency(rdd1) { new NarrowDependency(rdd1) {
def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2) def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2)
}, },
@ -68,11 +65,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
} }
) )
override def getDependencies = deps_
override def clearDependencies() { override def clearDependencies() {
deps_ = Nil
splits_ = null
rdd1 = null rdd1 = null
rdd2 = null rdd2 = null
} }

View file

@ -9,23 +9,26 @@ import org.apache.hadoop.fs.Path
import java.io.{File, IOException, EOFException} import java.io.{File, IOException, EOFException}
import java.text.NumberFormat import java.text.NumberFormat
private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split { private[spark] class CheckpointRDDSplit(val index: Int) extends Split {}
override val index: Int = idx
}
/** /**
* This RDD represents a RDD checkpoint file (similar to HadoopRDD). * This RDD represents a RDD checkpoint file (similar to HadoopRDD).
*/ */
private[spark] private[spark]
class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String)
extends RDD[T](sc, Nil) { extends RDD[T](sc, Nil) {
@transient val path = new Path(checkpointPath) @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
@transient val fs = path.getFileSystem(new Configuration())
@transient val splits_ : Array[Split] = { @transient val splits_ : Array[Split] = {
val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted val dirContents = fs.listStatus(new Path(checkpointPath))
splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray val splitFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
val numSplits = splitFiles.size
if (!splitFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
!splitFiles(numSplits-1).endsWith(CheckpointRDD.splitIdToFile(numSplits-1))) {
throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
}
Array.tabulate(numSplits)(i => new CheckpointRDDSplit(i))
} }
checkpointData = Some(new RDDCheckpointData[T](this)) checkpointData = Some(new RDDCheckpointData[T](this))
@ -34,36 +37,34 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String)
override def getSplits = splits_ override def getSplits = splits_
override def getPreferredLocations(split: Split): Seq[String] = { override def getPreferredLocations(split: Split): Seq[String] = {
val status = fs.getFileStatus(path) val status = fs.getFileStatus(new Path(checkpointPath))
val locations = fs.getFileBlockLocations(status, 0, status.getLen) val locations = fs.getFileBlockLocations(status, 0, status.getLen)
locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost") locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
} }
override def compute(split: Split, context: TaskContext): Iterator[T] = { override def compute(split: Split, context: TaskContext): Iterator[T] = {
CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile, context) val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))
CheckpointRDD.readFromFile(file, context)
} }
override def checkpoint() { override def checkpoint() {
// Do nothing. Hadoop RDD should not be checkpointed. // Do nothing. CheckpointRDD should not be checkpointed.
} }
} }
private[spark] object CheckpointRDD extends Logging { private[spark] object CheckpointRDD extends Logging {
def splitIdToFileName(splitId: Int): String = { def splitIdToFile(splitId: Int): String = {
val numfmt = NumberFormat.getInstance() "part-%05d".format(splitId)
numfmt.setMinimumIntegerDigits(5)
numfmt.setGroupingUsed(false)
"part-" + numfmt.format(splitId)
} }
def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) { def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
val outputDir = new Path(path) val outputDir = new Path(path)
val fs = outputDir.getFileSystem(new Configuration()) val fs = outputDir.getFileSystem(new Configuration())
val finalOutputName = splitIdToFileName(context.splitId) val finalOutputName = splitIdToFile(ctx.splitId)
val finalOutputPath = new Path(outputDir, finalOutputName) val finalOutputPath = new Path(outputDir, finalOutputName)
val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId) val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
if (fs.exists(tempOutputPath)) { if (fs.exists(tempOutputPath)) {
throw new IOException("Checkpoint failed: temporary path " + throw new IOException("Checkpoint failed: temporary path " +
@ -83,22 +84,22 @@ private[spark] object CheckpointRDD extends Logging {
serializeStream.close() serializeStream.close()
if (!fs.rename(tempOutputPath, finalOutputPath)) { if (!fs.rename(tempOutputPath, finalOutputPath)) {
if (!fs.delete(finalOutputPath, true)) { if (!fs.exists(finalOutputPath)) {
throw new IOException("Checkpoint failed: failed to delete earlier output of task " fs.delete(tempOutputPath, false)
+ context.attemptId)
}
if (!fs.rename(tempOutputPath, finalOutputPath)) {
throw new IOException("Checkpoint failed: failed to save output of task: " throw new IOException("Checkpoint failed: failed to save output of task: "
+ context.attemptId) + ctx.attemptId + " and final output path does not exist")
} else {
// Some other copy of this task must've finished before us and renamed it
logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it")
fs.delete(tempOutputPath, false)
} }
} }
} }
def readFromFile[T](path: String, context: TaskContext): Iterator[T] = { def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
val inputPath = new Path(path) val fs = path.getFileSystem(new Configuration())
val fs = inputPath.getFileSystem(new Configuration())
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val fileInputStream = fs.open(inputPath, bufferSize) val fileInputStream = fs.open(path, bufferSize)
val serializer = SparkEnv.get.serializer.newInstance() val serializer = SparkEnv.get.serializer.newInstance()
val deserializeStream = serializer.deserializeStream(fileInputStream) val deserializeStream = serializer.deserializeStream(fileInputStream)

View file

@ -45,8 +45,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
val aggr = new CoGroupAggregator val aggr = new CoGroupAggregator
@transient @transient var deps_ = {
var deps_ = {
val deps = new ArrayBuffer[Dependency[_]] val deps = new ArrayBuffer[Dependency[_]]
for ((rdd, index) <- rdds.zipWithIndex) { for ((rdd, index) <- rdds.zipWithIndex) {
if (rdd.partitioner == Some(part)) { if (rdd.partitioner == Some(part)) {
@ -63,8 +62,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
override def getDependencies = deps_ override def getDependencies = deps_
@transient @transient var splits_ : Array[Split] = {
var splits_ : Array[Split] = {
val array = new Array[Split](part.numPartitions) val array = new Array[Split](part.numPartitions)
for (i <- 0 until array.size) { for (i <- 0 until array.size) {
array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) => array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) =>
@ -86,6 +84,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit] val split = s.asInstanceOf[CoGroupSplit]
val numRdds = split.deps.size val numRdds = split.deps.size
// e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
val map = new JHashMap[K, Seq[ArrayBuffer[Any]]] val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = { def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
val seq = map.get(k) val seq = map.get(k)
@ -106,13 +105,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
} }
case ShuffleCoGroupSplitDep(shuffleId) => { case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle // Read map outputs of shuffle
def mergePair(pair: (K, Seq[Any])) {
val mySeq = getSeq(pair._1)
for (v <- pair._2)
mySeq(depNum) += v
}
val fetcher = SparkEnv.get.shuffleFetcher val fetcher = SparkEnv.get.shuffleFetcher
fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair) for ((k, vs) <- fetcher.fetch[K, Seq[Any]](shuffleId, split.index)) {
getSeq(k)(depNum) ++= vs
}
} }
} }
JavaConversions.mapAsScalaMap(map).iterator JavaConversions.mapAsScalaMap(map).iterator

View file

@ -27,11 +27,11 @@ private[spark] case class CoalescedRDDSplit(
* or to avoid having a large number of small tasks when processing a directory with many files. * or to avoid having a large number of small tasks when processing a directory with many files.
*/ */
class CoalescedRDD[T: ClassManifest]( class CoalescedRDD[T: ClassManifest](
var prev: RDD[T], @transient var prev: RDD[T],
maxPartitions: Int) maxPartitions: Int)
extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies
@transient var splits_ : Array[Split] = { override def getSplits: Array[Split] = {
val prevSplits = prev.splits val prevSplits = prev.splits
if (prevSplits.length < maxPartitions) { if (prevSplits.length < maxPartitions) {
prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) } prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) }
@ -44,26 +44,20 @@ class CoalescedRDD[T: ClassManifest](
} }
} }
override def getSplits = splits_
override def compute(split: Split, context: TaskContext): Iterator[T] = { override def compute(split: Split, context: TaskContext): Iterator[T] = {
split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { parentSplit => split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { parentSplit =>
firstParent[T].iterator(parentSplit, context) firstParent[T].iterator(parentSplit, context)
} }
} }
var deps_ : List[Dependency[_]] = List( override def getDependencies: Seq[Dependency[_]] = List(
new NarrowDependency(prev) { new NarrowDependency(prev) {
def getParents(id: Int): Seq[Int] = def getParents(id: Int): Seq[Int] =
splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices
} }
) )
override def getDependencies() = deps_
override def clearDependencies() { override def clearDependencies() {
deps_ = Nil
splits_ = null
prev = null prev = null
} }
} }

View file

@ -3,13 +3,11 @@ package spark.rdd
import spark.{RDD, Split, TaskContext} import spark.{RDD, Split, TaskContext}
private[spark] private[spark]
class MappedRDD[U: ClassManifest, T: ClassManifest]( class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U)
prev: RDD[T],
f: T => U)
extends RDD[U](prev) { extends RDD[U](prev) {
override def getSplits = firstParent[T].splits override def getSplits = firstParent[T].splits
override def compute(split: Split, context: TaskContext) = override def compute(split: Split, context: TaskContext) =
firstParent[T].iterator(split, context).map(f) firstParent[T].iterator(split, context).map(f)
} }

View file

@ -37,11 +37,9 @@ class NewHadoopRDD[K, V](
formatter.format(new Date()) formatter.format(new Date())
} }
@transient @transient private val jobId = new JobID(jobtrackerId, id)
private val jobId = new JobID(jobtrackerId, id)
@transient @transient private val splits_ : Array[Split] = {
private val splits_ : Array[Split] = {
val inputFormat = inputFormatClass.newInstance val inputFormat = inputFormatClass.newInstance
val jobContext = newJobContext(conf, jobId) val jobContext = newJobContext(conf, jobId)
val rawSplits = inputFormat.getSplits(jobContext).toArray val rawSplits = inputFormat.getSplits(jobContext).toArray

View file

@ -0,0 +1,42 @@
package spark.rdd
import spark.{NarrowDependency, RDD, SparkEnv, Split, TaskContext}
class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split {
override val index = idx
}
/**
* Represents a dependency between the PartitionPruningRDD and its parent. In this
* case, the child RDD contains a subset of partitions of the parents'.
*/
class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean)
extends NarrowDependency[T](rdd) {
@transient
val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index))
.zipWithIndex.map { case(split, idx) => new PartitionPruningRDDSplit(idx, split) : Split }
override def getParents(partitionId: Int) = List(partitions(partitionId).index)
}
/**
* A RDD used to prune RDD partitions/splits so we can avoid launching tasks on
* all partitions. An example use case: If we know the RDD is partitioned by range,
* and the execution DAG has a filter on the key, we can avoid launching tasks
* on partitions that don't have the range covering the key.
*/
class PartitionPruningRDD[T: ClassManifest](
@transient prev: RDD[T],
@transient partitionFilterFunc: Int => Boolean)
extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) {
override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(
split.asInstanceOf[PartitionPruningRDDSplit].parentSplit, context)
override protected def getSplits =
getDependencies.head.asInstanceOf[PruneDependency[T]].partitions
}

View file

@ -19,13 +19,12 @@ class SampledRDD[T: ClassManifest](
seed: Int) seed: Int)
extends RDD[T](prev) { extends RDD[T](prev) {
@transient @transient var splits_ : Array[Split] = {
var splits_ : Array[Split] = {
val rg = new Random(seed) val rg = new Random(seed)
firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt)) firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt))
} }
override def getSplits = splits_.asInstanceOf[Array[Split]] override def getSplits = splits_
override def getPreferredLocations(split: Split) = override def getPreferredLocations(split: Split) =
firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev)

View file

@ -22,17 +22,10 @@ class ShuffledRDD[K, V](
override val partitioner = Some(part) override val partitioner = Some(part)
@transient override def getSplits = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
override def getSplits = splits_
override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = { override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index)
} }
override def clearDependencies() {
splits_ = null
}
} }

View file

@ -26,10 +26,9 @@ private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIn
class UnionRDD[T: ClassManifest]( class UnionRDD[T: ClassManifest](
sc: SparkContext, sc: SparkContext,
@transient var rdds: Seq[RDD[T]]) @transient var rdds: Seq[RDD[T]])
extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs extends RDD[T](sc, Nil) { // Nil since we implement getDependencies
@transient override def getSplits: Array[Split] = {
var splits_ : Array[Split] = {
val array = new Array[Split](rdds.map(_.splits.size).sum) val array = new Array[Split](rdds.map(_.splits.size).sum)
var pos = 0 var pos = 0
for (rdd <- rdds; split <- rdd.splits) { for (rdd <- rdds; split <- rdd.splits) {
@ -39,20 +38,16 @@ class UnionRDD[T: ClassManifest](
array array
} }
override def getSplits = splits_ override def getDependencies: Seq[Dependency[_]] = {
@transient var deps_ = {
val deps = new ArrayBuffer[Dependency[_]] val deps = new ArrayBuffer[Dependency[_]]
var pos = 0 var pos = 0
for (rdd <- rdds) { for (rdd <- rdds) {
deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) deps += new RangeDependency(rdd, 0, pos, rdd.splits.size)
pos += rdd.splits.size pos += rdd.splits.size
} }
deps.toList deps
} }
override def getDependencies = deps_
override def compute(s: Split, context: TaskContext): Iterator[T] = override def compute(s: Split, context: TaskContext): Iterator[T] =
s.asInstanceOf[UnionSplit[T]].iterator(context) s.asInstanceOf[UnionSplit[T]].iterator(context)
@ -60,8 +55,6 @@ class UnionRDD[T: ClassManifest](
s.asInstanceOf[UnionSplit[T]].preferredLocations() s.asInstanceOf[UnionSplit[T]].preferredLocations()
override def clearDependencies() { override def clearDependencies() {
deps_ = null
splits_ = null
rdds = null rdds = null
} }
} }

View file

@ -32,10 +32,7 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))) extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2)))
with Serializable { with Serializable {
// TODO: FIX THIS. override def getSplits: Array[Split] = {
@transient
var splits_ : Array[Split] = {
if (rdd1.splits.size != rdd2.splits.size) { if (rdd1.splits.size != rdd2.splits.size) {
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
} }
@ -46,8 +43,6 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
array array
} }
override def getSplits = splits_
override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = { override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = {
val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context)) rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
@ -59,7 +54,6 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
} }
override def clearDependencies() { override def clearDependencies() {
splits_ = null
rdd1 = null rdd1 = null
rdd2 = null rdd2 = null
} }

View file

@ -23,7 +23,16 @@ import util.{MetadataCleaner, TimeStampedHashMap}
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents). * and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
*/ */
private[spark] private[spark]
class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging { class DAGScheduler(
taskSched: TaskScheduler,
mapOutputTracker: MapOutputTracker,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv)
extends TaskSchedulerListener with Logging {
def this(taskSched: TaskScheduler) {
this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
}
taskSched.setListener(this) taskSched.setListener(this)
// Called by TaskScheduler to report task completions or failures. // Called by TaskScheduler to report task completions or failures.
@ -35,12 +44,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
eventQueue.put(CompletionEvent(task, reason, result, accumUpdates)) eventQueue.put(CompletionEvent(task, reason, result, accumUpdates))
} }
// Called by TaskScheduler when a host fails. // Called by TaskScheduler when an executor fails.
override def hostLost(host: String) { override def executorLost(execId: String) {
eventQueue.put(HostLost(host)) eventQueue.put(ExecutorLost(execId))
} }
// Called by TaskScheduler to cancel an entier TaskSet due to repeated failures. // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
override def taskSetFailed(taskSet: TaskSet, reason: String) { override def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason)) eventQueue.put(TaskSetFailed(taskSet, reason))
} }
@ -54,8 +63,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
// resubmit failed stages // resubmit failed stages
val POLL_TIMEOUT = 10L val POLL_TIMEOUT = 10L
private val lock = new Object // Used for access to the entire DAGScheduler
private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent] private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
val nextRunId = new AtomicInteger(0) val nextRunId = new AtomicInteger(0)
@ -68,12 +75,13 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
var cacheLocs = new HashMap[Int, Array[List[String]]] var cacheLocs = new HashMap[Int, Array[List[String]]]
val env = SparkEnv.get // For tracking failed nodes, we use the MapOutputTracker's generation number, which is
val cacheTracker = env.cacheTracker // sent with every task. When we detect a node failing, we note the current generation number
val mapOutputTracker = env.mapOutputTracker // and failed executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask
// results.
val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back; // TODO: Garbage collect information about failure generations when we know there are no more
// that's not going to be a realistic assumption in general // stray messages to detect.
val failedGeneration = new HashMap[String, Long]
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
val running = new HashSet[Stage] // Stages we are running right now val running = new HashSet[Stage] // Stages we are running right now
@ -87,19 +95,27 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
// Start a thread to run the DAGScheduler event loop // Start a thread to run the DAGScheduler event loop
new Thread("DAGScheduler") { def start() {
setDaemon(true) new Thread("DAGScheduler") {
override def run() { setDaemon(true)
DAGScheduler.this.run() override def run() {
} DAGScheduler.this.run()
}.start() }
}.start()
}
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
if (!cacheLocs.contains(rdd.id)) {
val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map {
locations => locations.map(_.ip).toList
}.toArray
}
cacheLocs(rdd.id) cacheLocs(rdd.id)
} }
def updateCacheLocs() { private def clearCacheLocs() {
cacheLocs = cacheTracker.getLocationsSnapshot() cacheLocs.clear()
} }
/** /**
@ -107,7 +123,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* The priority value passed in will be used if the stage doesn't already exist with * The priority value passed in will be used if the stage doesn't already exist with
* a lower priority (we assume that priorities always increase across jobs for now). * a lower priority (we assume that priorities always increase across jobs for now).
*/ */
def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = { private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match { shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage case Some(stage) => stage
case None => case None =>
@ -122,12 +138,11 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* as a result stage for the final RDD used directly in an action. The stage will also be given * as a result stage for the final RDD used directly in an action. The stage will also be given
* the provided priority. * the provided priority.
*/ */
def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = { private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of splits is unknown
logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
cacheTracker.registerRDD(rdd.id, rdd.splits.size)
if (shuffleDep != None) { if (shuffleDep != None) {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of splits is unknown
logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size) mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
} }
val id = nextStageId.getAndIncrement() val id = nextStageId.getAndIncrement()
@ -140,7 +155,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Get or create the list of parent stages for a given RDD. The stages will be assigned the * Get or create the list of parent stages for a given RDD. The stages will be assigned the
* provided priority if they haven't already been created with a lower priority. * provided priority if they haven't already been created with a lower priority.
*/ */
def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = { private def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
val parents = new HashSet[Stage] val parents = new HashSet[Stage]
val visited = new HashSet[RDD[_]] val visited = new HashSet[RDD[_]]
def visit(r: RDD[_]) { def visit(r: RDD[_]) {
@ -148,8 +163,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
visited += r visited += r
// Kind of ugly: need to register RDDs with the cache here since // Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of splits is unknown // we can't do it in its constructor because # of splits is unknown
logInfo("Registering parent RDD " + r.id + " (" + r.origin + ")")
cacheTracker.registerRDD(r.id, r.splits.size)
for (dep <- r.dependencies) { for (dep <- r.dependencies) {
dep match { dep match {
case shufDep: ShuffleDependency[_,_] => case shufDep: ShuffleDependency[_,_] =>
@ -164,25 +177,22 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
parents.toList parents.toList
} }
def getMissingParentStages(stage: Stage): List[Stage] = { private def getMissingParentStages(stage: Stage): List[Stage] = {
val missing = new HashSet[Stage] val missing = new HashSet[Stage]
val visited = new HashSet[RDD[_]] val visited = new HashSet[RDD[_]]
def visit(rdd: RDD[_]) { def visit(rdd: RDD[_]) {
if (!visited(rdd)) { if (!visited(rdd)) {
visited += rdd visited += rdd
val locs = getCacheLocs(rdd) if (getCacheLocs(rdd).contains(Nil)) {
for (p <- 0 until rdd.splits.size) { for (dep <- rdd.dependencies) {
if (locs(p) == Nil) { dep match {
for (dep <- rdd.dependencies) { case shufDep: ShuffleDependency[_,_] =>
dep match { val mapStage = getShuffleMapStage(shufDep, stage.priority)
case shufDep: ShuffleDependency[_,_] => if (!mapStage.isAvailable) {
val mapStage = getShuffleMapStage(shufDep, stage.priority) missing += mapStage
if (!mapStage.isAvailable) { }
missing += mapStage case narrowDep: NarrowDependency[_] =>
} visit(narrowDep.rdd)
case narrowDep: NarrowDependency[_] =>
visit(narrowDep.rdd)
}
} }
} }
} }
@ -192,23 +202,45 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
missing.toList missing.toList
} }
/**
* Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
* JobWaiter whose getResult() method will return the result of the job when it is complete.
*
* The job is assumed to have at least one partition; zero partition jobs should be handled
* without a JobSubmitted event.
*/
private[scheduler] def prepareJob[T, U: ClassManifest](
finalRdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean,
resultHandler: (Int, U) => Unit)
: (JobSubmitted, JobWaiter[U]) =
{
assert(partitions.size > 0)
val waiter = new JobWaiter(partitions.size, resultHandler)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)
return (toSubmit, waiter)
}
def runJob[T, U: ClassManifest]( def runJob[T, U: ClassManifest](
finalRdd: RDD[T], finalRdd: RDD[T],
func: (TaskContext, Iterator[T]) => U, func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int], partitions: Seq[Int],
callSite: String, callSite: String,
allowLocal: Boolean) allowLocal: Boolean,
: Array[U] = resultHandler: (Int, U) => Unit)
{ {
if (partitions.size == 0) { if (partitions.size == 0) {
return new Array[U](0) return
} }
val waiter = new JobWaiter(partitions.size) val (toSubmit, waiter) = prepareJob(
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] finalRdd, func, partitions, callSite, allowLocal, resultHandler)
eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)) eventQueue.put(toSubmit)
waiter.getResult() match { waiter.awaitResult() match {
case JobSucceeded(results: Seq[_]) => case JobSucceeded => {}
return results.asInstanceOf[Seq[U]].toArray
case JobFailed(exception: Exception) => case JobFailed(exception: Exception) =>
logInfo("Failed to run " + callSite) logInfo("Failed to run " + callSite)
throw exception throw exception
@ -227,90 +259,117 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.splits.size).toArray val partitions = (0 until rdd.splits.size).toArray
eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener)) eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener))
return listener.getResult() // Will throw an exception if the job fails return listener.awaitResult() // Will throw an exception if the job fails
} }
/**
* Process one event retrieved from the event queue.
* Returns true if we should stop the event loop.
*/
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
val runId = nextRunId.getAndIncrement()
val finalStage = newStage(finalRDD, None, runId)
val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
clearCacheLocs()
logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
" output partitions (allowLocal=" + allowLocal + ")")
logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))
if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
// Compute very short actions like first() or take() with no parent stages locally.
runLocally(job)
} else {
activeJobs += job
resultStageToJob(finalStage) = job
submitStage(finalStage)
}
case ExecutorLost(execId) =>
handleExecutorLost(execId)
case completion: CompletionEvent =>
handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
abortStage(idToStage(taskSet.stageId), reason)
case StopDAGScheduler =>
// Cancel any active jobs
for (job <- activeJobs) {
val error = new SparkException("Job cancelled because SparkContext was shut down")
job.listener.jobFailed(error)
}
return true
}
return false
}
/**
* Resubmit any failed stages. Ordinarily called after a small amount of time has passed since
* the last fetch failure.
*/
private[scheduler] def resubmitFailedStages() {
logInfo("Resubmitting failed stages")
clearCacheLocs()
val failed2 = failed.toArray
failed.clear()
for (stage <- failed2.sortBy(_.priority)) {
submitStage(stage)
}
}
/**
* Check for waiting or failed stages which are now eligible for resubmission.
* Ordinarily run on every iteration of the event loop.
*/
private[scheduler] def submitWaitingStages() {
// TODO: We might want to run this less often, when we are sure that something has become
// runnable that wasn't before.
logTrace("Checking for newly runnable parent stages")
logTrace("running: " + running)
logTrace("waiting: " + waiting)
logTrace("failed: " + failed)
val waiting2 = waiting.toArray
waiting.clear()
for (stage <- waiting2.sortBy(_.priority)) {
submitStage(stage)
}
}
/** /**
* The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
* events and responds by launching tasks. This runs in a dedicated thread and receives events * events and responds by launching tasks. This runs in a dedicated thread and receives events
* via the eventQueue. * via the eventQueue.
*/ */
def run() { private def run() {
SparkEnv.set(env) SparkEnv.set(env)
while (true) { while (true) {
val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS) val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
if (event != null) { if (event != null) {
logDebug("Got event of type " + event.getClass.getName) logDebug("Got event of type " + event.getClass.getName)
} }
event match { if (event != null) {
case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) => if (processEvent(event)) {
val runId = nextRunId.getAndIncrement()
val finalStage = newStage(finalRDD, None, runId)
val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
updateCacheLocs()
logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
" output partitions")
logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))
if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
// Compute very short actions like first() or take() with no parent stages locally.
runLocally(job)
} else {
activeJobs += job
resultStageToJob(finalStage) = job
submitStage(finalStage)
}
case HostLost(host) =>
handleHostLost(host)
case completion: CompletionEvent =>
handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
abortStage(idToStage(taskSet.stageId), reason)
case StopDAGScheduler =>
// Cancel any active jobs
for (job <- activeJobs) {
val error = new SparkException("Job cancelled because SparkContext was shut down")
job.listener.jobFailed(error)
}
return return
}
case null =>
// queue.poll() timed out, ignore it
} }
val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
// Periodically resubmit failed stages if some map output fetches have failed and we have // Periodically resubmit failed stages if some map output fetches have failed and we have
// waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails, // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
// tasks on many other nodes are bound to get a fetch failure, and they won't all get it at // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
// the same time, so we want to make sure we've identified all the reduce tasks that depend // the same time, so we want to make sure we've identified all the reduce tasks that depend
// on the failed node. // on the failed node.
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
logInfo("Resubmitting failed stages") resubmitFailedStages()
updateCacheLocs()
val failed2 = failed.toArray
failed.clear()
for (stage <- failed2.sortBy(_.priority)) {
submitStage(stage)
}
} else { } else {
// TODO: We might want to run this less often, when we are sure that something has become submitWaitingStages()
// runnable that wasn't before.
logDebug("Checking for newly runnable parent stages")
logDebug("running: " + running)
logDebug("waiting: " + waiting)
logDebug("failed: " + failed)
val waiting2 = waiting.toArray
waiting.clear()
for (stage <- waiting2.sortBy(_.priority)) {
submitStage(stage)
}
} }
} }
} }
@ -320,7 +379,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* We run the operation in a separate thread just in case it takes a bunch of time, so that we * We run the operation in a separate thread just in case it takes a bunch of time, so that we
* don't block the DAGScheduler event loop or other concurrent jobs. * don't block the DAGScheduler event loop or other concurrent jobs.
*/ */
def runLocally(job: ActiveJob) { private def runLocally(job: ActiveJob) {
logInfo("Computing the requested partition locally") logInfo("Computing the requested partition locally")
new Thread("Local computation of job " + job.runId) { new Thread("Local computation of job " + job.runId) {
override def run() { override def run() {
@ -329,9 +388,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val rdd = job.finalStage.rdd val rdd = job.finalStage.rdd
val split = rdd.splits(job.partitions(0)) val split = rdd.splits(job.partitions(0))
val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
val result = job.func(taskContext, rdd.iterator(split, taskContext)) try {
taskContext.executeOnCompleteCallbacks() val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result) job.listener.taskSucceeded(0, result)
} finally {
taskContext.executeOnCompleteCallbacks()
}
} catch { } catch {
case e: Exception => case e: Exception =>
job.listener.jobFailed(e) job.listener.jobFailed(e)
@ -340,13 +402,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}.start() }.start()
} }
def submitStage(stage: Stage) { /** Submits stage, but first recursively submits any missing parents. */
private def submitStage(stage: Stage) {
logDebug("submitStage(" + stage + ")") logDebug("submitStage(" + stage + ")")
if (!waiting(stage) && !running(stage) && !failed(stage)) { if (!waiting(stage) && !running(stage) && !failed(stage)) {
val missing = getMissingParentStages(stage).sortBy(_.id) val missing = getMissingParentStages(stage).sortBy(_.id)
logDebug("missing: " + missing) logDebug("missing: " + missing)
if (missing == Nil) { if (missing == Nil) {
logInfo("Submitting " + stage + " (" + stage.origin + "), which has no missing parents") logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
submitMissingTasks(stage) submitMissingTasks(stage)
running += stage running += stage
} else { } else {
@ -358,7 +421,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
} }
} }
def submitMissingTasks(stage: Stage) { /** Called when stage's parents are available and we can now do its task. */
private def submitMissingTasks(stage: Stage) {
logDebug("submitMissingTasks(" + stage + ")") logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry // Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
@ -379,11 +443,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
} }
} }
if (tasks.size > 0) { if (tasks.size > 0) {
logInfo("Submitting " + tasks.size + " missing tasks from " + stage) logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
myPending ++= tasks myPending ++= tasks
logDebug("New pending tasks: " + myPending) logDebug("New pending tasks: " + myPending)
taskSched.submitTasks( taskSched.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority)) new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority))
if (!stage.submissionTime.isDefined) {
stage.submissionTime = Some(System.currentTimeMillis())
}
} else { } else {
logDebug("Stage " + stage + " is actually done; %b %d %d".format( logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@ -395,9 +462,18 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Responds to a task finishing. This is called inside the event loop so it assumes that it can * Responds to a task finishing. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
*/ */
def handleTaskCompletion(event: CompletionEvent) { private def handleTaskCompletion(event: CompletionEvent) {
val task = event.task val task = event.task
val stage = idToStage(task.stageId) val stage = idToStage(task.stageId)
def markStageAsFinished(stage: Stage) = {
val serviceTime = stage.submissionTime match {
case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0)
case _ => "Unkown"
}
logInfo("%s (%s) finished in %s s".format(stage, stage.origin, serviceTime))
running -= stage
}
event.reason match { event.reason match {
case Success => case Success =>
logInfo("Completed " + task) logInfo("Completed " + task)
@ -412,13 +488,13 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
if (!job.finished(rt.outputId)) { if (!job.finished(rt.outputId)) {
job.finished(rt.outputId) = true job.finished(rt.outputId) = true
job.numFinished += 1 job.numFinished += 1
job.listener.taskSucceeded(rt.outputId, event.result)
// If the whole job has finished, remove it // If the whole job has finished, remove it
if (job.numFinished == job.numPartitions) { if (job.numFinished == job.numPartitions) {
activeJobs -= job activeJobs -= job
resultStageToJob -= stage resultStageToJob -= stage
running -= stage markStageAsFinished(stage)
} }
job.listener.taskSucceeded(rt.outputId, event.result)
} }
case None => case None =>
logInfo("Ignoring result from " + rt + " because its job has finished") logInfo("Ignoring result from " + rt + " because its job has finished")
@ -427,23 +503,32 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
case smt: ShuffleMapTask => case smt: ShuffleMapTask =>
val stage = idToStage(smt.stageId) val stage = idToStage(smt.stageId)
val status = event.result.asInstanceOf[MapStatus] val status = event.result.asInstanceOf[MapStatus]
val host = status.address.ip val execId = status.location.executorId
logInfo("ShuffleMapTask finished with host " + host) logDebug("ShuffleMapTask finished on " + execId)
if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos if (failedGeneration.contains(execId) && smt.generation <= failedGeneration(execId)) {
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
} else {
stage.addOutputLoc(smt.partition, status) stage.addOutputLoc(smt.partition, status)
} }
if (running.contains(stage) && pendingTasks(stage).isEmpty) { if (running.contains(stage) && pendingTasks(stage).isEmpty) {
logInfo(stage + " (" + stage.origin + ") finished; looking for newly runnable stages") markStageAsFinished(stage)
running -= stage logInfo("looking for newly runnable stages")
logInfo("running: " + running) logInfo("running: " + running)
logInfo("waiting: " + waiting) logInfo("waiting: " + waiting)
logInfo("failed: " + failed) logInfo("failed: " + failed)
if (stage.shuffleDep != None) { if (stage.shuffleDep != None) {
// We supply true to increment the generation number here in case this is a
// recomputation of the map outputs. In that case, some nodes may have cached
// locations with holes (from when we detected the error) and will need the
// generation incremented to refetch them.
// TODO: Only increment the generation number if this is not the first time
// we registered these map outputs.
mapOutputTracker.registerMapOutputs( mapOutputTracker.registerMapOutputs(
stage.shuffleDep.get.shuffleId, stage.shuffleDep.get.shuffleId,
stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray) stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray,
true)
} }
updateCacheLocs() clearCacheLocs()
if (stage.outputLocs.count(_ == Nil) != 0) { if (stage.outputLocs.count(_ == Nil) != 0) {
// Some tasks had failed; let's resubmit this stage // Some tasks had failed; let's resubmit this stage
// TODO: Lower-level scheduler should also deal with this // TODO: Lower-level scheduler should also deal with this
@ -462,7 +547,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
waiting --= newlyRunnable waiting --= newlyRunnable
running ++= newlyRunnable running ++= newlyRunnable
for (stage <- newlyRunnable.sortBy(_.id)) { for (stage <- newlyRunnable.sortBy(_.id)) {
logInfo("Submitting " + stage + " (" + stage.origin + "), which is now runnable") logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable")
submitMissingTasks(stage) submitMissingTasks(stage)
} }
} }
@ -493,9 +578,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
// Remember that a fetch failed now; this is used to resubmit the broken // Remember that a fetch failed now; this is used to resubmit the broken
// stages later, after a small wait (to give other tasks the chance to fail) // stages later, after a small wait (to give other tasks the chance to fail)
lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock
// TODO: mark the host as failed only if there were lots of fetch failures on it // TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) { if (bmAddress != null) {
handleHostLost(bmAddress.ip) handleExecutorLost(bmAddress.executorId, Some(task.generation))
} }
case other => case other =>
@ -505,22 +590,31 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
} }
/** /**
* Responds to a host being lost. This is called inside the event loop so it assumes that it can * Responds to an executor being lost. This is called inside the event loop, so it assumes it can
* modify the scheduler's internal state. Use hostLost() to post a host lost event from outside. * modify the scheduler's internal state. Use executorLost() to post a loss event from outside.
*
* Optionally the generation during which the failure was caught can be passed to avoid allowing
* stray fetch failures from possibly retriggering the detection of a node as lost.
*/ */
def handleHostLost(host: String) { private def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) {
if (!deadHosts.contains(host)) { val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration)
logInfo("Host lost: " + host) if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) {
deadHosts += host failedGeneration(execId) = currentGeneration
env.blockManager.master.notifyADeadHost(host) logInfo("Executor lost: %s (generation %d)".format(execId, currentGeneration))
blockManagerMaster.removeExecutor(execId)
// TODO: This will be really slow if we keep accumulating shuffle map stages // TODO: This will be really slow if we keep accumulating shuffle map stages
for ((shuffleId, stage) <- shuffleToMapStage) { for ((shuffleId, stage) <- shuffleToMapStage) {
stage.removeOutputsOnHost(host) stage.removeOutputsOnExecutor(execId)
val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
mapOutputTracker.registerMapOutputs(shuffleId, locs, true) mapOutputTracker.registerMapOutputs(shuffleId, locs, true)
} }
cacheTracker.cacheLost(host) if (shuffleToMapStage.isEmpty) {
updateCacheLocs() mapOutputTracker.incrementGeneration()
}
clearCacheLocs()
} else {
logDebug("Additional executor lost message for " + execId +
"(generation " + currentGeneration + ")")
} }
} }
@ -528,7 +622,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Aborts all jobs depending on a particular Stage. This is called in response to a task set * Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/ */
def abortStage(failedStage: Stage, reason: String) { private def abortStage(failedStage: Stage, reason: String) {
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
for (resultStage <- dependentStages) { for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage) val job = resultStageToJob(resultStage)
@ -544,7 +638,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
/** /**
* Return true if one of stage's ancestors is target. * Return true if one of stage's ancestors is target.
*/ */
def stageDependsOn(stage: Stage, target: Stage): Boolean = { private def stageDependsOn(stage: Stage, target: Stage): Boolean = {
if (stage == target) { if (stage == target) {
return true return true
} }
@ -571,7 +665,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
visitedRdds.contains(target.rdd) visitedRdds.contains(target.rdd)
} }
def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = { private def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
// If the partition is cached, return the cache locations // If the partition is cached, return the cache locations
val cached = getCacheLocs(rdd)(partition) val cached = getCacheLocs(rdd)(partition)
if (cached != Nil) { if (cached != Nil) {
@ -597,7 +691,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
return Nil return Nil
} }
def cleanup(cleanupTime: Long) { private def cleanup(cleanupTime: Long) {
var sizeBefore = idToStage.size var sizeBefore = idToStage.size
idToStage.clearOldValues(cleanupTime) idToStage.clearOldValues(cleanupTime)
logInfo("idToStage " + sizeBefore + " --> " + idToStage.size) logInfo("idToStage " + sizeBefore + " --> " + idToStage.size)

View file

@ -28,7 +28,7 @@ private[spark] case class CompletionEvent(
accumUpdates: Map[Long, Any]) accumUpdates: Map[Long, Any])
extends DAGSchedulerEvent extends DAGSchedulerEvent
private[spark] case class HostLost(host: String) extends DAGSchedulerEvent private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent

View file

@ -5,5 +5,5 @@ package spark.scheduler
*/ */
private[spark] sealed trait JobResult private[spark] sealed trait JobResult
private[spark] case class JobSucceeded(results: Seq[_]) extends JobResult private[spark] case object JobSucceeded extends JobResult
private[spark] case class JobFailed(exception: Exception) extends JobResult private[spark] case class JobFailed(exception: Exception) extends JobResult

View file

@ -3,10 +3,12 @@ package spark.scheduler
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
/** /**
* An object that waits for a DAGScheduler job to complete. * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
* results to the given handler function.
*/ */
private[spark] class JobWaiter(totalTasks: Int) extends JobListener { private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit)
private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null) extends JobListener {
private var finishedTasks = 0 private var finishedTasks = 0
private var jobFinished = false // Is the job as a whole finished (succeeded or failed)? private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
@ -17,11 +19,11 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
if (jobFinished) { if (jobFinished) {
throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
} }
taskResults(index) = result resultHandler(index, result.asInstanceOf[T])
finishedTasks += 1 finishedTasks += 1
if (finishedTasks == totalTasks) { if (finishedTasks == totalTasks) {
jobFinished = true jobFinished = true
jobResult = JobSucceeded(taskResults) jobResult = JobSucceeded
this.notifyAll() this.notifyAll()
} }
} }
@ -38,7 +40,7 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
} }
} }
def getResult(): JobResult = synchronized { def awaitResult(): JobResult = synchronized {
while (!jobFinished) { while (!jobFinished) {
this.wait() this.wait()
} }

View file

@ -8,19 +8,19 @@ import java.io.{ObjectOutput, ObjectInput, Externalizable}
* task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
* The map output sizes are compressed using MapOutputTracker.compressSize. * The map output sizes are compressed using MapOutputTracker.compressSize.
*/ */
private[spark] class MapStatus(var address: BlockManagerId, var compressedSizes: Array[Byte]) private[spark] class MapStatus(var location: BlockManagerId, var compressedSizes: Array[Byte])
extends Externalizable { extends Externalizable {
def this() = this(null, null) // For deserialization only def this() = this(null, null) // For deserialization only
def writeExternal(out: ObjectOutput) { def writeExternal(out: ObjectOutput) {
address.writeExternal(out) location.writeExternal(out)
out.writeInt(compressedSizes.length) out.writeInt(compressedSizes.length)
out.write(compressedSizes) out.write(compressedSizes)
} }
def readExternal(in: ObjectInput) { def readExternal(in: ObjectInput) {
address = new BlockManagerId(in) location = BlockManagerId(in)
compressedSizes = new Array[Byte](in.readInt()) compressedSizes = new Array[Byte](in.readInt())
in.readFully(compressedSizes) in.readFully(compressedSizes)
} }

View file

@ -32,7 +32,7 @@ private[spark] object ShuffleMapTask {
return old return old
} else { } else {
val out = new ByteArrayOutputStream val out = new ByteArrayOutputStream
val ser = SparkEnv.get.closureSerializer.newInstance val ser = SparkEnv.get.closureSerializer.newInstance()
val objOut = ser.serializeStream(new GZIPOutputStream(out)) val objOut = ser.serializeStream(new GZIPOutputStream(out))
objOut.writeObject(rdd) objOut.writeObject(rdd)
objOut.writeObject(dep) objOut.writeObject(dep)
@ -48,7 +48,7 @@ private[spark] object ShuffleMapTask {
synchronized { synchronized {
val loader = Thread.currentThread.getContextClassLoader val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val ser = SparkEnv.get.closureSerializer.newInstance val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in) val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]] val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
@ -81,7 +81,7 @@ private[spark] class ShuffleMapTask(
with Externalizable with Externalizable
with Logging { with Logging {
def this() = this(0, null, null, 0, null) protected def this() = this(0, null, null, 0, null)
var split = if (rdd == null) { var split = if (rdd == null) {
null null
@ -117,34 +117,33 @@ private[spark] class ShuffleMapTask(
override def run(attemptId: Long): MapStatus = { override def run(attemptId: Long): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions val numOutputSplits = dep.partitioner.numPartitions
val partitioner = dep.partitioner
val taskContext = new TaskContext(stageId, partition, attemptId) val taskContext = new TaskContext(stageId, partition, attemptId)
try {
// Partition the map output.
val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
for (elem <- rdd.iterator(split, taskContext)) {
val pair = elem.asInstanceOf[(Any, Any)]
val bucketId = dep.partitioner.getPartition(pair._1)
buckets(bucketId) += pair
}
// Partition the map output. val compressedSizes = new Array[Byte](numOutputSplits)
val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
for (elem <- rdd.iterator(split, taskContext)) { val blockManager = SparkEnv.get.blockManager
val pair = elem.asInstanceOf[(Any, Any)] for (i <- 0 until numOutputSplits) {
val bucketId = partitioner.getPartition(pair._1) val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
buckets(bucketId) += pair // Get a Scala iterator from Java map
val iter: Iterator[(Any, Any)] = buckets(i).iterator
val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
compressedSizes(i) = MapOutputTracker.compressSize(size)
}
return new MapStatus(blockManager.blockManagerId, compressedSizes)
} finally {
// Execute the callbacks on task completion.
taskContext.executeOnCompleteCallbacks()
} }
val bucketIterators = buckets.map(_.iterator)
val compressedSizes = new Array[Byte](numOutputSplits)
val blockManager = SparkEnv.get.blockManager
for (i <- 0 until numOutputSplits) {
val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
// Get a Scala iterator from Java map
val iter: Iterator[(Any, Any)] = bucketIterators(i)
val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
compressedSizes(i) = MapOutputTracker.compressSize(size)
}
// Execute the callbacks on task completion.
taskContext.executeOnCompleteCallbacks()
return new MapStatus(blockManager.blockManagerId, compressedSizes)
} }
override def preferredLocations: Seq[String] = locs override def preferredLocations: Seq[String] = locs

View file

@ -32,6 +32,9 @@ private[spark] class Stage(
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
var numAvailableOutputs = 0 var numAvailableOutputs = 0
/** When first task was submitted to scheduler. */
var submissionTime: Option[Long] = None
private var nextAttemptId = 0 private var nextAttemptId = 0
def isAvailable: Boolean = { def isAvailable: Boolean = {
@ -51,18 +54,18 @@ private[spark] class Stage(
def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) { def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) {
val prevList = outputLocs(partition) val prevList = outputLocs(partition)
val newList = prevList.filterNot(_.address == bmAddress) val newList = prevList.filterNot(_.location == bmAddress)
outputLocs(partition) = newList outputLocs(partition) = newList
if (prevList != Nil && newList == Nil) { if (prevList != Nil && newList == Nil) {
numAvailableOutputs -= 1 numAvailableOutputs -= 1
} }
} }
def removeOutputsOnHost(host: String) { def removeOutputsOnExecutor(execId: String) {
var becameUnavailable = false var becameUnavailable = false
for (partition <- 0 until numPartitions) { for (partition <- 0 until numPartitions) {
val prevList = outputLocs(partition) val prevList = outputLocs(partition)
val newList = prevList.filterNot(_.address.ip == host) val newList = prevList.filterNot(_.location.executorId == execId)
outputLocs(partition) = newList outputLocs(partition) = newList
if (prevList != Nil && newList == Nil) { if (prevList != Nil && newList == Nil) {
becameUnavailable = true becameUnavailable = true
@ -70,7 +73,8 @@ private[spark] class Stage(
} }
} }
if (becameUnavailable) { if (becameUnavailable) {
logInfo("%s is now unavailable on %s (%d/%d, %s)".format(this, host, numAvailableOutputs, numPartitions, isAvailable)) logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format(
this, execId, numAvailableOutputs, numPartitions, isAvailable))
} }
} }
@ -82,7 +86,7 @@ private[spark] class Stage(
def origin: String = rdd.origin def origin: String = rdd.origin
override def toString = "Stage " + id // + ": [RDD = " + rdd.id + ", isShuffle = " + isShuffleMap + "]" override def toString = "Stage " + id
override def hashCode(): Int = id override def hashCode(): Int = id
} }

View file

@ -12,7 +12,7 @@ private[spark] trait TaskSchedulerListener {
def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit
// A node was lost from the cluster. // A node was lost from the cluster.
def hostLost(host: String): Unit def executorLost(execId: String): Unit
// The TaskScheduler wants to abort an entire task set. // The TaskScheduler wants to abort an entire task set.
def taskSetFailed(taskSet: TaskSet, reason: String): Unit def taskSetFailed(taskSet: TaskSet, reason: String): Unit

View file

@ -27,19 +27,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager]
val taskIdToTaskSetId = new HashMap[Long, String] val taskIdToTaskSetId = new HashMap[Long, String]
val taskIdToSlaveId = new HashMap[Long, String] val taskIdToExecutorId = new HashMap[Long, String]
val taskSetTaskIds = new HashMap[String, HashSet[Long]] val taskSetTaskIds = new HashMap[String, HashSet[Long]]
// Incrementing Mesos task IDs // Incrementing Mesos task IDs
val nextTaskId = new AtomicLong(0) val nextTaskId = new AtomicLong(0)
// Which hosts in the cluster are alive (contains hostnames) // Which executor IDs we have executors on
val hostsAlive = new HashSet[String] val activeExecutorIds = new HashSet[String]
// Which slave IDs we have executors on // The set of executors we have on each host; this is used to compute hostsAlive, which
val slaveIdsWithExecutors = new HashSet[String] // in turn is used to decide when we can attain data locality on a given host
val executorsByHost = new HashMap[String, HashSet[String]]
val slaveIdToHost = new HashMap[String, String] val executorIdToHost = new HashMap[String, String]
// JAR server, if any JARs were added by the user to the SparkContext // JAR server, if any JARs were added by the user to the SparkContext
var jarServer: HttpServer = null var jarServer: HttpServer = null
@ -85,7 +86,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
} }
} }
def submitTasks(taskSet: TaskSet) { override def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized { this.synchronized {
@ -102,7 +103,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
activeTaskSets -= manager.taskSet.id activeTaskSets -= manager.taskSet.id
activeTaskSetsQueue -= manager activeTaskSetsQueue -= manager
taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
taskIdToSlaveId --= taskSetTaskIds(manager.taskSet.id) taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
taskSetTaskIds.remove(manager.taskSet.id) taskSetTaskIds.remove(manager.taskSet.id)
} }
} }
@ -117,8 +118,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
SparkEnv.set(sc.env) SparkEnv.set(sc.env)
// Mark each slave as alive and remember its hostname // Mark each slave as alive and remember its hostname
for (o <- offers) { for (o <- offers) {
slaveIdToHost(o.slaveId) = o.hostname executorIdToHost(o.executorId) = o.hostname
hostsAlive += o.hostname
} }
// Build a list of tasks to assign to each slave // Build a list of tasks to assign to each slave
val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
@ -128,16 +128,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
do { do {
launchedTask = false launchedTask = false
for (i <- 0 until offers.size) { for (i <- 0 until offers.size) {
val sid = offers(i).slaveId val execId = offers(i).executorId
val host = offers(i).hostname val host = offers(i).hostname
manager.slaveOffer(sid, host, availableCpus(i)) match { manager.slaveOffer(execId, host, availableCpus(i)) match {
case Some(task) => case Some(task) =>
tasks(i) += task tasks(i) += task
val tid = task.taskId val tid = task.taskId
taskIdToTaskSetId(tid) = manager.taskSet.id taskIdToTaskSetId(tid) = manager.taskSet.id
taskSetTaskIds(manager.taskSet.id) += tid taskSetTaskIds(manager.taskSet.id) += tid
taskIdToSlaveId(tid) = sid taskIdToExecutorId(tid) = execId
slaveIdsWithExecutors += sid activeExecutorIds += execId
if (!executorsByHost.contains(host)) {
executorsByHost(host) = new HashSet()
}
executorsByHost(host) += execId
availableCpus(i) -= 1 availableCpus(i) -= 1
launchedTask = true launchedTask = true
@ -152,25 +156,21 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
var taskSetToUpdate: Option[TaskSetManager] = None var taskSetToUpdate: Option[TaskSetManager] = None
var failedHost: Option[String] = None var failedExecutor: Option[String] = None
var taskFailed = false var taskFailed = false
synchronized { synchronized {
try { try {
if (state == TaskState.LOST && taskIdToSlaveId.contains(tid)) { if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
// We lost the executor on this slave, so remember that it's gone // We lost this entire executor, so remember that it's gone
val slaveId = taskIdToSlaveId(tid) val execId = taskIdToExecutorId(tid)
val host = slaveIdToHost(slaveId) if (activeExecutorIds.contains(execId)) {
if (hostsAlive.contains(host)) { removeExecutor(execId)
slaveIdsWithExecutors -= slaveId failedExecutor = Some(execId)
hostsAlive -= host
activeTaskSetsQueue.foreach(_.hostLost(host))
failedHost = Some(host)
} }
} }
taskIdToTaskSetId.get(tid) match { taskIdToTaskSetId.get(tid) match {
case Some(taskSetId) => case Some(taskSetId) =>
if (activeTaskSets.contains(taskSetId)) { if (activeTaskSets.contains(taskSetId)) {
//activeTaskSets(taskSetId).statusUpdate(status)
taskSetToUpdate = Some(activeTaskSets(taskSetId)) taskSetToUpdate = Some(activeTaskSets(taskSetId))
} }
if (TaskState.isFinished(state)) { if (TaskState.isFinished(state)) {
@ -178,7 +178,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (taskSetTaskIds.contains(taskSetId)) { if (taskSetTaskIds.contains(taskSetId)) {
taskSetTaskIds(taskSetId) -= tid taskSetTaskIds(taskSetId) -= tid
} }
taskIdToSlaveId.remove(tid) taskIdToExecutorId.remove(tid)
} }
if (state == TaskState.FAILED) { if (state == TaskState.FAILED) {
taskFailed = true taskFailed = true
@ -190,12 +190,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
case e: Exception => logError("Exception in statusUpdate", e) case e: Exception => logError("Exception in statusUpdate", e)
} }
} }
// Update the task set and DAGScheduler without holding a lock on this, because that can deadlock // Update the task set and DAGScheduler without holding a lock on this, since that can deadlock
if (taskSetToUpdate != None) { if (taskSetToUpdate != None) {
taskSetToUpdate.get.statusUpdate(tid, state, serializedData) taskSetToUpdate.get.statusUpdate(tid, state, serializedData)
} }
if (failedHost != None) { if (failedExecutor != None) {
listener.hostLost(failedHost.get) listener.executorLost(failedExecutor.get)
backend.reviveOffers() backend.reviveOffers()
} }
if (taskFailed) { if (taskFailed) {
@ -249,27 +249,42 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
} }
} }
def slaveLost(slaveId: String, reason: ExecutorLossReason) { def executorLost(executorId: String, reason: ExecutorLossReason) {
var failedHost: Option[String] = None var failedExecutor: Option[String] = None
synchronized { synchronized {
val host = slaveIdToHost(slaveId) if (activeExecutorIds.contains(executorId)) {
if (hostsAlive.contains(host)) { val host = executorIdToHost(executorId)
logError("Lost an executor on " + host + ": " + reason) logError("Lost executor %s on %s: %s".format(executorId, host, reason))
slaveIdsWithExecutors -= slaveId removeExecutor(executorId)
hostsAlive -= host failedExecutor = Some(executorId)
activeTaskSetsQueue.foreach(_.hostLost(host))
failedHost = Some(host)
} else { } else {
// We may get multiple slaveLost() calls with different loss reasons. For example, one // We may get multiple executorLost() calls with different loss reasons. For example, one
// may be triggered by a dropped connection from the slave while another may be a report // may be triggered by a dropped connection from the slave while another may be a report
// of executor termination from Mesos. We produce log messages for both so we eventually // of executor termination from Mesos. We produce log messages for both so we eventually
// report the termination reason. // report the termination reason.
logError("Lost an executor on " + host + " (already removed): " + reason) logError("Lost an executor " + executorId + " (already removed): " + reason)
} }
} }
if (failedHost != None) { // Call listener.executorLost without holding the lock on this to prevent deadlock
listener.hostLost(failedHost.get) if (failedExecutor != None) {
listener.executorLost(failedExecutor.get)
backend.reviveOffers() backend.reviveOffers()
} }
} }
/** Get a list of hosts that currently have executors */
def hostsAlive: scala.collection.Set[String] = executorsByHost.keySet
/** Remove an executor from all our data structures and mark it as lost */
private def removeExecutor(executorId: String) {
activeExecutorIds -= executorId
val host = executorIdToHost(executorId)
val execs = executorsByHost.getOrElse(host, new HashSet)
execs -= executorId
if (execs.isEmpty) {
executorsByHost -= host
}
executorIdToHost -= executorId
activeTaskSetsQueue.foreach(_.executorLost(executorId, host))
}
} }

View file

@ -1,5 +1,7 @@
package spark.scheduler.cluster package spark.scheduler.cluster
import spark.Utils
/** /**
* A backend interface for cluster scheduling systems that allows plugging in different ones under * A backend interface for cluster scheduling systems that allows plugging in different ones under
* ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as * ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as
@ -11,5 +13,15 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit def reviveOffers(): Unit
def defaultParallelism(): Int def defaultParallelism(): Int
// Memory used by each executor (in megabytes)
protected val executorMemory = {
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
Option(System.getProperty("spark.executor.memory"))
.orElse(Option(System.getenv("SPARK_MEM")))
.map(Utils.memoryStringToMb)
.getOrElse(512)
}
// TODO: Probably want to add a killTask too // TODO: Probably want to add a killTask too
} }

View file

@ -1,4 +0,0 @@
package spark.scheduler.cluster
private[spark]
class SlaveResources(val slaveId: String, val hostname: String, val coresFree: Int) {}

View file

@ -19,34 +19,25 @@ private[spark] class SparkDeploySchedulerBackend(
var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _
val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt
val executorIdToSlaveId = new HashMap[String, String]
// Memory used by each executor (in megabytes)
val executorMemory = {
if (System.getenv("SPARK_MEM") != null) {
Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
} else {
512
}
}
override def start() { override def start() {
super.start() super.start()
val masterUrl = "akka://spark@%s:%s/user/%s".format( // The endpoint for executors to talk to us
System.getProperty("spark.master.host"), System.getProperty("spark.master.port"), val driverUrl = "akka://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
StandaloneSchedulerBackend.ACTOR_NAME) StandaloneSchedulerBackend.ACTOR_NAME)
val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}")
val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command) val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone"))
val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome)
client = new Client(sc.env.actorSystem, master, jobDesc, this) client = new Client(sc.env.actorSystem, master, jobDesc, this)
client.start() client.start()
} }
override def stop() { override def stop() {
stopping = true; stopping = true
super.stop() super.stop()
client.stop() client.stop()
if (shutdownCallback != null) { if (shutdownCallback != null) {
@ -54,35 +45,28 @@ private[spark] class SparkDeploySchedulerBackend(
} }
} }
def connected(jobId: String) { override def connected(jobId: String) {
logInfo("Connected to Spark cluster with job ID " + jobId) logInfo("Connected to Spark cluster with job ID " + jobId)
} }
def disconnected() { override def disconnected() {
if (!stopping) { if (!stopping) {
logError("Disconnected from Spark cluster!") logError("Disconnected from Spark cluster!")
scheduler.error("Disconnected from Spark cluster") scheduler.error("Disconnected from Spark cluster")
} }
} }
def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) { override def executorAdded(executorId: String, workerId: String, host: String, cores: Int, memory: Int) {
executorIdToSlaveId += id -> workerId
logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format( logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format(
id, host, cores, Utils.memoryMegabytesToString(memory))) executorId, host, cores, Utils.memoryMegabytesToString(memory)))
} }
def executorRemoved(id: String, message: String, exitStatus: Option[Int]) { override def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) {
val reason: ExecutorLossReason = exitStatus match { val reason: ExecutorLossReason = exitStatus match {
case Some(code) => ExecutorExited(code) case Some(code) => ExecutorExited(code)
case None => SlaveLost(message) case None => SlaveLost(message)
} }
logInfo("Executor %s removed: %s".format(id, message)) logInfo("Executor %s removed: %s".format(executorId, message))
executorIdToSlaveId.get(id) match { scheduler.executorLost(executorId, reason)
case Some(slaveId) =>
executorIdToSlaveId.remove(id)
scheduler.slaveLost(slaveId, reason)
case None =>
logInfo("No slave ID known for executor %s".format(id))
}
} }
} }

View file

@ -6,32 +6,34 @@ import spark.util.SerializableBuffer
private[spark] sealed trait StandaloneClusterMessage extends Serializable private[spark] sealed trait StandaloneClusterMessage extends Serializable
// Master to slaves // Driver to executors
private[spark] private[spark]
case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage
private[spark] private[spark]
case class RegisteredSlave(sparkProperties: Seq[(String, String)]) extends StandaloneClusterMessage case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
extends StandaloneClusterMessage
private[spark] private[spark]
case class RegisterSlaveFailed(message: String) extends StandaloneClusterMessage case class RegisterExecutorFailed(message: String) extends StandaloneClusterMessage
// Slaves to master // Executors to driver
private[spark] private[spark]
case class RegisterSlave(slaveId: String, host: String, cores: Int) extends StandaloneClusterMessage case class RegisterExecutor(executorId: String, host: String, cores: Int)
extends StandaloneClusterMessage
private[spark] private[spark]
case class StatusUpdate(slaveId: String, taskId: Long, state: TaskState, data: SerializableBuffer) case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer)
extends StandaloneClusterMessage extends StandaloneClusterMessage
private[spark] private[spark]
object StatusUpdate { object StatusUpdate {
/** Alternate factory method that takes a ByteBuffer directly for the data field */ /** Alternate factory method that takes a ByteBuffer directly for the data field */
def apply(slaveId: String, taskId: Long, state: TaskState, data: ByteBuffer): StatusUpdate = { def apply(executorId: String, taskId: Long, state: TaskState, data: ByteBuffer): StatusUpdate = {
StatusUpdate(slaveId, taskId, state, new SerializableBuffer(data)) StatusUpdate(executorId, taskId, state, new SerializableBuffer(data))
} }
} }
// Internal messages in master // Internal messages in driver
private[spark] case object ReviveOffers extends StandaloneClusterMessage private[spark] case object ReviveOffers extends StandaloneClusterMessage
private[spark] case object StopMaster extends StandaloneClusterMessage private[spark] case object StopDriver extends StandaloneClusterMessage

View file

@ -23,13 +23,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed // Use an atomic variable to track total number of cores in the cluster for simplicity and speed
var totalCoreCount = new AtomicInteger(0) var totalCoreCount = new AtomicInteger(0)
class MasterActor(sparkProperties: Seq[(String, String)]) extends Actor { class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor {
val slaveActor = new HashMap[String, ActorRef] val executorActor = new HashMap[String, ActorRef]
val slaveAddress = new HashMap[String, Address] val executorAddress = new HashMap[String, Address]
val slaveHost = new HashMap[String, String] val executorHost = new HashMap[String, String]
val freeCores = new HashMap[String, Int] val freeCores = new HashMap[String, Int]
val actorToSlaveId = new HashMap[ActorRef, String] val actorToExecutorId = new HashMap[ActorRef, String]
val addressToSlaveId = new HashMap[Address, String] val addressToExecutorId = new HashMap[Address, String]
override def preStart() { override def preStart() {
// Listen for remote client disconnection events, since they don't go through Akka's watch() // Listen for remote client disconnection events, since they don't go through Akka's watch()
@ -37,86 +37,86 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
} }
def receive = { def receive = {
case RegisterSlave(slaveId, host, cores) => case RegisterExecutor(executorId, host, cores) =>
if (slaveActor.contains(slaveId)) { if (executorActor.contains(executorId)) {
sender ! RegisterSlaveFailed("Duplicate slave ID: " + slaveId) sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId)
} else { } else {
logInfo("Registered slave: " + sender + " with ID " + slaveId) logInfo("Registered executor: " + sender + " with ID " + executorId)
sender ! RegisteredSlave(sparkProperties) sender ! RegisteredExecutor(sparkProperties)
context.watch(sender) context.watch(sender)
slaveActor(slaveId) = sender executorActor(executorId) = sender
slaveHost(slaveId) = host executorHost(executorId) = host
freeCores(slaveId) = cores freeCores(executorId) = cores
slaveAddress(slaveId) = sender.path.address executorAddress(executorId) = sender.path.address
actorToSlaveId(sender) = slaveId actorToExecutorId(sender) = executorId
addressToSlaveId(sender.path.address) = slaveId addressToExecutorId(sender.path.address) = executorId
totalCoreCount.addAndGet(cores) totalCoreCount.addAndGet(cores)
makeOffers() makeOffers()
} }
case StatusUpdate(slaveId, taskId, state, data) => case StatusUpdate(executorId, taskId, state, data) =>
scheduler.statusUpdate(taskId, state, data.value) scheduler.statusUpdate(taskId, state, data.value)
if (TaskState.isFinished(state)) { if (TaskState.isFinished(state)) {
freeCores(slaveId) += 1 freeCores(executorId) += 1
makeOffers(slaveId) makeOffers(executorId)
} }
case ReviveOffers => case ReviveOffers =>
makeOffers() makeOffers()
case StopMaster => case StopDriver =>
sender ! true sender ! true
context.stop(self) context.stop(self)
case Terminated(actor) => case Terminated(actor) =>
actorToSlaveId.get(actor).foreach(removeSlave(_, "Akka actor terminated")) actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated"))
case RemoteClientDisconnected(transport, address) => case RemoteClientDisconnected(transport, address) =>
addressToSlaveId.get(address).foreach(removeSlave(_, "remote Akka client disconnected")) addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disconnected"))
case RemoteClientShutdown(transport, address) => case RemoteClientShutdown(transport, address) =>
addressToSlaveId.get(address).foreach(removeSlave(_, "remote Akka client shutdown")) addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client shutdown"))
} }
// Make fake resource offers on all slaves // Make fake resource offers on all executors
def makeOffers() { def makeOffers() {
launchTasks(scheduler.resourceOffers( launchTasks(scheduler.resourceOffers(
slaveHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))})) executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))}))
} }
// Make fake resource offers on just one slave // Make fake resource offers on just one executor
def makeOffers(slaveId: String) { def makeOffers(executorId: String) {
launchTasks(scheduler.resourceOffers( launchTasks(scheduler.resourceOffers(
Seq(new WorkerOffer(slaveId, slaveHost(slaveId), freeCores(slaveId))))) Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId)))))
} }
// Launch tasks returned by a set of resource offers // Launch tasks returned by a set of resource offers
def launchTasks(tasks: Seq[Seq[TaskDescription]]) { def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
for (task <- tasks.flatten) { for (task <- tasks.flatten) {
freeCores(task.slaveId) -= 1 freeCores(task.executorId) -= 1
slaveActor(task.slaveId) ! LaunchTask(task) executorActor(task.executorId) ! LaunchTask(task)
} }
} }
// Remove a disconnected slave from the cluster // Remove a disconnected slave from the cluster
def removeSlave(slaveId: String, reason: String) { def removeExecutor(executorId: String, reason: String) {
logInfo("Slave " + slaveId + " disconnected, so removing it") logInfo("Slave " + executorId + " disconnected, so removing it")
val numCores = freeCores(slaveId) val numCores = freeCores(executorId)
actorToSlaveId -= slaveActor(slaveId) actorToExecutorId -= executorActor(executorId)
addressToSlaveId -= slaveAddress(slaveId) addressToExecutorId -= executorAddress(executorId)
slaveActor -= slaveId executorActor -= executorId
slaveHost -= slaveId executorHost -= executorId
freeCores -= slaveId freeCores -= executorId
slaveHost -= slaveId executorHost -= executorId
totalCoreCount.addAndGet(-numCores) totalCoreCount.addAndGet(-numCores)
scheduler.slaveLost(slaveId, SlaveLost(reason)) scheduler.executorLost(executorId, SlaveLost(reason))
} }
} }
var masterActor: ActorRef = null var driverActor: ActorRef = null
val taskIdsOnSlave = new HashMap[String, HashSet[String]] val taskIdsOnSlave = new HashMap[String, HashSet[String]]
def start() { override def start() {
val properties = new ArrayBuffer[(String, String)] val properties = new ArrayBuffer[(String, String)]
val iterator = System.getProperties.entrySet.iterator val iterator = System.getProperties.entrySet.iterator
while (iterator.hasNext) { while (iterator.hasNext) {
@ -126,15 +126,15 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
properties += ((key, value)) properties += ((key, value))
} }
} }
masterActor = actorSystem.actorOf( driverActor = actorSystem.actorOf(
Props(new MasterActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME) Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
} }
def stop() { override def stop() {
try { try {
if (masterActor != null) { if (driverActor != null) {
val timeout = 5.seconds val timeout = 5.seconds
val future = masterActor.ask(StopMaster)(timeout) val future = driverActor.ask(StopDriver)(timeout)
Await.result(future, timeout) Await.result(future, timeout)
} }
} catch { } catch {
@ -143,11 +143,11 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
} }
} }
def reviveOffers() { override def reviveOffers() {
masterActor ! ReviveOffers driverActor ! ReviveOffers
} }
def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2) override def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2)
} }
private[spark] object StandaloneSchedulerBackend { private[spark] object StandaloneSchedulerBackend {

View file

@ -5,7 +5,7 @@ import spark.util.SerializableBuffer
private[spark] class TaskDescription( private[spark] class TaskDescription(
val taskId: Long, val taskId: Long,
val slaveId: String, val executorId: String,
val name: String, val name: String,
_serializedTask: ByteBuffer) _serializedTask: ByteBuffer)
extends Serializable { extends Serializable {

View file

@ -4,7 +4,12 @@ package spark.scheduler.cluster
* Information about a running task attempt inside a TaskSet. * Information about a running task attempt inside a TaskSet.
*/ */
private[spark] private[spark]
class TaskInfo(val taskId: Long, val index: Int, val launchTime: Long, val host: String) { class TaskInfo(
val taskId: Long,
val index: Int,
val launchTime: Long,
val executorId: String,
val host: String) {
var finishTime: Long = 0 var finishTime: Long = 0
var failed = false var failed = false

View file

@ -17,10 +17,7 @@ import java.nio.ByteBuffer
/** /**
* Schedules the tasks within a single TaskSet in the ClusterScheduler. * Schedules the tasks within a single TaskSet in the ClusterScheduler.
*/ */
private[spark] class TaskSetManager( private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) extends Logging {
sched: ClusterScheduler,
val taskSet: TaskSet)
extends Logging {
// Maximum time to wait to run a task in a preferred location (in ms) // Maximum time to wait to run a task in a preferred location (in ms)
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
@ -100,7 +97,7 @@ private[spark] class TaskSetManager(
} }
// Add a task to all the pending-task lists that it should be on. // Add a task to all the pending-task lists that it should be on.
def addPendingTask(index: Int) { private def addPendingTask(index: Int) {
val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
if (locations.size == 0) { if (locations.size == 0) {
pendingTasksWithNoPrefs += index pendingTasksWithNoPrefs += index
@ -115,7 +112,7 @@ private[spark] class TaskSetManager(
// Return the pending tasks list for a given host, or an empty list if // Return the pending tasks list for a given host, or an empty list if
// there is no map entry for that host // there is no map entry for that host
def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
pendingTasksForHost.getOrElse(host, ArrayBuffer()) pendingTasksForHost.getOrElse(host, ArrayBuffer())
} }
@ -123,7 +120,7 @@ private[spark] class TaskSetManager(
// Return None if the list is empty. // Return None if the list is empty.
// This method also cleans up any tasks in the list that have already // This method also cleans up any tasks in the list that have already
// been launched, since we want that to happen lazily. // been launched, since we want that to happen lazily.
def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
while (!list.isEmpty) { while (!list.isEmpty) {
val index = list.last val index = list.last
list.trimEnd(1) list.trimEnd(1)
@ -137,11 +134,12 @@ private[spark] class TaskSetManager(
// Return a speculative task for a given host if any are available. The task should not have an // Return a speculative task for a given host if any are available. The task should not have an
// attempt running on this host, in case the host is slow. In addition, if localOnly is set, the // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
// task must have a preference for this host (or no preferred locations at all). // task must have a preference for this host (or no preferred locations at all).
def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = { private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
val hostsAlive = sched.hostsAlive
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
val localTask = speculatableTasks.find { val localTask = speculatableTasks.find {
index => index =>
val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive val locations = tasks(index).preferredLocations.toSet & hostsAlive
val attemptLocs = taskAttempts(index).map(_.host) val attemptLocs = taskAttempts(index).map(_.host)
(locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host) (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host)
} }
@ -161,7 +159,7 @@ private[spark] class TaskSetManager(
// Dequeue a pending task for a given node and return its index. // Dequeue a pending task for a given node and return its index.
// If localOnly is set to false, allow non-local tasks as well. // If localOnly is set to false, allow non-local tasks as well.
def findTask(host: String, localOnly: Boolean): Option[Int] = { private def findTask(host: String, localOnly: Boolean): Option[Int] = {
val localTask = findTaskFromList(getPendingTasksForHost(host)) val localTask = findTaskFromList(getPendingTasksForHost(host))
if (localTask != None) { if (localTask != None) {
return localTask return localTask
@ -183,13 +181,13 @@ private[spark] class TaskSetManager(
// Does a host count as a preferred location for a task? This is true if // Does a host count as a preferred location for a task? This is true if
// either the task has preferred locations and this host is one, or it has // either the task has preferred locations and this host is one, or it has
// no preferred locations (in which we still count the launch as preferred). // no preferred locations (in which we still count the launch as preferred).
def isPreferredLocation(task: Task[_], host: String): Boolean = { private def isPreferredLocation(task: Task[_], host: String): Boolean = {
val locs = task.preferredLocations val locs = task.preferredLocations
return (locs.contains(host) || locs.isEmpty) return (locs.contains(host) || locs.isEmpty)
} }
// Respond to an offer of a single slave from the scheduler by finding a task // Respond to an offer of a single slave from the scheduler by finding a task
def slaveOffer(slaveId: String, host: String, availableCpus: Double): Option[TaskDescription] = { def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = {
if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
val time = System.currentTimeMillis val time = System.currentTimeMillis
val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
@ -206,11 +204,11 @@ private[spark] class TaskSetManager(
} else { } else {
"non-preferred, not one of " + task.preferredLocations.mkString(", ") "non-preferred, not one of " + task.preferredLocations.mkString(", ")
} }
logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
taskSet.id, index, taskId, slaveId, host, prefStr)) taskSet.id, index, taskId, execId, host, prefStr))
// Do various bookkeeping // Do various bookkeeping
copiesRunning(index) += 1 copiesRunning(index) += 1
val info = new TaskInfo(taskId, index, time, host) val info = new TaskInfo(taskId, index, time, execId, host)
taskInfos(taskId) = info taskInfos(taskId) = info
taskAttempts(index) = info :: taskAttempts(index) taskAttempts(index) = info :: taskAttempts(index)
if (preferred) { if (preferred) {
@ -224,7 +222,7 @@ private[spark] class TaskSetManager(
logInfo("Serialized task %s:%d as %d bytes in %d ms".format( logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken)) taskSet.id, index, serializedTask.limit, timeTaken))
val taskName = "task %s:%d".format(taskSet.id, index) val taskName = "task %s:%d".format(taskSet.id, index)
return Some(new TaskDescription(taskId, slaveId, taskName, serializedTask)) return Some(new TaskDescription(taskId, execId, taskName, serializedTask))
} }
case _ => case _ =>
} }
@ -334,7 +332,7 @@ private[spark] class TaskSetManager(
if (numFailures(index) > MAX_TASK_FAILURES) { if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %s:%d failed more than %d times; aborting job".format( logError("Task %s:%d failed more than %d times; aborting job".format(
taskSet.id, index, MAX_TASK_FAILURES)) taskSet.id, index, MAX_TASK_FAILURES))
abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES)) abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
} }
} }
} else { } else {
@ -356,19 +354,22 @@ private[spark] class TaskSetManager(
sched.taskSetFinished(this) sched.taskSetFinished(this)
} }
def hostLost(hostname: String) { def executorLost(execId: String, hostname: String) {
logInfo("Re-queueing tasks for " + hostname + " from TaskSet " + taskSet.id) logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
// If some task has preferred locations only on hostname, put it in the no-prefs list val newHostsAlive = sched.hostsAlive
// to avoid the wait from delay scheduling // If some task has preferred locations only on hostname, and there are no more executors there,
for (index <- getPendingTasksForHost(hostname)) { // put it in the no-prefs list to avoid the wait from delay scheduling
val newLocs = tasks(index).preferredLocations.toSet & sched.hostsAlive if (!newHostsAlive.contains(hostname)) {
if (newLocs.isEmpty) { for (index <- getPendingTasksForHost(hostname)) {
pendingTasksWithNoPrefs += index val newLocs = tasks(index).preferredLocations.toSet & newHostsAlive
if (newLocs.isEmpty) {
pendingTasksWithNoPrefs += index
}
} }
} }
// Re-enqueue any tasks that ran on the failed host if this is a shuffle map stage // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
if (tasks(0).isInstanceOf[ShuffleMapTask]) { if (tasks(0).isInstanceOf[ShuffleMapTask]) {
for ((tid, info) <- taskInfos if info.host == hostname) { for ((tid, info) <- taskInfos if info.executorId == execId) {
val index = taskInfos(tid).index val index = taskInfos(tid).index
if (finished(index)) { if (finished(index)) {
finished(index) = false finished(index) = false
@ -382,7 +383,7 @@ private[spark] class TaskSetManager(
} }
} }
// Also re-enqueue any tasks that were running on the node // Also re-enqueue any tasks that were running on the node
for ((tid, info) <- taskInfos if info.running && info.host == hostname) { for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
taskLost(tid, TaskState.KILLED, null) taskLost(tid, TaskState.KILLED, null)
} }
} }

View file

@ -1,8 +1,8 @@
package spark.scheduler.cluster package spark.scheduler.cluster
/** /**
* Represents free resources available on a worker node. * Represents free resources available on an executor.
*/ */
private[spark] private[spark]
class WorkerOffer(val slaveId: String, val hostname: String, val cores: Int) { class WorkerOffer(val executorId: String, val hostname: String, val cores: Int) {
} }

View file

@ -20,7 +20,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
with Logging { with Logging {
var attemptId = new AtomicInteger(0) var attemptId = new AtomicInteger(0)
var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) var threadPool = Utils.newDaemonFixedThreadPool(threads)
val env = SparkEnv.get val env = SparkEnv.get
var listener: TaskSchedulerListener = null var listener: TaskSchedulerListener = null
@ -53,7 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
} }
def runTask(task: Task[_], idInJob: Int, attemptId: Int) { def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
logInfo("Running task " + idInJob) logInfo("Running " + task)
// Set the Spark execution environment for the worker thread // Set the Spark execution environment for the worker thread
SparkEnv.set(env) SparkEnv.set(env)
try { try {
@ -80,7 +80,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
val resultToReturn = ser.deserialize[Any](ser.serialize(result)) val resultToReturn = ser.deserialize[Any](ser.serialize(result))
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values)) ser.serialize(Accumulators.values))
logInfo("Finished task " + idInJob) logInfo("Finished " + task)
// If the threadpool has not already been shutdown, notify DAGScheduler // If the threadpool has not already been shutdown, notify DAGScheduler
if (!Thread.currentThread().isInterrupted) if (!Thread.currentThread().isInterrupted)
@ -116,16 +116,16 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
// Fetch missing dependencies // Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp) logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File(".")) Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentFiles(name) = timestamp currentFiles(name) = timestamp
} }
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp) logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File(".")) Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentJars(name) = timestamp currentJars(name) = timestamp
// Add it to our class loader // Add it to our class loader
val localName = name.split("/").last val localName = name.split("/").last
val url = new File(".", localName).toURI.toURL val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
if (!classLoader.getURLs.contains(url)) { if (!classLoader.getURLs.contains(url)) {
logInfo("Adding " + url + " to class loader") logInfo("Adding " + url + " to class loader")
classLoader.addURL(url) classLoader.addURL(url)

View file

@ -35,16 +35,6 @@ private[spark] class CoarseMesosSchedulerBackend(
val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures
// Memory used by each executor (in megabytes)
val executorMemory = {
if (System.getenv("SPARK_MEM") != null) {
Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
} else {
512
}
}
// Lock used to wait for scheduler to be registered // Lock used to wait for scheduler to be registered
var isRegistered = false var isRegistered = false
val registeredLock = new Object() val registeredLock = new Object()
@ -64,13 +54,9 @@ private[spark] class CoarseMesosSchedulerBackend(
val taskIdToSlaveId = new HashMap[Int, String] val taskIdToSlaveId = new HashMap[Int, String]
val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed
val sparkHome = sc.getSparkHome() match { val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException(
case Some(path) => "Spark home is not set; set it through the spark.home system " +
path "property, the SPARK_HOME environment variable or the SparkContext constructor"))
case None =>
throw new SparkException("Spark home is not set; set it through the spark.home system " +
"property, the SPARK_HOME environment variable or the SparkContext constructor")
}
val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt
@ -108,11 +94,11 @@ private[spark] class CoarseMesosSchedulerBackend(
def createCommand(offer: Offer, numCores: Int): CommandInfo = { def createCommand(offer: Offer, numCores: Int): CommandInfo = {
val runScript = new File(sparkHome, "run").getCanonicalPath val runScript = new File(sparkHome, "run").getCanonicalPath
val masterUrl = "akka://spark@%s:%s/user/%s".format( val driverUrl = "akka://spark@%s:%s/user/%s".format(
System.getProperty("spark.master.host"), System.getProperty("spark.master.port"), System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
StandaloneSchedulerBackend.ACTOR_NAME) StandaloneSchedulerBackend.ACTOR_NAME)
val command = "\"%s\" spark.executor.StandaloneExecutorBackend %s %s %s %d".format( val command = "\"%s\" spark.executor.StandaloneExecutorBackend %s %s %s %d".format(
runScript, masterUrl, offer.getSlaveId.getValue, offer.getHostname, numCores) runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)
val environment = Environment.newBuilder() val environment = Environment.newBuilder()
sc.executorEnvs.foreach { case (key, value) => sc.executorEnvs.foreach { case (key, value) =>
environment.addVariables(Environment.Variable.newBuilder() environment.addVariables(Environment.Variable.newBuilder()
@ -184,7 +170,7 @@ private[spark] class CoarseMesosSchedulerBackend(
} }
/** Helper function to pull out a resource from a Mesos Resources protobuf */ /** Helper function to pull out a resource from a Mesos Resources protobuf */
def getResource(res: JList[Resource], name: String): Double = { private def getResource(res: JList[Resource], name: String): Double = {
for (r <- res if r.getName == name) { for (r <- res if r.getName == name) {
return r.getScalar.getValue return r.getScalar.getValue
} }
@ -193,7 +179,7 @@ private[spark] class CoarseMesosSchedulerBackend(
} }
/** Build a Mesos resource protobuf object */ /** Build a Mesos resource protobuf object */
def createResource(resourceName: String, quantity: Double): Protos.Resource = { private def createResource(resourceName: String, quantity: Double): Protos.Resource = {
Resource.newBuilder() Resource.newBuilder()
.setName(resourceName) .setName(resourceName)
.setType(Value.Type.SCALAR) .setType(Value.Type.SCALAR)
@ -202,7 +188,7 @@ private[spark] class CoarseMesosSchedulerBackend(
} }
/** Check whether a Mesos task state represents a finished task */ /** Check whether a Mesos task state represents a finished task */
def isFinished(state: MesosTaskState) = { private def isFinished(state: MesosTaskState) = {
state == MesosTaskState.TASK_FINISHED || state == MesosTaskState.TASK_FINISHED ||
state == MesosTaskState.TASK_FAILED || state == MesosTaskState.TASK_FAILED ||
state == MesosTaskState.TASK_KILLED || state == MesosTaskState.TASK_KILLED ||

View file

@ -29,16 +29,6 @@ private[spark] class MesosSchedulerBackend(
with MScheduler with MScheduler
with Logging { with Logging {
// Memory used by each executor (in megabytes)
val EXECUTOR_MEMORY = {
if (System.getenv("SPARK_MEM") != null) {
Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
} else {
512
}
}
// Lock used to wait for scheduler to be registered // Lock used to wait for scheduler to be registered
var isRegistered = false var isRegistered = false
val registeredLock = new Object() val registeredLock = new Object()
@ -51,7 +41,7 @@ private[spark] class MesosSchedulerBackend(
val taskIdToSlaveId = new HashMap[Long, String] val taskIdToSlaveId = new HashMap[Long, String]
// An ExecutorInfo for our tasks // An ExecutorInfo for our tasks
var executorInfo: ExecutorInfo = null var execArgs: Array[Byte] = null
override def start() { override def start() {
synchronized { synchronized {
@ -70,19 +60,14 @@ private[spark] class MesosSchedulerBackend(
} }
}.start() }.start()
executorInfo = createExecutorInfo()
waitForRegister() waitForRegister()
} }
} }
def createExecutorInfo(): ExecutorInfo = { def createExecutorInfo(execId: String): ExecutorInfo = {
val sparkHome = sc.getSparkHome() match { val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException(
case Some(path) => "Spark home is not set; set it through the spark.home system " +
path "property, the SPARK_HOME environment variable or the SparkContext constructor"))
case None =>
throw new SparkException("Spark home is not set; set it through the spark.home system " +
"property, the SPARK_HOME environment variable or the SparkContext constructor")
}
val execScript = new File(sparkHome, "spark-executor").getCanonicalPath val execScript = new File(sparkHome, "spark-executor").getCanonicalPath
val environment = Environment.newBuilder() val environment = Environment.newBuilder()
sc.executorEnvs.foreach { case (key, value) => sc.executorEnvs.foreach { case (key, value) =>
@ -94,14 +79,14 @@ private[spark] class MesosSchedulerBackend(
val memory = Resource.newBuilder() val memory = Resource.newBuilder()
.setName("mem") .setName("mem")
.setType(Value.Type.SCALAR) .setType(Value.Type.SCALAR)
.setScalar(Value.Scalar.newBuilder().setValue(EXECUTOR_MEMORY).build()) .setScalar(Value.Scalar.newBuilder().setValue(executorMemory).build())
.build() .build()
val command = CommandInfo.newBuilder() val command = CommandInfo.newBuilder()
.setValue(execScript) .setValue(execScript)
.setEnvironment(environment) .setEnvironment(environment)
.build() .build()
ExecutorInfo.newBuilder() ExecutorInfo.newBuilder()
.setExecutorId(ExecutorID.newBuilder().setValue("default").build()) .setExecutorId(ExecutorID.newBuilder().setValue(execId).build())
.setCommand(command) .setCommand(command)
.setData(ByteString.copyFrom(createExecArg())) .setData(ByteString.copyFrom(createExecArg()))
.addResources(memory) .addResources(memory)
@ -113,17 +98,20 @@ private[spark] class MesosSchedulerBackend(
* containing all the spark.* system properties in the form of (String, String) pairs. * containing all the spark.* system properties in the form of (String, String) pairs.
*/ */
private def createExecArg(): Array[Byte] = { private def createExecArg(): Array[Byte] = {
val props = new HashMap[String, String] if (execArgs == null) {
val iterator = System.getProperties.entrySet.iterator val props = new HashMap[String, String]
while (iterator.hasNext) { val iterator = System.getProperties.entrySet.iterator
val entry = iterator.next while (iterator.hasNext) {
val (key, value) = (entry.getKey.toString, entry.getValue.toString) val entry = iterator.next
if (key.startsWith("spark.")) { val (key, value) = (entry.getKey.toString, entry.getValue.toString)
props(key) = value if (key.startsWith("spark.")) {
props(key) = value
}
} }
// Serialize the map as an array of (String, String) pairs
execArgs = Utils.serialize(props.toArray)
} }
// Serialize the map as an array of (String, String) pairs return execArgs
return Utils.serialize(props.toArray)
} }
override def offerRescinded(d: SchedulerDriver, o: OfferID) {} override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
@ -163,7 +151,7 @@ private[spark] class MesosSchedulerBackend(
def enoughMemory(o: Offer) = { def enoughMemory(o: Offer) = {
val mem = getResource(o.getResourcesList, "mem") val mem = getResource(o.getResourcesList, "mem")
val slaveId = o.getSlaveId.getValue val slaveId = o.getSlaveId.getValue
mem >= EXECUTOR_MEMORY || slaveIdsWithExecutors.contains(slaveId) mem >= executorMemory || slaveIdsWithExecutors.contains(slaveId)
} }
for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) { for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) {
@ -220,7 +208,7 @@ private[spark] class MesosSchedulerBackend(
return MesosTaskInfo.newBuilder() return MesosTaskInfo.newBuilder()
.setTaskId(taskId) .setTaskId(taskId)
.setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build())
.setExecutor(executorInfo) .setExecutor(createExecutorInfo(slaveId))
.setName(task.name) .setName(task.name)
.addResources(cpuResource) .addResources(cpuResource)
.setData(ByteString.copyFrom(task.serializedTask)) .setData(ByteString.copyFrom(task.serializedTask))
@ -272,7 +260,7 @@ private[spark] class MesosSchedulerBackend(
synchronized { synchronized {
slaveIdsWithExecutors -= slaveId.getValue slaveIdsWithExecutors -= slaveId.getValue
} }
scheduler.slaveLost(slaveId.getValue, reason) scheduler.executorLost(slaveId.getValue, reason)
} }
override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) {

View file

@ -16,7 +16,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
import spark.{CacheTracker, Logging, SizeEstimator, SparkEnv, SparkException, Utils} import spark.{Logging, SizeEstimator, SparkEnv, SparkException, Utils}
import spark.network._ import spark.network._
import spark.serializer.Serializer import spark.serializer.Serializer
import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap} import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap}
@ -30,6 +30,7 @@ extends Exception(message)
private[spark] private[spark]
class BlockManager( class BlockManager(
executorId: String,
actorSystem: ActorSystem, actorSystem: ActorSystem,
val master: BlockManagerMaster, val master: BlockManagerMaster,
val serializer: Serializer, val serializer: Serializer,
@ -68,11 +69,8 @@ class BlockManager(
val connectionManager = new ConnectionManager(0) val connectionManager = new ConnectionManager(0)
implicit val futureExecContext = connectionManager.futureExecContext implicit val futureExecContext = connectionManager.futureExecContext
val connectionManagerId = connectionManager.id val blockManagerId = BlockManagerId(
val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port) executorId, connectionManager.id.host, connectionManager.id.port)
// TODO: This will be removed after cacheTracker is removed from the code base.
var cacheTracker: CacheTracker = null
// Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
// for receiving shuffle outputs) // for receiving shuffle outputs)
@ -93,7 +91,10 @@ class BlockManager(
val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
@volatile private var shuttingDown = false // Pending reregistration action being executed asynchronously or null if none
// is pending. Accesses should synchronize on asyncReregisterLock.
var asyncReregisterTask: Future[Unit] = null
val asyncReregisterLock = new Object
private def heartBeat() { private def heartBeat() {
if (!master.sendHeartBeat(blockManagerId)) { if (!master.sendHeartBeat(blockManagerId)) {
@ -109,8 +110,9 @@ class BlockManager(
/** /**
* Construct a BlockManager with a memory limit set based on system properties. * Construct a BlockManager with a memory limit set based on system properties.
*/ */
def this(actorSystem: ActorSystem, master: BlockManagerMaster, serializer: Serializer) = { def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster,
this(actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties) serializer: Serializer) = {
this(execId, actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties)
} }
/** /**
@ -150,6 +152,8 @@ class BlockManager(
/** /**
* Reregister with the master and report all blocks to it. This will be called by the heart beat * Reregister with the master and report all blocks to it. This will be called by the heart beat
* thread if our heartbeat to the block amnager indicates that we were not registered. * thread if our heartbeat to the block amnager indicates that we were not registered.
*
* Note that this method must be called without any BlockInfo locks held.
*/ */
def reregister() { def reregister() {
// TODO: We might need to rate limit reregistering. // TODO: We might need to rate limit reregistering.
@ -158,6 +162,32 @@ class BlockManager(
reportAllBlocks() reportAllBlocks()
} }
/**
* Reregister with the master sometime soon.
*/
def asyncReregister() {
asyncReregisterLock.synchronized {
if (asyncReregisterTask == null) {
asyncReregisterTask = Future[Unit] {
reregister()
asyncReregisterLock.synchronized {
asyncReregisterTask = null
}
}
}
}
}
/**
* For testing. Wait for any pending asynchronous reregistration; otherwise, do nothing.
*/
def waitForAsyncReregister() {
val task = asyncReregisterTask
if (task != null) {
Await.ready(task, Duration.Inf)
}
}
/** /**
* Get storage level of local block. If no info exists for the block, then returns null. * Get storage level of local block. If no info exists for the block, then returns null.
*/ */
@ -173,7 +203,7 @@ class BlockManager(
if (needReregister) { if (needReregister) {
logInfo("Got told to reregister updating block " + blockId) logInfo("Got told to reregister updating block " + blockId)
// Reregistering will report our new block for free. // Reregistering will report our new block for free.
reregister() asyncReregister()
} }
logDebug("Told master about block " + blockId) logDebug("Told master about block " + blockId)
} }
@ -191,7 +221,7 @@ class BlockManager(
case level => case level =>
val inMem = level.useMemory && memoryStore.contains(blockId) val inMem = level.useMemory && memoryStore.contains(blockId)
val onDisk = level.useDisk && diskStore.contains(blockId) val onDisk = level.useDisk && diskStore.contains(blockId)
val storageLevel = new StorageLevel(onDisk, inMem, level.deserialized, level.replication) val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication)
val memSize = if (inMem) memoryStore.getSize(blockId) else 0L val memSize = if (inMem) memoryStore.getSize(blockId) else 0L
val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L
(storageLevel, memSize, diskSize, info.tellMaster) (storageLevel, memSize, diskSize, info.tellMaster)
@ -213,7 +243,7 @@ class BlockManager(
val startTimeMs = System.currentTimeMillis val startTimeMs = System.currentTimeMillis
var managers = master.getLocations(blockId) var managers = master.getLocations(blockId)
val locations = managers.map(_.ip) val locations = managers.map(_.ip)
logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs)) logDebug("Got block locations in " + Utils.getUsedTimeMs(startTimeMs))
return locations return locations
} }
@ -223,7 +253,7 @@ class BlockManager(
def getLocations(blockIds: Array[String]): Array[Seq[String]] = { def getLocations(blockIds: Array[String]): Array[Seq[String]] = {
val startTimeMs = System.currentTimeMillis val startTimeMs = System.currentTimeMillis
val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray
logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
return locations return locations
} }
@ -615,7 +645,7 @@ class BlockManager(
var size = 0L var size = 0L
myInfo.synchronized { myInfo.synchronized {
logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block") + " to get into synchronized block")
if (level.useMemory) { if (level.useMemory) {
@ -647,8 +677,10 @@ class BlockManager(
} }
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
// Replicate block if required // Replicate block if required
if (level.replication > 1) { if (level.replication > 1) {
val remoteStartTime = System.currentTimeMillis
// Serialize the block if not already done // Serialize the block if not already done
if (bytesAfterPut == null) { if (bytesAfterPut == null) {
if (valuesAfterPut == null) { if (valuesAfterPut == null) {
@ -658,16 +690,10 @@ class BlockManager(
bytesAfterPut = dataSerialize(blockId, valuesAfterPut) bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
} }
replicate(blockId, bytesAfterPut, level) replicate(blockId, bytesAfterPut, level)
logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime))
} }
BlockManager.dispose(bytesAfterPut) BlockManager.dispose(bytesAfterPut)
// TODO: This code will be removed when CacheTracker is gone.
if (blockId.startsWith("rdd")) {
notifyCacheTracker(blockId)
}
logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs))
return size return size
} }
@ -733,11 +759,6 @@ class BlockManager(
} }
} }
// TODO: This code will be removed when CacheTracker is gone.
if (blockId.startsWith("rdd")) {
notifyCacheTracker(blockId)
}
// If replication had started, then wait for it to finish // If replication had started, then wait for it to finish
if (level.replication > 1) { if (level.replication > 1) {
if (replicationFuture == null) { if (replicationFuture == null) {
@ -760,8 +781,7 @@ class BlockManager(
*/ */
var cachedPeers: Seq[BlockManagerId] = null var cachedPeers: Seq[BlockManagerId] = null
private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) {
val tLevel: StorageLevel = val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
if (cachedPeers == null) { if (cachedPeers == null) {
cachedPeers = master.getPeers(blockManagerId, level.replication - 1) cachedPeers = master.getPeers(blockManagerId, level.replication - 1)
} }
@ -780,16 +800,6 @@ class BlockManager(
} }
} }
// TODO: This code will be removed when CacheTracker is gone.
private def notifyCacheTracker(key: String) {
if (cacheTracker != null) {
val rddInfo = key.split("_")
val rddId: Int = rddInfo(1).toInt
val partition: Int = rddInfo(2).toInt
cacheTracker.notifyFromBlockManager(spark.AddedToCache(rddId, partition, host))
}
}
/** /**
* Read a block consisting of a single object. * Read a block consisting of a single object.
*/ */
@ -940,6 +950,7 @@ class BlockManager(
blockInfo.clear() blockInfo.clear()
memoryStore.clear() memoryStore.clear()
diskStore.clear() diskStore.clear()
metadataCleaner.cancel()
logInfo("BlockManager stopped") logInfo("BlockManager stopped")
} }
} }
@ -968,7 +979,7 @@ object BlockManager extends Logging {
*/ */
def dispose(buffer: ByteBuffer) { def dispose(buffer: ByteBuffer) {
if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) {
logDebug("Unmapping " + buffer) logTrace("Unmapping " + buffer)
if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) { if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) {
buffer.asInstanceOf[DirectBuffer].cleaner().clean() buffer.asInstanceOf[DirectBuffer].cleaner().clean()
} }

View file

@ -3,38 +3,67 @@ package spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
/**
* This class represent an unique identifier for a BlockManager.
* The first 2 constructors of this class is made private to ensure that
* BlockManagerId objects can be created only using the factory method in
* [[spark.storage.BlockManager$]]. This allows de-duplication of ID objects.
* Also, constructor parameters are private to ensure that parameters cannot
* be modified from outside this class.
*/
private[spark] class BlockManagerId private (
private var executorId_ : String,
private var ip_ : String,
private var port_ : Int
) extends Externalizable {
private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { private def this() = this(null, null, 0) // For deserialization only
def this() = this(null, 0) // For deserialization only
def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) def executorId: String = executorId_
def ip: String = ip_
def port: Int = port_
override def writeExternal(out: ObjectOutput) { override def writeExternal(out: ObjectOutput) {
out.writeUTF(ip) out.writeUTF(executorId_)
out.writeInt(port) out.writeUTF(ip_)
out.writeInt(port_)
} }
override def readExternal(in: ObjectInput) { override def readExternal(in: ObjectInput) {
ip = in.readUTF() executorId_ = in.readUTF()
port = in.readInt() ip_ = in.readUTF()
port_ = in.readInt()
} }
@throws(classOf[IOException]) @throws(classOf[IOException])
private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this)
override def toString = "BlockManagerId(" + ip + ", " + port + ")" override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, ip, port)
override def hashCode = ip.hashCode * 41 + port override def hashCode: Int = (executorId.hashCode * 41 + ip.hashCode) * 41 + port
override def equals(that: Any) = that match { override def equals(that: Any) = that match {
case id: BlockManagerId => port == id.port && ip == id.ip case id: BlockManagerId =>
case _ => false executorId == id.executorId && port == id.port && ip == id.ip
case _ =>
false
} }
} }
private[spark] object BlockManagerId { private[spark] object BlockManagerId {
def apply(execId: String, ip: String, port: Int) =
getCachedBlockManagerId(new BlockManagerId(execId, ip, port))
def apply(in: ObjectInput) = {
val obj = new BlockManagerId()
obj.readExternal(in)
getCachedBlockManagerId(obj)
}
val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()
def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {

View file

@ -1,6 +1,10 @@
package spark.storage package spark.storage
import scala.collection.mutable.ArrayBuffer import java.io._
import java.util.{HashMap => JHashMap}
import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.util.Random import scala.util.Random
import akka.actor.{Actor, ActorRef, ActorSystem, Props} import akka.actor.{Actor, ActorRef, ActorSystem, Props}
@ -11,52 +15,49 @@ import akka.util.duration._
import spark.{Logging, SparkException, Utils} import spark.{Logging, SparkException, Utils}
private[spark] class BlockManagerMaster( private[spark] class BlockManagerMaster(
val actorSystem: ActorSystem, val actorSystem: ActorSystem,
isMaster: Boolean, isDriver: Boolean,
isLocal: Boolean, isLocal: Boolean,
masterIp: String, driverIp: String,
masterPort: Int) driverPort: Int)
extends Logging { extends Logging {
val AKKA_RETRY_ATTEMPS: Int = System.getProperty("spark.akka.num.retries", "3").toInt val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt
val MASTER_AKKA_ACTOR_NAME = "BlockMasterManager" val DRIVER_AKKA_ACTOR_NAME = "BlockMasterManager"
val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager"
val DEFAULT_MANAGER_IP: String = Utils.localHostName()
val timeout = 10.seconds val timeout = 10.seconds
var masterActor: ActorRef = { var driverActor: ActorRef = {
if (isMaster) { if (isDriver) {
val masterActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)), val driverActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)),
name = MASTER_AKKA_ACTOR_NAME) name = DRIVER_AKKA_ACTOR_NAME)
logInfo("Registered BlockManagerMaster Actor") logInfo("Registered BlockManagerMaster Actor")
masterActor driverActor
} else { } else {
val url = "akka://spark@%s:%s/user/%s".format(masterIp, masterPort, MASTER_AKKA_ACTOR_NAME) val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, DRIVER_AKKA_ACTOR_NAME)
logInfo("Connecting to BlockManagerMaster: " + url) logInfo("Connecting to BlockManagerMaster: " + url)
actorSystem.actorFor(url) actorSystem.actorFor(url)
} }
} }
/** Remove a dead host from the master actor. This is only called on the master side. */ /** Remove a dead executor from the driver actor. This is only called on the driver side. */
def notifyADeadHost(host: String) { def removeExecutor(execId: String) {
tell(RemoveHost(host)) tell(RemoveExecutor(execId))
logInfo("Removed " + host + " successfully in notifyADeadHost") logInfo("Removed " + execId + " successfully in removeExecutor")
} }
/** /**
* Send the master actor a heart beat from the slave. Returns true if everything works out, * Send the driver actor a heart beat from the slave. Returns true if everything works out,
* false if the master does not know about the given block manager, which means the block * false if the driver does not know about the given block manager, which means the block
* manager should re-register. * manager should re-register.
*/ */
def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = { def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = {
askMasterWithRetry[Boolean](HeartBeat(blockManagerId)) askDriverWithReply[Boolean](HeartBeat(blockManagerId))
} }
/** Register the BlockManager's id with the master. */ /** Register the BlockManager's id with the driver. */
def registerBlockManager( def registerBlockManager(
blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
logInfo("Trying to register BlockManager") logInfo("Trying to register BlockManager")
@ -70,25 +71,25 @@ private[spark] class BlockManagerMaster(
storageLevel: StorageLevel, storageLevel: StorageLevel,
memSize: Long, memSize: Long,
diskSize: Long): Boolean = { diskSize: Long): Boolean = {
val res = askMasterWithRetry[Boolean]( val res = askDriverWithReply[Boolean](
UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)) UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize))
logInfo("Updated info of block " + blockId) logInfo("Updated info of block " + blockId)
res res
} }
/** Get locations of the blockId from the master */ /** Get locations of the blockId from the driver */
def getLocations(blockId: String): Seq[BlockManagerId] = { def getLocations(blockId: String): Seq[BlockManagerId] = {
askMasterWithRetry[Seq[BlockManagerId]](GetLocations(blockId)) askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId))
} }
/** Get locations of multiple blockIds from the master */ /** Get locations of multiple blockIds from the driver */
def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
askMasterWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
} }
/** Get ids of other nodes in the cluster from the master */ /** Get ids of other nodes in the cluster from the driver */
def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = {
val result = askMasterWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers))
if (result.length != numPeers) { if (result.length != numPeers) {
throw new SparkException( throw new SparkException(
"Error getting peers, only got " + result.size + " instead of " + numPeers) "Error getting peers, only got " + result.size + " instead of " + numPeers)
@ -98,10 +99,10 @@ private[spark] class BlockManagerMaster(
/** /**
* Remove a block from the slaves that have it. This can only be used to remove * Remove a block from the slaves that have it. This can only be used to remove
* blocks that the master knows about. * blocks that the driver knows about.
*/ */
def removeBlock(blockId: String) { def removeBlock(blockId: String) {
askMasterWithRetry(RemoveBlock(blockId)) askDriverWithReply(RemoveBlock(blockId))
} }
/** /**
@ -111,41 +112,45 @@ private[spark] class BlockManagerMaster(
* amount of remaining memory. * amount of remaining memory.
*/ */
def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = {
askMasterWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
} }
/** Stop the master actor, called only on the Spark master node */ def getStorageStatus: Array[StorageStatus] = {
askDriverWithReply[ArrayBuffer[StorageStatus]](GetStorageStatus).toArray
}
/** Stop the driver actor, called only on the Spark driver node */
def stop() { def stop() {
if (masterActor != null) { if (driverActor != null) {
tell(StopBlockManagerMaster) tell(StopBlockManagerMaster)
masterActor = null driverActor = null
logInfo("BlockManagerMaster stopped") logInfo("BlockManagerMaster stopped")
} }
} }
/** Send a one-way message to the master actor, to which we expect it to reply with true. */ /** Send a one-way message to the master actor, to which we expect it to reply with true. */
private def tell(message: Any) { private def tell(message: Any) {
if (!askMasterWithRetry[Boolean](message)) { if (!askDriverWithReply[Boolean](message)) {
throw new SparkException("BlockManagerMasterActor returned false, expected true.") throw new SparkException("BlockManagerMasterActor returned false, expected true.")
} }
} }
/** /**
* Send a message to the master actor and get its result within a default timeout, or * Send a message to the driver actor and get its result within a default timeout, or
* throw a SparkException if this fails. * throw a SparkException if this fails.
*/ */
private def askMasterWithRetry[T](message: Any): T = { private def askDriverWithReply[T](message: Any): T = {
// TODO: Consider removing multiple attempts // TODO: Consider removing multiple attempts
if (masterActor == null) { if (driverActor == null) {
throw new SparkException("Error sending message to BlockManager as masterActor is null " + throw new SparkException("Error sending message to BlockManager as driverActor is null " +
"[message = " + message + "]") "[message = " + message + "]")
} }
var attempts = 0 var attempts = 0
var lastException: Exception = null var lastException: Exception = null
while (attempts < AKKA_RETRY_ATTEMPS) { while (attempts < AKKA_RETRY_ATTEMPTS) {
attempts += 1 attempts += 1
try { try {
val future = masterActor.ask(message)(timeout) val future = driverActor.ask(message)(timeout)
val result = Await.result(future, timeout) val result = Await.result(future, timeout)
if (result == null) { if (result == null) {
throw new Exception("BlockManagerMaster returned null") throw new Exception("BlockManagerMaster returned null")

View file

@ -23,9 +23,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
private val blockManagerInfo = private val blockManagerInfo =
new HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo] new HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo]
// Mapping from host name to block manager id. We allow multiple block managers // Mapping from executor ID to block manager ID.
// on the same host name (ip). private val blockManagerIdByExecutor = new HashMap[String, BlockManagerId]
private val blockManagerIdByHost = new HashMap[String, ArrayBuffer[BlockManagerId]]
// Mapping from block id to the set of block managers that have the block. // Mapping from block id to the set of block managers that have the block.
private val blockLocations = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]] private val blockLocations = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]]
@ -68,11 +67,14 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
case GetMemoryStatus => case GetMemoryStatus =>
getMemoryStatus getMemoryStatus
case GetStorageStatus =>
getStorageStatus
case RemoveBlock(blockId) => case RemoveBlock(blockId) =>
removeBlock(blockId) removeBlock(blockId)
case RemoveHost(host) => case RemoveExecutor(execId) =>
removeHost(host) removeExecutor(execId)
sender ! true sender ! true
case StopBlockManagerMaster => case StopBlockManagerMaster =>
@ -96,16 +98,12 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
def removeBlockManager(blockManagerId: BlockManagerId) { def removeBlockManager(blockManagerId: BlockManagerId) {
val info = blockManagerInfo(blockManagerId) val info = blockManagerInfo(blockManagerId)
// Remove the block manager from blockManagerIdByHost. If the list of block // Remove the block manager from blockManagerIdByExecutor.
// managers belonging to the IP is empty, remove the entry from the hash map. blockManagerIdByExecutor -= blockManagerId.executorId
blockManagerIdByHost.get(blockManagerId.ip).foreach { managers: ArrayBuffer[BlockManagerId] =>
managers -= blockManagerId
if (managers.size == 0) blockManagerIdByHost.remove(blockManagerId.ip)
}
// Remove it from blockManagerInfo and remove all the blocks. // Remove it from blockManagerInfo and remove all the blocks.
blockManagerInfo.remove(blockManagerId) blockManagerInfo.remove(blockManagerId)
var iterator = info.blocks.keySet.iterator val iterator = info.blocks.keySet.iterator
while (iterator.hasNext) { while (iterator.hasNext) {
val blockId = iterator.next val blockId = iterator.next
val locations = blockLocations.get(blockId)._2 val locations = blockLocations.get(blockId)._2
@ -117,7 +115,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
} }
def expireDeadHosts() { def expireDeadHosts() {
logDebug("Checking for hosts with no recent heart beats in BlockManagerMaster.") logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.")
val now = System.currentTimeMillis() val now = System.currentTimeMillis()
val minSeenTime = now - slaveTimeout val minSeenTime = now - slaveTimeout
val toRemove = new HashSet[BlockManagerId] val toRemove = new HashSet[BlockManagerId]
@ -130,17 +128,15 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
toRemove.foreach(removeBlockManager) toRemove.foreach(removeBlockManager)
} }
def removeHost(host: String) { def removeExecutor(execId: String) {
logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.") logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.")
logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq) blockManagerIdByExecutor.get(execId).foreach(removeBlockManager)
blockManagerIdByHost.get(host).foreach(_.foreach(removeBlockManager))
logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq)
sender ! true sender ! true
} }
def heartBeat(blockManagerId: BlockManagerId) { def heartBeat(blockManagerId: BlockManagerId) {
if (!blockManagerInfo.contains(blockManagerId)) { if (!blockManagerInfo.contains(blockManagerId)) {
if (blockManagerId.ip == Utils.localHostName() && !isLocal) { if (blockManagerId.executorId == "<driver>" && !isLocal) {
sender ! true sender ! true
} else { } else {
sender ! false sender ! false
@ -177,24 +173,28 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
sender ! res sender ! res
} }
private def register(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { private def getStorageStatus() {
val startTimeMs = System.currentTimeMillis() val res = blockManagerInfo.map { case(blockManagerId, info) =>
val tmp = " " + blockManagerId + " " import collection.JavaConverters._
StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap)
}
sender ! res
}
if (blockManagerId.ip == Utils.localHostName() && !isLocal) { private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
logInfo("Got Register Msg from master node, don't register it") if (id.executorId == "<driver>" && !isLocal) {
} else { // Got a register message from the master node; don't register it
blockManagerIdByHost.get(blockManagerId.ip) match { } else if (!blockManagerInfo.contains(id)) {
case Some(managers) => blockManagerIdByExecutor.get(id.executorId) match {
// A block manager of the same host name already exists. case Some(manager) =>
logInfo("Got another registration for host " + blockManagerId) // A block manager of the same host name already exists
managers += blockManagerId logError("Got two different block manager registrations on " + id.executorId)
System.exit(1)
case None => case None =>
blockManagerIdByHost += (blockManagerId.ip -> ArrayBuffer(blockManagerId)) blockManagerIdByExecutor(id.executorId) = id
} }
blockManagerInfo(id) = new BlockManagerMasterActor.BlockManagerInfo(
blockManagerInfo += (blockManagerId -> new BlockManagerMasterActor.BlockManagerInfo( id, System.currentTimeMillis(), maxMemSize, slaveActor)
blockManagerId, System.currentTimeMillis(), maxMemSize, slaveActor))
} }
sender ! true sender ! true
} }
@ -206,11 +206,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
memSize: Long, memSize: Long,
diskSize: Long) { diskSize: Long) {
val startTimeMs = System.currentTimeMillis()
val tmp = " " + blockManagerId + " " + blockId + " "
if (!blockManagerInfo.contains(blockManagerId)) { if (!blockManagerInfo.contains(blockManagerId)) {
if (blockManagerId.ip == Utils.localHostName() && !isLocal) { if (blockManagerId.executorId == "<driver>" && !isLocal) {
// We intentionally do not register the master (except in local mode), // We intentionally do not register the master (except in local mode),
// so we should not indicate failure. // so we should not indicate failure.
sender ! true sender ! true
@ -342,8 +339,8 @@ object BlockManagerMasterActor {
_lastSeenMs = System.currentTimeMillis() _lastSeenMs = System.currentTimeMillis()
} }
def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long) def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long,
: Unit = synchronized { diskSize: Long) {
updateLastSeenMs() updateLastSeenMs()

View file

@ -54,11 +54,9 @@ class UpdateBlockInfo(
} }
override def readExternal(in: ObjectInput) { override def readExternal(in: ObjectInput) {
blockManagerId = new BlockManagerId() blockManagerId = BlockManagerId(in)
blockManagerId.readExternal(in)
blockId = in.readUTF() blockId = in.readUTF()
storageLevel = new StorageLevel() storageLevel = StorageLevel(in)
storageLevel.readExternal(in)
memSize = in.readInt() memSize = in.readInt()
diskSize = in.readInt() diskSize = in.readInt()
} }
@ -90,7 +88,7 @@ private[spark]
case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster
private[spark] private[spark]
case class RemoveHost(host: String) extends ToBlockManagerMaster case class RemoveExecutor(execId: String) extends ToBlockManagerMaster
private[spark] private[spark]
case object StopBlockManagerMaster extends ToBlockManagerMaster case object StopBlockManagerMaster extends ToBlockManagerMaster
@ -100,3 +98,6 @@ case object GetMemoryStatus extends ToBlockManagerMaster
private[spark] private[spark]
case object ExpireDeadHosts extends ToBlockManagerMaster case object ExpireDeadHosts extends ToBlockManagerMaster
private[spark]
case object GetStorageStatus extends ToBlockManagerMaster

View file

@ -0,0 +1,76 @@
package spark.storage
import akka.actor.{ActorRef, ActorSystem}
import akka.util.Timeout
import akka.util.duration._
import cc.spray.typeconversion.TwirlSupport._
import cc.spray.Directives
import spark.{Logging, SparkContext}
import spark.util.AkkaUtils
import spark.Utils
/**
* Web UI server for the BlockManager inside each SparkContext.
*/
private[spark]
class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, sc: SparkContext)
extends Directives with Logging {
val STATIC_RESOURCE_DIR = "spark/deploy/static"
implicit val timeout = Timeout(10 seconds)
/** Start a HTTP server to run the Web interface */
def start() {
try {
val port = if (System.getProperty("spark.ui.port") != null) {
System.getProperty("spark.ui.port").toInt
} else {
// TODO: Unfortunately, it's not possible to pass port 0 to spray and figure out which
// random port it bound to, so we have to try to find a local one by creating a socket.
Utils.findFreePort()
}
AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, handler, "BlockManagerHTTPServer")
logInfo("Started BlockManager web UI at http://%s:%d".format(Utils.localHostName(), port))
} catch {
case e: Exception =>
logError("Failed to create BlockManager WebUI", e)
System.exit(1)
}
}
val handler = {
get {
path("") {
completeWith {
// Request the current storage status from the Master
val storageStatusList = sc.getExecutorStorageStatus
// Calculate macro-level statistics
val maxMem = storageStatusList.map(_.maxMem).reduce(_+_)
val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_)
val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize))
.reduceOption(_+_).getOrElse(0L)
val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc)
spark.storage.html.index.
render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList)
}
} ~
path("rdd") {
parameter("id") { id =>
completeWith {
val prefix = "rdd_" + id.toString
val storageStatusList = sc.getExecutorStorageStatus
val filteredStorageStatusList = StorageUtils.
filterStorageStatusByPrefix(storageStatusList, prefix)
val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head
spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList)
}
}
} ~
pathPrefix("static") {
getFromResourceDirectory(STATIC_RESOURCE_DIR)
}
}
}
}

View file

@ -64,7 +64,7 @@ private[spark] class BlockMessage() {
val booleanInt = buffer.getInt() val booleanInt = buffer.getInt()
val replication = buffer.getInt() val replication = buffer.getInt()
level = new StorageLevel(booleanInt, replication) level = StorageLevel(booleanInt, replication)
val dataLength = buffer.getInt() val dataLength = buffer.getInt()
data = ByteBuffer.allocate(dataLength) data = ByteBuffer.allocate(dataLength)

View file

@ -7,25 +7,30 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
* whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory
* in a serialized format, and whether to replicate the RDD partitions on multiple nodes. * in a serialized format, and whether to replicate the RDD partitions on multiple nodes.
* The [[spark.storage.StorageLevel$]] singleton object contains some static constants for * The [[spark.storage.StorageLevel$]] singleton object contains some static constants for
* commonly useful storage levels. * commonly useful storage levels. To create your own storage level object, use the factor method
* of the singleton object (`StorageLevel(...)`).
*/ */
class StorageLevel( class StorageLevel private(
var useDisk: Boolean, private var useDisk_ : Boolean,
var useMemory: Boolean, private var useMemory_ : Boolean,
var deserialized: Boolean, private var deserialized_ : Boolean,
var replication: Int = 1) private var replication_ : Int = 1)
extends Externalizable { extends Externalizable {
// TODO: Also add fields for caching priority, dataset ID, and flushing. // TODO: Also add fields for caching priority, dataset ID, and flushing.
private def this(flags: Int, replication: Int) {
assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes")
def this(flags: Int, replication: Int) {
this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication)
} }
def this() = this(false, true, false) // For deserialization def this() = this(false, true, false) // For deserialization
def useDisk = useDisk_
def useMemory = useMemory_
def deserialized = deserialized_
def replication = replication_
assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes")
override def clone(): StorageLevel = new StorageLevel( override def clone(): StorageLevel = new StorageLevel(
this.useDisk, this.useMemory, this.deserialized, this.replication) this.useDisk, this.useMemory, this.deserialized, this.replication)
@ -43,13 +48,13 @@ class StorageLevel(
def toInt: Int = { def toInt: Int = {
var ret = 0 var ret = 0
if (useDisk) { if (useDisk_) {
ret |= 4 ret |= 4
} }
if (useMemory) { if (useMemory_) {
ret |= 2 ret |= 2
} }
if (deserialized) { if (deserialized_) {
ret |= 1 ret |= 1
} }
return ret return ret
@ -57,15 +62,15 @@ class StorageLevel(
override def writeExternal(out: ObjectOutput) { override def writeExternal(out: ObjectOutput) {
out.writeByte(toInt) out.writeByte(toInt)
out.writeByte(replication) out.writeByte(replication_)
} }
override def readExternal(in: ObjectInput) { override def readExternal(in: ObjectInput) {
val flags = in.readByte() val flags = in.readByte()
useDisk = (flags & 4) != 0 useDisk_ = (flags & 4) != 0
useMemory = (flags & 2) != 0 useMemory_ = (flags & 2) != 0
deserialized = (flags & 1) != 0 deserialized_ = (flags & 1) != 0
replication = in.readByte() replication_ = in.readByte()
} }
@throws(classOf[IOException]) @throws(classOf[IOException])
@ -75,6 +80,14 @@ class StorageLevel(
"StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication) "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication)
override def hashCode(): Int = toInt * 41 + replication override def hashCode(): Int = toInt * 41 + replication
def description : String = {
var result = ""
result += (if (useDisk) "Disk " else "")
result += (if (useMemory) "Memory " else "")
result += (if (deserialized) "Deserialized " else "Serialized")
result += "%sx Replicated".format(replication)
result
}
} }
@ -91,6 +104,21 @@ object StorageLevel {
val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false) val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false)
val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2) val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2)
/** Create a new StorageLevel object */
def apply(useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, replication: Int = 1) =
getCachedStorageLevel(new StorageLevel(useDisk, useMemory, deserialized, replication))
/** Create a new StorageLevel object from its integer representation */
def apply(flags: Int, replication: Int) =
getCachedStorageLevel(new StorageLevel(flags, replication))
/** Read StorageLevel object from ObjectInput stream */
def apply(in: ObjectInput) = {
val obj = new StorageLevel()
obj.readExternal(in)
getCachedStorageLevel(obj)
}
private[spark] private[spark]
val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]()

View file

@ -0,0 +1,82 @@
package spark.storage
import spark.{Utils, SparkContext}
import BlockManagerMasterActor.BlockStatus
private[spark]
case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
blocks: Map[String, BlockStatus]) {
def memUsed(blockPrefix: String = "") = {
blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize).
reduceOption(_+_).getOrElse(0l)
}
def diskUsed(blockPrefix: String = "") = {
blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.diskSize).
reduceOption(_+_).getOrElse(0l)
}
def memRemaining : Long = maxMem - memUsed()
}
case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel,
numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) {
override def toString = {
import Utils.memoryBytesToString
"RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; DiskSize: %s".format(name, id,
storageLevel.toString, numCachedPartitions, numPartitions, memoryBytesToString(memSize), memoryBytesToString(diskSize))
}
}
/* Helper methods for storage-related objects */
private[spark]
object StorageUtils {
/* Given the current storage status of the BlockManager, returns information for each RDD */
def rddInfoFromStorageStatus(storageStatusList: Array[StorageStatus],
sc: SparkContext) : Array[RDDInfo] = {
rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc)
}
/* Given a list of BlockStatus objets, returns information for each RDD */
def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
sc: SparkContext) : Array[RDDInfo] = {
// Group by rddId, ignore the partition name
val groupedRddBlocks = infos.groupBy { case(k, v) =>
k.substring(0,k.lastIndexOf('_'))
}.mapValues(_.values.toArray)
// For each RDD, generate an RDDInfo object
groupedRddBlocks.map { case(rddKey, rddBlocks) =>
// Add up memory and disk sizes
val memSize = rddBlocks.map(_.memSize).reduce(_ + _)
val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _)
// Find the id of the RDD, e.g. rdd_1 => 1
val rddId = rddKey.split("_").last.toInt
// Get the friendly name for the rdd, if available.
val rdd = sc.persistentRdds(rddId)
val rddName = Option(rdd.name).getOrElse(rddKey)
val rddStorageLevel = rdd.getStorageLevel
RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.splits.size, memSize, diskSize)
}.toArray
}
/* Removes all BlockStatus object that are not part of a block prefix */
def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus],
prefix: String) : Array[StorageStatus] = {
storageStatusList.map { status =>
val newBlocks = status.blocks.filterKeys(_.startsWith(prefix))
//val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _)
StorageStatus(status.blockManagerId, status.maxMem, newBlocks)
}
}
}

View file

@ -75,10 +75,11 @@ private[spark] object ThreadingTest {
System.setProperty("spark.kryoserializer.buffer.mb", "1") System.setProperty("spark.kryoserializer.buffer.mb", "1")
val actorSystem = ActorSystem("test") val actorSystem = ActorSystem("test")
val serializer = new KryoSerializer val serializer = new KryoSerializer
val masterIp: String = System.getProperty("spark.master.host", "localhost") val driverIp: String = System.getProperty("spark.driver.host", "localhost")
val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, masterIp, masterPort) val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, driverIp, driverPort)
val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer, 1024 * 1024) val blockManager = new BlockManager(
"<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024)
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
producers.foreach(_.start) producers.foreach(_.start)

View file

@ -1,6 +1,6 @@
package spark.util package spark.util
import akka.actor.{Props, ActorSystemImpl, ActorSystem} import akka.actor.{ActorRef, Props, ActorSystemImpl, ActorSystem}
import com.typesafe.config.ConfigFactory import com.typesafe.config.ConfigFactory
import akka.util.duration._ import akka.util.duration._
import akka.pattern.ask import akka.pattern.ask
@ -18,9 +18,13 @@ import java.util.concurrent.TimeoutException
* Various utility classes for working with Akka. * Various utility classes for working with Akka.
*/ */
private[spark] object AkkaUtils { private[spark] object AkkaUtils {
/** /**
* Creates an ActorSystem ready for remoting, with various Spark features. Returns both the * Creates an ActorSystem ready for remoting, with various Spark features. Returns both the
* ActorSystem itself and its port (which is hard to get from Akka). * ActorSystem itself and its port (which is hard to get from Akka).
*
* Note: the `name` parameter is important, as even if a client sends a message to right
* host + port, if the system name is incorrect, Akka will drop the message.
*/ */
def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = { def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = {
val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt
@ -30,8 +34,10 @@ private[spark] object AkkaUtils {
val akkaConf = ConfigFactory.parseString(""" val akkaConf = ConfigFactory.parseString("""
akka.daemonic = on akka.daemonic = on
akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
akka.stdout-loglevel = "ERROR"
akka.actor.provider = "akka.remote.RemoteActorRefProvider" akka.actor.provider = "akka.remote.RemoteActorRefProvider"
akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" akka.remote.transport = "akka.remote.netty.NettyRemoteTransport"
akka.remote.log-remote-lifecycle-events = on
akka.remote.netty.hostname = "%s" akka.remote.netty.hostname = "%s"
akka.remote.netty.port = %d akka.remote.netty.port = %d
akka.remote.netty.connection-timeout = %ds akka.remote.netty.connection-timeout = %ds
@ -40,7 +46,7 @@ private[spark] object AkkaUtils {
akka.actor.default-dispatcher.throughput = %d akka.actor.default-dispatcher.throughput = %d
""".format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize)) """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize))
val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader) val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader)
// Figure out the port number we bound to, in case port was passed as 0. This is a bit of a // Figure out the port number we bound to, in case port was passed as 0. This is a bit of a
// hack because Akka doesn't let you figure out the port through the public API yet. // hack because Akka doesn't let you figure out the port through the public API yet.
@ -51,21 +57,22 @@ private[spark] object AkkaUtils {
/** /**
* Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to * Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to
* handle requests. Throws a SparkException if this fails. * handle requests. Returns the bound port or throws a SparkException on failure.
*/ */
def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route) { def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route,
name: String = "HttpServer"): ActorRef = {
val ioWorker = new IoWorker(actorSystem).start() val ioWorker = new IoWorker(actorSystem).start()
val httpService = actorSystem.actorOf(Props(new HttpService(route))) val httpService = actorSystem.actorOf(Props(new HttpService(route)))
val rootService = actorSystem.actorOf(Props(new SprayCanRootService(httpService))) val rootService = actorSystem.actorOf(Props(new SprayCanRootService(httpService)))
val server = actorSystem.actorOf( val server = actorSystem.actorOf(
Props(new HttpServer(ioWorker, SingletonHandler(rootService))), name = "HttpServer") Props(new HttpServer(ioWorker, SingletonHandler(rootService))), name = name)
actorSystem.registerOnTermination { ioWorker.stop() } actorSystem.registerOnTermination { ioWorker.stop() }
val timeout = 3.seconds val timeout = 3.seconds
val future = server.ask(HttpServer.Bind(ip, port))(timeout) val future = server.ask(HttpServer.Bind(ip, port))(timeout)
try { try {
Await.result(future, timeout) match { Await.result(future, timeout) match {
case bound: HttpServer.Bound => case bound: HttpServer.Bound =>
return return server
case other: Any => case other: Any =>
throw new SparkException("Failed to bind web UI to port " + port + ": " + other) throw new SparkException("Failed to bind web UI to port " + port + ": " + other)
} }

View file

@ -5,29 +5,29 @@ import java.util.{TimerTask, Timer}
import spark.Logging import spark.Logging
/**
* Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries)
*/
class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging {
private val delaySeconds = MetadataCleaner.getDelaySeconds
private val periodSeconds = math.max(10, delaySeconds / 10)
private val timer = new Timer(name + " cleanup timer", true)
val delaySeconds = MetadataCleaner.getDelaySeconds private val task = new TimerTask {
val periodSeconds = math.max(10, delaySeconds / 10) override def run() {
val timer = new Timer(name + " cleanup timer", true)
val task = new TimerTask {
def run() {
try { try {
if (delaySeconds > 0) { cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000))
cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) logInfo("Ran metadata cleaner for " + name)
logInfo("Ran metadata cleaner for " + name)
}
} catch { } catch {
case e: Exception => logError("Error running cleanup task for " + name, e) case e: Exception => logError("Error running cleanup task for " + name, e)
} }
} }
} }
if (periodSeconds > 0) { if (delaySeconds > 0) {
logInfo( logDebug(
"Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and " "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " +
+ "period of " + periodSeconds + " secs") "and period of " + periodSeconds + " secs")
timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000)
} }
@ -38,7 +38,7 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging
object MetadataCleaner { object MetadataCleaner {
def getDelaySeconds = (System.getProperty("spark.cleaner.delay", "-100").toDouble * 60).toInt def getDelaySeconds = System.getProperty("spark.cleaner.delay", "-1").toInt
def setDelaySeconds(delay: Long) { System.setProperty("spark.cleaner.delay", delay.toString) } def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.delay", delay.toString) }
} }

View file

@ -63,9 +63,9 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging {
override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() override def empty: Map[A, B] = new TimeStampedHashMap[A, B]()
override def size(): Int = internalMap.size() override def size: Int = internalMap.size
override def foreach[U](f: ((A, B)) => U): Unit = { override def foreach[U](f: ((A, B)) => U) {
val iterator = internalMap.entrySet().iterator() val iterator = internalMap.entrySet().iterator()
while(iterator.hasNext) { while(iterator.hasNext) {
val entry = iterator.next() val entry = iterator.next()

Some files were not shown because too many files have changed in this diff Show more