Added fork()/join() operations for SparkContext, as well as corresponding changes to MesosScheduler to support multiple ParallelOperations.

This commit is contained in:
Justin Ma 2010-09-12 09:01:44 -07:00
parent 6f0d2c1cbc
commit 0896fd6219
2 changed files with 93 additions and 51 deletions

View file

@ -3,6 +3,7 @@ package spark
import java.io.File
import scala.collection.mutable.Map
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
import mesos.{Scheduler => NScheduler}
@ -31,7 +32,15 @@ extends NScheduler with spark.Scheduler
val registeredLock = new Object()
// Current callback object (may be null)
var activeOp: ParallelOperation = null
var activeOps = new HashMap[Int, ParallelOperation]
private var nextOpId = 0
private[spark] var taskIdToOpId = new HashMap[Int, Int]
def newOpId(): Int = {
val id = nextOpId
nextOpId += 1
return id
}
// Incrementing task ID
private var nextTaskId = 0
@ -62,27 +71,29 @@ extends NScheduler with spark.Scheduler
new ExecutorInfo(new File("spark-executor").getCanonicalPath(), execArg)
override def runTasks[T: ClassManifest](tasks: Array[Task[T]]): Array[T] = {
var opId = 0
runTasksMutex.synchronized {
waitForRegister()
val myOp = new SimpleParallelOperation(this, tasks)
try {
this.synchronized {
this.activeOp = myOp
}
driver.reviveOffers();
myOp.join();
} finally {
this.synchronized {
this.activeOp = null
}
}
if (myOp.errorHappened)
throw new SparkException(myOp.errorMessage, myOp.errorCode)
else
return myOp.results
opId = newOpId()
}
val myOp = new SimpleParallelOperation(this, tasks, opId)
try {
this.synchronized {
this.activeOps(myOp.opId) = myOp
}
driver.reviveOffers();
myOp.join();
} finally {
this.synchronized {
this.activeOps.remove(myOp.opId)
}
}
if (myOp.errorHappened)
throw new SparkException(myOp.errorMessage, myOp.errorCode)
else
return myOp.results
}
override def registered(d: SchedulerDriver, frameworkId: String) {
@ -104,28 +115,26 @@ extends NScheduler with spark.Scheduler
d: SchedulerDriver, oid: String, offers: java.util.List[SlaveOffer]) {
synchronized {
val tasks = new java.util.ArrayList[TaskDescription]
if (activeOp != null) {
try {
val availableCpus = offers.map(_.getParams.get("cpus").toInt)
val availableMem = offers.map(_.getParams.get("mem").toInt)
var resourcesAvailable = true
while (resourcesAvailable) {
resourcesAvailable = false
for (i <- 0 until offers.size.toInt) {
activeOp.slaveOffer(offers.get(i), availableCpus(i), availableMem(i)) match {
case Some(task) =>
tasks.add(task)
availableCpus(i) -= task.getParams.get("cpus").toInt
availableMem(i) -= task.getParams.get("mem").toInt
resourcesAvailable = resourcesAvailable || true
case None => {}
}
val availableCpus = offers.map(_.getParams.get("cpus").toInt)
val availableMem = offers.map(_.getParams.get("mem").toInt)
var resourcesAvailable = true
while (resourcesAvailable) {
resourcesAvailable = false
for (i <- 0 until offers.size.toInt; (opId, activeOp) <- activeOps) {
try {
activeOp.slaveOffer(offers.get(i), availableCpus(i), availableMem(i)) match {
case Some(task) =>
tasks.add(task)
availableCpus(i) -= task.getParams.get("cpus").toInt
availableMem(i) -= task.getParams.get("mem").toInt
resourcesAvailable = resourcesAvailable || true
case None => {}
}
} catch {
case e: Exception => e.printStackTrace
}
} catch {
case e: Exception => e.printStackTrace
}
}
}
val params = new java.util.HashMap[String, String]
params.put("timeout", "1")
d.replyToOffer(oid, tasks, params) // TODO: use smaller timeout
@ -135,9 +144,15 @@ extends NScheduler with spark.Scheduler
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
synchronized {
try {
if (activeOp != null) {
activeOp.statusUpdate(status)
taskIdToOpId.get(status.getTaskId) match {
case Some(opId) =>
if (activeOps.contains(opId)) {
activeOps(opId).statusUpdate(status)
}
case None =>
println("TID " + status.getTaskId + "already finished")
}
} catch {
case e: Exception => e.printStackTrace
}
@ -146,11 +161,13 @@ extends NScheduler with spark.Scheduler
override def error(d: SchedulerDriver, code: Int, message: String) {
synchronized {
if (activeOp != null) {
try {
activeOp.error(code, message)
} catch {
case e: Exception => e.printStackTrace
if (activeOps.size > 0) {
for ((opId, activeOp) <- activeOps) {
try {
activeOp.error(code, message)
} catch {
case e: Exception => e.printStackTrace
}
}
} else {
val msg = "Mesos error: %s (error code: %d)".format(message, code)
@ -180,7 +197,7 @@ trait ParallelOperation {
class SimpleParallelOperation[T: ClassManifest](
sched: MesosScheduler, tasks: Array[Task[T]])
sched: MesosScheduler, tasks: Array[Task[T]], val opId: Int)
extends ParallelOperation
{
// Maximum time to wait to run a task in a preferred location (in ms)
@ -235,10 +252,10 @@ extends ParallelOperation
tasks(i).preferredLocations.isEmpty))
{
val taskId = sched.newTaskId()
sched.taskIdToOpId(taskId) = opId
tidToIndex(taskId) = i
//printf("Starting task %d as TID %s on slave %s: %s (%s)\n",
printf("Starting task %d as TID %s on slave %s: %s (%s)",
i, taskId, offer.getSlaveId, offer.getHost,
printf("Starting task %d as opId %d, TID %s on slave %s: %s (%s)",
i, opId, taskId, offer.getSlaveId, offer.getHost,
if(checkPref) "preferred" else "non-preferred")
tasks(i).markStarted(offer)
launched(i) = true
@ -274,7 +291,7 @@ extends ParallelOperation
def taskFinished(status: TaskStatus) {
val tid = status.getTaskId
print("Finished TID " + tid)
print("Finished opId " + opId + " TID " + tid)
if (!finished(tidToIndex(tid))) {
// Deserialize task result
val result = Utils.deserialize[TaskResult[T]](status.getData)
@ -283,6 +300,8 @@ extends ParallelOperation
Accumulators.add(callingThread, result.accumUpdates)
// Mark finished and stop if we've finished all the tasks
finished(tidToIndex(tid)) = true
// Remove TID -> opId mapping from sched
sched.taskIdToOpId.remove(tid)
tasksFinished += 1
println(", finished " + tasksFinished + "/" + numTasks)
@ -295,7 +314,7 @@ extends ParallelOperation
def taskLost(status: TaskStatus) {
val tid = status.getTaskId
println("Lost TID " + tid)
println("Lost opId " + opId + " TID " + tid)
if (!finished(tidToIndex(tid))) {
launched(tidToIndex(tid)) = false
tasksLaunched -= 1

View file

@ -4,6 +4,17 @@ import java.io._
import java.util.UUID
import scala.collection.mutable.ArrayBuffer
import scala.actors.Actor._
case class SparkAsyncLock(var finished: Boolean = false) {
def join() {
this.synchronized {
while (!finished) {
this.wait
}
}
}
}
class SparkContext(master: String, frameworkName: String) {
Broadcast.initialize(true)
@ -21,6 +32,18 @@ class SparkContext(master: String, frameworkName: String) {
def broadcast[T](value: T) = new CentralizedHDFSBroadcast(value, local)
//def broadcast[T](value: T) = new ChainedStreamingBroadcast(value, local)
def fork(f: => Unit): SparkAsyncLock = {
val thisLock = new SparkAsyncLock
actor {
f
thisLock.synchronized {
thisLock.finished = true
thisLock.notifyAll()
}
}
thisLock
}
def textFile(path: String) = new HdfsTextFile(this, path)
val LOCAL_REGEX = """local\[([0-9]+)\]""".r