2011-02-27 17:27:12 -05:00
|
|
|
package spark
|
|
|
|
|
|
|
|
import java.util.concurrent.ConcurrentHashMap
|
|
|
|
|
2011-02-27 22:15:52 -05:00
|
|
|
import scala.actors._
|
|
|
|
import scala.actors.Actor._
|
|
|
|
import scala.actors.remote._
|
2011-03-06 19:16:38 -05:00
|
|
|
import scala.collection.mutable.HashSet
|
2011-02-27 22:15:52 -05:00
|
|
|
|
2011-03-06 19:16:38 -05:00
|
|
|
sealed trait MapOutputTrackerMessage
|
|
|
|
case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage
|
2011-05-13 15:03:58 -04:00
|
|
|
case object StopMapOutputTracker extends MapOutputTrackerMessage
|
2011-03-06 19:16:38 -05:00
|
|
|
|
|
|
|
class MapOutputTracker(serverUris: ConcurrentHashMap[Int, Array[String]])
|
|
|
|
extends DaemonActor with Logging {
|
2011-02-27 22:15:52 -05:00
|
|
|
def act() {
|
|
|
|
val port = System.getProperty("spark.master.port", "50501").toInt
|
|
|
|
RemoteActor.alive(port)
|
|
|
|
RemoteActor.register('MapOutputTracker, self)
|
2011-03-06 15:16:38 -05:00
|
|
|
logInfo("Registered actor on port " + port)
|
2011-03-06 19:16:38 -05:00
|
|
|
|
|
|
|
loop {
|
|
|
|
react {
|
|
|
|
case GetMapOutputLocations(shuffleId: Int) =>
|
|
|
|
logInfo("Asked to get map output locations for shuffle " + shuffleId)
|
|
|
|
reply(serverUris.get(shuffleId))
|
2011-05-13 15:03:58 -04:00
|
|
|
case StopMapOutputTracker =>
|
|
|
|
reply('OK)
|
|
|
|
exit()
|
2011-03-06 19:16:38 -05:00
|
|
|
}
|
|
|
|
}
|
2011-02-27 22:15:52 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2011-03-06 19:16:38 -05:00
|
|
|
object MapOutputTracker extends Logging {
|
2011-02-27 22:15:52 -05:00
|
|
|
var trackerActor: AbstractActor = null
|
|
|
|
|
2011-03-06 19:16:38 -05:00
|
|
|
private val serverUris = new ConcurrentHashMap[Int, Array[String]]
|
|
|
|
|
2011-02-27 22:15:52 -05:00
|
|
|
def initialize(isMaster: Boolean) {
|
|
|
|
if (isMaster) {
|
2011-03-06 19:16:38 -05:00
|
|
|
val tracker = new MapOutputTracker(serverUris)
|
2011-02-27 22:15:52 -05:00
|
|
|
tracker.start
|
|
|
|
trackerActor = tracker
|
|
|
|
} else {
|
|
|
|
val host = System.getProperty("spark.master.host")
|
|
|
|
val port = System.getProperty("spark.master.port").toInt
|
|
|
|
trackerActor = RemoteActor.select(Node(host, port), 'MapOutputTracker)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2011-02-27 17:27:12 -05:00
|
|
|
def registerMapOutput(shuffleId: Int, numMaps: Int, mapId: Int, serverUri: String) {
|
|
|
|
var array = serverUris.get(shuffleId)
|
|
|
|
if (array == null) {
|
|
|
|
array = Array.fill[String](numMaps)(null)
|
|
|
|
serverUris.put(shuffleId, array)
|
|
|
|
}
|
|
|
|
array(mapId) = serverUri
|
|
|
|
}
|
|
|
|
|
|
|
|
def registerMapOutputs(shuffleId: Int, locs: Array[String]) {
|
|
|
|
serverUris.put(shuffleId, Array[String]() ++ locs)
|
|
|
|
}
|
|
|
|
|
2011-03-06 19:16:38 -05:00
|
|
|
|
|
|
|
// Remembers which map output locations are currently being fetched
|
|
|
|
val fetching = new HashSet[Int]
|
|
|
|
|
2011-02-27 17:27:12 -05:00
|
|
|
def getServerUris(shuffleId: Int): Array[String] = {
|
|
|
|
// TODO: On remote node, fetch locations from master
|
2011-03-06 19:16:38 -05:00
|
|
|
val locs = serverUris.get(shuffleId)
|
|
|
|
if (locs == null) {
|
|
|
|
logInfo("Don't have map outputs for " + shuffleId + ", fetching them")
|
|
|
|
fetching.synchronized {
|
|
|
|
if (fetching.contains(shuffleId)) {
|
|
|
|
// Someone else is fetching it; wait for them to be done
|
|
|
|
while (fetching.contains(shuffleId)) {
|
|
|
|
try {fetching.wait()} catch {case _ =>}
|
|
|
|
}
|
|
|
|
return serverUris.get(shuffleId)
|
|
|
|
} else {
|
|
|
|
fetching += shuffleId
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// We won the race to fetch the output locs; do so
|
|
|
|
logInfo("Doing the fetch; tracker actor = " + trackerActor)
|
|
|
|
val fetched = (trackerActor !? GetMapOutputLocations(shuffleId)).asInstanceOf[Array[String]]
|
|
|
|
serverUris.put(shuffleId, fetched)
|
|
|
|
fetching.synchronized {
|
|
|
|
fetching -= shuffleId
|
|
|
|
fetching.notifyAll()
|
|
|
|
}
|
|
|
|
return fetched
|
|
|
|
} else {
|
|
|
|
return locs
|
|
|
|
}
|
2011-02-27 17:27:12 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
def getMapOutputUri(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int): String = {
|
|
|
|
"%s/shuffle/%s/%s/%s".format(serverUri, shuffleId, mapId, reduceId)
|
|
|
|
}
|
2011-05-13 15:03:58 -04:00
|
|
|
|
|
|
|
def stop() {
|
|
|
|
trackerActor !? StopMapOutputTracker
|
|
|
|
serverUris.clear()
|
|
|
|
trackerActor = null
|
|
|
|
}
|
|
|
|
}
|