Deduplicate Local and Cluster schedulers.

The code in LocalScheduler/LocalTaskSetManager was nearly identical
to the code in ClusterScheduler/ClusterTaskSetManager. The redundancy
made making updating the schedulers unnecessarily painful and error-
prone. This commit combines the two into a single TaskScheduler/
TaskSetManager.
This commit is contained in:
Kay Ousterhout 2013-10-30 17:07:24 -07:00
parent dc9ce16f6b
commit 5e91495f5c
22 changed files with 1280 additions and 1932 deletions

View file

@ -56,10 +56,9 @@ import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._ import org.apache.spark.rdd._
import org.apache.spark.scheduler._ import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
SparkDeploySchedulerBackend, ClusterScheduler, SimrSchedulerBackend} SparkDeploySchedulerBackend, SimrSchedulerBackend}
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import org.apache.spark.scheduler.local.LocalScheduler import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.scheduler.StageInfo
import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils} import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
import org.apache.spark.ui.SparkUI import org.apache.spark.ui.SparkUI
import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType,
@ -149,8 +148,6 @@ class SparkContext(
private[spark] var taskScheduler: TaskScheduler = { private[spark] var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format // Regular expression used for local[N] master format
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally // Regular expression for simulating a Spark cluster of [N, cores, memory] locally
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters // Regular expression for connecting to Spark deploy clusters
@ -162,23 +159,26 @@ class SparkContext(
master match { master match {
case "local" => case "local" =>
new LocalScheduler(1, 0, this) val scheduler = new TaskScheduler(this)
val backend = new LocalBackend(scheduler, 1)
scheduler.initialize(backend)
scheduler
case LOCAL_N_REGEX(threads) => case LOCAL_N_REGEX(threads) =>
new LocalScheduler(threads.toInt, 0, this) val scheduler = new TaskScheduler(this)
val backend = new LocalBackend(scheduler, threads.toInt)
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => scheduler.initialize(backend)
new LocalScheduler(threads.toInt, maxFailures.toInt, this) scheduler
case SPARK_REGEX(sparkUrl) => case SPARK_REGEX(sparkUrl) =>
val scheduler = new ClusterScheduler(this) val scheduler = new TaskScheduler(this)
val masterUrls = sparkUrl.split(",").map("spark://" + _) val masterUrls = sparkUrl.split(",").map("spark://" + _)
val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName) val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName)
scheduler.initialize(backend) scheduler.initialize(backend)
scheduler scheduler
case SIMR_REGEX(simrUrl) => case SIMR_REGEX(simrUrl) =>
val scheduler = new ClusterScheduler(this) val scheduler = new TaskScheduler(this)
val backend = new SimrSchedulerBackend(scheduler, this, simrUrl) val backend = new SimrSchedulerBackend(scheduler, this, simrUrl)
scheduler.initialize(backend) scheduler.initialize(backend)
scheduler scheduler
@ -192,7 +192,7 @@ class SparkContext(
memoryPerSlaveInt, SparkContext.executorMemoryRequested)) memoryPerSlaveInt, SparkContext.executorMemoryRequested))
} }
val scheduler = new ClusterScheduler(this) val scheduler = new TaskScheduler(this)
val localCluster = new LocalSparkCluster( val localCluster = new LocalSparkCluster(
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
val masterUrls = localCluster.start() val masterUrls = localCluster.start()
@ -207,7 +207,7 @@ class SparkContext(
val scheduler = try { val scheduler = try {
val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
val cons = clazz.getConstructor(classOf[SparkContext]) val cons = clazz.getConstructor(classOf[SparkContext])
cons.newInstance(this).asInstanceOf[ClusterScheduler] cons.newInstance(this).asInstanceOf[TaskScheduler]
} catch { } catch {
// TODO: Enumerate the exact reasons why it can fail // TODO: Enumerate the exact reasons why it can fail
// But irrespective of it, it means we cannot proceed ! // But irrespective of it, it means we cannot proceed !
@ -221,7 +221,7 @@ class SparkContext(
case MESOS_REGEX(mesosUrl) => case MESOS_REGEX(mesosUrl) =>
MesosNativeLibrary.load() MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this) val scheduler = new TaskScheduler(this)
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
val backend = if (coarseGrained) { val backend = if (coarseGrained) {
new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName) new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName)
@ -593,9 +593,7 @@ class SparkContext(
} }
addedFiles(key) = System.currentTimeMillis addedFiles(key) = System.currentTimeMillis
// Fetch the file locally in case a job is executed locally. // Fetch the file locally in case a job is executed using DAGScheduler.runLocally().
// Jobs that run through LocalScheduler will already fetch the required dependencies,
// but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here.
Utils.fetchFile(path, new File(SparkFiles.getRootDirectory)) Utils.fetchFile(path, new File(SparkFiles.getRootDirectory))
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))

View file

@ -15,7 +15,7 @@
* limitations under the License. * limitations under the License.
*/ */
package org.apache.spark.scheduler.cluster package org.apache.spark.scheduler
import org.apache.spark.executor.ExecutorExitCode import org.apache.spark.executor.ExecutorExitCode

View file

@ -15,13 +15,13 @@
* limitations under the License. * limitations under the License.
*/ */
package org.apache.spark.scheduler.cluster package org.apache.spark.scheduler
import org.apache.spark.SparkContext import org.apache.spark.SparkContext
/** /**
* A backend interface for cluster scheduling systems that allows plugging in different ones under * A backend interface for scheduling systems that allows plugging in different ones under
* ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as * TaskScheduler. We assume a Mesos-like model where the application gets resource offers as
* machines become available and can launch tasks on them. * machines become available and can launch tasks on them.
*/ */
private[spark] trait SchedulerBackend { private[spark] trait SchedulerBackend {

View file

@ -15,21 +15,20 @@
* limitations under the License. * limitations under the License.
*/ */
package org.apache.spark.scheduler.cluster package org.apache.spark.scheduler
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.util.concurrent.{LinkedBlockingDeque, ThreadFactory, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.{LinkedBlockingDeque, ThreadFactory, ThreadPoolExecutor, TimeUnit}
import org.apache.spark._ import org.apache.spark._
import org.apache.spark.TaskState.TaskState import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
import org.apache.spark.serializer.SerializerInstance import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
/** /**
* Runs a thread pool that deserializes and remotely fetches (if necessary) task results. * Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
*/ */
private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler) private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskScheduler)
extends Logging { extends Logging {
private val THREADS = System.getProperty("spark.resultGetter.threads", "4").toInt private val THREADS = System.getProperty("spark.resultGetter.threads", "4").toInt
private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool( private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool(
@ -42,7 +41,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
} }
def enqueueSuccessfulTask( def enqueueSuccessfulTask(
taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) { taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) {
getTaskResultExecutor.execute(new Runnable { getTaskResultExecutor.execute(new Runnable {
override def run() { override def run() {
try { try {
@ -78,7 +77,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
}) })
} }
def enqueueFailedTask(taskSetManager: ClusterTaskSetManager, tid: Long, taskState: TaskState, def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState,
serializedData: ByteBuffer) { serializedData: ByteBuffer) {
var reason: Option[TaskEndReason] = None var reason: Option[TaskEndReason] = None
getTaskResultExecutor.execute(new Runnable { getTaskResultExecutor.execute(new Runnable {

View file

@ -17,39 +17,477 @@
package org.apache.spark.scheduler package org.apache.spark.scheduler
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicLong
import java.util.{TimerTask, Timer}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
/** /**
* Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler. * Schedules tasks for a single SparkContext. Receives a set of tasks from the DAGScheduler for
* Each TaskScheduler schedulers task for a single SparkContext. * each stage, and is responsible for sending tasks to executors, running them, retrying if there
* These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage, * are failures, and mitigating stragglers. Returns events to the DAGScheduler.
* and are responsible for sending the tasks to the cluster, running them, retrying if there *
* are failures, and mitigating stragglers. They return events to the DAGScheduler. * Clients should first call initialize() and start(), then submit task sets through the
* runTasks method.
*
* This class can work with multiple types of clusters by acting through a SchedulerBackend.
* It can also work with a local setup by using a LocalBackend and setting isLocal to true.
* It handles common logic, like determining a scheduling order across jobs, waking up to launch
* speculative tasks, etc.
*
* THREADING: SchedulerBackends and task-submitting clients can call this class from multiple
* threads, so it needs locks in public API methods to maintain its state. In addition, some
* SchedulerBackends sycnchronize on themselves when they want to send events here, and then
* acquire a lock on us, so we need to make sure that we don't try to lock the backend while
* we are holding a lock on ourselves.
*/ */
private[spark] trait TaskScheduler { private[spark] class TaskScheduler(val sc: SparkContext, isLocal: Boolean = false) extends Logging {
// How often to check for speculative tasks
val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
def rootPool: Pool // Threshold above which we warn user initial TaskSet may be starved
val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong
def schedulingMode: SchedulingMode // TaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
val activeTaskSets = new HashMap[String, TaskSetManager]
def start(): Unit val taskIdToTaskSetId = new HashMap[Long, String]
val taskIdToExecutorId = new HashMap[Long, String]
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
// Invoked after system has successfully initialized (typically in spark context). @volatile private var hasReceivedTask = false
// Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc. @volatile private var hasLaunchedTask = false
private val starvationTimer = new Timer(true)
// Incrementing task IDs
val nextTaskId = new AtomicLong(0)
// Which executor IDs we have executors on
val activeExecutorIds = new HashSet[String]
// The set of executors we have on each host; this is used to compute hostsAlive, which
// in turn is used to decide when we can attain data locality on a given host
private val executorsByHost = new HashMap[String, HashSet[String]]
private val executorIdToHost = new HashMap[String, String]
// Listener object to pass upcalls into
var dagScheduler: DAGScheduler = null
var backend: SchedulerBackend = null
val mapOutputTracker = SparkEnv.get.mapOutputTracker
var schedulableBuilder: SchedulableBuilder = null
var rootPool: Pool = null
// default scheduler is FIFO
val schedulingMode: SchedulingMode = SchedulingMode.withName(
System.getProperty("spark.scheduler.mode", "FIFO"))
// This is a var so that we can reset it for testing purposes.
private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)
def setDAGScheduler(dagScheduler: DAGScheduler) {
this.dagScheduler = dagScheduler
}
def initialize(context: SchedulerBackend) {
backend = context
// temporarily set rootPool name to empty
rootPool = new Pool("", schedulingMode, 0, 0)
schedulableBuilder = {
schedulingMode match {
case SchedulingMode.FIFO =>
new FIFOSchedulableBuilder(rootPool)
case SchedulingMode.FAIR =>
new FairSchedulableBuilder(rootPool)
}
}
schedulableBuilder.buildPools()
}
def newTaskId(): Long = nextTaskId.getAndIncrement()
def start() {
backend.start()
if (!isLocal && System.getProperty("spark.speculation", "false").toBoolean) {
new Thread("TaskScheduler speculation check") {
setDaemon(true)
override def run() {
logInfo("Starting speculative execution thread")
while (true) {
try {
Thread.sleep(SPECULATION_INTERVAL)
} catch {
case e: InterruptedException => {}
}
checkSpeculatableTasks()
}
}
}.start()
}
}
def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = new TaskSetManager(this, taskSet)
activeTaskSets(taskSet.id) = manager
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
if (!isLocal && !hasReceivedTask) {
starvationTimer.scheduleAtFixedRate(new TimerTask() {
override def run() {
if (!hasLaunchedTask) {
logWarning("Initial job has not accepted any resources; " +
"check your cluster UI to ensure that workers are registered " +
"and have sufficient memory")
} else {
this.cancel()
}
}
}, STARVATION_TIMEOUT, STARVATION_TIMEOUT)
}
hasReceivedTask = true
}
backend.reviveOffers()
}
def cancelTasks(stageId: Int): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
// There are two possible cases here:
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task and then abort
// the stage.
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
// simply abort the stage.
val taskIds = taskSetTaskIds(tsm.taskSet.id)
if (taskIds.size > 0) {
taskIds.foreach { tid =>
val execId = taskIdToExecutorId(tid)
backend.killTask(tid, execId)
}
}
tsm.error("Stage %d was cancelled".format(stageId))
}
}
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
// Check to see if the given task set has been removed. This is possible in the case of
// multiple unrecoverable task failures (e.g. if the entire task set is killed when it has
// more than one running tasks).
if (activeTaskSets.contains(manager.taskSet.id)) {
activeTaskSets -= manager.taskSet.id
manager.parent.removeSchedulable(manager)
logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
taskSetTaskIds.remove(manager.taskSet.id)
}
}
/**
* Called by cluster manager to offer resources on slaves. We respond by asking our active task
* sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so
* that tasks are balanced across the cluster.
*/
def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
SparkEnv.set(sc.env)
// Mark each slave as alive and remember its hostname
for (o <- offers) {
executorIdToHost(o.executorId) = o.host
if (!executorsByHost.contains(o.host)) {
executorsByHost(o.host) = new HashSet[String]()
executorGained(o.executorId, o.host)
}
}
// Build a list of tasks to assign to each worker
val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
val availableCpus = offers.map(o => o.cores).toArray
val sortedTaskSets = rootPool.getSortedTaskSetQueue()
for (taskSet <- sortedTaskSets) {
logDebug("parentName: %s, name: %s, runningTasks: %s".format(
taskSet.parent.name, taskSet.name, taskSet.runningTasks))
}
// Take each TaskSet in our scheduling order, and then offer it each node in increasing order
// of locality levels so that it gets a chance to launch local tasks on all of them.
var launchedTask = false
for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) {
do {
launchedTask = false
for (i <- 0 until offers.size) {
val execId = offers(i).executorId
val host = offers(i).host
for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) {
tasks(i) += task
val tid = task.taskId
taskIdToTaskSetId(tid) = taskSet.taskSet.id
taskSetTaskIds(taskSet.taskSet.id) += tid
taskIdToExecutorId(tid) = execId
activeExecutorIds += execId
executorsByHost(host) += execId
availableCpus(i) -= 1
launchedTask = true
}
}
} while (launchedTask)
}
if (tasks.size > 0) {
hasLaunchedTask = true
}
return tasks
}
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
var failedExecutor: Option[String] = None
var taskFailed = false
synchronized {
try {
if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
// We lost this entire executor, so remember that it's gone
val execId = taskIdToExecutorId(tid)
if (activeExecutorIds.contains(execId)) {
removeExecutor(execId)
failedExecutor = Some(execId)
}
}
taskIdToTaskSetId.get(tid) match {
case Some(taskSetId) =>
if (TaskState.isFinished(state)) {
taskIdToTaskSetId.remove(tid)
if (taskSetTaskIds.contains(taskSetId)) {
taskSetTaskIds(taskSetId) -= tid
}
taskIdToExecutorId.remove(tid)
}
if (state == TaskState.FAILED) {
taskFailed = true
}
activeTaskSets.get(taskSetId).foreach { taskSet =>
if (state == TaskState.FINISHED) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
}
}
case None =>
logInfo("Ignoring update from TID " + tid + " because its task set is gone")
}
} catch {
case e: Exception => logError("Exception in statusUpdate", e)
}
}
// Update the DAGScheduler without holding a lock on this, since that can deadlock
if (failedExecutor != None) {
dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers()
}
if (taskFailed) {
// Also revive offers if a task had failed for some reason other than host lost
backend.reviveOffers()
}
}
def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) {
taskSetManager.handleTaskGettingResult(tid)
}
def handleSuccessfulTask(
taskSetManager: TaskSetManager,
tid: Long,
taskResult: DirectTaskResult[_]) = synchronized {
taskSetManager.handleSuccessfulTask(tid, taskResult)
}
def handleFailedTask(
taskSetManager: TaskSetManager,
tid: Long,
taskState: TaskState,
reason: Option[TaskEndReason]) = synchronized {
taskSetManager.handleFailedTask(tid, taskState, reason)
if (taskState == TaskState.FINISHED) {
// The task finished successfully but the result was lost, so we should revive offers.
backend.reviveOffers()
}
}
def error(message: String) {
synchronized {
if (activeTaskSets.size > 0) {
// Have each task set throw a SparkException with the error
for ((taskSetId, manager) <- activeTaskSets) {
try {
manager.error(message)
} catch {
case e: Exception => logError("Exception in error callback", e)
}
}
} else {
// No task sets are active but we still got an error. Just exit since this
// must mean the error is during registration.
// It might be good to do something smarter here in the future.
logError("Exiting due to error from task scheduler: " + message)
System.exit(1)
}
}
}
def stop() {
if (backend != null) {
backend.stop()
}
if (taskResultGetter != null) {
taskResultGetter.stop()
}
// sleeping for an arbitrary 5 seconds : to ensure that messages are sent out.
// TODO: Do something better !
Thread.sleep(5000L)
}
def defaultParallelism() = backend.defaultParallelism()
// Check for speculatable tasks in all our active jobs.
def checkSpeculatableTasks() {
var shouldRevive = false
synchronized {
shouldRevive = rootPool.checkSpeculatableTasks()
}
if (shouldRevive) {
backend.reviveOffers()
}
}
// Check for pending tasks in all our active jobs.
def hasPendingTasks: Boolean = {
synchronized {
rootPool.hasPendingTasks()
}
}
def executorLost(executorId: String, reason: ExecutorLossReason) {
var failedExecutor: Option[String] = None
synchronized {
if (activeExecutorIds.contains(executorId)) {
val hostPort = executorIdToHost(executorId)
logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason))
removeExecutor(executorId)
failedExecutor = Some(executorId)
} else {
// We may get multiple executorLost() calls with different loss reasons. For example, one
// may be triggered by a dropped connection from the slave while another may be a report
// of executor termination from Mesos. We produce log messages for both so we eventually
// report the termination reason.
logError("Lost an executor " + executorId + " (already removed): " + reason)
}
}
// Call dagScheduler.executorLost without holding the lock on this to prevent deadlock
if (failedExecutor != None) {
dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers()
}
}
/** Remove an executor from all our data structures and mark it as lost */
private def removeExecutor(executorId: String) {
activeExecutorIds -= executorId
val host = executorIdToHost(executorId)
val execs = executorsByHost.getOrElse(host, new HashSet)
execs -= executorId
if (execs.isEmpty) {
executorsByHost -= host
}
executorIdToHost -= executorId
rootPool.executorLost(executorId, host)
}
def executorGained(execId: String, host: String) {
dagScheduler.executorGained(execId, host)
}
def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized {
executorsByHost.get(host).map(_.toSet)
}
def hasExecutorsAliveOnHost(host: String): Boolean = synchronized {
executorsByHost.contains(host)
}
def isExecutorAlive(execId: String): Boolean = synchronized {
activeExecutorIds.contains(execId)
}
// By default, rack is unknown
def getRackForHost(value: String): Option[String] = None
/**
* Invoked after the system has successfully been initialized. YARN uses this to bootstrap
* allocation of resources based on preferred locations, wait for slave registrations, etc.
*/
def postStartHook() { } def postStartHook() { }
}
// Disconnect from the cluster.
def stop(): Unit
object TaskScheduler {
// Submit a sequence of tasks to run. /**
def submitTasks(taskSet: TaskSet): Unit * Used to balance containers across hosts.
*
// Cancel a stage. * Accepts a map of hosts to resource offers for that host, and returns a prioritized list of
def cancelTasks(stageId: Int) * resource offers representing the order in which the offers should be used. The resource
* offers are ordered such that we'll allocate one container on each host before allocating a
// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. * second container on any host, and so on, in order to reduce the damage if a host fails.
def setDAGScheduler(dagScheduler: DAGScheduler): Unit *
* For example, given <h1, [o1, o2, o3]>, <h2, [o4]>, <h1, [o5, o6]>, returns
// Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. * [o1, o5, o4, 02, o6, o3]
def defaultParallelism(): Int */
def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = {
val _keyList = new ArrayBuffer[K](map.size)
_keyList ++= map.keys
// order keyList based on population of value in map
val keyList = _keyList.sortWith(
(left, right) => map(left).size > map(right).size
)
val retval = new ArrayBuffer[T](keyList.size * 2)
var index = 0
var found = true
while (found) {
found = false
for (key <- keyList) {
val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null)
assert(containerList != null)
// Get the index'th entry for this host - if present
if (index < containerList.size){
retval += containerList.apply(index)
found = true
}
}
index += 1
}
retval.toList
}
} }

View file

@ -17,32 +17,690 @@
package org.apache.spark.scheduler package org.apache.spark.scheduler
import java.nio.ByteBuffer import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.math.max
import scala.math.min
import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
import org.apache.spark.TaskState.TaskState import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler._
import org.apache.spark.util.{SystemClock, Clock}
/** /**
* Tracks and schedules the tasks within a single TaskSet. This class keeps track of the status of * Schedules the tasks within a single TaskSet in the TaskScheduler. This class keeps track of
* each task and is responsible for retries on failure and locality. The main interfaces to it * each task, retries tasks if they fail (up to a limited number of times), and
* are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, and * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces
* statusUpdate, which tells it that one of its tasks changed state (e.g. finished). * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node,
* and statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
* *
* THREADING: This class is designed to only be called from code with a lock on the TaskScheduler * THREADING: This class is designed to only be called from code with a lock on the
* (e.g. its event handlers). It should not be called from other threads. * TaskScheduler (e.g. its event handlers). It should not be called from other threads.
*/ */
private[spark] trait TaskSetManager extends Schedulable { private[spark] class TaskSetManager(
def schedulableQueue = null sched: TaskScheduler,
val taskSet: TaskSet,
def schedulingMode = SchedulingMode.NONE clock: Clock = SystemClock)
extends Schedulable with Logging
def taskSet: TaskSet {
// CPUs to request per task
val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt
// Maximum times a task is allowed to fail before failing the job
val MAX_TASK_FAILURES = System.getProperty("spark.task.maxFailures", "4").toInt
// Quantile of tasks at which to start speculation
val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
// Serializer for closures and tasks.
val env = SparkEnv.get
val ser = env.closureSerializer.newInstance()
val tasks = taskSet.tasks
val numTasks = tasks.length
val copiesRunning = new Array[Int](numTasks)
val successful = new Array[Boolean](numTasks)
val numFailures = new Array[Int](numTasks)
val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
var tasksSuccessful = 0
var weight = 1
var minShare = 0
var priority = taskSet.priority
var stageId = taskSet.stageId
var name = "TaskSet_"+taskSet.stageId.toString
var parent: Pool = null
var runningTasks = 0
private val runningTasksSet = new HashSet[Long]
// Set of pending tasks for each executor. These collections are actually
// treated as stacks, in which new tasks are added to the end of the
// ArrayBuffer and removed from the end. This makes it faster to detect
// tasks that repeatedly fail because whenever a task failed, it is put
// back at the head of the stack. They are also only cleaned up lazily;
// when a task is launched, it remains in all the pending lists except
// the one that it was launched from, but gets removed from them later.
private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]]
// Set of pending tasks for each host. Similar to pendingTasksForExecutor,
// but at host level.
private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
// Set of pending tasks for each rack -- similar to the above.
private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]]
// Set containing pending tasks with no locality preferences.
val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
// Set containing all pending tasks (also used as a stack, as above).
val allPendingTasks = new ArrayBuffer[Int]
// Tasks that can be speculated. Since these will be a small fraction of total
// tasks, we'll just hold them in a HashSet.
val speculatableTasks = new HashSet[Int]
// Task index, start and finish time for each task attempt (indexed by task ID)
val taskInfos = new HashMap[Long, TaskInfo]
// Did the TaskSet fail?
var failed = false
var causeOfFailure = ""
// How frequently to reprint duplicate exceptions in full, in milliseconds
val EXCEPTION_PRINT_INTERVAL =
System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
// Map of recent exceptions (identified by string representation and top stack frame) to
// duplicate count (how many times the same exception has appeared) and time the full exception
// was printed. This should ideally be an LRU map that can drop old exceptions automatically.
val recentExceptions = HashMap[String, (Int, Long)]()
// Figure out the current map output tracker epoch and set it on all tasks
val epoch = sched.mapOutputTracker.getEpoch
logDebug("Epoch for " + taskSet + ": " + epoch)
for (t <- tasks) {
t.epoch = epoch
}
// Add all our tasks to the pending lists. We do this in reverse order
// of task index so that tasks with low indices get launched first.
for (i <- (0 until numTasks).reverse) {
addPendingTask(i)
}
// Figure out which locality levels we have in our TaskSet, so we can do delay scheduling
val myLocalityLevels = computeValidLocalityLevels()
val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level
// Delay scheduling variables: we keep track of our current locality level and the time we
// last launched a task at that level, and move up a level when localityWaits[curLevel] expires.
// We then move down if we manage to launch a "more local" task.
var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels
var lastLaunchTime = clock.getTime() // Time we last launched a task at this level
override def schedulableQueue = null
override def schedulingMode = SchedulingMode.NONE
/**
* Add a task to all the pending-task lists that it should be on. If readding is set, we are
* re-adding the task so only include it in each list if it's not already there.
*/
private def addPendingTask(index: Int, readding: Boolean = false) {
// Utility method that adds `index` to a list only if readding=false or it's not already there
def addTo(list: ArrayBuffer[Int]) {
if (!readding || !list.contains(index)) {
list += index
}
}
var hadAliveLocations = false
for (loc <- tasks(index).preferredLocations) {
for (execId <- loc.executorId) {
if (sched.isExecutorAlive(execId)) {
addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
hadAliveLocations = true
}
}
if (sched.hasExecutorsAliveOnHost(loc.host)) {
addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
for (rack <- sched.getRackForHost(loc.host)) {
addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer))
}
hadAliveLocations = true
}
}
if (!hadAliveLocations) {
// Even though the task might've had preferred locations, all of those hosts or executors
// are dead; put it in the no-prefs list so we can schedule it elsewhere right away.
addTo(pendingTasksWithNoPrefs)
}
if (!readding) {
allPendingTasks += index // No point scanning this whole list to find the old task there
}
}
/**
* Return the pending tasks list for a given executor ID, or an empty list if
* there is no map entry for that host
*/
private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = {
pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer())
}
/**
* Return the pending tasks list for a given host, or an empty list if
* there is no map entry for that host
*/
private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
pendingTasksForHost.getOrElse(host, ArrayBuffer())
}
/**
* Return the pending rack-local task list for a given rack, or an empty list if
* there is no map entry for that rack
*/
private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = {
pendingTasksForRack.getOrElse(rack, ArrayBuffer())
}
/**
* Dequeue a pending task from the given list and return its index.
* Return None if the list is empty.
* This method also cleans up any tasks in the list that have already
* been launched, since we want that to happen lazily.
*/
private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
while (!list.isEmpty) {
val index = list.last
list.trimEnd(1)
if (copiesRunning(index) == 0 && !successful(index)) {
return Some(index)
}
}
return None
}
/** Check whether a task is currently running an attempt on a given host */
private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = {
!taskAttempts(taskIndex).exists(_.host == host)
}
/**
* Return a speculative task for a given executor if any are available. The task should not have
* an attempt running on this host, in case the host is slow. In addition, the task should meet
* the given locality constraint.
*/
private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
: Option[(Int, TaskLocality.Value)] =
{
speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
if (!speculatableTasks.isEmpty) {
// Check for process-local or preference-less tasks; note that tasks can be process-local
// on multiple nodes when we replicate cached blocks, as in Spark Streaming
for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
val prefs = tasks(index).preferredLocations
val executors = prefs.flatMap(_.executorId)
if (prefs.size == 0 || executors.contains(execId)) {
speculatableTasks -= index
return Some((index, TaskLocality.PROCESS_LOCAL))
}
}
// Check for node-local tasks
if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
val locations = tasks(index).preferredLocations.map(_.host)
if (locations.contains(host)) {
speculatableTasks -= index
return Some((index, TaskLocality.NODE_LOCAL))
}
}
}
// Check for rack-local tasks
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
for (rack <- sched.getRackForHost(host)) {
for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost)
if (racks.contains(rack)) {
speculatableTasks -= index
return Some((index, TaskLocality.RACK_LOCAL))
}
}
}
}
// Check for non-local tasks
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
speculatableTasks -= index
return Some((index, TaskLocality.ANY))
}
}
}
return None
}
/**
* Dequeue a pending task for a given node and return its index and locality level.
* Only search for tasks matching the given locality constraint.
*/
private def findTask(execId: String, host: String, locality: TaskLocality.Value)
: Option[(Int, TaskLocality.Value)] =
{
for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) {
return Some((index, TaskLocality.PROCESS_LOCAL))
}
if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
for (index <- findTaskFromList(getPendingTasksForHost(host))) {
return Some((index, TaskLocality.NODE_LOCAL))
}
}
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
for {
rack <- sched.getRackForHost(host)
index <- findTaskFromList(getPendingTasksForRack(rack))
} {
return Some((index, TaskLocality.RACK_LOCAL))
}
}
// Look for no-pref tasks after rack-local tasks since they can run anywhere.
for (index <- findTaskFromList(pendingTasksWithNoPrefs)) {
return Some((index, TaskLocality.PROCESS_LOCAL))
}
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
for (index <- findTaskFromList(allPendingTasks)) {
return Some((index, TaskLocality.ANY))
}
}
// Finally, if all else has failed, find a speculative task
return findSpeculativeTask(execId, host, locality)
}
/**
* Respond to an offer of a single executor from the scheduler by finding a task
*/
def resourceOffer( def resourceOffer(
execId: String, execId: String,
host: String, host: String,
availableCpus: Int, availableCpus: Int,
maxLocality: TaskLocality.TaskLocality) maxLocality: TaskLocality.TaskLocality)
: Option[TaskDescription] : Option[TaskDescription] =
{
if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) {
val curTime = clock.getTime()
def error(message: String) var allowedLocality = getAllowedLocalityLevel(curTime)
if (allowedLocality > maxLocality) {
allowedLocality = maxLocality // We're not allowed to search for farther-away tasks
}
findTask(execId, host, allowedLocality) match {
case Some((index, taskLocality)) => {
// Found a task; do some bookkeeping and return a task description
val task = tasks(index)
val taskId = sched.newTaskId()
// Figure out whether this should count as a preferred launch
logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
taskSet.id, index, taskId, execId, host, taskLocality))
// Do various bookkeeping
copiesRunning(index) += 1
val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality)
taskInfos(taskId) = info
taskAttempts(index) = info :: taskAttempts(index)
// Update our locality level for delay scheduling
currentLocalityIndex = getLocalityIndex(taskLocality)
lastLaunchTime = curTime
// Serialize and return the task
val startTime = clock.getTime()
// We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
// we assume the task can be serialized without exceptions.
val serializedTask = Task.serializeWithDependencies(
task, sched.sc.addedFiles, sched.sc.addedJars, ser)
val timeTaken = clock.getTime() - startTime
addRunningTask(taskId)
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
val taskName = "task %s:%d".format(taskSet.id, index)
if (taskAttempts(index).size == 1)
taskStarted(task,info)
return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
}
case _ =>
}
}
return None
}
/**
* Get the level we can launch tasks according to delay scheduling, based on current wait time.
*/
private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = {
while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) &&
currentLocalityIndex < myLocalityLevels.length - 1)
{
// Jump to the next locality level, and remove our waiting time for the current one since
// we don't want to count it again on the next one
lastLaunchTime += localityWaits(currentLocalityIndex)
currentLocalityIndex += 1
}
myLocalityLevels(currentLocalityIndex)
}
/**
* Find the index in myLocalityLevels for a given locality. This is also designed to work with
* localities that are not in myLocalityLevels (in case we somehow get those) by returning the
* next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY.
*/
def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = {
var index = 0
while (locality > myLocalityLevels(index)) {
index += 1
}
index
}
private def taskStarted(task: Task[_], info: TaskInfo) {
sched.dagScheduler.taskStarted(task, info)
}
def handleTaskGettingResult(tid: Long) = {
val info = taskInfos(tid)
info.markGettingResult()
sched.dagScheduler.taskGettingResult(tasks(info.index), info)
}
/**
* Marks the task as successful and notifies the DAGScheduler that a task has ended.
*/
def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
val info = taskInfos(tid)
val index = info.index
info.markSuccessful()
removeRunningTask(tid)
if (!successful(index)) {
logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
tid, info.duration, info.host, tasksSuccessful, numTasks))
sched.dagScheduler.taskEnded(
tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
// Mark successful and stop if all the tasks have succeeded.
tasksSuccessful += 1
successful(index) = true
if (tasksSuccessful == numTasks) {
sched.taskSetFinished(this)
}
} else {
logInfo("Ignorning task-finished event for TID " + tid + " because task " +
index + " has already completed successfully")
}
}
/**
* Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
* DAG Scheduler.
*/
def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
val info = taskInfos(tid)
if (info.failed) {
return
}
removeRunningTask(tid)
val index = info.index
info.markFailed()
if (!successful(index)) {
logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
copiesRunning(index) -= 1
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
reason.foreach {
case fetchFailed: FetchFailed =>
logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
successful(index) = true
tasksSuccessful += 1
sched.taskSetFinished(this)
removeAllRunningTasks()
return
case TaskKilled =>
logWarning("Task %d was killed.".format(tid))
sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
return
case ef: ExceptionFailure =>
sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
val key = ef.description
val now = clock.getTime()
val (printFull, dupCount) = {
if (recentExceptions.contains(key)) {
val (dupCount, printTime) = recentExceptions(key)
if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
recentExceptions(key) = (0, now)
(true, 0)
} else {
recentExceptions(key) = (dupCount + 1, printTime)
(false, dupCount + 1)
}
} else {
recentExceptions(key) = (0, now)
(true, 0)
}
}
if (printFull) {
val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
logWarning("Loss was due to %s\n%s\n%s".format(
ef.className, ef.description, locs.mkString("\n")))
} else {
logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
}
case TaskResultLost =>
logWarning("Lost result for TID %s on host %s".format(tid, info.host))
sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
case _ => {}
}
// On non-fetch failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index)
if (state != TaskState.KILLED) {
numFailures(index) += 1
if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %s:%d failed more than %d times; aborting job".format(
taskSet.id, index, MAX_TASK_FAILURES))
abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
}
}
} else {
logInfo("Ignoring task-lost event for TID " + tid +
" because task " + index + " is already finished")
}
}
def error(message: String) {
// Save the error message
abort("Error: " + message)
}
def abort(message: String) {
failed = true
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
sched.dagScheduler.taskSetFailed(taskSet, message)
removeAllRunningTasks()
sched.taskSetFinished(this)
}
/** If the given task ID is not in the set of running tasks, adds it.
*
* Used to keep track of the number of running tasks, for enforcing scheduling policies.
*/
def addRunningTask(tid: Long) {
if (runningTasksSet.add(tid) && parent != null) {
parent.increaseRunningTasks(1)
}
runningTasks = runningTasksSet.size
}
/** If the given task ID is in the set of running tasks, removes it. */
def removeRunningTask(tid: Long) {
if (runningTasksSet.remove(tid) && parent != null) {
parent.decreaseRunningTasks(1)
}
runningTasks = runningTasksSet.size
}
private def removeAllRunningTasks() {
val numRunningTasks = runningTasksSet.size
runningTasksSet.clear()
if (parent != null) {
parent.decreaseRunningTasks(numRunningTasks)
}
runningTasks = 0
}
override def getSchedulableByName(name: String): Schedulable = {
return null
}
override def addSchedulable(schedulable: Schedulable) {}
override def removeSchedulable(schedulable: Schedulable) {}
override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this)
sortedTaskSetQueue += this
return sortedTaskSetQueue
}
/** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */
override def executorLost(execId: String, host: String) {
logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
// Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a
// task that used to have locations on only this host might now go to the no-prefs list. Note
// that it's okay if we add a task to the same queue twice (if it had multiple preferred
// locations), because findTaskFromList will skip already-running tasks.
for (index <- getPendingTasksForExecutor(execId)) {
addPendingTask(index, readding=true)
}
for (index <- getPendingTasksForHost(host)) {
addPendingTask(index, readding=true)
}
// Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
if (tasks(0).isInstanceOf[ShuffleMapTask]) {
for ((tid, info) <- taskInfos if info.executorId == execId) {
val index = taskInfos(tid).index
if (successful(index)) {
successful(index) = false
copiesRunning(index) -= 1
tasksSuccessful -= 1
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
}
}
}
// Also re-enqueue any tasks that were running on the node
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
handleFailedTask(tid, TaskState.KILLED, None)
}
}
/**
* Check for tasks to be speculated and return true if there are any. This is called periodically
* by the TaskScheduler.
*
* TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
* we don't scan the whole task set. It might also help to make this sorted by launch time.
*/
override def checkSpeculatableTasks(): Boolean = {
// Can't speculate if we only have one task, or if all tasks have finished.
if (numTasks == 1 || tasksSuccessful == numTasks) {
return false
}
var foundTasks = false
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
val time = clock.getTime()
val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
Arrays.sort(durations)
val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1))
val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
// TODO: Threshold should also look at standard deviation of task durations and have a lower
// bound based on that.
logDebug("Task length threshold for speculation: " + threshold)
for ((tid, info) <- taskInfos) {
val index = info.index
if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
!speculatableTasks.contains(index)) {
logInfo(
"Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
taskSet.id, index, info.host, threshold))
speculatableTasks += index
foundTasks = true
}
}
}
return foundTasks
}
override def hasPendingTasks(): Boolean = {
numTasks > 0 && tasksSuccessful < numTasks
}
private def getLocalityWait(level: TaskLocality.TaskLocality): Long = {
val defaultWait = System.getProperty("spark.locality.wait", "3000")
level match {
case TaskLocality.PROCESS_LOCAL =>
System.getProperty("spark.locality.wait.process", defaultWait).toLong
case TaskLocality.NODE_LOCAL =>
System.getProperty("spark.locality.wait.node", defaultWait).toLong
case TaskLocality.RACK_LOCAL =>
System.getProperty("spark.locality.wait.rack", defaultWait).toLong
case TaskLocality.ANY =>
0L
}
}
/**
* Compute the locality levels used in this TaskSet. Assumes that all tasks have already been
* added to queues using addPendingTask.
*/
private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY}
val levels = new ArrayBuffer[TaskLocality.TaskLocality]
if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) {
levels += PROCESS_LOCAL
}
if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) {
levels += NODE_LOCAL
}
if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) {
levels += RACK_LOCAL
}
levels += ANY
logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", "))
levels.toArray
}
} }

View file

@ -15,7 +15,7 @@
* limitations under the License. * limitations under the License.
*/ */
package org.apache.spark.scheduler.cluster package org.apache.spark.scheduler
/** /**
* Represents free resources available on an executor. * Represents free resources available on an executor.

View file

@ -1,486 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.scheduler.cluster
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicLong
import java.util.{TimerTask, Timer}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
/**
* The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call
* initialize() and start(), then submit task sets through the runTasks method.
*
* This class can work with multiple types of clusters by acting through a SchedulerBackend.
* It handles common logic, like determining a scheduling order across jobs, waking up to launch
* speculative tasks, etc.
*
* THREADING: SchedulerBackends and task-submitting clients can call this class from multiple
* threads, so it needs locks in public API methods to maintain its state. In addition, some
* SchedulerBackends sycnchronize on themselves when they want to send events here, and then
* acquire a lock on us, so we need to make sure that we don't try to lock the backend while
* we are holding a lock on ourselves.
*/
private[spark] class ClusterScheduler(val sc: SparkContext)
extends TaskScheduler
with Logging
{
// How often to check for speculative tasks
val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
// Threshold above which we warn user initial TaskSet may be starved
val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong
// ClusterTaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
val activeTaskSets = new HashMap[String, ClusterTaskSetManager]
val taskIdToTaskSetId = new HashMap[Long, String]
val taskIdToExecutorId = new HashMap[Long, String]
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
@volatile private var hasReceivedTask = false
@volatile private var hasLaunchedTask = false
private val starvationTimer = new Timer(true)
// Incrementing task IDs
val nextTaskId = new AtomicLong(0)
// Which executor IDs we have executors on
val activeExecutorIds = new HashSet[String]
// The set of executors we have on each host; this is used to compute hostsAlive, which
// in turn is used to decide when we can attain data locality on a given host
private val executorsByHost = new HashMap[String, HashSet[String]]
private val executorIdToHost = new HashMap[String, String]
// Listener object to pass upcalls into
var dagScheduler: DAGScheduler = null
var backend: SchedulerBackend = null
val mapOutputTracker = SparkEnv.get.mapOutputTracker
var schedulableBuilder: SchedulableBuilder = null
var rootPool: Pool = null
// default scheduler is FIFO
val schedulingMode: SchedulingMode = SchedulingMode.withName(
System.getProperty("spark.scheduler.mode", "FIFO"))
// This is a var so that we can reset it for testing purposes.
private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)
override def setDAGScheduler(dagScheduler: DAGScheduler) {
this.dagScheduler = dagScheduler
}
def initialize(context: SchedulerBackend) {
backend = context
// temporarily set rootPool name to empty
rootPool = new Pool("", schedulingMode, 0, 0)
schedulableBuilder = {
schedulingMode match {
case SchedulingMode.FIFO =>
new FIFOSchedulableBuilder(rootPool)
case SchedulingMode.FAIR =>
new FairSchedulableBuilder(rootPool)
}
}
schedulableBuilder.buildPools()
}
def newTaskId(): Long = nextTaskId.getAndIncrement()
override def start() {
backend.start()
if (System.getProperty("spark.speculation", "false").toBoolean) {
new Thread("ClusterScheduler speculation check") {
setDaemon(true)
override def run() {
logInfo("Starting speculative execution thread")
while (true) {
try {
Thread.sleep(SPECULATION_INTERVAL)
} catch {
case e: InterruptedException => {}
}
checkSpeculatableTasks()
}
}
}.start()
}
}
override def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = new ClusterTaskSetManager(this, taskSet)
activeTaskSets(taskSet.id) = manager
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
if (!hasReceivedTask) {
starvationTimer.scheduleAtFixedRate(new TimerTask() {
override def run() {
if (!hasLaunchedTask) {
logWarning("Initial job has not accepted any resources; " +
"check your cluster UI to ensure that workers are registered " +
"and have sufficient memory")
} else {
this.cancel()
}
}
}, STARVATION_TIMEOUT, STARVATION_TIMEOUT)
}
hasReceivedTask = true
}
backend.reviveOffers()
}
override def cancelTasks(stageId: Int): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
// There are two possible cases here:
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task and then abort
// the stage.
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
// simply abort the stage.
val taskIds = taskSetTaskIds(tsm.taskSet.id)
if (taskIds.size > 0) {
taskIds.foreach { tid =>
val execId = taskIdToExecutorId(tid)
backend.killTask(tid, execId)
}
}
tsm.error("Stage %d was cancelled".format(stageId))
}
}
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
// Check to see if the given task set has been removed. This is possible in the case of
// multiple unrecoverable task failures (e.g. if the entire task set is killed when it has
// more than one running tasks).
if (activeTaskSets.contains(manager.taskSet.id)) {
activeTaskSets -= manager.taskSet.id
manager.parent.removeSchedulable(manager)
logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
taskSetTaskIds.remove(manager.taskSet.id)
}
}
/**
* Called by cluster manager to offer resources on slaves. We respond by asking our active task
* sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so
* that tasks are balanced across the cluster.
*/
def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
SparkEnv.set(sc.env)
// Mark each slave as alive and remember its hostname
for (o <- offers) {
executorIdToHost(o.executorId) = o.host
if (!executorsByHost.contains(o.host)) {
executorsByHost(o.host) = new HashSet[String]()
executorGained(o.executorId, o.host)
}
}
// Build a list of tasks to assign to each worker
val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
val availableCpus = offers.map(o => o.cores).toArray
val sortedTaskSets = rootPool.getSortedTaskSetQueue()
for (taskSet <- sortedTaskSets) {
logDebug("parentName: %s, name: %s, runningTasks: %s".format(
taskSet.parent.name, taskSet.name, taskSet.runningTasks))
}
// Take each TaskSet in our scheduling order, and then offer it each node in increasing order
// of locality levels so that it gets a chance to launch local tasks on all of them.
var launchedTask = false
for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) {
do {
launchedTask = false
for (i <- 0 until offers.size) {
val execId = offers(i).executorId
val host = offers(i).host
for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) {
tasks(i) += task
val tid = task.taskId
taskIdToTaskSetId(tid) = taskSet.taskSet.id
taskSetTaskIds(taskSet.taskSet.id) += tid
taskIdToExecutorId(tid) = execId
activeExecutorIds += execId
executorsByHost(host) += execId
availableCpus(i) -= 1
launchedTask = true
}
}
} while (launchedTask)
}
if (tasks.size > 0) {
hasLaunchedTask = true
}
return tasks
}
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
var failedExecutor: Option[String] = None
var taskFailed = false
synchronized {
try {
if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
// We lost this entire executor, so remember that it's gone
val execId = taskIdToExecutorId(tid)
if (activeExecutorIds.contains(execId)) {
removeExecutor(execId)
failedExecutor = Some(execId)
}
}
taskIdToTaskSetId.get(tid) match {
case Some(taskSetId) =>
if (TaskState.isFinished(state)) {
taskIdToTaskSetId.remove(tid)
if (taskSetTaskIds.contains(taskSetId)) {
taskSetTaskIds(taskSetId) -= tid
}
taskIdToExecutorId.remove(tid)
}
if (state == TaskState.FAILED) {
taskFailed = true
}
activeTaskSets.get(taskSetId).foreach { taskSet =>
if (state == TaskState.FINISHED) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
}
}
case None =>
logInfo("Ignoring update from TID " + tid + " because its task set is gone")
}
} catch {
case e: Exception => logError("Exception in statusUpdate", e)
}
}
// Update the DAGScheduler without holding a lock on this, since that can deadlock
if (failedExecutor != None) {
dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers()
}
if (taskFailed) {
// Also revive offers if a task had failed for some reason other than host lost
backend.reviveOffers()
}
}
def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) {
taskSetManager.handleTaskGettingResult(tid)
}
def handleSuccessfulTask(
taskSetManager: ClusterTaskSetManager,
tid: Long,
taskResult: DirectTaskResult[_]) = synchronized {
taskSetManager.handleSuccessfulTask(tid, taskResult)
}
def handleFailedTask(
taskSetManager: ClusterTaskSetManager,
tid: Long,
taskState: TaskState,
reason: Option[TaskEndReason]) = synchronized {
taskSetManager.handleFailedTask(tid, taskState, reason)
if (taskState == TaskState.FINISHED) {
// The task finished successfully but the result was lost, so we should revive offers.
backend.reviveOffers()
}
}
def error(message: String) {
synchronized {
if (activeTaskSets.size > 0) {
// Have each task set throw a SparkException with the error
for ((taskSetId, manager) <- activeTaskSets) {
try {
manager.error(message)
} catch {
case e: Exception => logError("Exception in error callback", e)
}
}
} else {
// No task sets are active but we still got an error. Just exit since this
// must mean the error is during registration.
// It might be good to do something smarter here in the future.
logError("Exiting due to error from cluster scheduler: " + message)
System.exit(1)
}
}
}
override def stop() {
if (backend != null) {
backend.stop()
}
if (taskResultGetter != null) {
taskResultGetter.stop()
}
// sleeping for an arbitrary 5 seconds : to ensure that messages are sent out.
// TODO: Do something better !
Thread.sleep(5000L)
}
override def defaultParallelism() = backend.defaultParallelism()
// Check for speculatable tasks in all our active jobs.
def checkSpeculatableTasks() {
var shouldRevive = false
synchronized {
shouldRevive = rootPool.checkSpeculatableTasks()
}
if (shouldRevive) {
backend.reviveOffers()
}
}
// Check for pending tasks in all our active jobs.
def hasPendingTasks: Boolean = {
synchronized {
rootPool.hasPendingTasks()
}
}
def executorLost(executorId: String, reason: ExecutorLossReason) {
var failedExecutor: Option[String] = None
synchronized {
if (activeExecutorIds.contains(executorId)) {
val hostPort = executorIdToHost(executorId)
logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason))
removeExecutor(executorId)
failedExecutor = Some(executorId)
} else {
// We may get multiple executorLost() calls with different loss reasons. For example, one
// may be triggered by a dropped connection from the slave while another may be a report
// of executor termination from Mesos. We produce log messages for both so we eventually
// report the termination reason.
logError("Lost an executor " + executorId + " (already removed): " + reason)
}
}
// Call dagScheduler.executorLost without holding the lock on this to prevent deadlock
if (failedExecutor != None) {
dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers()
}
}
/** Remove an executor from all our data structures and mark it as lost */
private def removeExecutor(executorId: String) {
activeExecutorIds -= executorId
val host = executorIdToHost(executorId)
val execs = executorsByHost.getOrElse(host, new HashSet)
execs -= executorId
if (execs.isEmpty) {
executorsByHost -= host
}
executorIdToHost -= executorId
rootPool.executorLost(executorId, host)
}
def executorGained(execId: String, host: String) {
dagScheduler.executorGained(execId, host)
}
def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized {
executorsByHost.get(host).map(_.toSet)
}
def hasExecutorsAliveOnHost(host: String): Boolean = synchronized {
executorsByHost.contains(host)
}
def isExecutorAlive(execId: String): Boolean = synchronized {
activeExecutorIds.contains(execId)
}
// By default, rack is unknown
def getRackForHost(value: String): Option[String] = None
}
object ClusterScheduler {
/**
* Used to balance containers across hosts.
*
* Accepts a map of hosts to resource offers for that host, and returns a prioritized list of
* resource offers representing the order in which the offers should be used. The resource
* offers are ordered such that we'll allocate one container on each host before allocating a
* second container on any host, and so on, in order to reduce the damage if a host fails.
*
* For example, given <h1, [o1, o2, o3]>, <h2, [o4]>, <h1, [o5, o6]>, returns
* [o1, o5, o4, 02, o6, o3]
*/
def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = {
val _keyList = new ArrayBuffer[K](map.size)
_keyList ++= map.keys
// order keyList based on population of value in map
val keyList = _keyList.sortWith(
(left, right) => map(left).size > map(right).size
)
val retval = new ArrayBuffer[T](keyList.size * 2)
var index = 0
var found = true
while (found) {
found = false
for (key <- keyList) {
val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null)
assert(containerList != null)
// Get the index'th entry for this host - if present
if (index < containerList.size){
retval += containerList.apply(index)
found = true
}
}
index += 1
}
retval.toList
}
}

View file

@ -1,703 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.scheduler.cluster
import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.math.max
import scala.math.min
import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler._
import org.apache.spark.util.{SystemClock, Clock}
/**
* Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of
* the status of each task, retries tasks if they fail (up to a limited number of times), and
* handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces
* to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node,
* and statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
*
* THREADING: This class is designed to only be called from code with a lock on the
* ClusterScheduler (e.g. its event handlers). It should not be called from other threads.
*/
private[spark] class ClusterTaskSetManager(
sched: ClusterScheduler,
val taskSet: TaskSet,
clock: Clock = SystemClock)
extends TaskSetManager
with Logging
{
// CPUs to request per task
val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt
// Maximum times a task is allowed to fail before failing the job
val MAX_TASK_FAILURES = System.getProperty("spark.task.maxFailures", "4").toInt
// Quantile of tasks at which to start speculation
val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
// Serializer for closures and tasks.
val env = SparkEnv.get
val ser = env.closureSerializer.newInstance()
val tasks = taskSet.tasks
val numTasks = tasks.length
val copiesRunning = new Array[Int](numTasks)
val successful = new Array[Boolean](numTasks)
val numFailures = new Array[Int](numTasks)
val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
var tasksSuccessful = 0
var weight = 1
var minShare = 0
var priority = taskSet.priority
var stageId = taskSet.stageId
var name = "TaskSet_"+taskSet.stageId.toString
var parent: Pool = null
var runningTasks = 0
private val runningTasksSet = new HashSet[Long]
// Set of pending tasks for each executor. These collections are actually
// treated as stacks, in which new tasks are added to the end of the
// ArrayBuffer and removed from the end. This makes it faster to detect
// tasks that repeatedly fail because whenever a task failed, it is put
// back at the head of the stack. They are also only cleaned up lazily;
// when a task is launched, it remains in all the pending lists except
// the one that it was launched from, but gets removed from them later.
private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]]
// Set of pending tasks for each host. Similar to pendingTasksForExecutor,
// but at host level.
private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
// Set of pending tasks for each rack -- similar to the above.
private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]]
// Set containing pending tasks with no locality preferences.
val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
// Set containing all pending tasks (also used as a stack, as above).
val allPendingTasks = new ArrayBuffer[Int]
// Tasks that can be speculated. Since these will be a small fraction of total
// tasks, we'll just hold them in a HashSet.
val speculatableTasks = new HashSet[Int]
// Task index, start and finish time for each task attempt (indexed by task ID)
val taskInfos = new HashMap[Long, TaskInfo]
// Did the TaskSet fail?
var failed = false
var causeOfFailure = ""
// How frequently to reprint duplicate exceptions in full, in milliseconds
val EXCEPTION_PRINT_INTERVAL =
System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
// Map of recent exceptions (identified by string representation and top stack frame) to
// duplicate count (how many times the same exception has appeared) and time the full exception
// was printed. This should ideally be an LRU map that can drop old exceptions automatically.
val recentExceptions = HashMap[String, (Int, Long)]()
// Figure out the current map output tracker epoch and set it on all tasks
val epoch = sched.mapOutputTracker.getEpoch
logDebug("Epoch for " + taskSet + ": " + epoch)
for (t <- tasks) {
t.epoch = epoch
}
// Add all our tasks to the pending lists. We do this in reverse order
// of task index so that tasks with low indices get launched first.
for (i <- (0 until numTasks).reverse) {
addPendingTask(i)
}
// Figure out which locality levels we have in our TaskSet, so we can do delay scheduling
val myLocalityLevels = computeValidLocalityLevels()
val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level
// Delay scheduling variables: we keep track of our current locality level and the time we
// last launched a task at that level, and move up a level when localityWaits[curLevel] expires.
// We then move down if we manage to launch a "more local" task.
var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels
var lastLaunchTime = clock.getTime() // Time we last launched a task at this level
/**
* Add a task to all the pending-task lists that it should be on. If readding is set, we are
* re-adding the task so only include it in each list if it's not already there.
*/
private def addPendingTask(index: Int, readding: Boolean = false) {
// Utility method that adds `index` to a list only if readding=false or it's not already there
def addTo(list: ArrayBuffer[Int]) {
if (!readding || !list.contains(index)) {
list += index
}
}
var hadAliveLocations = false
for (loc <- tasks(index).preferredLocations) {
for (execId <- loc.executorId) {
if (sched.isExecutorAlive(execId)) {
addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
hadAliveLocations = true
}
}
if (sched.hasExecutorsAliveOnHost(loc.host)) {
addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
for (rack <- sched.getRackForHost(loc.host)) {
addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer))
}
hadAliveLocations = true
}
}
if (!hadAliveLocations) {
// Even though the task might've had preferred locations, all of those hosts or executors
// are dead; put it in the no-prefs list so we can schedule it elsewhere right away.
addTo(pendingTasksWithNoPrefs)
}
if (!readding) {
allPendingTasks += index // No point scanning this whole list to find the old task there
}
}
/**
* Return the pending tasks list for a given executor ID, or an empty list if
* there is no map entry for that host
*/
private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = {
pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer())
}
/**
* Return the pending tasks list for a given host, or an empty list if
* there is no map entry for that host
*/
private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
pendingTasksForHost.getOrElse(host, ArrayBuffer())
}
/**
* Return the pending rack-local task list for a given rack, or an empty list if
* there is no map entry for that rack
*/
private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = {
pendingTasksForRack.getOrElse(rack, ArrayBuffer())
}
/**
* Dequeue a pending task from the given list and return its index.
* Return None if the list is empty.
* This method also cleans up any tasks in the list that have already
* been launched, since we want that to happen lazily.
*/
private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
while (!list.isEmpty) {
val index = list.last
list.trimEnd(1)
if (copiesRunning(index) == 0 && !successful(index)) {
return Some(index)
}
}
return None
}
/** Check whether a task is currently running an attempt on a given host */
private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = {
!taskAttempts(taskIndex).exists(_.host == host)
}
/**
* Return a speculative task for a given executor if any are available. The task should not have
* an attempt running on this host, in case the host is slow. In addition, the task should meet
* the given locality constraint.
*/
private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
: Option[(Int, TaskLocality.Value)] =
{
speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
if (!speculatableTasks.isEmpty) {
// Check for process-local or preference-less tasks; note that tasks can be process-local
// on multiple nodes when we replicate cached blocks, as in Spark Streaming
for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
val prefs = tasks(index).preferredLocations
val executors = prefs.flatMap(_.executorId)
if (prefs.size == 0 || executors.contains(execId)) {
speculatableTasks -= index
return Some((index, TaskLocality.PROCESS_LOCAL))
}
}
// Check for node-local tasks
if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
val locations = tasks(index).preferredLocations.map(_.host)
if (locations.contains(host)) {
speculatableTasks -= index
return Some((index, TaskLocality.NODE_LOCAL))
}
}
}
// Check for rack-local tasks
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
for (rack <- sched.getRackForHost(host)) {
for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost)
if (racks.contains(rack)) {
speculatableTasks -= index
return Some((index, TaskLocality.RACK_LOCAL))
}
}
}
}
// Check for non-local tasks
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
speculatableTasks -= index
return Some((index, TaskLocality.ANY))
}
}
}
return None
}
/**
* Dequeue a pending task for a given node and return its index and locality level.
* Only search for tasks matching the given locality constraint.
*/
private def findTask(execId: String, host: String, locality: TaskLocality.Value)
: Option[(Int, TaskLocality.Value)] =
{
for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) {
return Some((index, TaskLocality.PROCESS_LOCAL))
}
if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
for (index <- findTaskFromList(getPendingTasksForHost(host))) {
return Some((index, TaskLocality.NODE_LOCAL))
}
}
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
for {
rack <- sched.getRackForHost(host)
index <- findTaskFromList(getPendingTasksForRack(rack))
} {
return Some((index, TaskLocality.RACK_LOCAL))
}
}
// Look for no-pref tasks after rack-local tasks since they can run anywhere.
for (index <- findTaskFromList(pendingTasksWithNoPrefs)) {
return Some((index, TaskLocality.PROCESS_LOCAL))
}
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
for (index <- findTaskFromList(allPendingTasks)) {
return Some((index, TaskLocality.ANY))
}
}
// Finally, if all else has failed, find a speculative task
return findSpeculativeTask(execId, host, locality)
}
/**
* Respond to an offer of a single executor from the scheduler by finding a task
*/
override def resourceOffer(
execId: String,
host: String,
availableCpus: Int,
maxLocality: TaskLocality.TaskLocality)
: Option[TaskDescription] =
{
if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) {
val curTime = clock.getTime()
var allowedLocality = getAllowedLocalityLevel(curTime)
if (allowedLocality > maxLocality) {
allowedLocality = maxLocality // We're not allowed to search for farther-away tasks
}
findTask(execId, host, allowedLocality) match {
case Some((index, taskLocality)) => {
// Found a task; do some bookkeeping and return a task description
val task = tasks(index)
val taskId = sched.newTaskId()
// Figure out whether this should count as a preferred launch
logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
taskSet.id, index, taskId, execId, host, taskLocality))
// Do various bookkeeping
copiesRunning(index) += 1
val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality)
taskInfos(taskId) = info
taskAttempts(index) = info :: taskAttempts(index)
// Update our locality level for delay scheduling
currentLocalityIndex = getLocalityIndex(taskLocality)
lastLaunchTime = curTime
// Serialize and return the task
val startTime = clock.getTime()
// We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
// we assume the task can be serialized without exceptions.
val serializedTask = Task.serializeWithDependencies(
task, sched.sc.addedFiles, sched.sc.addedJars, ser)
val timeTaken = clock.getTime() - startTime
addRunningTask(taskId)
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
val taskName = "task %s:%d".format(taskSet.id, index)
if (taskAttempts(index).size == 1)
taskStarted(task,info)
return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
}
case _ =>
}
}
return None
}
/**
* Get the level we can launch tasks according to delay scheduling, based on current wait time.
*/
private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = {
while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) &&
currentLocalityIndex < myLocalityLevels.length - 1)
{
// Jump to the next locality level, and remove our waiting time for the current one since
// we don't want to count it again on the next one
lastLaunchTime += localityWaits(currentLocalityIndex)
currentLocalityIndex += 1
}
myLocalityLevels(currentLocalityIndex)
}
/**
* Find the index in myLocalityLevels for a given locality. This is also designed to work with
* localities that are not in myLocalityLevels (in case we somehow get those) by returning the
* next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY.
*/
def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = {
var index = 0
while (locality > myLocalityLevels(index)) {
index += 1
}
index
}
private def taskStarted(task: Task[_], info: TaskInfo) {
sched.dagScheduler.taskStarted(task, info)
}
def handleTaskGettingResult(tid: Long) = {
val info = taskInfos(tid)
info.markGettingResult()
sched.dagScheduler.taskGettingResult(tasks(info.index), info)
}
/**
* Marks the task as successful and notifies the DAGScheduler that a task has ended.
*/
def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
val info = taskInfos(tid)
val index = info.index
info.markSuccessful()
removeRunningTask(tid)
if (!successful(index)) {
logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
tid, info.duration, info.host, tasksSuccessful, numTasks))
sched.dagScheduler.taskEnded(
tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
// Mark successful and stop if all the tasks have succeeded.
tasksSuccessful += 1
successful(index) = true
if (tasksSuccessful == numTasks) {
sched.taskSetFinished(this)
}
} else {
logInfo("Ignorning task-finished event for TID " + tid + " because task " +
index + " has already completed successfully")
}
}
/**
* Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
* DAG Scheduler.
*/
def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
val info = taskInfos(tid)
if (info.failed) {
return
}
removeRunningTask(tid)
val index = info.index
info.markFailed()
if (!successful(index)) {
logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
copiesRunning(index) -= 1
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
reason.foreach {
case fetchFailed: FetchFailed =>
logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
successful(index) = true
tasksSuccessful += 1
sched.taskSetFinished(this)
removeAllRunningTasks()
return
case TaskKilled =>
logWarning("Task %d was killed.".format(tid))
sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
return
case ef: ExceptionFailure =>
sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
val key = ef.description
val now = clock.getTime()
val (printFull, dupCount) = {
if (recentExceptions.contains(key)) {
val (dupCount, printTime) = recentExceptions(key)
if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
recentExceptions(key) = (0, now)
(true, 0)
} else {
recentExceptions(key) = (dupCount + 1, printTime)
(false, dupCount + 1)
}
} else {
recentExceptions(key) = (0, now)
(true, 0)
}
}
if (printFull) {
val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
logWarning("Loss was due to %s\n%s\n%s".format(
ef.className, ef.description, locs.mkString("\n")))
} else {
logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
}
case TaskResultLost =>
logWarning("Lost result for TID %s on host %s".format(tid, info.host))
sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
case _ => {}
}
// On non-fetch failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index)
if (state != TaskState.KILLED) {
numFailures(index) += 1
if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %s:%d failed more than %d times; aborting job".format(
taskSet.id, index, MAX_TASK_FAILURES))
abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
}
}
} else {
logInfo("Ignoring task-lost event for TID " + tid +
" because task " + index + " is already finished")
}
}
override def error(message: String) {
// Save the error message
abort("Error: " + message)
}
def abort(message: String) {
failed = true
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
sched.dagScheduler.taskSetFailed(taskSet, message)
removeAllRunningTasks()
sched.taskSetFinished(this)
}
/** If the given task ID is not in the set of running tasks, adds it.
*
* Used to keep track of the number of running tasks, for enforcing scheduling policies.
*/
def addRunningTask(tid: Long) {
if (runningTasksSet.add(tid) && parent != null) {
parent.increaseRunningTasks(1)
}
runningTasks = runningTasksSet.size
}
/** If the given task ID is in the set of running tasks, removes it. */
def removeRunningTask(tid: Long) {
if (runningTasksSet.remove(tid) && parent != null) {
parent.decreaseRunningTasks(1)
}
runningTasks = runningTasksSet.size
}
private def removeAllRunningTasks() {
val numRunningTasks = runningTasksSet.size
runningTasksSet.clear()
if (parent != null) {
parent.decreaseRunningTasks(numRunningTasks)
}
runningTasks = 0
}
override def getSchedulableByName(name: String): Schedulable = {
return null
}
override def addSchedulable(schedulable: Schedulable) {}
override def removeSchedulable(schedulable: Schedulable) {}
override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this)
sortedTaskSetQueue += this
return sortedTaskSetQueue
}
/** Called by cluster scheduler when an executor is lost so we can re-enqueue our tasks */
override def executorLost(execId: String, host: String) {
logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
// Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a
// task that used to have locations on only this host might now go to the no-prefs list. Note
// that it's okay if we add a task to the same queue twice (if it had multiple preferred
// locations), because findTaskFromList will skip already-running tasks.
for (index <- getPendingTasksForExecutor(execId)) {
addPendingTask(index, readding=true)
}
for (index <- getPendingTasksForHost(host)) {
addPendingTask(index, readding=true)
}
// Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
if (tasks(0).isInstanceOf[ShuffleMapTask]) {
for ((tid, info) <- taskInfos if info.executorId == execId) {
val index = taskInfos(tid).index
if (successful(index)) {
successful(index) = false
copiesRunning(index) -= 1
tasksSuccessful -= 1
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
}
}
}
// Also re-enqueue any tasks that were running on the node
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
handleFailedTask(tid, TaskState.KILLED, None)
}
}
/**
* Check for tasks to be speculated and return true if there are any. This is called periodically
* by the ClusterScheduler.
*
* TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
* we don't scan the whole task set. It might also help to make this sorted by launch time.
*/
override def checkSpeculatableTasks(): Boolean = {
// Can't speculate if we only have one task, or if all tasks have finished.
if (numTasks == 1 || tasksSuccessful == numTasks) {
return false
}
var foundTasks = false
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
val time = clock.getTime()
val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
Arrays.sort(durations)
val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1))
val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
// TODO: Threshold should also look at standard deviation of task durations and have a lower
// bound based on that.
logDebug("Task length threshold for speculation: " + threshold)
for ((tid, info) <- taskInfos) {
val index = info.index
if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
!speculatableTasks.contains(index)) {
logInfo(
"Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
taskSet.id, index, info.host, threshold))
speculatableTasks += index
foundTasks = true
}
}
}
return foundTasks
}
override def hasPendingTasks(): Boolean = {
numTasks > 0 && tasksSuccessful < numTasks
}
private def getLocalityWait(level: TaskLocality.TaskLocality): Long = {
val defaultWait = System.getProperty("spark.locality.wait", "3000")
level match {
case TaskLocality.PROCESS_LOCAL =>
System.getProperty("spark.locality.wait.process", defaultWait).toLong
case TaskLocality.NODE_LOCAL =>
System.getProperty("spark.locality.wait.node", defaultWait).toLong
case TaskLocality.RACK_LOCAL =>
System.getProperty("spark.locality.wait.rack", defaultWait).toLong
case TaskLocality.ANY =>
0L
}
}
/**
* Compute the locality levels used in this TaskSet. Assumes that all tasks have already been
* added to queues using addPendingTask.
*/
private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY}
val levels = new ArrayBuffer[TaskLocality.TaskLocality]
if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) {
levels += PROCESS_LOCAL
}
if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) {
levels += NODE_LOCAL
}
if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) {
levels += RACK_LOCAL
}
levels += ANY
logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", "))
levels.toArray
}
}

View file

@ -29,7 +29,8 @@ import akka.util.Duration
import akka.util.duration._ import akka.util.duration._
import org.apache.spark.{SparkException, Logging, TaskState} import org.apache.spark.{SparkException, Logging, TaskState}
import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskScheduler,
WorkerOffer}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
@ -42,7 +43,7 @@ import org.apache.spark.util.Utils
* (spark.deploy.*). * (spark.deploy.*).
*/ */
private[spark] private[spark]
class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem) class CoarseGrainedSchedulerBackend(scheduler: TaskScheduler, actorSystem: ActorSystem)
extends SchedulerBackend with Logging extends SchedulerBackend with Logging
{ {
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed // Use an atomic variable to track total number of cores in the cluster for simplicity and speed

View file

@ -19,10 +19,12 @@ package org.apache.spark.scheduler.cluster
import org.apache.hadoop.conf.Configuration import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem} import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.spark.{Logging, SparkContext} import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.scheduler.TaskScheduler
private[spark] class SimrSchedulerBackend( private[spark] class SimrSchedulerBackend(
scheduler: ClusterScheduler, scheduler: TaskScheduler,
sc: SparkContext, sc: SparkContext,
driverFilePath: String) driverFilePath: String)
extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)

View file

@ -17,14 +17,16 @@
package org.apache.spark.scheduler.cluster package org.apache.spark.scheduler.cluster
import scala.collection.mutable.HashMap
import org.apache.spark.{Logging, SparkContext} import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.deploy.client.{Client, ClientListener} import org.apache.spark.deploy.client.{Client, ClientListener}
import org.apache.spark.deploy.{Command, ApplicationDescription} import org.apache.spark.deploy.{Command, ApplicationDescription}
import scala.collection.mutable.HashMap import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskScheduler}
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
private[spark] class SparkDeploySchedulerBackend( private[spark] class SparkDeploySchedulerBackend(
scheduler: ClusterScheduler, scheduler: TaskScheduler,
sc: SparkContext, sc: SparkContext,
masters: Array[String], masters: Array[String],
appName: String) appName: String)

View file

@ -30,7 +30,8 @@ import org.apache.mesos._
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
import org.apache.spark.{SparkException, Logging, SparkContext, TaskState} import org.apache.spark.{SparkException, Logging, SparkContext, TaskState}
import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend} import org.apache.spark.scheduler.TaskScheduler
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
/** /**
* A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
@ -43,7 +44,7 @@ import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedu
* remove this. * remove this.
*/ */
private[spark] class CoarseMesosSchedulerBackend( private[spark] class CoarseMesosSchedulerBackend(
scheduler: ClusterScheduler, scheduler: TaskScheduler,
sc: SparkContext, sc: SparkContext,
master: String, master: String,
appName: String) appName: String)

View file

@ -30,9 +30,8 @@ import org.apache.mesos._
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
import org.apache.spark.{Logging, SparkException, SparkContext, TaskState} import org.apache.spark.{Logging, SparkException, SparkContext, TaskState}
import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SchedulerBackend, SlaveLost,
import org.apache.spark.scheduler.cluster.{ClusterScheduler, ExecutorExited, ExecutorLossReason} TaskDescription, TaskScheduler, WorkerOffer}
import org.apache.spark.scheduler.cluster.{SchedulerBackend, SlaveLost, WorkerOffer}
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
/** /**
@ -41,7 +40,7 @@ import org.apache.spark.util.Utils
* from multiple apps can run on different cores) and in time (a core can switch ownership). * from multiple apps can run on different cores) and in time (a core can switch ownership).
*/ */
private[spark] class MesosSchedulerBackend( private[spark] class MesosSchedulerBackend(
scheduler: ClusterScheduler, scheduler: TaskScheduler,
sc: SparkContext, sc: SparkContext,
master: String, master: String,
appName: String) appName: String)
@ -210,7 +209,7 @@ private[spark] class MesosSchedulerBackend(
getResource(offer.getResourcesList, "cpus").toInt) getResource(offer.getResourcesList, "cpus").toInt)
} }
// Call into the ClusterScheduler // Call into the TaskScheduler
val taskLists = scheduler.resourceOffers(offerableWorkers) val taskLists = scheduler.resourceOffers(offerableWorkers)
// Build a list of Mesos tasks for each slave // Build a list of Mesos tasks for each slave

View file

@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.scheduler.local
import java.nio.ByteBuffer
import akka.actor.{Actor, ActorRef, Props}
import org.apache.spark.{SparkContext, SparkEnv, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.{Executor, ExecutorBackend}
import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, WorkerOffer}
/**
* LocalBackend sits behind a TaskScheduler and handles launching tasks on a single Executor
* (created by the LocalBackend) running locally.
*
* THREADING: Because methods can be called both from the Executor and the TaskScheduler, and
* because the Executor class is not thread safe, all methods are synchronized.
*/
private[spark] class LocalBackend(scheduler: TaskScheduler, private val totalCores: Int)
extends SchedulerBackend with ExecutorBackend {
private var freeCores = totalCores
private val localExecutorId = "localhost"
private val localExecutorHostname = "localhost"
val executor = new Executor(localExecutorId, localExecutorHostname, Seq.empty, isLocal = true)
override def start() {
}
override def stop() {
}
override def reviveOffers() = synchronized {
val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores))
for (task <- scheduler.resourceOffers(offers).flatten) {
freeCores -= 1
executor.launchTask(this, task.taskId, task.serializedTask)
}
}
override def defaultParallelism() = totalCores
override def killTask(taskId: Long, executorId: String) = synchronized {
executor.killTask(taskId)
}
override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) = synchronized {
scheduler.statusUpdate(taskId, state, serializedData)
if (TaskState.isFinished(state)) {
freeCores += 1
reviveOffers()
}
}
}

View file

@ -1,219 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.scheduler.local
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import akka.actor._
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.{Executor, ExecutorBackend}
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
/**
* A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
* the scheduler also allows each task to fail up to maxFailures times, which is useful for
* testing fault recovery.
*/
private[local]
case class LocalReviveOffers()
private[local]
case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
private[local]
case class KillTask(taskId: Long)
private[spark]
class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int)
extends Actor with Logging {
val executor = new Executor("localhost", "localhost", Seq.empty, isLocal = true)
def receive = {
case LocalReviveOffers =>
launchTask(localScheduler.resourceOffer(freeCores))
case LocalStatusUpdate(taskId, state, serializeData) =>
if (TaskState.isFinished(state)) {
freeCores += 1
launchTask(localScheduler.resourceOffer(freeCores))
}
case KillTask(taskId) =>
executor.killTask(taskId)
}
private def launchTask(tasks: Seq[TaskDescription]) {
for (task <- tasks) {
freeCores -= 1
executor.launchTask(localScheduler, task.taskId, task.serializedTask)
}
}
}
private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
extends TaskScheduler
with ExecutorBackend
with Logging {
val env = SparkEnv.get
val attemptId = new AtomicInteger
var dagScheduler: DAGScheduler = null
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
var schedulableBuilder: SchedulableBuilder = null
var rootPool: Pool = null
val schedulingMode: SchedulingMode = SchedulingMode.withName(
System.getProperty("spark.scheduler.mode", "FIFO"))
val activeTaskSets = new HashMap[String, LocalTaskSetManager]
val taskIdToTaskSetId = new HashMap[Long, String]
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
var localActor: ActorRef = null
override def start() {
// temporarily set rootPool name to empty
rootPool = new Pool("", schedulingMode, 0, 0)
schedulableBuilder = {
schedulingMode match {
case SchedulingMode.FIFO =>
new FIFOSchedulableBuilder(rootPool)
case SchedulingMode.FAIR =>
new FairSchedulableBuilder(rootPool)
}
}
schedulableBuilder.buildPools()
localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
}
override def setDAGScheduler(dagScheduler: DAGScheduler) {
this.dagScheduler = dagScheduler
}
override def submitTasks(taskSet: TaskSet) {
synchronized {
val manager = new LocalTaskSetManager(this, taskSet)
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
activeTaskSets(taskSet.id) = manager
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
localActor ! LocalReviveOffers
}
}
override def cancelTasks(stageId: Int): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId))
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
// There are two possible cases here:
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task and then abort
// the stage.
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
// simply abort the stage.
val taskIds = taskSetTaskIds(tsm.taskSet.id)
if (taskIds.size > 0) {
taskIds.foreach { tid =>
localActor ! KillTask(tid)
}
}
tsm.error("Stage %d was cancelled".format(stageId))
}
}
def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
synchronized {
var freeCpuCores = freeCores
val tasks = new ArrayBuffer[TaskDescription](freeCores)
val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
for (manager <- sortedTaskSetQueue) {
logDebug("parentName:%s,name:%s,runningTasks:%s".format(
manager.parent.name, manager.name, manager.runningTasks))
}
var launchTask = false
for (manager <- sortedTaskSetQueue) {
do {
launchTask = false
manager.resourceOffer(null, null, freeCpuCores, null) match {
case Some(task) =>
tasks += task
taskIdToTaskSetId(task.taskId) = manager.taskSet.id
taskSetTaskIds(manager.taskSet.id) += task.taskId
freeCpuCores -= 1
launchTask = true
case None => {}
}
} while(launchTask)
}
return tasks
}
}
def taskSetFinished(manager: TaskSetManager) {
synchronized {
activeTaskSets -= manager.taskSet.id
manager.parent.removeSchedulable(manager)
logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
taskSetTaskIds -= manager.taskSet.id
}
}
override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
if (TaskState.isFinished(state)) {
synchronized {
taskIdToTaskSetId.get(taskId) match {
case Some(taskSetId) =>
val taskSetManager = activeTaskSets(taskSetId)
taskSetTaskIds(taskSetId) -= taskId
state match {
case TaskState.FINISHED =>
taskSetManager.taskEnded(taskId, state, serializedData)
case TaskState.FAILED =>
taskSetManager.taskFailed(taskId, state, serializedData)
case TaskState.KILLED =>
taskSetManager.error("Task %d was killed".format(taskId))
case _ => {}
}
case None =>
logInfo("Ignoring update from TID " + taskId + " because its task set is gone")
}
}
localActor ! LocalStatusUpdate(taskId, state, serializedData)
}
}
override def stop() {
}
override def defaultParallelism() = threads
}

View file

@ -1,191 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.scheduler.local
import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import org.apache.spark.{ExceptionFailure, Logging, SparkEnv, SparkException, Success, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Pool, Schedulable, Task,
TaskDescription, TaskInfo, TaskLocality, TaskResult, TaskSet, TaskSetManager}
private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet)
extends TaskSetManager with Logging {
var parent: Pool = null
var weight: Int = 1
var minShare: Int = 0
var runningTasks: Int = 0
var priority: Int = taskSet.priority
var stageId: Int = taskSet.stageId
var name: String = "TaskSet_" + taskSet.stageId.toString
var failCount = new Array[Int](taskSet.tasks.size)
val taskInfos = new HashMap[Long, TaskInfo]
val numTasks = taskSet.tasks.size
var numFinished = 0
val env = SparkEnv.get
val ser = env.closureSerializer.newInstance()
val copiesRunning = new Array[Int](numTasks)
val finished = new Array[Boolean](numTasks)
val numFailures = new Array[Int](numTasks)
val MAX_TASK_FAILURES = sched.maxFailures
def increaseRunningTasks(taskNum: Int): Unit = {
runningTasks += taskNum
if (parent != null) {
parent.increaseRunningTasks(taskNum)
}
}
def decreaseRunningTasks(taskNum: Int): Unit = {
runningTasks -= taskNum
if (parent != null) {
parent.decreaseRunningTasks(taskNum)
}
}
override def addSchedulable(schedulable: Schedulable): Unit = {
// nothing
}
override def removeSchedulable(schedulable: Schedulable): Unit = {
// nothing
}
override def getSchedulableByName(name: String): Schedulable = {
return null
}
override def executorLost(executorId: String, host: String): Unit = {
// nothing
}
override def checkSpeculatableTasks() = true
override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
sortedTaskSetQueue += this
return sortedTaskSetQueue
}
override def hasPendingTasks() = true
def findTask(): Option[Int] = {
for (i <- 0 to numTasks-1) {
if (copiesRunning(i) == 0 && !finished(i)) {
return Some(i)
}
}
return None
}
override def resourceOffer(
execId: String,
host: String,
availableCpus: Int,
maxLocality: TaskLocality.TaskLocality)
: Option[TaskDescription] =
{
SparkEnv.set(sched.env)
logDebug("availableCpus:%d, numFinished:%d, numTasks:%d".format(
availableCpus.toInt, numFinished, numTasks))
if (availableCpus > 0 && numFinished < numTasks) {
findTask() match {
case Some(index) =>
val taskId = sched.attemptId.getAndIncrement()
val task = taskSet.tasks(index)
val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1",
TaskLocality.NODE_LOCAL)
taskInfos(taskId) = info
// We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
// we assume the task can be serialized without exceptions.
val bytes = Task.serializeWithDependencies(
task, sched.sc.addedFiles, sched.sc.addedJars, ser)
logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes")
val taskName = "task %s:%d".format(taskSet.id, index)
copiesRunning(index) += 1
increaseRunningTasks(1)
taskStarted(task, info)
return Some(new TaskDescription(taskId, null, taskName, index, bytes))
case None => {}
}
}
return None
}
def taskStarted(task: Task[_], info: TaskInfo) {
sched.dagScheduler.taskStarted(task, info)
}
def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
val index = info.index
val task = taskSet.tasks(index)
info.markSuccessful()
val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) match {
case directResult: DirectTaskResult[_] => directResult
case IndirectTaskResult(blockId) => {
throw new SparkException("Expect only DirectTaskResults when using LocalScheduler")
}
}
result.metrics.resultSize = serializedData.limit()
sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info,
result.metrics)
numFinished += 1
decreaseRunningTasks(1)
finished(index) = true
if (numFinished == numTasks) {
sched.taskSetFinished(this)
}
}
def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
val index = info.index
val task = taskSet.tasks(index)
info.markFailed()
decreaseRunningTasks(1)
val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](
serializedData, getClass.getClassLoader)
sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
if (!finished(index)) {
copiesRunning(index) -= 1
numFailures(index) += 1
val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString))
logInfo("Loss was due to %s\n%s\n%s".format(
reason.className, reason.description, locs.mkString("\n")))
if (numFailures(index) > MAX_TASK_FAILURES) {
val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(
taskSet.id, index, MAX_TASK_FAILURES, reason.description)
decreaseRunningTasks(runningTasks)
sched.dagScheduler.taskSetFailed(taskSet, errorMessage)
// need to delete failed Taskset from schedule queue
sched.taskSetFinished(this)
}
}
}
override def error(message: String) {
sched.dagScheduler.taskSetFailed(taskSet, message)
sched.taskSetFinished(this)
}
}

View file

@ -17,7 +17,7 @@
package org.apache.spark package org.apache.spark
import org.scalatest.FunSuite import org.scalatest.{BeforeAndAfterAll, FunSuite}
import SparkContext._ import SparkContext._
import org.apache.spark.util.NonSerializable import org.apache.spark.util.NonSerializable
@ -37,12 +37,20 @@ object FailureSuiteState {
} }
} }
class FailureSuite extends FunSuite with LocalSparkContext { class FailureSuite extends FunSuite with LocalSparkContext with BeforeAndAfterAll {
override def beforeAll {
System.setProperty("spark.task.maxFailures", "1")
}
override def afterAll {
System.clearProperty("spark.task.maxFailures")
}
// Run a 3-task map job in which task 1 deterministically fails once, and check // Run a 3-task map job in which task 1 deterministically fails once, and check
// whether the job completes successfully and we ran 4 tasks in total. // whether the job completes successfully and we ran 4 tasks in total.
test("failure in a single-stage job") { test("failure in a single-stage job") {
sc = new SparkContext("local[1,1]", "test") sc = new SparkContext("local[1]", "test")
val results = sc.makeRDD(1 to 3, 3).map { x => val results = sc.makeRDD(1 to 3, 3).map { x =>
FailureSuiteState.synchronized { FailureSuiteState.synchronized {
FailureSuiteState.tasksRun += 1 FailureSuiteState.tasksRun += 1
@ -62,7 +70,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
// Run a map-reduce job in which a reduce task deterministically fails once. // Run a map-reduce job in which a reduce task deterministically fails once.
test("failure in a two-stage job") { test("failure in a two-stage job") {
sc = new SparkContext("local[1,1]", "test") sc = new SparkContext("local[1]", "test")
val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map { val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map {
case (k, v) => case (k, v) =>
FailureSuiteState.synchronized { FailureSuiteState.synchronized {
@ -82,7 +90,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
} }
test("failure because task results are not serializable") { test("failure because task results are not serializable") {
sc = new SparkContext("local[1,1]", "test") sc = new SparkContext("local[1]", "test")
val results = sc.makeRDD(1 to 3).map(x => new NonSerializable) val results = sc.makeRDD(1 to 3).map(x => new NonSerializable)
val thrown = intercept[SparkException] { val thrown = intercept[SparkException] {
@ -95,7 +103,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
} }
test("failure because task closure is not serializable") { test("failure because task closure is not serializable") {
sc = new SparkContext("local[1,1]", "test") sc = new SparkContext("local[1]", "test")
val a = new NonSerializable val a = new NonSerializable
// Non-serializable closure in the final result stage // Non-serializable closure in the final result stage

View file

@ -19,23 +19,26 @@ package org.apache.spark.scheduler
import scala.collection.mutable.{Buffer, HashSet} import scala.collection.mutable.{Buffer, HashSet}
import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
import org.scalatest.matchers.ShouldMatchers import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.{LocalSparkContext, SparkContext} import org.apache.spark.{LocalSparkContext, SparkContext}
import org.apache.spark.SparkContext._ import org.apache.spark.SparkContext._
class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
with BeforeAndAfterAll { with BeforeAndAfter with BeforeAndAfterAll {
/** Length of time to wait while draining listener events. */ /** Length of time to wait while draining listener events. */
val WAIT_TIMEOUT_MILLIS = 10000 val WAIT_TIMEOUT_MILLIS = 10000
before {
sc = new SparkContext("local", "SparkListenerSuite")
}
override def afterAll { override def afterAll {
System.clearProperty("spark.akka.frameSize") System.clearProperty("spark.akka.frameSize")
} }
test("basic creation of StageInfo") { test("basic creation of StageInfo") {
sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo val listener = new SaveStageInfo
sc.addSparkListener(listener) sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4) val rdd1 = sc.parallelize(1 to 100, 4)
@ -56,7 +59,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
} }
test("StageInfo with fewer tasks than partitions") { test("StageInfo with fewer tasks than partitions") {
sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo val listener = new SaveStageInfo
sc.addSparkListener(listener) sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4) val rdd1 = sc.parallelize(1 to 100, 4)
@ -72,7 +74,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
} }
test("local metrics") { test("local metrics") {
sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo val listener = new SaveStageInfo
sc.addSparkListener(listener) sc.addSparkListener(listener)
sc.addSparkListener(new StatsReportListener) sc.addSparkListener(new StatsReportListener)
@ -135,10 +136,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
} }
test("onTaskGettingResult() called when result fetched remotely") { test("onTaskGettingResult() called when result fetched remotely") {
// Need to use local cluster mode here, because results are not ever returned through the
// block manager when using the LocalScheduler.
sc = new SparkContext("local-cluster[1,1,512]", "test")
val listener = new SaveTaskEvents val listener = new SaveTaskEvents
sc.addSparkListener(listener) sc.addSparkListener(listener)
@ -157,10 +154,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
} }
test("onTaskGettingResult() not called when result sent directly") { test("onTaskGettingResult() not called when result sent directly") {
// Need to use local cluster mode here, because results are not ever returned through the
// block manager when using the LocalScheduler.
sc = new SparkContext("local-cluster[1,1,512]", "test")
val listener = new SaveTaskEvents val listener = new SaveTaskEvents
sc.addSparkListener(listener) sc.addSparkListener(listener)

View file

@ -66,9 +66,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
} }
before { before {
// Use local-cluster mode because results are returned differently when running with the sc = new SparkContext("local", "test")
// LocalScheduler.
sc = new SparkContext("local-cluster[1,1,512]", "test")
} }
override def afterAll { override def afterAll {

View file

@ -1,227 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.scheduler.local
import java.util.concurrent.Semaphore
import java.util.concurrent.CountDownLatch
import scala.collection.mutable.HashMap
import org.scalatest.{BeforeAndAfterEach, FunSuite}
import org.apache.spark._
class Lock() {
var finished = false
def jobWait() = {
synchronized {
while(!finished) {
this.wait()
}
}
}
def jobFinished() = {
synchronized {
finished = true
this.notifyAll()
}
}
}
object TaskThreadInfo {
val threadToLock = HashMap[Int, Lock]()
val threadToRunning = HashMap[Int, Boolean]()
val threadToStarted = HashMap[Int, CountDownLatch]()
}
/*
* 1. each thread contains one job.
* 2. each job contains one stage.
* 3. each stage only contains one task.
* 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure
* it will get cpu core resource, and will wait to finished after user manually
* release "Lock" and then cluster will contain another free cpu cores.
* 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue,
* thus it will be scheduled later when cluster has free cpu cores.
*/
class LocalSchedulerSuite extends FunSuite with LocalSparkContext with BeforeAndAfterEach {
override def afterEach() {
super.afterEach()
System.clearProperty("spark.scheduler.mode")
}
def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) {
TaskThreadInfo.threadToRunning(threadIndex) = false
val nums = sc.parallelize(threadIndex to threadIndex, 1)
TaskThreadInfo.threadToLock(threadIndex) = new Lock()
TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1)
new Thread {
if (poolName != null) {
sc.setLocalProperty("spark.scheduler.pool", poolName)
}
override def run() {
val ans = nums.map(number => {
TaskThreadInfo.threadToRunning(number) = true
TaskThreadInfo.threadToStarted(number).countDown()
TaskThreadInfo.threadToLock(number).jobWait()
TaskThreadInfo.threadToRunning(number) = false
number
}).collect()
assert(ans.toList === List(threadIndex))
sem.release()
}
}.start()
}
test("Local FIFO scheduler end-to-end test") {
System.setProperty("spark.scheduler.mode", "FIFO")
sc = new SparkContext("local[4]", "test")
val sem = new Semaphore(0)
createThread(1,null,sc,sem)
TaskThreadInfo.threadToStarted(1).await()
createThread(2,null,sc,sem)
TaskThreadInfo.threadToStarted(2).await()
createThread(3,null,sc,sem)
TaskThreadInfo.threadToStarted(3).await()
createThread(4,null,sc,sem)
TaskThreadInfo.threadToStarted(4).await()
// thread 5 and 6 (stage pending)must meet following two points
// 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager
// queue before executing TaskThreadInfo.threadToLock(1).jobFinished()
// 2. priority of stage in thread 5 should be prior to priority of stage in thread 6
// So I just use "sleep" 1s here for each thread.
// TODO: any better solution?
createThread(5,null,sc,sem)
Thread.sleep(1000)
createThread(6,null,sc,sem)
Thread.sleep(1000)
assert(TaskThreadInfo.threadToRunning(1) === true)
assert(TaskThreadInfo.threadToRunning(2) === true)
assert(TaskThreadInfo.threadToRunning(3) === true)
assert(TaskThreadInfo.threadToRunning(4) === true)
assert(TaskThreadInfo.threadToRunning(5) === false)
assert(TaskThreadInfo.threadToRunning(6) === false)
TaskThreadInfo.threadToLock(1).jobFinished()
TaskThreadInfo.threadToStarted(5).await()
assert(TaskThreadInfo.threadToRunning(1) === false)
assert(TaskThreadInfo.threadToRunning(2) === true)
assert(TaskThreadInfo.threadToRunning(3) === true)
assert(TaskThreadInfo.threadToRunning(4) === true)
assert(TaskThreadInfo.threadToRunning(5) === true)
assert(TaskThreadInfo.threadToRunning(6) === false)
TaskThreadInfo.threadToLock(3).jobFinished()
TaskThreadInfo.threadToStarted(6).await()
assert(TaskThreadInfo.threadToRunning(1) === false)
assert(TaskThreadInfo.threadToRunning(2) === true)
assert(TaskThreadInfo.threadToRunning(3) === false)
assert(TaskThreadInfo.threadToRunning(4) === true)
assert(TaskThreadInfo.threadToRunning(5) === true)
assert(TaskThreadInfo.threadToRunning(6) === true)
TaskThreadInfo.threadToLock(2).jobFinished()
TaskThreadInfo.threadToLock(4).jobFinished()
TaskThreadInfo.threadToLock(5).jobFinished()
TaskThreadInfo.threadToLock(6).jobFinished()
sem.acquire(6)
}
test("Local fair scheduler end-to-end test") {
System.setProperty("spark.scheduler.mode", "FAIR")
val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
System.setProperty("spark.scheduler.allocation.file", xmlPath)
sc = new SparkContext("local[8]", "LocalSchedulerSuite")
val sem = new Semaphore(0)
createThread(10,"1",sc,sem)
TaskThreadInfo.threadToStarted(10).await()
createThread(20,"2",sc,sem)
TaskThreadInfo.threadToStarted(20).await()
createThread(30,"3",sc,sem)
TaskThreadInfo.threadToStarted(30).await()
assert(TaskThreadInfo.threadToRunning(10) === true)
assert(TaskThreadInfo.threadToRunning(20) === true)
assert(TaskThreadInfo.threadToRunning(30) === true)
createThread(11,"1",sc,sem)
TaskThreadInfo.threadToStarted(11).await()
createThread(21,"2",sc,sem)
TaskThreadInfo.threadToStarted(21).await()
createThread(31,"3",sc,sem)
TaskThreadInfo.threadToStarted(31).await()
assert(TaskThreadInfo.threadToRunning(11) === true)
assert(TaskThreadInfo.threadToRunning(21) === true)
assert(TaskThreadInfo.threadToRunning(31) === true)
createThread(12,"1",sc,sem)
TaskThreadInfo.threadToStarted(12).await()
createThread(22,"2",sc,sem)
TaskThreadInfo.threadToStarted(22).await()
createThread(32,"3",sc,sem)
assert(TaskThreadInfo.threadToRunning(12) === true)
assert(TaskThreadInfo.threadToRunning(22) === true)
assert(TaskThreadInfo.threadToRunning(32) === false)
TaskThreadInfo.threadToLock(10).jobFinished()
TaskThreadInfo.threadToStarted(32).await()
assert(TaskThreadInfo.threadToRunning(32) === true)
//1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager
// queue so that cluster will assign free cpu core to stage 23 after stage 11 finished.
//2. priority of 23 and 33 will be meaningless as using fair scheduler here.
createThread(23,"2",sc,sem)
createThread(33,"3",sc,sem)
Thread.sleep(1000)
TaskThreadInfo.threadToLock(11).jobFinished()
TaskThreadInfo.threadToStarted(23).await()
assert(TaskThreadInfo.threadToRunning(23) === true)
assert(TaskThreadInfo.threadToRunning(33) === false)
TaskThreadInfo.threadToLock(12).jobFinished()
TaskThreadInfo.threadToStarted(33).await()
assert(TaskThreadInfo.threadToRunning(33) === true)
TaskThreadInfo.threadToLock(20).jobFinished()
TaskThreadInfo.threadToLock(21).jobFinished()
TaskThreadInfo.threadToLock(22).jobFinished()
TaskThreadInfo.threadToLock(23).jobFinished()
TaskThreadInfo.threadToLock(30).jobFinished()
TaskThreadInfo.threadToLock(31).jobFinished()
TaskThreadInfo.threadToLock(32).jobFinished()
TaskThreadInfo.threadToLock(33).jobFinished()
sem.acquire(11)
}
}

View file

@ -17,16 +17,20 @@
package org.apache.spark.scheduler.cluster package org.apache.spark.scheduler.cluster
import org.apache.hadoop.conf.Configuration
import org.apache.spark._ import org.apache.spark._
import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler} import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler}
import org.apache.spark.scheduler.TaskScheduler
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
import org.apache.hadoop.conf.Configuration
/** /**
* *
* This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done * This is a simple extension to TaskScheduler - to ensure that appropriate initialization of
* ApplicationMaster, etc. is done
*/ */
private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) { private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration)
extends TaskScheduler(sc) {
logInfo("Created YarnClusterScheduler") logInfo("Created YarnClusterScheduler")