Merge branch 'streaming' into ScrapCode-streaming

Conflicts:
	streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
	streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala
This commit is contained in:
Tathagata Das 2013-02-18 13:26:12 -08:00
commit 6a6e6bda57
242 changed files with 9609 additions and 3478 deletions

View file

@ -23,7 +23,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter {
sc = null
}
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.master.port")
System.clearProperty("spark.driver.port")
}
test("halting by voting") {

View file

@ -26,7 +26,8 @@ fi
# Set SPARK_PUBLIC_DNS so the master report the correct webUI address to the slaves
if [ "$SPARK_PUBLIC_DNS" = "" ]; then
# If we appear to be running on EC2, use the public address by default:
if [[ `hostname` == *ec2.internal ]]; then
# NOTE: ec2-metadata is installed on Amazon Linux AMI. Check based on that and hostname
if command -v ec2-metadata > /dev/null || [[ `hostname` == *ec2.internal ]]; then
export SPARK_PUBLIC_DNS=`wget -q -O - http://instance-data.ec2.internal/latest/meta-data/public-hostname`
fi
fi

View file

@ -71,6 +71,10 @@
<groupId>cc.spray</groupId>
<artifactId>spray-server</artifactId>
</dependency>
<dependency>
<groupId>cc.spray</groupId>
<artifactId>spray-json_${scala.version}</artifactId>
</dependency>
<dependency>
<groupId>org.tomdz.twirl</groupId>
<artifactId>twirl-api</artifactId>
@ -94,6 +98,11 @@
<artifactId>scalacheck_${scala.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.easymock</groupId>
<artifactId>easymock</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.novocode</groupId>
<artifactId>junit-interface</artifactId>

View file

@ -25,8 +25,7 @@ class Accumulable[R, T] (
extends Serializable {
val id = Accumulators.newId
@transient
private var value_ = initialValue // Current value on master
@transient private var value_ = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers
var deserialized = false
@ -63,9 +62,12 @@ class Accumulable[R, T] (
/**
* Access the accumulator's current value; only allowed on master.
*/
def value = {
if (!deserialized) value_
else throw new UnsupportedOperationException("Can't read accumulator value in task")
def value: R = {
if (!deserialized) {
value_
} else {
throw new UnsupportedOperationException("Can't read accumulator value in task")
}
}
/**
@ -82,11 +84,18 @@ class Accumulable[R, T] (
/**
* Set the accumulator's value; only allowed on master.
*/
def value_= (r: R) {
if (!deserialized) value_ = r
def value_= (newValue: R) {
if (!deserialized) value_ = newValue
else throw new UnsupportedOperationException("Can't assign accumulator value in task")
}
/**
* Set the accumulator's value; only allowed on master
*/
def setValue(newValue: R) {
this.value = newValue
}
// Called by Java when deserializing an object
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()

View file

@ -1,118 +0,0 @@
package spark
import java.util.LinkedHashMap
/**
* An implementation of Cache that estimates the sizes of its entries and attempts to limit its
* total memory usage to a fraction of the JVM heap. Objects' sizes are estimated using
* SizeEstimator, which has limitations; most notably, we will overestimate total memory used if
* some cache entries have pointers to a shared object. Nonetheless, this Cache should work well
* when most of the space is used by arrays of primitives or of simple classes.
*/
private[spark] class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
logInfo("BoundedMemoryCache.maxBytes = " + maxBytes)
def this() {
this(BoundedMemoryCache.getMaxBytes)
}
private var currentBytes = 0L
private val map = new LinkedHashMap[(Any, Int), Entry](32, 0.75f, true)
override def get(datasetId: Any, partition: Int): Any = {
synchronized {
val entry = map.get((datasetId, partition))
if (entry != null) {
entry.value
} else {
null
}
}
}
override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
val key = (datasetId, partition)
logInfo("Asked to add key " + key)
val size = estimateValueSize(key, value)
synchronized {
if (size > getCapacity) {
return CachePutFailure()
} else if (ensureFreeSpace(datasetId, size)) {
logInfo("Adding key " + key)
map.put(key, new Entry(value, size))
currentBytes += size
logInfo("Number of entries is now " + map.size)
return CachePutSuccess(size)
} else {
logInfo("Didn't add key " + key + " because we would have evicted part of same dataset")
return CachePutFailure()
}
}
}
override def getCapacity: Long = maxBytes
/**
* Estimate sizeOf 'value'
*/
private def estimateValueSize(key: (Any, Int), value: Any) = {
val startTime = System.currentTimeMillis
val size = SizeEstimator.estimate(value.asInstanceOf[AnyRef])
val timeTaken = System.currentTimeMillis - startTime
logInfo("Estimated size for key %s is %d".format(key, size))
logInfo("Size estimation for key %s took %d ms".format(key, timeTaken))
size
}
/**
* Remove least recently used entries from the map until at least space bytes are free, in order
* to make space for a partition from the given dataset ID. If this cannot be done without
* evicting other data from the same dataset, returns false; otherwise, returns true. Assumes
* that a lock is held on the BoundedMemoryCache.
*/
private def ensureFreeSpace(datasetId: Any, space: Long): Boolean = {
logInfo("ensureFreeSpace(%s, %d) called with curBytes=%d, maxBytes=%d".format(
datasetId, space, currentBytes, maxBytes))
val iter = map.entrySet.iterator // Will give entries in LRU order
while (maxBytes - currentBytes < space && iter.hasNext) {
val mapEntry = iter.next()
val (entryDatasetId, entryPartition) = mapEntry.getKey
if (entryDatasetId == datasetId) {
// Cannot make space without removing part of the same dataset, or a more recently used one
return false
}
reportEntryDropped(entryDatasetId, entryPartition, mapEntry.getValue)
currentBytes -= mapEntry.getValue.size
iter.remove()
}
return true
}
protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size))
// TODO: remove BoundedMemoryCache
val (keySpaceId, innerDatasetId) = datasetId.asInstanceOf[(Any, Any)]
innerDatasetId match {
case rddId: Int =>
SparkEnv.get.cacheTracker.dropEntry(rddId, partition)
case broadcastUUID: java.util.UUID =>
// TODO: Maybe something should be done if the broadcasted variable falls out of cache
case _ =>
}
}
}
// An entry in our map; stores a cached object and its size in bytes
private[spark] case class Entry(value: Any, size: Long)
private[spark] object BoundedMemoryCache {
/**
* Get maximum cache capacity from system configuration
*/
def getMaxBytes: Long = {
val memoryFractionToUse = System.getProperty("spark.boundedMemoryCache.memoryFraction", "0.66").toDouble
(Runtime.getRuntime.maxMemory * memoryFractionToUse).toLong
}
}

View file

@ -0,0 +1,65 @@
package spark
import scala.collection.mutable.{ArrayBuffer, HashSet}
import spark.storage.{BlockManager, StorageLevel}
/** Spark class responsible for passing RDDs split contents to the BlockManager and making
sure a node doesn't load two copies of an RDD at once.
*/
private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
private val loading = new HashSet[String]
/** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
: Iterator[T] = {
val key = "rdd_%d_%d".format(rdd.id, split.index)
logInfo("Cache key is " + key)
blockManager.get(key) match {
case Some(cachedValues) =>
// Split is in cache, so just return its values
logInfo("Found partition in cache!")
return cachedValues.asInstanceOf[Iterator[T]]
case None =>
// Mark the split as loading (unless someone else marks it first)
loading.synchronized {
if (loading.contains(key)) {
logInfo("Loading contains " + key + ", waiting...")
while (loading.contains(key)) {
try {loading.wait()} catch {case _ =>}
}
logInfo("Loading no longer contains " + key + ", so returning cached result")
// See whether someone else has successfully loaded it. The main way this would fail
// is for the RDD-level cache eviction policy if someone else has loaded the same RDD
// partition but we didn't want to make space for it. However, that case is unlikely
// because it's unlikely that two threads would work on the same RDD partition. One
// downside of the current code is that threads wait serially if this does happen.
blockManager.get(key) match {
case Some(values) =>
return values.asInstanceOf[Iterator[T]]
case None =>
logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
loading.add(key)
}
} else {
loading.add(key)
}
}
try {
// If we got here, we have to load the split
val elements = new ArrayBuffer[Any]
logInfo("Computing partition " + split)
elements ++= rdd.computeOrReadCheckpoint(split, context)
// Try to put this block in the blockManager
blockManager.put(key, elements, storageLevel, true)
return elements.iterator.asInstanceOf[Iterator[T]]
} finally {
loading.synchronized {
loading.remove(key)
loading.notifyAll()
}
}
}
}
}

View file

@ -1,240 +0,0 @@
package spark
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import akka.actor._
import akka.dispatch._
import akka.pattern.ask
import akka.remote._
import akka.util.Duration
import akka.util.Timeout
import akka.util.duration._
import spark.storage.BlockManager
import spark.storage.StorageLevel
import util.{TimeStampedHashSet, MetadataCleaner, TimeStampedHashMap}
private[spark] sealed trait CacheTrackerMessage
private[spark] case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
private[spark] case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
private[spark] case class MemoryCacheLost(host: String) extends CacheTrackerMessage
private[spark] case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
private[spark] case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage
private[spark] case object GetCacheStatus extends CacheTrackerMessage
private[spark] case object GetCacheLocations extends CacheTrackerMessage
private[spark] case object StopCacheTracker extends CacheTrackerMessage
private[spark] class CacheTrackerActor extends Actor with Logging {
// TODO: Should probably store (String, CacheType) tuples
private val locs = new TimeStampedHashMap[Int, Array[List[String]]]
/**
* A map from the slave's host name to its cache size.
*/
private val slaveCapacity = new HashMap[String, Long]
private val slaveUsage = new HashMap[String, Long]
private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.clearOldValues)
private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
def receive = {
case SlaveCacheStarted(host: String, size: Long) =>
slaveCapacity.put(host, size)
slaveUsage.put(host, 0)
sender ! true
case RegisterRDD(rddId: Int, numPartitions: Int) =>
logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
sender ! true
case AddedToCache(rddId, partition, host, size) =>
slaveUsage.put(host, getCacheUsage(host) + size)
locs(rddId)(partition) = host :: locs(rddId)(partition)
sender ! true
case DroppedFromCache(rddId, partition, host, size) =>
slaveUsage.put(host, getCacheUsage(host) - size)
// Do a sanity check to make sure usage is greater than 0.
locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
sender ! true
case MemoryCacheLost(host) =>
logInfo("Memory cache lost on " + host)
for ((id, locations) <- locs) {
for (i <- 0 until locations.length) {
locations(i) = locations(i).filterNot(_ == host)
}
}
sender ! true
case GetCacheLocations =>
logInfo("Asked for current cache locations")
sender ! locs.map{case (rrdId, array) => (rrdId -> array.clone())}
case GetCacheStatus =>
val status = slaveCapacity.map { case (host, capacity) =>
(host, capacity, getCacheUsage(host))
}.toSeq
sender ! status
case StopCacheTracker =>
logInfo("Stopping CacheTrackerActor")
sender ! true
metadataCleaner.cancel()
context.stop(self)
}
}
private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager)
extends Logging {
// Tracker actor on the master, or remote reference to it on workers
val ip: String = System.getProperty("spark.master.host", "localhost")
val port: Int = System.getProperty("spark.master.port", "7077").toInt
val actorName: String = "CacheTracker"
val timeout = 10.seconds
var trackerActor: ActorRef = if (isMaster) {
val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName)
logInfo("Registered CacheTrackerActor actor")
actor
} else {
val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
actorSystem.actorFor(url)
}
// TODO: Consider removing this HashSet completely as locs CacheTrackerActor already
// keeps track of registered RDDs
val registeredRddIds = new TimeStampedHashSet[Int]
// Remembers which splits are currently being loaded (on worker nodes)
val loading = new HashSet[String]
val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.clearOldValues)
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
def askTracker(message: Any): Any = {
try {
val future = trackerActor.ask(message)(timeout)
return Await.result(future, timeout)
} catch {
case e: Exception =>
throw new SparkException("Error communicating with CacheTracker", e)
}
}
// Send a one-way message to the trackerActor, to which we expect it to reply with true.
def communicate(message: Any) {
if (askTracker(message) != true) {
throw new SparkException("Error reply received from CacheTracker")
}
}
// Registers an RDD (on master only)
def registerRDD(rddId: Int, numPartitions: Int) {
registeredRddIds.synchronized {
if (!registeredRddIds.contains(rddId)) {
logInfo("Registering RDD ID " + rddId + " with cache")
registeredRddIds += rddId
communicate(RegisterRDD(rddId, numPartitions))
}
}
}
// For BlockManager.scala only
def cacheLost(host: String) {
communicate(MemoryCacheLost(host))
logInfo("CacheTracker successfully removed entries on " + host)
}
// Get the usage status of slave caches. Each tuple in the returned sequence
// is in the form of (host name, capacity, usage).
def getCacheStatus(): Seq[(String, Long, Long)] = {
askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]]
}
// For BlockManager.scala only
def notifyFromBlockManager(t: AddedToCache) {
communicate(t)
}
// Get a snapshot of the currently known locations
def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]]
}
// Gets or computes an RDD split
def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
: Iterator[T] = {
val key = "rdd_%d_%d".format(rdd.id, split.index)
logInfo("Cache key is " + key)
blockManager.get(key) match {
case Some(cachedValues) =>
// Split is in cache, so just return its values
logInfo("Found partition in cache!")
return cachedValues.asInstanceOf[Iterator[T]]
case None =>
// Mark the split as loading (unless someone else marks it first)
loading.synchronized {
if (loading.contains(key)) {
logInfo("Loading contains " + key + ", waiting...")
while (loading.contains(key)) {
try {loading.wait()} catch {case _ =>}
}
logInfo("Loading no longer contains " + key + ", so returning cached result")
// See whether someone else has successfully loaded it. The main way this would fail
// is for the RDD-level cache eviction policy if someone else has loaded the same RDD
// partition but we didn't want to make space for it. However, that case is unlikely
// because it's unlikely that two threads would work on the same RDD partition. One
// downside of the current code is that threads wait serially if this does happen.
blockManager.get(key) match {
case Some(values) =>
return values.asInstanceOf[Iterator[T]]
case None =>
logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
loading.add(key)
}
} else {
loading.add(key)
}
}
try {
// If we got here, we have to load the split
val elements = new ArrayBuffer[Any]
logInfo("Computing partition " + split)
elements ++= rdd.compute(split, context)
// Try to put this block in the blockManager
blockManager.put(key, elements, storageLevel, true)
return elements.iterator.asInstanceOf[Iterator[T]]
} finally {
loading.synchronized {
loading.remove(key)
loading.notifyAll()
}
}
}
}
// Called by the Cache to report that an entry has been dropped from it
def dropEntry(rddId: Int, partition: Int) {
communicate(DroppedFromCache(rddId, partition, Utils.localHostName()))
}
def stop() {
communicate(StopCacheTracker)
registeredRddIds.clear()
trackerActor = null
}
}

View file

@ -1,18 +0,0 @@
package spark
import java.util.concurrent.ThreadFactory
/**
* A ThreadFactory that creates daemon threads
*/
private object DaemonThreadFactory extends ThreadFactory {
override def newThread(r: Runnable): Thread = new DaemonThread(r)
}
private class DaemonThread(r: Runnable = null) extends Thread {
override def run() {
if (r != null) {
r.run()
}
}
}

View file

@ -5,6 +5,7 @@ package spark
*/
abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
/**
* Base class for dependencies where each partition of the parent RDD is used by at most one
* partition of the child RDD. Narrow dependencies allow for pipelined execution.
@ -12,12 +13,13 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
/**
* Get the parent partitions for a child partition.
* @param outputPartition a partition of the child RDD
* @param partitionId a partition of the child RDD
* @return the partitions of the parent RDD that the child partition depends upon
*/
def getParents(outputPartition: Int): Seq[Int]
def getParents(partitionId: Int): Seq[Int]
}
/**
* Represents a dependency on the output of a shuffle stage.
* @param shuffleId the shuffle id
@ -32,6 +34,7 @@ class ShuffleDependency[K, V](
val shuffleId: Int = rdd.context.newShuffleId()
}
/**
* Represents a one-to-one dependency between partitions of the parent and child RDDs.
*/
@ -39,6 +42,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
override def getParents(partitionId: Int) = List(partitionId)
}
/**
* Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs.
* @param rdd the parent RDD

View file

@ -1,9 +1,7 @@
package spark
import java.io.{File, PrintWriter}
import java.net.URL
import scala.collection.mutable.HashMap
import org.apache.hadoop.fs.FileUtil
import java.io.{File}
import com.google.common.io.Files
private[spark] class HttpFileServer extends Logging {
@ -40,7 +38,7 @@ private[spark] class HttpFileServer extends Logging {
}
def addFileToDir(file: File, dir: File) : String = {
Utils.copyFile(file, new File(dir, file.getName))
Files.copy(file, new File(dir, file.getName))
return dir + "/" + file.getName
}

View file

@ -4,6 +4,7 @@ import java.io.File
import java.net.InetAddress
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.bio.SocketConnector
import org.eclipse.jetty.server.handler.DefaultHandler
import org.eclipse.jetty.server.handler.HandlerList
import org.eclipse.jetty.server.handler.ResourceHandler
@ -27,7 +28,13 @@ private[spark] class HttpServer(resourceBase: File) extends Logging {
if (server != null) {
throw new ServerStateException("Server is already started")
} else {
server = new Server(0)
server = new Server()
val connector = new SocketConnector
connector.setMaxIdleTime(60*1000)
connector.setSoLingerTime(-1)
connector.setPort(0)
server.addConnector(connector)
val threadPool = new QueuedThreadPool
threadPool.setDaemon(true)
server.setThreadPool(threadPool)

View file

@ -206,5 +206,8 @@ class KryoSerializer extends spark.serializer.Serializer with Logging {
kryo
}
def newInstance(): SerializerInstance = new KryoSerializerInstance(this)
def newInstance(): SerializerInstance = {
this.kryo.get().setClassLoader(Thread.currentThread().getContextClassLoader)
new KryoSerializerInstance(this)
}
}

View file

@ -11,8 +11,7 @@ import org.slf4j.LoggerFactory
trait Logging {
// Make the log field transient so that objects with Logging can
// be serialized and used on another machine
@transient
private var log_ : Logger = null
@transient private var log_ : Logger = null
// Method to get or create the logger for this object
protected def log: Logger = {

View file

@ -38,10 +38,7 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac
}
}
private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging {
val ip: String = System.getProperty("spark.master.host", "localhost")
val port: Int = System.getProperty("spark.master.port", "7077").toInt
val actorName: String = "MapOutputTracker"
private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolean) extends Logging {
val timeout = 10.seconds
@ -56,11 +53,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
var cacheGeneration = generation
val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
var trackerActor: ActorRef = if (isMaster) {
val actorName: String = "MapOutputTracker"
var trackerActor: ActorRef = if (isDriver) {
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
logInfo("Registered MapOutputTrackerActor actor")
actor
} else {
val ip = System.getProperty("spark.driver.host", "localhost")
val port = System.getProperty("spark.driver.port", "7077").toInt
val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
actorSystem.actorFor(url)
}
@ -114,7 +114,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
var array = mapStatuses(shuffleId)
if (array != null) {
array.synchronized {
if (array(mapId) != null && array(mapId).address == bmAddress) {
if (array(mapId) != null && array(mapId).location == bmAddress) {
array(mapId) = null
}
}
@ -142,8 +142,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
case e: InterruptedException =>
}
}
return mapStatuses(shuffleId).map(status =>
(status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId))))
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, mapStatuses(shuffleId))
} else {
fetching += shuffleId
}
@ -159,25 +158,19 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
fetchedStatuses = deserializeStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
if (fetchedStatuses.contains(null)) {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing an output location for shuffle " + shuffleId))
}
} finally {
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
}
return fetchedStatuses.map(s =>
(s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
} else {
return statuses.map(s =>
(s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
}
}
def cleanup(cleanupTime: Long) {
private def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}
@ -267,6 +260,28 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
private[spark] object MapOutputTracker {
private val LOG_BASE = 1.1
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),
// throw a FetchFailedException.
def convertMapStatuses(
shuffleId: Int,
reduceId: Int,
statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
if (statuses == null) {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing all output locations for shuffle " + shuffleId))
}
statuses.map {
status =>
if (status == null) {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing an output location for shuffle " + shuffleId))
} else {
(status.location, decompressSize(status.compressedSizes(reduceId)))
}
}
}
/**
* Compress a size in bytes to 8 bits for efficient reporting of map output sizes.
* We do this by encoding the log base 1.1 of the size as an integer, which can support

View file

@ -465,7 +465,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
val res = self.context.runJob(self, process _, Array(index), false)
res(0)
case None =>
self.filter(_._1 == key).map(_._2).collect
self.filter(_._1 == key).map(_._2).collect()
}
}
@ -485,18 +485,6 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
}
/**
* Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
*/
def saveAsNewAPIHadoopFile(
path: String,
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) {
saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, new Configuration)
}
/**
* Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
@ -506,7 +494,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[_ <: NewOutputFormat[_, _]],
conf: Configuration) {
conf: Configuration = self.context.hadoopConfiguration) {
val job = new NewAPIHadoopJob(conf)
job.setOutputKeyClass(keyClass)
job.setOutputValueClass(valueClass)
@ -557,7 +545,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[_ <: OutputFormat[_, _]],
conf: JobConf = new JobConf) {
conf: JobConf = new JobConf(self.context.hadoopConfiguration)) {
conf.setOutputKeyClass(keyClass)
conf.setOutputValueClass(valueClass)
// conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug
@ -602,7 +590,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
var count = 0
while(iter.hasNext) {
val record = iter.next
val record = iter.next()
count += 1
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
}
@ -615,6 +603,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
writer.cleanup()
}
/**
* Return an RDD with the keys of each tuple.
*/
def keys: RDD[K] = self.map(_._1)
/**
* Return an RDD with the values of each tuple.
*/
def values: RDD[V] = self.map(_._2)
private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure
private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure
@ -651,9 +649,7 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
}
private[spark]
class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U)
extends RDD[(K, U)](prev) {
class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev) {
override def getSplits = firstParent[(K, V)].splits
override val partitioner = firstParent[(K, V)].partitioner
override def compute(split: Split, context: TaskContext) =

View file

@ -23,32 +23,28 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest](
}
private[spark] class ParallelCollection[T: ClassManifest](
@transient sc : SparkContext,
@transient sc: SparkContext,
@transient data: Seq[T],
numSlices: Int,
locationPrefs : Map[Int,Seq[String]])
locationPrefs: Map[Int,Seq[String]])
extends RDD[T](sc, Nil) {
// TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
// cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
// instead.
// UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal.
@transient
var splits_ : Array[Split] = {
@transient var splits_ : Array[Split] = {
val slices = ParallelCollection.slice(data, numSlices).toArray
slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray
}
override def getSplits = splits_.asInstanceOf[Array[Split]]
override def getSplits = splits_
override def compute(s: Split, context: TaskContext) =
s.asInstanceOf[ParallelCollectionSplit[T]].iterator
override def getPreferredLocations(s: Split): Seq[String] = {
locationPrefs.get(s.index) match {
case Some(s) => s
case _ => Nil
}
locationPrefs.getOrElse(s.index, Nil)
}
override def clearDependencies() {
@ -56,7 +52,6 @@ private[spark] class ParallelCollection[T: ClassManifest](
}
}
private object ParallelCollection {
/**
* Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range

View file

@ -1,27 +1,17 @@
package spark
import java.io.{ObjectOutputStream, IOException, EOFException, ObjectInputStream}
import java.net.URL
import java.util.{Date, Random}
import java.util.{HashMap => JHashMap}
import java.util.concurrent.atomic.AtomicLong
import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.FileOutputCommitter
import org.apache.hadoop.mapred.HadoopWriter
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.OutputCommitter
import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.mapred.TextOutputFormat
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
@ -30,7 +20,6 @@ import spark.partial.BoundedDouble
import spark.partial.CountEvaluator
import spark.partial.GroupedCountEvaluator
import spark.partial.PartialResult
import spark.rdd.BlockRDD
import spark.rdd.CartesianRDD
import spark.rdd.FilteredRDD
import spark.rdd.FlatMappedRDD
@ -73,11 +62,11 @@ import SparkContext._
* on RDD internals.
*/
abstract class RDD[T: ClassManifest](
@transient var sc: SparkContext,
var dependencies_ : List[Dependency[_]]
@transient private var sc: SparkContext,
@transient private var deps: Seq[Dependency[_]]
) extends Serializable with Logging {
/** Construct an RDD with just a one-to-one dependency on one parent */
def this(@transient oneParent: RDD[_]) =
this(oneParent.context , List(new OneToOneDependency(oneParent)))
@ -85,14 +74,20 @@ abstract class RDD[T: ClassManifest](
// Methods that should be implemented by subclasses of RDD
// =======================================================================
/** Function for computing a given partition. */
/** Implemented by subclasses to compute a given partition. */
def compute(split: Split, context: TaskContext): Iterator[T]
/** Set of partitions in this RDD. */
protected def getSplits(): Array[Split]
/**
* Implemented by subclasses to return the set of partitions in this RDD. This method will only
* be called once, so it is safe to implement a time-consuming computation in it.
*/
protected def getSplits: Array[Split]
/** How this RDD depends on any parent RDDs. */
protected def getDependencies(): List[Dependency[_]] = dependencies_
/**
* Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only
* be called once, so it is safe to implement a time-consuming computation in it.
*/
protected def getDependencies: Seq[Dependency[_]] = deps
/** Optionally overridden by subclasses to specify placement preferences. */
protected def getPreferredLocations(split: Split): Seq[String] = Nil
@ -100,7 +95,6 @@ abstract class RDD[T: ClassManifest](
/** Optionally overridden by subclasses to specify how they are partitioned. */
val partitioner: Option[Partitioner] = None
// =======================================================================
// Methods and fields available on all RDDs
// =======================================================================
@ -108,6 +102,15 @@ abstract class RDD[T: ClassManifest](
/** A unique ID for this RDD (within its SparkContext). */
val id = sc.newRddId()
/** A friendly name for this RDD */
var name: String = null
/** Assign a name to this RDD */
def setName(_name: String) = {
name = _name
this
}
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
@ -119,6 +122,8 @@ abstract class RDD[T: ClassManifest](
"Cannot change storage level of an RDD after it was already assigned a level")
}
storageLevel = newLevel
// Register the RDD with the SparkContext
sc.persistentRdds(id) = this
this
}
@ -131,15 +136,24 @@ abstract class RDD[T: ClassManifest](
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
// Our dependencies and splits will be gotten by calling subclass's methods below, and will
// be overwritten when we're checkpointed
private var dependencies_ : Seq[Dependency[_]] = null
@transient private var splits_ : Array[Split] = null
/** An Option holding our checkpoint RDD, if we are checkpointed */
private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD)
/**
* Get the preferred location of a split, taking into account whether the
* Get the list of dependencies of this RDD, taking into account whether the
* RDD is checkpointed or not.
*/
final def preferredLocations(split: Split): Seq[String] = {
if (isCheckpointed) {
checkpointData.get.getPreferredLocations(split)
} else {
getPreferredLocations(split)
final def dependencies: Seq[Dependency[_]] = {
checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse {
if (dependencies_ == null) {
dependencies_ = getDependencies
}
dependencies_
}
}
@ -148,22 +162,21 @@ abstract class RDD[T: ClassManifest](
* RDD is checkpointed or not.
*/
final def splits: Array[Split] = {
if (isCheckpointed) {
checkpointData.get.getSplits
} else {
getSplits
checkpointRDD.map(_.splits).getOrElse {
if (splits_ == null) {
splits_ = getSplits
}
splits_
}
}
/**
* Get the list of dependencies of this RDD, taking into account whether the
* Get the preferred location of a split, taking into account whether the
* RDD is checkpointed or not.
*/
final def dependencies: List[Dependency[_]] = {
if (isCheckpointed) {
dependencies_
} else {
getDependencies
final def preferredLocations(split: Split): Seq[String] = {
checkpointRDD.map(_.getPreferredLocations(split)).getOrElse {
getPreferredLocations(split)
}
}
@ -173,10 +186,19 @@ abstract class RDD[T: ClassManifest](
* subclasses of RDD.
*/
final def iterator(split: Split, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
} else {
computeOrReadCheckpoint(split, context)
}
}
/**
* Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
*/
private[spark] def computeOrReadCheckpoint(split: Split, context: TaskContext): Iterator[T] = {
if (isCheckpointed) {
checkpointData.get.iterator(split, context)
} else if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel)
firstParent[T].iterator(split, context)
} else {
compute(split, context)
}
@ -348,6 +370,13 @@ abstract class RDD[T: ClassManifest](
*/
def toArray(): Array[T] = collect()
/**
* Return an RDD that contains all matching values by applying `f`.
*/
def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = {
filter(f.isDefinedAt).map(f)
}
/**
* Reduces the elements of this RDD using the specified associative binary operator.
*/
@ -356,21 +385,23 @@ abstract class RDD[T: ClassManifest](
val reducePartition: Iterator[T] => Option[T] = iter => {
if (iter.hasNext) {
Some(iter.reduceLeft(cleanF))
}else {
} else {
None
}
}
val options = sc.runJob(this, reducePartition)
val results = new ArrayBuffer[T]
for (opt <- options; elem <- opt) {
results += elem
var jobResult: Option[T] = None
val mergeResult = (index: Int, taskResult: Option[T]) => {
if (taskResult != None) {
jobResult = jobResult match {
case Some(value) => Some(f(value, taskResult.get))
case None => taskResult
}
if (results.size == 0) {
throw new UnsupportedOperationException("empty collection")
} else {
return results.reduceLeft(cleanF)
}
}
sc.runJob(this, reducePartition, mergeResult)
// Get the final result out of our Option, or throw an exception if the RDD was empty
jobResult.getOrElse(throw new UnsupportedOperationException("empty collection"))
}
/**
* Aggregate the elements of each partition, and then the results for all the partitions, using a
@ -379,9 +410,13 @@ abstract class RDD[T: ClassManifest](
* modify t2.
*/
def fold(zeroValue: T)(op: (T, T) => T): T = {
// Clone the zero value since we will also be serializing it as part of tasks
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
val cleanOp = sc.clean(op)
val results = sc.runJob(this, (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp))
return results.fold(zeroValue)(cleanOp)
val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp)
val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult)
sc.runJob(this, foldPartition, mergeResult)
jobResult
}
/**
@ -393,11 +428,14 @@ abstract class RDD[T: ClassManifest](
* allocation.
*/
def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
// Clone the zero value since we will also be serializing it as part of tasks
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
val cleanSeqOp = sc.clean(seqOp)
val cleanCombOp = sc.clean(combOp)
val results = sc.runJob(this,
(iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp))
return results.fold(zeroValue)(cleanCombOp)
val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
sc.runJob(this, aggregatePartition, mergeResult)
jobResult
}
/**
@ -408,7 +446,7 @@ abstract class RDD[T: ClassManifest](
var result = 0L
while (iter.hasNext) {
result += 1L
iter.next
iter.next()
}
result
}).sum
@ -423,7 +461,7 @@ abstract class RDD[T: ClassManifest](
var result = 0L
while (iter.hasNext) {
result += 1L
iter.next
iter.next()
}
result
}
@ -529,23 +567,29 @@ abstract class RDD[T: ClassManifest](
.saveAsSequenceFile(path)
}
/**
* Creates tuples of the elements in this RDD by applying `f`.
*/
def keyBy[K](f: T => K): RDD[(K, T)] = {
map(x => (f(x), x))
}
/** A private method for tests, to look at the contents of each partition */
private[spark] def collectPartitions(): Array[Array[T]] = {
sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
}
/**
* Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir`
* (set using setCheckpointDir()) and all references to its parent RDDs will be removed.
* This is used to truncate very long lineages. In the current implementation, Spark will save
* this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done.
* Hence, it is strongly recommended to use checkpoint() on RDDs when
* (i) checkpoint() is called before the any job has been executed on this RDD.
* (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will
* require recomputation.
* Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
* directory set with SparkContext.setCheckpointDir() and all references to its parent
* RDDs will be removed. This function must be called before any job has been
* executed on this RDD. It is strongly recommended that this RDD is persisted in
* memory, otherwise saving it on a file will require recomputation.
*/
def checkpoint() {
if (checkpointData.isEmpty) {
if (context.checkpointDir.isEmpty) {
throw new Exception("Checkpoint directory has not been set in the SparkContext")
} else if (checkpointData.isEmpty) {
checkpointData = Some(new RDDCheckpointData(this))
checkpointData.get.markForCheckpoint()
}
@ -554,15 +598,15 @@ abstract class RDD[T: ClassManifest](
/**
* Return whether this RDD has been checkpointed or not
*/
def isCheckpointed(): Boolean = {
if (checkpointData.isDefined) checkpointData.get.isCheckpointed() else false
def isCheckpointed: Boolean = {
checkpointData.map(_.isCheckpointed).getOrElse(false)
}
/**
* Gets the name of the file to which this RDD was checkpointed
*/
def getCheckpointFile(): Option[String] = {
if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None
def getCheckpointFile: Option[String] = {
checkpointData.flatMap(_.getCheckpointFile)
}
// =======================================================================
@ -587,31 +631,52 @@ abstract class RDD[T: ClassManifest](
def context = sc
/**
* Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler
* Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler
* after a job using this RDD has completed (therefore the RDD has been materialized and
* potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs.
*/
protected[spark] def doCheckpoint() {
if (checkpointData.isDefined) checkpointData.get.doCheckpoint()
private[spark] def doCheckpoint() {
if (checkpointData.isDefined) {
checkpointData.get.doCheckpoint()
} else {
dependencies.foreach(_.rdd.doCheckpoint())
}
}
/**
* Changes the dependencies of this RDD from its original parents to the new RDD
* (`newRDD`) created from the checkpoint file.
* Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`)
* created from the checkpoint file, and forget its old dependencies and splits.
*/
protected[spark] def changeDependencies(newRDD: RDD[_]) {
private[spark] def markCheckpointed(checkpointRDD: RDD[_]) {
clearDependencies()
dependencies_ = List(new OneToOneDependency(newRDD))
dependencies_ = null
splits_ = null
deps = null // Forget the constructor argument for dependencies too
}
/**
* Clears the dependencies of this RDD. This method must ensure that all references
* to the original parent RDDs is removed to enable the parent RDDs to be garbage
* collected. Subclasses of RDD may override this method for implementing their own cleaning
* logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea.
* logic. See [[spark.rdd.UnionRDD]] for an example.
*/
protected[spark] def clearDependencies() {
protected def clearDependencies() {
dependencies_ = null
}
/** A description of this RDD and its recursive dependencies for debugging. */
def toDebugString(): String = {
def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = {
Seq(prefix + rdd + " (" + rdd.splits.size + " splits)") ++
rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " "))
}
debugString(this).mkString("\n")
}
override def toString(): String = "%s%s[%d] at %s".format(
Option(name).map(_ + " ").getOrElse(""),
getClass.getSimpleName,
id,
origin)
}

View file

@ -20,7 +20,7 @@ private[spark] object CheckpointState extends Enumeration {
* of the checkpointed RDD.
*/
private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
extends Logging with Serializable {
extends Logging with Serializable {
import CheckpointState._
@ -31,7 +31,7 @@ extends Logging with Serializable {
@transient var cpFile: Option[String] = None
// The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
@transient var cpRDD: Option[RDD[T]] = None
var cpRDD: Option[RDD[T]] = None
// Mark the RDD for checkpointing
def markForCheckpoint() {
@ -41,12 +41,12 @@ extends Logging with Serializable {
}
// Is the RDD already checkpointed
def isCheckpointed(): Boolean = {
def isCheckpointed: Boolean = {
RDDCheckpointData.synchronized { cpState == Checkpointed }
}
// Get the file to which this RDD was checkpointed to as an Option
def getCheckpointFile(): Option[String] = {
def getCheckpointFile: Option[String] = {
RDDCheckpointData.synchronized { cpFile }
}
@ -63,7 +63,7 @@ extends Logging with Serializable {
}
// Save to file, and reload it as an RDD
val path = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString
val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id).toString
rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _)
val newRDD = new CheckpointRDD[T](rdd.context, path)
@ -71,7 +71,7 @@ extends Logging with Serializable {
RDDCheckpointData.synchronized {
cpFile = Some(path)
cpRDD = Some(newRDD)
rdd.changeDependencies(newRDD)
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and splits
cpState = Checkpointed
RDDCheckpointData.clearTaskCaches()
logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
@ -79,7 +79,7 @@ extends Logging with Serializable {
}
// Get preferred location of a split after checkpointing
def getPreferredLocations(split: Split) = {
def getPreferredLocations(split: Split): Seq[String] = {
RDDCheckpointData.synchronized {
cpRDD.get.preferredLocations(split)
}
@ -91,9 +91,10 @@ extends Logging with Serializable {
}
}
// Get iterator. This is called at the worker nodes.
def iterator(split: Split, context: TaskContext): Iterator[T] = {
rdd.firstParent[T].iterator(split, context)
def checkpointRDD: Option[RDD[T]] = {
RDDCheckpointData.synchronized {
cpRDD
}
}
}

View file

@ -42,7 +42,13 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) {
classManifest[T].erasure
} else {
implicitly[T => Writable].getClass.getMethods()(0).getReturnType
// We get the type of the Writable class by looking at the apply method which converts
// from T to Writable. Since we have two apply methods we filter out the one which
// is of the form "java.lang.Object apply(java.lang.Object)"
implicitly[T => Writable].getClass.getDeclaredMethods().filter(
m => m.getReturnType().toString != "java.lang.Object" &&
m.getName() == "apply")(0).getReturnType
}
// TODO: use something like WritableConverter to avoid reflection
}

View file

@ -9,7 +9,6 @@ import java.util.Random
import javax.management.MBeanServer
import java.lang.management.ManagementFactory
import com.sun.management.HotSpotDiagnosticMXBean
import scala.collection.mutable.ArrayBuffer
@ -76,12 +75,20 @@ private[spark] object SizeEstimator extends Logging {
if (System.getProperty("spark.test.useCompressedOops") != null) {
return System.getProperty("spark.test.useCompressedOops").toBoolean
}
try {
val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic"
val server = ManagementFactory.getPlatformMBeanServer()
// NOTE: This should throw an exception in non-Sun JVMs
val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean")
val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption",
Class.forName("java.lang.String"))
val bean = ManagementFactory.newPlatformMXBeanProxy(server,
hotSpotMBeanName, classOf[HotSpotDiagnosticMXBean])
return bean.getVMOption("UseCompressedOops").getValue.toBoolean
hotSpotMBeanName, hotSpotMBeanClass)
// TODO: We could use reflection on the VMOption returned ?
return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true")
} catch {
case e: Exception => {
// Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB

View file

@ -1,6 +1,7 @@
package spark
import java.io._
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger
import java.net.{URI, URLClassLoader}
import java.lang.ref.WeakReference
@ -8,6 +9,7 @@ import java.lang.ref.WeakReference
import scala.collection.Map
import scala.collection.generic.Growable
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.JavaConversions._
import akka.actor.Actor
import akka.actor.Actor._
@ -42,6 +44,9 @@ import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler}
import spark.scheduler.local.LocalScheduler
import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import storage.BlockManagerUI
import util.{MetadataCleaner, TimeStampedHashMap}
import storage.{StorageStatus, StorageUtils, RDDInfo}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@ -57,57 +62,53 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend
class SparkContext(
val master: String,
val jobName: String,
val sparkHome: String,
val jars: Seq[String],
environment: Map[String, String])
val sparkHome: String = null,
val jars: Seq[String] = Nil,
environment: Map[String, String] = Map())
extends Logging {
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param jobName A name for your job, to display on the cluster web UI
* @param sparkHome Location where Spark is installed on cluster nodes.
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
*/
def this(master: String, jobName: String, sparkHome: String, jars: Seq[String]) =
this(master, jobName, sparkHome, jars, Map())
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param jobName A name for your job, to display on the cluster web UI
*/
def this(master: String, jobName: String) = this(master, jobName, null, Nil, Map())
// Ensure logging is initialized before we spawn any threads
initLogging()
// Set Spark master host and port system properties
if (System.getProperty("spark.master.host") == null) {
System.setProperty("spark.master.host", Utils.localIpAddress)
// Set Spark driver host and port system properties
if (System.getProperty("spark.driver.host") == null) {
System.setProperty("spark.driver.host", Utils.localIpAddress)
}
if (System.getProperty("spark.master.port") == null) {
System.setProperty("spark.master.port", "0")
if (System.getProperty("spark.driver.port") == null) {
System.setProperty("spark.driver.port", "0")
}
private val isLocal = (master == "local" || master.startsWith("local["))
// Create the Spark execution environment (cache, map output tracker, etc)
private[spark] val env = SparkEnv.createFromSystemProperties(
System.getProperty("spark.master.host"),
System.getProperty("spark.master.port").toInt,
"<driver>",
System.getProperty("spark.driver.host"),
System.getProperty("spark.driver.port").toInt,
true,
isLocal)
SparkEnv.set(env)
// Start the BlockManager UI
private[spark] val ui = new BlockManagerUI(
env.actorSystem, env.blockManager.master.driverActor, this)
ui.start()
// Used to store a URL for each static file/jar together with the file's local timestamp
private[spark] val addedFiles = HashMap[String, Long]()
private[spark] val addedJars = HashMap[String, Long]()
// Keeps track of all persisted RDDs
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]()
private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup)
// Add each JAR given through the constructor
jars.foreach { addJar(_) }
// Environment variables to pass to our executors
private[spark] val executorEnvs = HashMap[String, String]()
// Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner
for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS",
"SPARK_TESTING")) {
val value = System.getenv(key)
@ -127,6 +128,8 @@ class SparkContext(
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """(spark://.*)""".r
//Regular expression for connection to Mesos cluster
val MESOS_REGEX = """(mesos://.*)""".r
master match {
case "local" =>
@ -167,6 +170,9 @@ class SparkContext(
scheduler
case _ =>
if (MESOS_REGEX.findFirstIn(master).isEmpty) {
logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
}
MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this)
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
@ -183,8 +189,28 @@ class SparkContext(
taskScheduler.start()
private var dagScheduler = new DAGScheduler(taskScheduler)
dagScheduler.start()
private[spark] var checkpointDir: String = null
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = {
val conf = new Configuration()
// Explicitly check for S3 environment variables
if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
}
// Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) {
conf.set(key.substring("spark.hadoop.".length), System.getProperty(key))
}
val bufferSize = System.getProperty("spark.buffer.size", "65536")
conf.set("io.file.buffer.size", bufferSize)
conf
}
private[spark] var checkpointDir: Option[String] = None
// Methods for creating RDDs
@ -238,10 +264,8 @@ class SparkContext(
valueClass: Class[V],
minSplits: Int = defaultMinSplits
) : RDD[(K, V)] = {
val conf = new JobConf()
val conf = new JobConf(hadoopConfiguration)
FileInputFormat.setInputPaths(conf, path)
val bufferSize = System.getProperty("spark.buffer.size", "65536")
conf.set("io.file.buffer.size", bufferSize)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
}
@ -282,8 +306,7 @@ class SparkContext(
path,
fm.erasure.asInstanceOf[Class[F]],
km.erasure.asInstanceOf[Class[K]],
vm.erasure.asInstanceOf[Class[V]],
new Configuration)
vm.erasure.asInstanceOf[Class[V]])
}
/**
@ -295,7 +318,7 @@ class SparkContext(
fClass: Class[F],
kClass: Class[K],
vClass: Class[V],
conf: Configuration): RDD[(K, V)] = {
conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration
@ -307,7 +330,7 @@ class SparkContext(
* and extra configuration options to pass to the input format.
*/
def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]](
conf: Configuration,
conf: Configuration = hadoopConfiguration,
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): RDD[(K, V)] = {
@ -390,14 +413,14 @@ class SparkContext(
/**
* Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values
* to using the `+=` method. Only the master can access the accumulator's `value`.
* to using the `+=` method. Only the driver can access the accumulator's `value`.
*/
def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
new Accumulator(initialValue, param)
/**
* Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`.
* Only the master can access the accumuable's `value`.
* Only the driver can access the accumuable's `value`.
* @tparam T accumulator type
* @tparam R type that can be added to the accumulator
*/
@ -422,9 +445,10 @@ class SparkContext(
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
/**
* Add a file to be downloaded into the working directory of this Spark job on every node.
* Add a file to be downloaded with this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
* filesystems), or an HTTP, HTTPS or FTP URI.
* filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
* use `SparkFiles.get(path)` to find its download location.
*/
def addFile(path: String) {
val uri = new URI(path)
@ -437,7 +461,7 @@ class SparkContext(
// Fetch the file locally in case a job is executed locally.
// Jobs that run through LocalScheduler will already fetch the required dependencies,
// but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here.
Utils.fetchFile(path, new File("."))
Utils.fetchFile(path, new File(SparkFiles.getRootDirectory))
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
@ -446,12 +470,27 @@ class SparkContext(
* Return a map from the slave to the max memory available for caching and the remaining
* memory available for caching.
*/
def getSlavesMemoryStatus: Map[String, (Long, Long)] = {
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
(blockManagerId.ip + ":" + blockManagerId.port, mem)
}
}
/**
* Return information about what RDDs are cached, if they are in mem or on disk, how much space
* they take, etc.
*/
def getRDDStorageInfo : Array[RDDInfo] = {
StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
}
/**
* Return information about blocks stored in all of the slaves
*/
def getExecutorStorageStatus : Array[StorageStatus] = {
env.blockManager.master.getStorageStatus
}
/**
* Clear the job's list of files added by `addFile` so that they do not get downloaded to
* any new nodes.
@ -486,6 +525,7 @@ class SparkContext(
/** Shut down the SparkContext. */
def stop() {
if (dagScheduler != null) {
metadataCleaner.cancel()
dagScheduler.stop()
dagScheduler = null
taskScheduler = null
@ -521,10 +561,30 @@ class SparkContext(
}
/**
* Run a function on a given set of partitions in an RDD and return the results. This is the main
* entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies
* whether the scheduler can run the computation on the master rather than shipping it out to the
* cluster, for short actions like first().
* Run a function on a given set of partitions in an RDD and pass the results to the given
* handler function. This is the main entry point for all actions in Spark. The allowLocal
* flag specifies whether the scheduler can run the computation on the driver rather than
* shipping it out to the cluster, for short actions like first().
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
val callSite = Utils.getSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
result
}
/**
* Run a function on a given set of partitions in an RDD and return the results as an array. The
* allowLocal flag specifies whether the scheduler can run the computation on the driver rather
* than shipping it out to the cluster, for short actions like first().
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
@ -532,13 +592,9 @@ class SparkContext(
partitions: Seq[Int],
allowLocal: Boolean
): Array[U] = {
val callSite = Utils.getSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
result
val results = new Array[U](partitions.size)
runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res)
results
}
/**
@ -568,6 +624,29 @@ class SparkContext(
runJob(rdd, func, 0 until rdd.splits.size, false)
}
/**
* Run a job on all partitions in an RDD and pass the results to a handler function.
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
processPartition: (TaskContext, Iterator[T]) => U,
resultHandler: (Int, U) => Unit)
{
runJob[T, U](rdd, processPartition, 0 until rdd.splits.size, false, resultHandler)
}
/**
* Run a job on all partitions in an RDD and pass the results to a handler function.
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
processPartition: Iterator[T] => U,
resultHandler: (Int, U) => Unit)
{
val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter)
runJob[T, U](rdd, processFunc, 0 until rdd.splits.size, false, resultHandler)
}
/**
* Run a job that can return approximate results.
*/
@ -595,10 +674,11 @@ class SparkContext(
}
/**
* Set the directory under which RDDs are going to be checkpointed. This method will
* create this directory and will throw an exception of the path already exists (to avoid
* overwriting existing files may be overwritten). The directory will be deleted on exit
* if indicated.
* Set the directory under which RDDs are going to be checkpointed. The directory must
* be a HDFS path if running on a cluster. If the directory does not exist, it will
* be created. If the directory exists and useExisting is set to true, then the
* exisiting directory will be used. Otherwise an exception will be thrown to
* prevent accidental overriding of checkpoint files in the existing directory.
*/
def setCheckpointDir(dir: String, useExisting: Boolean = false) {
val path = new Path(dir)
@ -610,7 +690,7 @@ class SparkContext(
fs.mkdirs(path)
}
}
checkpointDir = dir
checkpointDir = Some(dir)
}
/** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */
@ -627,6 +707,11 @@ class SparkContext(
/** Register a new RDD, returning its RDD ID */
private[spark] def newRddId(): Int = nextRddId.getAndIncrement()
/** Called by MetadataCleaner to clean up the persistentRdds map periodically */
private[spark] def cleanup(cleanupTime: Long) {
persistentRdds.clearOldValues(cleanupTime)
}
}
/**
@ -645,6 +730,16 @@ object SparkContext {
def zero(initialValue: Int) = 0
}
implicit object LongAccumulatorParam extends AccumulatorParam[Long] {
def addInPlace(t1: Long, t2: Long) = t1 + t2
def zero(initialValue: Long) = 0l
}
implicit object FloatAccumulatorParam extends AccumulatorParam[Float] {
def addInPlace(t1: Float, t2: Float) = t1 + t2
def zero(initialValue: Float) = 0f
}
// TODO: Add AccumulatorParams for other types, e.g. lists and strings
implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =

View file

@ -19,27 +19,23 @@ import spark.util.AkkaUtils
* SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set.
*/
class SparkEnv (
val executorId: String,
val actorSystem: ActorSystem,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheTracker: CacheTracker,
val cacheManager: CacheManager,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager,
val httpFileServer: HttpFileServer
val httpFileServer: HttpFileServer,
val sparkFilesDir: String
) {
/** No-parameter constructor for unit tests. */
def this() = {
this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
}
def stop() {
httpFileServer.stop()
mapOutputTracker.stop()
cacheTracker.stop()
shuffleFetcher.stop()
broadcastManager.stop()
blockManager.stop()
@ -63,17 +59,18 @@ object SparkEnv extends Logging {
}
def createFromSystemProperties(
executorId: String,
hostname: String,
port: Int,
isMaster: Boolean,
isLocal: Boolean
) : SparkEnv = {
isDriver: Boolean,
isLocal: Boolean): SparkEnv = {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port)
// Bit of a hack: If this is the master and our port was 0 (meaning bind to any free port),
// figure out which port number Akka actually bound to and set spark.master.port to it.
if (isMaster && port == 0) {
System.setProperty("spark.master.port", boundPort.toString)
// Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port),
// figure out which port number Akka actually bound to and set spark.driver.port to it.
if (isDriver && port == 0) {
System.setProperty("spark.driver.port", boundPort.toString)
}
val classLoader = Thread.currentThread.getContextClassLoader
@ -87,23 +84,22 @@ object SparkEnv extends Logging {
val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
val masterIp: String = System.getProperty("spark.master.host", "localhost")
val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt
val driverIp: String = System.getProperty("spark.driver.host", "localhost")
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
val blockManagerMaster = new BlockManagerMaster(
actorSystem, isMaster, isLocal, masterIp, masterPort)
val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer)
actorSystem, isDriver, isLocal, driverIp, driverPort)
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer)
val connectionManager = blockManager.connectionManager
val broadcastManager = new BroadcastManager(isMaster)
val broadcastManager = new BroadcastManager(isDriver)
val closureSerializer = instantiateClass[Serializer](
"spark.closure.serializer", "spark.JavaSerializer")
val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager)
blockManager.cacheTracker = cacheTracker
val cacheManager = new CacheManager(blockManager)
val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster)
val mapOutputTracker = new MapOutputTracker(actorSystem, isDriver)
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
@ -112,6 +108,15 @@ object SparkEnv extends Logging {
httpFileServer.initialize()
System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
// Set the sparkFiles directory, used when downloading dependencies. In local mode,
// this is a temporary directory; in distributed mode, this is the executor's current working
// directory.
val sparkFilesDir: String = if (isDriver) {
Utils.createTempDir().getAbsolutePath
} else {
"."
}
// Warn about deprecated spark.cache.class property
if (System.getProperty("spark.cache.class") != null) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
@ -119,15 +124,17 @@ object SparkEnv extends Logging {
}
new SparkEnv(
executorId,
actorSystem,
serializer,
closureSerializer,
cacheTracker,
cacheManager,
mapOutputTracker,
shuffleFetcher,
broadcastManager,
blockManager,
connectionManager,
httpFileServer)
httpFileServer,
sparkFilesDir)
}
}

View file

@ -0,0 +1,25 @@
package spark;
import java.io.File;
/**
* Resolves paths to files added through `SparkContext.addFile()`.
*/
public class SparkFiles {
private SparkFiles() {}
/**
* Get the absolute path of a file added through `SparkContext.addFile()`.
*/
public static String get(String filename) {
return new File(getRootDirectory(), filename).getAbsolutePath();
}
/**
* Get the root directory that contains files added through `SparkContext.addFile()`.
*/
public static String getRootDirectory() {
return SparkEnv.get().sparkFilesDir();
}
}

View file

@ -5,8 +5,7 @@ import scala.collection.mutable.ArrayBuffer
class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable {
@transient
val onCompleteCallbacks = new ArrayBuffer[() => Unit]
@transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]
// Add a callback function to be executed on task completion. An example use
// is for HadoopRDD to register a callback to close the input stream.

View file

@ -1,7 +1,7 @@
package spark
import java.io._
import java.net.{NetworkInterface, InetAddress, URL, URI}
import java.net._
import java.util.{Locale, Random, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import org.apache.hadoop.conf.Configuration
@ -10,6 +10,9 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import scala.io.Source
import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
import scala.Some
import spark.serializer.SerializerInstance
/**
* Various utility methods used by Spark.
@ -111,20 +114,6 @@ private object Utils extends Logging {
}
}
/** Copy a file on the local file system */
def copyFile(source: File, dest: File) {
val in = new FileInputStream(source)
val out = new FileOutputStream(dest)
copyStream(in, out, true)
}
/** Download a file from a given URL to the local filesystem */
def downloadFile(url: URL, localPath: String) {
val in = url.openStream()
val out = new FileOutputStream(localPath)
Utils.copyStream(in, out, true)
}
/**
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
@ -134,7 +123,7 @@ private object Utils extends Logging {
*/
def fetchFile(url: String, targetDir: File) {
val filename = url.split("/").last
val tempDir = System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))
val tempDir = getLocalDir
val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir))
val targetFile = new File(targetDir, filename)
val uri = new URI(url)
@ -201,7 +190,16 @@ private object Utils extends Logging {
Utils.execute(Seq("tar", "-xf", filename), targetDir)
}
// Make the file executable - That's necessary for scripts
FileUtil.chmod(filename, "a+x")
FileUtil.chmod(targetFile.getAbsolutePath, "a+x")
}
/**
* Get a temporary directory using Spark's spark.local.dir property, if set. This will always
* return a single directory, even though the spark.local.dir property might be a list of
* multiple paths.
*/
def getLocalDir: String = {
System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0)
}
/**
@ -242,7 +240,8 @@ private object Utils extends Logging {
// Address resolves to something like 127.0.1.1, which happens on Debian; try to find
// a better address using the local network interfaces
for (ni <- NetworkInterface.getNetworkInterfaces) {
for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && !addr.isLoopbackAddress) {
for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress &&
!addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) {
// We've found an address that looks reasonable!
logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" +
" a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress +
@ -277,29 +276,14 @@ private object Utils extends Logging {
customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
}
/**
* Returns a standard ThreadFactory except all threads are daemons.
*/
private def newDaemonThreadFactory: ThreadFactory = {
new ThreadFactory {
def newThread(r: Runnable): Thread = {
var t = Executors.defaultThreadFactory.newThread (r)
t.setDaemon (true)
return t
}
}
}
private[spark] val daemonThreadFactory: ThreadFactory =
new ThreadFactoryBuilder().setDaemon(true).build()
/**
* Wrapper over newCachedThreadPool.
*/
def newDaemonCachedThreadPool(): ThreadPoolExecutor = {
var threadPool = Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory (newDaemonThreadFactory)
return threadPool
}
def newDaemonCachedThreadPool(): ThreadPoolExecutor =
Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
/**
* Return the string to tell how long has passed in seconds. The passing parameter should be in
@ -312,13 +296,8 @@ private object Utils extends Logging {
/**
* Wrapper over newFixedThreadPool.
*/
def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
var threadPool = Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory(newDaemonThreadFactory)
return threadPool
}
def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor =
Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
/**
* Delete a file or directory and its contents recursively.
@ -454,4 +433,25 @@ private object Utils extends Logging {
}
"%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
}
/**
* Try to find a free port to bind to on the local host. This should ideally never be needed,
* except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray)
* don't let users bind to port 0 and then figure out which free port they actually bound to.
* We work around this by binding a ServerSocket and immediately unbinding it. This is *not*
* necessarily guaranteed to work, but it's the best we can do.
*/
def findFreePort(): Int = {
val socket = new ServerSocket(0)
val portBound = socket.getLocalPort
socket.close()
portBound
}
/**
* Clone an object using a Spark serializer.
*/
def clone[T](value: T, serializer: SerializerInstance): T = {
serializer.deserialize[T](serializer.serialize(value))
}
}

View file

@ -471,6 +471,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x)
fromRDD(new OrderedRDDFunctions(rdd).sortByKey(ascending))
}
/**
* Return an RDD with the keys of each tuple.
*/
def keys(): JavaRDD[K] = JavaRDD.fromRDD[K](rdd.map(_._1))
/**
* Return an RDD with the values of each tuple.
*/
def values(): JavaRDD[V] = JavaRDD.fromRDD[V](rdd.map(_._2))
}
object JavaPairRDD {

View file

@ -12,7 +12,7 @@ import spark.storage.StorageLevel
import com.google.common.base.Optional
trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround[T] {
def wrapRDD(rdd: RDD[T]): This
implicit val classManifest: ClassManifest[T]
@ -82,10 +82,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
* Return a new RDD by first applying a function to all elements of this
* RDD, and then flattening the results.
* Part of the workaround for SPARK-668; called in PairFlatMapWorkaround.java.
*/
def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = {
private[spark] def doFlatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
@ -301,21 +300,26 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path)
/**
* Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir`
* (set using setCheckpointDir()) and all references to its parent RDDs will be removed.
* This is used to truncate very long lineages. In the current implementation, Spark will save
* this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done.
* Hence, it is strongly recommended to use checkpoint() on RDDs when
* (i) checkpoint() is called before the any job has been executed on this RDD.
* (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will
* require recomputation.
* Creates tuples of the elements in this RDD by applying `f`.
*/
def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = {
implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
JavaPairRDD.fromRDD(rdd.keyBy(f))
}
/**
* Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
* directory set with SparkContext.setCheckpointDir() and all references to its parent
* RDDs will be removed. This function must be called before any job has been
* executed on this RDD. It is strongly recommended that this RDD is persisted in
* memory, otherwise saving it on a file will require recomputation.
*/
def checkpoint() = rdd.checkpoint()
/**
* Return whether this RDD has been checkpointed or not
*/
def isCheckpointed(): Boolean = rdd.isCheckpointed()
def isCheckpointed: Boolean = rdd.isCheckpointed
/**
* Gets the name of the file to which this RDD was checkpointed
@ -326,4 +330,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
case _ => Optional.absent()
}
}
/** A description of this RDD and its recursive dependencies for debugging. */
def toDebugString(): String = {
rdd.toDebugString()
}
}

View file

@ -277,6 +277,19 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] =
sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]]
/**
* Create an [[spark.Accumulator]] integer variable, which tasks can "add" values
* to using the `add` method. Only the master can access the accumulator's `value`.
*/
def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue)
/**
* Create an [[spark.Accumulator]] double variable, which tasks can "add" values
* to using the `add` method. Only the master can access the accumulator's `value`.
*/
def accumulator(initialValue: Double): Accumulator[java.lang.Double] =
doubleAccumulator(initialValue)
/**
* Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values
* to using the `add` method. Only the master can access the accumulator's `value`.
@ -310,9 +323,10 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
def getSparkHome(): Option[String] = sc.getSparkHome()
/**
* Add a file to be downloaded into the working directory of this Spark job on every node.
* Add a file to be downloaded with this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
* filesystems), or an HTTP, HTTPS or FTP URI.
* filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
* use `SparkFiles.get(path)` to find its download location.
*/
def addFile(path: String) {
sc.addFile(path)
@ -344,20 +358,28 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
}
/**
* Set the directory under which RDDs are going to be checkpointed. This method will
* create this directory and will throw an exception of the path already exists (to avoid
* overwriting existing files may be overwritten). The directory will be deleted on exit
* if indicated.
* Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse.
*/
def hadoopConfiguration(): Configuration = {
sc.hadoopConfiguration
}
/**
* Set the directory under which RDDs are going to be checkpointed. The directory must
* be a HDFS path if running on a cluster. If the directory does not exist, it will
* be created. If the directory exists and useExisting is set to true, then the
* exisiting directory will be used. Otherwise an exception will be thrown to
* prevent accidental overriding of checkpoint files in the existing directory.
*/
def setCheckpointDir(dir: String, useExisting: Boolean) {
sc.setCheckpointDir(dir, useExisting)
}
/**
* Set the directory under which RDDs are going to be checkpointed. This method will
* create this directory and will throw an exception of the path already exists (to avoid
* overwriting existing files may be overwritten). The directory will be deleted on exit
* if indicated.
* Set the directory under which RDDs are going to be checkpointed. The directory must
* be a HDFS path if running on a cluster. If the directory does not exist, it will
* be created. If the directory exists, an exception will be thrown to prevent accidental
* overriding of checkpoint files.
*/
def setCheckpointDir(dir: String) {
sc.setCheckpointDir(dir)

View file

@ -0,0 +1,20 @@
package spark.api.java;
import spark.api.java.JavaPairRDD;
import spark.api.java.JavaRDDLike;
import spark.api.java.function.PairFlatMapFunction;
import java.io.Serializable;
/**
* Workaround for SPARK-668.
*/
class PairFlatMapWorkaround<T> implements Serializable {
/**
* Return a new RDD by first applying a function to all elements of this
* RDD, and then flattening the results.
*/
public <K, V> JavaPairRDD<K, V> flatMap(PairFlatMapFunction<T, K, V> f) {
return ((JavaRDDLike <T, ?>) this).doFlatMap(f);
}
}

View file

@ -17,4 +17,15 @@ public class StorageLevels {
public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2);
public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1);
public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2);
/**
* Create a new StorageLevel object.
* @param useDisk saved to disk, if true
* @param useMemory saved to memory, if true
* @param deserialized saved as deserialized objects, if true
* @param replication replication factor
*/
public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) {
return StorageLevel.apply(useDisk, useMemory, deserialized, replication);
}
}

View file

@ -0,0 +1,48 @@
package spark.api.python
import spark.Partitioner
import java.util.Arrays
/**
* A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
*
* Stores the unique id() of the Python-side partitioning function so that it is incorporated into
* equality comparisons. Correctness requires that the id is a unique identifier for the
* lifetime of the job (i.e. that it is not re-used as the id of a different partitioning
* function). This can be ensured by using the Python id() function and maintaining a reference
* to the Python partitioning function so that its id() is not reused.
*/
private[spark] class PythonPartitioner(
override val numPartitions: Int,
val pyPartitionFunctionId: Long)
extends Partitioner {
override def getPartition(key: Any): Int = {
if (key == null) {
return 0
}
else {
val hashCode = {
if (key.isInstanceOf[Array[Byte]]) {
Arrays.hashCode(key.asInstanceOf[Array[Byte]])
} else {
key.hashCode()
}
}
val mod = hashCode % numPartitions
if (mod < 0) {
mod + numPartitions
} else {
mod // Guard against negative hash codes
}
}
}
override def equals(other: Any): Boolean = other match {
case h: PythonPartitioner =>
h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId
case _ =>
false
}
}

View file

@ -0,0 +1,309 @@
package spark.api.python
import java.io._
import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Collections}
import scala.collection.JavaConversions._
import scala.io.Source
import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import spark.broadcast.Broadcast
import spark._
import spark.rdd.PipedRDD
private[spark] class PythonRDD[T: ClassManifest](
parent: RDD[T],
command: Seq[String],
envVars: java.util.Map[String, String],
preservePartitoning: Boolean,
pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
preservePartitoning: Boolean, pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]]) =
this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec,
broadcastVars, accumulator)
override def getSplits = parent.splits
override val partitioner = if (preservePartitoning) parent.partitioner else None
override def compute(split: Split, context: TaskContext): Iterator[Array[Byte]] = {
val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py"))
// Add the environmental variables to the process.
val currentEnvVars = pb.environment()
for ((variable, value) <- envVars) {
currentEnvVars.put(variable, value)
}
val proc = pb.start()
val env = SparkEnv.get
// Start a thread to print the process's stderr to ours
new Thread("stderr reader for " + command) {
override def run() {
for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
System.err.println(line)
}
}
}.start()
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + command) {
override def run() {
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
val dOut = new DataOutputStream(proc.getOutputStream)
// Split index
dOut.writeInt(split.index)
// sparkFilesDir
PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut)
// Broadcast variables
dOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
dOut.writeLong(broadcast.id)
dOut.writeInt(broadcast.value.length)
dOut.write(broadcast.value)
dOut.flush()
}
// Serialized user code
for (elem <- command) {
out.println(elem)
}
out.flush()
// Data values
for (elem <- parent.iterator(split, context)) {
PythonRDD.writeAsPickle(elem, dOut)
}
dOut.flush()
out.flush()
proc.getOutputStream.close()
}
}.start()
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(proc.getInputStream)
return new Iterator[Array[Byte]] {
def next(): Array[Byte] = {
val obj = _nextObj
_nextObj = read()
obj
}
private def read(): Array[Byte] = {
try {
stream.readInt() match {
case length if length > 0 =>
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
case -2 =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
throw new PythonException(new String(obj))
case -1 =>
// We've finished the data section of the output, but we can still read some
// accumulator updates; let's do that, breaking when we get EOFException
while (true) {
val len2 = stream.readInt()
val update = new Array[Byte](len2)
stream.readFully(update)
accumulator += Collections.singletonList(update)
}
new Array[Byte](0)
}
} catch {
case eof: EOFException => {
val exitStatus = proc.waitFor()
if (exitStatus != 0) {
throw new Exception("Subprocess exited with status " + exitStatus)
}
new Array[Byte](0)
}
case e => throw e
}
}
var _nextObj = read()
def hasNext = _nextObj.length != 0
}
}
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
}
/** Thrown for exceptions in user Python code. */
private class PythonException(msg: String) extends Exception(msg)
/**
* Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
* This is used by PySpark's shuffle operations.
*/
private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
RDD[(Array[Byte], Array[Byte])](prev) {
override def getSplits = prev.splits
override def compute(split: Split, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (a, b)
case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
}
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
}
private[spark] object PythonRDD {
/** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
def stripPickle(arr: Array[Byte]) : Array[Byte] = {
arr.slice(2, arr.length - 1)
}
/**
* Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
* The data format is a 32-bit integer representing the pickled object's length (in bytes),
* followed by the pickled data.
*
* Pickle module:
*
* http://docs.python.org/2/library/pickle.html
*
* The pickle protocol is documented in the source of the `pickle` and `pickletools` modules:
*
* http://hg.python.org/cpython/file/2.6/Lib/pickle.py
* http://hg.python.org/cpython/file/2.6/Lib/pickletools.py
*
* @param elem the object to write
* @param dOut a data output stream
*/
def writeAsPickle(elem: Any, dOut: DataOutputStream) {
if (elem.isInstanceOf[Array[Byte]]) {
val arr = elem.asInstanceOf[Array[Byte]]
dOut.writeInt(arr.length)
dOut.write(arr)
} else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes
dOut.writeInt(length)
dOut.writeByte(Pickle.PROTO)
dOut.writeByte(Pickle.TWO)
dOut.write(PythonRDD.stripPickle(t._1))
dOut.write(PythonRDD.stripPickle(t._2))
dOut.writeByte(Pickle.TUPLE2)
dOut.writeByte(Pickle.STOP)
} else if (elem.isInstanceOf[String]) {
// For uniformity, strings are wrapped into Pickles.
val s = elem.asInstanceOf[String].getBytes("UTF-8")
val length = 2 + 1 + 4 + s.length + 1
dOut.writeInt(length)
dOut.writeByte(Pickle.PROTO)
dOut.writeByte(Pickle.TWO)
dOut.write(Pickle.BINUNICODE)
dOut.writeInt(Integer.reverseBytes(s.length))
dOut.write(s)
dOut.writeByte(Pickle.STOP)
} else {
throw new Exception("Unexpected RDD type")
}
}
def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
try {
while (true) {
val length = file.readInt()
val obj = new Array[Byte](length)
file.readFully(obj)
objs.append(obj)
}
} catch {
case eof: EOFException => {}
case e => throw e
}
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
import scala.collection.JavaConverters._
writeIteratorToPickleFile(items.asScala, filename)
}
def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
for (item <- items) {
writeAsPickle(item, file)
}
file.close()
}
def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = {
implicit val cm : ClassManifest[T] = rdd.elementClassManifest
rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator
}
}
private object Pickle {
val PROTO: Byte = 0x80.toByte
val TWO: Byte = 0x02.toByte
val BINUNICODE: Byte = 'X'
val STOP: Byte = '.'
val TUPLE2: Byte = 0x86.toByte
val EMPTY_LIST: Byte = ']'
val MARK: Byte = '('
val APPENDS: Byte = 'e'
}
private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
}
/**
* Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
* collects a list of pickled strings that we pass to Python through a socket.
*/
class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
: JList[Array[Byte]] = {
if (serverHost == null) {
// This happens on the worker node, where we just want to remember all the updates
val1.addAll(val2)
val1
} else {
// This happens on the master, where we pass the updates to Python through a socket
val socket = new Socket(serverHost, serverPort)
val in = socket.getInputStream
val out = new DataOutputStream(socket.getOutputStream)
out.writeInt(val2.size)
for (array <- val2) {
out.writeInt(array.length)
out.write(array)
}
out.flush()
// Wait for a byte from the Python side as an acknowledgement
val byteRead = in.read()
if (byteRead == -1) {
throw new SparkException("EOF reached before Python server acknowledged")
}
socket.close()
null
}
}
}

View file

@ -31,7 +31,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
@transient var totalBlocks = -1
@transient var hasBlocks = new AtomicInteger(0)
// Used ONLY by Master to track how many unique blocks have been sent out
// Used ONLY by driver to track how many unique blocks have been sent out
@transient var sentBlocks = new AtomicInteger(0)
@transient var listenPortLock = new Object
@ -42,7 +42,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
@transient var serveMR: ServeMultipleRequests = null
// Used only in Master
// Used only in driver
@transient var guideMR: GuideMultipleRequests = null
// Used only in Workers
@ -99,14 +99,14 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
}
// Must always come AFTER listenPort is created
val masterSource =
val driverSource =
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
hasBlocksBitVector.synchronized {
masterSource.hasBlocksBitVector = hasBlocksBitVector
driverSource.hasBlocksBitVector = hasBlocksBitVector
}
// In the beginning, this is the only known source to Guide
listOfSources += masterSource
listOfSources += driverSource
// Register with the Tracker
MultiTracker.registerBroadcast(id,
@ -122,7 +122,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
case None =>
logInfo("Started reading broadcast variable " + id)
// Initializing everything because Master will only send null/0 values
// Initializing everything because driver will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables()
@ -151,7 +151,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
}
}
// Initialize variables in the worker node. Master sends everything as 0/null
// Initialize variables in the worker node. Driver sends everything as 0/null
private def initializeWorkerVariables() {
arrayOfBlocks = null
hasBlocksBitVector = null
@ -248,7 +248,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
// Receive source information from Guide
var suitableSources =
oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]]
logDebug("Received suitableSources from Master " + suitableSources)
logDebug("Received suitableSources from Driver " + suitableSources)
addToListOfSources(suitableSources)
@ -532,7 +532,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
oosSource.writeObject(blockToAskFor)
oosSource.flush()
// CHANGED: Master might send some other block than the one
// CHANGED: Driver might send some other block than the one
// requested to ensure fast spreading of all blocks.
val recvStartTime = System.currentTimeMillis
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
@ -982,9 +982,9 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
// Receive which block to send
var blockToSend = ois.readObject.asInstanceOf[Int]
// If it is master AND at least one copy of each block has not been
// If it is driver AND at least one copy of each block has not been
// sent out already, MODIFY blockToSend
if (MultiTracker.isMaster && sentBlocks.get < totalBlocks) {
if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) {
blockToSend = sentBlocks.getAndIncrement
}
@ -1031,7 +1031,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
private[spark] class BitTorrentBroadcastFactory
extends BroadcastFactory {
def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) }
def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new BitTorrentBroadcast[T](value_, isLocal, id)

View file

@ -5,7 +5,7 @@ import java.util.concurrent.atomic.AtomicLong
import spark._
abstract class Broadcast[T](id: Long) extends Serializable {
abstract class Broadcast[T](private[spark] val id: Long) extends Serializable {
def value: T
// We cannot have an abstract readObject here due to some weird issues with
@ -15,7 +15,7 @@ abstract class Broadcast[T](id: Long) extends Serializable {
}
private[spark]
class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializable {
class BroadcastManager(val _isDriver: Boolean) extends Logging with Serializable {
private var initialized = false
private var broadcastFactory: BroadcastFactory = null
@ -33,7 +33,7 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isMaster)
broadcastFactory.initialize(isDriver)
initialized = true
}
@ -49,5 +49,5 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl
def newBroadcast[T](value_ : T, isLocal: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
def isMaster = isMaster_
def isDriver = _isDriver
}

View file

@ -7,7 +7,7 @@ package spark.broadcast
* entire Spark job.
*/
private[spark] trait BroadcastFactory {
def initialize(isMaster: Boolean): Unit
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long): Broadcast[T]
def initialize(isDriver: Boolean): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit
}

View file

@ -48,7 +48,7 @@ extends Broadcast[T](id) with Logging with Serializable {
}
private[spark] class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isMaster: Boolean) { HttpBroadcast.initialize(isMaster) }
def initialize(isDriver: Boolean) { HttpBroadcast.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
@ -69,12 +69,12 @@ private object HttpBroadcast extends Logging {
private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup)
def initialize(isMaster: Boolean) {
def initialize(isDriver: Boolean) {
synchronized {
if (!initialized) {
bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
compress = System.getProperty("spark.broadcast.compress", "true").toBoolean
if (isMaster) {
if (isDriver) {
createServer()
}
serverUri = System.getProperty("spark.httpBroadcast.uri")
@ -95,7 +95,7 @@ private object HttpBroadcast extends Logging {
}
private def createServer() {
broadcastDir = Utils.createTempDir()
broadcastDir = Utils.createTempDir(Utils.getLocalDir)
server = new HttpServer(broadcastDir)
server.start()
serverUri = server.uri

View file

@ -23,25 +23,24 @@ extends Logging {
var ranGen = new Random
private var initialized = false
private var isMaster_ = false
private var _isDriver = false
private var stopBroadcast = false
private var trackMV: TrackMultipleValues = null
def initialize(isMaster__ : Boolean) {
def initialize(__isDriver: Boolean) {
synchronized {
if (!initialized) {
_isDriver = __isDriver
isMaster_ = isMaster__
if (isMaster) {
if (isDriver) {
trackMV = new TrackMultipleValues
trackMV.setDaemon(true)
trackMV.start()
// Set masterHostAddress to the master's IP address for the slaves to read
System.setProperty("spark.MultiTracker.MasterHostAddress", Utils.localIpAddress)
// Set DriverHostAddress to the driver's IP address for the slaves to read
System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress)
}
initialized = true
@ -54,10 +53,10 @@ extends Logging {
}
// Load common parameters
private var MasterHostAddress_ = System.getProperty(
"spark.MultiTracker.MasterHostAddress", "")
private var MasterTrackerPort_ = System.getProperty(
"spark.broadcast.masterTrackerPort", "11111").toInt
private var DriverHostAddress_ = System.getProperty(
"spark.MultiTracker.DriverHostAddress", "")
private var DriverTrackerPort_ = System.getProperty(
"spark.broadcast.driverTrackerPort", "11111").toInt
private var BlockSize_ = System.getProperty(
"spark.broadcast.blockSize", "4096").toInt * 1024
private var MaxRetryCount_ = System.getProperty(
@ -91,11 +90,11 @@ extends Logging {
private var EndGameFraction_ = System.getProperty(
"spark.broadcast.endGameFraction", "0.95").toDouble
def isMaster = isMaster_
def isDriver = _isDriver
// Common config params
def MasterHostAddress = MasterHostAddress_
def MasterTrackerPort = MasterTrackerPort_
def DriverHostAddress = DriverHostAddress_
def DriverTrackerPort = DriverTrackerPort_
def BlockSize = BlockSize_
def MaxRetryCount = MaxRetryCount_
@ -123,7 +122,7 @@ extends Logging {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(MasterTrackerPort)
serverSocket = new ServerSocket(DriverTrackerPort)
logInfo("TrackMultipleValues started at " + serverSocket)
try {
@ -235,7 +234,7 @@ extends Logging {
try {
// Connect to the tracker to find out GuideInfo
clientSocketToTracker =
new Socket(MultiTracker.MasterHostAddress, MultiTracker.MasterTrackerPort)
new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort)
oosTracker =
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
oosTracker.flush()
@ -276,7 +275,7 @@ extends Logging {
}
def registerBroadcast(id: Long, gInfo: SourceInfo) {
val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
val oisST = new ObjectInputStream(socket.getInputStream)
@ -303,7 +302,7 @@ extends Logging {
}
def unregisterBroadcast(id: Long) {
val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
val oisST = new ObjectInputStream(socket.getInputStream)

View file

@ -98,7 +98,7 @@ extends Broadcast[T](id) with Logging with Serializable {
case None =>
logInfo("Started reading broadcast variable " + id)
// Initializing everything because Master will only send null/0 values
// Initializing everything because Driver will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables()
@ -157,55 +157,55 @@ extends Broadcast[T](id) with Logging with Serializable {
listenPortLock.synchronized { listenPortLock.wait() }
}
var clientSocketToMaster: Socket = null
var oosMaster: ObjectOutputStream = null
var oisMaster: ObjectInputStream = null
var clientSocketToDriver: Socket = null
var oosDriver: ObjectOutputStream = null
var oisDriver: ObjectInputStream = null
// Connect and receive broadcast from the specified source, retrying the
// specified number of times in case of failures
var retriesLeft = MultiTracker.MaxRetryCount
do {
// Connect to Master and send this worker's Information
clientSocketToMaster = new Socket(MultiTracker.MasterHostAddress, gInfo.listenPort)
oosMaster = new ObjectOutputStream(clientSocketToMaster.getOutputStream)
oosMaster.flush()
oisMaster = new ObjectInputStream(clientSocketToMaster.getInputStream)
// Connect to Driver and send this worker's Information
clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort)
oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream)
oosDriver.flush()
oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream)
logDebug("Connected to Master's guiding object")
logDebug("Connected to Driver's guiding object")
// Send local source information
oosMaster.writeObject(SourceInfo(hostAddress, listenPort))
oosMaster.flush()
oosDriver.writeObject(SourceInfo(hostAddress, listenPort))
oosDriver.flush()
// Receive source information from Master
var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
// Receive source information from Driver
var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo]
totalBlocks = sourceInfo.totalBlocks
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
totalBytes = sourceInfo.totalBytes
logDebug("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort)
val start = System.nanoTime
val receptionSucceeded = receiveSingleTransmission(sourceInfo)
val time = (System.nanoTime - start) / 1e9
// Updating some statistics in sourceInfo. Master will be using them later
// Updating some statistics in sourceInfo. Driver will be using them later
if (!receptionSucceeded) {
sourceInfo.receptionFailed = true
}
// Send back statistics to the Master
oosMaster.writeObject(sourceInfo)
// Send back statistics to the Driver
oosDriver.writeObject(sourceInfo)
if (oisMaster != null) {
oisMaster.close()
if (oisDriver != null) {
oisDriver.close()
}
if (oosMaster != null) {
oosMaster.close()
if (oosDriver != null) {
oosDriver.close()
}
if (clientSocketToMaster != null) {
clientSocketToMaster.close()
if (clientSocketToDriver != null) {
clientSocketToDriver.close()
}
retriesLeft -= 1
@ -552,7 +552,7 @@ extends Broadcast[T](id) with Logging with Serializable {
}
private def sendObject() {
// Wait till receiving the SourceInfo from Master
// Wait till receiving the SourceInfo from Driver
while (totalBlocks == -1) {
totalBlocksLock.synchronized { totalBlocksLock.wait() }
}
@ -576,7 +576,7 @@ extends Broadcast[T](id) with Logging with Serializable {
private[spark] class TreeBroadcastFactory
extends BroadcastFactory {
def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) }
def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TreeBroadcast[T](value_, isLocal, id)

View file

@ -4,7 +4,6 @@ import spark.deploy.ExecutorState.ExecutorState
import spark.deploy.master.{WorkerInfo, JobInfo}
import spark.deploy.worker.ExecutorRunner
import scala.collection.immutable.List
import scala.collection.mutable.HashMap
private[spark] sealed trait DeployMessage extends Serializable
@ -42,7 +41,8 @@ private[spark] case class LaunchExecutor(
execId: Int,
jobDesc: JobDescription,
cores: Int,
memory: Int)
memory: Int,
sparkHome: String)
extends DeployMessage

View file

@ -4,7 +4,8 @@ private[spark] class JobDescription(
val name: String,
val cores: Int,
val memoryPerSlave: Int,
val command: Command)
val command: Command,
val sparkHome: String)
extends Serializable {
val user = System.getProperty("user.name", "<unknown>")

View file

@ -0,0 +1,78 @@
package spark.deploy
import master.{JobInfo, WorkerInfo}
import worker.ExecutorRunner
import cc.spray.json._
/**
* spray-json helper class containing implicit conversion to json for marshalling responses
*/
private[spark] object JsonProtocol extends DefaultJsonProtocol {
implicit object WorkerInfoJsonFormat extends RootJsonWriter[WorkerInfo] {
def write(obj: WorkerInfo) = JsObject(
"id" -> JsString(obj.id),
"host" -> JsString(obj.host),
"webuiaddress" -> JsString(obj.webUiAddress),
"cores" -> JsNumber(obj.cores),
"coresused" -> JsNumber(obj.coresUsed),
"memory" -> JsNumber(obj.memory),
"memoryused" -> JsNumber(obj.memoryUsed)
)
}
implicit object JobInfoJsonFormat extends RootJsonWriter[JobInfo] {
def write(obj: JobInfo) = JsObject(
"starttime" -> JsNumber(obj.startTime),
"id" -> JsString(obj.id),
"name" -> JsString(obj.desc.name),
"cores" -> JsNumber(obj.desc.cores),
"user" -> JsString(obj.desc.user),
"memoryperslave" -> JsNumber(obj.desc.memoryPerSlave),
"submitdate" -> JsString(obj.submitDate.toString))
}
implicit object JobDescriptionJsonFormat extends RootJsonWriter[JobDescription] {
def write(obj: JobDescription) = JsObject(
"name" -> JsString(obj.name),
"cores" -> JsNumber(obj.cores),
"memoryperslave" -> JsNumber(obj.memoryPerSlave),
"user" -> JsString(obj.user)
)
}
implicit object ExecutorRunnerJsonFormat extends RootJsonWriter[ExecutorRunner] {
def write(obj: ExecutorRunner) = JsObject(
"id" -> JsNumber(obj.execId),
"memory" -> JsNumber(obj.memory),
"jobid" -> JsString(obj.jobId),
"jobdesc" -> obj.jobDesc.toJson.asJsObject
)
}
implicit object MasterStateJsonFormat extends RootJsonWriter[MasterState] {
def write(obj: MasterState) = JsObject(
"url" -> JsString("spark://" + obj.uri),
"workers" -> JsArray(obj.workers.toList.map(_.toJson)),
"cores" -> JsNumber(obj.workers.map(_.cores).sum),
"coresused" -> JsNumber(obj.workers.map(_.coresUsed).sum),
"memory" -> JsNumber(obj.workers.map(_.memory).sum),
"memoryused" -> JsNumber(obj.workers.map(_.memoryUsed).sum),
"activejobs" -> JsArray(obj.activeJobs.toList.map(_.toJson)),
"completedjobs" -> JsArray(obj.completedJobs.toList.map(_.toJson))
)
}
implicit object WorkerStateJsonFormat extends RootJsonWriter[WorkerState] {
def write(obj: WorkerState) = JsObject(
"id" -> JsString(obj.workerId),
"masterurl" -> JsString(obj.masterUrl),
"masterwebuiurl" -> JsString(obj.masterWebUiUrl),
"cores" -> JsNumber(obj.cores),
"coresused" -> JsNumber(obj.coresUsed),
"memory" -> JsNumber(obj.memory),
"memoryused" -> JsNumber(obj.memoryUsed),
"executors" -> JsArray(obj.executors.toList.map(_.toJson)),
"finishedexecutors" -> JsArray(obj.finishedExecutors.toList.map(_.toJson))
)
}
}

View file

@ -9,43 +9,32 @@ import spark.{Logging, Utils}
import scala.collection.mutable.ArrayBuffer
/**
* Testing class that creates a Spark standalone process in-cluster (that is, running the
* spark.deploy.master.Master and spark.deploy.worker.Workers in the same JVMs). Executors launched
* by the Workers still run in separate JVMs. This can be used to test distributed operation and
* fault recovery without spinning up a lot of processes.
*/
private[spark]
class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) extends Logging {
class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
val localIpAddress = Utils.localIpAddress
private val localIpAddress = Utils.localIpAddress
private val masterActorSystems = ArrayBuffer[ActorSystem]()
private val workerActorSystems = ArrayBuffer[ActorSystem]()
var masterActor : ActorRef = _
var masterActorSystem : ActorSystem = _
var masterPort : Int = _
var masterUrl : String = _
val slaveActorSystems = ArrayBuffer[ActorSystem]()
val slaveActors = ArrayBuffer[ActorRef]()
def start() : String = {
logInfo("Starting a local Spark cluster with " + numSlaves + " slaves.")
def start(): String = {
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
/* Start the Master */
val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0)
masterActorSystem = actorSystem
masterUrl = "spark://" + localIpAddress + ":" + masterPort
val actor = masterActorSystem.actorOf(
Props(new Master(localIpAddress, masterPort, 0)), name = "Master")
masterActor = actor
val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0)
masterActorSystems += masterSystem
val masterUrl = "spark://" + localIpAddress + ":" + masterPort
/* Start the Slaves */
for (slaveNum <- 1 to numSlaves) {
/* We can pretend to test distributed stuff by giving the slaves distinct hostnames.
All of 127/8 should be a loopback, we use 127.100.*.* in hopes that it is
sufficiently distinctive. */
val slaveIpAddress = "127.100.0." + (slaveNum % 256)
val (actorSystem, boundPort) =
AkkaUtils.createActorSystem("sparkWorker" + slaveNum, slaveIpAddress, 0)
slaveActorSystems += actorSystem
val actor = actorSystem.actorOf(
Props(new Worker(slaveIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)),
name = "Worker")
slaveActors += actor
/* Start the Workers */
for (workerNum <- 1 to numWorkers) {
val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker,
memoryPerWorker, masterUrl, null, Some(workerNum))
workerActorSystems += workerSystem
}
return masterUrl
@ -53,10 +42,10 @@ class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int)
def stop() {
logInfo("Shutting down local Spark cluster.")
// Stop the slaves before the master so they don't get upset that it disconnected
slaveActorSystems.foreach(_.shutdown())
slaveActorSystems.foreach(_.awaitTermination())
masterActorSystem.shutdown()
masterActorSystem.awaitTermination()
// Stop the workers before the master so they don't get upset that it disconnected
workerActorSystems.foreach(_.shutdown())
workerActorSystems.foreach(_.awaitTermination())
masterActorSystems.foreach(_.shutdown())
masterActorSystems.foreach(_.awaitTermination())
}
}

View file

@ -9,6 +9,7 @@ import spark.{SparkException, Logging}
import akka.remote.RemoteClientLifeCycleEvent
import akka.remote.RemoteClientShutdown
import spark.deploy.RegisterJob
import spark.deploy.master.Master
import akka.remote.RemoteClientDisconnected
import akka.actor.Terminated
import akka.dispatch.Await
@ -24,26 +25,18 @@ private[spark] class Client(
listener: ClientListener)
extends Logging {
val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r
var actor: ActorRef = null
var jobId: String = null
if (MASTER_REGEX.unapplySeq(masterUrl) == None) {
throw new SparkException("Invalid master URL: " + masterUrl)
}
class ClientActor extends Actor with Logging {
var master: ActorRef = null
var masterAddress: Address = null
var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
override def preStart() {
val Seq(masterHost, masterPort) = MASTER_REGEX.unapplySeq(masterUrl).get
logInfo("Connecting to master spark://" + masterHost + ":" + masterPort)
val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
logInfo("Connecting to master " + masterUrl)
try {
master = context.actorFor(akkaUrl)
master = context.actorFor(Master.toAkkaUrl(masterUrl))
masterAddress = master.path.address
master ! RegisterJob(jobDescription)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])

View file

@ -12,7 +12,7 @@ private[spark] trait ClientListener {
def disconnected(): Unit
def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int): Unit
def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int): Unit
def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit
def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit
}

View file

@ -25,7 +25,7 @@ private[spark] object TestClient {
val url = args(0)
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0)
val desc = new JobDescription(
"TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()))
"TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home")
val listener = new TestListener
val client = new Client(actorSystem, url, desc, listener)
client.start()

View file

@ -10,7 +10,7 @@ private[spark] class JobInfo(
val id: String,
val desc: JobDescription,
val submitDate: Date,
val actor: ActorRef)
val driver: ActorRef)
{
var state = JobState.WAITING
var executors = new mutable.HashMap[Int, ExecutorInfo]

View file

@ -88,7 +88,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
execOption match {
case Some(exec) => {
exec.state = state
exec.job.actor ! ExecutorUpdated(execId, state, message, exitStatus)
exec.job.driver ! ExecutorUpdated(execId, state, message, exitStatus)
if (ExecutorState.isFinished(state)) {
val jobInfo = idToJob(jobId)
// Remove this executor from the worker and job
@ -97,14 +97,12 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
exec.worker.removeExecutor(exec)
// Only retry certain number of times so we don't go into an infinite loop.
if (jobInfo.incrementRetryCount <= JobState.MAX_NUM_RETRY) {
if (jobInfo.incrementRetryCount < JobState.MAX_NUM_RETRY) {
schedule()
} else {
val e = new SparkException("Job %s wth ID %s failed %d times.".format(
logError("Job %s with ID %s failed %d times, removing it".format(
jobInfo.desc.name, jobInfo.id, jobInfo.retryCount))
logError(e.getMessage, e)
throw e
//System.exit(1)
removeJob(jobInfo)
}
}
}
@ -173,7 +171,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
for (pos <- 0 until numUsable) {
if (assigned(pos) > 0) {
val exec = job.addExecutor(usableWorkers(pos), assigned(pos))
launchExecutor(usableWorkers(pos), exec)
launchExecutor(usableWorkers(pos), exec, job.desc.sparkHome)
job.state = JobState.RUNNING
}
}
@ -186,7 +184,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
val coresToUse = math.min(worker.coresFree, job.coresLeft)
if (coresToUse > 0) {
val exec = job.addExecutor(worker, coresToUse)
launchExecutor(worker, exec)
launchExecutor(worker, exec, job.desc.sparkHome)
job.state = JobState.RUNNING
}
}
@ -195,11 +193,11 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
}
}
def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo) {
def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) {
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory)
exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome)
exec.job.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
}
def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
@ -221,19 +219,19 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
actorToWorker -= worker.actor
addressToWorker -= worker.actor.path.address
for (exec <- worker.executors.values) {
exec.job.actor ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None, None)
exec.job.driver ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None, None)
exec.job.executors -= exec.id
}
}
def addJob(desc: JobDescription, actor: ActorRef): JobInfo = {
def addJob(desc: JobDescription, driver: ActorRef): JobInfo = {
val now = System.currentTimeMillis()
val date = new Date(now)
val job = new JobInfo(now, newJobId(date), desc, date, actor)
val job = new JobInfo(now, newJobId(date), desc, date, driver)
jobs += job
idToJob(job.id) = job
actorToJob(sender) = job
addressToJob(sender.path.address) = job
actorToJob(driver) = job
addressToJob(driver.path.address) = job
return job
}
@ -242,8 +240,8 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
logInfo("Removing job " + job.id)
jobs -= job
idToJob -= job.id
actorToJob -= job.actor
addressToWorker -= job.actor.path.address
actorToJob -= job.driver
addressToWorker -= job.driver.path.address
completedJobs += job // Remember it in our history
waitingJobs -= job
for (exec <- job.executors.values) {
@ -264,11 +262,29 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
}
private[spark] object Master {
private val systemName = "sparkMaster"
private val actorName = "Master"
private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r
def main(argStrings: Array[String]) {
val args = new MasterArguments(argStrings)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
val actor = actorSystem.actorOf(
Props(new Master(args.ip, boundPort, args.webUiPort)), name = "Master")
val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort)
actorSystem.awaitTermination()
}
/** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */
def toAkkaUrl(sparkUrl: String): String = {
sparkUrl match {
case sparkUrlRegex(host, port) =>
"akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName)
case _ =>
throw new SparkException("Invalid master URL: " + sparkUrl)
}
}
def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int) = {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName)
(actorSystem, boundPort)
}
}

View file

@ -8,18 +8,31 @@ import akka.util.duration._
import cc.spray.Directives
import cc.spray.directives._
import cc.spray.typeconversion.TwirlSupport._
import spark.deploy._
import cc.spray.http.MediaTypes
import cc.spray.typeconversion.SprayJsonSupport._
import spark.deploy._
import spark.deploy.JsonProtocol._
/**
* Web UI server for the standalone master.
*/
private[spark]
class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives {
val RESOURCE_DIR = "spark/deploy/master/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
implicit val timeout = Timeout(1 seconds)
implicit val timeout = Timeout(10 seconds)
val handler = {
get {
path("") {
(path("") & parameters('format ?)) {
case Some(js) if js.equalsIgnoreCase("json") =>
val future = master ? RequestMasterState
respondWithMediaType(MediaTypes.`application/json`) { ctx =>
ctx.complete(future.mapTo[MasterState])
}
case _ =>
completeWith {
val future = master ? RequestMasterState
future.map {
@ -28,18 +41,26 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
}
} ~
path("job") {
parameter("jobId") { jobId =>
parameters("jobId", 'format ?) {
case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) =>
val future = master ? RequestMasterState
val jobInfo = for (masterState <- future.mapTo[MasterState]) yield {
masterState.activeJobs.find(_.id == jobId).getOrElse({
masterState.completedJobs.find(_.id == jobId).getOrElse(null)
})
}
respondWithMediaType(MediaTypes.`application/json`) { ctx =>
ctx.complete(jobInfo.mapTo[JobInfo])
}
case (jobId, _) =>
completeWith {
val future = master ? RequestMasterState
future.map { state =>
val masterState = state.asInstanceOf[MasterState]
// A bit ugly an inefficient, but we won't have a number of jobs
// so large that it will make a significant difference.
(masterState.activeJobs ++ masterState.completedJobs).find(_.id == jobId) match {
case Some(job) => spark.deploy.master.html.job_details.render(job)
case _ => null
}
val job = masterState.activeJobs.find(_.id == jobId).getOrElse({
masterState.completedJobs.find(_.id == jobId).getOrElse(null)
})
spark.deploy.master.html.job_details.render(job)
}
}
}
@ -50,5 +71,4 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
getFromResourceDirectory(RESOURCE_DIR)
}
}
}

View file

@ -65,9 +65,9 @@ private[spark] class ExecutorRunner(
}
}
/** Replace variables such as {{SLAVEID}} and {{CORES}} in a command argument passed to us */
/** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */
def substituteVariables(argument: String): String = argument match {
case "{{SLAVEID}}" => workerId
case "{{EXECUTOR_ID}}" => execId.toString
case "{{HOSTNAME}}" => hostname
case "{{CORES}}" => cores.toString
case other => other
@ -106,11 +106,6 @@ private[spark] class ExecutorRunner(
throw new IOException("Failed to create directory " + executorDir)
}
// Download the files it depends on into it (disabled for now)
//for (url <- jobDesc.fileUrls) {
// fetchFile(url, executorDir)
//}
// Launch the process
val command = buildCommandSeq()
val builder = new ProcessBuilder(command: _*).directory(executorDir)
@ -118,8 +113,7 @@ private[spark] class ExecutorRunner(
for ((key, value) <- jobDesc.command.environment) {
env.put(key, value)
}
env.put("SPARK_CORES", cores.toString)
env.put("SPARK_MEMORY", memory.toString)
env.put("SPARK_MEM", memory.toString + "m")
// In case we are running this from within the Spark Shell, avoid creating a "scala"
// parent process for the executor command
env.put("SPARK_LAUNCH_WITH_SCALA", "0")

View file

@ -1,19 +1,17 @@
package spark.deploy.worker
import scala.collection.mutable.{ArrayBuffer, HashMap}
import akka.actor.{ActorRef, Props, Actor}
import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated}
import spark.{Logging, Utils}
import spark.util.AkkaUtils
import spark.deploy._
import akka.remote.RemoteClientLifeCycleEvent
import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected}
import java.text.SimpleDateFormat
import java.util.Date
import akka.remote.RemoteClientShutdown
import akka.remote.RemoteClientDisconnected
import spark.deploy.RegisterWorker
import spark.deploy.LaunchExecutor
import spark.deploy.RegisterWorkerFailed
import akka.actor.Terminated
import spark.deploy.master.Master
import java.io.File
private[spark] class Worker(
@ -27,7 +25,6 @@ private[spark] class Worker(
extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r
var master: ActorRef = null
var masterWebUiUrl : String = ""
@ -48,11 +45,7 @@ private[spark] class Worker(
def memoryFree: Int = memory - memoryUsed
def createWorkDir() {
workDir = if (workDirPath != null) {
new File(workDirPath)
} else {
new File(sparkHome, "work")
}
workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work"))
try {
if (!workDir.exists() && !workDir.mkdirs()) {
logError("Failed to create work directory " + workDir)
@ -68,8 +61,7 @@ private[spark] class Worker(
override def preStart() {
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
ip, port, cores, Utils.memoryMegabytesToString(memory)))
val envVar = System.getenv("SPARK_HOME")
sparkHome = new File(if (envVar == null) "." else envVar)
sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
logInfo("Spark home: " + sparkHome)
createWorkDir()
connectToMaster()
@ -77,12 +69,9 @@ private[spark] class Worker(
}
def connectToMaster() {
masterUrl match {
case MASTER_REGEX(masterHost, masterPort) => {
logInfo("Connecting to master spark://" + masterHost + ":" + masterPort)
val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
logInfo("Connecting to master " + masterUrl)
try {
master = context.actorFor(akkaUrl)
master = context.actorFor(Master.toAkkaUrl(masterUrl))
master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(master) // Doesn't work with remote actors, but useful for testing
@ -93,12 +82,6 @@ private[spark] class Worker(
}
}
case _ =>
logError("Invalid master URL: " + masterUrl)
System.exit(1)
}
}
def startWebUi() {
val webUi = new WorkerWebUI(context.system, self)
try {
@ -119,10 +102,10 @@ private[spark] class Worker(
logError("Worker registration failed: " + message)
System.exit(1)
case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_) =>
case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_, execSparkHome_) =>
logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name))
val manager = new ExecutorRunner(
jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, sparkHome, workDir)
jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir)
executors(jobId + "/" + execId) = manager
manager.start()
coresUsed += cores_
@ -134,7 +117,9 @@ private[spark] class Worker(
val fullId = jobId + "/" + execId
if (ExecutorState.isFinished(state)) {
val executor = executors(fullId)
logInfo("Executor " + fullId + " finished with state " + state)
logInfo("Executor " + fullId + " finished with state " + state +
message.map(" message " + _).getOrElse("") +
exitStatus.map(" exitStatus " + _).getOrElse(""))
finishedExecutors(fullId) = executor
executors -= fullId
coresUsed -= executor.cores
@ -143,9 +128,13 @@ private[spark] class Worker(
case KillExecutor(jobId, execId) =>
val fullId = jobId + "/" + execId
val executor = executors(fullId)
executors.get(fullId) match {
case Some(executor) =>
logInfo("Asked to kill executor " + fullId)
executor.kill()
case None =>
logInfo("Asked to kill unknown executor " + fullId)
}
case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
masterDisconnected()
@ -177,11 +166,19 @@ private[spark] class Worker(
private[spark] object Worker {
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
val actor = actorSystem.actorOf(
Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory,
args.master, args.workDir)),
name = "Worker")
val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores,
args.memory, args.master, args.workDir)
actorSystem.awaitTermination()
}
def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int,
masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory,
masterUrl, workDir)), name = "Worker")
(actorSystem, boundPort)
}
}

View file

@ -104,9 +104,25 @@ private[spark] class WorkerArguments(args: Array[String]) {
}
def inferDefaultMemory(): Int = {
val bean = ManagementFactory.getOperatingSystemMXBean
.asInstanceOf[com.sun.management.OperatingSystemMXBean]
val totalMb = (bean.getTotalPhysicalMemorySize / 1024 / 1024).toInt
val ibmVendor = System.getProperty("java.vendor").contains("IBM")
var totalMb = 0
try {
val bean = ManagementFactory.getOperatingSystemMXBean()
if (ibmVendor) {
val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean")
val method = beanClass.getDeclaredMethod("getTotalPhysicalMemory")
totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt
} else {
val beanClass = Class.forName("com.sun.management.OperatingSystemMXBean")
val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize")
totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt
}
} catch {
case e: Exception => {
totalMb = 2*1024
System.out.println("Failed to get total physical memory. Using " + totalMb + " MB")
}
}
// Leave out 1 GB for the operating system, but don't return a negative memory size
math.max(totalMb - 1024, 512)
}

View file

@ -7,18 +7,32 @@ import akka.util.Timeout
import akka.util.duration._
import cc.spray.Directives
import cc.spray.typeconversion.TwirlSupport._
import spark.deploy.{WorkerState, RequestWorkerState}
import cc.spray.http.MediaTypes
import cc.spray.typeconversion.SprayJsonSupport._
import spark.deploy.{WorkerState, RequestWorkerState}
import spark.deploy.JsonProtocol._
/**
* Web UI server for the standalone worker.
*/
private[spark]
class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives {
val RESOURCE_DIR = "spark/deploy/worker/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
implicit val timeout = Timeout(1 seconds)
implicit val timeout = Timeout(10 seconds)
val handler = {
get {
path("") {
(path("") & parameters('format ?)) {
case Some(js) if js.equalsIgnoreCase("json") => {
val future = worker ? RequestWorkerState
respondWithMediaType(MediaTypes.`application/json`) { ctx =>
ctx.complete(future.mapTo[WorkerState])
}
}
case _ =>
completeWith{
val future = worker ? RequestWorkerState
future.map { workerState =>
@ -39,5 +53,4 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct
getFromResourceDirectory(RESOURCE_DIR)
}
}
}

View file

@ -30,7 +30,7 @@ private[spark] class Executor extends Logging {
initLogging()
def initialize(slaveHostname: String, properties: Seq[(String, String)]) {
def initialize(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) {
// Make sure the local hostname we report matches the cluster scheduler's name for this host
Utils.setCustomHostname(slaveHostname)
@ -64,7 +64,7 @@ private[spark] class Executor extends Logging {
)
// Initialize Spark environment (using system properties read above)
env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false)
env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
SparkEnv.set(env)
// Start worker thread pool
@ -159,23 +159,25 @@ private[spark] class Executor extends Logging {
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
synchronized {
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File("."))
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File("."))
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
val url = new File(".", localName).toURI.toURL
val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
if (!urlClassLoader.getURLs.contains(url)) {
logInfo("Adding " + url + " to class loader")
urlClassLoader.addURL(url)
}
}
}
}
}

View file

@ -29,9 +29,14 @@ private[spark] class MesosExecutorBackend(executor: Executor)
executorInfo: ExecutorInfo,
frameworkInfo: FrameworkInfo,
slaveInfo: SlaveInfo) {
logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
this.driver = driver
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
executor.initialize(slaveInfo.getHostname, properties)
executor.initialize(
executorInfo.getExecutorId.getValue,
slaveInfo.getHostname,
properties
)
}
override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {

View file

@ -4,78 +4,72 @@ import java.nio.ByteBuffer
import spark.Logging
import spark.TaskState.TaskState
import spark.util.AkkaUtils
import akka.actor.{ActorRef, Actor, Props}
import akka.actor.{ActorRef, Actor, Props, Terminated}
import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected}
import java.util.concurrent.{TimeUnit, ThreadPoolExecutor, SynchronousQueue}
import akka.remote.RemoteClientLifeCycleEvent
import spark.scheduler.cluster._
import spark.scheduler.cluster.RegisteredSlave
import spark.scheduler.cluster.RegisteredExecutor
import spark.scheduler.cluster.LaunchTask
import spark.scheduler.cluster.RegisterSlaveFailed
import spark.scheduler.cluster.RegisterSlave
import spark.scheduler.cluster.RegisterExecutorFailed
import spark.scheduler.cluster.RegisterExecutor
private[spark] class StandaloneExecutorBackend(
executor: Executor,
masterUrl: String,
slaveId: String,
driverUrl: String,
executorId: String,
hostname: String,
cores: Int)
extends Actor
with ExecutorBackend
with Logging {
val threadPool = new ThreadPoolExecutor(
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
var master: ActorRef = null
var driver: ActorRef = null
override def preStart() {
try {
logInfo("Connecting to master: " + masterUrl)
master = context.actorFor(masterUrl)
master ! RegisterSlave(slaveId, hostname, cores)
logInfo("Connecting to driver: " + driverUrl)
driver = context.actorFor(driverUrl)
driver ! RegisterExecutor(executorId, hostname, cores)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(master) // Doesn't work with remote actors, but useful for testing
} catch {
case e: Exception =>
logError("Failed to connect to master", e)
System.exit(1)
}
context.watch(driver) // Doesn't work with remote actors, but useful for testing
}
override def receive = {
case RegisteredSlave(sparkProperties) =>
logInfo("Successfully registered with master")
executor.initialize(hostname, sparkProperties)
case RegisteredExecutor(sparkProperties) =>
logInfo("Successfully registered with driver")
executor.initialize(executorId, hostname, sparkProperties)
case RegisterSlaveFailed(message) =>
case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
System.exit(1)
case LaunchTask(taskDesc) =>
logInfo("Got assigned task " + taskDesc.taskId)
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
logError("Driver terminated or disconnected! Shutting down.")
System.exit(1)
}
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
master ! StatusUpdate(slaveId, taskId, state, data)
driver ! StatusUpdate(executorId, taskId, state, data)
}
}
private[spark] object StandaloneExecutorBackend {
def run(masterUrl: String, slaveId: String, hostname: String, cores: Int) {
def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0)
val actor = actorSystem.actorOf(
Props(new StandaloneExecutorBackend(new Executor, masterUrl, slaveId, hostname, cores)),
Props(new StandaloneExecutorBackend(new Executor, driverUrl, executorId, hostname, cores)),
name = "Executor")
actorSystem.awaitTermination()
}
def main(args: Array[String]) {
if (args.length != 4) {
System.err.println("Usage: StandaloneExecutorBackend <master> <slaveId> <hostname> <cores>")
System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores>")
System.exit(1)
}
run(args(0), args(1), args(2), args(3).toInt)

View file

@ -12,7 +12,14 @@ import java.net._
private[spark]
abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging {
abstract class Connection(val channel: SocketChannel, val selector: Selector,
val remoteConnectionManagerId: ConnectionManagerId) extends Logging {
def this(channel_ : SocketChannel, selector_ : Selector) = {
this(channel_, selector_,
ConnectionManagerId.fromSocketAddress(
channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
))
}
channel.configureBlocking(false)
channel.socket.setTcpNoDelay(true)
@ -25,7 +32,6 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
val remoteAddress = getRemoteAddress()
val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress)
def key() = channel.keyFor(selector)
@ -103,8 +109,9 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
}
private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector)
extends Connection(SocketChannel.open, selector_) {
private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
remoteId_ : ConnectionManagerId)
extends Connection(SocketChannel.open, selector_, remoteId_) {
class Outbox(fair: Int = 0) {
val messages = new Queue[Message]()
@ -135,8 +142,11 @@ extends Connection(SocketChannel.open, selector_) {
val chunk = message.getChunkForSending(defaultChunkSize)
if (chunk.isDefined) {
messages += message // this is probably incorrect, it wont work as fifo
if (!message.started) logDebug("Starting to send [" + message + "]")
if (!message.started) {
logDebug("Starting to send [" + message + "]")
message.started = true
message.startTime = System.currentTimeMillis
}
return chunk
} else {
/*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/

View file

@ -43,17 +43,16 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
}
val selector = SelectorProvider.provider.openSelector()
val handleMessageExecutor = Executors.newFixedThreadPool(4)
val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt)
val serverChannel = ServerSocketChannel.open()
val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
val messageStatuses = new HashMap[Int, MessageStatus]
val connectionRequests = new SynchronizedQueue[SendingConnection]
val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
val sendMessageRequests = new Queue[(Message, SendingConnection)]
implicit val futureExecContext = ExecutionContext.fromExecutor(
Executors.newCachedThreadPool(DaemonThreadFactory))
implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
@ -79,10 +78,10 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
def run() {
try {
while(!selectorThread.isInterrupted) {
while(!connectionRequests.isEmpty) {
val sendingConnection = connectionRequests.dequeue
for( (connectionManagerId, sendingConnection) <- connectionRequests) {
sendingConnection.connect()
addConnection(sendingConnection)
connectionRequests -= connectionManagerId
}
sendMessageRequests.synchronized {
while(!sendMessageRequests.isEmpty) {
@ -300,8 +299,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
val newConnection = new SendingConnection(inetSocketAddress, selector)
connectionRequests += newConnection
val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId,
new SendingConnection(inetSocketAddress, selector, connectionManagerId))
newConnection
}
val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress)
@ -473,6 +472,7 @@ private[spark] object ConnectionManager {
val mb = size * count / 1024.0 / 1024.0
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
println("--------------------------")
println()
}

View file

@ -13,8 +13,14 @@ import akka.util.duration._
private[spark] object ConnectionManagerTest extends Logging{
def main(args: Array[String]) {
//<mesos cluster> - the master URL
//<slaves file> - a list slaves to run connectionTest on
//[num of tasks] - the number of parallel tasks to be initiated default is number of slave hosts
//[size of msg in MB (integer)] - the size of messages to be sent in each task, default is 10
//[count] - how many times to run, default is 3
//[await time in seconds] : await time (in seconds), default is 600
if (args.length < 2) {
println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>")
println("Usage: ConnectionManagerTest <mesos cluster> <slaves file> [num of tasks] [size of msg in MB (integer)] [count] [await time in seconds)] ")
System.exit(1)
}
@ -29,16 +35,19 @@ private[spark] object ConnectionManagerTest extends Logging{
/*println("Slaves")*/
/*slaves.foreach(println)*/
val slaveConnManagerIds = sc.parallelize(0 until slaves.length, slaves.length).map(
val tasknum = if (args.length > 2) args(2).toInt else slaves.length
val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024
val count = if (args.length > 4) args(4).toInt else 3
val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second
println("Running "+count+" rounds of test: " + "parallel tasks = " + tasknum + ", msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime)
val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map(
i => SparkEnv.get.connectionManager.id).collect()
println("\nSlave ConnectionManagerIds")
slaveConnManagerIds.foreach(println)
println
val count = 10
(0 until count).foreach(i => {
val resultStrs = sc.parallelize(0 until slaves.length, slaves.length).map(i => {
val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => {
val connManager = SparkEnv.get.connectionManager
val thisConnManagerId = connManager.id
connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
@ -46,7 +55,6 @@ private[spark] object ConnectionManagerTest extends Logging{
None
})
val size = 100 * 1024 * 1024
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
@ -56,13 +64,13 @@ private[spark] object ConnectionManagerTest extends Logging{
logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]")
connManager.sendMessageReliably(slaveConnManagerId, bufferMessage)
})
val results = futures.map(f => Await.result(f, 1.second))
val results = futures.map(f => Await.result(f, awaitTime))
val finishTime = System.currentTimeMillis
Thread.sleep(5000)
val mb = size * results.size / 1024.0 / 1024.0
val ms = finishTime - startTime
val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"
val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"
logInfo(resultStr)
resultStr
}).collect()

View file

@ -32,7 +32,7 @@ private[spark] class ApproximateActionListener[T, U, R](
if (finishedTasks == totalTasks) {
// If we had already returned a PartialResult, set its final value
resultObject.foreach(r => r.setFinalValue(evaluator.currentResult()))
// Notify any waiting thread that may have called getResult
// Notify any waiting thread that may have called awaitResult
this.notifyAll()
}
}
@ -49,7 +49,7 @@ private[spark] class ApproximateActionListener[T, U, R](
* Waits for up to timeout milliseconds since the listener was created and then returns a
* PartialResult with the result so far. This may be complete if the whole job is done.
*/
def getResult(): PartialResult[R] = synchronized {
def awaitResult(): PartialResult[R] = synchronized {
val finishTime = startTime + timeout
while (true) {
val time = System.currentTimeMillis()

View file

@ -11,13 +11,11 @@ private[spark]
class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
extends RDD[T](sc, Nil) {
@transient
var splits_ : Array[Split] = (0 until blockIds.size).map(i => {
@transient var splits_ : Array[Split] = (0 until blockIds.size).map(i => {
new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split]
}).toArray
@transient
lazy val locations_ = {
@transient lazy val locations_ = {
val blockManager = SparkEnv.get.blockManager
/*val locations = blockIds.map(id => blockManager.getLocations(id))*/
val locations = blockManager.getLocations(blockIds)

View file

@ -1,7 +1,7 @@
package spark.rdd
import java.io.{ObjectOutputStream, IOException}
import spark.{OneToOneDependency, NarrowDependency, RDD, SparkContext, Split, TaskContext}
import spark._
private[spark]
@ -35,8 +35,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
val numSplitsInRdd2 = rdd2.splits.size
@transient
var splits_ = {
override def getSplits: Array[Split] = {
// create the cross product split
val array = new Array[Split](rdd1.splits.size * rdd2.splits.size)
for (s1 <- rdd1.splits; s2 <- rdd2.splits) {
@ -46,8 +45,6 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
array
}
override def getSplits = splits_
override def getPreferredLocations(split: Split) = {
val currSplit = split.asInstanceOf[CartesianSplit]
rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
@ -59,7 +56,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
}
var deps_ = List(
override def getDependencies: Seq[Dependency[_]] = List(
new NarrowDependency(rdd1) {
def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2)
},
@ -68,11 +65,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
}
)
override def getDependencies = deps_
override def clearDependencies() {
deps_ = Nil
splits_ = null
rdd1 = null
rdd2 = null
}

View file

@ -9,23 +9,26 @@ import org.apache.hadoop.fs.Path
import java.io.{File, IOException, EOFException}
import java.text.NumberFormat
private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split {
override val index: Int = idx
}
private[spark] class CheckpointRDDSplit(val index: Int) extends Split {}
/**
* This RDD represents a RDD checkpoint file (similar to HadoopRDD).
*/
private[spark]
class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String)
class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String)
extends RDD[T](sc, Nil) {
@transient val path = new Path(checkpointPath)
@transient val fs = path.getFileSystem(new Configuration())
@transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
@transient val splits_ : Array[Split] = {
val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted
splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray
val dirContents = fs.listStatus(new Path(checkpointPath))
val splitFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
val numSplits = splitFiles.size
if (numSplits > 0 && (!splitFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
!splitFiles(numSplits-1).endsWith(CheckpointRDD.splitIdToFile(numSplits-1)))) {
throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
}
Array.tabulate(numSplits)(i => new CheckpointRDDSplit(i))
}
checkpointData = Some(new RDDCheckpointData[T](this))
@ -34,36 +37,34 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String)
override def getSplits = splits_
override def getPreferredLocations(split: Split): Seq[String] = {
val status = fs.getFileStatus(path)
val status = fs.getFileStatus(new Path(checkpointPath))
val locations = fs.getFileBlockLocations(status, 0, status.getLen)
locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
}
override def compute(split: Split, context: TaskContext): Iterator[T] = {
CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile, context)
val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))
CheckpointRDD.readFromFile(file, context)
}
override def checkpoint() {
// Do nothing. Hadoop RDD should not be checkpointed.
// Do nothing. CheckpointRDD should not be checkpointed.
}
}
private[spark] object CheckpointRDD extends Logging {
def splitIdToFileName(splitId: Int): String = {
val numfmt = NumberFormat.getInstance()
numfmt.setMinimumIntegerDigits(5)
numfmt.setGroupingUsed(false)
"part-" + numfmt.format(splitId)
def splitIdToFile(splitId: Int): String = {
"part-%05d".format(splitId)
}
def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) {
def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
val outputDir = new Path(path)
val fs = outputDir.getFileSystem(new Configuration())
val finalOutputName = splitIdToFileName(context.splitId)
val finalOutputName = splitIdToFile(ctx.splitId)
val finalOutputPath = new Path(outputDir, finalOutputName)
val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId)
val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
if (fs.exists(tempOutputPath)) {
throw new IOException("Checkpoint failed: temporary path " +
@ -83,22 +84,22 @@ private[spark] object CheckpointRDD extends Logging {
serializeStream.close()
if (!fs.rename(tempOutputPath, finalOutputPath)) {
if (!fs.delete(finalOutputPath, true)) {
throw new IOException("Checkpoint failed: failed to delete earlier output of task "
+ context.attemptId)
}
if (!fs.rename(tempOutputPath, finalOutputPath)) {
if (!fs.exists(finalOutputPath)) {
fs.delete(tempOutputPath, false)
throw new IOException("Checkpoint failed: failed to save output of task: "
+ context.attemptId)
+ ctx.attemptId + " and final output path does not exist")
} else {
// Some other copy of this task must've finished before us and renamed it
logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it")
fs.delete(tempOutputPath, false)
}
}
}
def readFromFile[T](path: String, context: TaskContext): Iterator[T] = {
val inputPath = new Path(path)
val fs = inputPath.getFileSystem(new Configuration())
def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
val fs = path.getFileSystem(new Configuration())
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val fileInputStream = fs.open(inputPath, bufferSize)
val fileInputStream = fs.open(path, bufferSize)
val serializer = SparkEnv.get.serializer.newInstance()
val deserializeStream = serializer.deserializeStream(fileInputStream)

View file

@ -1,9 +1,9 @@
package spark.rdd
import java.io.{ObjectOutputStream, IOException}
import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext}
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
@ -45,8 +45,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
val aggr = new CoGroupAggregator
@transient
var deps_ = {
@transient var deps_ = {
val deps = new ArrayBuffer[Dependency[_]]
for ((rdd, index) <- rdds.zipWithIndex) {
if (rdd.partitioner == Some(part)) {
@ -63,8 +62,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
override def getDependencies = deps_
@transient
var splits_ : Array[Split] = {
@transient var splits_ : Array[Split] = {
val array = new Array[Split](part.numPartitions)
for (i <- 0 until array.size) {
array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) =>
@ -86,9 +84,17 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit]
val numRdds = split.deps.size
val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
// e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any]))
val seq = map.get(k)
if (seq != null) {
seq
} else {
val seq = Array.fill(numRdds)(new ArrayBuffer[Any])
map.put(k, seq)
seq
}
}
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, itsSplitIndex, itsSplit) => {
@ -99,16 +105,13 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
}
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
def mergePair(pair: (K, Seq[Any])) {
val mySeq = getSeq(pair._1)
for (v <- pair._2)
mySeq(depNum) += v
}
val fetcher = SparkEnv.get.shuffleFetcher
fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair)
for ((k, vs) <- fetcher.fetch[K, Seq[Any]](shuffleId, split.index)) {
getSeq(k)(depNum) ++= vs
}
}
map.iterator
}
JavaConversions.mapAsScalaMap(map).iterator
}
override def clearDependencies() {

View file

@ -27,11 +27,11 @@ private[spark] case class CoalescedRDDSplit(
* or to avoid having a large number of small tasks when processing a directory with many files.
*/
class CoalescedRDD[T: ClassManifest](
var prev: RDD[T],
@transient var prev: RDD[T],
maxPartitions: Int)
extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs
extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies
@transient var splits_ : Array[Split] = {
override def getSplits: Array[Split] = {
val prevSplits = prev.splits
if (prevSplits.length < maxPartitions) {
prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) }
@ -44,26 +44,20 @@ class CoalescedRDD[T: ClassManifest](
}
}
override def getSplits = splits_
override def compute(split: Split, context: TaskContext): Iterator[T] = {
split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { parentSplit =>
firstParent[T].iterator(parentSplit, context)
}
}
var deps_ : List[Dependency[_]] = List(
override def getDependencies: Seq[Dependency[_]] = List(
new NarrowDependency(prev) {
def getParents(id: Int): Seq[Int] =
splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices
}
)
override def getDependencies() = deps_
override def clearDependencies() {
deps_ = Nil
splits_ = null
prev = null
}
}

View file

@ -9,6 +9,8 @@ private[spark] class FilteredRDD[T: ClassManifest](
override def getSplits = firstParent[T].splits
override val partitioner = prev.partitioner // Since filter cannot change a partition's keys
override def compute(split: Split, context: TaskContext) =
firstParent[T].iterator(split, context).filter(f)
}

View file

@ -3,9 +3,7 @@ package spark.rdd
import spark.{RDD, Split, TaskContext}
private[spark]
class MappedRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: T => U)
class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U)
extends RDD[U](prev) {
override def getSplits = firstParent[T].splits

View file

@ -37,11 +37,9 @@ class NewHadoopRDD[K, V](
formatter.format(new Date())
}
@transient
private val jobId = new JobID(jobtrackerId, id)
@transient private val jobId = new JobID(jobtrackerId, id)
@transient
private val splits_ : Array[Split] = {
@transient private val splits_ : Array[Split] = {
val inputFormat = inputFormatClass.newInstance
val jobContext = newJobContext(conf, jobId)
val rawSplits = inputFormat.getSplits(jobContext).toArray

View file

@ -0,0 +1,42 @@
package spark.rdd
import spark.{NarrowDependency, RDD, SparkEnv, Split, TaskContext}
class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split {
override val index = idx
}
/**
* Represents a dependency between the PartitionPruningRDD and its parent. In this
* case, the child RDD contains a subset of partitions of the parents'.
*/
class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean)
extends NarrowDependency[T](rdd) {
@transient
val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index))
.zipWithIndex.map { case(split, idx) => new PartitionPruningRDDSplit(idx, split) : Split }
override def getParents(partitionId: Int) = List(partitions(partitionId).index)
}
/**
* A RDD used to prune RDD partitions/splits so we can avoid launching tasks on
* all partitions. An example use case: If we know the RDD is partitioned by range,
* and the execution DAG has a filter on the key, we can avoid launching tasks
* on partitions that don't have the range covering the key.
*/
class PartitionPruningRDD[T: ClassManifest](
@transient prev: RDD[T],
@transient partitionFilterFunc: Int => Boolean)
extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) {
override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(
split.asInstanceOf[PartitionPruningRDDSplit].parentSplit, context)
override protected def getSplits =
getDependencies.head.asInstanceOf[PruneDependency[T]].partitions
}

View file

@ -19,13 +19,12 @@ class SampledRDD[T: ClassManifest](
seed: Int)
extends RDD[T](prev) {
@transient
var splits_ : Array[Split] = {
@transient var splits_ : Array[Split] = {
val rg = new Random(seed)
firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt))
}
override def getSplits = splits_.asInstanceOf[Array[Split]]
override def getSplits = splits_
override def getPreferredLocations(split: Split) =
firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev)

View file

@ -22,17 +22,10 @@ class ShuffledRDD[K, V](
override val partitioner = Some(part)
@transient
var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
override def getSplits = splits_
override def getSplits = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index)
}
override def clearDependencies() {
splits_ = null
}
}

View file

@ -26,10 +26,9 @@ private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIn
class UnionRDD[T: ClassManifest](
sc: SparkContext,
@transient var rdds: Seq[RDD[T]])
extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs
extends RDD[T](sc, Nil) { // Nil since we implement getDependencies
@transient
var splits_ : Array[Split] = {
override def getSplits: Array[Split] = {
val array = new Array[Split](rdds.map(_.splits.size).sum)
var pos = 0
for (rdd <- rdds; split <- rdd.splits) {
@ -39,20 +38,16 @@ class UnionRDD[T: ClassManifest](
array
}
override def getSplits = splits_
@transient var deps_ = {
override def getDependencies: Seq[Dependency[_]] = {
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
for (rdd <- rdds) {
deps += new RangeDependency(rdd, 0, pos, rdd.splits.size)
pos += rdd.splits.size
}
deps.toList
deps
}
override def getDependencies = deps_
override def compute(s: Split, context: TaskContext): Iterator[T] =
s.asInstanceOf[UnionSplit[T]].iterator(context)
@ -60,8 +55,6 @@ class UnionRDD[T: ClassManifest](
s.asInstanceOf[UnionSplit[T]].preferredLocations()
override def clearDependencies() {
deps_ = null
splits_ = null
rdds = null
}
}

View file

@ -32,10 +32,7 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2)))
with Serializable {
// TODO: FIX THIS.
@transient
var splits_ : Array[Split] = {
override def getSplits: Array[Split] = {
if (rdd1.splits.size != rdd2.splits.size) {
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
}
@ -46,8 +43,6 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
array
}
override def getSplits = splits_
override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = {
val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
@ -59,7 +54,6 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
}
override def clearDependencies() {
splits_ = null
rdd1 = null
rdd2 = null
}

View file

@ -23,7 +23,16 @@ import util.{MetadataCleaner, TimeStampedHashMap}
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
*/
private[spark]
class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging {
class DAGScheduler(
taskSched: TaskScheduler,
mapOutputTracker: MapOutputTracker,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv)
extends TaskSchedulerListener with Logging {
def this(taskSched: TaskScheduler) {
this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
}
taskSched.setListener(this)
// Called by TaskScheduler to report task completions or failures.
@ -35,12 +44,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
eventQueue.put(CompletionEvent(task, reason, result, accumUpdates))
}
// Called by TaskScheduler when a host fails.
override def hostLost(host: String) {
eventQueue.put(HostLost(host))
// Called by TaskScheduler when an executor fails.
override def executorLost(execId: String) {
eventQueue.put(ExecutorLost(execId))
}
// Called by TaskScheduler to cancel an entier TaskSet due to repeated failures.
// Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
override def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
}
@ -54,8 +63,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
// resubmit failed stages
val POLL_TIMEOUT = 10L
private val lock = new Object // Used for access to the entire DAGScheduler
private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
val nextRunId = new AtomicInteger(0)
@ -68,12 +75,13 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
var cacheLocs = new HashMap[Int, Array[List[String]]]
val env = SparkEnv.get
val cacheTracker = env.cacheTracker
val mapOutputTracker = env.mapOutputTracker
val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back;
// that's not going to be a realistic assumption in general
// For tracking failed nodes, we use the MapOutputTracker's generation number, which is
// sent with every task. When we detect a node failing, we note the current generation number
// and failed executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask
// results.
// TODO: Garbage collect information about failure generations when we know there are no more
// stray messages to detect.
val failedGeneration = new HashMap[String, Long]
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
val running = new HashSet[Stage] // Stages we are running right now
@ -87,19 +95,27 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
// Start a thread to run the DAGScheduler event loop
def start() {
new Thread("DAGScheduler") {
setDaemon(true)
override def run() {
DAGScheduler.this.run()
}
}.start()
}
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
if (!cacheLocs.contains(rdd.id)) {
val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map {
locations => locations.map(_.ip).toList
}.toArray
}
cacheLocs(rdd.id)
}
def updateCacheLocs() {
cacheLocs = cacheTracker.getLocationsSnapshot()
private def clearCacheLocs() {
cacheLocs.clear()
}
/**
@ -107,7 +123,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* The priority value passed in will be used if the stage doesn't already exist with
* a lower priority (we assume that priorities always increase across jobs for now).
*/
def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = {
private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
@ -122,12 +138,11 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* as a result stage for the final RDD used directly in an action. The stage will also be given
* the provided priority.
*/
def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
if (shuffleDep != None) {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of splits is unknown
logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
cacheTracker.registerRDD(rdd.id, rdd.splits.size)
if (shuffleDep != None) {
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
}
val id = nextStageId.getAndIncrement()
@ -140,7 +155,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Get or create the list of parent stages for a given RDD. The stages will be assigned the
* provided priority if they haven't already been created with a lower priority.
*/
def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
private def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
val parents = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
def visit(r: RDD[_]) {
@ -148,8 +163,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
visited += r
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of splits is unknown
logInfo("Registering parent RDD " + r.id + " (" + r.origin + ")")
cacheTracker.registerRDD(r.id, r.splits.size)
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
@ -164,15 +177,13 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
parents.toList
}
def getMissingParentStages(stage: Stage): List[Stage] = {
private def getMissingParentStages(stage: Stage): List[Stage] = {
val missing = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
def visit(rdd: RDD[_]) {
if (!visited(rdd)) {
visited += rdd
val locs = getCacheLocs(rdd)
for (p <- 0 until rdd.splits.size) {
if (locs(p) == Nil) {
if (getCacheLocs(rdd).contains(Nil)) {
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
@ -187,28 +198,49 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
}
}
}
visit(stage.rdd)
missing.toList
}
/**
* Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
* JobWaiter whose getResult() method will return the result of the job when it is complete.
*
* The job is assumed to have at least one partition; zero partition jobs should be handled
* without a JobSubmitted event.
*/
private[scheduler] def prepareJob[T, U: ClassManifest](
finalRdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean,
resultHandler: (Int, U) => Unit)
: (JobSubmitted, JobWaiter[U]) =
{
assert(partitions.size > 0)
val waiter = new JobWaiter(partitions.size, resultHandler)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)
return (toSubmit, waiter)
}
def runJob[T, U: ClassManifest](
finalRdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean)
: Array[U] =
allowLocal: Boolean,
resultHandler: (Int, U) => Unit)
{
if (partitions.size == 0) {
return new Array[U](0)
return
}
val waiter = new JobWaiter(partitions.size)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter))
waiter.getResult() match {
case JobSucceeded(results: Seq[_]) =>
return results.asInstanceOf[Seq[U]].toArray
val (toSubmit, waiter) = prepareJob(
finalRdd, func, partitions, callSite, allowLocal, resultHandler)
eventQueue.put(toSubmit)
waiter.awaitResult() match {
case JobSucceeded => {}
case JobFailed(exception: Exception) =>
logInfo("Failed to run " + callSite)
throw exception
@ -227,32 +259,22 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.splits.size).toArray
eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener))
return listener.getResult() // Will throw an exception if the job fails
return listener.awaitResult() // Will throw an exception if the job fails
}
/**
* The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
* events and responds by launching tasks. This runs in a dedicated thread and receives events
* via the eventQueue.
* Process one event retrieved from the event queue.
* Returns true if we should stop the event loop.
*/
def run() {
SparkEnv.set(env)
while (true) {
val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
if (event != null) {
logDebug("Got event of type " + event.getClass.getName)
}
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
val runId = nextRunId.getAndIncrement()
val finalStage = newStage(finalRDD, None, runId)
val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
updateCacheLocs()
clearCacheLocs()
logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
" output partitions")
" output partitions (allowLocal=" + allowLocal + ")")
logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))
@ -265,8 +287,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
submitStage(finalStage)
}
case HostLost(host) =>
handleHostLost(host)
case ExecutorLost(execId) =>
handleExecutorLost(execId)
case completion: CompletionEvent =>
handleTaskCompletion(completion)
@ -280,37 +302,74 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val error = new SparkException("Job cancelled because SparkContext was shut down")
job.listener.jobFailed(error)
}
return
case null =>
// queue.poll() timed out, ignore it
return true
}
return false
}
/**
* Resubmit any failed stages. Ordinarily called after a small amount of time has passed since
* the last fetch failure.
*/
private[scheduler] def resubmitFailedStages() {
logInfo("Resubmitting failed stages")
clearCacheLocs()
val failed2 = failed.toArray
failed.clear()
for (stage <- failed2.sortBy(_.priority)) {
submitStage(stage)
}
}
/**
* Check for waiting or failed stages which are now eligible for resubmission.
* Ordinarily run on every iteration of the event loop.
*/
private[scheduler] def submitWaitingStages() {
// TODO: We might want to run this less often, when we are sure that something has become
// runnable that wasn't before.
logTrace("Checking for newly runnable parent stages")
logTrace("running: " + running)
logTrace("waiting: " + waiting)
logTrace("failed: " + failed)
val waiting2 = waiting.toArray
waiting.clear()
for (stage <- waiting2.sortBy(_.priority)) {
submitStage(stage)
}
}
/**
* The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
* events and responds by launching tasks. This runs in a dedicated thread and receives events
* via the eventQueue.
*/
private def run() {
SparkEnv.set(env)
while (true) {
val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
if (event != null) {
logDebug("Got event of type " + event.getClass.getName)
}
if (event != null) {
if (processEvent(event)) {
return
}
}
val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
// Periodically resubmit failed stages if some map output fetches have failed and we have
// waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
// tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
// the same time, so we want to make sure we've identified all the reduce tasks that depend
// on the failed node.
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
logInfo("Resubmitting failed stages")
updateCacheLocs()
val failed2 = failed.toArray
failed.clear()
for (stage <- failed2.sortBy(_.priority)) {
submitStage(stage)
}
resubmitFailedStages()
} else {
// TODO: We might want to run this less often, when we are sure that something has become
// runnable that wasn't before.
logDebug("Checking for newly runnable parent stages")
logDebug("running: " + running)
logDebug("waiting: " + waiting)
logDebug("failed: " + failed)
val waiting2 = waiting.toArray
waiting.clear()
for (stage <- waiting2.sortBy(_.priority)) {
submitStage(stage)
}
submitWaitingStages()
}
}
}
@ -320,7 +379,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* We run the operation in a separate thread just in case it takes a bunch of time, so that we
* don't block the DAGScheduler event loop or other concurrent jobs.
*/
def runLocally(job: ActiveJob) {
private def runLocally(job: ActiveJob) {
logInfo("Computing the requested partition locally")
new Thread("Local computation of job " + job.runId) {
override def run() {
@ -329,9 +388,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val rdd = job.finalStage.rdd
val split = rdd.splits(job.partitions(0))
val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
taskContext.executeOnCompleteCallbacks()
job.listener.taskSucceeded(0, result)
} finally {
taskContext.executeOnCompleteCallbacks()
}
} catch {
case e: Exception =>
job.listener.jobFailed(e)
@ -340,13 +402,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}.start()
}
def submitStage(stage: Stage) {
/** Submits stage, but first recursively submits any missing parents. */
private def submitStage(stage: Stage) {
logDebug("submitStage(" + stage + ")")
if (!waiting(stage) && !running(stage) && !failed(stage)) {
val missing = getMissingParentStages(stage).sortBy(_.id)
logDebug("missing: " + missing)
if (missing == Nil) {
logInfo("Submitting " + stage + " (" + stage.origin + "), which has no missing parents")
logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
submitMissingTasks(stage)
running += stage
} else {
@ -358,7 +421,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
}
def submitMissingTasks(stage: Stage) {
/** Called when stage's parents are available and we can now do its task. */
private def submitMissingTasks(stage: Stage) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
@ -379,11 +443,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
}
if (tasks.size > 0) {
logInfo("Submitting " + tasks.size + " missing tasks from " + stage)
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
myPending ++= tasks
logDebug("New pending tasks: " + myPending)
taskSched.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority))
if (!stage.submissionTime.isDefined) {
stage.submissionTime = Some(System.currentTimeMillis())
}
} else {
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@ -395,9 +462,18 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Responds to a task finishing. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
*/
def handleTaskCompletion(event: CompletionEvent) {
private def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
val stage = idToStage(task.stageId)
def markStageAsFinished(stage: Stage) = {
val serviceTime = stage.submissionTime match {
case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0)
case _ => "Unkown"
}
logInfo("%s (%s) finished in %s s".format(stage, stage.origin, serviceTime))
running -= stage
}
event.reason match {
case Success =>
logInfo("Completed " + task)
@ -412,13 +488,13 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
if (!job.finished(rt.outputId)) {
job.finished(rt.outputId) = true
job.numFinished += 1
job.listener.taskSucceeded(rt.outputId, event.result)
// If the whole job has finished, remove it
if (job.numFinished == job.numPartitions) {
activeJobs -= job
resultStageToJob -= stage
running -= stage
markStageAsFinished(stage)
}
job.listener.taskSucceeded(rt.outputId, event.result)
}
case None =>
logInfo("Ignoring result from " + rt + " because its job has finished")
@ -427,23 +503,32 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
case smt: ShuffleMapTask =>
val stage = idToStage(smt.stageId)
val status = event.result.asInstanceOf[MapStatus]
val host = status.address.ip
logInfo("ShuffleMapTask finished with host " + host)
if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
if (failedGeneration.contains(execId) && smt.generation <= failedGeneration(execId)) {
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
} else {
stage.addOutputLoc(smt.partition, status)
}
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
logInfo(stage + " (" + stage.origin + ") finished; looking for newly runnable stages")
running -= stage
markStageAsFinished(stage)
logInfo("looking for newly runnable stages")
logInfo("running: " + running)
logInfo("waiting: " + waiting)
logInfo("failed: " + failed)
if (stage.shuffleDep != None) {
// We supply true to increment the generation number here in case this is a
// recomputation of the map outputs. In that case, some nodes may have cached
// locations with holes (from when we detected the error) and will need the
// generation incremented to refetch them.
// TODO: Only increment the generation number if this is not the first time
// we registered these map outputs.
mapOutputTracker.registerMapOutputs(
stage.shuffleDep.get.shuffleId,
stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray)
stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray,
true)
}
updateCacheLocs()
clearCacheLocs()
if (stage.outputLocs.count(_ == Nil) != 0) {
// Some tasks had failed; let's resubmit this stage
// TODO: Lower-level scheduler should also deal with this
@ -462,7 +547,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
waiting --= newlyRunnable
running ++= newlyRunnable
for (stage <- newlyRunnable.sortBy(_.id)) {
logInfo("Submitting " + stage + " (" + stage.origin + "), which is now runnable")
logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable")
submitMissingTasks(stage)
}
}
@ -493,9 +578,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
// Remember that a fetch failed now; this is used to resubmit the broken
// stages later, after a small wait (to give other tasks the chance to fail)
lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock
// TODO: mark the host as failed only if there were lots of fetch failures on it
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
handleHostLost(bmAddress.ip)
handleExecutorLost(bmAddress.executorId, Some(task.generation))
}
case other =>
@ -505,22 +590,31 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
/**
* Responds to a host being lost. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use hostLost() to post a host lost event from outside.
* Responds to an executor being lost. This is called inside the event loop, so it assumes it can
* modify the scheduler's internal state. Use executorLost() to post a loss event from outside.
*
* Optionally the generation during which the failure was caught can be passed to avoid allowing
* stray fetch failures from possibly retriggering the detection of a node as lost.
*/
def handleHostLost(host: String) {
if (!deadHosts.contains(host)) {
logInfo("Host lost: " + host)
deadHosts += host
env.blockManager.master.notifyADeadHost(host)
private def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) {
val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration)
if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) {
failedGeneration(execId) = currentGeneration
logInfo("Executor lost: %s (generation %d)".format(execId, currentGeneration))
blockManagerMaster.removeExecutor(execId)
// TODO: This will be really slow if we keep accumulating shuffle map stages
for ((shuffleId, stage) <- shuffleToMapStage) {
stage.removeOutputsOnHost(host)
stage.removeOutputsOnExecutor(execId)
val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
mapOutputTracker.registerMapOutputs(shuffleId, locs, true)
}
cacheTracker.cacheLost(host)
updateCacheLocs()
if (shuffleToMapStage.isEmpty) {
mapOutputTracker.incrementGeneration()
}
clearCacheLocs()
} else {
logDebug("Additional executor lost message for " + execId +
"(generation " + currentGeneration + ")")
}
}
@ -528,7 +622,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
def abortStage(failedStage: Stage, reason: String) {
private def abortStage(failedStage: Stage, reason: String) {
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
@ -544,7 +638,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
/**
* Return true if one of stage's ancestors is target.
*/
def stageDependsOn(stage: Stage, target: Stage): Boolean = {
private def stageDependsOn(stage: Stage, target: Stage): Boolean = {
if (stage == target) {
return true
}
@ -571,7 +665,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
visitedRdds.contains(target.rdd)
}
def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
private def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
// If the partition is cached, return the cache locations
val cached = getCacheLocs(rdd)(partition)
if (cached != Nil) {
@ -597,7 +691,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
return Nil
}
def cleanup(cleanupTime: Long) {
private def cleanup(cleanupTime: Long) {
var sizeBefore = idToStage.size
idToStage.clearOldValues(cleanupTime)
logInfo("idToStage " + sizeBefore + " --> " + idToStage.size)

View file

@ -28,7 +28,7 @@ private[spark] case class CompletionEvent(
accumUpdates: Map[Long, Any])
extends DAGSchedulerEvent
private[spark] case class HostLost(host: String) extends DAGSchedulerEvent
private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent

View file

@ -5,5 +5,5 @@ package spark.scheduler
*/
private[spark] sealed trait JobResult
private[spark] case class JobSucceeded(results: Seq[_]) extends JobResult
private[spark] case object JobSucceeded extends JobResult
private[spark] case class JobFailed(exception: Exception) extends JobResult

View file

@ -3,10 +3,12 @@ package spark.scheduler
import scala.collection.mutable.ArrayBuffer
/**
* An object that waits for a DAGScheduler job to complete.
* An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
* results to the given handler function.
*/
private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null)
private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit)
extends JobListener {
private var finishedTasks = 0
private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
@ -17,11 +19,11 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
if (jobFinished) {
throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
}
taskResults(index) = result
resultHandler(index, result.asInstanceOf[T])
finishedTasks += 1
if (finishedTasks == totalTasks) {
jobFinished = true
jobResult = JobSucceeded(taskResults)
jobResult = JobSucceeded
this.notifyAll()
}
}
@ -38,7 +40,7 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener {
}
}
def getResult(): JobResult = synchronized {
def awaitResult(): JobResult = synchronized {
while (!jobFinished) {
this.wait()
}

View file

@ -8,19 +8,19 @@ import java.io.{ObjectOutput, ObjectInput, Externalizable}
* task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
* The map output sizes are compressed using MapOutputTracker.compressSize.
*/
private[spark] class MapStatus(var address: BlockManagerId, var compressedSizes: Array[Byte])
private[spark] class MapStatus(var location: BlockManagerId, var compressedSizes: Array[Byte])
extends Externalizable {
def this() = this(null, null) // For deserialization only
def writeExternal(out: ObjectOutput) {
address.writeExternal(out)
location.writeExternal(out)
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
}
def readExternal(in: ObjectInput) {
address = new BlockManagerId(in)
location = BlockManagerId(in)
compressedSizes = new Array[Byte](in.readInt())
in.readFully(compressedSizes)
}

View file

@ -72,9 +72,11 @@ private[spark] class ResultTask[T, U](
override def run(attemptId: Long): U = {
val context = new TaskContext(stageId, partition, attemptId)
val result = func(context, rdd.iterator(split, context))
try {
func(context, rdd.iterator(split, context))
} finally {
context.executeOnCompleteCallbacks()
result
}
}
override def preferredLocations: Seq[String] = locs

View file

@ -32,7 +32,7 @@ private[spark] object ShuffleMapTask {
return old
} else {
val out = new ByteArrayOutputStream
val ser = SparkEnv.get.closureSerializer.newInstance
val ser = SparkEnv.get.closureSerializer.newInstance()
val objOut = ser.serializeStream(new GZIPOutputStream(out))
objOut.writeObject(rdd)
objOut.writeObject(dep)
@ -48,7 +48,7 @@ private[spark] object ShuffleMapTask {
synchronized {
val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val ser = SparkEnv.get.closureSerializer.newInstance
val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
@ -81,7 +81,7 @@ private[spark] class ShuffleMapTask(
with Externalizable
with Logging {
def this() = this(0, null, null, 0, null)
protected def this() = this(0, null, null, 0, null)
var split = if (rdd == null) {
null
@ -117,18 +117,16 @@ private[spark] class ShuffleMapTask(
override def run(attemptId: Long): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions
val partitioner = dep.partitioner
val taskContext = new TaskContext(stageId, partition, attemptId)
try {
// Partition the map output.
val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
for (elem <- rdd.iterator(split, taskContext)) {
val pair = elem.asInstanceOf[(Any, Any)]
val bucketId = partitioner.getPartition(pair._1)
val bucketId = dep.partitioner.getPartition(pair._1)
buckets(bucketId) += pair
}
val bucketIterators = buckets.map(_.iterator)
val compressedSizes = new Array[Byte](numOutputSplits)
@ -136,15 +134,16 @@ private[spark] class ShuffleMapTask(
for (i <- 0 until numOutputSplits) {
val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
// Get a Scala iterator from Java map
val iter: Iterator[(Any, Any)] = bucketIterators(i)
val iter: Iterator[(Any, Any)] = buckets(i).iterator
val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
compressedSizes(i) = MapOutputTracker.compressSize(size)
}
return new MapStatus(blockManager.blockManagerId, compressedSizes)
} finally {
// Execute the callbacks on task completion.
taskContext.executeOnCompleteCallbacks()
return new MapStatus(blockManager.blockManagerId, compressedSizes)
}
}
override def preferredLocations: Seq[String] = locs

View file

@ -32,6 +32,9 @@ private[spark] class Stage(
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
var numAvailableOutputs = 0
/** When first task was submitted to scheduler. */
var submissionTime: Option[Long] = None
private var nextAttemptId = 0
def isAvailable: Boolean = {
@ -51,18 +54,18 @@ private[spark] class Stage(
def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) {
val prevList = outputLocs(partition)
val newList = prevList.filterNot(_.address == bmAddress)
val newList = prevList.filterNot(_.location == bmAddress)
outputLocs(partition) = newList
if (prevList != Nil && newList == Nil) {
numAvailableOutputs -= 1
}
}
def removeOutputsOnHost(host: String) {
def removeOutputsOnExecutor(execId: String) {
var becameUnavailable = false
for (partition <- 0 until numPartitions) {
val prevList = outputLocs(partition)
val newList = prevList.filterNot(_.address.ip == host)
val newList = prevList.filterNot(_.location.executorId == execId)
outputLocs(partition) = newList
if (prevList != Nil && newList == Nil) {
becameUnavailable = true
@ -70,7 +73,8 @@ private[spark] class Stage(
}
}
if (becameUnavailable) {
logInfo("%s is now unavailable on %s (%d/%d, %s)".format(this, host, numAvailableOutputs, numPartitions, isAvailable))
logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format(
this, execId, numAvailableOutputs, numPartitions, isAvailable))
}
}
@ -82,7 +86,7 @@ private[spark] class Stage(
def origin: String = rdd.origin
override def toString = "Stage " + id // + ": [RDD = " + rdd.id + ", isShuffle = " + isShuffleMap + "]"
override def toString = "Stage " + id
override def hashCode(): Int = id
}

View file

@ -12,7 +12,7 @@ private[spark] trait TaskSchedulerListener {
def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit
// A node was lost from the cluster.
def hostLost(host: String): Unit
def executorLost(execId: String): Unit
// The TaskScheduler wants to abort an entire task set.
def taskSetFailed(taskSet: TaskSet, reason: String): Unit

View file

@ -27,19 +27,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager]
val taskIdToTaskSetId = new HashMap[Long, String]
val taskIdToSlaveId = new HashMap[Long, String]
val taskIdToExecutorId = new HashMap[Long, String]
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
// Incrementing Mesos task IDs
val nextTaskId = new AtomicLong(0)
// Which hosts in the cluster are alive (contains hostnames)
val hostsAlive = new HashSet[String]
// Which executor IDs we have executors on
val activeExecutorIds = new HashSet[String]
// Which slave IDs we have executors on
val slaveIdsWithExecutors = new HashSet[String]
// The set of executors we have on each host; this is used to compute hostsAlive, which
// in turn is used to decide when we can attain data locality on a given host
val executorsByHost = new HashMap[String, HashSet[String]]
val slaveIdToHost = new HashMap[String, String]
val executorIdToHost = new HashMap[String, String]
// JAR server, if any JARs were added by the user to the SparkContext
var jarServer: HttpServer = null
@ -85,7 +86,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
def submitTasks(taskSet: TaskSet) {
override def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
@ -102,7 +103,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
activeTaskSets -= manager.taskSet.id
activeTaskSetsQueue -= manager
taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
taskIdToSlaveId --= taskSetTaskIds(manager.taskSet.id)
taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
taskSetTaskIds.remove(manager.taskSet.id)
}
}
@ -117,8 +118,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
SparkEnv.set(sc.env)
// Mark each slave as alive and remember its hostname
for (o <- offers) {
slaveIdToHost(o.slaveId) = o.hostname
hostsAlive += o.hostname
executorIdToHost(o.executorId) = o.hostname
}
// Build a list of tasks to assign to each slave
val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
@ -128,16 +128,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
do {
launchedTask = false
for (i <- 0 until offers.size) {
val sid = offers(i).slaveId
val execId = offers(i).executorId
val host = offers(i).hostname
manager.slaveOffer(sid, host, availableCpus(i)) match {
manager.slaveOffer(execId, host, availableCpus(i)) match {
case Some(task) =>
tasks(i) += task
val tid = task.taskId
taskIdToTaskSetId(tid) = manager.taskSet.id
taskSetTaskIds(manager.taskSet.id) += tid
taskIdToSlaveId(tid) = sid
slaveIdsWithExecutors += sid
taskIdToExecutorId(tid) = execId
activeExecutorIds += execId
if (!executorsByHost.contains(host)) {
executorsByHost(host) = new HashSet()
}
executorsByHost(host) += execId
availableCpus(i) -= 1
launchedTask = true
@ -152,25 +156,21 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
var taskSetToUpdate: Option[TaskSetManager] = None
var failedHost: Option[String] = None
var failedExecutor: Option[String] = None
var taskFailed = false
synchronized {
try {
if (state == TaskState.LOST && taskIdToSlaveId.contains(tid)) {
// We lost the executor on this slave, so remember that it's gone
val slaveId = taskIdToSlaveId(tid)
val host = slaveIdToHost(slaveId)
if (hostsAlive.contains(host)) {
slaveIdsWithExecutors -= slaveId
hostsAlive -= host
activeTaskSetsQueue.foreach(_.hostLost(host))
failedHost = Some(host)
if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
// We lost this entire executor, so remember that it's gone
val execId = taskIdToExecutorId(tid)
if (activeExecutorIds.contains(execId)) {
removeExecutor(execId)
failedExecutor = Some(execId)
}
}
taskIdToTaskSetId.get(tid) match {
case Some(taskSetId) =>
if (activeTaskSets.contains(taskSetId)) {
//activeTaskSets(taskSetId).statusUpdate(status)
taskSetToUpdate = Some(activeTaskSets(taskSetId))
}
if (TaskState.isFinished(state)) {
@ -178,7 +178,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (taskSetTaskIds.contains(taskSetId)) {
taskSetTaskIds(taskSetId) -= tid
}
taskIdToSlaveId.remove(tid)
taskIdToExecutorId.remove(tid)
}
if (state == TaskState.FAILED) {
taskFailed = true
@ -190,12 +190,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
case e: Exception => logError("Exception in statusUpdate", e)
}
}
// Update the task set and DAGScheduler without holding a lock on this, because that can deadlock
// Update the task set and DAGScheduler without holding a lock on this, since that can deadlock
if (taskSetToUpdate != None) {
taskSetToUpdate.get.statusUpdate(tid, state, serializedData)
}
if (failedHost != None) {
listener.hostLost(failedHost.get)
if (failedExecutor != None) {
listener.executorLost(failedExecutor.get)
backend.reviveOffers()
}
if (taskFailed) {
@ -249,27 +249,42 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
def slaveLost(slaveId: String, reason: ExecutorLossReason) {
var failedHost: Option[String] = None
def executorLost(executorId: String, reason: ExecutorLossReason) {
var failedExecutor: Option[String] = None
synchronized {
val host = slaveIdToHost(slaveId)
if (hostsAlive.contains(host)) {
logError("Lost an executor on " + host + ": " + reason)
slaveIdsWithExecutors -= slaveId
hostsAlive -= host
activeTaskSetsQueue.foreach(_.hostLost(host))
failedHost = Some(host)
if (activeExecutorIds.contains(executorId)) {
val host = executorIdToHost(executorId)
logError("Lost executor %s on %s: %s".format(executorId, host, reason))
removeExecutor(executorId)
failedExecutor = Some(executorId)
} else {
// We may get multiple slaveLost() calls with different loss reasons. For example, one
// We may get multiple executorLost() calls with different loss reasons. For example, one
// may be triggered by a dropped connection from the slave while another may be a report
// of executor termination from Mesos. We produce log messages for both so we eventually
// report the termination reason.
logError("Lost an executor on " + host + " (already removed): " + reason)
logError("Lost an executor " + executorId + " (already removed): " + reason)
}
}
if (failedHost != None) {
listener.hostLost(failedHost.get)
// Call listener.executorLost without holding the lock on this to prevent deadlock
if (failedExecutor != None) {
listener.executorLost(failedExecutor.get)
backend.reviveOffers()
}
}
/** Get a list of hosts that currently have executors */
def hostsAlive: scala.collection.Set[String] = executorsByHost.keySet
/** Remove an executor from all our data structures and mark it as lost */
private def removeExecutor(executorId: String) {
activeExecutorIds -= executorId
val host = executorIdToHost(executorId)
val execs = executorsByHost.getOrElse(host, new HashSet)
execs -= executorId
if (execs.isEmpty) {
executorsByHost -= host
}
executorIdToHost -= executorId
activeTaskSetsQueue.foreach(_.executorLost(executorId, host))
}
}

View file

@ -1,5 +1,7 @@
package spark.scheduler.cluster
import spark.Utils
/**
* A backend interface for cluster scheduling systems that allows plugging in different ones under
* ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as
@ -11,5 +13,15 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit
def defaultParallelism(): Int
// Memory used by each executor (in megabytes)
protected val executorMemory = {
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
Option(System.getProperty("spark.executor.memory"))
.orElse(Option(System.getenv("SPARK_MEM")))
.map(Utils.memoryStringToMb)
.getOrElse(512)
}
// TODO: Probably want to add a killTask too
}

View file

@ -1,4 +0,0 @@
package spark.scheduler.cluster
private[spark]
class SlaveResources(val slaveId: String, val hostname: String, val coresFree: Int) {}

View file

@ -19,34 +19,25 @@ private[spark] class SparkDeploySchedulerBackend(
var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _
val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt
val executorIdToSlaveId = new HashMap[String, String]
// Memory used by each executor (in megabytes)
val executorMemory = {
if (System.getenv("SPARK_MEM") != null) {
Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
} else {
512
}
}
override def start() {
super.start()
val masterUrl = "akka://spark@%s:%s/user/%s".format(
System.getProperty("spark.master.host"), System.getProperty("spark.master.port"),
// The endpoint for executors to talk to us
val driverUrl = "akka://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
StandaloneSchedulerBackend.ACTOR_NAME)
val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}")
val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}")
val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command)
val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone"))
val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome)
client = new Client(sc.env.actorSystem, master, jobDesc, this)
client.start()
}
override def stop() {
stopping = true;
stopping = true
super.stop()
client.stop()
if (shutdownCallback != null) {
@ -54,35 +45,28 @@ private[spark] class SparkDeploySchedulerBackend(
}
}
def connected(jobId: String) {
override def connected(jobId: String) {
logInfo("Connected to Spark cluster with job ID " + jobId)
}
def disconnected() {
override def disconnected() {
if (!stopping) {
logError("Disconnected from Spark cluster!")
scheduler.error("Disconnected from Spark cluster")
}
}
def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) {
executorIdToSlaveId += id -> workerId
override def executorAdded(executorId: String, workerId: String, host: String, cores: Int, memory: Int) {
logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format(
id, host, cores, Utils.memoryMegabytesToString(memory)))
executorId, host, cores, Utils.memoryMegabytesToString(memory)))
}
def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {
override def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) {
val reason: ExecutorLossReason = exitStatus match {
case Some(code) => ExecutorExited(code)
case None => SlaveLost(message)
}
logInfo("Executor %s removed: %s".format(id, message))
executorIdToSlaveId.get(id) match {
case Some(slaveId) =>
executorIdToSlaveId.remove(id)
scheduler.slaveLost(slaveId, reason)
case None =>
logInfo("No slave ID known for executor %s".format(id))
}
logInfo("Executor %s removed: %s".format(executorId, message))
scheduler.executorLost(executorId, reason)
}
}

View file

@ -6,32 +6,34 @@ import spark.util.SerializableBuffer
private[spark] sealed trait StandaloneClusterMessage extends Serializable
// Master to slaves
// Driver to executors
private[spark]
case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage
private[spark]
case class RegisteredSlave(sparkProperties: Seq[(String, String)]) extends StandaloneClusterMessage
case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
extends StandaloneClusterMessage
private[spark]
case class RegisterSlaveFailed(message: String) extends StandaloneClusterMessage
case class RegisterExecutorFailed(message: String) extends StandaloneClusterMessage
// Slaves to master
// Executors to driver
private[spark]
case class RegisterSlave(slaveId: String, host: String, cores: Int) extends StandaloneClusterMessage
case class RegisterExecutor(executorId: String, host: String, cores: Int)
extends StandaloneClusterMessage
private[spark]
case class StatusUpdate(slaveId: String, taskId: Long, state: TaskState, data: SerializableBuffer)
case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer)
extends StandaloneClusterMessage
private[spark]
object StatusUpdate {
/** Alternate factory method that takes a ByteBuffer directly for the data field */
def apply(slaveId: String, taskId: Long, state: TaskState, data: ByteBuffer): StatusUpdate = {
StatusUpdate(slaveId, taskId, state, new SerializableBuffer(data))
def apply(executorId: String, taskId: Long, state: TaskState, data: ByteBuffer): StatusUpdate = {
StatusUpdate(executorId, taskId, state, new SerializableBuffer(data))
}
}
// Internal messages in master
// Internal messages in driver
private[spark] case object ReviveOffers extends StandaloneClusterMessage
private[spark] case object StopMaster extends StandaloneClusterMessage
private[spark] case object StopDriver extends StandaloneClusterMessage

View file

@ -23,13 +23,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
var totalCoreCount = new AtomicInteger(0)
class MasterActor(sparkProperties: Seq[(String, String)]) extends Actor {
val slaveActor = new HashMap[String, ActorRef]
val slaveAddress = new HashMap[String, Address]
val slaveHost = new HashMap[String, String]
class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor {
val executorActor = new HashMap[String, ActorRef]
val executorAddress = new HashMap[String, Address]
val executorHost = new HashMap[String, String]
val freeCores = new HashMap[String, Int]
val actorToSlaveId = new HashMap[ActorRef, String]
val addressToSlaveId = new HashMap[Address, String]
val actorToExecutorId = new HashMap[ActorRef, String]
val addressToExecutorId = new HashMap[Address, String]
override def preStart() {
// Listen for remote client disconnection events, since they don't go through Akka's watch()
@ -37,86 +37,86 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
def receive = {
case RegisterSlave(slaveId, host, cores) =>
if (slaveActor.contains(slaveId)) {
sender ! RegisterSlaveFailed("Duplicate slave ID: " + slaveId)
case RegisterExecutor(executorId, host, cores) =>
if (executorActor.contains(executorId)) {
sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId)
} else {
logInfo("Registered slave: " + sender + " with ID " + slaveId)
sender ! RegisteredSlave(sparkProperties)
logInfo("Registered executor: " + sender + " with ID " + executorId)
sender ! RegisteredExecutor(sparkProperties)
context.watch(sender)
slaveActor(slaveId) = sender
slaveHost(slaveId) = host
freeCores(slaveId) = cores
slaveAddress(slaveId) = sender.path.address
actorToSlaveId(sender) = slaveId
addressToSlaveId(sender.path.address) = slaveId
executorActor(executorId) = sender
executorHost(executorId) = host
freeCores(executorId) = cores
executorAddress(executorId) = sender.path.address
actorToExecutorId(sender) = executorId
addressToExecutorId(sender.path.address) = executorId
totalCoreCount.addAndGet(cores)
makeOffers()
}
case StatusUpdate(slaveId, taskId, state, data) =>
case StatusUpdate(executorId, taskId, state, data) =>
scheduler.statusUpdate(taskId, state, data.value)
if (TaskState.isFinished(state)) {
freeCores(slaveId) += 1
makeOffers(slaveId)
freeCores(executorId) += 1
makeOffers(executorId)
}
case ReviveOffers =>
makeOffers()
case StopMaster =>
case StopDriver =>
sender ! true
context.stop(self)
case Terminated(actor) =>
actorToSlaveId.get(actor).foreach(removeSlave(_, "Akka actor terminated"))
actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated"))
case RemoteClientDisconnected(transport, address) =>
addressToSlaveId.get(address).foreach(removeSlave(_, "remote Akka client disconnected"))
addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disconnected"))
case RemoteClientShutdown(transport, address) =>
addressToSlaveId.get(address).foreach(removeSlave(_, "remote Akka client shutdown"))
addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client shutdown"))
}
// Make fake resource offers on all slaves
// Make fake resource offers on all executors
def makeOffers() {
launchTasks(scheduler.resourceOffers(
slaveHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))}))
executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))}))
}
// Make fake resource offers on just one slave
def makeOffers(slaveId: String) {
// Make fake resource offers on just one executor
def makeOffers(executorId: String) {
launchTasks(scheduler.resourceOffers(
Seq(new WorkerOffer(slaveId, slaveHost(slaveId), freeCores(slaveId)))))
Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId)))))
}
// Launch tasks returned by a set of resource offers
def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
for (task <- tasks.flatten) {
freeCores(task.slaveId) -= 1
slaveActor(task.slaveId) ! LaunchTask(task)
freeCores(task.executorId) -= 1
executorActor(task.executorId) ! LaunchTask(task)
}
}
// Remove a disconnected slave from the cluster
def removeSlave(slaveId: String, reason: String) {
logInfo("Slave " + slaveId + " disconnected, so removing it")
val numCores = freeCores(slaveId)
actorToSlaveId -= slaveActor(slaveId)
addressToSlaveId -= slaveAddress(slaveId)
slaveActor -= slaveId
slaveHost -= slaveId
freeCores -= slaveId
slaveHost -= slaveId
def removeExecutor(executorId: String, reason: String) {
logInfo("Slave " + executorId + " disconnected, so removing it")
val numCores = freeCores(executorId)
actorToExecutorId -= executorActor(executorId)
addressToExecutorId -= executorAddress(executorId)
executorActor -= executorId
executorHost -= executorId
freeCores -= executorId
executorHost -= executorId
totalCoreCount.addAndGet(-numCores)
scheduler.slaveLost(slaveId, SlaveLost(reason))
scheduler.executorLost(executorId, SlaveLost(reason))
}
}
var masterActor: ActorRef = null
var driverActor: ActorRef = null
val taskIdsOnSlave = new HashMap[String, HashSet[String]]
def start() {
override def start() {
val properties = new ArrayBuffer[(String, String)]
val iterator = System.getProperties.entrySet.iterator
while (iterator.hasNext) {
@ -126,15 +126,15 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
properties += ((key, value))
}
}
masterActor = actorSystem.actorOf(
Props(new MasterActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
driverActor = actorSystem.actorOf(
Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
}
def stop() {
override def stop() {
try {
if (masterActor != null) {
if (driverActor != null) {
val timeout = 5.seconds
val future = masterActor.ask(StopMaster)(timeout)
val future = driverActor.ask(StopDriver)(timeout)
Await.result(future, timeout)
}
} catch {
@ -143,11 +143,11 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
}
def reviveOffers() {
masterActor ! ReviveOffers
override def reviveOffers() {
driverActor ! ReviveOffers
}
def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2)
override def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2)
}
private[spark] object StandaloneSchedulerBackend {

View file

@ -5,7 +5,7 @@ import spark.util.SerializableBuffer
private[spark] class TaskDescription(
val taskId: Long,
val slaveId: String,
val executorId: String,
val name: String,
_serializedTask: ByteBuffer)
extends Serializable {

View file

@ -4,7 +4,12 @@ package spark.scheduler.cluster
* Information about a running task attempt inside a TaskSet.
*/
private[spark]
class TaskInfo(val taskId: Long, val index: Int, val launchTime: Long, val host: String) {
class TaskInfo(
val taskId: Long,
val index: Int,
val launchTime: Long,
val executorId: String,
val host: String) {
var finishTime: Long = 0
var failed = false

View file

@ -17,10 +17,7 @@ import java.nio.ByteBuffer
/**
* Schedules the tasks within a single TaskSet in the ClusterScheduler.
*/
private[spark] class TaskSetManager(
sched: ClusterScheduler,
val taskSet: TaskSet)
extends Logging {
private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) extends Logging {
// Maximum time to wait to run a task in a preferred location (in ms)
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
@ -100,7 +97,7 @@ private[spark] class TaskSetManager(
}
// Add a task to all the pending-task lists that it should be on.
def addPendingTask(index: Int) {
private def addPendingTask(index: Int) {
val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
if (locations.size == 0) {
pendingTasksWithNoPrefs += index
@ -115,7 +112,7 @@ private[spark] class TaskSetManager(
// Return the pending tasks list for a given host, or an empty list if
// there is no map entry for that host
def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
pendingTasksForHost.getOrElse(host, ArrayBuffer())
}
@ -123,7 +120,7 @@ private[spark] class TaskSetManager(
// Return None if the list is empty.
// This method also cleans up any tasks in the list that have already
// been launched, since we want that to happen lazily.
def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
while (!list.isEmpty) {
val index = list.last
list.trimEnd(1)
@ -137,11 +134,12 @@ private[spark] class TaskSetManager(
// Return a speculative task for a given host if any are available. The task should not have an
// attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
// task must have a preference for this host (or no preferred locations at all).
def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
val hostsAlive = sched.hostsAlive
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
val localTask = speculatableTasks.find {
index =>
val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
val locations = tasks(index).preferredLocations.toSet & hostsAlive
val attemptLocs = taskAttempts(index).map(_.host)
(locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host)
}
@ -161,7 +159,7 @@ private[spark] class TaskSetManager(
// Dequeue a pending task for a given node and return its index.
// If localOnly is set to false, allow non-local tasks as well.
def findTask(host: String, localOnly: Boolean): Option[Int] = {
private def findTask(host: String, localOnly: Boolean): Option[Int] = {
val localTask = findTaskFromList(getPendingTasksForHost(host))
if (localTask != None) {
return localTask
@ -183,13 +181,13 @@ private[spark] class TaskSetManager(
// Does a host count as a preferred location for a task? This is true if
// either the task has preferred locations and this host is one, or it has
// no preferred locations (in which we still count the launch as preferred).
def isPreferredLocation(task: Task[_], host: String): Boolean = {
private def isPreferredLocation(task: Task[_], host: String): Boolean = {
val locs = task.preferredLocations
return (locs.contains(host) || locs.isEmpty)
}
// Respond to an offer of a single slave from the scheduler by finding a task
def slaveOffer(slaveId: String, host: String, availableCpus: Double): Option[TaskDescription] = {
def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = {
if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
val time = System.currentTimeMillis
val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
@ -201,12 +199,16 @@ private[spark] class TaskSetManager(
val taskId = sched.newTaskId()
// Figure out whether this should count as a preferred launch
val preferred = isPreferredLocation(task, host)
val prefStr = if (preferred) "preferred" else "non-preferred"
logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
taskSet.id, index, taskId, slaveId, host, prefStr))
val prefStr = if (preferred) {
"preferred"
} else {
"non-preferred, not one of " + task.preferredLocations.mkString(", ")
}
logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
taskSet.id, index, taskId, execId, host, prefStr))
// Do various bookkeeping
copiesRunning(index) += 1
val info = new TaskInfo(taskId, index, time, host)
val info = new TaskInfo(taskId, index, time, execId, host)
taskInfos(taskId) = info
taskAttempts(index) = info :: taskAttempts(index)
if (preferred) {
@ -220,7 +222,7 @@ private[spark] class TaskSetManager(
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
val taskName = "task %s:%d".format(taskSet.id, index)
return Some(new TaskDescription(taskId, slaveId, taskName, serializedTask))
return Some(new TaskDescription(taskId, execId, taskName, serializedTask))
}
case _ =>
}
@ -330,7 +332,7 @@ private[spark] class TaskSetManager(
if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %s:%d failed more than %d times; aborting job".format(
taskSet.id, index, MAX_TASK_FAILURES))
abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES))
abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
}
}
} else {
@ -352,19 +354,22 @@ private[spark] class TaskSetManager(
sched.taskSetFinished(this)
}
def hostLost(hostname: String) {
logInfo("Re-queueing tasks for " + hostname + " from TaskSet " + taskSet.id)
// If some task has preferred locations only on hostname, put it in the no-prefs list
// to avoid the wait from delay scheduling
def executorLost(execId: String, hostname: String) {
logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
val newHostsAlive = sched.hostsAlive
// If some task has preferred locations only on hostname, and there are no more executors there,
// put it in the no-prefs list to avoid the wait from delay scheduling
if (!newHostsAlive.contains(hostname)) {
for (index <- getPendingTasksForHost(hostname)) {
val newLocs = tasks(index).preferredLocations.toSet & sched.hostsAlive
val newLocs = tasks(index).preferredLocations.toSet & newHostsAlive
if (newLocs.isEmpty) {
pendingTasksWithNoPrefs += index
}
}
// Re-enqueue any tasks that ran on the failed host if this is a shuffle map stage
}
// Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
if (tasks(0).isInstanceOf[ShuffleMapTask]) {
for ((tid, info) <- taskInfos if info.host == hostname) {
for ((tid, info) <- taskInfos if info.executorId == execId) {
val index = taskInfos(tid).index
if (finished(index)) {
finished(index) = false
@ -378,7 +383,7 @@ private[spark] class TaskSetManager(
}
}
// Also re-enqueue any tasks that were running on the node
for ((tid, info) <- taskInfos if info.running && info.host == hostname) {
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
taskLost(tid, TaskState.KILLED, null)
}
}

View file

@ -1,8 +1,8 @@
package spark.scheduler.cluster
/**
* Represents free resources available on a worker node.
* Represents free resources available on an executor.
*/
private[spark]
class WorkerOffer(val slaveId: String, val hostname: String, val cores: Int) {
class WorkerOffer(val executorId: String, val hostname: String, val cores: Int) {
}

View file

@ -20,7 +20,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
with Logging {
var attemptId = new AtomicInteger(0)
var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
var threadPool = Utils.newDaemonFixedThreadPool(threads)
val env = SparkEnv.get
var listener: TaskSchedulerListener = null
@ -53,7 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
}
def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
logInfo("Running task " + idInJob)
logInfo("Running " + task)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
try {
@ -80,7 +80,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
val resultToReturn = ser.deserialize[Any](ser.serialize(result))
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
logInfo("Finished task " + idInJob)
logInfo("Finished " + task)
// If the threadpool has not already been shutdown, notify DAGScheduler
if (!Thread.currentThread().isInterrupted)
@ -116,16 +116,16 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File("."))
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File("."))
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
val url = new File(".", localName).toURI.toURL
val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
if (!classLoader.getURLs.contains(url)) {
logInfo("Adding " + url + " to class loader")
classLoader.addURL(url)

View file

@ -35,16 +35,6 @@ private[spark] class CoarseMesosSchedulerBackend(
val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures
// Memory used by each executor (in megabytes)
val executorMemory = {
if (System.getenv("SPARK_MEM") != null) {
Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
} else {
512
}
}
// Lock used to wait for scheduler to be registered
var isRegistered = false
val registeredLock = new Object()
@ -64,13 +54,9 @@ private[spark] class CoarseMesosSchedulerBackend(
val taskIdToSlaveId = new HashMap[Int, String]
val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed
val sparkHome = sc.getSparkHome() match {
case Some(path) =>
path
case None =>
throw new SparkException("Spark home is not set; set it through the spark.home system " +
"property, the SPARK_HOME environment variable or the SparkContext constructor")
}
val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException(
"Spark home is not set; set it through the spark.home system " +
"property, the SPARK_HOME environment variable or the SparkContext constructor"))
val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt
@ -108,11 +94,11 @@ private[spark] class CoarseMesosSchedulerBackend(
def createCommand(offer: Offer, numCores: Int): CommandInfo = {
val runScript = new File(sparkHome, "run").getCanonicalPath
val masterUrl = "akka://spark@%s:%s/user/%s".format(
System.getProperty("spark.master.host"), System.getProperty("spark.master.port"),
val driverUrl = "akka://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
StandaloneSchedulerBackend.ACTOR_NAME)
val command = "\"%s\" spark.executor.StandaloneExecutorBackend %s %s %s %d".format(
runScript, masterUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)
runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)
val environment = Environment.newBuilder()
sc.executorEnvs.foreach { case (key, value) =>
environment.addVariables(Environment.Variable.newBuilder()
@ -184,7 +170,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
/** Helper function to pull out a resource from a Mesos Resources protobuf */
def getResource(res: JList[Resource], name: String): Double = {
private def getResource(res: JList[Resource], name: String): Double = {
for (r <- res if r.getName == name) {
return r.getScalar.getValue
}
@ -193,7 +179,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
/** Build a Mesos resource protobuf object */
def createResource(resourceName: String, quantity: Double): Protos.Resource = {
private def createResource(resourceName: String, quantity: Double): Protos.Resource = {
Resource.newBuilder()
.setName(resourceName)
.setType(Value.Type.SCALAR)
@ -202,7 +188,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
/** Check whether a Mesos task state represents a finished task */
def isFinished(state: MesosTaskState) = {
private def isFinished(state: MesosTaskState) = {
state == MesosTaskState.TASK_FINISHED ||
state == MesosTaskState.TASK_FAILED ||
state == MesosTaskState.TASK_KILLED ||

View file

@ -29,16 +29,6 @@ private[spark] class MesosSchedulerBackend(
with MScheduler
with Logging {
// Memory used by each executor (in megabytes)
val EXECUTOR_MEMORY = {
if (System.getenv("SPARK_MEM") != null) {
Utils.memoryStringToMb(System.getenv("SPARK_MEM"))
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
} else {
512
}
}
// Lock used to wait for scheduler to be registered
var isRegistered = false
val registeredLock = new Object()
@ -51,7 +41,7 @@ private[spark] class MesosSchedulerBackend(
val taskIdToSlaveId = new HashMap[Long, String]
// An ExecutorInfo for our tasks
var executorInfo: ExecutorInfo = null
var execArgs: Array[Byte] = null
override def start() {
synchronized {
@ -70,19 +60,14 @@ private[spark] class MesosSchedulerBackend(
}
}.start()
executorInfo = createExecutorInfo()
waitForRegister()
}
}
def createExecutorInfo(): ExecutorInfo = {
val sparkHome = sc.getSparkHome() match {
case Some(path) =>
path
case None =>
throw new SparkException("Spark home is not set; set it through the spark.home system " +
"property, the SPARK_HOME environment variable or the SparkContext constructor")
}
def createExecutorInfo(execId: String): ExecutorInfo = {
val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException(
"Spark home is not set; set it through the spark.home system " +
"property, the SPARK_HOME environment variable or the SparkContext constructor"))
val execScript = new File(sparkHome, "spark-executor").getCanonicalPath
val environment = Environment.newBuilder()
sc.executorEnvs.foreach { case (key, value) =>
@ -94,14 +79,14 @@ private[spark] class MesosSchedulerBackend(
val memory = Resource.newBuilder()
.setName("mem")
.setType(Value.Type.SCALAR)
.setScalar(Value.Scalar.newBuilder().setValue(EXECUTOR_MEMORY).build())
.setScalar(Value.Scalar.newBuilder().setValue(executorMemory).build())
.build()
val command = CommandInfo.newBuilder()
.setValue(execScript)
.setEnvironment(environment)
.build()
ExecutorInfo.newBuilder()
.setExecutorId(ExecutorID.newBuilder().setValue("default").build())
.setExecutorId(ExecutorID.newBuilder().setValue(execId).build())
.setCommand(command)
.setData(ByteString.copyFrom(createExecArg()))
.addResources(memory)
@ -113,6 +98,7 @@ private[spark] class MesosSchedulerBackend(
* containing all the spark.* system properties in the form of (String, String) pairs.
*/
private def createExecArg(): Array[Byte] = {
if (execArgs == null) {
val props = new HashMap[String, String]
val iterator = System.getProperties.entrySet.iterator
while (iterator.hasNext) {
@ -123,7 +109,9 @@ private[spark] class MesosSchedulerBackend(
}
}
// Serialize the map as an array of (String, String) pairs
return Utils.serialize(props.toArray)
execArgs = Utils.serialize(props.toArray)
}
return execArgs
}
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
@ -163,7 +151,7 @@ private[spark] class MesosSchedulerBackend(
def enoughMemory(o: Offer) = {
val mem = getResource(o.getResourcesList, "mem")
val slaveId = o.getSlaveId.getValue
mem >= EXECUTOR_MEMORY || slaveIdsWithExecutors.contains(slaveId)
mem >= executorMemory || slaveIdsWithExecutors.contains(slaveId)
}
for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) {
@ -220,7 +208,7 @@ private[spark] class MesosSchedulerBackend(
return MesosTaskInfo.newBuilder()
.setTaskId(taskId)
.setSlaveId(SlaveID.newBuilder().setValue(slaveId).build())
.setExecutor(executorInfo)
.setExecutor(createExecutorInfo(slaveId))
.setName(task.name)
.addResources(cpuResource)
.setData(ByteString.copyFrom(task.serializedTask))
@ -272,7 +260,7 @@ private[spark] class MesosSchedulerBackend(
synchronized {
slaveIdsWithExecutors -= slaveId.getValue
}
scheduler.slaveLost(slaveId.getValue, reason)
scheduler.executorLost(slaveId.getValue, reason)
}
override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) {

View file

@ -16,7 +16,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
import spark.{CacheTracker, Logging, SizeEstimator, SparkEnv, SparkException, Utils}
import spark.{Logging, SizeEstimator, SparkEnv, SparkException, Utils}
import spark.network._
import spark.serializer.Serializer
import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap}
@ -30,6 +30,7 @@ extends Exception(message)
private[spark]
class BlockManager(
executorId: String,
actorSystem: ActorSystem,
val master: BlockManagerMaster,
val serializer: Serializer,
@ -68,11 +69,8 @@ class BlockManager(
val connectionManager = new ConnectionManager(0)
implicit val futureExecContext = connectionManager.futureExecContext
val connectionManagerId = connectionManager.id
val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port)
// TODO: This will be removed after cacheTracker is removed from the code base.
var cacheTracker: CacheTracker = null
val blockManagerId = BlockManagerId(
executorId, connectionManager.id.host, connectionManager.id.port)
// Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
// for receiving shuffle outputs)
@ -93,7 +91,10 @@ class BlockManager(
val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
@volatile private var shuttingDown = false
// Pending reregistration action being executed asynchronously or null if none
// is pending. Accesses should synchronize on asyncReregisterLock.
var asyncReregisterTask: Future[Unit] = null
val asyncReregisterLock = new Object
private def heartBeat() {
if (!master.sendHeartBeat(blockManagerId)) {
@ -109,8 +110,9 @@ class BlockManager(
/**
* Construct a BlockManager with a memory limit set based on system properties.
*/
def this(actorSystem: ActorSystem, master: BlockManagerMaster, serializer: Serializer) = {
this(actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties)
def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster,
serializer: Serializer) = {
this(execId, actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties)
}
/**
@ -150,6 +152,8 @@ class BlockManager(
/**
* Reregister with the master and report all blocks to it. This will be called by the heart beat
* thread if our heartbeat to the block amnager indicates that we were not registered.
*
* Note that this method must be called without any BlockInfo locks held.
*/
def reregister() {
// TODO: We might need to rate limit reregistering.
@ -158,6 +162,32 @@ class BlockManager(
reportAllBlocks()
}
/**
* Reregister with the master sometime soon.
*/
def asyncReregister() {
asyncReregisterLock.synchronized {
if (asyncReregisterTask == null) {
asyncReregisterTask = Future[Unit] {
reregister()
asyncReregisterLock.synchronized {
asyncReregisterTask = null
}
}
}
}
}
/**
* For testing. Wait for any pending asynchronous reregistration; otherwise, do nothing.
*/
def waitForAsyncReregister() {
val task = asyncReregisterTask
if (task != null) {
Await.ready(task, Duration.Inf)
}
}
/**
* Get storage level of local block. If no info exists for the block, then returns null.
*/
@ -173,7 +203,7 @@ class BlockManager(
if (needReregister) {
logInfo("Got told to reregister updating block " + blockId)
// Reregistering will report our new block for free.
reregister()
asyncReregister()
}
logDebug("Told master about block " + blockId)
}
@ -191,7 +221,7 @@ class BlockManager(
case level =>
val inMem = level.useMemory && memoryStore.contains(blockId)
val onDisk = level.useDisk && diskStore.contains(blockId)
val storageLevel = new StorageLevel(onDisk, inMem, level.deserialized, level.replication)
val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication)
val memSize = if (inMem) memoryStore.getSize(blockId) else 0L
val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L
(storageLevel, memSize, diskSize, info.tellMaster)
@ -213,7 +243,7 @@ class BlockManager(
val startTimeMs = System.currentTimeMillis
var managers = master.getLocations(blockId)
val locations = managers.map(_.ip)
logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs))
logDebug("Got block locations in " + Utils.getUsedTimeMs(startTimeMs))
return locations
}
@ -223,7 +253,7 @@ class BlockManager(
def getLocations(blockIds: Array[String]): Array[Seq[String]] = {
val startTimeMs = System.currentTimeMillis
val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray
logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
return locations
}
@ -615,7 +645,7 @@ class BlockManager(
var size = 0L
myInfo.synchronized {
logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
if (level.useMemory) {
@ -647,8 +677,10 @@ class BlockManager(
}
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
// Replicate block if required
if (level.replication > 1) {
val remoteStartTime = System.currentTimeMillis
// Serialize the block if not already done
if (bytesAfterPut == null) {
if (valuesAfterPut == null) {
@ -658,16 +690,10 @@ class BlockManager(
bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
}
replicate(blockId, bytesAfterPut, level)
logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime))
}
BlockManager.dispose(bytesAfterPut)
// TODO: This code will be removed when CacheTracker is gone.
if (blockId.startsWith("rdd")) {
notifyCacheTracker(blockId)
}
logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs))
return size
}
@ -733,11 +759,6 @@ class BlockManager(
}
}
// TODO: This code will be removed when CacheTracker is gone.
if (blockId.startsWith("rdd")) {
notifyCacheTracker(blockId)
}
// If replication had started, then wait for it to finish
if (level.replication > 1) {
if (replicationFuture == null) {
@ -760,8 +781,7 @@ class BlockManager(
*/
var cachedPeers: Seq[BlockManagerId] = null
private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) {
val tLevel: StorageLevel =
new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
if (cachedPeers == null) {
cachedPeers = master.getPeers(blockManagerId, level.replication - 1)
}
@ -780,16 +800,6 @@ class BlockManager(
}
}
// TODO: This code will be removed when CacheTracker is gone.
private def notifyCacheTracker(key: String) {
if (cacheTracker != null) {
val rddInfo = key.split("_")
val rddId: Int = rddInfo(1).toInt
val partition: Int = rddInfo(2).toInt
cacheTracker.notifyFromBlockManager(spark.AddedToCache(rddId, partition, host))
}
}
/**
* Read a block consisting of a single object.
*/
@ -940,6 +950,7 @@ class BlockManager(
blockInfo.clear()
memoryStore.clear()
diskStore.clear()
metadataCleaner.cancel()
logInfo("BlockManager stopped")
}
}
@ -968,7 +979,7 @@ object BlockManager extends Logging {
*/
def dispose(buffer: ByteBuffer) {
if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) {
logDebug("Unmapping " + buffer)
logTrace("Unmapping " + buffer)
if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) {
buffer.asInstanceOf[DirectBuffer].cleaner().clean()
}

View file

@ -3,38 +3,67 @@ package spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.util.concurrent.ConcurrentHashMap
/**
* This class represent an unique identifier for a BlockManager.
* The first 2 constructors of this class is made private to ensure that
* BlockManagerId objects can be created only using the factory method in
* [[spark.storage.BlockManager$]]. This allows de-duplication of ID objects.
* Also, constructor parameters are private to ensure that parameters cannot
* be modified from outside this class.
*/
private[spark] class BlockManagerId private (
private var executorId_ : String,
private var ip_ : String,
private var port_ : Int
) extends Externalizable {
private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable {
def this() = this(null, 0) // For deserialization only
private def this() = this(null, null, 0) // For deserialization only
def this(in: ObjectInput) = this(in.readUTF(), in.readInt())
def executorId: String = executorId_
def ip: String = ip_
def port: Int = port_
override def writeExternal(out: ObjectOutput) {
out.writeUTF(ip)
out.writeInt(port)
out.writeUTF(executorId_)
out.writeUTF(ip_)
out.writeInt(port_)
}
override def readExternal(in: ObjectInput) {
ip = in.readUTF()
port = in.readInt()
executorId_ = in.readUTF()
ip_ = in.readUTF()
port_ = in.readInt()
}
@throws(classOf[IOException])
private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this)
override def toString = "BlockManagerId(" + ip + ", " + port + ")"
override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, ip, port)
override def hashCode = ip.hashCode * 41 + port
override def hashCode: Int = (executorId.hashCode * 41 + ip.hashCode) * 41 + port
override def equals(that: Any) = that match {
case id: BlockManagerId => port == id.port && ip == id.ip
case _ => false
case id: BlockManagerId =>
executorId == id.executorId && port == id.port && ip == id.ip
case _ =>
false
}
}
private[spark] object BlockManagerId {
def apply(execId: String, ip: String, port: Int) =
getCachedBlockManagerId(new BlockManagerId(execId, ip, port))
def apply(in: ObjectInput) = {
val obj = new BlockManagerId()
obj.readExternal(in)
getCachedBlockManagerId(obj)
}
val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()
def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {

View file

@ -1,6 +1,10 @@
package spark.storage
import scala.collection.mutable.ArrayBuffer
import java.io._
import java.util.{HashMap => JHashMap}
import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.util.Random
import akka.actor.{Actor, ActorRef, ActorSystem, Props}
@ -11,52 +15,49 @@ import akka.util.duration._
import spark.{Logging, SparkException, Utils}
private[spark] class BlockManagerMaster(
val actorSystem: ActorSystem,
isMaster: Boolean,
isDriver: Boolean,
isLocal: Boolean,
masterIp: String,
masterPort: Int)
driverIp: String,
driverPort: Int)
extends Logging {
val AKKA_RETRY_ATTEMPS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt
val MASTER_AKKA_ACTOR_NAME = "BlockMasterManager"
val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager"
val DEFAULT_MANAGER_IP: String = Utils.localHostName()
val DRIVER_AKKA_ACTOR_NAME = "BlockMasterManager"
val timeout = 10.seconds
var masterActor: ActorRef = {
if (isMaster) {
val masterActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)),
name = MASTER_AKKA_ACTOR_NAME)
var driverActor: ActorRef = {
if (isDriver) {
val driverActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)),
name = DRIVER_AKKA_ACTOR_NAME)
logInfo("Registered BlockManagerMaster Actor")
masterActor
driverActor
} else {
val url = "akka://spark@%s:%s/user/%s".format(masterIp, masterPort, MASTER_AKKA_ACTOR_NAME)
val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, DRIVER_AKKA_ACTOR_NAME)
logInfo("Connecting to BlockManagerMaster: " + url)
actorSystem.actorFor(url)
}
}
/** Remove a dead host from the master actor. This is only called on the master side. */
def notifyADeadHost(host: String) {
tell(RemoveHost(host))
logInfo("Removed " + host + " successfully in notifyADeadHost")
/** Remove a dead executor from the driver actor. This is only called on the driver side. */
def removeExecutor(execId: String) {
tell(RemoveExecutor(execId))
logInfo("Removed " + execId + " successfully in removeExecutor")
}
/**
* Send the master actor a heart beat from the slave. Returns true if everything works out,
* false if the master does not know about the given block manager, which means the block
* Send the driver actor a heart beat from the slave. Returns true if everything works out,
* false if the driver does not know about the given block manager, which means the block
* manager should re-register.
*/
def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = {
askMasterWithRetry[Boolean](HeartBeat(blockManagerId))
askDriverWithReply[Boolean](HeartBeat(blockManagerId))
}
/** Register the BlockManager's id with the master. */
/** Register the BlockManager's id with the driver. */
def registerBlockManager(
blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
logInfo("Trying to register BlockManager")
@ -70,25 +71,25 @@ private[spark] class BlockManagerMaster(
storageLevel: StorageLevel,
memSize: Long,
diskSize: Long): Boolean = {
val res = askMasterWithRetry[Boolean](
val res = askDriverWithReply[Boolean](
UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize))
logInfo("Updated info of block " + blockId)
res
}
/** Get locations of the blockId from the master */
/** Get locations of the blockId from the driver */
def getLocations(blockId: String): Seq[BlockManagerId] = {
askMasterWithRetry[Seq[BlockManagerId]](GetLocations(blockId))
askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId))
}
/** Get locations of multiple blockIds from the master */
/** Get locations of multiple blockIds from the driver */
def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
askMasterWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
}
/** Get ids of other nodes in the cluster from the master */
/** Get ids of other nodes in the cluster from the driver */
def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = {
val result = askMasterWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers))
val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers))
if (result.length != numPeers) {
throw new SparkException(
"Error getting peers, only got " + result.size + " instead of " + numPeers)
@ -98,10 +99,10 @@ private[spark] class BlockManagerMaster(
/**
* Remove a block from the slaves that have it. This can only be used to remove
* blocks that the master knows about.
* blocks that the driver knows about.
*/
def removeBlock(blockId: String) {
askMasterWithRetry(RemoveBlock(blockId))
askDriverWithReply(RemoveBlock(blockId))
}
/**
@ -111,41 +112,45 @@ private[spark] class BlockManagerMaster(
* amount of remaining memory.
*/
def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = {
askMasterWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
}
/** Stop the master actor, called only on the Spark master node */
def getStorageStatus: Array[StorageStatus] = {
askDriverWithReply[ArrayBuffer[StorageStatus]](GetStorageStatus).toArray
}
/** Stop the driver actor, called only on the Spark driver node */
def stop() {
if (masterActor != null) {
if (driverActor != null) {
tell(StopBlockManagerMaster)
masterActor = null
driverActor = null
logInfo("BlockManagerMaster stopped")
}
}
/** Send a one-way message to the master actor, to which we expect it to reply with true. */
private def tell(message: Any) {
if (!askMasterWithRetry[Boolean](message)) {
if (!askDriverWithReply[Boolean](message)) {
throw new SparkException("BlockManagerMasterActor returned false, expected true.")
}
}
/**
* Send a message to the master actor and get its result within a default timeout, or
* Send a message to the driver actor and get its result within a default timeout, or
* throw a SparkException if this fails.
*/
private def askMasterWithRetry[T](message: Any): T = {
private def askDriverWithReply[T](message: Any): T = {
// TODO: Consider removing multiple attempts
if (masterActor == null) {
throw new SparkException("Error sending message to BlockManager as masterActor is null " +
if (driverActor == null) {
throw new SparkException("Error sending message to BlockManager as driverActor is null " +
"[message = " + message + "]")
}
var attempts = 0
var lastException: Exception = null
while (attempts < AKKA_RETRY_ATTEMPS) {
while (attempts < AKKA_RETRY_ATTEMPTS) {
attempts += 1
try {
val future = masterActor.ask(message)(timeout)
val future = driverActor.ask(message)(timeout)
val result = Await.result(future, timeout)
if (result == null) {
throw new Exception("BlockManagerMaster returned null")

View file

@ -23,9 +23,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
private val blockManagerInfo =
new HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo]
// Mapping from host name to block manager id. We allow multiple block managers
// on the same host name (ip).
private val blockManagerIdByHost = new HashMap[String, ArrayBuffer[BlockManagerId]]
// Mapping from executor ID to block manager ID.
private val blockManagerIdByExecutor = new HashMap[String, BlockManagerId]
// Mapping from block id to the set of block managers that have the block.
private val blockLocations = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]]
@ -68,11 +67,14 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
case GetMemoryStatus =>
getMemoryStatus
case GetStorageStatus =>
getStorageStatus
case RemoveBlock(blockId) =>
removeBlock(blockId)
case RemoveHost(host) =>
removeHost(host)
case RemoveExecutor(execId) =>
removeExecutor(execId)
sender ! true
case StopBlockManagerMaster =>
@ -96,16 +98,12 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
def removeBlockManager(blockManagerId: BlockManagerId) {
val info = blockManagerInfo(blockManagerId)
// Remove the block manager from blockManagerIdByHost. If the list of block
// managers belonging to the IP is empty, remove the entry from the hash map.
blockManagerIdByHost.get(blockManagerId.ip).foreach { managers: ArrayBuffer[BlockManagerId] =>
managers -= blockManagerId
if (managers.size == 0) blockManagerIdByHost.remove(blockManagerId.ip)
}
// Remove the block manager from blockManagerIdByExecutor.
blockManagerIdByExecutor -= blockManagerId.executorId
// Remove it from blockManagerInfo and remove all the blocks.
blockManagerInfo.remove(blockManagerId)
var iterator = info.blocks.keySet.iterator
val iterator = info.blocks.keySet.iterator
while (iterator.hasNext) {
val blockId = iterator.next
val locations = blockLocations.get(blockId)._2
@ -117,7 +115,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}
def expireDeadHosts() {
logDebug("Checking for hosts with no recent heart beats in BlockManagerMaster.")
logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.")
val now = System.currentTimeMillis()
val minSeenTime = now - slaveTimeout
val toRemove = new HashSet[BlockManagerId]
@ -130,17 +128,15 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
toRemove.foreach(removeBlockManager)
}
def removeHost(host: String) {
logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.")
logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq)
blockManagerIdByHost.get(host).foreach(_.foreach(removeBlockManager))
logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq)
def removeExecutor(execId: String) {
logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.")
blockManagerIdByExecutor.get(execId).foreach(removeBlockManager)
sender ! true
}
def heartBeat(blockManagerId: BlockManagerId) {
if (!blockManagerInfo.contains(blockManagerId)) {
if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
if (blockManagerId.executorId == "<driver>" && !isLocal) {
sender ! true
} else {
sender ! false
@ -177,24 +173,28 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
sender ! res
}
private def register(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
val startTimeMs = System.currentTimeMillis()
val tmp = " " + blockManagerId + " "
if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
logInfo("Got Register Msg from master node, don't register it")
} else {
blockManagerIdByHost.get(blockManagerId.ip) match {
case Some(managers) =>
// A block manager of the same host name already exists.
logInfo("Got another registration for host " + blockManagerId)
managers += blockManagerId
case None =>
blockManagerIdByHost += (blockManagerId.ip -> ArrayBuffer(blockManagerId))
private def getStorageStatus() {
val res = blockManagerInfo.map { case(blockManagerId, info) =>
import collection.JavaConverters._
StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap)
}
sender ! res
}
blockManagerInfo += (blockManagerId -> new BlockManagerMasterActor.BlockManagerInfo(
blockManagerId, System.currentTimeMillis(), maxMemSize, slaveActor))
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
if (id.executorId == "<driver>" && !isLocal) {
// Got a register message from the master node; don't register it
} else if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
case Some(manager) =>
// A block manager of the same host name already exists
logError("Got two different block manager registrations on " + id.executorId)
System.exit(1)
case None =>
blockManagerIdByExecutor(id.executorId) = id
}
blockManagerInfo(id) = new BlockManagerMasterActor.BlockManagerInfo(
id, System.currentTimeMillis(), maxMemSize, slaveActor)
}
sender ! true
}
@ -206,11 +206,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
memSize: Long,
diskSize: Long) {
val startTimeMs = System.currentTimeMillis()
val tmp = " " + blockManagerId + " " + blockId + " "
if (!blockManagerInfo.contains(blockManagerId)) {
if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
if (blockManagerId.executorId == "<driver>" && !isLocal) {
// We intentionally do not register the master (except in local mode),
// so we should not indicate failure.
sender ! true
@ -342,8 +339,8 @@ object BlockManagerMasterActor {
_lastSeenMs = System.currentTimeMillis()
}
def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long)
: Unit = synchronized {
def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long,
diskSize: Long) {
updateLastSeenMs()

View file

@ -54,11 +54,9 @@ class UpdateBlockInfo(
}
override def readExternal(in: ObjectInput) {
blockManagerId = new BlockManagerId()
blockManagerId.readExternal(in)
blockManagerId = BlockManagerId(in)
blockId = in.readUTF()
storageLevel = new StorageLevel()
storageLevel.readExternal(in)
storageLevel = StorageLevel(in)
memSize = in.readInt()
diskSize = in.readInt()
}
@ -90,7 +88,7 @@ private[spark]
case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster
private[spark]
case class RemoveHost(host: String) extends ToBlockManagerMaster
case class RemoveExecutor(execId: String) extends ToBlockManagerMaster
private[spark]
case object StopBlockManagerMaster extends ToBlockManagerMaster
@ -100,3 +98,6 @@ case object GetMemoryStatus extends ToBlockManagerMaster
private[spark]
case object ExpireDeadHosts extends ToBlockManagerMaster
private[spark]
case object GetStorageStatus extends ToBlockManagerMaster

View file

@ -0,0 +1,76 @@
package spark.storage
import akka.actor.{ActorRef, ActorSystem}
import akka.util.Timeout
import akka.util.duration._
import cc.spray.typeconversion.TwirlSupport._
import cc.spray.Directives
import spark.{Logging, SparkContext}
import spark.util.AkkaUtils
import spark.Utils
/**
* Web UI server for the BlockManager inside each SparkContext.
*/
private[spark]
class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, sc: SparkContext)
extends Directives with Logging {
val STATIC_RESOURCE_DIR = "spark/deploy/static"
implicit val timeout = Timeout(10 seconds)
/** Start a HTTP server to run the Web interface */
def start() {
try {
val port = if (System.getProperty("spark.ui.port") != null) {
System.getProperty("spark.ui.port").toInt
} else {
// TODO: Unfortunately, it's not possible to pass port 0 to spray and figure out which
// random port it bound to, so we have to try to find a local one by creating a socket.
Utils.findFreePort()
}
AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, handler, "BlockManagerHTTPServer")
logInfo("Started BlockManager web UI at http://%s:%d".format(Utils.localHostName(), port))
} catch {
case e: Exception =>
logError("Failed to create BlockManager WebUI", e)
System.exit(1)
}
}
val handler = {
get {
path("") {
completeWith {
// Request the current storage status from the Master
val storageStatusList = sc.getExecutorStorageStatus
// Calculate macro-level statistics
val maxMem = storageStatusList.map(_.maxMem).reduce(_+_)
val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_)
val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize))
.reduceOption(_+_).getOrElse(0L)
val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc)
spark.storage.html.index.
render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList)
}
} ~
path("rdd") {
parameter("id") { id =>
completeWith {
val prefix = "rdd_" + id.toString
val storageStatusList = sc.getExecutorStorageStatus
val filteredStorageStatusList = StorageUtils.
filterStorageStatusByPrefix(storageStatusList, prefix)
val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head
spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList)
}
}
} ~
pathPrefix("static") {
getFromResourceDirectory(STATIC_RESOURCE_DIR)
}
}
}
}

Some files were not shown because too many files have changed in this diff Show more