Merge pull request #407 from woggling/no-cache-tracker

Eliminate CacheTracker
This commit is contained in:
Matei Zaharia 2013-01-23 12:28:48 -08:00
commit 1a3aeeca23
7 changed files with 83 additions and 411 deletions

View file

@ -0,0 +1,65 @@
package spark
import scala.collection.mutable.{ArrayBuffer, HashSet}
import spark.storage.{BlockManager, StorageLevel}
/** Spark class responsible for passing RDDs split contents to the BlockManager and making
sure a node doesn't load two copies of an RDD at once.
*/
private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
private val loading = new HashSet[String]
/** Gets or computes an RDD split. Used by RDD.iterator() when a RDD is cached. */
def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
: Iterator[T] = {
val key = "rdd_%d_%d".format(rdd.id, split.index)
logInfo("Cache key is " + key)
blockManager.get(key) match {
case Some(cachedValues) =>
// Split is in cache, so just return its values
logInfo("Found partition in cache!")
return cachedValues.asInstanceOf[Iterator[T]]
case None =>
// Mark the split as loading (unless someone else marks it first)
loading.synchronized {
if (loading.contains(key)) {
logInfo("Loading contains " + key + ", waiting...")
while (loading.contains(key)) {
try {loading.wait()} catch {case _ =>}
}
logInfo("Loading no longer contains " + key + ", so returning cached result")
// See whether someone else has successfully loaded it. The main way this would fail
// is for the RDD-level cache eviction policy if someone else has loaded the same RDD
// partition but we didn't want to make space for it. However, that case is unlikely
// because it's unlikely that two threads would work on the same RDD partition. One
// downside of the current code is that threads wait serially if this does happen.
blockManager.get(key) match {
case Some(values) =>
return values.asInstanceOf[Iterator[T]]
case None =>
logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
loading.add(key)
}
} else {
loading.add(key)
}
}
try {
// If we got here, we have to load the split
val elements = new ArrayBuffer[Any]
logInfo("Computing partition " + split)
elements ++= rdd.compute(split, context)
// Try to put this block in the blockManager
blockManager.put(key, elements, storageLevel, true)
return elements.iterator.asInstanceOf[Iterator[T]]
} finally {
loading.synchronized {
loading.remove(key)
loading.notifyAll()
}
}
}
}
}

View file

@ -1,240 +0,0 @@
package spark
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import akka.actor._
import akka.dispatch._
import akka.pattern.ask
import akka.remote._
import akka.util.Duration
import akka.util.Timeout
import akka.util.duration._
import spark.storage.BlockManager
import spark.storage.StorageLevel
import util.{TimeStampedHashSet, MetadataCleaner, TimeStampedHashMap}
private[spark] sealed trait CacheTrackerMessage
private[spark] case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
private[spark] case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
private[spark] case class MemoryCacheLost(host: String) extends CacheTrackerMessage
private[spark] case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
private[spark] case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage
private[spark] case object GetCacheStatus extends CacheTrackerMessage
private[spark] case object GetCacheLocations extends CacheTrackerMessage
private[spark] case object StopCacheTracker extends CacheTrackerMessage
private[spark] class CacheTrackerActor extends Actor with Logging {
// TODO: Should probably store (String, CacheType) tuples
private val locs = new TimeStampedHashMap[Int, Array[List[String]]]
/**
* A map from the slave's host name to its cache size.
*/
private val slaveCapacity = new HashMap[String, Long]
private val slaveUsage = new HashMap[String, Long]
private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.clearOldValues)
private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
def receive = {
case SlaveCacheStarted(host: String, size: Long) =>
slaveCapacity.put(host, size)
slaveUsage.put(host, 0)
sender ! true
case RegisterRDD(rddId: Int, numPartitions: Int) =>
logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
sender ! true
case AddedToCache(rddId, partition, host, size) =>
slaveUsage.put(host, getCacheUsage(host) + size)
locs(rddId)(partition) = host :: locs(rddId)(partition)
sender ! true
case DroppedFromCache(rddId, partition, host, size) =>
slaveUsage.put(host, getCacheUsage(host) - size)
// Do a sanity check to make sure usage is greater than 0.
locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
sender ! true
case MemoryCacheLost(host) =>
logInfo("Memory cache lost on " + host)
for ((id, locations) <- locs) {
for (i <- 0 until locations.length) {
locations(i) = locations(i).filterNot(_ == host)
}
}
sender ! true
case GetCacheLocations =>
logInfo("Asked for current cache locations")
sender ! locs.map{case (rrdId, array) => (rrdId -> array.clone())}
case GetCacheStatus =>
val status = slaveCapacity.map { case (host, capacity) =>
(host, capacity, getCacheUsage(host))
}.toSeq
sender ! status
case StopCacheTracker =>
logInfo("Stopping CacheTrackerActor")
sender ! true
metadataCleaner.cancel()
context.stop(self)
}
}
private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager)
extends Logging {
// Tracker actor on the master, or remote reference to it on workers
val ip: String = System.getProperty("spark.master.host", "localhost")
val port: Int = System.getProperty("spark.master.port", "7077").toInt
val actorName: String = "CacheTracker"
val timeout = 10.seconds
var trackerActor: ActorRef = if (isMaster) {
val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName)
logInfo("Registered CacheTrackerActor actor")
actor
} else {
val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
actorSystem.actorFor(url)
}
// TODO: Consider removing this HashSet completely as locs CacheTrackerActor already
// keeps track of registered RDDs
val registeredRddIds = new TimeStampedHashSet[Int]
// Remembers which splits are currently being loaded (on worker nodes)
val loading = new HashSet[String]
val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.clearOldValues)
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
def askTracker(message: Any): Any = {
try {
val future = trackerActor.ask(message)(timeout)
return Await.result(future, timeout)
} catch {
case e: Exception =>
throw new SparkException("Error communicating with CacheTracker", e)
}
}
// Send a one-way message to the trackerActor, to which we expect it to reply with true.
def communicate(message: Any) {
if (askTracker(message) != true) {
throw new SparkException("Error reply received from CacheTracker")
}
}
// Registers an RDD (on master only)
def registerRDD(rddId: Int, numPartitions: Int) {
registeredRddIds.synchronized {
if (!registeredRddIds.contains(rddId)) {
logInfo("Registering RDD ID " + rddId + " with cache")
registeredRddIds += rddId
communicate(RegisterRDD(rddId, numPartitions))
}
}
}
// For BlockManager.scala only
def cacheLost(host: String) {
communicate(MemoryCacheLost(host))
logInfo("CacheTracker successfully removed entries on " + host)
}
// Get the usage status of slave caches. Each tuple in the returned sequence
// is in the form of (host name, capacity, usage).
def getCacheStatus(): Seq[(String, Long, Long)] = {
askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]]
}
// For BlockManager.scala only
def notifyFromBlockManager(t: AddedToCache) {
communicate(t)
}
// Get a snapshot of the currently known locations
def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]]
}
// Gets or computes an RDD split
def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
: Iterator[T] = {
val key = "rdd_%d_%d".format(rdd.id, split.index)
logInfo("Cache key is " + key)
blockManager.get(key) match {
case Some(cachedValues) =>
// Split is in cache, so just return its values
logInfo("Found partition in cache!")
return cachedValues.asInstanceOf[Iterator[T]]
case None =>
// Mark the split as loading (unless someone else marks it first)
loading.synchronized {
if (loading.contains(key)) {
logInfo("Loading contains " + key + ", waiting...")
while (loading.contains(key)) {
try {loading.wait()} catch {case _ =>}
}
logInfo("Loading no longer contains " + key + ", so returning cached result")
// See whether someone else has successfully loaded it. The main way this would fail
// is for the RDD-level cache eviction policy if someone else has loaded the same RDD
// partition but we didn't want to make space for it. However, that case is unlikely
// because it's unlikely that two threads would work on the same RDD partition. One
// downside of the current code is that threads wait serially if this does happen.
blockManager.get(key) match {
case Some(values) =>
return values.asInstanceOf[Iterator[T]]
case None =>
logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
loading.add(key)
}
} else {
loading.add(key)
}
}
try {
// If we got here, we have to load the split
val elements = new ArrayBuffer[Any]
logInfo("Computing partition " + split)
elements ++= rdd.compute(split, context)
// Try to put this block in the blockManager
blockManager.put(key, elements, storageLevel, true)
return elements.iterator.asInstanceOf[Iterator[T]]
} finally {
loading.synchronized {
loading.remove(key)
loading.notifyAll()
}
}
}
}
// Called by the Cache to report that an entry has been dropped from it
def dropEntry(rddId: Int, partition: Int) {
communicate(DroppedFromCache(rddId, partition, Utils.localHostName()))
}
def stop() {
communicate(StopCacheTracker)
registeredRddIds.clear()
trackerActor = null
}
}

View file

@ -176,7 +176,7 @@ abstract class RDD[T: ClassManifest](
if (isCheckpointed) {
checkpointData.get.iterator(split, context)
} else if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel)
SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
} else {
compute(split, context)
}

View file

@ -22,7 +22,7 @@ class SparkEnv (
val actorSystem: ActorSystem,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheTracker: CacheTracker,
val cacheManager: CacheManager,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
val broadcastManager: BroadcastManager,
@ -35,7 +35,6 @@ class SparkEnv (
def stop() {
httpFileServer.stop()
mapOutputTracker.stop()
cacheTracker.stop()
shuffleFetcher.stop()
broadcastManager.stop()
blockManager.stop()
@ -96,8 +95,7 @@ object SparkEnv extends Logging {
val closureSerializer = instantiateClass[Serializer](
"spark.closure.serializer", "spark.JavaSerializer")
val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager)
blockManager.cacheTracker = cacheTracker
val cacheManager = new CacheManager(blockManager)
val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster)
@ -127,7 +125,7 @@ object SparkEnv extends Logging {
actorSystem,
serializer,
closureSerializer,
cacheTracker,
cacheManager,
mapOutputTracker,
shuffleFetcher,
broadcastManager,

View file

@ -69,8 +69,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
var cacheLocs = new HashMap[Int, Array[List[String]]]
val env = SparkEnv.get
val cacheTracker = env.cacheTracker
val mapOutputTracker = env.mapOutputTracker
val blockManagerMaster = env.blockManager.master
val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back;
// that's not going to be a realistic assumption in general
@ -95,11 +95,17 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}.start()
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
if (!cacheLocs.contains(rdd.id)) {
val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map {
locations => locations.map(_.ip).toList
}.toArray
}
cacheLocs(rdd.id)
}
def updateCacheLocs() {
cacheLocs = cacheTracker.getLocationsSnapshot()
def clearCacheLocs() {
cacheLocs.clear
}
/**
@ -126,7 +132,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of splits is unknown
logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
cacheTracker.registerRDD(rdd.id, rdd.splits.size)
if (shuffleDep != None) {
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
}
@ -148,8 +153,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
visited += r
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of splits is unknown
logInfo("Registering parent RDD " + r.id + " (" + r.origin + ")")
cacheTracker.registerRDD(r.id, r.splits.size)
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
@ -250,7 +253,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val runId = nextRunId.getAndIncrement()
val finalStage = newStage(finalRDD, None, runId)
val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
updateCacheLocs()
clearCacheLocs()
logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
" output partitions")
logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
@ -293,7 +296,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
// on the failed node.
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
logInfo("Resubmitting failed stages")
updateCacheLocs()
clearCacheLocs()
val failed2 = failed.toArray
failed.clear()
for (stage <- failed2.sortBy(_.priority)) {
@ -443,7 +446,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
stage.shuffleDep.get.shuffleId,
stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray)
}
updateCacheLocs()
clearCacheLocs()
if (stage.outputLocs.count(_ == Nil) != 0) {
// Some tasks had failed; let's resubmit this stage
// TODO: Lower-level scheduler should also deal with this
@ -519,8 +522,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
mapOutputTracker.registerMapOutputs(shuffleId, locs, true)
}
cacheTracker.cacheLost(host)
updateCacheLocs()
clearCacheLocs()
}
}

View file

@ -16,7 +16,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
import spark.{CacheTracker, Logging, SizeEstimator, SparkEnv, SparkException, Utils}
import spark.{Logging, SizeEstimator, SparkEnv, SparkException, Utils}
import spark.network._
import spark.serializer.Serializer
import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap}
@ -71,9 +71,6 @@ class BlockManager(
val connectionManagerId = connectionManager.id
val blockManagerId = BlockManagerId(connectionManagerId.host, connectionManagerId.port)
// TODO: This will be removed after cacheTracker is removed from the code base.
var cacheTracker: CacheTracker = null
// Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
// for receiving shuffle outputs)
val maxBytesInFlight =
@ -662,10 +659,6 @@ class BlockManager(
BlockManager.dispose(bytesAfterPut)
// TODO: This code will be removed when CacheTracker is gone.
if (blockId.startsWith("rdd")) {
notifyCacheTracker(blockId)
}
logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs))
return size
@ -733,11 +726,6 @@ class BlockManager(
}
}
// TODO: This code will be removed when CacheTracker is gone.
if (blockId.startsWith("rdd")) {
notifyCacheTracker(blockId)
}
// If replication had started, then wait for it to finish
if (level.replication > 1) {
if (replicationFuture == null) {
@ -779,16 +767,6 @@ class BlockManager(
}
}
// TODO: This code will be removed when CacheTracker is gone.
private def notifyCacheTracker(key: String) {
if (cacheTracker != null) {
val rddInfo = key.split("_")
val rddId: Int = rddInfo(1).toInt
val partition: Int = rddInfo(2).toInt
cacheTracker.notifyFromBlockManager(spark.AddedToCache(rddId, partition, host))
}
}
/**
* Read a block consisting of a single object.
*/

View file

@ -1,131 +0,0 @@
package spark
import org.scalatest.FunSuite
import scala.collection.mutable.HashMap
import akka.actor._
import akka.dispatch._
import akka.pattern.ask
import akka.remote._
import akka.util.Duration
import akka.util.Timeout
import akka.util.duration._
class CacheTrackerSuite extends FunSuite {
// Send a message to an actor and wait for a reply, in a blocking manner
private def ask(actor: ActorRef, message: Any): Any = {
try {
val timeout = 10.seconds
val future = actor.ask(message)(timeout)
return Await.result(future, timeout)
} catch {
case e: Exception =>
throw new SparkException("Error communicating with actor", e)
}
}
test("CacheTrackerActor slave initialization & cache status") {
//System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
val actorSystem = ActorSystem("test")
val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 0L)))
assert(ask(tracker, StopCacheTracker) === true)
actorSystem.shutdown()
actorSystem.awaitTermination()
}
test("RegisterRDD") {
//System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
val actorSystem = ActorSystem("test")
val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
assert(ask(tracker, RegisterRDD(1, 3)) === true)
assert(ask(tracker, RegisterRDD(2, 1)) === true)
assert(getCacheLocations(tracker) === Map(1 -> List(Nil, Nil, Nil), 2 -> List(Nil)))
assert(ask(tracker, StopCacheTracker) === true)
actorSystem.shutdown()
actorSystem.awaitTermination()
}
test("AddedToCache") {
//System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
val actorSystem = ActorSystem("test")
val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
assert(ask(tracker, RegisterRDD(1, 2)) === true)
assert(ask(tracker, RegisterRDD(2, 1)) === true)
assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true)
assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true)
assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true)
assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L)))
assert(getCacheLocations(tracker) ===
Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
assert(ask(tracker, StopCacheTracker) === true)
actorSystem.shutdown()
actorSystem.awaitTermination()
}
test("DroppedFromCache") {
//System.setProperty("spark.master.port", "1345")
val initialSize = 2L << 20
val actorSystem = ActorSystem("test")
val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
assert(ask(tracker, RegisterRDD(1, 2)) === true)
assert(ask(tracker, RegisterRDD(2, 1)) === true)
assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true)
assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true)
assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true)
assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L)))
assert(getCacheLocations(tracker) ===
Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
assert(ask(tracker, DroppedFromCache(1, 1, "host001", 2L << 11)) === true)
assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 68608L)))
assert(getCacheLocations(tracker) ===
Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
assert(ask(tracker, StopCacheTracker) === true)
actorSystem.shutdown()
actorSystem.awaitTermination()
}
/**
* Helper function to get cacheLocations from CacheTracker
*/
def getCacheLocations(tracker: ActorRef): HashMap[Int, List[List[String]]] = {
val answer = ask(tracker, GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]]
answer.map { case (i, arr) => (i, arr.toList) }
}
}