spark-instrumented-optimizer/core/src/main/scala/spark/MapOutputTracker.scala

107 lines
3.3 KiB
Scala
Raw Normal View History

2011-02-27 17:27:12 -05:00
package spark
import java.util.concurrent.ConcurrentHashMap
2011-02-27 22:15:52 -05:00
import scala.actors._
import scala.actors.Actor._
import scala.actors.remote._
import scala.collection.mutable.HashSet
2011-02-27 22:15:52 -05:00
sealed trait MapOutputTrackerMessage
case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage
case object StopMapOutputTracker extends MapOutputTrackerMessage
class MapOutputTrackerActor(serverUris: ConcurrentHashMap[Int, Array[String]])
extends DaemonActor with Logging {
2011-02-27 22:15:52 -05:00
def act() {
val port = System.getProperty("spark.master.port").toInt
2011-02-27 22:15:52 -05:00
RemoteActor.alive(port)
RemoteActor.register('MapOutputTracker, self)
logInfo("Registered actor on port " + port)
loop {
react {
case GetMapOutputLocations(shuffleId: Int) =>
logInfo("Asked to get map output locations for shuffle " + shuffleId)
reply(serverUris.get(shuffleId))
case StopMapOutputTracker =>
reply('OK)
exit()
}
}
2011-02-27 22:15:52 -05:00
}
}
class MapOutputTracker(isMaster: Boolean) extends Logging {
2011-02-27 22:15:52 -05:00
var trackerActor: AbstractActor = null
2011-05-19 14:19:25 -04:00
private val serverUris = new ConcurrentHashMap[Int, Array[String]]
2011-02-27 22:15:52 -05:00
if (isMaster) {
val tracker = new MapOutputTrackerActor(serverUris)
2011-05-19 14:19:25 -04:00
tracker.start()
trackerActor = tracker
} else {
val host = System.getProperty("spark.master.host")
val port = System.getProperty("spark.master.port").toInt
trackerActor = RemoteActor.select(Node(host, port), 'MapOutputTracker)
2011-02-27 22:15:52 -05:00
}
2011-02-27 17:27:12 -05:00
def registerMapOutput(shuffleId: Int, numMaps: Int, mapId: Int, serverUri: String) {
var array = serverUris.get(shuffleId)
if (array == null) {
array = Array.fill[String](numMaps)(null)
serverUris.put(shuffleId, array)
}
array(mapId) = serverUri
}
def registerMapOutputs(shuffleId: Int, locs: Array[String]) {
serverUris.put(shuffleId, Array[String]() ++ locs)
}
// Remembers which map output locations are currently being fetched
val fetching = new HashSet[Int]
2011-02-27 17:27:12 -05:00
def getServerUris(shuffleId: Int): Array[String] = {
// TODO: On remote node, fetch locations from master
val locs = serverUris.get(shuffleId)
if (locs == null) {
logInfo("Don't have map outputs for " + shuffleId + ", fetching them")
fetching.synchronized {
if (fetching.contains(shuffleId)) {
// Someone else is fetching it; wait for them to be done
while (fetching.contains(shuffleId)) {
try {fetching.wait()} catch {case _ =>}
}
return serverUris.get(shuffleId)
} else {
fetching += shuffleId
}
}
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
val fetched = (trackerActor !? GetMapOutputLocations(shuffleId)).asInstanceOf[Array[String]]
2011-05-19 14:19:25 -04:00
println("Got locations: " + fetched.mkString(", "))
serverUris.put(shuffleId, fetched)
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
return fetched
} else {
return locs
}
2011-02-27 17:27:12 -05:00
}
def getMapOutputUri(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int): String = {
"%s/shuffle/%s/%s/%s".format(serverUri, shuffleId, mapId, reduceId)
}
def stop() {
trackerActor !? StopMapOutputTracker
serverUris.clear()
trackerActor = null
}
}