[SPARK-3543] Write TaskContext in Java and expose it through a static accessor.
Author: Prashant Sharma <prashant.s@imaginea.com> Author: Shashank Sharma <shashank21j@gmail.com> Closes #2425 from ScrapCodes/SPARK-3543/withTaskContext and squashes the following commits: 8ae414c [Shashank Sharma] CR ee8bd00 [Prashant Sharma] Added internal API in docs comments. ddb8cbe [Prashant Sharma] Moved setting the thread local to where TaskContext is instantiated. a7d5e23 [Prashant Sharma] Added doc comments. edf945e [Prashant Sharma] Code review git add -A f716fd1 [Prashant Sharma] introduced thread local for getting the task context. 333c7d6 [Prashant Sharma] Translated Task context from scala to java.
This commit is contained in:
parent
f872e4fb80
commit
5e34855cf0
274
core/src/main/java/org/apache/spark/TaskContext.java
Normal file
274
core/src/main/java/org/apache/spark/TaskContext.java
Normal file
|
@ -0,0 +1,274 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import scala.Function0;
|
||||||
|
import scala.Function1;
|
||||||
|
import scala.Unit;
|
||||||
|
import scala.collection.JavaConversions;
|
||||||
|
|
||||||
|
import org.apache.spark.annotation.DeveloperApi;
|
||||||
|
import org.apache.spark.executor.TaskMetrics;
|
||||||
|
import org.apache.spark.util.TaskCompletionListener;
|
||||||
|
import org.apache.spark.util.TaskCompletionListenerException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* :: DeveloperApi ::
|
||||||
|
* Contextual information about a task which can be read or mutated during execution.
|
||||||
|
*/
|
||||||
|
@DeveloperApi
|
||||||
|
public class TaskContext implements Serializable {
|
||||||
|
|
||||||
|
private int stageId;
|
||||||
|
private int partitionId;
|
||||||
|
private long attemptId;
|
||||||
|
private boolean runningLocally;
|
||||||
|
private TaskMetrics taskMetrics;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* :: DeveloperApi ::
|
||||||
|
* Contextual information about a task which can be read or mutated during execution.
|
||||||
|
*
|
||||||
|
* @param stageId stage id
|
||||||
|
* @param partitionId index of the partition
|
||||||
|
* @param attemptId the number of attempts to execute this task
|
||||||
|
* @param runningLocally whether the task is running locally in the driver JVM
|
||||||
|
* @param taskMetrics performance metrics of the task
|
||||||
|
*/
|
||||||
|
@DeveloperApi
|
||||||
|
public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean runningLocally,
|
||||||
|
TaskMetrics taskMetrics) {
|
||||||
|
this.attemptId = attemptId;
|
||||||
|
this.partitionId = partitionId;
|
||||||
|
this.runningLocally = runningLocally;
|
||||||
|
this.stageId = stageId;
|
||||||
|
this.taskMetrics = taskMetrics;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* :: DeveloperApi ::
|
||||||
|
* Contextual information about a task which can be read or mutated during execution.
|
||||||
|
*
|
||||||
|
* @param stageId stage id
|
||||||
|
* @param partitionId index of the partition
|
||||||
|
* @param attemptId the number of attempts to execute this task
|
||||||
|
* @param runningLocally whether the task is running locally in the driver JVM
|
||||||
|
*/
|
||||||
|
@DeveloperApi
|
||||||
|
public TaskContext(Integer stageId, Integer partitionId, Long attemptId,
|
||||||
|
Boolean runningLocally) {
|
||||||
|
this.attemptId = attemptId;
|
||||||
|
this.partitionId = partitionId;
|
||||||
|
this.runningLocally = runningLocally;
|
||||||
|
this.stageId = stageId;
|
||||||
|
this.taskMetrics = TaskMetrics.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* :: DeveloperApi ::
|
||||||
|
* Contextual information about a task which can be read or mutated during execution.
|
||||||
|
*
|
||||||
|
* @param stageId stage id
|
||||||
|
* @param partitionId index of the partition
|
||||||
|
* @param attemptId the number of attempts to execute this task
|
||||||
|
*/
|
||||||
|
@DeveloperApi
|
||||||
|
public TaskContext(Integer stageId, Integer partitionId, Long attemptId) {
|
||||||
|
this.attemptId = attemptId;
|
||||||
|
this.partitionId = partitionId;
|
||||||
|
this.runningLocally = false;
|
||||||
|
this.stageId = stageId;
|
||||||
|
this.taskMetrics = TaskMetrics.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ThreadLocal<TaskContext> taskContext =
|
||||||
|
new ThreadLocal<TaskContext>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* :: Internal API ::
|
||||||
|
* This is spark internal API, not intended to be called from user programs.
|
||||||
|
*/
|
||||||
|
public static void setTaskContext(TaskContext tc) {
|
||||||
|
taskContext.set(tc);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static TaskContext get() {
|
||||||
|
return taskContext.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* :: Internal API ::
|
||||||
|
*/
|
||||||
|
public static void remove() {
|
||||||
|
taskContext.remove();
|
||||||
|
}
|
||||||
|
|
||||||
|
// List of callback functions to execute when the task completes.
|
||||||
|
private transient List<TaskCompletionListener> onCompleteCallbacks =
|
||||||
|
new ArrayList<TaskCompletionListener>();
|
||||||
|
|
||||||
|
// Whether the corresponding task has been killed.
|
||||||
|
private volatile Boolean interrupted = false;
|
||||||
|
|
||||||
|
// Whether the task has completed.
|
||||||
|
private volatile Boolean completed = false;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks whether the task has completed.
|
||||||
|
*/
|
||||||
|
public Boolean isCompleted() {
|
||||||
|
return completed;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks whether the task has been killed.
|
||||||
|
*/
|
||||||
|
public Boolean isInterrupted() {
|
||||||
|
return interrupted;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a (Java friendly) listener to be executed on task completion.
|
||||||
|
* This will be called in all situation - success, failure, or cancellation.
|
||||||
|
* <p/>
|
||||||
|
* An example use is for HadoopRDD to register a callback to close the input stream.
|
||||||
|
*/
|
||||||
|
public TaskContext addTaskCompletionListener(TaskCompletionListener listener) {
|
||||||
|
onCompleteCallbacks.add(listener);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a listener in the form of a Scala closure to be executed on task completion.
|
||||||
|
* This will be called in all situations - success, failure, or cancellation.
|
||||||
|
* <p/>
|
||||||
|
* An example use is for HadoopRDD to register a callback to close the input stream.
|
||||||
|
*/
|
||||||
|
public TaskContext addTaskCompletionListener(final Function1<TaskContext, Unit> f) {
|
||||||
|
onCompleteCallbacks.add(new TaskCompletionListener() {
|
||||||
|
@Override
|
||||||
|
public void onTaskCompletion(TaskContext context) {
|
||||||
|
f.apply(context);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a callback function to be executed on task completion. An example use
|
||||||
|
* is for HadoopRDD to register a callback to close the input stream.
|
||||||
|
* Will be called in any situation - success, failure, or cancellation.
|
||||||
|
*
|
||||||
|
* Deprecated: use addTaskCompletionListener
|
||||||
|
*
|
||||||
|
* @param f Callback function.
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
|
public void addOnCompleteCallback(final Function0<Unit> f) {
|
||||||
|
onCompleteCallbacks.add(new TaskCompletionListener() {
|
||||||
|
@Override
|
||||||
|
public void onTaskCompletion(TaskContext context) {
|
||||||
|
f.apply();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ::Internal API::
|
||||||
|
* Marks the task as completed and triggers the listeners.
|
||||||
|
*/
|
||||||
|
public void markTaskCompleted() throws TaskCompletionListenerException {
|
||||||
|
completed = true;
|
||||||
|
List<String> errorMsgs = new ArrayList<String>(2);
|
||||||
|
// Process complete callbacks in the reverse order of registration
|
||||||
|
List<TaskCompletionListener> revlist =
|
||||||
|
new ArrayList<TaskCompletionListener>(onCompleteCallbacks);
|
||||||
|
Collections.reverse(revlist);
|
||||||
|
for (TaskCompletionListener tcl: revlist) {
|
||||||
|
try {
|
||||||
|
tcl.onTaskCompletion(this);
|
||||||
|
} catch (Throwable e) {
|
||||||
|
errorMsgs.add(e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!errorMsgs.isEmpty()) {
|
||||||
|
throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ::Internal API::
|
||||||
|
* Marks the task for interruption, i.e. cancellation.
|
||||||
|
*/
|
||||||
|
public void markInterrupted() {
|
||||||
|
interrupted = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
|
/** Deprecated: use getStageId() */
|
||||||
|
public int stageId() {
|
||||||
|
return stageId;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
|
/** Deprecated: use getPartitionId() */
|
||||||
|
public int partitionId() {
|
||||||
|
return partitionId;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
|
/** Deprecated: use getAttemptId() */
|
||||||
|
public long attemptId() {
|
||||||
|
return attemptId;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
|
/** Deprecated: use getRunningLocally() */
|
||||||
|
public boolean runningLocally() {
|
||||||
|
return runningLocally;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean getRunningLocally() {
|
||||||
|
return runningLocally;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getStageId() {
|
||||||
|
return stageId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getPartitionId() {
|
||||||
|
return partitionId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long getAttemptId() {
|
||||||
|
return attemptId;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** ::Internal API:: */
|
||||||
|
public TaskMetrics taskMetrics() {
|
||||||
|
return taskMetrics;
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,126 +0,0 @@
|
||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
||||||
* contributor license agreements. See the NOTICE file distributed with
|
|
||||||
* this work for additional information regarding copyright ownership.
|
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
||||||
* (the "License"); you may not use this file except in compliance with
|
|
||||||
* the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.spark
|
|
||||||
|
|
||||||
import scala.collection.mutable.ArrayBuffer
|
|
||||||
|
|
||||||
import org.apache.spark.annotation.DeveloperApi
|
|
||||||
import org.apache.spark.executor.TaskMetrics
|
|
||||||
import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* :: DeveloperApi ::
|
|
||||||
* Contextual information about a task which can be read or mutated during execution.
|
|
||||||
*
|
|
||||||
* @param stageId stage id
|
|
||||||
* @param partitionId index of the partition
|
|
||||||
* @param attemptId the number of attempts to execute this task
|
|
||||||
* @param runningLocally whether the task is running locally in the driver JVM
|
|
||||||
* @param taskMetrics performance metrics of the task
|
|
||||||
*/
|
|
||||||
@DeveloperApi
|
|
||||||
class TaskContext(
|
|
||||||
val stageId: Int,
|
|
||||||
val partitionId: Int,
|
|
||||||
val attemptId: Long,
|
|
||||||
val runningLocally: Boolean = false,
|
|
||||||
private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty)
|
|
||||||
extends Serializable with Logging {
|
|
||||||
|
|
||||||
@deprecated("use partitionId", "0.8.1")
|
|
||||||
def splitId = partitionId
|
|
||||||
|
|
||||||
// List of callback functions to execute when the task completes.
|
|
||||||
@transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
|
|
||||||
|
|
||||||
// Whether the corresponding task has been killed.
|
|
||||||
@volatile private var interrupted: Boolean = false
|
|
||||||
|
|
||||||
// Whether the task has completed.
|
|
||||||
@volatile private var completed: Boolean = false
|
|
||||||
|
|
||||||
/** Checks whether the task has completed. */
|
|
||||||
def isCompleted: Boolean = completed
|
|
||||||
|
|
||||||
/** Checks whether the task has been killed. */
|
|
||||||
def isInterrupted: Boolean = interrupted
|
|
||||||
|
|
||||||
// TODO: Also track whether the task has completed successfully or with exception.
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a (Java friendly) listener to be executed on task completion.
|
|
||||||
* This will be called in all situation - success, failure, or cancellation.
|
|
||||||
*
|
|
||||||
* An example use is for HadoopRDD to register a callback to close the input stream.
|
|
||||||
*/
|
|
||||||
def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
|
|
||||||
onCompleteCallbacks += listener
|
|
||||||
this
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a listener in the form of a Scala closure to be executed on task completion.
|
|
||||||
* This will be called in all situation - success, failure, or cancellation.
|
|
||||||
*
|
|
||||||
* An example use is for HadoopRDD to register a callback to close the input stream.
|
|
||||||
*/
|
|
||||||
def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
|
|
||||||
onCompleteCallbacks += new TaskCompletionListener {
|
|
||||||
override def onTaskCompletion(context: TaskContext): Unit = f(context)
|
|
||||||
}
|
|
||||||
this
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a callback function to be executed on task completion. An example use
|
|
||||||
* is for HadoopRDD to register a callback to close the input stream.
|
|
||||||
* Will be called in any situation - success, failure, or cancellation.
|
|
||||||
* @param f Callback function.
|
|
||||||
*/
|
|
||||||
@deprecated("use addTaskCompletionListener", "1.1.0")
|
|
||||||
def addOnCompleteCallback(f: () => Unit) {
|
|
||||||
onCompleteCallbacks += new TaskCompletionListener {
|
|
||||||
override def onTaskCompletion(context: TaskContext): Unit = f()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Marks the task as completed and triggers the listeners. */
|
|
||||||
private[spark] def markTaskCompleted(): Unit = {
|
|
||||||
completed = true
|
|
||||||
val errorMsgs = new ArrayBuffer[String](2)
|
|
||||||
// Process complete callbacks in the reverse order of registration
|
|
||||||
onCompleteCallbacks.reverse.foreach { listener =>
|
|
||||||
try {
|
|
||||||
listener.onTaskCompletion(this)
|
|
||||||
} catch {
|
|
||||||
case e: Throwable =>
|
|
||||||
errorMsgs += e.getMessage
|
|
||||||
logError("Error in TaskCompletionListener", e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (errorMsgs.nonEmpty) {
|
|
||||||
throw new TaskCompletionListenerException(errorMsgs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Marks the task for interruption, i.e. cancellation. */
|
|
||||||
private[spark] def markInterrupted(): Unit = {
|
|
||||||
interrupted = true
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -619,6 +619,7 @@ abstract class RDD[T: ClassTag](
|
||||||
* should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
|
* should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
|
||||||
*/
|
*/
|
||||||
@DeveloperApi
|
@DeveloperApi
|
||||||
|
@deprecated("use TaskContext.get", "1.2.0")
|
||||||
def mapPartitionsWithContext[U: ClassTag](
|
def mapPartitionsWithContext[U: ClassTag](
|
||||||
f: (TaskContext, Iterator[T]) => Iterator[U],
|
f: (TaskContext, Iterator[T]) => Iterator[U],
|
||||||
preservesPartitioning: Boolean = false): RDD[U] = {
|
preservesPartitioning: Boolean = false): RDD[U] = {
|
||||||
|
|
|
@ -634,12 +634,14 @@ class DAGScheduler(
|
||||||
val rdd = job.finalStage.rdd
|
val rdd = job.finalStage.rdd
|
||||||
val split = rdd.partitions(job.partitions(0))
|
val split = rdd.partitions(job.partitions(0))
|
||||||
val taskContext =
|
val taskContext =
|
||||||
new TaskContext(job.finalStage.id, job.partitions(0), 0, runningLocally = true)
|
new TaskContext(job.finalStage.id, job.partitions(0), 0, true)
|
||||||
|
TaskContext.setTaskContext(taskContext)
|
||||||
try {
|
try {
|
||||||
val result = job.func(taskContext, rdd.iterator(split, taskContext))
|
val result = job.func(taskContext, rdd.iterator(split, taskContext))
|
||||||
job.listener.taskSucceeded(0, result)
|
job.listener.taskSucceeded(0, result)
|
||||||
} finally {
|
} finally {
|
||||||
taskContext.markTaskCompleted()
|
taskContext.markTaskCompleted()
|
||||||
|
TaskContext.remove()
|
||||||
}
|
}
|
||||||
} catch {
|
} catch {
|
||||||
case e: Exception =>
|
case e: Exception =>
|
||||||
|
|
|
@ -45,7 +45,8 @@ import org.apache.spark.util.Utils
|
||||||
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
|
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
|
||||||
|
|
||||||
final def run(attemptId: Long): T = {
|
final def run(attemptId: Long): T = {
|
||||||
context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
|
context = new TaskContext(stageId, partitionId, attemptId, false)
|
||||||
|
TaskContext.setTaskContext(context)
|
||||||
context.taskMetrics.hostname = Utils.localHostName()
|
context.taskMetrics.hostname = Utils.localHostName()
|
||||||
taskThread = Thread.currentThread()
|
taskThread = Thread.currentThread()
|
||||||
if (_killed) {
|
if (_killed) {
|
||||||
|
@ -92,7 +93,8 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
|
||||||
if (interruptThread && taskThread != null) {
|
if (interruptThread && taskThread != null) {
|
||||||
taskThread.interrupt()
|
taskThread.interrupt()
|
||||||
}
|
}
|
||||||
}
|
TaskContext.remove()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -776,7 +776,7 @@ public class JavaAPISuite implements Serializable {
|
||||||
@Test
|
@Test
|
||||||
public void iterator() {
|
public void iterator() {
|
||||||
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
|
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
|
||||||
TaskContext context = new TaskContext(0, 0, 0, false, new TaskMetrics());
|
TaskContext context = new TaskContext(0, 0, 0L, false, new TaskMetrics());
|
||||||
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
|
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
|
||||||
}
|
}
|
||||||
|
|
||||||
whenExecuting(blockManager) {
|
whenExecuting(blockManager) {
|
||||||
val context = new TaskContext(0, 0, 0, runningLocally = true)
|
val context = new TaskContext(0, 0, 0, true)
|
||||||
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
|
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
|
||||||
assert(value.toList === List(1, 2, 3, 4))
|
assert(value.toList === List(1, 2, 3, 4))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue