1. Add unit test for local scheduler
2. Move localTaskSetManager to a new file
This commit is contained in:
parent
ecceb101d3
commit
c3db3ea554
|
@ -15,7 +15,7 @@ import spark.scheduler.cluster._
|
|||
import akka.actor._
|
||||
|
||||
/**
|
||||
* A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
|
||||
* A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
|
||||
* the scheduler also allows each task to fail up to maxFailures times, which is useful for
|
||||
* testing fault recovery.
|
||||
*/
|
||||
|
@ -26,10 +26,8 @@ private[spark] case class LocalStatusUpdate(taskId: Long, state: TaskState, seri
|
|||
private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging {
|
||||
def receive = {
|
||||
case LocalReviveOffers =>
|
||||
logInfo("LocalReviveOffers")
|
||||
launchTask(localScheduler.resourceOffer(freeCores))
|
||||
case LocalStatusUpdate(taskId, state, serializeData) =>
|
||||
logInfo("LocalStatusUpdate")
|
||||
freeCores += 1
|
||||
localScheduler.statusUpdate(taskId, state, serializeData)
|
||||
launchTask(localScheduler.resourceOffer(freeCores))
|
||||
|
@ -48,168 +46,6 @@ private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: I
|
|||
}
|
||||
}
|
||||
|
||||
private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging {
|
||||
var parent: Schedulable = null
|
||||
var weight: Int = 1
|
||||
var minShare: Int = 0
|
||||
var runningTasks: Int = 0
|
||||
var priority: Int = taskSet.priority
|
||||
var stageId: Int = taskSet.stageId
|
||||
var name: String = "TaskSet_"+taskSet.stageId.toString
|
||||
|
||||
|
||||
var failCount = new Array[Int](taskSet.tasks.size)
|
||||
val taskInfos = new HashMap[Long, TaskInfo]
|
||||
val numTasks = taskSet.tasks.size
|
||||
var numFinished = 0
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||
val copiesRunning = new Array[Int](numTasks)
|
||||
val finished = new Array[Boolean](numTasks)
|
||||
val numFailures = new Array[Int](numTasks)
|
||||
|
||||
def increaseRunningTasks(taskNum: Int): Unit = {
|
||||
runningTasks += taskNum
|
||||
if (parent != null) {
|
||||
parent.increaseRunningTasks(taskNum)
|
||||
}
|
||||
}
|
||||
|
||||
def decreaseRunningTasks(taskNum: Int): Unit = {
|
||||
runningTasks -= taskNum
|
||||
if (parent != null) {
|
||||
parent.decreaseRunningTasks(taskNum)
|
||||
}
|
||||
}
|
||||
|
||||
def addSchedulable(schedulable: Schedulable): Unit = {
|
||||
//nothing
|
||||
}
|
||||
|
||||
def removeSchedulable(schedulable: Schedulable): Unit = {
|
||||
//nothing
|
||||
}
|
||||
|
||||
def getSchedulableByName(name: String): Schedulable = {
|
||||
return null
|
||||
}
|
||||
|
||||
def executorLost(executorId: String, host: String): Unit = {
|
||||
//nothing
|
||||
}
|
||||
|
||||
def checkSpeculatableTasks(): Boolean = {
|
||||
return true
|
||||
}
|
||||
|
||||
def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
|
||||
var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
|
||||
sortedTaskSetQueue += this
|
||||
return sortedTaskSetQueue
|
||||
}
|
||||
|
||||
def hasPendingTasks(): Boolean = {
|
||||
return true
|
||||
}
|
||||
|
||||
def findTask(): Option[Int] = {
|
||||
for (i <- 0 to numTasks-1) {
|
||||
if (copiesRunning(i) == 0 && !finished(i)) {
|
||||
return Some(i)
|
||||
}
|
||||
}
|
||||
return None
|
||||
}
|
||||
|
||||
def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
|
||||
Thread.currentThread().setContextClassLoader(sched.classLoader)
|
||||
SparkEnv.set(sched.env)
|
||||
logInfo("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks))
|
||||
if (availableCpus > 0 && numFinished < numTasks) {
|
||||
findTask() match {
|
||||
case Some(index) =>
|
||||
logInfo(taskSet.tasks(index).toString)
|
||||
val taskId = sched.attemptId.getAndIncrement()
|
||||
val task = taskSet.tasks(index)
|
||||
logInfo("taskId:%d,task:%s".format(index,task))
|
||||
val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
|
||||
taskInfos(taskId) = info
|
||||
val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
|
||||
logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes")
|
||||
val taskName = "task %s:%d".format(taskSet.id, index)
|
||||
copiesRunning(index) += 1
|
||||
increaseRunningTasks(1)
|
||||
return Some(new TaskDescription(taskId, null, taskName, bytes))
|
||||
case None => {}
|
||||
}
|
||||
}
|
||||
return None
|
||||
}
|
||||
|
||||
def numPendingTasksForHostPort(hostPort: String): Int = {
|
||||
return 0
|
||||
}
|
||||
|
||||
def numRackLocalPendingTasksForHost(hostPort :String): Int = {
|
||||
return 0
|
||||
}
|
||||
|
||||
def numPendingTasksForHost(hostPort: String): Int = {
|
||||
return 0
|
||||
}
|
||||
|
||||
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
|
||||
state match {
|
||||
case TaskState.FINISHED =>
|
||||
taskEnded(tid, state, serializedData)
|
||||
case TaskState.FAILED =>
|
||||
taskFailed(tid, state, serializedData)
|
||||
case _ => {}
|
||||
}
|
||||
}
|
||||
|
||||
def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
|
||||
val info = taskInfos(tid)
|
||||
val index = info.index
|
||||
val task = taskSet.tasks(index)
|
||||
info.markSuccessful()
|
||||
val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
|
||||
result.metrics.resultSize = serializedData.limit()
|
||||
sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
|
||||
numFinished += 1
|
||||
decreaseRunningTasks(1)
|
||||
finished(index) = true
|
||||
if (numFinished == numTasks) {
|
||||
sched.taskSetFinished(this)
|
||||
}
|
||||
}
|
||||
|
||||
def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) {
|
||||
val info = taskInfos(tid)
|
||||
val index = info.index
|
||||
val task = taskSet.tasks(index)
|
||||
info.markFailed()
|
||||
decreaseRunningTasks(1)
|
||||
val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader)
|
||||
if (!finished(index)) {
|
||||
copiesRunning(index) -= 1
|
||||
numFailures(index) += 1
|
||||
val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString))
|
||||
logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n")))
|
||||
if (numFailures(index) > 4) {
|
||||
val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description)
|
||||
//logError(errorMessage)
|
||||
//sched.listener.taskEnded(task, reason, null, null, info, null)
|
||||
sched.listener.taskSetFailed(taskSet, errorMessage)
|
||||
sched.taskSetFinished(this)
|
||||
decreaseRunningTasks(runningTasks)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def error(message: String) {
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
|
||||
extends TaskScheduler
|
||||
with Logging {
|
||||
|
@ -233,7 +69,6 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
|
|||
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
|
||||
|
||||
var localActor: ActorRef = null
|
||||
// TODO: Need to take into account stage priority in scheduling
|
||||
|
||||
override def start() {
|
||||
//default scheduler is FIFO
|
||||
|
@ -250,7 +85,6 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
|
|||
}
|
||||
schedulableBuilder.buildPools()
|
||||
|
||||
//val properties = new ArrayBuffer[(String, String)]
|
||||
localActor = env.actorSystem.actorOf(
|
||||
Props(new LocalActor(this, threads)), "Test")
|
||||
}
|
||||
|
@ -260,51 +94,56 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
|
|||
}
|
||||
|
||||
override def submitTasks(taskSet: TaskSet) {
|
||||
var manager = new LocalTaskSetManager(this, taskSet)
|
||||
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
|
||||
activeTaskSets(taskSet.id) = manager
|
||||
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
|
||||
localActor ! LocalReviveOffers
|
||||
synchronized {
|
||||
var manager = new LocalTaskSetManager(this, taskSet)
|
||||
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
|
||||
activeTaskSets(taskSet.id) = manager
|
||||
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
|
||||
localActor ! LocalReviveOffers
|
||||
}
|
||||
}
|
||||
|
||||
def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
|
||||
var freeCpuCores = freeCores
|
||||
val tasks = new ArrayBuffer[TaskDescription](freeCores)
|
||||
val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
|
||||
for (manager <- sortedTaskSetQueue) {
|
||||
logInfo("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks))
|
||||
}
|
||||
synchronized {
|
||||
var freeCpuCores = freeCores
|
||||
val tasks = new ArrayBuffer[TaskDescription](freeCores)
|
||||
val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
|
||||
for (manager <- sortedTaskSetQueue) {
|
||||
logInfo("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks))
|
||||
}
|
||||
|
||||
var launchTask = false
|
||||
for (manager <- sortedTaskSetQueue) {
|
||||
var launchTask = false
|
||||
for (manager <- sortedTaskSetQueue) {
|
||||
do {
|
||||
launchTask = false
|
||||
logInfo("freeCores is" + freeCpuCores)
|
||||
manager.slaveOffer(null,null,freeCpuCores) match {
|
||||
case Some(task) =>
|
||||
tasks += task
|
||||
taskIdToTaskSetId(task.taskId) = manager.taskSet.id
|
||||
taskSetTaskIds(manager.taskSet.id) += task.taskId
|
||||
freeCpuCores -= 1
|
||||
launchTask = true
|
||||
tasks += task
|
||||
taskIdToTaskSetId(task.taskId) = manager.taskSet.id
|
||||
taskSetTaskIds(manager.taskSet.id) += task.taskId
|
||||
freeCpuCores -= 1
|
||||
launchTask = true
|
||||
case None => {}
|
||||
}
|
||||
}
|
||||
} while(launchTask)
|
||||
}
|
||||
return tasks
|
||||
}
|
||||
return tasks
|
||||
}
|
||||
|
||||
def taskSetFinished(manager: TaskSetManager) {
|
||||
activeTaskSets -= manager.taskSet.id
|
||||
manager.parent.removeSchedulable(manager)
|
||||
logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
|
||||
taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
|
||||
taskSetTaskIds -= manager.taskSet.id
|
||||
synchronized {
|
||||
activeTaskSets -= manager.taskSet.id
|
||||
manager.parent.removeSchedulable(manager)
|
||||
logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
|
||||
taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
|
||||
taskSetTaskIds -= manager.taskSet.id
|
||||
}
|
||||
}
|
||||
|
||||
def runTask(taskId: Long, bytes: ByteBuffer) {
|
||||
logInfo("Running " + taskId)
|
||||
val info = new TaskInfo(taskId, 0 , System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
|
||||
val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
|
||||
// Set the Spark execution environment for the worker thread
|
||||
SparkEnv.set(env)
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||
|
@ -344,8 +183,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
|
|||
case t: Throwable => {
|
||||
val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace)
|
||||
localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -376,11 +215,13 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
|
|||
}
|
||||
}
|
||||
|
||||
def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer)
|
||||
{
|
||||
val taskSetId = taskIdToTaskSetId(taskId)
|
||||
val taskSetManager = activeTaskSets(taskSetId)
|
||||
taskSetManager.statusUpdate(taskId, state, serializedData)
|
||||
def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) {
|
||||
synchronized {
|
||||
val taskSetId = taskIdToTaskSetId(taskId)
|
||||
val taskSetManager = activeTaskSets(taskSetId)
|
||||
taskSetTaskIds(taskSetId) -= taskId
|
||||
taskSetManager.statusUpdate(taskId, state, serializedData)
|
||||
}
|
||||
}
|
||||
|
||||
override def stop() {
|
||||
|
|
|
@ -0,0 +1,173 @@
|
|||
package spark.scheduler.local
|
||||
|
||||
import java.io.File
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import java.nio.ByteBuffer
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.collection.mutable.HashSet
|
||||
|
||||
import spark._
|
||||
import spark.TaskState.TaskState
|
||||
import spark.scheduler._
|
||||
import spark.scheduler.cluster._
|
||||
|
||||
private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging {
|
||||
var parent: Schedulable = null
|
||||
var weight: Int = 1
|
||||
var minShare: Int = 0
|
||||
var runningTasks: Int = 0
|
||||
var priority: Int = taskSet.priority
|
||||
var stageId: Int = taskSet.stageId
|
||||
var name: String = "TaskSet_"+taskSet.stageId.toString
|
||||
|
||||
|
||||
var failCount = new Array[Int](taskSet.tasks.size)
|
||||
val taskInfos = new HashMap[Long, TaskInfo]
|
||||
val numTasks = taskSet.tasks.size
|
||||
var numFinished = 0
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||
val copiesRunning = new Array[Int](numTasks)
|
||||
val finished = new Array[Boolean](numTasks)
|
||||
val numFailures = new Array[Int](numTasks)
|
||||
val MAX_TASK_FAILURES = sched.maxFailures
|
||||
|
||||
def increaseRunningTasks(taskNum: Int): Unit = {
|
||||
runningTasks += taskNum
|
||||
if (parent != null) {
|
||||
parent.increaseRunningTasks(taskNum)
|
||||
}
|
||||
}
|
||||
|
||||
def decreaseRunningTasks(taskNum: Int): Unit = {
|
||||
runningTasks -= taskNum
|
||||
if (parent != null) {
|
||||
parent.decreaseRunningTasks(taskNum)
|
||||
}
|
||||
}
|
||||
|
||||
def addSchedulable(schedulable: Schedulable): Unit = {
|
||||
//nothing
|
||||
}
|
||||
|
||||
def removeSchedulable(schedulable: Schedulable): Unit = {
|
||||
//nothing
|
||||
}
|
||||
|
||||
def getSchedulableByName(name: String): Schedulable = {
|
||||
return null
|
||||
}
|
||||
|
||||
def executorLost(executorId: String, host: String): Unit = {
|
||||
//nothing
|
||||
}
|
||||
|
||||
def checkSpeculatableTasks(): Boolean = {
|
||||
return true
|
||||
}
|
||||
|
||||
def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
|
||||
var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
|
||||
sortedTaskSetQueue += this
|
||||
return sortedTaskSetQueue
|
||||
}
|
||||
|
||||
def hasPendingTasks(): Boolean = {
|
||||
return true
|
||||
}
|
||||
|
||||
def findTask(): Option[Int] = {
|
||||
for (i <- 0 to numTasks-1) {
|
||||
if (copiesRunning(i) == 0 && !finished(i)) {
|
||||
return Some(i)
|
||||
}
|
||||
}
|
||||
return None
|
||||
}
|
||||
|
||||
def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
|
||||
SparkEnv.set(sched.env)
|
||||
logDebug("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks))
|
||||
if (availableCpus > 0 && numFinished < numTasks) {
|
||||
findTask() match {
|
||||
case Some(index) =>
|
||||
logInfo(taskSet.tasks(index).toString)
|
||||
val taskId = sched.attemptId.getAndIncrement()
|
||||
val task = taskSet.tasks(index)
|
||||
val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
|
||||
taskInfos(taskId) = info
|
||||
val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
|
||||
logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes")
|
||||
val taskName = "task %s:%d".format(taskSet.id, index)
|
||||
copiesRunning(index) += 1
|
||||
increaseRunningTasks(1)
|
||||
return Some(new TaskDescription(taskId, null, taskName, bytes))
|
||||
case None => {}
|
||||
}
|
||||
}
|
||||
return None
|
||||
}
|
||||
|
||||
def numPendingTasksForHostPort(hostPort: String): Int = {
|
||||
return 0
|
||||
}
|
||||
|
||||
def numRackLocalPendingTasksForHost(hostPort :String): Int = {
|
||||
return 0
|
||||
}
|
||||
|
||||
def numPendingTasksForHost(hostPort: String): Int = {
|
||||
return 0
|
||||
}
|
||||
|
||||
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
|
||||
state match {
|
||||
case TaskState.FINISHED =>
|
||||
taskEnded(tid, state, serializedData)
|
||||
case TaskState.FAILED =>
|
||||
taskFailed(tid, state, serializedData)
|
||||
case _ => {}
|
||||
}
|
||||
}
|
||||
|
||||
def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
|
||||
val info = taskInfos(tid)
|
||||
val index = info.index
|
||||
val task = taskSet.tasks(index)
|
||||
info.markSuccessful()
|
||||
val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
|
||||
result.metrics.resultSize = serializedData.limit()
|
||||
sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
|
||||
numFinished += 1
|
||||
decreaseRunningTasks(1)
|
||||
finished(index) = true
|
||||
if (numFinished == numTasks) {
|
||||
sched.taskSetFinished(this)
|
||||
}
|
||||
}
|
||||
|
||||
def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) {
|
||||
val info = taskInfos(tid)
|
||||
val index = info.index
|
||||
val task = taskSet.tasks(index)
|
||||
info.markFailed()
|
||||
decreaseRunningTasks(1)
|
||||
val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader)
|
||||
if (!finished(index)) {
|
||||
copiesRunning(index) -= 1
|
||||
numFailures(index) += 1
|
||||
val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString))
|
||||
logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n")))
|
||||
if (numFailures(index) > MAX_TASK_FAILURES) {
|
||||
val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description)
|
||||
decreaseRunningTasks(runningTasks)
|
||||
sched.listener.taskSetFailed(taskSet, errorMessage)
|
||||
// need to delete failed Taskset from schedule queue
|
||||
sched.taskSetFinished(this)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def error(message: String) {
|
||||
}
|
||||
}
|
171
core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala
Normal file
171
core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala
Normal file
|
@ -0,0 +1,171 @@
|
|||
package spark.scheduler
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.BeforeAndAfter
|
||||
|
||||
import spark._
|
||||
import spark.scheduler._
|
||||
import spark.scheduler.cluster._
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.mutable.{ConcurrentMap, HashMap}
|
||||
import java.util.concurrent.Semaphore
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
|
||||
import java.util.Properties
|
||||
|
||||
class Lock() {
|
||||
var finished = false
|
||||
def jobWait() = {
|
||||
synchronized {
|
||||
while(!finished) {
|
||||
this.wait()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def jobFinished() = {
|
||||
synchronized {
|
||||
finished = true
|
||||
this.notifyAll()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object TaskThreadInfo {
|
||||
val threadToLock = HashMap[Int, Lock]()
|
||||
val threadToRunning = HashMap[Int, Boolean]()
|
||||
}
|
||||
|
||||
|
||||
class LocalSchedulerSuite extends FunSuite with LocalSparkContext {
|
||||
|
||||
def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) {
|
||||
|
||||
TaskThreadInfo.threadToRunning(threadIndex) = false
|
||||
val nums = sc.parallelize(threadIndex to threadIndex, 1)
|
||||
TaskThreadInfo.threadToLock(threadIndex) = new Lock()
|
||||
new Thread {
|
||||
if (poolName != null) {
|
||||
sc.addLocalProperties("spark.scheduler.cluster.fair.pool",poolName)
|
||||
}
|
||||
override def run() {
|
||||
val ans = nums.map(number => {
|
||||
TaskThreadInfo.threadToRunning(number) = true
|
||||
TaskThreadInfo.threadToLock(number).jobWait()
|
||||
number
|
||||
}).collect()
|
||||
assert(ans.toList === List(threadIndex))
|
||||
sem.release()
|
||||
TaskThreadInfo.threadToRunning(threadIndex) = false
|
||||
}
|
||||
}.start()
|
||||
Thread.sleep(2000)
|
||||
}
|
||||
|
||||
test("Local FIFO scheduler end-to-end test") {
|
||||
System.setProperty("spark.cluster.schedulingmode", "FIFO")
|
||||
sc = new SparkContext("local[4]", "test")
|
||||
val sem = new Semaphore(0)
|
||||
|
||||
createThread(1,null,sc,sem)
|
||||
createThread(2,null,sc,sem)
|
||||
createThread(3,null,sc,sem)
|
||||
createThread(4,null,sc,sem)
|
||||
createThread(5,null,sc,sem)
|
||||
createThread(6,null,sc,sem)
|
||||
assert(TaskThreadInfo.threadToRunning(1) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(2) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(3) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(4) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(5) === false)
|
||||
assert(TaskThreadInfo.threadToRunning(6) === false)
|
||||
|
||||
TaskThreadInfo.threadToLock(1).jobFinished()
|
||||
Thread.sleep(1000)
|
||||
|
||||
assert(TaskThreadInfo.threadToRunning(1) === false)
|
||||
assert(TaskThreadInfo.threadToRunning(2) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(3) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(4) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(5) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(6) === false)
|
||||
|
||||
TaskThreadInfo.threadToLock(3).jobFinished()
|
||||
Thread.sleep(1000)
|
||||
|
||||
assert(TaskThreadInfo.threadToRunning(1) === false)
|
||||
assert(TaskThreadInfo.threadToRunning(2) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(3) === false)
|
||||
assert(TaskThreadInfo.threadToRunning(4) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(5) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(6) === true)
|
||||
|
||||
TaskThreadInfo.threadToLock(2).jobFinished()
|
||||
TaskThreadInfo.threadToLock(4).jobFinished()
|
||||
TaskThreadInfo.threadToLock(5).jobFinished()
|
||||
TaskThreadInfo.threadToLock(6).jobFinished()
|
||||
sem.acquire(6)
|
||||
}
|
||||
|
||||
test("Local fair scheduler end-to-end test") {
|
||||
sc = new SparkContext("local[8]", "LocalSchedulerSuite")
|
||||
val sem = new Semaphore(0)
|
||||
System.setProperty("spark.cluster.schedulingmode", "FAIR")
|
||||
val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
|
||||
System.setProperty("spark.fairscheduler.allocation.file", xmlPath)
|
||||
|
||||
createThread(10,"1",sc,sem)
|
||||
createThread(20,"2",sc,sem)
|
||||
createThread(30,"3",sc,sem)
|
||||
|
||||
assert(TaskThreadInfo.threadToRunning(10) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(20) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(30) === true)
|
||||
|
||||
createThread(11,"1",sc,sem)
|
||||
createThread(21,"2",sc,sem)
|
||||
createThread(31,"3",sc,sem)
|
||||
|
||||
assert(TaskThreadInfo.threadToRunning(11) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(21) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(31) === true)
|
||||
|
||||
createThread(12,"1",sc,sem)
|
||||
createThread(22,"2",sc,sem)
|
||||
createThread(32,"3",sc,sem)
|
||||
|
||||
assert(TaskThreadInfo.threadToRunning(12) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(22) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(32) === false)
|
||||
|
||||
TaskThreadInfo.threadToLock(10).jobFinished()
|
||||
Thread.sleep(1000)
|
||||
assert(TaskThreadInfo.threadToRunning(32) === true)
|
||||
|
||||
createThread(23,"2",sc,sem)
|
||||
createThread(33,"3",sc,sem)
|
||||
|
||||
TaskThreadInfo.threadToLock(11).jobFinished()
|
||||
Thread.sleep(1000)
|
||||
|
||||
assert(TaskThreadInfo.threadToRunning(23) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(33) === false)
|
||||
|
||||
TaskThreadInfo.threadToLock(12).jobFinished()
|
||||
Thread.sleep(1000)
|
||||
|
||||
assert(TaskThreadInfo.threadToRunning(33) === true)
|
||||
|
||||
TaskThreadInfo.threadToLock(20).jobFinished()
|
||||
TaskThreadInfo.threadToLock(21).jobFinished()
|
||||
TaskThreadInfo.threadToLock(22).jobFinished()
|
||||
TaskThreadInfo.threadToLock(23).jobFinished()
|
||||
TaskThreadInfo.threadToLock(30).jobFinished()
|
||||
TaskThreadInfo.threadToLock(31).jobFinished()
|
||||
TaskThreadInfo.threadToLock(32).jobFinished()
|
||||
TaskThreadInfo.threadToLock(33).jobFinished()
|
||||
|
||||
sem.acquire(11)
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue