1. Add unit test for local scheduler

2. Move localTaskSetManager to a new file
This commit is contained in:
Andrew xia 2013-05-30 20:49:40 +08:00
parent ecceb101d3
commit c3db3ea554
3 changed files with 385 additions and 200 deletions

View file

@ -15,7 +15,7 @@ import spark.scheduler.cluster._
import akka.actor._
/**
* A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
* A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
* the scheduler also allows each task to fail up to maxFailures times, which is useful for
* testing fault recovery.
*/
@ -26,10 +26,8 @@ private[spark] case class LocalStatusUpdate(taskId: Long, state: TaskState, seri
private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging {
def receive = {
case LocalReviveOffers =>
logInfo("LocalReviveOffers")
launchTask(localScheduler.resourceOffer(freeCores))
case LocalStatusUpdate(taskId, state, serializeData) =>
logInfo("LocalStatusUpdate")
freeCores += 1
localScheduler.statusUpdate(taskId, state, serializeData)
launchTask(localScheduler.resourceOffer(freeCores))
@ -48,168 +46,6 @@ private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: I
}
}
private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging {
var parent: Schedulable = null
var weight: Int = 1
var minShare: Int = 0
var runningTasks: Int = 0
var priority: Int = taskSet.priority
var stageId: Int = taskSet.stageId
var name: String = "TaskSet_"+taskSet.stageId.toString
var failCount = new Array[Int](taskSet.tasks.size)
val taskInfos = new HashMap[Long, TaskInfo]
val numTasks = taskSet.tasks.size
var numFinished = 0
val ser = SparkEnv.get.closureSerializer.newInstance()
val copiesRunning = new Array[Int](numTasks)
val finished = new Array[Boolean](numTasks)
val numFailures = new Array[Int](numTasks)
def increaseRunningTasks(taskNum: Int): Unit = {
runningTasks += taskNum
if (parent != null) {
parent.increaseRunningTasks(taskNum)
}
}
def decreaseRunningTasks(taskNum: Int): Unit = {
runningTasks -= taskNum
if (parent != null) {
parent.decreaseRunningTasks(taskNum)
}
}
def addSchedulable(schedulable: Schedulable): Unit = {
//nothing
}
def removeSchedulable(schedulable: Schedulable): Unit = {
//nothing
}
def getSchedulableByName(name: String): Schedulable = {
return null
}
def executorLost(executorId: String, host: String): Unit = {
//nothing
}
def checkSpeculatableTasks(): Boolean = {
return true
}
def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
sortedTaskSetQueue += this
return sortedTaskSetQueue
}
def hasPendingTasks(): Boolean = {
return true
}
def findTask(): Option[Int] = {
for (i <- 0 to numTasks-1) {
if (copiesRunning(i) == 0 && !finished(i)) {
return Some(i)
}
}
return None
}
def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
Thread.currentThread().setContextClassLoader(sched.classLoader)
SparkEnv.set(sched.env)
logInfo("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks))
if (availableCpus > 0 && numFinished < numTasks) {
findTask() match {
case Some(index) =>
logInfo(taskSet.tasks(index).toString)
val taskId = sched.attemptId.getAndIncrement()
val task = taskSet.tasks(index)
logInfo("taskId:%d,task:%s".format(index,task))
val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
taskInfos(taskId) = info
val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes")
val taskName = "task %s:%d".format(taskSet.id, index)
copiesRunning(index) += 1
increaseRunningTasks(1)
return Some(new TaskDescription(taskId, null, taskName, bytes))
case None => {}
}
}
return None
}
def numPendingTasksForHostPort(hostPort: String): Int = {
return 0
}
def numRackLocalPendingTasksForHost(hostPort :String): Int = {
return 0
}
def numPendingTasksForHost(hostPort: String): Int = {
return 0
}
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
state match {
case TaskState.FINISHED =>
taskEnded(tid, state, serializedData)
case TaskState.FAILED =>
taskFailed(tid, state, serializedData)
case _ => {}
}
}
def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
val index = info.index
val task = taskSet.tasks(index)
info.markSuccessful()
val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
result.metrics.resultSize = serializedData.limit()
sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
numFinished += 1
decreaseRunningTasks(1)
finished(index) = true
if (numFinished == numTasks) {
sched.taskSetFinished(this)
}
}
def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
val index = info.index
val task = taskSet.tasks(index)
info.markFailed()
decreaseRunningTasks(1)
val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader)
if (!finished(index)) {
copiesRunning(index) -= 1
numFailures(index) += 1
val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString))
logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n")))
if (numFailures(index) > 4) {
val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description)
//logError(errorMessage)
//sched.listener.taskEnded(task, reason, null, null, info, null)
sched.listener.taskSetFailed(taskSet, errorMessage)
sched.taskSetFinished(this)
decreaseRunningTasks(runningTasks)
}
}
}
def error(message: String) {
}
}
private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
extends TaskScheduler
with Logging {
@ -233,7 +69,6 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
var localActor: ActorRef = null
// TODO: Need to take into account stage priority in scheduling
override def start() {
//default scheduler is FIFO
@ -250,7 +85,6 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
}
schedulableBuilder.buildPools()
//val properties = new ArrayBuffer[(String, String)]
localActor = env.actorSystem.actorOf(
Props(new LocalActor(this, threads)), "Test")
}
@ -260,14 +94,17 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
}
override def submitTasks(taskSet: TaskSet) {
synchronized {
var manager = new LocalTaskSetManager(this, taskSet)
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
activeTaskSets(taskSet.id) = manager
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
localActor ! LocalReviveOffers
}
}
def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
synchronized {
var freeCpuCores = freeCores
val tasks = new ArrayBuffer[TaskDescription](freeCores)
val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
@ -279,7 +116,6 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
for (manager <- sortedTaskSetQueue) {
do {
launchTask = false
logInfo("freeCores is" + freeCpuCores)
manager.slaveOffer(null,null,freeCpuCores) match {
case Some(task) =>
tasks += task
@ -293,18 +129,21 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
}
return tasks
}
}
def taskSetFinished(manager: TaskSetManager) {
synchronized {
activeTaskSets -= manager.taskSet.id
manager.parent.removeSchedulable(manager)
logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
taskSetTaskIds -= manager.taskSet.id
}
}
def runTask(taskId: Long, bytes: ByteBuffer) {
logInfo("Running " + taskId)
val info = new TaskInfo(taskId, 0 , System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
val ser = SparkEnv.get.closureSerializer.newInstance()
@ -376,12 +215,14 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
}
}
def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer)
{
def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) {
synchronized {
val taskSetId = taskIdToTaskSetId(taskId)
val taskSetManager = activeTaskSets(taskSetId)
taskSetTaskIds(taskSetId) -= taskId
taskSetManager.statusUpdate(taskId, state, serializedData)
}
}
override def stop() {
threadPool.shutdownNow()

View file

@ -0,0 +1,173 @@
package spark.scheduler.local
import java.io.File
import java.util.concurrent.atomic.AtomicInteger
import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import spark._
import spark.TaskState.TaskState
import spark.scheduler._
import spark.scheduler.cluster._
private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging {
var parent: Schedulable = null
var weight: Int = 1
var minShare: Int = 0
var runningTasks: Int = 0
var priority: Int = taskSet.priority
var stageId: Int = taskSet.stageId
var name: String = "TaskSet_"+taskSet.stageId.toString
var failCount = new Array[Int](taskSet.tasks.size)
val taskInfos = new HashMap[Long, TaskInfo]
val numTasks = taskSet.tasks.size
var numFinished = 0
val ser = SparkEnv.get.closureSerializer.newInstance()
val copiesRunning = new Array[Int](numTasks)
val finished = new Array[Boolean](numTasks)
val numFailures = new Array[Int](numTasks)
val MAX_TASK_FAILURES = sched.maxFailures
def increaseRunningTasks(taskNum: Int): Unit = {
runningTasks += taskNum
if (parent != null) {
parent.increaseRunningTasks(taskNum)
}
}
def decreaseRunningTasks(taskNum: Int): Unit = {
runningTasks -= taskNum
if (parent != null) {
parent.decreaseRunningTasks(taskNum)
}
}
def addSchedulable(schedulable: Schedulable): Unit = {
//nothing
}
def removeSchedulable(schedulable: Schedulable): Unit = {
//nothing
}
def getSchedulableByName(name: String): Schedulable = {
return null
}
def executorLost(executorId: String, host: String): Unit = {
//nothing
}
def checkSpeculatableTasks(): Boolean = {
return true
}
def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
sortedTaskSetQueue += this
return sortedTaskSetQueue
}
def hasPendingTasks(): Boolean = {
return true
}
def findTask(): Option[Int] = {
for (i <- 0 to numTasks-1) {
if (copiesRunning(i) == 0 && !finished(i)) {
return Some(i)
}
}
return None
}
def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
SparkEnv.set(sched.env)
logDebug("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks))
if (availableCpus > 0 && numFinished < numTasks) {
findTask() match {
case Some(index) =>
logInfo(taskSet.tasks(index).toString)
val taskId = sched.attemptId.getAndIncrement()
val task = taskSet.tasks(index)
val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
taskInfos(taskId) = info
val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes")
val taskName = "task %s:%d".format(taskSet.id, index)
copiesRunning(index) += 1
increaseRunningTasks(1)
return Some(new TaskDescription(taskId, null, taskName, bytes))
case None => {}
}
}
return None
}
def numPendingTasksForHostPort(hostPort: String): Int = {
return 0
}
def numRackLocalPendingTasksForHost(hostPort :String): Int = {
return 0
}
def numPendingTasksForHost(hostPort: String): Int = {
return 0
}
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
state match {
case TaskState.FINISHED =>
taskEnded(tid, state, serializedData)
case TaskState.FAILED =>
taskFailed(tid, state, serializedData)
case _ => {}
}
}
def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
val index = info.index
val task = taskSet.tasks(index)
info.markSuccessful()
val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
result.metrics.resultSize = serializedData.limit()
sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
numFinished += 1
decreaseRunningTasks(1)
finished(index) = true
if (numFinished == numTasks) {
sched.taskSetFinished(this)
}
}
def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
val index = info.index
val task = taskSet.tasks(index)
info.markFailed()
decreaseRunningTasks(1)
val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader)
if (!finished(index)) {
copiesRunning(index) -= 1
numFailures(index) += 1
val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString))
logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n")))
if (numFailures(index) > MAX_TASK_FAILURES) {
val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description)
decreaseRunningTasks(runningTasks)
sched.listener.taskSetFailed(taskSet, errorMessage)
// need to delete failed Taskset from schedule queue
sched.taskSetFinished(this)
}
}
}
def error(message: String) {
}
}

View file

@ -0,0 +1,171 @@
package spark.scheduler
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import spark._
import spark.scheduler._
import spark.scheduler.cluster._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.{ConcurrentMap, HashMap}
import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
import java.util.Properties
class Lock() {
var finished = false
def jobWait() = {
synchronized {
while(!finished) {
this.wait()
}
}
}
def jobFinished() = {
synchronized {
finished = true
this.notifyAll()
}
}
}
object TaskThreadInfo {
val threadToLock = HashMap[Int, Lock]()
val threadToRunning = HashMap[Int, Boolean]()
}
class LocalSchedulerSuite extends FunSuite with LocalSparkContext {
def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) {
TaskThreadInfo.threadToRunning(threadIndex) = false
val nums = sc.parallelize(threadIndex to threadIndex, 1)
TaskThreadInfo.threadToLock(threadIndex) = new Lock()
new Thread {
if (poolName != null) {
sc.addLocalProperties("spark.scheduler.cluster.fair.pool",poolName)
}
override def run() {
val ans = nums.map(number => {
TaskThreadInfo.threadToRunning(number) = true
TaskThreadInfo.threadToLock(number).jobWait()
number
}).collect()
assert(ans.toList === List(threadIndex))
sem.release()
TaskThreadInfo.threadToRunning(threadIndex) = false
}
}.start()
Thread.sleep(2000)
}
test("Local FIFO scheduler end-to-end test") {
System.setProperty("spark.cluster.schedulingmode", "FIFO")
sc = new SparkContext("local[4]", "test")
val sem = new Semaphore(0)
createThread(1,null,sc,sem)
createThread(2,null,sc,sem)
createThread(3,null,sc,sem)
createThread(4,null,sc,sem)
createThread(5,null,sc,sem)
createThread(6,null,sc,sem)
assert(TaskThreadInfo.threadToRunning(1) === true)
assert(TaskThreadInfo.threadToRunning(2) === true)
assert(TaskThreadInfo.threadToRunning(3) === true)
assert(TaskThreadInfo.threadToRunning(4) === true)
assert(TaskThreadInfo.threadToRunning(5) === false)
assert(TaskThreadInfo.threadToRunning(6) === false)
TaskThreadInfo.threadToLock(1).jobFinished()
Thread.sleep(1000)
assert(TaskThreadInfo.threadToRunning(1) === false)
assert(TaskThreadInfo.threadToRunning(2) === true)
assert(TaskThreadInfo.threadToRunning(3) === true)
assert(TaskThreadInfo.threadToRunning(4) === true)
assert(TaskThreadInfo.threadToRunning(5) === true)
assert(TaskThreadInfo.threadToRunning(6) === false)
TaskThreadInfo.threadToLock(3).jobFinished()
Thread.sleep(1000)
assert(TaskThreadInfo.threadToRunning(1) === false)
assert(TaskThreadInfo.threadToRunning(2) === true)
assert(TaskThreadInfo.threadToRunning(3) === false)
assert(TaskThreadInfo.threadToRunning(4) === true)
assert(TaskThreadInfo.threadToRunning(5) === true)
assert(TaskThreadInfo.threadToRunning(6) === true)
TaskThreadInfo.threadToLock(2).jobFinished()
TaskThreadInfo.threadToLock(4).jobFinished()
TaskThreadInfo.threadToLock(5).jobFinished()
TaskThreadInfo.threadToLock(6).jobFinished()
sem.acquire(6)
}
test("Local fair scheduler end-to-end test") {
sc = new SparkContext("local[8]", "LocalSchedulerSuite")
val sem = new Semaphore(0)
System.setProperty("spark.cluster.schedulingmode", "FAIR")
val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
System.setProperty("spark.fairscheduler.allocation.file", xmlPath)
createThread(10,"1",sc,sem)
createThread(20,"2",sc,sem)
createThread(30,"3",sc,sem)
assert(TaskThreadInfo.threadToRunning(10) === true)
assert(TaskThreadInfo.threadToRunning(20) === true)
assert(TaskThreadInfo.threadToRunning(30) === true)
createThread(11,"1",sc,sem)
createThread(21,"2",sc,sem)
createThread(31,"3",sc,sem)
assert(TaskThreadInfo.threadToRunning(11) === true)
assert(TaskThreadInfo.threadToRunning(21) === true)
assert(TaskThreadInfo.threadToRunning(31) === true)
createThread(12,"1",sc,sem)
createThread(22,"2",sc,sem)
createThread(32,"3",sc,sem)
assert(TaskThreadInfo.threadToRunning(12) === true)
assert(TaskThreadInfo.threadToRunning(22) === true)
assert(TaskThreadInfo.threadToRunning(32) === false)
TaskThreadInfo.threadToLock(10).jobFinished()
Thread.sleep(1000)
assert(TaskThreadInfo.threadToRunning(32) === true)
createThread(23,"2",sc,sem)
createThread(33,"3",sc,sem)
TaskThreadInfo.threadToLock(11).jobFinished()
Thread.sleep(1000)
assert(TaskThreadInfo.threadToRunning(23) === true)
assert(TaskThreadInfo.threadToRunning(33) === false)
TaskThreadInfo.threadToLock(12).jobFinished()
Thread.sleep(1000)
assert(TaskThreadInfo.threadToRunning(33) === true)
TaskThreadInfo.threadToLock(20).jobFinished()
TaskThreadInfo.threadToLock(21).jobFinished()
TaskThreadInfo.threadToLock(22).jobFinished()
TaskThreadInfo.threadToLock(23).jobFinished()
TaskThreadInfo.threadToLock(30).jobFinished()
TaskThreadInfo.threadToLock(31).jobFinished()
TaskThreadInfo.threadToLock(32).jobFinished()
TaskThreadInfo.threadToLock(33).jobFinished()
sem.acquire(11)
}
}