Initial commit for job killing.
This commit is contained in:
parent
3443d3fd43
commit
cbc48be13b
|
@ -0,0 +1,26 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark
|
||||||
|
|
||||||
|
class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T])
|
||||||
|
extends Iterator[T] {
|
||||||
|
|
||||||
|
def hasNext: Boolean = !context.interrupted && delegate.hasNext
|
||||||
|
|
||||||
|
def next(): T = delegate.next()
|
||||||
|
}
|
|
@ -812,6 +812,13 @@ class SparkContext(
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Kill a running job.
|
||||||
|
*/
|
||||||
|
def killJob(jobId: Int) {
|
||||||
|
dagScheduler.killJob(jobId)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Clean a closure to make it ready to serialized and send to tasks
|
* Clean a closure to make it ready to serialized and send to tasks
|
||||||
* (removes unreferenced variables in $outer's, updates REPL variables)
|
* (removes unreferenced variables in $outer's, updates REPL variables)
|
||||||
|
|
|
@ -25,6 +25,7 @@ class TaskContext(
|
||||||
val splitId: Int,
|
val splitId: Int,
|
||||||
val attemptId: Long,
|
val attemptId: Long,
|
||||||
val runningLocally: Boolean = false,
|
val runningLocally: Boolean = false,
|
||||||
|
@volatile var interrupted: Boolean = false,
|
||||||
val taskMetrics: TaskMetrics = TaskMetrics.empty()
|
val taskMetrics: TaskMetrics = TaskMetrics.empty()
|
||||||
) extends Serializable {
|
) extends Serializable {
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
package org.apache.spark.executor
|
package org.apache.spark.executor
|
||||||
|
|
||||||
import java.io.{File}
|
import java.io.File
|
||||||
import java.lang.management.ManagementFactory
|
import java.lang.management.ManagementFactory
|
||||||
import java.nio.ByteBuffer
|
import java.nio.ByteBuffer
|
||||||
import java.util.concurrent._
|
import java.util.concurrent._
|
||||||
|
@ -31,7 +31,8 @@ import org.apache.spark.util.Utils
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The Mesos executor for Spark.
|
* The backend executor for Spark. The executor maintains a thread pool and uses it to execute
|
||||||
|
* tasks.
|
||||||
*/
|
*/
|
||||||
private[spark] class Executor(
|
private[spark] class Executor(
|
||||||
executorId: String,
|
executorId: String,
|
||||||
|
@ -101,18 +102,38 @@ private[spark] class Executor(
|
||||||
val executorSource = new ExecutorSource(this, executorId)
|
val executorSource = new ExecutorSource(this, executorId)
|
||||||
|
|
||||||
// Initialize Spark environment (using system properties read above)
|
// Initialize Spark environment (using system properties read above)
|
||||||
val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
|
private val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0,
|
||||||
|
isDriver = false, isLocal = false)
|
||||||
SparkEnv.set(env)
|
SparkEnv.set(env)
|
||||||
env.metricsSystem.registerSource(executorSource)
|
env.metricsSystem.registerSource(executorSource)
|
||||||
|
|
||||||
private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size")
|
private val akkaFrameSize = {
|
||||||
|
env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size")
|
||||||
|
}
|
||||||
|
|
||||||
// Start worker thread pool
|
// Start worker thread pool
|
||||||
val threadPool = new ThreadPoolExecutor(
|
val threadPool = new ThreadPoolExecutor(
|
||||||
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
|
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
|
||||||
|
|
||||||
|
// Maintains the list of running tasks.
|
||||||
|
private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
|
||||||
|
|
||||||
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
|
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
|
||||||
threadPool.execute(new TaskRunner(context, taskId, serializedTask))
|
val task = new TaskRunner(context, taskId, serializedTask)
|
||||||
|
runningTasks.put(taskId, task)
|
||||||
|
threadPool.execute(task)
|
||||||
|
}
|
||||||
|
|
||||||
|
def killTask(taskId: Long) {
|
||||||
|
val task = runningTasks.get(taskId)
|
||||||
|
if (task != null) {
|
||||||
|
task.kill()
|
||||||
|
// We remove the task also in the finally block in TaskRunner.run.
|
||||||
|
// The reason we need to remove it here is because killTask might be called before the task
|
||||||
|
// is even launched, and never reaching that finally block. ConcurrentHashMap's remove is
|
||||||
|
// idempotent.
|
||||||
|
runningTasks.remove(taskId)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Get the Yarn approved local directories. */
|
/** Get the Yarn approved local directories. */
|
||||||
|
@ -124,15 +145,26 @@ private[spark] class Executor(
|
||||||
.getOrElse(Option(System.getenv("LOCAL_DIRS"))
|
.getOrElse(Option(System.getenv("LOCAL_DIRS"))
|
||||||
.getOrElse(""))
|
.getOrElse(""))
|
||||||
|
|
||||||
if (localDirs.isEmpty()) {
|
if (localDirs.isEmpty) {
|
||||||
throw new Exception("Yarn Local dirs can't be empty")
|
throw new Exception("Yarn Local dirs can't be empty")
|
||||||
}
|
}
|
||||||
return localDirs
|
localDirs
|
||||||
}
|
}
|
||||||
|
|
||||||
class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
|
class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
|
||||||
extends Runnable {
|
extends Runnable {
|
||||||
|
|
||||||
|
@volatile private var killed = false
|
||||||
|
@volatile private var task: Task[Any] = _
|
||||||
|
|
||||||
|
def kill() {
|
||||||
|
// Note that there is a tiny possibiliy of raising here.
|
||||||
|
killed = true
|
||||||
|
if (task != null) {
|
||||||
|
task.kill()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
override def run() {
|
override def run() {
|
||||||
val startTime = System.currentTimeMillis()
|
val startTime = System.currentTimeMillis()
|
||||||
SparkEnv.set(env)
|
SparkEnv.set(env)
|
||||||
|
@ -150,22 +182,29 @@ private[spark] class Executor(
|
||||||
Accumulators.clear()
|
Accumulators.clear()
|
||||||
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
|
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
|
||||||
updateDependencies(taskFiles, taskJars)
|
updateDependencies(taskFiles, taskJars)
|
||||||
val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
|
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
|
||||||
|
|
||||||
|
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
|
||||||
|
// continue executing the task.
|
||||||
|
if (killed) {
|
||||||
|
context.statusUpdate(taskId, TaskState.KILLED, ser.serialize(""))
|
||||||
|
}
|
||||||
|
|
||||||
attemptedTask = Some(task)
|
attemptedTask = Some(task)
|
||||||
logInfo("Its epoch is " + task.epoch)
|
logDebug("Its epoch is " + task.epoch)
|
||||||
env.mapOutputTracker.updateEpoch(task.epoch)
|
env.mapOutputTracker.updateEpoch(task.epoch)
|
||||||
taskStart = System.currentTimeMillis()
|
taskStart = System.currentTimeMillis()
|
||||||
val value = task.run(taskId.toInt)
|
val value = task.run(taskId.toInt)
|
||||||
val taskFinish = System.currentTimeMillis()
|
val taskFinish = System.currentTimeMillis()
|
||||||
for (m <- task.metrics) {
|
for (m <- task.metrics) {
|
||||||
m.hostname = Utils.localHostName
|
m.hostname = Utils.localHostName()
|
||||||
m.executorDeserializeTime = (taskStart - startTime).toInt
|
m.executorDeserializeTime = (taskStart - startTime).toInt
|
||||||
m.executorRunTime = (taskFinish - taskStart).toInt
|
m.executorRunTime = (taskFinish - taskStart).toInt
|
||||||
m.jvmGCTime = getTotalGCTime - startGCTime
|
m.jvmGCTime = getTotalGCTime - startGCTime
|
||||||
}
|
}
|
||||||
//TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c
|
// TODO I'd also like to track the time it takes to serialize the task results, but that is
|
||||||
// we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could
|
// huge headache, b/c we need to serialize the task metrics first. If TaskMetrics had a
|
||||||
// just change the relevants bytes in the byte buffer
|
// custom serialized format, we could just change the relevants bytes in the byte buffer
|
||||||
val accumUpdates = Accumulators.values
|
val accumUpdates = Accumulators.values
|
||||||
val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null))
|
val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null))
|
||||||
val serializedResult = ser.serialize(result)
|
val serializedResult = ser.serialize(result)
|
||||||
|
@ -198,6 +237,8 @@ private[spark] class Executor(
|
||||||
logError("Exception in task ID " + taskId, t)
|
logError("Exception in task ID " + taskId, t)
|
||||||
//System.exit(1)
|
//System.exit(1)
|
||||||
}
|
}
|
||||||
|
} finally {
|
||||||
|
runningTasks.remove(taskId)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -207,7 +248,7 @@ private[spark] class Executor(
|
||||||
* created by the interpreter to the search path
|
* created by the interpreter to the search path
|
||||||
*/
|
*/
|
||||||
private def createClassLoader(): ExecutorURLClassLoader = {
|
private def createClassLoader(): ExecutorURLClassLoader = {
|
||||||
var loader = this.getClass.getClassLoader
|
val loader = this.getClass.getClassLoader
|
||||||
|
|
||||||
// For each of the jars in the jarSet, add them to the class loader.
|
// For each of the jars in the jarSet, add them to the class loader.
|
||||||
// We assume each of the files has already been fetched.
|
// We assume each of the files has already been fetched.
|
||||||
|
@ -229,7 +270,7 @@ private[spark] class Executor(
|
||||||
val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader")
|
val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader")
|
||||||
.asInstanceOf[Class[_ <: ClassLoader]]
|
.asInstanceOf[Class[_ <: ClassLoader]]
|
||||||
val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
|
val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
|
||||||
return constructor.newInstance(classUri, parent)
|
constructor.newInstance(classUri, parent)
|
||||||
} catch {
|
} catch {
|
||||||
case _: ClassNotFoundException =>
|
case _: ClassNotFoundException =>
|
||||||
logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!")
|
logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!")
|
||||||
|
@ -237,7 +278,7 @@ private[spark] class Executor(
|
||||||
null
|
null
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return parent
|
parent
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -69,6 +69,12 @@ private[spark] class StandaloneExecutorBackend(
|
||||||
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
|
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case KillTask(taskId, _) =>
|
||||||
|
logInfo("Kill task %s %s".format(taskId, executorId))
|
||||||
|
if (executor != null) {
|
||||||
|
executor.killTask(taskId)
|
||||||
|
}
|
||||||
|
|
||||||
case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
|
case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
|
||||||
logError("Driver terminated or disconnected! Shutting down.")
|
logError("Driver terminated or disconnected! Shutting down.")
|
||||||
System.exit(1)
|
System.exit(1)
|
||||||
|
|
|
@ -128,6 +128,7 @@ class DAGScheduler(
|
||||||
// stray messages to detect.
|
// stray messages to detect.
|
||||||
val failedEpoch = new HashMap[String, Long]
|
val failedEpoch = new HashMap[String, Long]
|
||||||
|
|
||||||
|
// stage id to the active job
|
||||||
val idToActiveJob = new HashMap[Int, ActiveJob]
|
val idToActiveJob = new HashMap[Int, ActiveJob]
|
||||||
|
|
||||||
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
|
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
|
||||||
|
@ -334,6 +335,35 @@ class DAGScheduler(
|
||||||
listener.awaitResult() // Will throw an exception if the job fails
|
listener.awaitResult() // Will throw an exception if the job fails
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def killJob(jobId: Int) {
|
||||||
|
activeJobs.find(job => job.jobId == jobId).foreach(job => killJob(job))
|
||||||
|
}
|
||||||
|
|
||||||
|
private def killJob(job: ActiveJob) {
|
||||||
|
logInfo("Killing Job and cleaning up stages %d".format(job.jobId))
|
||||||
|
activeJobs.remove(job)
|
||||||
|
idToActiveJob.remove(job.jobId)
|
||||||
|
val stage = job.finalStage
|
||||||
|
resultStageToJob.remove(stage)
|
||||||
|
killStage(stage)
|
||||||
|
// recursively remove all parent stages
|
||||||
|
stage.parents.foreach(p => killStage(p))
|
||||||
|
job.listener.jobFailed(new SparkException("Job killed"))
|
||||||
|
}
|
||||||
|
|
||||||
|
private def killStage(stage: Stage) {
|
||||||
|
logInfo("Killing Stage %s".format(stage.id))
|
||||||
|
stageIdToStage.remove(stage.id)
|
||||||
|
if (stage.isShuffleMap) {
|
||||||
|
shuffleToMapStage.remove(stage.id)
|
||||||
|
}
|
||||||
|
waiting.remove(stage)
|
||||||
|
pendingTasks.remove(stage)
|
||||||
|
running.remove(stage)
|
||||||
|
taskSched.killTasks(stage.id)
|
||||||
|
stage.parents.foreach(p => killStage(p))
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Process one event retrieved from the event queue.
|
* Process one event retrieved from the event queue.
|
||||||
* Returns true if we should stop the event loop.
|
* Returns true if we should stop the event loop.
|
||||||
|
@ -579,6 +609,11 @@ class DAGScheduler(
|
||||||
*/
|
*/
|
||||||
private def handleTaskCompletion(event: CompletionEvent) {
|
private def handleTaskCompletion(event: CompletionEvent) {
|
||||||
val task = event.task
|
val task = event.task
|
||||||
|
|
||||||
|
if (!stageIdToStage.contains(task.stageId)) {
|
||||||
|
// Skip all the actions if the stage has been cancelled.
|
||||||
|
return
|
||||||
|
}
|
||||||
val stage = stageIdToStage(task.stageId)
|
val stage = stageIdToStage(task.stageId)
|
||||||
|
|
||||||
def markStageAsFinished(stage: Stage) = {
|
def markStageAsFinished(stage: Stage) = {
|
||||||
|
|
|
@ -38,17 +38,17 @@ private[spark] object ResultTask {
|
||||||
synchronized {
|
synchronized {
|
||||||
val old = serializedInfoCache.get(stageId).orNull
|
val old = serializedInfoCache.get(stageId).orNull
|
||||||
if (old != null) {
|
if (old != null) {
|
||||||
return old
|
old
|
||||||
} else {
|
} else {
|
||||||
val out = new ByteArrayOutputStream
|
val out = new ByteArrayOutputStream
|
||||||
val ser = SparkEnv.get.closureSerializer.newInstance
|
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||||
val objOut = ser.serializeStream(new GZIPOutputStream(out))
|
val objOut = ser.serializeStream(new GZIPOutputStream(out))
|
||||||
objOut.writeObject(rdd)
|
objOut.writeObject(rdd)
|
||||||
objOut.writeObject(func)
|
objOut.writeObject(func)
|
||||||
objOut.close()
|
objOut.close()
|
||||||
val bytes = out.toByteArray
|
val bytes = out.toByteArray
|
||||||
serializedInfoCache.put(stageId, bytes)
|
serializedInfoCache.put(stageId, bytes)
|
||||||
return bytes
|
bytes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -56,11 +56,11 @@ private[spark] object ResultTask {
|
||||||
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = {
|
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = {
|
||||||
val loader = Thread.currentThread.getContextClassLoader
|
val loader = Thread.currentThread.getContextClassLoader
|
||||||
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
|
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
|
||||||
val ser = SparkEnv.get.closureSerializer.newInstance
|
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||||
val objIn = ser.deserializeStream(in)
|
val objIn = ser.deserializeStream(in)
|
||||||
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
|
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
|
||||||
val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
|
val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
|
||||||
return (rdd, func)
|
(rdd, func)
|
||||||
}
|
}
|
||||||
|
|
||||||
def clearCache() {
|
def clearCache() {
|
||||||
|
@ -75,25 +75,20 @@ private[spark] class ResultTask[T, U](
|
||||||
stageId: Int,
|
stageId: Int,
|
||||||
var rdd: RDD[T],
|
var rdd: RDD[T],
|
||||||
var func: (TaskContext, Iterator[T]) => U,
|
var func: (TaskContext, Iterator[T]) => U,
|
||||||
var partition: Int,
|
_partition: Int,
|
||||||
@transient locs: Seq[TaskLocation],
|
@transient locs: Seq[TaskLocation],
|
||||||
var outputId: Int)
|
var outputId: Int)
|
||||||
extends Task[U](stageId) with Externalizable {
|
extends Task[U](stageId, _partition) with Externalizable {
|
||||||
|
|
||||||
def this() = this(0, null, null, 0, null, 0)
|
def this() = this(0, null, null, 0, null, 0)
|
||||||
|
|
||||||
var split = if (rdd == null) {
|
var split = if (rdd == null) null else rdd.partitions(partition)
|
||||||
null
|
|
||||||
} else {
|
|
||||||
rdd.partitions(partition)
|
|
||||||
}
|
|
||||||
|
|
||||||
@transient private val preferredLocs: Seq[TaskLocation] = {
|
@transient private val preferredLocs: Seq[TaskLocation] = {
|
||||||
if (locs == null) Nil else locs.toSet.toSeq
|
if (locs == null) Nil else locs.toSet.toSeq
|
||||||
}
|
}
|
||||||
|
|
||||||
override def run(attemptId: Long): U = {
|
override def runTask(context: TaskContext): U = {
|
||||||
val context = new TaskContext(stageId, partition, attemptId, runningLocally = false)
|
|
||||||
metrics = Some(context.taskMetrics)
|
metrics = Some(context.taskMetrics)
|
||||||
try {
|
try {
|
||||||
func(context, rdd.iterator(split, context))
|
func(context, rdd.iterator(split, context))
|
||||||
|
|
|
@ -53,7 +53,7 @@ private[spark] object ShuffleMapTask {
|
||||||
objOut.close()
|
objOut.close()
|
||||||
val bytes = out.toByteArray
|
val bytes = out.toByteArray
|
||||||
serializedInfoCache.put(stageId, bytes)
|
serializedInfoCache.put(stageId, bytes)
|
||||||
return bytes
|
bytes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -66,7 +66,7 @@ private[spark] object ShuffleMapTask {
|
||||||
val objIn = ser.deserializeStream(in)
|
val objIn = ser.deserializeStream(in)
|
||||||
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
|
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
|
||||||
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
|
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
|
||||||
return (rdd, dep)
|
(rdd, dep)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ private[spark] object ShuffleMapTask {
|
||||||
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
|
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
|
||||||
val objIn = new ObjectInputStream(in)
|
val objIn = new ObjectInputStream(in)
|
||||||
val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
|
val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
|
||||||
return (HashMap(set.toSeq: _*))
|
HashMap(set.toSeq: _*)
|
||||||
}
|
}
|
||||||
|
|
||||||
def clearCache() {
|
def clearCache() {
|
||||||
|
@ -89,9 +89,9 @@ private[spark] class ShuffleMapTask(
|
||||||
stageId: Int,
|
stageId: Int,
|
||||||
var rdd: RDD[_],
|
var rdd: RDD[_],
|
||||||
var dep: ShuffleDependency[_,_],
|
var dep: ShuffleDependency[_,_],
|
||||||
var partition: Int,
|
_partition: Int,
|
||||||
@transient private var locs: Seq[TaskLocation])
|
@transient private var locs: Seq[TaskLocation])
|
||||||
extends Task[MapStatus](stageId)
|
extends Task[MapStatus](stageId, _partition)
|
||||||
with Externalizable
|
with Externalizable
|
||||||
with Logging {
|
with Logging {
|
||||||
|
|
||||||
|
@ -129,11 +129,9 @@ private[spark] class ShuffleMapTask(
|
||||||
split = in.readObject().asInstanceOf[Partition]
|
split = in.readObject().asInstanceOf[Partition]
|
||||||
}
|
}
|
||||||
|
|
||||||
override def run(attemptId: Long): MapStatus = {
|
override def runTask(context: TaskContext): MapStatus = {
|
||||||
val numOutputSplits = dep.partitioner.numPartitions
|
val numOutputSplits = dep.partitioner.numPartitions
|
||||||
|
metrics = Some(context.taskMetrics)
|
||||||
val taskContext = new TaskContext(stageId, partition, attemptId, runningLocally = false)
|
|
||||||
metrics = Some(taskContext.taskMetrics)
|
|
||||||
|
|
||||||
val blockManager = SparkEnv.get.blockManager
|
val blockManager = SparkEnv.get.blockManager
|
||||||
var shuffle: ShuffleBlocks = null
|
var shuffle: ShuffleBlocks = null
|
||||||
|
@ -146,7 +144,7 @@ private[spark] class ShuffleMapTask(
|
||||||
buckets = shuffle.acquireWriters(partition)
|
buckets = shuffle.acquireWriters(partition)
|
||||||
|
|
||||||
// Write the map output to its associated buckets.
|
// Write the map output to its associated buckets.
|
||||||
for (elem <- rdd.iterator(split, taskContext)) {
|
for (elem <- rdd.iterator(split, context)) {
|
||||||
val pair = elem.asInstanceOf[Product2[Any, Any]]
|
val pair = elem.asInstanceOf[Product2[Any, Any]]
|
||||||
val bucketId = dep.partitioner.getPartition(pair._1)
|
val bucketId = dep.partitioner.getPartition(pair._1)
|
||||||
buckets.writers(bucketId).write(pair)
|
buckets.writers(bucketId).write(pair)
|
||||||
|
@ -167,7 +165,7 @@ private[spark] class ShuffleMapTask(
|
||||||
shuffleMetrics.shuffleBytesWritten = totalBytes
|
shuffleMetrics.shuffleBytesWritten = totalBytes
|
||||||
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
|
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
|
||||||
|
|
||||||
return new MapStatus(blockManager.blockManagerId, compressedSizes)
|
new MapStatus(blockManager.blockManagerId, compressedSizes)
|
||||||
} catch { case e: Exception =>
|
} catch { case e: Exception =>
|
||||||
// If there is an exception from running the task, revert the partial writes
|
// If there is an exception from running the task, revert the partial writes
|
||||||
// and throw the exception upstream to Spark.
|
// and throw the exception upstream to Spark.
|
||||||
|
@ -181,7 +179,7 @@ private[spark] class ShuffleMapTask(
|
||||||
shuffle.releaseWriters(buckets)
|
shuffle.releaseWriters(buckets)
|
||||||
}
|
}
|
||||||
// Execute the callbacks on task completion.
|
// Execute the callbacks on task completion.
|
||||||
taskContext.executeOnCompleteCallbacks()
|
context.executeOnCompleteCallbacks()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,25 +17,59 @@
|
||||||
|
|
||||||
package org.apache.spark.scheduler
|
package org.apache.spark.scheduler
|
||||||
|
|
||||||
import org.apache.spark.serializer.SerializerInstance
|
|
||||||
import java.io.{DataInputStream, DataOutputStream}
|
import java.io.{DataInputStream, DataOutputStream}
|
||||||
import java.nio.ByteBuffer
|
import java.nio.ByteBuffer
|
||||||
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
|
|
||||||
import org.apache.spark.util.ByteBufferInputStream
|
|
||||||
import scala.collection.mutable.HashMap
|
import scala.collection.mutable.HashMap
|
||||||
|
|
||||||
|
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
|
||||||
|
|
||||||
|
import org.apache.spark.TaskContext
|
||||||
import org.apache.spark.executor.TaskMetrics
|
import org.apache.spark.executor.TaskMetrics
|
||||||
|
import org.apache.spark.serializer.SerializerInstance
|
||||||
|
import org.apache.spark.util.ByteBufferInputStream
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A task to execute on a worker node.
|
* A task to execute on a worker node.
|
||||||
*/
|
*/
|
||||||
private[spark] abstract class Task[T](val stageId: Int) extends Serializable {
|
private[spark] abstract class Task[T](val stageId: Int, var partition: Int) extends Serializable {
|
||||||
def run(attemptId: Long): T
|
|
||||||
|
def run(attemptId: Long): T = {
|
||||||
|
context = new TaskContext(stageId, partition, attemptId, runningLocally = false)
|
||||||
|
if (killed) {
|
||||||
|
kill()
|
||||||
|
}
|
||||||
|
runTask(context)
|
||||||
|
}
|
||||||
|
|
||||||
|
def runTask(context: TaskContext): T
|
||||||
|
|
||||||
def preferredLocations: Seq[TaskLocation] = Nil
|
def preferredLocations: Seq[TaskLocation] = Nil
|
||||||
|
|
||||||
var epoch: Long = -1 // Map output tracker epoch. Will be set by TaskScheduler.
|
// Map output tracker epoch. Will be set by TaskScheduler.
|
||||||
|
var epoch: Long = -1
|
||||||
|
|
||||||
var metrics: Option[TaskMetrics] = None
|
var metrics: Option[TaskMetrics] = None
|
||||||
|
|
||||||
|
// Task context, to be initialized in run().
|
||||||
|
@transient protected var context: TaskContext = _
|
||||||
|
|
||||||
|
// A flag to indicate whether the task is killed. This is used in case context is not yet
|
||||||
|
// initialized when kill() is invoked.
|
||||||
|
@volatile @transient private var killed = false
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Kills a task by setting the interrupted flag to true. This relies on the upper level Spark
|
||||||
|
* code and user code to properly handle the flag. This function should be idempotent so it can
|
||||||
|
* be called multiple times.
|
||||||
|
*/
|
||||||
|
def kill() {
|
||||||
|
killed = true
|
||||||
|
if (context != null) {
|
||||||
|
context.interrupted = true
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.scheduler
|
||||||
|
|
||||||
import org.apache.spark.scheduler.cluster.Pool
|
import org.apache.spark.scheduler.cluster.Pool
|
||||||
import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
|
import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler.
|
* Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler.
|
||||||
* These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
|
* These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
|
||||||
|
@ -44,6 +45,9 @@ private[spark] trait TaskScheduler {
|
||||||
// Submit a sequence of tasks to run.
|
// Submit a sequence of tasks to run.
|
||||||
def submitTasks(taskSet: TaskSet): Unit
|
def submitTasks(taskSet: TaskSet): Unit
|
||||||
|
|
||||||
|
// Kill the stage.
|
||||||
|
def killTasks(stageId: Int)
|
||||||
|
|
||||||
// Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
|
// Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
|
||||||
def setListener(listener: TaskSchedulerListener): Unit
|
def setListener(listener: TaskSchedulerListener): Unit
|
||||||
|
|
||||||
|
|
|
@ -31,5 +31,9 @@ private[spark] class TaskSet(
|
||||||
val properties: Properties) {
|
val properties: Properties) {
|
||||||
val id: String = stageId + "." + attempt
|
val id: String = stageId + "." + attempt
|
||||||
|
|
||||||
|
def kill() {
|
||||||
|
tasks.foreach(_.kill())
|
||||||
|
}
|
||||||
|
|
||||||
override def toString: String = "TaskSet " + id
|
override def toString: String = "TaskSet " + id
|
||||||
}
|
}
|
||||||
|
|
|
@ -166,6 +166,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
|
||||||
backend.reviveOffers()
|
backend.reviveOffers()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def killTasks(stageId: Int) {
|
||||||
|
synchronized {
|
||||||
|
schedulableBuilder.popTaskSetManagers(stageId).foreach { t =>
|
||||||
|
val ts = t.asInstanceOf[TaskSetManager].taskSet
|
||||||
|
ts.kill()
|
||||||
|
val taskIds = taskSetTaskIds(ts.id)
|
||||||
|
taskIds.foreach { tid =>
|
||||||
|
val execId = taskIdToExecutorId(tid)
|
||||||
|
backend.killTask(tid, execId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def taskSetFinished(manager: TaskSetManager) {
|
def taskSetFinished(manager: TaskSetManager) {
|
||||||
this.synchronized {
|
this.synchronized {
|
||||||
activeTaskSets -= manager.taskSet.id
|
activeTaskSets -= manager.taskSet.id
|
||||||
|
|
|
@ -32,8 +32,17 @@ import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
|
||||||
* addTaskSetManager: build the leaf nodes(TaskSetManagers)
|
* addTaskSetManager: build the leaf nodes(TaskSetManagers)
|
||||||
*/
|
*/
|
||||||
private[spark] trait SchedulableBuilder {
|
private[spark] trait SchedulableBuilder {
|
||||||
|
def rootPool: Pool
|
||||||
|
|
||||||
def buildPools()
|
def buildPools()
|
||||||
|
|
||||||
def addTaskSetManager(manager: Schedulable, properties: Properties)
|
def addTaskSetManager(manager: Schedulable, properties: Properties)
|
||||||
|
|
||||||
|
def popTaskSetManagers(stageId: Int): Iterable[Schedulable] = {
|
||||||
|
val taskSets = rootPool.schedulableQueue.filter(_.stageId == stageId)
|
||||||
|
taskSets.foreach(rootPool.removeSchedulable)
|
||||||
|
taskSets
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] class FIFOSchedulableBuilder(val rootPool: Pool)
|
private[spark] class FIFOSchedulableBuilder(val rootPool: Pool)
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
package org.apache.spark.scheduler.cluster
|
package org.apache.spark.scheduler.cluster
|
||||||
|
|
||||||
import org.apache.spark.{SparkContext}
|
import org.apache.spark.SparkContext
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A backend interface for cluster scheduling systems that allows plugging in different ones under
|
* A backend interface for cluster scheduling systems that allows plugging in different ones under
|
||||||
|
@ -30,6 +30,10 @@ private[spark] trait SchedulerBackend {
|
||||||
def reviveOffers(): Unit
|
def reviveOffers(): Unit
|
||||||
def defaultParallelism(): Int
|
def defaultParallelism(): Int
|
||||||
|
|
||||||
|
def killTask(taskId: Long, executorId: String) {
|
||||||
|
throw new UnsupportedOperationException
|
||||||
|
}
|
||||||
|
|
||||||
// Memory used by each executor (in megabytes)
|
// Memory used by each executor (in megabytes)
|
||||||
protected val executorMemory: Int = SparkContext.executorMemoryRequested
|
protected val executorMemory: Int = SparkContext.executorMemoryRequested
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,8 @@ private[spark] object StandaloneClusterMessages {
|
||||||
// Driver to executors
|
// Driver to executors
|
||||||
case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage
|
case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage
|
||||||
|
|
||||||
|
case class KillTask(taskId: Long, executor: String) extends StandaloneClusterMessage
|
||||||
|
|
||||||
case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
|
case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
|
||||||
extends StandaloneClusterMessage
|
extends StandaloneClusterMessage
|
||||||
|
|
||||||
|
|
|
@ -90,6 +90,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
|
||||||
case ReviveOffers =>
|
case ReviveOffers =>
|
||||||
makeOffers()
|
makeOffers()
|
||||||
|
|
||||||
|
case KillTask(taskId, executorId) =>
|
||||||
|
executorActor(executorId) ! KillTask(taskId, executorId)
|
||||||
|
|
||||||
case StopDriver =>
|
case StopDriver =>
|
||||||
sender ! true
|
sender ! true
|
||||||
context.stop(self)
|
context.stop(self)
|
||||||
|
@ -179,6 +182,10 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
|
||||||
driverActor ! ReviveOffers
|
driverActor ! ReviveOffers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def killTask(taskId: Long, executorId: String) {
|
||||||
|
driverActor ! KillTask(taskId, executorId)
|
||||||
|
}
|
||||||
|
|
||||||
override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism"))
|
override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism"))
|
||||||
.map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2))
|
.map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2))
|
||||||
|
|
||||||
|
|
|
@ -128,6 +128,12 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def killTasks(stageId: Int) = synchronized {
|
||||||
|
schedulableBuilder.popTaskSetManagers(stageId).foreach {
|
||||||
|
_.asInstanceOf[TaskSetManager].taskSet.kill()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
|
def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
|
||||||
synchronized {
|
synchronized {
|
||||||
var freeCpuCores = freeCores
|
var freeCpuCores = freeCores
|
||||||
|
|
|
@ -58,7 +58,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
|
||||||
}
|
}
|
||||||
|
|
||||||
whenExecuting(blockManager) {
|
whenExecuting(blockManager) {
|
||||||
val context = new TaskContext(0, 0, 0, runningLocally = false, null)
|
val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
|
||||||
|
taskMetrics = null)
|
||||||
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
|
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
|
||||||
assert(value.toList === List(1, 2, 3, 4))
|
assert(value.toList === List(1, 2, 3, 4))
|
||||||
}
|
}
|
||||||
|
@ -70,7 +71,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
|
||||||
}
|
}
|
||||||
|
|
||||||
whenExecuting(blockManager) {
|
whenExecuting(blockManager) {
|
||||||
val context = new TaskContext(0, 0, 0, runningLocally = false, null)
|
val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
|
||||||
|
taskMetrics = null)
|
||||||
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
|
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
|
||||||
assert(value.toList === List(5, 6, 7))
|
assert(value.toList === List(5, 6, 7))
|
||||||
}
|
}
|
||||||
|
@ -83,7 +85,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
|
||||||
}
|
}
|
||||||
|
|
||||||
whenExecuting(blockManager) {
|
whenExecuting(blockManager) {
|
||||||
val context = new TaskContext(0, 0, 0, runningLocally = true, null)
|
val context = new TaskContext(0, 0, 0, runningLocally = true, interrupted = false,
|
||||||
|
taskMetrics = null)
|
||||||
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
|
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
|
||||||
assert(value.toList === List(1, 2, 3, 4))
|
assert(value.toList === List(1, 2, 3, 4))
|
||||||
}
|
}
|
||||||
|
|
|
@ -495,7 +495,7 @@ public class JavaAPISuite implements Serializable {
|
||||||
@Test
|
@Test
|
||||||
public void iterator() {
|
public void iterator() {
|
||||||
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
|
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
|
||||||
TaskContext context = new TaskContext(0, 0, 0, false, null);
|
TaskContext context = new TaskContext(0, 0, 0, false, false, null);
|
||||||
Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
|
Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,12 +29,13 @@ import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.Partition
|
import org.apache.spark.Partition
|
||||||
import org.apache.spark.TaskContext
|
import org.apache.spark.TaskContext
|
||||||
import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency}
|
import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency}
|
||||||
import org.apache.spark.{FetchFailed, Success, TaskEndReason}
|
import org.apache.spark.{Success, TaskEndReason}
|
||||||
import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster}
|
import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster}
|
||||||
|
|
||||||
import org.apache.spark.scheduler.cluster.Pool
|
import org.apache.spark.scheduler.cluster.Pool
|
||||||
import org.apache.spark.scheduler.cluster.SchedulingMode
|
import org.apache.spark.scheduler.cluster.SchedulingMode
|
||||||
import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
|
import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
|
||||||
|
import org.apache.spark.FetchFailed
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
|
* Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
|
||||||
|
@ -62,6 +63,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
|
||||||
taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
|
taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
|
||||||
taskSets += taskSet
|
taskSets += taskSet
|
||||||
}
|
}
|
||||||
|
override def killTasks(stageId: Int) {}
|
||||||
override def setListener(listener: TaskSchedulerListener) = {}
|
override def setListener(listener: TaskSchedulerListener) = {}
|
||||||
override def defaultParallelism() = 2
|
override def defaultParallelism() = 2
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,10 +17,11 @@
|
||||||
|
|
||||||
package org.apache.spark.scheduler.cluster
|
package org.apache.spark.scheduler.cluster
|
||||||
|
|
||||||
|
import org.apache.spark.TaskContext
|
||||||
import org.apache.spark.scheduler.{TaskLocation, Task}
|
import org.apache.spark.scheduler.{TaskLocation, Task}
|
||||||
|
|
||||||
class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId) {
|
class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
|
||||||
override def run(attemptId: Long): Int = 0
|
override def runTask(context: TaskContext): Int = 0
|
||||||
|
|
||||||
override def preferredLocations: Seq[TaskLocation] = prefLocs
|
override def preferredLocations: Seq[TaskLocation] = prefLocs
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue