2011-02-27 17:27:12 -05:00
|
|
|
package spark
|
|
|
|
|
2012-10-06 21:46:04 -04:00
|
|
|
import java.io._
|
2011-02-27 17:27:12 -05:00
|
|
|
import java.util.concurrent.ConcurrentHashMap
|
|
|
|
|
2012-06-07 03:25:47 -04:00
|
|
|
import akka.actor._
|
2012-06-29 02:51:28 -04:00
|
|
|
import akka.dispatch._
|
|
|
|
import akka.pattern.ask
|
|
|
|
import akka.remote._
|
|
|
|
import akka.util.Duration
|
|
|
|
import akka.util.Timeout
|
2012-06-07 03:25:47 -04:00
|
|
|
import akka.util.duration._
|
|
|
|
|
2012-09-12 17:54:40 -04:00
|
|
|
import scala.collection.mutable.HashMap
|
2011-03-06 19:16:38 -05:00
|
|
|
import scala.collection.mutable.HashSet
|
2011-02-27 22:15:52 -05:00
|
|
|
|
2012-10-06 21:46:04 -04:00
|
|
|
import scheduler.MapStatus
|
2012-06-07 03:25:47 -04:00
|
|
|
import spark.storage.BlockManagerId
|
2012-10-06 21:46:04 -04:00
|
|
|
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
|
2012-11-27 18:08:49 -05:00
|
|
|
import util.{CleanupTask, TimeStampedHashMap}
|
2012-06-07 03:25:47 -04:00
|
|
|
|
2012-10-02 22:00:19 -04:00
|
|
|
private[spark] sealed trait MapOutputTrackerMessage
|
2012-10-07 01:02:18 -04:00
|
|
|
private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String)
|
|
|
|
extends MapOutputTrackerMessage
|
2012-10-02 22:00:19 -04:00
|
|
|
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
|
2011-03-06 19:16:38 -05:00
|
|
|
|
2012-10-02 22:00:19 -04:00
|
|
|
private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
|
2012-06-07 03:25:47 -04:00
|
|
|
def receive = {
|
2012-10-07 01:02:18 -04:00
|
|
|
case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
|
2012-10-07 02:43:52 -04:00
|
|
|
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
|
2012-09-12 17:54:40 -04:00
|
|
|
sender ! tracker.getSerializedLocations(shuffleId)
|
2012-06-07 03:25:47 -04:00
|
|
|
|
|
|
|
case StopMapOutputTracker =>
|
|
|
|
logInfo("MapOutputTrackerActor stopped!")
|
2012-06-29 02:51:28 -04:00
|
|
|
sender ! true
|
|
|
|
context.stop(self)
|
2011-02-27 22:15:52 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2012-10-02 22:00:19 -04:00
|
|
|
private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging {
|
2012-06-07 03:25:47 -04:00
|
|
|
val ip: String = System.getProperty("spark.master.host", "localhost")
|
|
|
|
val port: Int = System.getProperty("spark.master.port", "7077").toInt
|
2012-06-29 02:51:28 -04:00
|
|
|
val actorName: String = "MapOutputTracker"
|
|
|
|
|
|
|
|
val timeout = 10.seconds
|
2011-05-19 14:19:25 -04:00
|
|
|
|
2012-11-27 18:08:49 -05:00
|
|
|
var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
|
2011-05-20 03:19:53 -04:00
|
|
|
|
|
|
|
// Incremented every time a fetch fails so that client nodes know to clear
|
|
|
|
// their cache of map output locations if this happens.
|
|
|
|
private var generation: Long = 0
|
2012-10-06 21:46:04 -04:00
|
|
|
private val generationLock = new java.lang.Object
|
2012-06-07 03:25:47 -04:00
|
|
|
|
2012-10-06 21:46:04 -04:00
|
|
|
// Cache a serialized version of the output statuses for each shuffle to send them out faster
|
2012-09-12 17:54:40 -04:00
|
|
|
var cacheGeneration = generation
|
2012-11-27 18:08:49 -05:00
|
|
|
val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
|
2012-09-12 17:54:40 -04:00
|
|
|
|
2012-06-07 03:25:47 -04:00
|
|
|
var trackerActor: ActorRef = if (isMaster) {
|
2012-09-12 17:54:40 -04:00
|
|
|
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
|
2012-06-29 02:51:28 -04:00
|
|
|
logInfo("Registered MapOutputTrackerActor actor")
|
2012-06-07 03:25:47 -04:00
|
|
|
actor
|
2011-05-17 15:41:13 -04:00
|
|
|
} else {
|
2012-06-29 19:01:36 -04:00
|
|
|
val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
|
2012-06-29 02:51:28 -04:00
|
|
|
actorSystem.actorFor(url)
|
|
|
|
}
|
|
|
|
|
2012-11-27 18:08:49 -05:00
|
|
|
val cleanupTask = new CleanupTask("MapOutputTracker", this.cleanup)
|
|
|
|
|
2012-06-29 02:51:28 -04:00
|
|
|
// Send a message to the trackerActor and get its result within a default timeout, or
|
|
|
|
// throw a SparkException if this fails.
|
|
|
|
def askTracker(message: Any): Any = {
|
|
|
|
try {
|
|
|
|
val future = trackerActor.ask(message)(timeout)
|
|
|
|
return Await.result(future, timeout)
|
|
|
|
} catch {
|
|
|
|
case e: Exception =>
|
|
|
|
throw new SparkException("Error communicating with MapOutputTracker", e)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Send a one-way message to the trackerActor, to which we expect it to reply with true.
|
|
|
|
def communicate(message: Any) {
|
|
|
|
if (askTracker(message) != true) {
|
|
|
|
throw new SparkException("Error reply received from MapOutputTracker")
|
|
|
|
}
|
2011-02-27 22:15:52 -05:00
|
|
|
}
|
2012-01-05 15:59:20 -05:00
|
|
|
|
|
|
|
def registerShuffle(shuffleId: Int, numMaps: Int) {
|
2012-11-27 18:08:49 -05:00
|
|
|
if (mapStatuses.get(shuffleId) != None) {
|
2012-01-05 15:59:20 -05:00
|
|
|
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
|
|
|
|
}
|
2012-10-06 21:46:04 -04:00
|
|
|
mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
|
2012-01-05 15:59:20 -05:00
|
|
|
}
|
2011-02-27 22:15:52 -05:00
|
|
|
|
2012-10-06 21:46:04 -04:00
|
|
|
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
|
2012-11-27 18:08:49 -05:00
|
|
|
var array = mapStatuses(shuffleId)
|
2011-05-20 03:19:53 -04:00
|
|
|
array.synchronized {
|
2012-10-06 21:46:04 -04:00
|
|
|
array(mapId) = status
|
2011-05-20 03:19:53 -04:00
|
|
|
}
|
2011-02-27 17:27:12 -05:00
|
|
|
}
|
|
|
|
|
2012-10-06 21:46:04 -04:00
|
|
|
def registerMapOutputs(
|
|
|
|
shuffleId: Int,
|
|
|
|
statuses: Array[MapStatus],
|
|
|
|
changeGeneration: Boolean = false) {
|
|
|
|
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
|
2012-06-07 03:25:47 -04:00
|
|
|
if (changeGeneration) {
|
|
|
|
incrementGeneration()
|
|
|
|
}
|
2011-02-27 17:27:12 -05:00
|
|
|
}
|
2011-05-20 03:19:53 -04:00
|
|
|
|
2012-06-07 03:25:47 -04:00
|
|
|
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
|
2012-11-27 18:08:49 -05:00
|
|
|
var array = mapStatuses(shuffleId)
|
2011-05-20 03:19:53 -04:00
|
|
|
if (array != null) {
|
|
|
|
array.synchronized {
|
2012-10-06 21:46:04 -04:00
|
|
|
if (array(mapId).address == bmAddress) {
|
2011-05-20 03:19:53 -04:00
|
|
|
array(mapId) = null
|
2012-02-10 11:19:53 -05:00
|
|
|
}
|
2011-05-20 03:19:53 -04:00
|
|
|
}
|
|
|
|
incrementGeneration()
|
|
|
|
} else {
|
|
|
|
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
|
|
|
|
}
|
|
|
|
}
|
2011-02-27 17:27:12 -05:00
|
|
|
|
2011-05-20 03:19:53 -04:00
|
|
|
// Remembers which map output locations are currently being fetched on a worker
|
2011-03-06 19:16:38 -05:00
|
|
|
val fetching = new HashSet[Int]
|
|
|
|
|
2012-10-06 21:46:04 -04:00
|
|
|
// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
|
|
|
|
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
|
2012-11-27 18:08:49 -05:00
|
|
|
val statuses = mapStatuses.get(shuffleId).orNull
|
2012-10-06 21:46:04 -04:00
|
|
|
if (statuses == null) {
|
2012-08-28 02:07:50 -04:00
|
|
|
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
|
2011-03-06 19:16:38 -05:00
|
|
|
fetching.synchronized {
|
|
|
|
if (fetching.contains(shuffleId)) {
|
|
|
|
// Someone else is fetching it; wait for them to be done
|
|
|
|
while (fetching.contains(shuffleId)) {
|
2012-02-10 11:19:53 -05:00
|
|
|
try {
|
|
|
|
fetching.wait()
|
|
|
|
} catch {
|
2012-10-06 21:46:04 -04:00
|
|
|
case e: InterruptedException =>
|
2012-02-10 11:19:53 -05:00
|
|
|
}
|
2011-03-06 19:16:38 -05:00
|
|
|
}
|
2012-11-27 18:08:49 -05:00
|
|
|
return mapStatuses(shuffleId).map(status =>
|
2012-10-06 21:46:04 -04:00
|
|
|
(status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId))))
|
2011-03-06 19:16:38 -05:00
|
|
|
} else {
|
|
|
|
fetching += shuffleId
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// We won the race to fetch the output locs; do so
|
|
|
|
logInfo("Doing the fetch; tracker actor = " + trackerActor)
|
2012-10-07 01:02:18 -04:00
|
|
|
val host = System.getProperty("spark.hostname", Utils.localHostName)
|
|
|
|
val fetchedBytes = askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]]
|
2012-10-06 21:46:04 -04:00
|
|
|
val fetchedStatuses = deserializeStatuses(fetchedBytes)
|
2012-06-07 03:25:47 -04:00
|
|
|
|
|
|
|
logInfo("Got the output locations")
|
2012-10-06 21:46:04 -04:00
|
|
|
mapStatuses.put(shuffleId, fetchedStatuses)
|
2011-03-06 19:16:38 -05:00
|
|
|
fetching.synchronized {
|
|
|
|
fetching -= shuffleId
|
|
|
|
fetching.notifyAll()
|
|
|
|
}
|
2012-10-06 21:46:04 -04:00
|
|
|
return fetchedStatuses.map(s =>
|
|
|
|
(s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
|
2011-03-06 19:16:38 -05:00
|
|
|
} else {
|
2012-10-06 21:46:04 -04:00
|
|
|
return statuses.map(s =>
|
|
|
|
(s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
|
2011-03-06 19:16:38 -05:00
|
|
|
}
|
2011-02-27 17:27:12 -05:00
|
|
|
}
|
2011-05-13 15:03:58 -04:00
|
|
|
|
2012-11-27 18:08:49 -05:00
|
|
|
def cleanup(cleanupTime: Long) {
|
|
|
|
mapStatuses.cleanup(cleanupTime)
|
|
|
|
cachedSerializedStatuses.cleanup(cleanupTime)
|
|
|
|
}
|
|
|
|
|
2011-05-13 15:03:58 -04:00
|
|
|
def stop() {
|
2012-06-29 02:51:28 -04:00
|
|
|
communicate(StopMapOutputTracker)
|
2012-10-06 21:46:04 -04:00
|
|
|
mapStatuses.clear()
|
2012-11-27 18:08:49 -05:00
|
|
|
cleanupTask.cancel()
|
2011-05-13 15:03:58 -04:00
|
|
|
trackerActor = null
|
|
|
|
}
|
2011-05-20 03:19:53 -04:00
|
|
|
|
|
|
|
// Called on master to increment the generation number
|
|
|
|
def incrementGeneration() {
|
|
|
|
generationLock.synchronized {
|
|
|
|
generation += 1
|
2012-08-28 02:07:50 -04:00
|
|
|
logDebug("Increasing generation to " + generation)
|
2011-05-20 03:19:53 -04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Called on master or workers to get current generation number
|
|
|
|
def getGeneration: Long = {
|
|
|
|
generationLock.synchronized {
|
|
|
|
return generation
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Called on workers to update the generation number, potentially clearing old outputs
|
|
|
|
// because of a fetch failure. (Each Mesos task calls this with the latest generation
|
|
|
|
// number on the master at the time it was created.)
|
|
|
|
def updateGeneration(newGen: Long) {
|
|
|
|
generationLock.synchronized {
|
|
|
|
if (newGen > generation) {
|
|
|
|
logInfo("Updating generation to " + newGen + " and clearing cache")
|
2012-11-27 18:08:49 -05:00
|
|
|
mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
|
2011-05-20 03:19:53 -04:00
|
|
|
generation = newGen
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2012-09-12 17:54:40 -04:00
|
|
|
|
|
|
|
def getSerializedLocations(shuffleId: Int): Array[Byte] = {
|
2012-10-06 21:46:04 -04:00
|
|
|
var statuses: Array[MapStatus] = null
|
2012-09-12 17:54:40 -04:00
|
|
|
var generationGotten: Long = -1
|
|
|
|
generationLock.synchronized {
|
|
|
|
if (generation > cacheGeneration) {
|
2012-10-06 21:46:04 -04:00
|
|
|
cachedSerializedStatuses.clear()
|
2012-09-12 17:54:40 -04:00
|
|
|
cacheGeneration = generation
|
|
|
|
}
|
2012-10-06 21:46:04 -04:00
|
|
|
cachedSerializedStatuses.get(shuffleId) match {
|
2012-09-12 17:54:40 -04:00
|
|
|
case Some(bytes) =>
|
|
|
|
return bytes
|
|
|
|
case None =>
|
2012-11-27 18:08:49 -05:00
|
|
|
statuses = mapStatuses(shuffleId)
|
2012-09-12 17:54:40 -04:00
|
|
|
generationGotten = generation
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// If we got here, we failed to find the serialized locations in the cache, so we pulled
|
|
|
|
// out a snapshot of the locations as "locs"; let's serialize and return that
|
2012-10-06 21:46:04 -04:00
|
|
|
val bytes = serializeStatuses(statuses)
|
2012-10-07 01:02:18 -04:00
|
|
|
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
|
2012-09-12 17:54:40 -04:00
|
|
|
// Add them into the table only if the generation hasn't changed while we were working
|
|
|
|
generationLock.synchronized {
|
|
|
|
if (generation == generationGotten) {
|
2012-10-06 21:46:04 -04:00
|
|
|
cachedSerializedStatuses(shuffleId) = bytes
|
2012-09-12 17:54:40 -04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return bytes
|
|
|
|
}
|
|
|
|
|
|
|
|
// Serialize an array of map output locations into an efficient byte format so that we can send
|
2012-10-06 21:46:04 -04:00
|
|
|
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
|
|
|
|
// generally be pretty compressible because many map outputs will be on the same hostname.
|
|
|
|
def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
|
2012-09-12 17:54:40 -04:00
|
|
|
val out = new ByteArrayOutputStream
|
2012-10-06 21:46:04 -04:00
|
|
|
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
|
|
|
|
objOut.writeObject(statuses)
|
|
|
|
objOut.close()
|
2012-09-12 17:54:40 -04:00
|
|
|
out.toByteArray
|
|
|
|
}
|
|
|
|
|
2012-10-06 21:46:04 -04:00
|
|
|
// Opposite of serializeStatuses.
|
|
|
|
def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
|
|
|
|
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
|
|
|
|
objIn.readObject().asInstanceOf[Array[MapStatus]]
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
private[spark] object MapOutputTracker {
|
|
|
|
private val LOG_BASE = 1.1
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Compress a size in bytes to 8 bits for efficient reporting of map output sizes.
|
|
|
|
* We do this by encoding the log base 1.1 of the size as an integer, which can support
|
|
|
|
* sizes up to 35 GB with at most 10% error.
|
|
|
|
*/
|
|
|
|
def compressSize(size: Long): Byte = {
|
|
|
|
if (size <= 1L) {
|
|
|
|
0
|
|
|
|
} else {
|
|
|
|
math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Decompress an 8-bit encoded block size, using the reverse operation of compressSize.
|
|
|
|
*/
|
|
|
|
def decompressSize(compressedSize: Byte): Long = {
|
|
|
|
if (compressedSize == 0) {
|
|
|
|
1
|
|
|
|
} else {
|
|
|
|
math.pow(LOG_BASE, (compressedSize & 0xFF)).toLong
|
2012-09-12 17:54:40 -04:00
|
|
|
}
|
|
|
|
}
|
2011-05-13 15:03:58 -04:00
|
|
|
}
|