Initial commit for job killing.

This commit is contained in:
Reynold Xin 2013-09-16 18:54:06 -07:00
parent 3443d3fd43
commit cbc48be13b
21 changed files with 255 additions and 56 deletions

View file

@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T])
extends Iterator[T] {
def hasNext: Boolean = !context.interrupted && delegate.hasNext
def next(): T = delegate.next()
}

View file

@ -812,6 +812,13 @@ class SparkContext(
result result
} }
/**
* Kill a running job.
*/
def killJob(jobId: Int) {
dagScheduler.killJob(jobId)
}
/** /**
* Clean a closure to make it ready to serialized and send to tasks * Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables) * (removes unreferenced variables in $outer's, updates REPL variables)

View file

@ -25,6 +25,7 @@ class TaskContext(
val splitId: Int, val splitId: Int,
val attemptId: Long, val attemptId: Long,
val runningLocally: Boolean = false, val runningLocally: Boolean = false,
@volatile var interrupted: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty() val taskMetrics: TaskMetrics = TaskMetrics.empty()
) extends Serializable { ) extends Serializable {

View file

@ -17,7 +17,7 @@
package org.apache.spark.executor package org.apache.spark.executor
import java.io.{File} import java.io.File
import java.lang.management.ManagementFactory import java.lang.management.ManagementFactory
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.util.concurrent._ import java.util.concurrent._
@ -31,7 +31,8 @@ import org.apache.spark.util.Utils
/** /**
* The Mesos executor for Spark. * The backend executor for Spark. The executor maintains a thread pool and uses it to execute
* tasks.
*/ */
private[spark] class Executor( private[spark] class Executor(
executorId: String, executorId: String,
@ -101,18 +102,38 @@ private[spark] class Executor(
val executorSource = new ExecutorSource(this, executorId) val executorSource = new ExecutorSource(this, executorId)
// Initialize Spark environment (using system properties read above) // Initialize Spark environment (using system properties read above)
val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false) private val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0,
isDriver = false, isLocal = false)
SparkEnv.set(env) SparkEnv.set(env)
env.metricsSystem.registerSource(executorSource) env.metricsSystem.registerSource(executorSource)
private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size") private val akkaFrameSize = {
env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size")
}
// Start worker thread pool // Start worker thread pool
val threadPool = new ThreadPoolExecutor( val threadPool = new ThreadPoolExecutor(
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
// Maintains the list of running tasks.
private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) { def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
threadPool.execute(new TaskRunner(context, taskId, serializedTask)) val task = new TaskRunner(context, taskId, serializedTask)
runningTasks.put(taskId, task)
threadPool.execute(task)
}
def killTask(taskId: Long) {
val task = runningTasks.get(taskId)
if (task != null) {
task.kill()
// We remove the task also in the finally block in TaskRunner.run.
// The reason we need to remove it here is because killTask might be called before the task
// is even launched, and never reaching that finally block. ConcurrentHashMap's remove is
// idempotent.
runningTasks.remove(taskId)
}
} }
/** Get the Yarn approved local directories. */ /** Get the Yarn approved local directories. */
@ -124,15 +145,26 @@ private[spark] class Executor(
.getOrElse(Option(System.getenv("LOCAL_DIRS")) .getOrElse(Option(System.getenv("LOCAL_DIRS"))
.getOrElse("")) .getOrElse(""))
if (localDirs.isEmpty()) { if (localDirs.isEmpty) {
throw new Exception("Yarn Local dirs can't be empty") throw new Exception("Yarn Local dirs can't be empty")
} }
return localDirs localDirs
} }
class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
extends Runnable { extends Runnable {
@volatile private var killed = false
@volatile private var task: Task[Any] = _
def kill() {
// Note that there is a tiny possibiliy of raising here.
killed = true
if (task != null) {
task.kill()
}
}
override def run() { override def run() {
val startTime = System.currentTimeMillis() val startTime = System.currentTimeMillis()
SparkEnv.set(env) SparkEnv.set(env)
@ -150,22 +182,29 @@ private[spark] class Executor(
Accumulators.clear() Accumulators.clear()
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars) updateDependencies(taskFiles, taskJars)
val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
if (killed) {
context.statusUpdate(taskId, TaskState.KILLED, ser.serialize(""))
}
attemptedTask = Some(task) attemptedTask = Some(task)
logInfo("Its epoch is " + task.epoch) logDebug("Its epoch is " + task.epoch)
env.mapOutputTracker.updateEpoch(task.epoch) env.mapOutputTracker.updateEpoch(task.epoch)
taskStart = System.currentTimeMillis() taskStart = System.currentTimeMillis()
val value = task.run(taskId.toInt) val value = task.run(taskId.toInt)
val taskFinish = System.currentTimeMillis() val taskFinish = System.currentTimeMillis()
for (m <- task.metrics) { for (m <- task.metrics) {
m.hostname = Utils.localHostName m.hostname = Utils.localHostName()
m.executorDeserializeTime = (taskStart - startTime).toInt m.executorDeserializeTime = (taskStart - startTime).toInt
m.executorRunTime = (taskFinish - taskStart).toInt m.executorRunTime = (taskFinish - taskStart).toInt
m.jvmGCTime = getTotalGCTime - startGCTime m.jvmGCTime = getTotalGCTime - startGCTime
} }
//TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c // TODO I'd also like to track the time it takes to serialize the task results, but that is
// we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could // huge headache, b/c we need to serialize the task metrics first. If TaskMetrics had a
// just change the relevants bytes in the byte buffer // custom serialized format, we could just change the relevants bytes in the byte buffer
val accumUpdates = Accumulators.values val accumUpdates = Accumulators.values
val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null)) val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null))
val serializedResult = ser.serialize(result) val serializedResult = ser.serialize(result)
@ -198,6 +237,8 @@ private[spark] class Executor(
logError("Exception in task ID " + taskId, t) logError("Exception in task ID " + taskId, t)
//System.exit(1) //System.exit(1)
} }
} finally {
runningTasks.remove(taskId)
} }
} }
} }
@ -207,7 +248,7 @@ private[spark] class Executor(
* created by the interpreter to the search path * created by the interpreter to the search path
*/ */
private def createClassLoader(): ExecutorURLClassLoader = { private def createClassLoader(): ExecutorURLClassLoader = {
var loader = this.getClass.getClassLoader val loader = this.getClass.getClassLoader
// For each of the jars in the jarSet, add them to the class loader. // For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched. // We assume each of the files has already been fetched.
@ -229,7 +270,7 @@ private[spark] class Executor(
val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]] .asInstanceOf[Class[_ <: ClassLoader]]
val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader]) val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
return constructor.newInstance(classUri, parent) constructor.newInstance(classUri, parent)
} catch { } catch {
case _: ClassNotFoundException => case _: ClassNotFoundException =>
logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!")
@ -237,7 +278,7 @@ private[spark] class Executor(
null null
} }
} else { } else {
return parent parent
} }
} }

View file

@ -69,6 +69,12 @@ private[spark] class StandaloneExecutorBackend(
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask) executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
} }
case KillTask(taskId, _) =>
logInfo("Kill task %s %s".format(taskId, executorId))
if (executor != null) {
executor.killTask(taskId)
}
case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
logError("Driver terminated or disconnected! Shutting down.") logError("Driver terminated or disconnected! Shutting down.")
System.exit(1) System.exit(1)

View file

@ -128,6 +128,7 @@ class DAGScheduler(
// stray messages to detect. // stray messages to detect.
val failedEpoch = new HashMap[String, Long] val failedEpoch = new HashMap[String, Long]
// stage id to the active job
val idToActiveJob = new HashMap[Int, ActiveJob] val idToActiveJob = new HashMap[Int, ActiveJob]
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
@ -334,6 +335,35 @@ class DAGScheduler(
listener.awaitResult() // Will throw an exception if the job fails listener.awaitResult() // Will throw an exception if the job fails
} }
def killJob(jobId: Int) {
activeJobs.find(job => job.jobId == jobId).foreach(job => killJob(job))
}
private def killJob(job: ActiveJob) {
logInfo("Killing Job and cleaning up stages %d".format(job.jobId))
activeJobs.remove(job)
idToActiveJob.remove(job.jobId)
val stage = job.finalStage
resultStageToJob.remove(stage)
killStage(stage)
// recursively remove all parent stages
stage.parents.foreach(p => killStage(p))
job.listener.jobFailed(new SparkException("Job killed"))
}
private def killStage(stage: Stage) {
logInfo("Killing Stage %s".format(stage.id))
stageIdToStage.remove(stage.id)
if (stage.isShuffleMap) {
shuffleToMapStage.remove(stage.id)
}
waiting.remove(stage)
pendingTasks.remove(stage)
running.remove(stage)
taskSched.killTasks(stage.id)
stage.parents.foreach(p => killStage(p))
}
/** /**
* Process one event retrieved from the event queue. * Process one event retrieved from the event queue.
* Returns true if we should stop the event loop. * Returns true if we should stop the event loop.
@ -579,6 +609,11 @@ class DAGScheduler(
*/ */
private def handleTaskCompletion(event: CompletionEvent) { private def handleTaskCompletion(event: CompletionEvent) {
val task = event.task val task = event.task
if (!stageIdToStage.contains(task.stageId)) {
// Skip all the actions if the stage has been cancelled.
return
}
val stage = stageIdToStage(task.stageId) val stage = stageIdToStage(task.stageId)
def markStageAsFinished(stage: Stage) = { def markStageAsFinished(stage: Stage) = {

View file

@ -38,17 +38,17 @@ private[spark] object ResultTask {
synchronized { synchronized {
val old = serializedInfoCache.get(stageId).orNull val old = serializedInfoCache.get(stageId).orNull
if (old != null) { if (old != null) {
return old old
} else { } else {
val out = new ByteArrayOutputStream val out = new ByteArrayOutputStream
val ser = SparkEnv.get.closureSerializer.newInstance val ser = SparkEnv.get.closureSerializer.newInstance()
val objOut = ser.serializeStream(new GZIPOutputStream(out)) val objOut = ser.serializeStream(new GZIPOutputStream(out))
objOut.writeObject(rdd) objOut.writeObject(rdd)
objOut.writeObject(func) objOut.writeObject(func)
objOut.close() objOut.close()
val bytes = out.toByteArray val bytes = out.toByteArray
serializedInfoCache.put(stageId, bytes) serializedInfoCache.put(stageId, bytes)
return bytes bytes
} }
} }
} }
@ -56,11 +56,11 @@ private[spark] object ResultTask {
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = { def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = {
val loader = Thread.currentThread.getContextClassLoader val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val ser = SparkEnv.get.closureSerializer.newInstance val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in) val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]] val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
return (rdd, func) (rdd, func)
} }
def clearCache() { def clearCache() {
@ -75,25 +75,20 @@ private[spark] class ResultTask[T, U](
stageId: Int, stageId: Int,
var rdd: RDD[T], var rdd: RDD[T],
var func: (TaskContext, Iterator[T]) => U, var func: (TaskContext, Iterator[T]) => U,
var partition: Int, _partition: Int,
@transient locs: Seq[TaskLocation], @transient locs: Seq[TaskLocation],
var outputId: Int) var outputId: Int)
extends Task[U](stageId) with Externalizable { extends Task[U](stageId, _partition) with Externalizable {
def this() = this(0, null, null, 0, null, 0) def this() = this(0, null, null, 0, null, 0)
var split = if (rdd == null) { var split = if (rdd == null) null else rdd.partitions(partition)
null
} else {
rdd.partitions(partition)
}
@transient private val preferredLocs: Seq[TaskLocation] = { @transient private val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq if (locs == null) Nil else locs.toSet.toSeq
} }
override def run(attemptId: Long): U = { override def runTask(context: TaskContext): U = {
val context = new TaskContext(stageId, partition, attemptId, runningLocally = false)
metrics = Some(context.taskMetrics) metrics = Some(context.taskMetrics)
try { try {
func(context, rdd.iterator(split, context)) func(context, rdd.iterator(split, context))

View file

@ -53,7 +53,7 @@ private[spark] object ShuffleMapTask {
objOut.close() objOut.close()
val bytes = out.toByteArray val bytes = out.toByteArray
serializedInfoCache.put(stageId, bytes) serializedInfoCache.put(stageId, bytes)
return bytes bytes
} }
} }
} }
@ -66,7 +66,7 @@ private[spark] object ShuffleMapTask {
val objIn = ser.deserializeStream(in) val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]] val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
return (rdd, dep) (rdd, dep)
} }
} }
@ -75,7 +75,7 @@ private[spark] object ShuffleMapTask {
val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val objIn = new ObjectInputStream(in) val objIn = new ObjectInputStream(in)
val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
return (HashMap(set.toSeq: _*)) HashMap(set.toSeq: _*)
} }
def clearCache() { def clearCache() {
@ -89,9 +89,9 @@ private[spark] class ShuffleMapTask(
stageId: Int, stageId: Int,
var rdd: RDD[_], var rdd: RDD[_],
var dep: ShuffleDependency[_,_], var dep: ShuffleDependency[_,_],
var partition: Int, _partition: Int,
@transient private var locs: Seq[TaskLocation]) @transient private var locs: Seq[TaskLocation])
extends Task[MapStatus](stageId) extends Task[MapStatus](stageId, _partition)
with Externalizable with Externalizable
with Logging { with Logging {
@ -129,11 +129,9 @@ private[spark] class ShuffleMapTask(
split = in.readObject().asInstanceOf[Partition] split = in.readObject().asInstanceOf[Partition]
} }
override def run(attemptId: Long): MapStatus = { override def runTask(context: TaskContext): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions val numOutputSplits = dep.partitioner.numPartitions
metrics = Some(context.taskMetrics)
val taskContext = new TaskContext(stageId, partition, attemptId, runningLocally = false)
metrics = Some(taskContext.taskMetrics)
val blockManager = SparkEnv.get.blockManager val blockManager = SparkEnv.get.blockManager
var shuffle: ShuffleBlocks = null var shuffle: ShuffleBlocks = null
@ -146,7 +144,7 @@ private[spark] class ShuffleMapTask(
buckets = shuffle.acquireWriters(partition) buckets = shuffle.acquireWriters(partition)
// Write the map output to its associated buckets. // Write the map output to its associated buckets.
for (elem <- rdd.iterator(split, taskContext)) { for (elem <- rdd.iterator(split, context)) {
val pair = elem.asInstanceOf[Product2[Any, Any]] val pair = elem.asInstanceOf[Product2[Any, Any]]
val bucketId = dep.partitioner.getPartition(pair._1) val bucketId = dep.partitioner.getPartition(pair._1)
buckets.writers(bucketId).write(pair) buckets.writers(bucketId).write(pair)
@ -167,7 +165,7 @@ private[spark] class ShuffleMapTask(
shuffleMetrics.shuffleBytesWritten = totalBytes shuffleMetrics.shuffleBytesWritten = totalBytes
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
return new MapStatus(blockManager.blockManagerId, compressedSizes) new MapStatus(blockManager.blockManagerId, compressedSizes)
} catch { case e: Exception => } catch { case e: Exception =>
// If there is an exception from running the task, revert the partial writes // If there is an exception from running the task, revert the partial writes
// and throw the exception upstream to Spark. // and throw the exception upstream to Spark.
@ -181,7 +179,7 @@ private[spark] class ShuffleMapTask(
shuffle.releaseWriters(buckets) shuffle.releaseWriters(buckets)
} }
// Execute the callbacks on task completion. // Execute the callbacks on task completion.
taskContext.executeOnCompleteCallbacks() context.executeOnCompleteCallbacks()
} }
} }

View file

@ -17,25 +17,59 @@
package org.apache.spark.scheduler package org.apache.spark.scheduler
import org.apache.spark.serializer.SerializerInstance
import java.io.{DataInputStream, DataOutputStream} import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer import java.nio.ByteBuffer
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
import org.apache.spark.util.ByteBufferInputStream
import scala.collection.mutable.HashMap import scala.collection.mutable.HashMap
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
import org.apache.spark.TaskContext
import org.apache.spark.executor.TaskMetrics import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.ByteBufferInputStream
/** /**
* A task to execute on a worker node. * A task to execute on a worker node.
*/ */
private[spark] abstract class Task[T](val stageId: Int) extends Serializable { private[spark] abstract class Task[T](val stageId: Int, var partition: Int) extends Serializable {
def run(attemptId: Long): T
def run(attemptId: Long): T = {
context = new TaskContext(stageId, partition, attemptId, runningLocally = false)
if (killed) {
kill()
}
runTask(context)
}
def runTask(context: TaskContext): T
def preferredLocations: Seq[TaskLocation] = Nil def preferredLocations: Seq[TaskLocation] = Nil
var epoch: Long = -1 // Map output tracker epoch. Will be set by TaskScheduler. // Map output tracker epoch. Will be set by TaskScheduler.
var epoch: Long = -1
var metrics: Option[TaskMetrics] = None var metrics: Option[TaskMetrics] = None
// Task context, to be initialized in run().
@transient protected var context: TaskContext = _
// A flag to indicate whether the task is killed. This is used in case context is not yet
// initialized when kill() is invoked.
@volatile @transient private var killed = false
/**
* Kills a task by setting the interrupted flag to true. This relies on the upper level Spark
* code and user code to properly handle the flag. This function should be idempotent so it can
* be called multiple times.
*/
def kill() {
killed = true
if (context != null) {
context.interrupted = true
}
}
} }
/** /**

View file

@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import org.apache.spark.scheduler.cluster.Pool import org.apache.spark.scheduler.cluster.Pool
import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
/** /**
* Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler. * Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler.
* These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage, * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
@ -44,6 +45,9 @@ private[spark] trait TaskScheduler {
// Submit a sequence of tasks to run. // Submit a sequence of tasks to run.
def submitTasks(taskSet: TaskSet): Unit def submitTasks(taskSet: TaskSet): Unit
// Kill the stage.
def killTasks(stageId: Int)
// Set a listener for upcalls. This is guaranteed to be set before submitTasks is called. // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
def setListener(listener: TaskSchedulerListener): Unit def setListener(listener: TaskSchedulerListener): Unit

View file

@ -31,5 +31,9 @@ private[spark] class TaskSet(
val properties: Properties) { val properties: Properties) {
val id: String = stageId + "." + attempt val id: String = stageId + "." + attempt
def kill() {
tasks.foreach(_.kill())
}
override def toString: String = "TaskSet " + id override def toString: String = "TaskSet " + id
} }

View file

@ -166,6 +166,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
backend.reviveOffers() backend.reviveOffers()
} }
override def killTasks(stageId: Int) {
synchronized {
schedulableBuilder.popTaskSetManagers(stageId).foreach { t =>
val ts = t.asInstanceOf[TaskSetManager].taskSet
ts.kill()
val taskIds = taskSetTaskIds(ts.id)
taskIds.foreach { tid =>
val execId = taskIdToExecutorId(tid)
backend.killTask(tid, execId)
}
}
}
}
def taskSetFinished(manager: TaskSetManager) { def taskSetFinished(manager: TaskSetManager) {
this.synchronized { this.synchronized {
activeTaskSets -= manager.taskSet.id activeTaskSets -= manager.taskSet.id

View file

@ -32,8 +32,17 @@ import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
* addTaskSetManager: build the leaf nodes(TaskSetManagers) * addTaskSetManager: build the leaf nodes(TaskSetManagers)
*/ */
private[spark] trait SchedulableBuilder { private[spark] trait SchedulableBuilder {
def rootPool: Pool
def buildPools() def buildPools()
def addTaskSetManager(manager: Schedulable, properties: Properties) def addTaskSetManager(manager: Schedulable, properties: Properties)
def popTaskSetManagers(stageId: Int): Iterable[Schedulable] = {
val taskSets = rootPool.schedulableQueue.filter(_.stageId == stageId)
taskSets.foreach(rootPool.removeSchedulable)
taskSets
}
} }
private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) private[spark] class FIFOSchedulableBuilder(val rootPool: Pool)

View file

@ -17,7 +17,7 @@
package org.apache.spark.scheduler.cluster package org.apache.spark.scheduler.cluster
import org.apache.spark.{SparkContext} import org.apache.spark.SparkContext
/** /**
* A backend interface for cluster scheduling systems that allows plugging in different ones under * A backend interface for cluster scheduling systems that allows plugging in different ones under
@ -30,6 +30,10 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit def reviveOffers(): Unit
def defaultParallelism(): Int def defaultParallelism(): Int
def killTask(taskId: Long, executorId: String) {
throw new UnsupportedOperationException
}
// Memory used by each executor (in megabytes) // Memory used by each executor (in megabytes)
protected val executorMemory: Int = SparkContext.executorMemoryRequested protected val executorMemory: Int = SparkContext.executorMemoryRequested

View file

@ -30,6 +30,8 @@ private[spark] object StandaloneClusterMessages {
// Driver to executors // Driver to executors
case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage
case class KillTask(taskId: Long, executor: String) extends StandaloneClusterMessage
case class RegisteredExecutor(sparkProperties: Seq[(String, String)]) case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
extends StandaloneClusterMessage extends StandaloneClusterMessage

View file

@ -90,6 +90,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
case ReviveOffers => case ReviveOffers =>
makeOffers() makeOffers()
case KillTask(taskId, executorId) =>
executorActor(executorId) ! KillTask(taskId, executorId)
case StopDriver => case StopDriver =>
sender ! true sender ! true
context.stop(self) context.stop(self)
@ -179,6 +182,10 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
driverActor ! ReviveOffers driverActor ! ReviveOffers
} }
override def killTask(taskId: Long, executorId: String) {
driverActor ! KillTask(taskId, executorId)
}
override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism")) override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism"))
.map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2)) .map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2))

View file

@ -128,6 +128,12 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
} }
} }
override def killTasks(stageId: Int) = synchronized {
schedulableBuilder.popTaskSetManagers(stageId).foreach {
_.asInstanceOf[TaskSetManager].taskSet.kill()
}
}
def resourceOffer(freeCores: Int): Seq[TaskDescription] = { def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
synchronized { synchronized {
var freeCpuCores = freeCores var freeCpuCores = freeCores

View file

@ -58,7 +58,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
} }
whenExecuting(blockManager) { whenExecuting(blockManager) {
val context = new TaskContext(0, 0, 0, runningLocally = false, null) val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
taskMetrics = null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4)) assert(value.toList === List(1, 2, 3, 4))
} }
@ -70,7 +71,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
} }
whenExecuting(blockManager) { whenExecuting(blockManager) {
val context = new TaskContext(0, 0, 0, runningLocally = false, null) val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
taskMetrics = null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7)) assert(value.toList === List(5, 6, 7))
} }
@ -83,7 +85,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
} }
whenExecuting(blockManager) { whenExecuting(blockManager) {
val context = new TaskContext(0, 0, 0, runningLocally = true, null) val context = new TaskContext(0, 0, 0, runningLocally = true, interrupted = false,
taskMetrics = null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4)) assert(value.toList === List(1, 2, 3, 4))
} }

View file

@ -495,7 +495,7 @@ public class JavaAPISuite implements Serializable {
@Test @Test
public void iterator() { public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
TaskContext context = new TaskContext(0, 0, 0, false, null); TaskContext context = new TaskContext(0, 0, 0, false, false, null);
Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue()); Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
} }

View file

@ -29,12 +29,13 @@ import org.apache.spark.SparkContext
import org.apache.spark.Partition import org.apache.spark.Partition
import org.apache.spark.TaskContext import org.apache.spark.TaskContext
import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency} import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency}
import org.apache.spark.{FetchFailed, Success, TaskEndReason} import org.apache.spark.{Success, TaskEndReason}
import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster} import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster}
import org.apache.spark.scheduler.cluster.Pool import org.apache.spark.scheduler.cluster.Pool
import org.apache.spark.scheduler.cluster.SchedulingMode import org.apache.spark.scheduler.cluster.SchedulingMode
import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
import org.apache.spark.FetchFailed
/** /**
* Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
@ -62,6 +63,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
taskSets += taskSet taskSets += taskSet
} }
override def killTasks(stageId: Int) {}
override def setListener(listener: TaskSchedulerListener) = {} override def setListener(listener: TaskSchedulerListener) = {}
override def defaultParallelism() = 2 override def defaultParallelism() = 2
} }

View file

@ -17,10 +17,11 @@
package org.apache.spark.scheduler.cluster package org.apache.spark.scheduler.cluster
import org.apache.spark.TaskContext
import org.apache.spark.scheduler.{TaskLocation, Task} import org.apache.spark.scheduler.{TaskLocation, Task}
class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId) { class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
override def run(attemptId: Long): Int = 0 override def runTask(context: TaskContext): Int = 0
override def preferredLocations: Seq[TaskLocation] = prefLocs override def preferredLocations: Seq[TaskLocation] = prefLocs
} }