Initial commit for job killing.
This commit is contained in:
parent
3443d3fd43
commit
cbc48be13b
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark
|
||||
|
||||
class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T])
|
||||
extends Iterator[T] {
|
||||
|
||||
def hasNext: Boolean = !context.interrupted && delegate.hasNext
|
||||
|
||||
def next(): T = delegate.next()
|
||||
}
|
|
@ -812,6 +812,13 @@ class SparkContext(
|
|||
result
|
||||
}
|
||||
|
||||
/**
|
||||
* Kill a running job.
|
||||
*/
|
||||
def killJob(jobId: Int) {
|
||||
dagScheduler.killJob(jobId)
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean a closure to make it ready to serialized and send to tasks
|
||||
* (removes unreferenced variables in $outer's, updates REPL variables)
|
||||
|
|
|
@ -25,6 +25,7 @@ class TaskContext(
|
|||
val splitId: Int,
|
||||
val attemptId: Long,
|
||||
val runningLocally: Boolean = false,
|
||||
@volatile var interrupted: Boolean = false,
|
||||
val taskMetrics: TaskMetrics = TaskMetrics.empty()
|
||||
) extends Serializable {
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.executor
|
||||
|
||||
import java.io.{File}
|
||||
import java.io.File
|
||||
import java.lang.management.ManagementFactory
|
||||
import java.nio.ByteBuffer
|
||||
import java.util.concurrent._
|
||||
|
@ -31,7 +31,8 @@ import org.apache.spark.util.Utils
|
|||
|
||||
|
||||
/**
|
||||
* The Mesos executor for Spark.
|
||||
* The backend executor for Spark. The executor maintains a thread pool and uses it to execute
|
||||
* tasks.
|
||||
*/
|
||||
private[spark] class Executor(
|
||||
executorId: String,
|
||||
|
@ -101,18 +102,38 @@ private[spark] class Executor(
|
|||
val executorSource = new ExecutorSource(this, executorId)
|
||||
|
||||
// Initialize Spark environment (using system properties read above)
|
||||
val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
|
||||
private val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0,
|
||||
isDriver = false, isLocal = false)
|
||||
SparkEnv.set(env)
|
||||
env.metricsSystem.registerSource(executorSource)
|
||||
|
||||
private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size")
|
||||
private val akkaFrameSize = {
|
||||
env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size")
|
||||
}
|
||||
|
||||
// Start worker thread pool
|
||||
val threadPool = new ThreadPoolExecutor(
|
||||
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
|
||||
|
||||
// Maintains the list of running tasks.
|
||||
private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
|
||||
|
||||
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
|
||||
threadPool.execute(new TaskRunner(context, taskId, serializedTask))
|
||||
val task = new TaskRunner(context, taskId, serializedTask)
|
||||
runningTasks.put(taskId, task)
|
||||
threadPool.execute(task)
|
||||
}
|
||||
|
||||
def killTask(taskId: Long) {
|
||||
val task = runningTasks.get(taskId)
|
||||
if (task != null) {
|
||||
task.kill()
|
||||
// We remove the task also in the finally block in TaskRunner.run.
|
||||
// The reason we need to remove it here is because killTask might be called before the task
|
||||
// is even launched, and never reaching that finally block. ConcurrentHashMap's remove is
|
||||
// idempotent.
|
||||
runningTasks.remove(taskId)
|
||||
}
|
||||
}
|
||||
|
||||
/** Get the Yarn approved local directories. */
|
||||
|
@ -124,15 +145,26 @@ private[spark] class Executor(
|
|||
.getOrElse(Option(System.getenv("LOCAL_DIRS"))
|
||||
.getOrElse(""))
|
||||
|
||||
if (localDirs.isEmpty()) {
|
||||
if (localDirs.isEmpty) {
|
||||
throw new Exception("Yarn Local dirs can't be empty")
|
||||
}
|
||||
return localDirs
|
||||
localDirs
|
||||
}
|
||||
|
||||
class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
|
||||
extends Runnable {
|
||||
|
||||
@volatile private var killed = false
|
||||
@volatile private var task: Task[Any] = _
|
||||
|
||||
def kill() {
|
||||
// Note that there is a tiny possibiliy of raising here.
|
||||
killed = true
|
||||
if (task != null) {
|
||||
task.kill()
|
||||
}
|
||||
}
|
||||
|
||||
override def run() {
|
||||
val startTime = System.currentTimeMillis()
|
||||
SparkEnv.set(env)
|
||||
|
@ -150,22 +182,29 @@ private[spark] class Executor(
|
|||
Accumulators.clear()
|
||||
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
|
||||
updateDependencies(taskFiles, taskJars)
|
||||
val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
|
||||
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
|
||||
|
||||
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
|
||||
// continue executing the task.
|
||||
if (killed) {
|
||||
context.statusUpdate(taskId, TaskState.KILLED, ser.serialize(""))
|
||||
}
|
||||
|
||||
attemptedTask = Some(task)
|
||||
logInfo("Its epoch is " + task.epoch)
|
||||
logDebug("Its epoch is " + task.epoch)
|
||||
env.mapOutputTracker.updateEpoch(task.epoch)
|
||||
taskStart = System.currentTimeMillis()
|
||||
val value = task.run(taskId.toInt)
|
||||
val taskFinish = System.currentTimeMillis()
|
||||
for (m <- task.metrics) {
|
||||
m.hostname = Utils.localHostName
|
||||
m.hostname = Utils.localHostName()
|
||||
m.executorDeserializeTime = (taskStart - startTime).toInt
|
||||
m.executorRunTime = (taskFinish - taskStart).toInt
|
||||
m.jvmGCTime = getTotalGCTime - startGCTime
|
||||
}
|
||||
//TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c
|
||||
// we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could
|
||||
// just change the relevants bytes in the byte buffer
|
||||
// TODO I'd also like to track the time it takes to serialize the task results, but that is
|
||||
// huge headache, b/c we need to serialize the task metrics first. If TaskMetrics had a
|
||||
// custom serialized format, we could just change the relevants bytes in the byte buffer
|
||||
val accumUpdates = Accumulators.values
|
||||
val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null))
|
||||
val serializedResult = ser.serialize(result)
|
||||
|
@ -198,6 +237,8 @@ private[spark] class Executor(
|
|||
logError("Exception in task ID " + taskId, t)
|
||||
//System.exit(1)
|
||||
}
|
||||
} finally {
|
||||
runningTasks.remove(taskId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -207,7 +248,7 @@ private[spark] class Executor(
|
|||
* created by the interpreter to the search path
|
||||
*/
|
||||
private def createClassLoader(): ExecutorURLClassLoader = {
|
||||
var loader = this.getClass.getClassLoader
|
||||
val loader = this.getClass.getClassLoader
|
||||
|
||||
// For each of the jars in the jarSet, add them to the class loader.
|
||||
// We assume each of the files has already been fetched.
|
||||
|
@ -229,7 +270,7 @@ private[spark] class Executor(
|
|||
val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader")
|
||||
.asInstanceOf[Class[_ <: ClassLoader]]
|
||||
val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
|
||||
return constructor.newInstance(classUri, parent)
|
||||
constructor.newInstance(classUri, parent)
|
||||
} catch {
|
||||
case _: ClassNotFoundException =>
|
||||
logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!")
|
||||
|
@ -237,7 +278,7 @@ private[spark] class Executor(
|
|||
null
|
||||
}
|
||||
} else {
|
||||
return parent
|
||||
parent
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -69,6 +69,12 @@ private[spark] class StandaloneExecutorBackend(
|
|||
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
|
||||
}
|
||||
|
||||
case KillTask(taskId, _) =>
|
||||
logInfo("Kill task %s %s".format(taskId, executorId))
|
||||
if (executor != null) {
|
||||
executor.killTask(taskId)
|
||||
}
|
||||
|
||||
case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
|
||||
logError("Driver terminated or disconnected! Shutting down.")
|
||||
System.exit(1)
|
||||
|
|
|
@ -128,6 +128,7 @@ class DAGScheduler(
|
|||
// stray messages to detect.
|
||||
val failedEpoch = new HashMap[String, Long]
|
||||
|
||||
// stage id to the active job
|
||||
val idToActiveJob = new HashMap[Int, ActiveJob]
|
||||
|
||||
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
|
||||
|
@ -334,6 +335,35 @@ class DAGScheduler(
|
|||
listener.awaitResult() // Will throw an exception if the job fails
|
||||
}
|
||||
|
||||
def killJob(jobId: Int) {
|
||||
activeJobs.find(job => job.jobId == jobId).foreach(job => killJob(job))
|
||||
}
|
||||
|
||||
private def killJob(job: ActiveJob) {
|
||||
logInfo("Killing Job and cleaning up stages %d".format(job.jobId))
|
||||
activeJobs.remove(job)
|
||||
idToActiveJob.remove(job.jobId)
|
||||
val stage = job.finalStage
|
||||
resultStageToJob.remove(stage)
|
||||
killStage(stage)
|
||||
// recursively remove all parent stages
|
||||
stage.parents.foreach(p => killStage(p))
|
||||
job.listener.jobFailed(new SparkException("Job killed"))
|
||||
}
|
||||
|
||||
private def killStage(stage: Stage) {
|
||||
logInfo("Killing Stage %s".format(stage.id))
|
||||
stageIdToStage.remove(stage.id)
|
||||
if (stage.isShuffleMap) {
|
||||
shuffleToMapStage.remove(stage.id)
|
||||
}
|
||||
waiting.remove(stage)
|
||||
pendingTasks.remove(stage)
|
||||
running.remove(stage)
|
||||
taskSched.killTasks(stage.id)
|
||||
stage.parents.foreach(p => killStage(p))
|
||||
}
|
||||
|
||||
/**
|
||||
* Process one event retrieved from the event queue.
|
||||
* Returns true if we should stop the event loop.
|
||||
|
@ -579,6 +609,11 @@ class DAGScheduler(
|
|||
*/
|
||||
private def handleTaskCompletion(event: CompletionEvent) {
|
||||
val task = event.task
|
||||
|
||||
if (!stageIdToStage.contains(task.stageId)) {
|
||||
// Skip all the actions if the stage has been cancelled.
|
||||
return
|
||||
}
|
||||
val stage = stageIdToStage(task.stageId)
|
||||
|
||||
def markStageAsFinished(stage: Stage) = {
|
||||
|
|
|
@ -38,17 +38,17 @@ private[spark] object ResultTask {
|
|||
synchronized {
|
||||
val old = serializedInfoCache.get(stageId).orNull
|
||||
if (old != null) {
|
||||
return old
|
||||
old
|
||||
} else {
|
||||
val out = new ByteArrayOutputStream
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||
val objOut = ser.serializeStream(new GZIPOutputStream(out))
|
||||
objOut.writeObject(rdd)
|
||||
objOut.writeObject(func)
|
||||
objOut.close()
|
||||
val bytes = out.toByteArray
|
||||
serializedInfoCache.put(stageId, bytes)
|
||||
return bytes
|
||||
bytes
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -56,11 +56,11 @@ private[spark] object ResultTask {
|
|||
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = {
|
||||
val loader = Thread.currentThread.getContextClassLoader
|
||||
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||
val objIn = ser.deserializeStream(in)
|
||||
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
|
||||
val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
|
||||
return (rdd, func)
|
||||
(rdd, func)
|
||||
}
|
||||
|
||||
def clearCache() {
|
||||
|
@ -75,25 +75,20 @@ private[spark] class ResultTask[T, U](
|
|||
stageId: Int,
|
||||
var rdd: RDD[T],
|
||||
var func: (TaskContext, Iterator[T]) => U,
|
||||
var partition: Int,
|
||||
_partition: Int,
|
||||
@transient locs: Seq[TaskLocation],
|
||||
var outputId: Int)
|
||||
extends Task[U](stageId) with Externalizable {
|
||||
extends Task[U](stageId, _partition) with Externalizable {
|
||||
|
||||
def this() = this(0, null, null, 0, null, 0)
|
||||
|
||||
var split = if (rdd == null) {
|
||||
null
|
||||
} else {
|
||||
rdd.partitions(partition)
|
||||
}
|
||||
var split = if (rdd == null) null else rdd.partitions(partition)
|
||||
|
||||
@transient private val preferredLocs: Seq[TaskLocation] = {
|
||||
if (locs == null) Nil else locs.toSet.toSeq
|
||||
}
|
||||
|
||||
override def run(attemptId: Long): U = {
|
||||
val context = new TaskContext(stageId, partition, attemptId, runningLocally = false)
|
||||
override def runTask(context: TaskContext): U = {
|
||||
metrics = Some(context.taskMetrics)
|
||||
try {
|
||||
func(context, rdd.iterator(split, context))
|
||||
|
|
|
@ -53,7 +53,7 @@ private[spark] object ShuffleMapTask {
|
|||
objOut.close()
|
||||
val bytes = out.toByteArray
|
||||
serializedInfoCache.put(stageId, bytes)
|
||||
return bytes
|
||||
bytes
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ private[spark] object ShuffleMapTask {
|
|||
val objIn = ser.deserializeStream(in)
|
||||
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
|
||||
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
|
||||
return (rdd, dep)
|
||||
(rdd, dep)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -75,7 +75,7 @@ private[spark] object ShuffleMapTask {
|
|||
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
|
||||
val objIn = new ObjectInputStream(in)
|
||||
val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
|
||||
return (HashMap(set.toSeq: _*))
|
||||
HashMap(set.toSeq: _*)
|
||||
}
|
||||
|
||||
def clearCache() {
|
||||
|
@ -89,9 +89,9 @@ private[spark] class ShuffleMapTask(
|
|||
stageId: Int,
|
||||
var rdd: RDD[_],
|
||||
var dep: ShuffleDependency[_,_],
|
||||
var partition: Int,
|
||||
_partition: Int,
|
||||
@transient private var locs: Seq[TaskLocation])
|
||||
extends Task[MapStatus](stageId)
|
||||
extends Task[MapStatus](stageId, _partition)
|
||||
with Externalizable
|
||||
with Logging {
|
||||
|
||||
|
@ -129,11 +129,9 @@ private[spark] class ShuffleMapTask(
|
|||
split = in.readObject().asInstanceOf[Partition]
|
||||
}
|
||||
|
||||
override def run(attemptId: Long): MapStatus = {
|
||||
override def runTask(context: TaskContext): MapStatus = {
|
||||
val numOutputSplits = dep.partitioner.numPartitions
|
||||
|
||||
val taskContext = new TaskContext(stageId, partition, attemptId, runningLocally = false)
|
||||
metrics = Some(taskContext.taskMetrics)
|
||||
metrics = Some(context.taskMetrics)
|
||||
|
||||
val blockManager = SparkEnv.get.blockManager
|
||||
var shuffle: ShuffleBlocks = null
|
||||
|
@ -146,7 +144,7 @@ private[spark] class ShuffleMapTask(
|
|||
buckets = shuffle.acquireWriters(partition)
|
||||
|
||||
// Write the map output to its associated buckets.
|
||||
for (elem <- rdd.iterator(split, taskContext)) {
|
||||
for (elem <- rdd.iterator(split, context)) {
|
||||
val pair = elem.asInstanceOf[Product2[Any, Any]]
|
||||
val bucketId = dep.partitioner.getPartition(pair._1)
|
||||
buckets.writers(bucketId).write(pair)
|
||||
|
@ -167,7 +165,7 @@ private[spark] class ShuffleMapTask(
|
|||
shuffleMetrics.shuffleBytesWritten = totalBytes
|
||||
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
|
||||
|
||||
return new MapStatus(blockManager.blockManagerId, compressedSizes)
|
||||
new MapStatus(blockManager.blockManagerId, compressedSizes)
|
||||
} catch { case e: Exception =>
|
||||
// If there is an exception from running the task, revert the partial writes
|
||||
// and throw the exception upstream to Spark.
|
||||
|
@ -181,7 +179,7 @@ private[spark] class ShuffleMapTask(
|
|||
shuffle.releaseWriters(buckets)
|
||||
}
|
||||
// Execute the callbacks on task completion.
|
||||
taskContext.executeOnCompleteCallbacks()
|
||||
context.executeOnCompleteCallbacks()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,25 +17,59 @@
|
|||
|
||||
package org.apache.spark.scheduler
|
||||
|
||||
import org.apache.spark.serializer.SerializerInstance
|
||||
import java.io.{DataInputStream, DataOutputStream}
|
||||
import java.nio.ByteBuffer
|
||||
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
|
||||
import org.apache.spark.util.ByteBufferInputStream
|
||||
|
||||
import scala.collection.mutable.HashMap
|
||||
|
||||
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.executor.TaskMetrics
|
||||
import org.apache.spark.serializer.SerializerInstance
|
||||
import org.apache.spark.util.ByteBufferInputStream
|
||||
|
||||
|
||||
/**
|
||||
* A task to execute on a worker node.
|
||||
*/
|
||||
private[spark] abstract class Task[T](val stageId: Int) extends Serializable {
|
||||
def run(attemptId: Long): T
|
||||
private[spark] abstract class Task[T](val stageId: Int, var partition: Int) extends Serializable {
|
||||
|
||||
def run(attemptId: Long): T = {
|
||||
context = new TaskContext(stageId, partition, attemptId, runningLocally = false)
|
||||
if (killed) {
|
||||
kill()
|
||||
}
|
||||
runTask(context)
|
||||
}
|
||||
|
||||
def runTask(context: TaskContext): T
|
||||
|
||||
def preferredLocations: Seq[TaskLocation] = Nil
|
||||
|
||||
var epoch: Long = -1 // Map output tracker epoch. Will be set by TaskScheduler.
|
||||
// Map output tracker epoch. Will be set by TaskScheduler.
|
||||
var epoch: Long = -1
|
||||
|
||||
var metrics: Option[TaskMetrics] = None
|
||||
|
||||
// Task context, to be initialized in run().
|
||||
@transient protected var context: TaskContext = _
|
||||
|
||||
// A flag to indicate whether the task is killed. This is used in case context is not yet
|
||||
// initialized when kill() is invoked.
|
||||
@volatile @transient private var killed = false
|
||||
|
||||
/**
|
||||
* Kills a task by setting the interrupted flag to true. This relies on the upper level Spark
|
||||
* code and user code to properly handle the flag. This function should be idempotent so it can
|
||||
* be called multiple times.
|
||||
*/
|
||||
def kill() {
|
||||
killed = true
|
||||
if (context != null) {
|
||||
context.interrupted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.scheduler
|
|||
|
||||
import org.apache.spark.scheduler.cluster.Pool
|
||||
import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
|
||||
|
||||
/**
|
||||
* Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler.
|
||||
* These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
|
||||
|
@ -44,6 +45,9 @@ private[spark] trait TaskScheduler {
|
|||
// Submit a sequence of tasks to run.
|
||||
def submitTasks(taskSet: TaskSet): Unit
|
||||
|
||||
// Kill the stage.
|
||||
def killTasks(stageId: Int)
|
||||
|
||||
// Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
|
||||
def setListener(listener: TaskSchedulerListener): Unit
|
||||
|
||||
|
|
|
@ -31,5 +31,9 @@ private[spark] class TaskSet(
|
|||
val properties: Properties) {
|
||||
val id: String = stageId + "." + attempt
|
||||
|
||||
def kill() {
|
||||
tasks.foreach(_.kill())
|
||||
}
|
||||
|
||||
override def toString: String = "TaskSet " + id
|
||||
}
|
||||
|
|
|
@ -166,6 +166,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
|
|||
backend.reviveOffers()
|
||||
}
|
||||
|
||||
override def killTasks(stageId: Int) {
|
||||
synchronized {
|
||||
schedulableBuilder.popTaskSetManagers(stageId).foreach { t =>
|
||||
val ts = t.asInstanceOf[TaskSetManager].taskSet
|
||||
ts.kill()
|
||||
val taskIds = taskSetTaskIds(ts.id)
|
||||
taskIds.foreach { tid =>
|
||||
val execId = taskIdToExecutorId(tid)
|
||||
backend.killTask(tid, execId)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def taskSetFinished(manager: TaskSetManager) {
|
||||
this.synchronized {
|
||||
activeTaskSets -= manager.taskSet.id
|
||||
|
|
|
@ -32,8 +32,17 @@ import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
|
|||
* addTaskSetManager: build the leaf nodes(TaskSetManagers)
|
||||
*/
|
||||
private[spark] trait SchedulableBuilder {
|
||||
def rootPool: Pool
|
||||
|
||||
def buildPools()
|
||||
|
||||
def addTaskSetManager(manager: Schedulable, properties: Properties)
|
||||
|
||||
def popTaskSetManagers(stageId: Int): Iterable[Schedulable] = {
|
||||
val taskSets = rootPool.schedulableQueue.filter(_.stageId == stageId)
|
||||
taskSets.foreach(rootPool.removeSchedulable)
|
||||
taskSets
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] class FIFOSchedulableBuilder(val rootPool: Pool)
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.scheduler.cluster
|
||||
|
||||
import org.apache.spark.{SparkContext}
|
||||
import org.apache.spark.SparkContext
|
||||
|
||||
/**
|
||||
* A backend interface for cluster scheduling systems that allows plugging in different ones under
|
||||
|
@ -30,6 +30,10 @@ private[spark] trait SchedulerBackend {
|
|||
def reviveOffers(): Unit
|
||||
def defaultParallelism(): Int
|
||||
|
||||
def killTask(taskId: Long, executorId: String) {
|
||||
throw new UnsupportedOperationException
|
||||
}
|
||||
|
||||
// Memory used by each executor (in megabytes)
|
||||
protected val executorMemory: Int = SparkContext.executorMemoryRequested
|
||||
|
||||
|
|
|
@ -30,6 +30,8 @@ private[spark] object StandaloneClusterMessages {
|
|||
// Driver to executors
|
||||
case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage
|
||||
|
||||
case class KillTask(taskId: Long, executor: String) extends StandaloneClusterMessage
|
||||
|
||||
case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
|
||||
extends StandaloneClusterMessage
|
||||
|
||||
|
|
|
@ -90,6 +90,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
|
|||
case ReviveOffers =>
|
||||
makeOffers()
|
||||
|
||||
case KillTask(taskId, executorId) =>
|
||||
executorActor(executorId) ! KillTask(taskId, executorId)
|
||||
|
||||
case StopDriver =>
|
||||
sender ! true
|
||||
context.stop(self)
|
||||
|
@ -179,6 +182,10 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
|
|||
driverActor ! ReviveOffers
|
||||
}
|
||||
|
||||
override def killTask(taskId: Long, executorId: String) {
|
||||
driverActor ! KillTask(taskId, executorId)
|
||||
}
|
||||
|
||||
override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism"))
|
||||
.map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2))
|
||||
|
||||
|
|
|
@ -128,6 +128,12 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
|
|||
}
|
||||
}
|
||||
|
||||
override def killTasks(stageId: Int) = synchronized {
|
||||
schedulableBuilder.popTaskSetManagers(stageId).foreach {
|
||||
_.asInstanceOf[TaskSetManager].taskSet.kill()
|
||||
}
|
||||
}
|
||||
|
||||
def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
|
||||
synchronized {
|
||||
var freeCpuCores = freeCores
|
||||
|
|
|
@ -58,7 +58,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
|
|||
}
|
||||
|
||||
whenExecuting(blockManager) {
|
||||
val context = new TaskContext(0, 0, 0, runningLocally = false, null)
|
||||
val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
|
||||
taskMetrics = null)
|
||||
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
|
||||
assert(value.toList === List(1, 2, 3, 4))
|
||||
}
|
||||
|
@ -70,7 +71,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
|
|||
}
|
||||
|
||||
whenExecuting(blockManager) {
|
||||
val context = new TaskContext(0, 0, 0, runningLocally = false, null)
|
||||
val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
|
||||
taskMetrics = null)
|
||||
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
|
||||
assert(value.toList === List(5, 6, 7))
|
||||
}
|
||||
|
@ -83,7 +85,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
|
|||
}
|
||||
|
||||
whenExecuting(blockManager) {
|
||||
val context = new TaskContext(0, 0, 0, runningLocally = true, null)
|
||||
val context = new TaskContext(0, 0, 0, runningLocally = true, interrupted = false,
|
||||
taskMetrics = null)
|
||||
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
|
||||
assert(value.toList === List(1, 2, 3, 4))
|
||||
}
|
||||
|
|
|
@ -495,7 +495,7 @@ public class JavaAPISuite implements Serializable {
|
|||
@Test
|
||||
public void iterator() {
|
||||
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
|
||||
TaskContext context = new TaskContext(0, 0, 0, false, null);
|
||||
TaskContext context = new TaskContext(0, 0, 0, false, false, null);
|
||||
Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
|
||||
}
|
||||
|
||||
|
|
|
@ -29,12 +29,13 @@ import org.apache.spark.SparkContext
|
|||
import org.apache.spark.Partition
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency}
|
||||
import org.apache.spark.{FetchFailed, Success, TaskEndReason}
|
||||
import org.apache.spark.{Success, TaskEndReason}
|
||||
import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster}
|
||||
|
||||
import org.apache.spark.scheduler.cluster.Pool
|
||||
import org.apache.spark.scheduler.cluster.SchedulingMode
|
||||
import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
|
||||
import org.apache.spark.FetchFailed
|
||||
|
||||
/**
|
||||
* Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
|
||||
|
@ -62,6 +63,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
|
|||
taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
|
||||
taskSets += taskSet
|
||||
}
|
||||
override def killTasks(stageId: Int) {}
|
||||
override def setListener(listener: TaskSchedulerListener) = {}
|
||||
override def defaultParallelism() = 2
|
||||
}
|
||||
|
|
|
@ -17,10 +17,11 @@
|
|||
|
||||
package org.apache.spark.scheduler.cluster
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.scheduler.{TaskLocation, Task}
|
||||
|
||||
class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId) {
|
||||
override def run(attemptId: Long): Int = 0
|
||||
class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
|
||||
override def runTask(context: TaskContext): Int = 0
|
||||
|
||||
override def preferredLocations: Seq[TaskLocation] = prefLocs
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue