[SPARK-3543] Write TaskContext in Java and expose it through a static accessor.

Author: Prashant Sharma <prashant.s@imaginea.com>
Author: Shashank Sharma <shashank21j@gmail.com>

Closes #2425 from ScrapCodes/SPARK-3543/withTaskContext and squashes the following commits:

8ae414c [Shashank Sharma] CR
ee8bd00 [Prashant Sharma] Added internal API in docs comments.
ddb8cbe [Prashant Sharma] Moved setting the thread local to where TaskContext is instantiated.
a7d5e23 [Prashant Sharma] Added doc comments.
edf945e [Prashant Sharma] Code review git add -A
f716fd1 [Prashant Sharma] introduced thread local for getting the task context.
333c7d6 [Prashant Sharma] Translated Task context from scala to java.
This commit is contained in:
Prashant Sharma 2014-09-26 21:29:54 -07:00 committed by Reynold Xin
parent f872e4fb80
commit 5e34855cf0
7 changed files with 284 additions and 131 deletions

View file

@ -0,0 +1,274 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import scala.Function0;
import scala.Function1;
import scala.Unit;
import scala.collection.JavaConversions;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.util.TaskCompletionListener;
import org.apache.spark.util.TaskCompletionListenerException;
/**
* :: DeveloperApi ::
* Contextual information about a task which can be read or mutated during execution.
*/
@DeveloperApi
public class TaskContext implements Serializable {
private int stageId;
private int partitionId;
private long attemptId;
private boolean runningLocally;
private TaskMetrics taskMetrics;
/**
* :: DeveloperApi ::
* Contextual information about a task which can be read or mutated during execution.
*
* @param stageId stage id
* @param partitionId index of the partition
* @param attemptId the number of attempts to execute this task
* @param runningLocally whether the task is running locally in the driver JVM
* @param taskMetrics performance metrics of the task
*/
@DeveloperApi
public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean runningLocally,
TaskMetrics taskMetrics) {
this.attemptId = attemptId;
this.partitionId = partitionId;
this.runningLocally = runningLocally;
this.stageId = stageId;
this.taskMetrics = taskMetrics;
}
/**
* :: DeveloperApi ::
* Contextual information about a task which can be read or mutated during execution.
*
* @param stageId stage id
* @param partitionId index of the partition
* @param attemptId the number of attempts to execute this task
* @param runningLocally whether the task is running locally in the driver JVM
*/
@DeveloperApi
public TaskContext(Integer stageId, Integer partitionId, Long attemptId,
Boolean runningLocally) {
this.attemptId = attemptId;
this.partitionId = partitionId;
this.runningLocally = runningLocally;
this.stageId = stageId;
this.taskMetrics = TaskMetrics.empty();
}
/**
* :: DeveloperApi ::
* Contextual information about a task which can be read or mutated during execution.
*
* @param stageId stage id
* @param partitionId index of the partition
* @param attemptId the number of attempts to execute this task
*/
@DeveloperApi
public TaskContext(Integer stageId, Integer partitionId, Long attemptId) {
this.attemptId = attemptId;
this.partitionId = partitionId;
this.runningLocally = false;
this.stageId = stageId;
this.taskMetrics = TaskMetrics.empty();
}
private static ThreadLocal<TaskContext> taskContext =
new ThreadLocal<TaskContext>();
/**
* :: Internal API ::
* This is spark internal API, not intended to be called from user programs.
*/
public static void setTaskContext(TaskContext tc) {
taskContext.set(tc);
}
public static TaskContext get() {
return taskContext.get();
}
/**
* :: Internal API ::
*/
public static void remove() {
taskContext.remove();
}
// List of callback functions to execute when the task completes.
private transient List<TaskCompletionListener> onCompleteCallbacks =
new ArrayList<TaskCompletionListener>();
// Whether the corresponding task has been killed.
private volatile Boolean interrupted = false;
// Whether the task has completed.
private volatile Boolean completed = false;
/**
* Checks whether the task has completed.
*/
public Boolean isCompleted() {
return completed;
}
/**
* Checks whether the task has been killed.
*/
public Boolean isInterrupted() {
return interrupted;
}
/**
* Add a (Java friendly) listener to be executed on task completion.
* This will be called in all situation - success, failure, or cancellation.
* <p/>
* An example use is for HadoopRDD to register a callback to close the input stream.
*/
public TaskContext addTaskCompletionListener(TaskCompletionListener listener) {
onCompleteCallbacks.add(listener);
return this;
}
/**
* Add a listener in the form of a Scala closure to be executed on task completion.
* This will be called in all situations - success, failure, or cancellation.
* <p/>
* An example use is for HadoopRDD to register a callback to close the input stream.
*/
public TaskContext addTaskCompletionListener(final Function1<TaskContext, Unit> f) {
onCompleteCallbacks.add(new TaskCompletionListener() {
@Override
public void onTaskCompletion(TaskContext context) {
f.apply(context);
}
});
return this;
}
/**
* Add a callback function to be executed on task completion. An example use
* is for HadoopRDD to register a callback to close the input stream.
* Will be called in any situation - success, failure, or cancellation.
*
* Deprecated: use addTaskCompletionListener
*
* @param f Callback function.
*/
@Deprecated
public void addOnCompleteCallback(final Function0<Unit> f) {
onCompleteCallbacks.add(new TaskCompletionListener() {
@Override
public void onTaskCompletion(TaskContext context) {
f.apply();
}
});
}
/**
* ::Internal API::
* Marks the task as completed and triggers the listeners.
*/
public void markTaskCompleted() throws TaskCompletionListenerException {
completed = true;
List<String> errorMsgs = new ArrayList<String>(2);
// Process complete callbacks in the reverse order of registration
List<TaskCompletionListener> revlist =
new ArrayList<TaskCompletionListener>(onCompleteCallbacks);
Collections.reverse(revlist);
for (TaskCompletionListener tcl: revlist) {
try {
tcl.onTaskCompletion(this);
} catch (Throwable e) {
errorMsgs.add(e.getMessage());
}
}
if (!errorMsgs.isEmpty()) {
throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs));
}
}
/**
* ::Internal API::
* Marks the task for interruption, i.e. cancellation.
*/
public void markInterrupted() {
interrupted = true;
}
@Deprecated
/** Deprecated: use getStageId() */
public int stageId() {
return stageId;
}
@Deprecated
/** Deprecated: use getPartitionId() */
public int partitionId() {
return partitionId;
}
@Deprecated
/** Deprecated: use getAttemptId() */
public long attemptId() {
return attemptId;
}
@Deprecated
/** Deprecated: use getRunningLocally() */
public boolean runningLocally() {
return runningLocally;
}
public boolean getRunningLocally() {
return runningLocally;
}
public int getStageId() {
return stageId;
}
public int getPartitionId() {
return partitionId;
}
public long getAttemptId() {
return attemptId;
}
/** ::Internal API:: */
public TaskMetrics taskMetrics() {
return taskMetrics;
}
}

View file

@ -1,126 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener}
/**
* :: DeveloperApi ::
* Contextual information about a task which can be read or mutated during execution.
*
* @param stageId stage id
* @param partitionId index of the partition
* @param attemptId the number of attempts to execute this task
* @param runningLocally whether the task is running locally in the driver JVM
* @param taskMetrics performance metrics of the task
*/
@DeveloperApi
class TaskContext(
val stageId: Int,
val partitionId: Int,
val attemptId: Long,
val runningLocally: Boolean = false,
private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends Serializable with Logging {
@deprecated("use partitionId", "0.8.1")
def splitId = partitionId
// List of callback functions to execute when the task completes.
@transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
// Whether the corresponding task has been killed.
@volatile private var interrupted: Boolean = false
// Whether the task has completed.
@volatile private var completed: Boolean = false
/** Checks whether the task has completed. */
def isCompleted: Boolean = completed
/** Checks whether the task has been killed. */
def isInterrupted: Boolean = interrupted
// TODO: Also track whether the task has completed successfully or with exception.
/**
* Add a (Java friendly) listener to be executed on task completion.
* This will be called in all situation - success, failure, or cancellation.
*
* An example use is for HadoopRDD to register a callback to close the input stream.
*/
def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
onCompleteCallbacks += listener
this
}
/**
* Add a listener in the form of a Scala closure to be executed on task completion.
* This will be called in all situation - success, failure, or cancellation.
*
* An example use is for HadoopRDD to register a callback to close the input stream.
*/
def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
onCompleteCallbacks += new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = f(context)
}
this
}
/**
* Add a callback function to be executed on task completion. An example use
* is for HadoopRDD to register a callback to close the input stream.
* Will be called in any situation - success, failure, or cancellation.
* @param f Callback function.
*/
@deprecated("use addTaskCompletionListener", "1.1.0")
def addOnCompleteCallback(f: () => Unit) {
onCompleteCallbacks += new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = f()
}
}
/** Marks the task as completed and triggers the listeners. */
private[spark] def markTaskCompleted(): Unit = {
completed = true
val errorMsgs = new ArrayBuffer[String](2)
// Process complete callbacks in the reverse order of registration
onCompleteCallbacks.reverse.foreach { listener =>
try {
listener.onTaskCompletion(this)
} catch {
case e: Throwable =>
errorMsgs += e.getMessage
logError("Error in TaskCompletionListener", e)
}
}
if (errorMsgs.nonEmpty) {
throw new TaskCompletionListenerException(errorMsgs)
}
}
/** Marks the task for interruption, i.e. cancellation. */
private[spark] def markInterrupted(): Unit = {
interrupted = true
}
}

View file

@ -619,6 +619,7 @@ abstract class RDD[T: ClassTag](
* should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
@DeveloperApi
@deprecated("use TaskContext.get", "1.2.0")
def mapPartitionsWithContext[U: ClassTag](
f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = {

View file

@ -634,12 +634,14 @@ class DAGScheduler(
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
val taskContext =
new TaskContext(job.finalStage.id, job.partitions(0), 0, runningLocally = true)
new TaskContext(job.finalStage.id, job.partitions(0), 0, true)
TaskContext.setTaskContext(taskContext)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result)
} finally {
taskContext.markTaskCompleted()
TaskContext.remove()
}
} catch {
case e: Exception =>

View file

@ -45,7 +45,8 @@ import org.apache.spark.util.Utils
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
final def run(attemptId: Long): T = {
context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
context = new TaskContext(stageId, partitionId, attemptId, false)
TaskContext.setTaskContext(context)
context.taskMetrics.hostname = Utils.localHostName()
taskThread = Thread.currentThread()
if (_killed) {
@ -92,7 +93,8 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
if (interruptThread && taskThread != null) {
taskThread.interrupt()
}
}
TaskContext.remove()
}
}
/**

View file

@ -776,7 +776,7 @@ public class JavaAPISuite implements Serializable {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
TaskContext context = new TaskContext(0, 0, 0, false, new TaskMetrics());
TaskContext context = new TaskContext(0, 0, 0L, false, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}

View file

@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
whenExecuting(blockManager) {
val context = new TaskContext(0, 0, 0, runningLocally = true)
val context = new TaskContext(0, 0, 0, true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}