Merge remote-tracking branch 'upstream/dev' into dev

This commit is contained in:
Mosharaf Chowdhury 2012-08-28 14:56:57 -07:00
commit c74455f309
11 changed files with 91 additions and 33 deletions

View file

@ -48,8 +48,9 @@ class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
}
}
} catch {
// TODO: this is really ugly -- let's find a better way of throwing a FetchFailedException
case be: BlockException => {
val regex = "shuffledid_([0-9]*)_([0-9]*)_([0-9]]*)".r
val regex = "shuffleid_([0-9]*)_([0-9]*)_([0-9]]*)".r
be.blockId match {
case regex(sId, mId, rId) => {
val address = addresses(mId.toInt)

View file

@ -116,7 +116,7 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg
def getServerAddresses(shuffleId: Int): Array[BlockManagerId] = {
val locs = bmAddresses.get(shuffleId)
if (locs == null) {
logInfo("Don't have map outputs for shuffe " + shuffleId + ", fetching them")
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
fetching.synchronized {
if (fetching.contains(shuffleId)) {
// Someone else is fetching it; wait for them to be done
@ -158,6 +158,7 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg
def incrementGeneration() {
generationLock.synchronized {
generation += 1
logDebug("Increasing generation to " + generation)
}
}

View file

@ -63,6 +63,7 @@ class Executor extends Logging {
Thread.currentThread.setContextClassLoader(classLoader)
Accumulators.clear()
val task = ser.deserialize[Task[Any]](serializedTask, classLoader)
logInfo("Its generation is " + task.generation)
env.mapOutputTracker.updateGeneration(task.generation)
val value = task.run(taskId.toInt)
val accumUpdates = Accumulators.values

View file

@ -111,7 +111,7 @@ extends Connection(SocketChannel.open, selector_) {
messages.synchronized{
/*messages += message*/
messages.enqueue(message)
logInfo("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]")
logDebug("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]")
}
}
@ -136,7 +136,7 @@ extends Connection(SocketChannel.open, selector_) {
return chunk
}
/*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken )
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken )
}
}
None

View file

@ -14,7 +14,8 @@ import scala.collection.mutable.SynchronizedQueue
import scala.collection.mutable.Queue
import scala.collection.mutable.ArrayBuffer
import akka.dispatch.{Promise, ExecutionContext, Future}
import akka.dispatch.{Await, Promise, ExecutionContext, Future}
import akka.util.Duration
case class ConnectionManagerId(host: String, port: Int) {
def toSocketAddress() = new InetSocketAddress(host, port)
@ -247,7 +248,7 @@ class ConnectionManager(port: Int) extends Logging {
}
private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) {
logInfo("Handling [" + message + "] from [" + connectionManagerId + "]")
logDebug("Handling [" + message + "] from [" + connectionManagerId + "]")
message match {
case bufferMessage: BufferMessage => {
if (bufferMessage.hasAckId) {
@ -305,7 +306,7 @@ class ConnectionManager(port: Int) extends Logging {
}
val connection = connectionsById.getOrElse(connectionManagerId, startNewConnection())
message.senderAddress = id.toSocketAddress()
logInfo("Sending [" + message + "] to [" + connectionManagerId + "]")
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
/*connection.send(message)*/
sendMessageRequests.synchronized {
sendMessageRequests += ((message, connection))
@ -325,7 +326,7 @@ class ConnectionManager(port: Int) extends Logging {
}
def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = {
sendMessageReliably(connectionManagerId, message)()
Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf)
}
def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) {

View file

@ -16,8 +16,11 @@ import spark._
import spark.storage._
object ShuffleMapTask {
// A simple map between the stage id to the serialized byte array of a task.
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
val serializedInfoCache = new HashMap[Int, Array[Byte]]
val deserializedInfoCache = new HashMap[Int, (RDD[_], ShuffleDependency[_,_,_])]
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = {
synchronized {
@ -39,29 +42,21 @@ object ShuffleMapTask {
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = {
synchronized {
val old = deserializedInfoCache.get(stageId)
if (old != null) {
return old
} else {
val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val objIn = new ObjectInputStream(in) {
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, loader)
}
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_,_]]
val tuple = (rdd, dep)
deserializedInfoCache.put(stageId, tuple)
return tuple
val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val objIn = new ObjectInputStream(in) {
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, loader)
}
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_,_]]
return (rdd, dep)
}
}
def clearCache() {
synchronized {
serializedInfoCache.clear()
deserializedInfoCache.clear()
}
}
}
@ -90,6 +85,7 @@ class ShuffleMapTask(
out.writeInt(bytes.length)
out.write(bytes)
out.writeInt(partition)
out.writeLong(generation)
out.writeObject(split)
}
@ -102,6 +98,7 @@ class ShuffleMapTask(
rdd = rdd_
dep = dep_
partition = in.readInt()
generation = in.readLong()
split = in.readObject().asInstanceOf[Split]
}

View file

@ -2,13 +2,14 @@ package spark.scheduler.cluster
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import akka.actor.{Props, Actor, ActorRef, ActorSystem}
import akka.actor._
import akka.util.duration._
import akka.pattern.ask
import spark.{SparkException, Logging, TaskState}
import akka.dispatch.Await
import java.util.concurrent.atomic.AtomicInteger
import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent}
/**
* A standalone scheduler backend, which waits for standalone executors to connect to it through
@ -23,8 +24,16 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
class MasterActor(sparkProperties: Seq[(String, String)]) extends Actor {
val slaveActor = new HashMap[String, ActorRef]
val slaveAddress = new HashMap[String, Address]
val slaveHost = new HashMap[String, String]
val freeCores = new HashMap[String, Int]
val actorToSlaveId = new HashMap[ActorRef, String]
val addressToSlaveId = new HashMap[Address, String]
override def preStart() {
// Listen for remote client disconnection events, since they don't go through Akka's watch()
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
}
def receive = {
case RegisterSlave(slaveId, host, cores) =>
@ -33,9 +42,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
} else {
logInfo("Registered slave: " + sender + " with ID " + slaveId)
sender ! RegisteredSlave(sparkProperties)
context.watch(sender)
slaveActor(slaveId) = sender
slaveHost(slaveId) = host
freeCores(slaveId) = cores
slaveAddress(slaveId) = sender.path.address
actorToSlaveId(sender) = slaveId
addressToSlaveId(sender.path.address) = slaveId
totalCoreCount.addAndGet(cores)
makeOffers()
}
@ -54,7 +67,14 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
sender ! true
context.stop(self)
// TODO: Deal with nodes disconnecting too! (Including decreasing totalCoreCount)
case Terminated(actor) =>
actorToSlaveId.get(actor).foreach(removeSlave)
case RemoteClientDisconnected(transport, address) =>
addressToSlaveId.get(address).foreach(removeSlave)
case RemoteClientShutdown(transport, address) =>
addressToSlaveId.get(address).foreach(removeSlave)
}
// Make fake resource offers on all slaves
@ -76,6 +96,20 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
slaveActor(task.slaveId) ! LaunchTask(task)
}
}
// Remove a disconnected slave from the cluster
def removeSlave(slaveId: String) {
logInfo("Slave " + slaveId + " disconnected, so removing it")
val numCores = freeCores(slaveId)
actorToSlaveId -= slaveActor(slaveId)
addressToSlaveId -= slaveAddress(slaveId)
slaveActor -= slaveId
slaveHost -= slaveId
freeCores -= slaveId
slaveHost -= slaveId
totalCoreCount.addAndGet(-numCores)
scheduler.slaveLost(slaveId)
}
}
var masterActor: ActorRef = null

View file

@ -20,6 +20,8 @@ class TaskInfo(val taskId: Long, val index: Int, val launchTime: Long, val host:
def successful: Boolean = finished && !failed
def running: Boolean = !finished
def duration: Long = {
if (!finished) {
throw new UnsupportedOperationException("duration() called on unfinished tasks")

View file

@ -88,6 +88,7 @@ class TaskSetManager(
// Figure out the current map output tracker generation and set it on all tasks
val generation = sched.mapOutputTracker.getGeneration
logDebug("Generation for " + taskSet.id + ": " + generation)
for (t <- tasks) {
t.generation = generation
}
@ -264,6 +265,11 @@ class TaskSetManager(
def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
if (info.failed) {
// We might get two task-lost messages for the same task in coarse-grained Mesos mode,
// or even from Mesos itself when acks get delayed.
return
}
val index = info.index
info.markFailed()
if (!finished(index)) {
@ -340,7 +346,7 @@ class TaskSetManager(
}
def hostLost(hostname: String) {
logInfo("Re-queueing tasks for " + hostname)
logInfo("Re-queueing tasks for " + hostname + " from TaskSet " + taskSet.id)
// If some task has preferred locations only on hostname, put it in the no-prefs list
// to avoid the wait from delay scheduling
for (index <- getPendingTasksForHost(hostname)) {
@ -349,7 +355,7 @@ class TaskSetManager(
pendingTasksWithNoPrefs += index
}
}
// Also re-enqueue any tasks that ran on the failed host if this is a shuffle map stage
// Re-enqueue any tasks that ran on the failed host if this is a shuffle map stage
if (tasks(0).isInstanceOf[ShuffleMapTask]) {
for ((tid, info) <- taskInfos if info.host == hostname) {
val index = taskInfos(tid).index
@ -364,6 +370,10 @@ class TaskSetManager(
}
}
}
// Also re-enqueue any tasks that were running on the node
for ((tid, info) <- taskInfos if info.running && info.host == hostname) {
taskLost(tid, TaskState.KILLED, null)
}
}
/**

View file

@ -364,6 +364,12 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
val startTimeMs = System.currentTimeMillis
var bytes: ByteBuffer = null
// If we need to replicate the data, we'll want access to the values, but because our
// put will read the whole iterator, there will be no values left. For the case where
// the put serializes data, we'll remember the bytes, above; but for the case where
// it doesn't, such as MEMORY_ONLY_DESER, let's rely on the put returning an Iterator.
var valuesAfterPut: Iterator[Any] = null
locker.getLock(blockId).synchronized {
logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
@ -391,7 +397,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
// If only save to memory
memoryStore.putValues(blockId, values, level) match {
case Right(newBytes) => bytes = newBytes
case _ =>
case Left(newIterator) => valuesAfterPut = newIterator
}
} else {
// If only save to disk
@ -408,8 +414,13 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m
// Replicate block if required
if (level.replication > 1) {
// Serialize the block if not already done
if (bytes == null) {
bytes = dataSerialize(values) // serialize the block if not already done
if (valuesAfterPut == null) {
throw new SparkException(
"Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
}
bytes = dataSerialize(valuesAfterPut)
}
replicate(blockId, bytes, level)
}

View file

@ -30,7 +30,7 @@ object SparkLR {
}
val sc = new SparkContext(args(0), "SparkLR")
val numSlices = if (args.length > 1) args(1).toInt else 2
val data = generateData
val points = sc.parallelize(generateData, numSlices).cache()
// Initialize w to a random value
var w = Vector(D, _ => 2 * rand.nextDouble - 1)
@ -38,7 +38,7 @@ object SparkLR {
for (i <- 1 to ITERATIONS) {
println("On iteration " + i)
val gradient = sc.parallelize(data, numSlices).map { p =>
val gradient = points.map { p =>
(1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y * p.x
}.reduce(_ + _)
w -= gradient