[SPARK-3027] TaskContext: tighten visibility and provide Java friendly callback API

Note this also passes the TaskContext itself to the TaskCompletionListener. In the future we can mark TaskContext with the exception object if exception occurs during task execution.

Author: Reynold Xin <rxin@apache.org>

Closes #1938 from rxin/TaskContext and squashes the following commits:

145de43 [Reynold Xin] Added JavaTaskCompletionListenerImpl for Java API friendly guarantee.
f435ea5 [Reynold Xin] Added license header for TaskCompletionListener.
dc4ed27 [Reynold Xin] [SPARK-3027] TaskContext: tighten the visibility and provide Java friendly callback API
This commit is contained in:
Reynold Xin 2014-08-14 18:37:02 -07:00
parent fa5a08e67d
commit 655699f8b7
14 changed files with 144 additions and 23 deletions

View file

@ -33,7 +33,7 @@ class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator
// is allowed. The assumption is that Thread.interrupted does not have a memory fence in read
// (just a volatile field in C), while context.interrupted is a volatile in the JVM, which
// introduces an expensive read fence.
if (context.interrupted) {
if (context.isInterrupted) {
throw new TaskKilledException
} else {
delegate.hasNext

View file

@ -21,10 +21,18 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.util.TaskCompletionListener
/**
* :: DeveloperApi ::
* Contextual information about a task which can be read or mutated during execution.
*
* @param stageId stage id
* @param partitionId index of the partition
* @param attemptId the number of attempts to execute this task
* @param runningLocally whether the task is running locally in the driver JVM
* @param taskMetrics performance metrics of the task
*/
@DeveloperApi
class TaskContext(
@ -39,13 +47,45 @@ class TaskContext(
def splitId = partitionId
// List of callback functions to execute when the task completes.
@transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]
@transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
// Whether the corresponding task has been killed.
@volatile var interrupted: Boolean = false
@volatile private var interrupted: Boolean = false
// Whether the task has completed, before the onCompleteCallbacks are executed.
@volatile var completed: Boolean = false
// Whether the task has completed.
@volatile private var completed: Boolean = false
/** Checks whether the task has completed. */
def isCompleted: Boolean = completed
/** Checks whether the task has been killed. */
def isInterrupted: Boolean = interrupted
// TODO: Also track whether the task has completed successfully or with exception.
/**
* Add a (Java friendly) listener to be executed on task completion.
* This will be called in all situation - success, failure, or cancellation.
*
* An example use is for HadoopRDD to register a callback to close the input stream.
*/
def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
onCompleteCallbacks += listener
this
}
/**
* Add a listener in the form of a Scala closure to be executed on task completion.
* This will be called in all situation - success, failure, or cancellation.
*
* An example use is for HadoopRDD to register a callback to close the input stream.
*/
def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
onCompleteCallbacks += new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = f(context)
}
this
}
/**
* Add a callback function to be executed on task completion. An example use
@ -53,13 +93,22 @@ class TaskContext(
* Will be called in any situation - success, failure, or cancellation.
* @param f Callback function.
*/
@deprecated("use addTaskCompletionListener", "1.1.0")
def addOnCompleteCallback(f: () => Unit) {
onCompleteCallbacks += f
onCompleteCallbacks += new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = f()
}
}
def executeOnCompleteCallbacks() {
/** Marks the task as completed and triggers the listeners. */
private[spark] def markTaskCompleted(): Unit = {
completed = true
// Process complete callbacks in the reverse order of registration
onCompleteCallbacks.reverse.foreach { _() }
onCompleteCallbacks.reverse.foreach { _.onTaskCompletion(this) }
}
/** Marks the task for interruption, i.e. cancellation. */
private[spark] def markInterrupted(): Unit = {
interrupted = true
}
}

View file

@ -68,7 +68,7 @@ private[spark] class PythonRDD(
// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)
context.addOnCompleteCallback { () =>
context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
// Cleanup the worker socket. This will also cause the Python worker to exit.
@ -137,7 +137,7 @@ private[spark] class PythonRDD(
}
} catch {
case e: Exception if context.interrupted =>
case e: Exception if context.isInterrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException
@ -176,7 +176,7 @@ private[spark] class PythonRDD(
/** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
def shutdownOnTaskCompletion() {
assert(context.completed)
assert(context.isCompleted)
this.interrupt()
}
@ -209,7 +209,7 @@ private[spark] class PythonRDD(
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
dataOut.flush()
} catch {
case e: Exception if context.completed || context.interrupted =>
case e: Exception if context.isCompleted || context.isInterrupted =>
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
case e: Exception =>
@ -235,10 +235,10 @@ private[spark] class PythonRDD(
override def run() {
// Kill the worker if it is interrupted, checking until task completion.
// TODO: This has a race condition if interruption occurs, as completed may still become true.
while (!context.interrupted && !context.completed) {
while (!context.isInterrupted && !context.isCompleted) {
Thread.sleep(2000)
}
if (!context.completed) {
if (!context.isCompleted) {
try {
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
env.destroyPythonWorker(pythonExec, envVars.toMap, worker)

View file

@ -141,7 +141,7 @@ private[spark] object CheckpointRDD extends Logging {
val deserializeStream = serializer.deserializeStream(fileInputStream)
// Register an on-task-completion callback to close the input stream.
context.addOnCompleteCallback(() => deserializeStream.close())
context.addTaskCompletionListener(context => deserializeStream.close())
deserializeStream.asIterator.asInstanceOf[Iterator[T]]
}

View file

@ -197,7 +197,7 @@ class HadoopRDD[K, V](
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
context.addOnCompleteCallback{ () => closeIfNeeded() }
context.addTaskCompletionListener{ context => closeIfNeeded() }
val key: K = reader.createKey()
val value: V = reader.createValue()

View file

@ -68,7 +68,7 @@ class JdbcRDD[T: ClassTag](
}
override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] {
context.addOnCompleteCallback{ () => closeIfNeeded() }
context.addTaskCompletionListener{ context => closeIfNeeded() }
val part = thePart.asInstanceOf[JdbcPartition]
val conn = getConnection()
val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)

View file

@ -129,7 +129,7 @@ class NewHadoopRDD[K, V](
context.taskMetrics.inputMetrics = Some(inputMetrics)
// Register an on-task-completion callback to close the input stream.
context.addOnCompleteCallback(() => close())
context.addTaskCompletionListener(context => close())
var havePair = false
var finished = false

View file

@ -634,7 +634,7 @@ class DAGScheduler(
val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result)
} finally {
taskContext.executeOnCompleteCallbacks()
taskContext.markTaskCompleted()
}
} catch {
case e: Exception =>

View file

@ -61,7 +61,7 @@ private[spark] class ResultTask[T, U](
try {
func(context, rdd.iterator(partition, context))
} finally {
context.executeOnCompleteCallbacks()
context.markTaskCompleted()
}
}

View file

@ -74,7 +74,7 @@ private[spark] class ShuffleMapTask(
}
throw e
} finally {
context.executeOnCompleteCallbacks()
context.markTaskCompleted()
}
}

View file

@ -87,7 +87,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
def kill(interruptThread: Boolean) {
_killed = true
if (context != null) {
context.interrupted = true
context.markInterrupted()
}
if (interruptThread && taskThread != null) {
taskThread.interrupt()

View file

@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util
import java.util.EventListener
import org.apache.spark.TaskContext
import org.apache.spark.annotation.DeveloperApi
/**
* :: DeveloperApi ::
*
* Listener providing a callback function to invoke when a task's execution completes.
*/
@DeveloperApi
trait TaskCompletionListener extends EventListener {
def onTaskCompletion(context: TaskContext)
}

View file

@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util;
import org.apache.spark.TaskContext;
/**
* A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and
* TaskContext is Java friendly.
*/
public class JavaTaskCompletionListenerImpl implements TaskCompletionListener {
@Override
public void onTaskCompletion(TaskContext context) {
context.isCompleted();
context.isInterrupted();
context.stageId();
context.partitionId();
context.runningLocally();
context.taskMetrics();
context.addTaskCompletionListener(this);
}
}

View file

@ -32,7 +32,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
val rdd = new RDD[String](sc, List()) {
override def getPartitions = Array[Partition](StubPartition(0))
override def compute(split: Partition, context: TaskContext) = {
context.addOnCompleteCallback(() => TaskContextSuite.completed = true)
context.addTaskCompletionListener(context => TaskContextSuite.completed = true)
sys.error("failed")
}
}