diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala new file mode 100644 index 0000000000..d2659db2d0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T]) + extends Iterator[T] { + + def hasNext: Boolean = !context.interrupted && delegate.hasNext + + def next(): T = delegate.next() +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 72540c712a..ff3e780edb 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -812,6 +812,13 @@ class SparkContext( result } + /** + * Kill a running job. + */ + def killJob(jobId: Int) { + dagScheduler.killJob(jobId) + } + /** * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index c2c358c7ad..0b1542c9bb 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -25,6 +25,7 @@ class TaskContext( val splitId: Int, val attemptId: Long, val runningLocally: Boolean = false, + @volatile var interrupted: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty() ) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index ceae3b8289..86a813d51e 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -17,7 +17,7 @@ package org.apache.spark.executor -import java.io.{File} +import java.io.File import java.lang.management.ManagementFactory import java.nio.ByteBuffer import java.util.concurrent._ @@ -31,7 +31,8 @@ import org.apache.spark.util.Utils /** - * The Mesos executor for Spark. + * The backend executor for Spark. The executor maintains a thread pool and uses it to execute + * tasks. */ private[spark] class Executor( executorId: String, @@ -101,18 +102,38 @@ private[spark] class Executor( val executorSource = new ExecutorSource(this, executorId) // Initialize Spark environment (using system properties read above) - val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false) + private val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, + isDriver = false, isLocal = false) SparkEnv.set(env) env.metricsSystem.registerSource(executorSource) - private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size") + private val akkaFrameSize = { + env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size") + } // Start worker thread pool val threadPool = new ThreadPoolExecutor( 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) + // Maintains the list of running tasks. + private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] + def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) { - threadPool.execute(new TaskRunner(context, taskId, serializedTask)) + val task = new TaskRunner(context, taskId, serializedTask) + runningTasks.put(taskId, task) + threadPool.execute(task) + } + + def killTask(taskId: Long) { + val task = runningTasks.get(taskId) + if (task != null) { + task.kill() + // We remove the task also in the finally block in TaskRunner.run. + // The reason we need to remove it here is because killTask might be called before the task + // is even launched, and never reaching that finally block. ConcurrentHashMap's remove is + // idempotent. + runningTasks.remove(taskId) + } } /** Get the Yarn approved local directories. */ @@ -124,15 +145,26 @@ private[spark] class Executor( .getOrElse(Option(System.getenv("LOCAL_DIRS")) .getOrElse("")) - if (localDirs.isEmpty()) { + if (localDirs.isEmpty) { throw new Exception("Yarn Local dirs can't be empty") } - return localDirs + localDirs } class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) extends Runnable { + @volatile private var killed = false + @volatile private var task: Task[Any] = _ + + def kill() { + // Note that there is a tiny possibiliy of raising here. + killed = true + if (task != null) { + task.kill() + } + } + override def run() { val startTime = System.currentTimeMillis() SparkEnv.set(env) @@ -150,22 +182,29 @@ private[spark] class Executor( Accumulators.clear() val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) updateDependencies(taskFiles, taskJars) - val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + + // If this task has been killed before we deserialized it, let's quit now. Otherwise, + // continue executing the task. + if (killed) { + context.statusUpdate(taskId, TaskState.KILLED, ser.serialize("")) + } + attemptedTask = Some(task) - logInfo("Its epoch is " + task.epoch) + logDebug("Its epoch is " + task.epoch) env.mapOutputTracker.updateEpoch(task.epoch) taskStart = System.currentTimeMillis() val value = task.run(taskId.toInt) val taskFinish = System.currentTimeMillis() for (m <- task.metrics) { - m.hostname = Utils.localHostName + m.hostname = Utils.localHostName() m.executorDeserializeTime = (taskStart - startTime).toInt m.executorRunTime = (taskFinish - taskStart).toInt m.jvmGCTime = getTotalGCTime - startGCTime } - //TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c - // we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could - // just change the relevants bytes in the byte buffer + // TODO I'd also like to track the time it takes to serialize the task results, but that is + // huge headache, b/c we need to serialize the task metrics first. If TaskMetrics had a + // custom serialized format, we could just change the relevants bytes in the byte buffer val accumUpdates = Accumulators.values val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null)) val serializedResult = ser.serialize(result) @@ -198,6 +237,8 @@ private[spark] class Executor( logError("Exception in task ID " + taskId, t) //System.exit(1) } + } finally { + runningTasks.remove(taskId) } } } @@ -207,7 +248,7 @@ private[spark] class Executor( * created by the interpreter to the search path */ private def createClassLoader(): ExecutorURLClassLoader = { - var loader = this.getClass.getClassLoader + val loader = this.getClass.getClassLoader // For each of the jars in the jarSet, add them to the class loader. // We assume each of the files has already been fetched. @@ -229,7 +270,7 @@ private[spark] class Executor( val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader]) - return constructor.newInstance(classUri, parent) + constructor.newInstance(classUri, parent) } catch { case _: ClassNotFoundException => logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") @@ -237,7 +278,7 @@ private[spark] class Executor( null } } else { - return parent + parent } } diff --git a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala index 7839023868..c2cc59740a 100644 --- a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala @@ -69,6 +69,12 @@ private[spark] class StandaloneExecutorBackend( executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask) } + case KillTask(taskId, _) => + logInfo("Kill task %s %s".format(taskId, executorId)) + if (executor != null) { + executor.killTask(taskId) + } + case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => logError("Driver terminated or disconnected! Shutting down.") System.exit(1) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 3e3f04f087..1d491d581a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -128,6 +128,7 @@ class DAGScheduler( // stray messages to detect. val failedEpoch = new HashMap[String, Long] + // stage id to the active job val idToActiveJob = new HashMap[Int, ActiveJob] val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done @@ -334,6 +335,35 @@ class DAGScheduler( listener.awaitResult() // Will throw an exception if the job fails } + def killJob(jobId: Int) { + activeJobs.find(job => job.jobId == jobId).foreach(job => killJob(job)) + } + + private def killJob(job: ActiveJob) { + logInfo("Killing Job and cleaning up stages %d".format(job.jobId)) + activeJobs.remove(job) + idToActiveJob.remove(job.jobId) + val stage = job.finalStage + resultStageToJob.remove(stage) + killStage(stage) + // recursively remove all parent stages + stage.parents.foreach(p => killStage(p)) + job.listener.jobFailed(new SparkException("Job killed")) + } + + private def killStage(stage: Stage) { + logInfo("Killing Stage %s".format(stage.id)) + stageIdToStage.remove(stage.id) + if (stage.isShuffleMap) { + shuffleToMapStage.remove(stage.id) + } + waiting.remove(stage) + pendingTasks.remove(stage) + running.remove(stage) + taskSched.killTasks(stage.id) + stage.parents.foreach(p => killStage(p)) + } + /** * Process one event retrieved from the event queue. * Returns true if we should stop the event loop. @@ -579,6 +609,11 @@ class DAGScheduler( */ private def handleTaskCompletion(event: CompletionEvent) { val task = event.task + + if (!stageIdToStage.contains(task.stageId)) { + // Skip all the actions if the stage has been cancelled. + return + } val stage = stageIdToStage(task.stageId) def markStageAsFinished(stage: Stage) = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 07e8317e3a..c084059859 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -38,17 +38,17 @@ private[spark] object ResultTask { synchronized { val old = serializedInfoCache.get(stageId).orNull if (old != null) { - return old + old } else { val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance + val ser = SparkEnv.get.closureSerializer.newInstance() val objOut = ser.serializeStream(new GZIPOutputStream(out)) objOut.writeObject(rdd) objOut.writeObject(func) objOut.close() val bytes = out.toByteArray serializedInfoCache.put(stageId, bytes) - return bytes + bytes } } } @@ -56,11 +56,11 @@ private[spark] object ResultTask { def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = { val loader = Thread.currentThread.getContextClassLoader val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance + val ser = SparkEnv.get.closureSerializer.newInstance() val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] - return (rdd, func) + (rdd, func) } def clearCache() { @@ -75,25 +75,20 @@ private[spark] class ResultTask[T, U]( stageId: Int, var rdd: RDD[T], var func: (TaskContext, Iterator[T]) => U, - var partition: Int, + _partition: Int, @transient locs: Seq[TaskLocation], var outputId: Int) - extends Task[U](stageId) with Externalizable { + extends Task[U](stageId, _partition) with Externalizable { def this() = this(0, null, null, 0, null, 0) - var split = if (rdd == null) { - null - } else { - rdd.partitions(partition) - } + var split = if (rdd == null) null else rdd.partitions(partition) @transient private val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } - override def run(attemptId: Long): U = { - val context = new TaskContext(stageId, partition, attemptId, runningLocally = false) + override def runTask(context: TaskContext): U = { metrics = Some(context.taskMetrics) try { func(context, rdd.iterator(split, context)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index d23df0dd2b..1904ee89c6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -53,7 +53,7 @@ private[spark] object ShuffleMapTask { objOut.close() val bytes = out.toByteArray serializedInfoCache.put(stageId, bytes) - return bytes + bytes } } } @@ -66,7 +66,7 @@ private[spark] object ShuffleMapTask { val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] - return (rdd, dep) + (rdd, dep) } } @@ -75,7 +75,7 @@ private[spark] object ShuffleMapTask { val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) val objIn = new ObjectInputStream(in) val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap - return (HashMap(set.toSeq: _*)) + HashMap(set.toSeq: _*) } def clearCache() { @@ -89,9 +89,9 @@ private[spark] class ShuffleMapTask( stageId: Int, var rdd: RDD[_], var dep: ShuffleDependency[_,_], - var partition: Int, + _partition: Int, @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId) + extends Task[MapStatus](stageId, _partition) with Externalizable with Logging { @@ -129,11 +129,9 @@ private[spark] class ShuffleMapTask( split = in.readObject().asInstanceOf[Partition] } - override def run(attemptId: Long): MapStatus = { + override def runTask(context: TaskContext): MapStatus = { val numOutputSplits = dep.partitioner.numPartitions - - val taskContext = new TaskContext(stageId, partition, attemptId, runningLocally = false) - metrics = Some(taskContext.taskMetrics) + metrics = Some(context.taskMetrics) val blockManager = SparkEnv.get.blockManager var shuffle: ShuffleBlocks = null @@ -146,7 +144,7 @@ private[spark] class ShuffleMapTask( buckets = shuffle.acquireWriters(partition) // Write the map output to its associated buckets. - for (elem <- rdd.iterator(split, taskContext)) { + for (elem <- rdd.iterator(split, context)) { val pair = elem.asInstanceOf[Product2[Any, Any]] val bucketId = dep.partitioner.getPartition(pair._1) buckets.writers(bucketId).write(pair) @@ -167,7 +165,7 @@ private[spark] class ShuffleMapTask( shuffleMetrics.shuffleBytesWritten = totalBytes metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) - return new MapStatus(blockManager.blockManagerId, compressedSizes) + new MapStatus(blockManager.blockManagerId, compressedSizes) } catch { case e: Exception => // If there is an exception from running the task, revert the partial writes // and throw the exception upstream to Spark. @@ -181,7 +179,7 @@ private[spark] class ShuffleMapTask( shuffle.releaseWriters(buckets) } // Execute the callbacks on task completion. - taskContext.executeOnCompleteCallbacks() + context.executeOnCompleteCallbacks() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 598d91752a..45945b64e8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -17,25 +17,59 @@ package org.apache.spark.scheduler -import org.apache.spark.serializer.SerializerInstance import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import org.apache.spark.util.ByteBufferInputStream + import scala.collection.mutable.HashMap + +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream + +import org.apache.spark.TaskContext import org.apache.spark.executor.TaskMetrics +import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.util.ByteBufferInputStream + /** * A task to execute on a worker node. */ -private[spark] abstract class Task[T](val stageId: Int) extends Serializable { - def run(attemptId: Long): T +private[spark] abstract class Task[T](val stageId: Int, var partition: Int) extends Serializable { + + def run(attemptId: Long): T = { + context = new TaskContext(stageId, partition, attemptId, runningLocally = false) + if (killed) { + kill() + } + runTask(context) + } + + def runTask(context: TaskContext): T + def preferredLocations: Seq[TaskLocation] = Nil - var epoch: Long = -1 // Map output tracker epoch. Will be set by TaskScheduler. + // Map output tracker epoch. Will be set by TaskScheduler. + var epoch: Long = -1 var metrics: Option[TaskMetrics] = None + // Task context, to be initialized in run(). + @transient protected var context: TaskContext = _ + + // A flag to indicate whether the task is killed. This is used in case context is not yet + // initialized when kill() is invoked. + @volatile @transient private var killed = false + + /** + * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark + * code and user code to properly handle the flag. This function should be idempotent so it can + * be called multiple times. + */ + def kill() { + killed = true + if (context != null) { + context.interrupted = true + } + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 63be8ba3f5..fe5d360fa9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import org.apache.spark.scheduler.cluster.Pool import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode + /** * Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler. * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage, @@ -44,6 +45,9 @@ private[spark] trait TaskScheduler { // Submit a sequence of tasks to run. def submitTasks(taskSet: TaskSet): Unit + // Kill the stage. + def killTasks(stageId: Int) + // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called. def setListener(listener: TaskSchedulerListener): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index c3ad325156..03bf760837 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -31,5 +31,9 @@ private[spark] class TaskSet( val properties: Properties) { val id: String = stageId + "." + attempt + def kill() { + tasks.foreach(_.kill()) + } + override def toString: String = "TaskSet " + id } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index 919acce828..51526164f2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -166,6 +166,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext) backend.reviveOffers() } + override def killTasks(stageId: Int) { + synchronized { + schedulableBuilder.popTaskSetManagers(stageId).foreach { t => + val ts = t.asInstanceOf[TaskSetManager].taskSet + ts.kill() + val taskIds = taskSetTaskIds(ts.id) + taskIds.foreach { tid => + val execId = taskIdToExecutorId(tid) + backend.killTask(tid, execId) + } + } + } + } + def taskSetFinished(manager: TaskSetManager) { this.synchronized { activeTaskSets -= manager.taskSet.id diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulableBuilder.scala index f80823317b..523fb9642a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulableBuilder.scala @@ -32,8 +32,17 @@ import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode * addTaskSetManager: build the leaf nodes(TaskSetManagers) */ private[spark] trait SchedulableBuilder { + def rootPool: Pool + def buildPools() + def addTaskSetManager(manager: Schedulable, properties: Properties) + + def popTaskSetManagers(stageId: Int): Iterable[Schedulable] = { + val taskSets = rootPool.schedulableQueue.filter(_.stageId == stageId) + taskSets.foreach(rootPool.removeSchedulable) + taskSets + } } private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala index d57eb3276f..c0578dcaa1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import org.apache.spark.{SparkContext} +import org.apache.spark.SparkContext /** * A backend interface for cluster scheduling systems that allows plugging in different ones under @@ -30,6 +30,10 @@ private[spark] trait SchedulerBackend { def reviveOffers(): Unit def defaultParallelism(): Int + def killTask(taskId: Long, executorId: String) { + throw new UnsupportedOperationException + } + // Memory used by each executor (in megabytes) protected val executorMemory: Int = SparkContext.executorMemoryRequested diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala index 9c36d221f6..73cee06c96 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala @@ -30,6 +30,8 @@ private[spark] object StandaloneClusterMessages { // Driver to executors case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage + case class KillTask(taskId: Long, executor: String) extends StandaloneClusterMessage + case class RegisteredExecutor(sparkProperties: Seq[(String, String)]) extends StandaloneClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index b4ea0be415..c999e0ab71 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -90,6 +90,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor case ReviveOffers => makeOffers() + case KillTask(taskId, executorId) => + executorActor(executorId) ! KillTask(taskId, executorId) + case StopDriver => sender ! true context.stop(self) @@ -179,6 +182,10 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor driverActor ! ReviveOffers } + override def killTask(taskId: Long, executorId: String) { + driverActor ! KillTask(taskId, executorId) + } + override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism")) .map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala index 8cb4d1396f..7dd0c7d743 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala @@ -128,6 +128,12 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } } + override def killTasks(stageId: Int) = synchronized { + schedulableBuilder.popTaskSetManagers(stageId).foreach { + _.asInstanceOf[TaskSetManager].taskSet.kill() + } + } + def resourceOffer(freeCores: Int): Seq[TaskDescription] = { synchronized { var freeCpuCores = freeCores diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 3a7171c488..f14e12df3d 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -58,7 +58,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, runningLocally = false, null) + val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false, + taskMetrics = null) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } @@ -70,7 +71,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, runningLocally = false, null) + val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false, + taskMetrics = null) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -83,7 +85,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, runningLocally = true, null) + val context = new TaskContext(0, 0, 0, runningLocally = true, interrupted = false, + taskMetrics = null) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 591c1d498d..7b0bb89ab2 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -495,7 +495,7 @@ public class JavaAPISuite implements Serializable { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0, false, null); + TaskContext context = new TaskContext(0, 0, 0, false, false, null); Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue()); } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 94f66c94c6..180211cbda 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -29,12 +29,13 @@ import org.apache.spark.SparkContext import org.apache.spark.Partition import org.apache.spark.TaskContext import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency} -import org.apache.spark.{FetchFailed, Success, TaskEndReason} +import org.apache.spark.{Success, TaskEndReason} import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster} import org.apache.spark.scheduler.cluster.Pool import org.apache.spark.scheduler.cluster.SchedulingMode import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode +import org.apache.spark.FetchFailed /** * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler @@ -62,6 +63,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) taskSets += taskSet } + override def killTasks(stageId: Int) {} override def setListener(listener: TaskSchedulerListener) = {} override def defaultParallelism() = 2 } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala index 2f12aaed18..0f01515179 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala @@ -17,10 +17,11 @@ package org.apache.spark.scheduler.cluster +import org.apache.spark.TaskContext import org.apache.spark.scheduler.{TaskLocation, Task} -class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId) { - override def run(attemptId: Long): Int = 0 +class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) { + override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs }