Merge remote-tracking branch 'xiajunluan/xiajunluan'

Conflicts:
	core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
This commit is contained in:
Matei Zaharia 2013-06-18 00:11:50 +02:00
commit ec193c7d89
7 changed files with 1291 additions and 817 deletions

View file

@ -177,7 +177,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = new TaskSetManager(this, taskSet)
val manager = new ClusterTaskSetManager(this, taskSet)
activeTaskSets(taskSet.id) = manager
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
taskSetTaskIds(taskSet.id) = new HashSet[Long]()

View file

@ -0,0 +1,747 @@
package spark.scheduler.cluster
import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.math.max
import scala.math.min
import spark._
import spark.scheduler._
import spark.TaskState.TaskState
import java.nio.ByteBuffer
private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging {
// process local is expected to be used ONLY within tasksetmanager for now.
val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value
type TaskLocality = Value
def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = {
// Must not be the constraint.
assert (constraint != TaskLocality.PROCESS_LOCAL)
constraint match {
case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL
case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL
// For anything else, allow
case _ => true
}
}
def parse(str: String): TaskLocality = {
// better way to do this ?
try {
val retval = TaskLocality.withName(str)
// Must not specify PROCESS_LOCAL !
assert (retval != TaskLocality.PROCESS_LOCAL)
retval
} catch {
case nEx: NoSuchElementException => {
logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL");
// default to preserve earlier behavior
NODE_LOCAL
}
}
}
}
/**
* Schedules the tasks within a single TaskSet in the ClusterScheduler.
*/
private[spark] class ClusterTaskSetManager(
sched: ClusterScheduler,
val taskSet: TaskSet)
extends TaskSetManager
with Logging {
// Maximum time to wait to run a task in a preferred location (in ms)
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
// CPUs to request per task
val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble
// Maximum times a task is allowed to fail before failing the job
val MAX_TASK_FAILURES = 4
// Quantile of tasks at which to start speculation
val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
// Serializer for closures and tasks.
val ser = SparkEnv.get.closureSerializer.newInstance()
val tasks = taskSet.tasks
val numTasks = tasks.length
val copiesRunning = new Array[Int](numTasks)
val finished = new Array[Boolean](numTasks)
val numFailures = new Array[Int](numTasks)
val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
var tasksFinished = 0
var weight = 1
var minShare = 0
var runningTasks = 0
var priority = taskSet.priority
var stageId = taskSet.stageId
var name = "TaskSet_"+taskSet.stageId.toString
var parent:Schedulable = null
// Last time when we launched a preferred task (for delay scheduling)
var lastPreferredLaunchTime = System.currentTimeMillis
// List of pending tasks for each node (process local to container). These collections are actually
// treated as stacks, in which new tasks are added to the end of the
// ArrayBuffer and removed from the end. This makes it faster to detect
// tasks that repeatedly fail because whenever a task failed, it is put
// back at the head of the stack. They are also only cleaned up lazily;
// when a task is launched, it remains in all the pending lists except
// the one that it was launched from, but gets removed from them later.
private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]]
// List of pending tasks for each node.
// Essentially, similar to pendingTasksForHostPort, except at host level
private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
// List of pending tasks for each node based on rack locality.
// Essentially, similar to pendingTasksForHost, except at rack level
private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]]
// List containing pending tasks with no locality preferences
val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
// List containing all pending tasks (also used as a stack, as above)
val allPendingTasks = new ArrayBuffer[Int]
// Tasks that can be speculated. Since these will be a small fraction of total
// tasks, we'll just hold them in a HashSet.
val speculatableTasks = new HashSet[Int]
// Task index, start and finish time for each task attempt (indexed by task ID)
val taskInfos = new HashMap[Long, TaskInfo]
// Did the job fail?
var failed = false
var causeOfFailure = ""
// How frequently to reprint duplicate exceptions in full, in milliseconds
val EXCEPTION_PRINT_INTERVAL =
System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
// Map of recent exceptions (identified by string representation and
// top stack frame) to duplicate count (how many times the same
// exception has appeared) and time the full exception was
// printed. This should ideally be an LRU map that can drop old
// exceptions automatically.
val recentExceptions = HashMap[String, (Int, Long)]()
// Figure out the current map output tracker generation and set it on all tasks
val generation = sched.mapOutputTracker.getGeneration
logDebug("Generation for " + taskSet.id + ": " + generation)
for (t <- tasks) {
t.generation = generation
}
// Add all our tasks to the pending lists. We do this in reverse order
// of task index so that tasks with low indices get launched first.
for (i <- (0 until numTasks).reverse) {
addPendingTask(i)
}
// Note that it follows the hierarchy.
// if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and
// if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL
private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler,
taskLocality: TaskLocality.TaskLocality): HashSet[String] = {
if (TaskLocality.PROCESS_LOCAL == taskLocality) {
// straight forward comparison ! Special case it.
val retval = new HashSet[String]()
scheduler.synchronized {
for (location <- _taskPreferredLocations) {
if (scheduler.isExecutorAliveOnHostPort(location)) {
retval += location
}
}
}
return retval
}
val taskPreferredLocations =
if (TaskLocality.NODE_LOCAL == taskLocality) {
_taskPreferredLocations
} else {
assert (TaskLocality.RACK_LOCAL == taskLocality)
// Expand set to include all 'seen' rack local hosts.
// This works since container allocation/management happens within master - so any rack locality information is updated in msater.
// Best case effort, and maybe sort of kludge for now ... rework it later ?
val hosts = new HashSet[String]
_taskPreferredLocations.foreach(h => {
val rackOpt = scheduler.getRackForHost(h)
if (rackOpt.isDefined) {
val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get)
if (hostsOpt.isDefined) {
hosts ++= hostsOpt.get
}
}
// Ensure that irrespective of what scheduler says, host is always added !
hosts += h
})
hosts
}
val retval = new HashSet[String]
scheduler.synchronized {
for (prefLocation <- taskPreferredLocations) {
val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1)
if (aliveLocationsOpt.isDefined) {
retval ++= aliveLocationsOpt.get
}
}
}
retval
}
// Add a task to all the pending-task lists that it should be on.
private def addPendingTask(index: Int) {
// We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate
// hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it.
val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL)
val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
if (rackLocalLocations.size == 0) {
// Current impl ensures this.
assert (processLocalLocations.size == 0)
assert (hostLocalLocations.size == 0)
pendingTasksWithNoPrefs += index
} else {
// process local locality
for (hostPort <- processLocalLocations) {
// DEBUG Code
Utils.checkHostPort(hostPort)
val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer())
hostPortList += index
}
// host locality (includes process local)
for (hostPort <- hostLocalLocations) {
// DEBUG Code
Utils.checkHostPort(hostPort)
val host = Utils.parseHostPort(hostPort)._1
val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
hostList += index
}
// rack locality (includes process local and host local)
for (rackLocalHostPort <- rackLocalLocations) {
// DEBUG Code
Utils.checkHostPort(rackLocalHostPort)
val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1
val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer())
list += index
}
}
allPendingTasks += index
}
// Return the pending tasks list for a given host port (process local), or an empty list if
// there is no map entry for that host
private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = {
// DEBUG Code
Utils.checkHostPort(hostPort)
pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer())
}
// Return the pending tasks list for a given host, or an empty list if
// there is no map entry for that host
private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
val host = Utils.parseHostPort(hostPort)._1
pendingTasksForHost.getOrElse(host, ArrayBuffer())
}
// Return the pending tasks (rack level) list for a given host, or an empty list if
// there is no map entry for that host
private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
val host = Utils.parseHostPort(hostPort)._1
pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer())
}
// Number of pending tasks for a given host Port (which would be process local)
def numPendingTasksForHostPort(hostPort: String): Int = {
getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
}
// Number of pending tasks for a given host (which would be data local)
def numPendingTasksForHost(hostPort: String): Int = {
getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
}
// Number of pending rack local tasks for a given host
def numRackLocalPendingTasksForHost(hostPort: String): Int = {
getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
}
// Dequeue a pending task from the given list and return its index.
// Return None if the list is empty.
// This method also cleans up any tasks in the list that have already
// been launched, since we want that to happen lazily.
private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
while (!list.isEmpty) {
val index = list.last
list.trimEnd(1)
if (copiesRunning(index) == 0 && !finished(index)) {
return Some(index)
}
}
return None
}
// Return a speculative task for a given host if any are available. The task should not have an
// attempt running on this host, in case the host is slow. In addition, if locality is set, the
// task must have a preference for this host/rack/no preferred locations at all.
private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL))
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
if (speculatableTasks.size > 0) {
val localTask = speculatableTasks.find {
index =>
val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
val attemptLocs = taskAttempts(index).map(_.hostPort)
(locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort)
}
if (localTask != None) {
speculatableTasks -= localTask.get
return localTask
}
// check for rack locality
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
val rackTask = speculatableTasks.find {
index =>
val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
val attemptLocs = taskAttempts(index).map(_.hostPort)
locations.contains(hostPort) && !attemptLocs.contains(hostPort)
}
if (rackTask != None) {
speculatableTasks -= rackTask.get
return rackTask
}
}
// Any task ...
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
// Check for attemptLocs also ?
val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort))
if (nonLocalTask != None) {
speculatableTasks -= nonLocalTask.get
return nonLocalTask
}
}
}
return None
}
// Dequeue a pending task for a given node and return its index.
// If localOnly is set to false, allow non-local tasks as well.
private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort))
if (processLocalTask != None) {
return processLocalTask
}
val localTask = findTaskFromList(getPendingTasksForHost(hostPort))
if (localTask != None) {
return localTask
}
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort))
if (rackLocalTask != None) {
return rackLocalTask
}
}
// Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner.
// TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down).
val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs)
if (noPrefTask != None) {
return noPrefTask
}
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
val nonLocalTask = findTaskFromList(allPendingTasks)
if (nonLocalTask != None) {
return nonLocalTask
}
}
// Finally, if all else has failed, find a speculative task
return findSpeculativeTask(hostPort, locality)
}
private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = {
Utils.checkHostPort(hostPort)
val locs = task.preferredLocations
locs.contains(hostPort)
}
private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = {
val locs = task.preferredLocations
// If no preference, consider it as host local
if (locs.isEmpty) return true
val host = Utils.parseHostPort(hostPort)._1
locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined
}
// Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location).
// This is true if either the task has preferred locations and this host is one, or it has
// no preferred locations (in which we still count the launch as preferred).
private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = {
val locs = task.preferredLocations
val preferredRacks = new HashSet[String]()
for (preferredHost <- locs) {
val rack = sched.getRackForHost(preferredHost)
if (None != rack) preferredRacks += rack.get
}
if (preferredRacks.isEmpty) return false
val hostRack = sched.getRackForHost(hostPort)
return None != hostRack && preferredRacks.contains(hostRack.get)
}
// Respond to an offer of a single slave from the scheduler by finding a task
def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
// If explicitly specified, use that
val locality = if (overrideLocality != null) overrideLocality else {
// expand only if we have waited for more than LOCALITY_WAIT for a host local task ...
val time = System.currentTimeMillis
if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY
}
findTask(hostPort, locality) match {
case Some(index) => {
// Found a task; do some bookkeeping and return a Mesos task for it
val task = tasks(index)
val taskId = sched.newTaskId()
// Figure out whether this should count as a preferred launch
val taskLocality =
if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL else
if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL else
if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else
TaskLocality.ANY
val prefStr = taskLocality.toString
logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
taskSet.id, index, taskId, execId, hostPort, prefStr))
// Do various bookkeeping
copiesRunning(index) += 1
val time = System.currentTimeMillis
val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality)
taskInfos(taskId) = info
taskAttempts(index) = info :: taskAttempts(index)
if (TaskLocality.NODE_LOCAL == taskLocality) {
lastPreferredLaunchTime = time
}
// Serialize and return the task
val startTime = System.currentTimeMillis
val serializedTask = Task.serializeWithDependencies(
task, sched.sc.addedFiles, sched.sc.addedJars, ser)
val timeTaken = System.currentTimeMillis - startTime
increaseRunningTasks(1)
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
val taskName = "task %s:%d".format(taskSet.id, index)
return Some(new TaskDescription(taskId, execId, taskName, serializedTask))
}
case _ =>
}
}
return None
}
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
state match {
case TaskState.FINISHED =>
taskFinished(tid, state, serializedData)
case TaskState.LOST =>
taskLost(tid, state, serializedData)
case TaskState.FAILED =>
taskLost(tid, state, serializedData)
case TaskState.KILLED =>
taskLost(tid, state, serializedData)
case _ =>
}
}
def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
if (info.failed) {
// We might get two task-lost messages for the same task in coarse-grained Mesos mode,
// or even from Mesos itself when acks get delayed.
return
}
val index = info.index
info.markSuccessful()
decreaseRunningTasks(1)
if (!finished(index)) {
tasksFinished += 1
logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
tid, info.duration, tasksFinished, numTasks))
// Deserialize task result and pass it to the scheduler
try {
val result = ser.deserialize[TaskResult[_]](serializedData)
result.metrics.resultSize = serializedData.limit()
sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
} catch {
case cnf: ClassNotFoundException =>
val loader = Thread.currentThread().getContextClassLoader
throw new SparkException("ClassNotFound with classloader: " + loader, cnf)
case ex => throw ex
}
// Mark finished and stop if we've finished all the tasks
finished(index) = true
if (tasksFinished == numTasks) {
sched.taskSetFinished(this)
}
} else {
logInfo("Ignoring task-finished event for TID " + tid +
" because task " + index + " is already finished")
}
}
def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
if (info.failed) {
// We might get two task-lost messages for the same task in coarse-grained Mesos mode,
// or even from Mesos itself when acks get delayed.
return
}
val index = info.index
info.markFailed()
decreaseRunningTasks(1)
if (!finished(index)) {
logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
copiesRunning(index) -= 1
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
if (serializedData != null && serializedData.limit() > 0) {
val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader)
reason match {
case fetchFailed: FetchFailed =>
logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
finished(index) = true
tasksFinished += 1
sched.taskSetFinished(this)
decreaseRunningTasks(runningTasks)
return
case taskResultTooBig: TaskResultTooBigFailure =>
logInfo("Loss was due to task %s result exceeding Akka frame size; " +
"aborting job".format(tid))
abort("Task %s result exceeded Akka frame size".format(tid))
return
case ef: ExceptionFailure =>
val key = ef.description
val now = System.currentTimeMillis
val (printFull, dupCount) = {
if (recentExceptions.contains(key)) {
val (dupCount, printTime) = recentExceptions(key)
if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
recentExceptions(key) = (0, now)
(true, 0)
} else {
recentExceptions(key) = (dupCount + 1, printTime)
(false, dupCount + 1)
}
} else {
recentExceptions(key) = (0, now)
(true, 0)
}
}
if (printFull) {
val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
logInfo("Loss was due to %s\n%s\n%s".format(
ef.className, ef.description, locs.mkString("\n")))
} else {
logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
}
case _ => {}
}
}
// On non-fetch failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index)
// Count failed attempts only on FAILED and LOST state (not on KILLED)
if (state == TaskState.FAILED || state == TaskState.LOST) {
numFailures(index) += 1
if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %s:%d failed more than %d times; aborting job".format(
taskSet.id, index, MAX_TASK_FAILURES))
abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
}
}
} else {
logInfo("Ignoring task-lost event for TID " + tid +
" because task " + index + " is already finished")
}
}
def error(message: String) {
// Save the error message
abort("Error: " + message)
}
def abort(message: String) {
failed = true
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
sched.listener.taskSetFailed(taskSet, message)
decreaseRunningTasks(runningTasks)
sched.taskSetFinished(this)
}
override def increaseRunningTasks(taskNum: Int) {
runningTasks += taskNum
if (parent != null) {
parent.increaseRunningTasks(taskNum)
}
}
override def decreaseRunningTasks(taskNum: Int) {
runningTasks -= taskNum
if (parent != null) {
parent.decreaseRunningTasks(taskNum)
}
}
//TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed
override def getSchedulableByName(name: String): Schedulable = {
return null
}
override def addSchedulable(schedulable:Schedulable) {
//nothing
}
override def removeSchedulable(schedulable:Schedulable) {
//nothing
}
override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
sortedTaskSetQueue += this
return sortedTaskSetQueue
}
override def executorLost(execId: String, hostPort: String) {
logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
// If some task has preferred locations only on hostname, and there are no more executors there,
// put it in the no-prefs list to avoid the wait from delay scheduling
// host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to
// no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc.
// Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if
// there is no host local node for the task (not if there is no process local node for the task)
for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) {
// val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
if (newLocs.isEmpty) {
pendingTasksWithNoPrefs += index
}
}
// Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
if (tasks(0).isInstanceOf[ShuffleMapTask]) {
for ((tid, info) <- taskInfos if info.executorId == execId) {
val index = taskInfos(tid).index
if (finished(index)) {
finished(index) = false
copiesRunning(index) -= 1
tasksFinished -= 1
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
}
}
}
// Also re-enqueue any tasks that were running on the node
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
taskLost(tid, TaskState.KILLED, null)
}
}
/**
* Check for tasks to be speculated and return true if there are any. This is called periodically
* by the ClusterScheduler.
*
* TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
* we don't scan the whole task set. It might also help to make this sorted by launch time.
*/
override def checkSpeculatableTasks(): Boolean = {
// Can't speculate if we only have one task, or if all tasks have finished.
if (numTasks == 1 || tasksFinished == numTasks) {
return false
}
var foundTasks = false
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
if (tasksFinished >= minFinishedForSpeculation) {
val time = System.currentTimeMillis()
val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
Arrays.sort(durations)
val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
// TODO: Threshold should also look at standard deviation of task durations and have a lower
// bound based on that.
logDebug("Task length threshold for speculation: " + threshold)
for ((tid, info) <- taskInfos) {
val index = info.index
if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
!speculatableTasks.contains(index)) {
logInfo(
"Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
taskSet.id, index, info.hostPort, threshold))
speculatableTasks += index
foundTasks = true
}
}
}
return foundTasks
}
override def hasPendingTasks(): Boolean = {
numTasks > 0 && tasksFinished < numTasks
}
}

View file

@ -1,747 +1,17 @@
package spark.scheduler.cluster
import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.math.max
import scala.math.min
import spark._
import spark.scheduler._
import spark.TaskState.TaskState
import java.nio.ByteBuffer
private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging {
// process local is expected to be used ONLY within tasksetmanager for now.
val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value
type TaskLocality = Value
def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = {
// Must not be the constraint.
assert (constraint != TaskLocality.PROCESS_LOCAL)
constraint match {
case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL
case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL
// For anything else, allow
case _ => true
}
}
def parse(str: String): TaskLocality = {
// better way to do this ?
try {
val retval = TaskLocality.withName(str)
// Must not specify PROCESS_LOCAL !
assert (retval != TaskLocality.PROCESS_LOCAL)
retval
} catch {
case nEx: NoSuchElementException => {
logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL");
// default to preserve earlier behavior
NODE_LOCAL
}
}
}
}
/**
* Schedules the tasks within a single TaskSet in the ClusterScheduler.
*/
private[spark] class TaskSetManager(
sched: ClusterScheduler,
val taskSet: TaskSet)
extends Schedulable
with Logging {
// Maximum time to wait to run a task in a preferred location (in ms)
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
// CPUs to request per task
val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble
// Maximum times a task is allowed to fail before failing the job
val MAX_TASK_FAILURES = 4
// Quantile of tasks at which to start speculation
val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
// Serializer for closures and tasks.
val ser = SparkEnv.get.closureSerializer.newInstance()
val tasks = taskSet.tasks
val numTasks = tasks.length
val copiesRunning = new Array[Int](numTasks)
val finished = new Array[Boolean](numTasks)
val numFailures = new Array[Int](numTasks)
val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
var tasksFinished = 0
var weight = 1
var minShare = 0
var runningTasks = 0
var priority = taskSet.priority
var stageId = taskSet.stageId
var name = "TaskSet_"+taskSet.stageId.toString
var parent:Schedulable = null
// Last time when we launched a preferred task (for delay scheduling)
var lastPreferredLaunchTime = System.currentTimeMillis
// List of pending tasks for each node (process local to container). These collections are actually
// treated as stacks, in which new tasks are added to the end of the
// ArrayBuffer and removed from the end. This makes it faster to detect
// tasks that repeatedly fail because whenever a task failed, it is put
// back at the head of the stack. They are also only cleaned up lazily;
// when a task is launched, it remains in all the pending lists except
// the one that it was launched from, but gets removed from them later.
private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]]
// List of pending tasks for each node.
// Essentially, similar to pendingTasksForHostPort, except at host level
private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
// List of pending tasks for each node based on rack locality.
// Essentially, similar to pendingTasksForHost, except at rack level
private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]]
// List containing pending tasks with no locality preferences
val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
// List containing all pending tasks (also used as a stack, as above)
val allPendingTasks = new ArrayBuffer[Int]
// Tasks that can be speculated. Since these will be a small fraction of total
// tasks, we'll just hold them in a HashSet.
val speculatableTasks = new HashSet[Int]
// Task index, start and finish time for each task attempt (indexed by task ID)
val taskInfos = new HashMap[Long, TaskInfo]
// Did the job fail?
var failed = false
var causeOfFailure = ""
// How frequently to reprint duplicate exceptions in full, in milliseconds
val EXCEPTION_PRINT_INTERVAL =
System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
// Map of recent exceptions (identified by string representation and
// top stack frame) to duplicate count (how many times the same
// exception has appeared) and time the full exception was
// printed. This should ideally be an LRU map that can drop old
// exceptions automatically.
val recentExceptions = HashMap[String, (Int, Long)]()
// Figure out the current map output tracker generation and set it on all tasks
val generation = sched.mapOutputTracker.getGeneration
logDebug("Generation for " + taskSet.id + ": " + generation)
for (t <- tasks) {
t.generation = generation
}
// Add all our tasks to the pending lists. We do this in reverse order
// of task index so that tasks with low indices get launched first.
for (i <- (0 until numTasks).reverse) {
addPendingTask(i)
}
// Note that it follows the hierarchy.
// if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and
// if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL
private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler,
taskLocality: TaskLocality.TaskLocality): HashSet[String] = {
if (TaskLocality.PROCESS_LOCAL == taskLocality) {
// straight forward comparison ! Special case it.
val retval = new HashSet[String]()
scheduler.synchronized {
for (location <- _taskPreferredLocations) {
if (scheduler.isExecutorAliveOnHostPort(location)) {
retval += location
}
}
}
return retval
}
val taskPreferredLocations =
if (TaskLocality.NODE_LOCAL == taskLocality) {
_taskPreferredLocations
} else {
assert (TaskLocality.RACK_LOCAL == taskLocality)
// Expand set to include all 'seen' rack local hosts.
// This works since container allocation/management happens within master - so any rack locality information is updated in msater.
// Best case effort, and maybe sort of kludge for now ... rework it later ?
val hosts = new HashSet[String]
_taskPreferredLocations.foreach(h => {
val rackOpt = scheduler.getRackForHost(h)
if (rackOpt.isDefined) {
val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get)
if (hostsOpt.isDefined) {
hosts ++= hostsOpt.get
}
}
// Ensure that irrespective of what scheduler says, host is always added !
hosts += h
})
hosts
}
val retval = new HashSet[String]
scheduler.synchronized {
for (prefLocation <- taskPreferredLocations) {
val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1)
if (aliveLocationsOpt.isDefined) {
retval ++= aliveLocationsOpt.get
}
}
}
retval
}
// Add a task to all the pending-task lists that it should be on.
private def addPendingTask(index: Int) {
// We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate
// hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it.
val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL)
val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
if (rackLocalLocations.size == 0) {
// Current impl ensures this.
assert (processLocalLocations.size == 0)
assert (hostLocalLocations.size == 0)
pendingTasksWithNoPrefs += index
} else {
// process local locality
for (hostPort <- processLocalLocations) {
// DEBUG Code
Utils.checkHostPort(hostPort)
val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer())
hostPortList += index
}
// host locality (includes process local)
for (hostPort <- hostLocalLocations) {
// DEBUG Code
Utils.checkHostPort(hostPort)
val host = Utils.parseHostPort(hostPort)._1
val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
hostList += index
}
// rack locality (includes process local and host local)
for (rackLocalHostPort <- rackLocalLocations) {
// DEBUG Code
Utils.checkHostPort(rackLocalHostPort)
val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1
val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer())
list += index
}
}
allPendingTasks += index
}
// Return the pending tasks list for a given host port (process local), or an empty list if
// there is no map entry for that host
private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = {
// DEBUG Code
Utils.checkHostPort(hostPort)
pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer())
}
// Return the pending tasks list for a given host, or an empty list if
// there is no map entry for that host
private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
val host = Utils.parseHostPort(hostPort)._1
pendingTasksForHost.getOrElse(host, ArrayBuffer())
}
// Return the pending tasks (rack level) list for a given host, or an empty list if
// there is no map entry for that host
private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
val host = Utils.parseHostPort(hostPort)._1
pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer())
}
// Number of pending tasks for a given host Port (which would be process local)
def numPendingTasksForHostPort(hostPort: String): Int = {
getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
}
// Number of pending tasks for a given host (which would be data local)
def numPendingTasksForHost(hostPort: String): Int = {
getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
}
// Number of pending rack local tasks for a given host
def numRackLocalPendingTasksForHost(hostPort: String): Int = {
getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
}
// Dequeue a pending task from the given list and return its index.
// Return None if the list is empty.
// This method also cleans up any tasks in the list that have already
// been launched, since we want that to happen lazily.
private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
while (!list.isEmpty) {
val index = list.last
list.trimEnd(1)
if (copiesRunning(index) == 0 && !finished(index)) {
return Some(index)
}
}
return None
}
// Return a speculative task for a given host if any are available. The task should not have an
// attempt running on this host, in case the host is slow. In addition, if locality is set, the
// task must have a preference for this host/rack/no preferred locations at all.
private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL))
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
if (speculatableTasks.size > 0) {
val localTask = speculatableTasks.find {
index =>
val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
val attemptLocs = taskAttempts(index).map(_.hostPort)
(locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort)
}
if (localTask != None) {
speculatableTasks -= localTask.get
return localTask
}
// check for rack locality
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
val rackTask = speculatableTasks.find {
index =>
val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
val attemptLocs = taskAttempts(index).map(_.hostPort)
locations.contains(hostPort) && !attemptLocs.contains(hostPort)
}
if (rackTask != None) {
speculatableTasks -= rackTask.get
return rackTask
}
}
// Any task ...
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
// Check for attemptLocs also ?
val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort))
if (nonLocalTask != None) {
speculatableTasks -= nonLocalTask.get
return nonLocalTask
}
}
}
return None
}
// Dequeue a pending task for a given node and return its index.
// If localOnly is set to false, allow non-local tasks as well.
private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort))
if (processLocalTask != None) {
return processLocalTask
}
val localTask = findTaskFromList(getPendingTasksForHost(hostPort))
if (localTask != None) {
return localTask
}
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort))
if (rackLocalTask != None) {
return rackLocalTask
}
}
// Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner.
// TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down).
val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs)
if (noPrefTask != None) {
return noPrefTask
}
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
val nonLocalTask = findTaskFromList(allPendingTasks)
if (nonLocalTask != None) {
return nonLocalTask
}
}
// Finally, if all else has failed, find a speculative task
return findSpeculativeTask(hostPort, locality)
}
private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = {
Utils.checkHostPort(hostPort)
val locs = task.preferredLocations
locs.contains(hostPort)
}
private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = {
val locs = task.preferredLocations
// If no preference, consider it as host local
if (locs.isEmpty) return true
val host = Utils.parseHostPort(hostPort)._1
locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined
}
// Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location).
// This is true if either the task has preferred locations and this host is one, or it has
// no preferred locations (in which we still count the launch as preferred).
private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = {
val locs = task.preferredLocations
val preferredRacks = new HashSet[String]()
for (preferredHost <- locs) {
val rack = sched.getRackForHost(preferredHost)
if (None != rack) preferredRacks += rack.get
}
if (preferredRacks.isEmpty) return false
val hostRack = sched.getRackForHost(hostPort)
return None != hostRack && preferredRacks.contains(hostRack.get)
}
// Respond to an offer of a single slave from the scheduler by finding a task
def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
// If explicitly specified, use that
val locality = if (overrideLocality != null) overrideLocality else {
// expand only if we have waited for more than LOCALITY_WAIT for a host local task ...
val time = System.currentTimeMillis
if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY
}
findTask(hostPort, locality) match {
case Some(index) => {
// Found a task; do some bookkeeping and return a Mesos task for it
val task = tasks(index)
val taskId = sched.newTaskId()
// Figure out whether this should count as a preferred launch
val taskLocality =
if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL else
if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL else
if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else
TaskLocality.ANY
val prefStr = taskLocality.toString
logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
taskSet.id, index, taskId, execId, hostPort, prefStr))
// Do various bookkeeping
copiesRunning(index) += 1
val time = System.currentTimeMillis
val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality)
taskInfos(taskId) = info
taskAttempts(index) = info :: taskAttempts(index)
if (TaskLocality.NODE_LOCAL == taskLocality) {
lastPreferredLaunchTime = time
}
// Serialize and return the task
val startTime = System.currentTimeMillis
val serializedTask = Task.serializeWithDependencies(
task, sched.sc.addedFiles, sched.sc.addedJars, ser)
val timeTaken = System.currentTimeMillis - startTime
increaseRunningTasks(1)
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
val taskName = "task %s:%d".format(taskSet.id, index)
return Some(new TaskDescription(taskId, execId, taskName, serializedTask))
}
case _ =>
}
}
return None
}
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
state match {
case TaskState.FINISHED =>
taskFinished(tid, state, serializedData)
case TaskState.LOST =>
taskLost(tid, state, serializedData)
case TaskState.FAILED =>
taskLost(tid, state, serializedData)
case TaskState.KILLED =>
taskLost(tid, state, serializedData)
case _ =>
}
}
def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
if (info.failed) {
// We might get two task-lost messages for the same task in coarse-grained Mesos mode,
// or even from Mesos itself when acks get delayed.
return
}
val index = info.index
info.markSuccessful()
decreaseRunningTasks(1)
if (!finished(index)) {
tasksFinished += 1
logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
tid, info.duration, tasksFinished, numTasks))
// Deserialize task result and pass it to the scheduler
try {
val result = ser.deserialize[TaskResult[_]](serializedData)
result.metrics.resultSize = serializedData.limit()
sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
} catch {
case cnf: ClassNotFoundException =>
val loader = Thread.currentThread().getContextClassLoader
throw new SparkException("ClassNotFound with classloader: " + loader, cnf)
case ex => throw ex
}
// Mark finished and stop if we've finished all the tasks
finished(index) = true
if (tasksFinished == numTasks) {
sched.taskSetFinished(this)
}
} else {
logInfo("Ignoring task-finished event for TID " + tid +
" because task " + index + " is already finished")
}
}
def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
if (info.failed) {
// We might get two task-lost messages for the same task in coarse-grained Mesos mode,
// or even from Mesos itself when acks get delayed.
return
}
val index = info.index
info.markFailed()
decreaseRunningTasks(1)
if (!finished(index)) {
logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
copiesRunning(index) -= 1
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
if (serializedData != null && serializedData.limit() > 0) {
val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader)
reason match {
case fetchFailed: FetchFailed =>
logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
finished(index) = true
tasksFinished += 1
sched.taskSetFinished(this)
decreaseRunningTasks(runningTasks)
return
case taskResultTooBig: TaskResultTooBigFailure =>
logInfo("Loss was due to task %s result exceeding Akka frame size;" +
"aborting job".format(tid))
abort("Task %s result exceeded Akka frame size".format(tid))
return
case ef: ExceptionFailure =>
val key = ef.description
val now = System.currentTimeMillis
val (printFull, dupCount) = {
if (recentExceptions.contains(key)) {
val (dupCount, printTime) = recentExceptions(key)
if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
recentExceptions(key) = (0, now)
(true, 0)
} else {
recentExceptions(key) = (dupCount + 1, printTime)
(false, dupCount + 1)
}
} else {
recentExceptions(key) = (0, now)
(true, 0)
}
}
if (printFull) {
val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
logInfo("Loss was due to %s\n%s\n%s".format(
ef.className, ef.description, locs.mkString("\n")))
} else {
logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
}
case _ => {}
}
}
// On non-fetch failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index)
// Count failed attempts only on FAILED and LOST state (not on KILLED)
if (state == TaskState.FAILED || state == TaskState.LOST) {
numFailures(index) += 1
if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %s:%d failed more than %d times; aborting job".format(
taskSet.id, index, MAX_TASK_FAILURES))
abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
}
}
} else {
logInfo("Ignoring task-lost event for TID " + tid +
" because task " + index + " is already finished")
}
}
def error(message: String) {
// Save the error message
abort("Error: " + message)
}
def abort(message: String) {
failed = true
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
sched.listener.taskSetFailed(taskSet, message)
decreaseRunningTasks(runningTasks)
sched.taskSetFinished(this)
}
override def increaseRunningTasks(taskNum: Int) {
runningTasks += taskNum
if (parent != null) {
parent.increaseRunningTasks(taskNum)
}
}
override def decreaseRunningTasks(taskNum: Int) {
runningTasks -= taskNum
if (parent != null) {
parent.decreaseRunningTasks(taskNum)
}
}
//TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed
override def getSchedulableByName(name: String): Schedulable = {
return null
}
override def addSchedulable(schedulable:Schedulable) {
//nothing
}
override def removeSchedulable(schedulable:Schedulable) {
//nothing
}
override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
sortedTaskSetQueue += this
return sortedTaskSetQueue
}
override def executorLost(execId: String, hostPort: String) {
logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
// If some task has preferred locations only on hostname, and there are no more executors there,
// put it in the no-prefs list to avoid the wait from delay scheduling
// host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to
// no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc.
// Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if
// there is no host local node for the task (not if there is no process local node for the task)
for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) {
// val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
if (newLocs.isEmpty) {
pendingTasksWithNoPrefs += index
}
}
// Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
if (tasks(0).isInstanceOf[ShuffleMapTask]) {
for ((tid, info) <- taskInfos if info.executorId == execId) {
val index = taskInfos(tid).index
if (finished(index)) {
finished(index) = false
copiesRunning(index) -= 1
tasksFinished -= 1
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
}
}
}
// Also re-enqueue any tasks that were running on the node
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
taskLost(tid, TaskState.KILLED, null)
}
}
/**
* Check for tasks to be speculated and return true if there are any. This is called periodically
* by the ClusterScheduler.
*
* TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
* we don't scan the whole task set. It might also help to make this sorted by launch time.
*/
override def checkSpeculatableTasks(): Boolean = {
// Can't speculate if we only have one task, or if all tasks have finished.
if (numTasks == 1 || tasksFinished == numTasks) {
return false
}
var foundTasks = false
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
if (tasksFinished >= minFinishedForSpeculation) {
val time = System.currentTimeMillis()
val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
Arrays.sort(durations)
val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
// TODO: Threshold should also look at standard deviation of task durations and have a lower
// bound based on that.
logDebug("Task length threshold for speculation: " + threshold)
for ((tid, info) <- taskInfos) {
val index = info.index
if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
!speculatableTasks.contains(index)) {
logInfo(
"Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
taskSet.id, index, info.hostPort, threshold))
speculatableTasks += index
foundTasks = true
}
}
}
return foundTasks
}
override def hasPendingTasks(): Boolean = {
numTasks > 0 && tasksFinished < numTasks
}
private[spark] trait TaskSetManager extends Schedulable {
def taskSet: TaskSet
def slaveOffer(execId: String, hostPort: String, availableCpus: Double,
overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription]
def numPendingTasksForHostPort(hostPort: String): Int
def numRackLocalPendingTasksForHost(hostPort :String): Int
def numPendingTasksForHost(hostPort: String): Int
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer)
def error(message: String)
}

View file

@ -2,19 +2,50 @@ package spark.scheduler.local
import java.io.File
import java.util.concurrent.atomic.AtomicInteger
import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import spark._
import spark.TaskState.TaskState
import spark.executor.ExecutorURLClassLoader
import spark.scheduler._
import spark.scheduler.cluster.{TaskLocality, TaskInfo}
import spark.scheduler.cluster._
import akka.actor._
/**
* A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
* A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
* the scheduler also allows each task to fail up to maxFailures times, which is useful for
* testing fault recovery.
*/
private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext)
private[spark] case class LocalReviveOffers()
private[spark] case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging {
def receive = {
case LocalReviveOffers =>
launchTask(localScheduler.resourceOffer(freeCores))
case LocalStatusUpdate(taskId, state, serializeData) =>
freeCores += 1
localScheduler.statusUpdate(taskId, state, serializeData)
launchTask(localScheduler.resourceOffer(freeCores))
}
def launchTask(tasks : Seq[TaskDescription]) {
for (task <- tasks) {
freeCores -= 1
localScheduler.threadPool.submit(new Runnable {
def run() {
localScheduler.runTask(task.taskId,task.serializedTask)
}
})
}
}
}
private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
extends TaskScheduler
with Logging {
@ -30,89 +61,127 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader)
// TODO: Need to take into account stage priority in scheduling
var schedulableBuilder: SchedulableBuilder = null
var rootPool: Pool = null
val activeTaskSets = new HashMap[String, TaskSetManager]
val taskIdToTaskSetId = new HashMap[Long, String]
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
override def start() { }
var localActor: ActorRef = null
override def start() {
//default scheduler is FIFO
val schedulingMode = System.getProperty("spark.cluster.schedulingmode", "FIFO")
//temporarily set rootPool name to empty
rootPool = new Pool("", SchedulingMode.withName(schedulingMode), 0, 0)
schedulableBuilder = {
schedulingMode match {
case "FIFO" =>
new FIFOSchedulableBuilder(rootPool)
case "FAIR" =>
new FairSchedulableBuilder(rootPool)
}
}
schedulableBuilder.buildPools()
localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
}
override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
}
override def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
val failCount = new Array[Int](tasks.size)
def submitTask(task: Task[_], idInJob: Int) {
val myAttemptId = attemptId.getAndIncrement()
threadPool.submit(new Runnable {
def run() {
runTask(task, idInJob, myAttemptId)
}
})
synchronized {
var manager = new LocalTaskSetManager(this, taskSet)
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
activeTaskSets(taskSet.id) = manager
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
localActor ! LocalReviveOffers
}
}
def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
logInfo("Running " + task)
val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
try {
Accumulators.clear()
Thread.currentThread().setContextClassLoader(classLoader)
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
val ser = SparkEnv.get.closureSerializer.newInstance()
val bytes = Task.serializeWithDependencies(task, sc.addedFiles, sc.addedJars, ser)
logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes")
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
updateDependencies(taskFiles, taskJars) // Download any files added with addFile
val deserStart = System.currentTimeMillis()
val deserializedTask = ser.deserialize[Task[_]](
taskBytes, Thread.currentThread.getContextClassLoader)
val deserTime = System.currentTimeMillis() - deserStart
// Run it
val result: Any = deserializedTask.run(attemptId)
// Serialize and deserialize the result to emulate what the Mesos
// executor does. This is useful to catch serialization errors early
// on in development (so when users move their local Spark programs
// to the cluster, they don't get surprised by serialization errors).
val serResult = ser.serialize(result)
deserializedTask.metrics.get.resultSize = serResult.limit()
val resultToReturn = ser.deserialize[Any](serResult)
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
logInfo("Finished " + task)
info.markSuccessful()
deserializedTask.metrics.get.executorRunTime = info.duration.toInt //close enough
deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
// If the threadpool has not already been shutdown, notify DAGScheduler
if (!Thread.currentThread().isInterrupted)
listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, deserializedTask.metrics.getOrElse(null))
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
failCount.synchronized {
failCount(idInJob) += 1
if (failCount(idInJob) <= maxFailures) {
submitTask(task, idInJob)
} else {
// TODO: Do something nicer here to return all the way to the user
if (!Thread.currentThread().isInterrupted) {
val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace)
listener.taskEnded(task, failure, null, null, info, null)
}
}
}
}
def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
synchronized {
var freeCpuCores = freeCores
val tasks = new ArrayBuffer[TaskDescription](freeCores)
val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
for (manager <- sortedTaskSetQueue) {
logDebug("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks))
}
}
for ((task, i) <- tasks.zipWithIndex) {
submitTask(task, i)
var launchTask = false
for (manager <- sortedTaskSetQueue) {
do {
launchTask = false
manager.slaveOffer(null,null,freeCpuCores) match {
case Some(task) =>
tasks += task
taskIdToTaskSetId(task.taskId) = manager.taskSet.id
taskSetTaskIds(manager.taskSet.id) += task.taskId
freeCpuCores -= 1
launchTask = true
case None => {}
}
} while(launchTask)
}
return tasks
}
}
def taskSetFinished(manager: TaskSetManager) {
synchronized {
activeTaskSets -= manager.taskSet.id
manager.parent.removeSchedulable(manager)
logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
taskSetTaskIds -= manager.taskSet.id
}
}
def runTask(taskId: Long, bytes: ByteBuffer) {
logInfo("Running " + taskId)
val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
val ser = SparkEnv.get.closureSerializer.newInstance()
try {
Accumulators.clear()
Thread.currentThread().setContextClassLoader(classLoader)
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
updateDependencies(taskFiles, taskJars) // Download any files added with addFile
val deserStart = System.currentTimeMillis()
val deserializedTask = ser.deserialize[Task[_]](
taskBytes, Thread.currentThread.getContextClassLoader)
val deserTime = System.currentTimeMillis() - deserStart
// Run it
val result: Any = deserializedTask.run(taskId)
// Serialize and deserialize the result to emulate what the Mesos
// executor does. This is useful to catch serialization errors early
// on in development (so when users move their local Spark programs
// to the cluster, they don't get surprised by serialization errors).
val serResult = ser.serialize(result)
deserializedTask.metrics.get.resultSize = serResult.limit()
val resultToReturn = ser.deserialize[Any](serResult)
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
logInfo("Finished " + taskId)
deserializedTask.metrics.get.executorRunTime = deserTime.toInt//info.duration.toInt //close enough
deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null))
val serializedResult = ser.serialize(taskResult)
localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
case t: Throwable => {
val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace)
localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure))
}
}
}
@ -128,6 +197,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
@ -143,7 +213,16 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
}
}
override def stop() {
def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) {
synchronized {
val taskSetId = taskIdToTaskSetId(taskId)
val taskSetManager = activeTaskSets(taskSetId)
taskSetTaskIds(taskSetId) -= taskId
taskSetManager.statusUpdate(taskId, state, serializedData)
}
}
override def stop() {
threadPool.shutdownNow()
}

View file

@ -0,0 +1,172 @@
package spark.scheduler.local
import java.io.File
import java.util.concurrent.atomic.AtomicInteger
import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import spark._
import spark.TaskState.TaskState
import spark.scheduler._
import spark.scheduler.cluster._
private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging {
var parent: Schedulable = null
var weight: Int = 1
var minShare: Int = 0
var runningTasks: Int = 0
var priority: Int = taskSet.priority
var stageId: Int = taskSet.stageId
var name: String = "TaskSet_"+taskSet.stageId.toString
var failCount = new Array[Int](taskSet.tasks.size)
val taskInfos = new HashMap[Long, TaskInfo]
val numTasks = taskSet.tasks.size
var numFinished = 0
val ser = SparkEnv.get.closureSerializer.newInstance()
val copiesRunning = new Array[Int](numTasks)
val finished = new Array[Boolean](numTasks)
val numFailures = new Array[Int](numTasks)
val MAX_TASK_FAILURES = sched.maxFailures
def increaseRunningTasks(taskNum: Int): Unit = {
runningTasks += taskNum
if (parent != null) {
parent.increaseRunningTasks(taskNum)
}
}
def decreaseRunningTasks(taskNum: Int): Unit = {
runningTasks -= taskNum
if (parent != null) {
parent.decreaseRunningTasks(taskNum)
}
}
def addSchedulable(schedulable: Schedulable): Unit = {
//nothing
}
def removeSchedulable(schedulable: Schedulable): Unit = {
//nothing
}
def getSchedulableByName(name: String): Schedulable = {
return null
}
def executorLost(executorId: String, host: String): Unit = {
//nothing
}
def checkSpeculatableTasks(): Boolean = {
return true
}
def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
sortedTaskSetQueue += this
return sortedTaskSetQueue
}
def hasPendingTasks(): Boolean = {
return true
}
def findTask(): Option[Int] = {
for (i <- 0 to numTasks-1) {
if (copiesRunning(i) == 0 && !finished(i)) {
return Some(i)
}
}
return None
}
def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
SparkEnv.set(sched.env)
logDebug("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks))
if (availableCpus > 0 && numFinished < numTasks) {
findTask() match {
case Some(index) =>
val taskId = sched.attemptId.getAndIncrement()
val task = taskSet.tasks(index)
val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
taskInfos(taskId) = info
val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes")
val taskName = "task %s:%d".format(taskSet.id, index)
copiesRunning(index) += 1
increaseRunningTasks(1)
return Some(new TaskDescription(taskId, null, taskName, bytes))
case None => {}
}
}
return None
}
def numPendingTasksForHostPort(hostPort: String): Int = {
return 0
}
def numRackLocalPendingTasksForHost(hostPort :String): Int = {
return 0
}
def numPendingTasksForHost(hostPort: String): Int = {
return 0
}
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
state match {
case TaskState.FINISHED =>
taskEnded(tid, state, serializedData)
case TaskState.FAILED =>
taskFailed(tid, state, serializedData)
case _ => {}
}
}
def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
val index = info.index
val task = taskSet.tasks(index)
info.markSuccessful()
val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
result.metrics.resultSize = serializedData.limit()
sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
numFinished += 1
decreaseRunningTasks(1)
finished(index) = true
if (numFinished == numTasks) {
sched.taskSetFinished(this)
}
}
def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
val index = info.index
val task = taskSet.tasks(index)
info.markFailed()
decreaseRunningTasks(1)
val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader)
if (!finished(index)) {
copiesRunning(index) -= 1
numFailures(index) += 1
val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString))
logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n")))
if (numFailures(index) > MAX_TASK_FAILURES) {
val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description)
decreaseRunningTasks(runningTasks)
sched.listener.taskSetFailed(taskSet, errorMessage)
// need to delete failed Taskset from schedule queue
sched.taskSetFinished(this)
}
}
}
def error(message: String) {
}
}

View file

@ -16,7 +16,7 @@ class DummyTaskSetManager(
initNumTasks: Int,
clusterScheduler: ClusterScheduler,
taskSet: TaskSet)
extends TaskSetManager(clusterScheduler,taskSet) {
extends ClusterTaskSetManager(clusterScheduler,taskSet) {
parent = null
weight = 1

View file

@ -0,0 +1,206 @@
package spark.scheduler
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import spark._
import spark.scheduler._
import spark.scheduler.cluster._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.{ConcurrentMap, HashMap}
import java.util.concurrent.Semaphore
import java.util.concurrent.CountDownLatch
import java.util.Properties
class Lock() {
var finished = false
def jobWait() = {
synchronized {
while(!finished) {
this.wait()
}
}
}
def jobFinished() = {
synchronized {
finished = true
this.notifyAll()
}
}
}
object TaskThreadInfo {
val threadToLock = HashMap[Int, Lock]()
val threadToRunning = HashMap[Int, Boolean]()
val threadToStarted = HashMap[Int, CountDownLatch]()
}
/*
* 1. each thread contains one job.
* 2. each job contains one stage.
* 3. each stage only contains one task.
* 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure
* it will get cpu core resource, and will wait to finished after user manually
* release "Lock" and then cluster will contain another free cpu cores.
* 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue,
* thus it will be scheduled later when cluster has free cpu cores.
*/
class LocalSchedulerSuite extends FunSuite with LocalSparkContext {
def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) {
TaskThreadInfo.threadToRunning(threadIndex) = false
val nums = sc.parallelize(threadIndex to threadIndex, 1)
TaskThreadInfo.threadToLock(threadIndex) = new Lock()
TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1)
new Thread {
if (poolName != null) {
sc.addLocalProperties("spark.scheduler.cluster.fair.pool",poolName)
}
override def run() {
val ans = nums.map(number => {
TaskThreadInfo.threadToRunning(number) = true
TaskThreadInfo.threadToStarted(number).countDown()
TaskThreadInfo.threadToLock(number).jobWait()
TaskThreadInfo.threadToRunning(number) = false
number
}).collect()
assert(ans.toList === List(threadIndex))
sem.release()
}
}.start()
}
test("Local FIFO scheduler end-to-end test") {
System.setProperty("spark.cluster.schedulingmode", "FIFO")
sc = new SparkContext("local[4]", "test")
val sem = new Semaphore(0)
createThread(1,null,sc,sem)
TaskThreadInfo.threadToStarted(1).await()
createThread(2,null,sc,sem)
TaskThreadInfo.threadToStarted(2).await()
createThread(3,null,sc,sem)
TaskThreadInfo.threadToStarted(3).await()
createThread(4,null,sc,sem)
TaskThreadInfo.threadToStarted(4).await()
// thread 5 and 6 (stage pending)must meet following two points
// 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager
// queue before executing TaskThreadInfo.threadToLock(1).jobFinished()
// 2. priority of stage in thread 5 should be prior to priority of stage in thread 6
// So I just use "sleep" 1s here for each thread.
// TODO: any better solution?
createThread(5,null,sc,sem)
Thread.sleep(1000)
createThread(6,null,sc,sem)
Thread.sleep(1000)
assert(TaskThreadInfo.threadToRunning(1) === true)
assert(TaskThreadInfo.threadToRunning(2) === true)
assert(TaskThreadInfo.threadToRunning(3) === true)
assert(TaskThreadInfo.threadToRunning(4) === true)
assert(TaskThreadInfo.threadToRunning(5) === false)
assert(TaskThreadInfo.threadToRunning(6) === false)
TaskThreadInfo.threadToLock(1).jobFinished()
TaskThreadInfo.threadToStarted(5).await()
assert(TaskThreadInfo.threadToRunning(1) === false)
assert(TaskThreadInfo.threadToRunning(2) === true)
assert(TaskThreadInfo.threadToRunning(3) === true)
assert(TaskThreadInfo.threadToRunning(4) === true)
assert(TaskThreadInfo.threadToRunning(5) === true)
assert(TaskThreadInfo.threadToRunning(6) === false)
TaskThreadInfo.threadToLock(3).jobFinished()
TaskThreadInfo.threadToStarted(6).await()
assert(TaskThreadInfo.threadToRunning(1) === false)
assert(TaskThreadInfo.threadToRunning(2) === true)
assert(TaskThreadInfo.threadToRunning(3) === false)
assert(TaskThreadInfo.threadToRunning(4) === true)
assert(TaskThreadInfo.threadToRunning(5) === true)
assert(TaskThreadInfo.threadToRunning(6) === true)
TaskThreadInfo.threadToLock(2).jobFinished()
TaskThreadInfo.threadToLock(4).jobFinished()
TaskThreadInfo.threadToLock(5).jobFinished()
TaskThreadInfo.threadToLock(6).jobFinished()
sem.acquire(6)
}
test("Local fair scheduler end-to-end test") {
sc = new SparkContext("local[8]", "LocalSchedulerSuite")
val sem = new Semaphore(0)
System.setProperty("spark.cluster.schedulingmode", "FAIR")
val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
System.setProperty("spark.fairscheduler.allocation.file", xmlPath)
createThread(10,"1",sc,sem)
TaskThreadInfo.threadToStarted(10).await()
createThread(20,"2",sc,sem)
TaskThreadInfo.threadToStarted(20).await()
createThread(30,"3",sc,sem)
TaskThreadInfo.threadToStarted(30).await()
assert(TaskThreadInfo.threadToRunning(10) === true)
assert(TaskThreadInfo.threadToRunning(20) === true)
assert(TaskThreadInfo.threadToRunning(30) === true)
createThread(11,"1",sc,sem)
TaskThreadInfo.threadToStarted(11).await()
createThread(21,"2",sc,sem)
TaskThreadInfo.threadToStarted(21).await()
createThread(31,"3",sc,sem)
TaskThreadInfo.threadToStarted(31).await()
assert(TaskThreadInfo.threadToRunning(11) === true)
assert(TaskThreadInfo.threadToRunning(21) === true)
assert(TaskThreadInfo.threadToRunning(31) === true)
createThread(12,"1",sc,sem)
TaskThreadInfo.threadToStarted(12).await()
createThread(22,"2",sc,sem)
TaskThreadInfo.threadToStarted(22).await()
createThread(32,"3",sc,sem)
assert(TaskThreadInfo.threadToRunning(12) === true)
assert(TaskThreadInfo.threadToRunning(22) === true)
assert(TaskThreadInfo.threadToRunning(32) === false)
TaskThreadInfo.threadToLock(10).jobFinished()
TaskThreadInfo.threadToStarted(32).await()
assert(TaskThreadInfo.threadToRunning(32) === true)
//1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager
// queue so that cluster will assign free cpu core to stage 23 after stage 11 finished.
//2. priority of 23 and 33 will be meaningless as using fair scheduler here.
createThread(23,"2",sc,sem)
createThread(33,"3",sc,sem)
Thread.sleep(1000)
TaskThreadInfo.threadToLock(11).jobFinished()
TaskThreadInfo.threadToStarted(23).await()
assert(TaskThreadInfo.threadToRunning(23) === true)
assert(TaskThreadInfo.threadToRunning(33) === false)
TaskThreadInfo.threadToLock(12).jobFinished()
TaskThreadInfo.threadToStarted(33).await()
assert(TaskThreadInfo.threadToRunning(33) === true)
TaskThreadInfo.threadToLock(20).jobFinished()
TaskThreadInfo.threadToLock(21).jobFinished()
TaskThreadInfo.threadToLock(22).jobFinished()
TaskThreadInfo.threadToLock(23).jobFinished()
TaskThreadInfo.threadToLock(30).jobFinished()
TaskThreadInfo.threadToLock(31).jobFinished()
TaskThreadInfo.threadToLock(32).jobFinished()
TaskThreadInfo.threadToLock(33).jobFinished()
sem.acquire(11)
}
}