Added fork()/join() operations for SparkContext, as well as corresponding changes to MesosScheduler to support multiple ParallelOperations.
This commit is contained in:
parent
6f0d2c1cbc
commit
0896fd6219
|
@ -3,6 +3,7 @@ package spark
|
|||
import java.io.File
|
||||
|
||||
import scala.collection.mutable.Map
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
import mesos.{Scheduler => NScheduler}
|
||||
|
@ -31,7 +32,15 @@ extends NScheduler with spark.Scheduler
|
|||
val registeredLock = new Object()
|
||||
|
||||
// Current callback object (may be null)
|
||||
var activeOp: ParallelOperation = null
|
||||
var activeOps = new HashMap[Int, ParallelOperation]
|
||||
private var nextOpId = 0
|
||||
private[spark] var taskIdToOpId = new HashMap[Int, Int]
|
||||
|
||||
def newOpId(): Int = {
|
||||
val id = nextOpId
|
||||
nextOpId += 1
|
||||
return id
|
||||
}
|
||||
|
||||
// Incrementing task ID
|
||||
private var nextTaskId = 0
|
||||
|
@ -62,27 +71,29 @@ extends NScheduler with spark.Scheduler
|
|||
new ExecutorInfo(new File("spark-executor").getCanonicalPath(), execArg)
|
||||
|
||||
override def runTasks[T: ClassManifest](tasks: Array[Task[T]]): Array[T] = {
|
||||
var opId = 0
|
||||
runTasksMutex.synchronized {
|
||||
waitForRegister()
|
||||
val myOp = new SimpleParallelOperation(this, tasks)
|
||||
|
||||
try {
|
||||
this.synchronized {
|
||||
this.activeOp = myOp
|
||||
}
|
||||
driver.reviveOffers();
|
||||
myOp.join();
|
||||
} finally {
|
||||
this.synchronized {
|
||||
this.activeOp = null
|
||||
}
|
||||
}
|
||||
|
||||
if (myOp.errorHappened)
|
||||
throw new SparkException(myOp.errorMessage, myOp.errorCode)
|
||||
else
|
||||
return myOp.results
|
||||
opId = newOpId()
|
||||
}
|
||||
val myOp = new SimpleParallelOperation(this, tasks, opId)
|
||||
|
||||
try {
|
||||
this.synchronized {
|
||||
this.activeOps(myOp.opId) = myOp
|
||||
}
|
||||
driver.reviveOffers();
|
||||
myOp.join();
|
||||
} finally {
|
||||
this.synchronized {
|
||||
this.activeOps.remove(myOp.opId)
|
||||
}
|
||||
}
|
||||
|
||||
if (myOp.errorHappened)
|
||||
throw new SparkException(myOp.errorMessage, myOp.errorCode)
|
||||
else
|
||||
return myOp.results
|
||||
}
|
||||
|
||||
override def registered(d: SchedulerDriver, frameworkId: String) {
|
||||
|
@ -104,28 +115,26 @@ extends NScheduler with spark.Scheduler
|
|||
d: SchedulerDriver, oid: String, offers: java.util.List[SlaveOffer]) {
|
||||
synchronized {
|
||||
val tasks = new java.util.ArrayList[TaskDescription]
|
||||
if (activeOp != null) {
|
||||
try {
|
||||
val availableCpus = offers.map(_.getParams.get("cpus").toInt)
|
||||
val availableMem = offers.map(_.getParams.get("mem").toInt)
|
||||
var resourcesAvailable = true
|
||||
while (resourcesAvailable) {
|
||||
resourcesAvailable = false
|
||||
for (i <- 0 until offers.size.toInt) {
|
||||
activeOp.slaveOffer(offers.get(i), availableCpus(i), availableMem(i)) match {
|
||||
case Some(task) =>
|
||||
tasks.add(task)
|
||||
availableCpus(i) -= task.getParams.get("cpus").toInt
|
||||
availableMem(i) -= task.getParams.get("mem").toInt
|
||||
resourcesAvailable = resourcesAvailable || true
|
||||
case None => {}
|
||||
}
|
||||
val availableCpus = offers.map(_.getParams.get("cpus").toInt)
|
||||
val availableMem = offers.map(_.getParams.get("mem").toInt)
|
||||
var resourcesAvailable = true
|
||||
while (resourcesAvailable) {
|
||||
resourcesAvailable = false
|
||||
for (i <- 0 until offers.size.toInt; (opId, activeOp) <- activeOps) {
|
||||
try {
|
||||
activeOp.slaveOffer(offers.get(i), availableCpus(i), availableMem(i)) match {
|
||||
case Some(task) =>
|
||||
tasks.add(task)
|
||||
availableCpus(i) -= task.getParams.get("cpus").toInt
|
||||
availableMem(i) -= task.getParams.get("mem").toInt
|
||||
resourcesAvailable = resourcesAvailable || true
|
||||
case None => {}
|
||||
}
|
||||
} catch {
|
||||
case e: Exception => e.printStackTrace
|
||||
}
|
||||
} catch {
|
||||
case e: Exception => e.printStackTrace
|
||||
}
|
||||
}
|
||||
}
|
||||
val params = new java.util.HashMap[String, String]
|
||||
params.put("timeout", "1")
|
||||
d.replyToOffer(oid, tasks, params) // TODO: use smaller timeout
|
||||
|
@ -135,9 +144,15 @@ extends NScheduler with spark.Scheduler
|
|||
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
|
||||
synchronized {
|
||||
try {
|
||||
if (activeOp != null) {
|
||||
activeOp.statusUpdate(status)
|
||||
taskIdToOpId.get(status.getTaskId) match {
|
||||
case Some(opId) =>
|
||||
if (activeOps.contains(opId)) {
|
||||
activeOps(opId).statusUpdate(status)
|
||||
}
|
||||
case None =>
|
||||
println("TID " + status.getTaskId + "already finished")
|
||||
}
|
||||
|
||||
} catch {
|
||||
case e: Exception => e.printStackTrace
|
||||
}
|
||||
|
@ -146,11 +161,13 @@ extends NScheduler with spark.Scheduler
|
|||
|
||||
override def error(d: SchedulerDriver, code: Int, message: String) {
|
||||
synchronized {
|
||||
if (activeOp != null) {
|
||||
try {
|
||||
activeOp.error(code, message)
|
||||
} catch {
|
||||
case e: Exception => e.printStackTrace
|
||||
if (activeOps.size > 0) {
|
||||
for ((opId, activeOp) <- activeOps) {
|
||||
try {
|
||||
activeOp.error(code, message)
|
||||
} catch {
|
||||
case e: Exception => e.printStackTrace
|
||||
}
|
||||
}
|
||||
} else {
|
||||
val msg = "Mesos error: %s (error code: %d)".format(message, code)
|
||||
|
@ -180,7 +197,7 @@ trait ParallelOperation {
|
|||
|
||||
|
||||
class SimpleParallelOperation[T: ClassManifest](
|
||||
sched: MesosScheduler, tasks: Array[Task[T]])
|
||||
sched: MesosScheduler, tasks: Array[Task[T]], val opId: Int)
|
||||
extends ParallelOperation
|
||||
{
|
||||
// Maximum time to wait to run a task in a preferred location (in ms)
|
||||
|
@ -235,10 +252,10 @@ extends ParallelOperation
|
|||
tasks(i).preferredLocations.isEmpty))
|
||||
{
|
||||
val taskId = sched.newTaskId()
|
||||
sched.taskIdToOpId(taskId) = opId
|
||||
tidToIndex(taskId) = i
|
||||
//printf("Starting task %d as TID %s on slave %s: %s (%s)\n",
|
||||
printf("Starting task %d as TID %s on slave %s: %s (%s)",
|
||||
i, taskId, offer.getSlaveId, offer.getHost,
|
||||
printf("Starting task %d as opId %d, TID %s on slave %s: %s (%s)",
|
||||
i, opId, taskId, offer.getSlaveId, offer.getHost,
|
||||
if(checkPref) "preferred" else "non-preferred")
|
||||
tasks(i).markStarted(offer)
|
||||
launched(i) = true
|
||||
|
@ -274,7 +291,7 @@ extends ParallelOperation
|
|||
|
||||
def taskFinished(status: TaskStatus) {
|
||||
val tid = status.getTaskId
|
||||
print("Finished TID " + tid)
|
||||
print("Finished opId " + opId + " TID " + tid)
|
||||
if (!finished(tidToIndex(tid))) {
|
||||
// Deserialize task result
|
||||
val result = Utils.deserialize[TaskResult[T]](status.getData)
|
||||
|
@ -283,6 +300,8 @@ extends ParallelOperation
|
|||
Accumulators.add(callingThread, result.accumUpdates)
|
||||
// Mark finished and stop if we've finished all the tasks
|
||||
finished(tidToIndex(tid)) = true
|
||||
// Remove TID -> opId mapping from sched
|
||||
sched.taskIdToOpId.remove(tid)
|
||||
tasksFinished += 1
|
||||
|
||||
println(", finished " + tasksFinished + "/" + numTasks)
|
||||
|
@ -295,7 +314,7 @@ extends ParallelOperation
|
|||
|
||||
def taskLost(status: TaskStatus) {
|
||||
val tid = status.getTaskId
|
||||
println("Lost TID " + tid)
|
||||
println("Lost opId " + opId + " TID " + tid)
|
||||
if (!finished(tidToIndex(tid))) {
|
||||
launched(tidToIndex(tid)) = false
|
||||
tasksLaunched -= 1
|
||||
|
|
|
@ -4,6 +4,17 @@ import java.io._
|
|||
import java.util.UUID
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.actors.Actor._
|
||||
|
||||
case class SparkAsyncLock(var finished: Boolean = false) {
|
||||
def join() {
|
||||
this.synchronized {
|
||||
while (!finished) {
|
||||
this.wait
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class SparkContext(master: String, frameworkName: String) {
|
||||
Broadcast.initialize(true)
|
||||
|
@ -21,6 +32,18 @@ class SparkContext(master: String, frameworkName: String) {
|
|||
def broadcast[T](value: T) = new CentralizedHDFSBroadcast(value, local)
|
||||
//def broadcast[T](value: T) = new ChainedStreamingBroadcast(value, local)
|
||||
|
||||
def fork(f: => Unit): SparkAsyncLock = {
|
||||
val thisLock = new SparkAsyncLock
|
||||
actor {
|
||||
f
|
||||
thisLock.synchronized {
|
||||
thisLock.finished = true
|
||||
thisLock.notifyAll()
|
||||
}
|
||||
}
|
||||
thisLock
|
||||
}
|
||||
|
||||
def textFile(path: String) = new HdfsTextFile(this, path)
|
||||
|
||||
val LOCAL_REGEX = """local\[([0-9]+)\]""".r
|
||||
|
|
Loading…
Reference in a new issue