[SPARK-20702][CORE] TaskContextImpl.markTaskCompleted should not hide the original error

## What changes were proposed in this pull request?

This PR adds an `error` parameter to `TaskContextImpl.markTaskCompleted` to propagate the original error.

It also fixes an issue that `TaskCompletionListenerException.getMessage` doesn't include `previousError`.

## How was this patch tested?

New unit tests.

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #17942 from zsxwing/SPARK-20702.
This commit is contained in:
Shixiong Zhu 2017-05-12 10:46:44 -07:00
parent b526f70c16
commit 7d6ff39106
6 changed files with 68 additions and 29 deletions

View file

@ -110,10 +110,10 @@ private[spark] class TaskContextImpl(
/** Marks the task as completed and triggers the completion listeners. */
@GuardedBy("this")
private[spark] def markTaskCompleted(): Unit = synchronized {
private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = synchronized {
if (completed) return
completed = true
invokeListeners(onCompleteCallbacks, "TaskCompletionListener", None) {
invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) {
_.onTaskCompletion(this)
}
}

View file

@ -115,26 +115,33 @@ private[spark] abstract class Task[T](
case t: Throwable =>
e.addSuppressed(t)
}
context.markTaskCompleted(Some(e))
throw e
} finally {
// Call the task completion callbacks.
context.markTaskCompleted()
try {
Utils.tryLogNonFatalError {
// Release memory used by this thread for unrolling blocks
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP)
// Notify any tasks waiting for execution memory to be freed to wake up and try to
// acquire memory again. This makes impossible the scenario where a task sleeps forever
// because there are no other tasks left to notify it. Since this is safe to do but may
// not be strictly necessary, we should revisit whether we can remove this in the future.
val memoryManager = SparkEnv.get.memoryManager
memoryManager.synchronized { memoryManager.notifyAll() }
}
// Call the task completion callbacks. If "markTaskCompleted" is called twice, the second
// one is no-op.
context.markTaskCompleted(None)
} finally {
// Though we unset the ThreadLocal here, the context member variable itself is still queried
// directly in the TaskRunner to check for FetchFailedExceptions.
TaskContext.unset()
try {
Utils.tryLogNonFatalError {
// Release memory used by this thread for unrolling blocks
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
MemoryMode.OFF_HEAP)
// Notify any tasks waiting for execution memory to be freed to wake up and try to
// acquire memory again. This makes impossible the scenario where a task sleeps forever
// because there are no other tasks left to notify it. Since this is safe to do but may
// not be strictly necessary, we should revisit whether we can remove this in the
// future.
val memoryManager = SparkEnv.get.memoryManager
memoryManager.synchronized { memoryManager.notifyAll() }
}
} finally {
// Though we unset the ThreadLocal here, the context member variable itself is still
// queried directly in the TaskRunner to check for FetchFailedExceptions.
TaskContext.unset()
}
}
}
}

View file

@ -55,14 +55,16 @@ class TaskCompletionListenerException(
extends RuntimeException {
override def getMessage: String = {
if (errorMessages.size == 1) {
errorMessages.head
} else {
errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
} +
previousError.map { e =>
val listenerErrorMessage =
if (errorMessages.size == 1) {
errorMessages.head
} else {
errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
}
val previousErrorMessage = previousError.map { e =>
"\n\nPrevious exception in task: " + e.getMessage + "\n" +
e.getStackTrace.mkString("\t", "\n\t", "")
}.getOrElse("")
listenerErrorMessage + previousErrorMessage
}
}

View file

@ -100,7 +100,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
context.addTaskCompletionListener(_ => throw new Exception("blah"))
intercept[TaskCompletionListenerException] {
context.markTaskCompleted()
context.markTaskCompleted(None)
}
verify(listener, times(1)).onTaskCompletion(any())
@ -231,10 +231,10 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
test("immediately call a completion listener if the context is completed") {
var invocations = 0
val context = TaskContext.empty()
context.markTaskCompleted()
context.markTaskCompleted(None)
context.addTaskCompletionListener(_ => invocations += 1)
assert(invocations == 1)
context.markTaskCompleted()
context.markTaskCompleted(None)
assert(invocations == 1)
}
@ -254,6 +254,36 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
assert(lastError == error)
assert(invocations == 1)
}
test("TaskCompletionListenerException.getMessage should include previousError") {
val listenerErrorMessage = "exception in listener"
val taskErrorMessage = "exception in task"
val e = new TaskCompletionListenerException(
Seq(listenerErrorMessage),
Some(new RuntimeException(taskErrorMessage)))
assert(e.getMessage.contains(listenerErrorMessage) && e.getMessage.contains(taskErrorMessage))
}
test("all TaskCompletionListeners should be called even if some fail or a task") {
val context = TaskContext.empty()
val listener = mock(classOf[TaskCompletionListener])
context.addTaskCompletionListener(_ => throw new Exception("exception in listener1"))
context.addTaskCompletionListener(listener)
context.addTaskCompletionListener(_ => throw new Exception("exception in listener3"))
val e = intercept[TaskCompletionListenerException] {
context.markTaskCompleted(Some(new Exception("exception in task")))
}
// Make sure listener 2 was called.
verify(listener, times(1)).onTaskCompletion(any())
// also need to check failure in TaskCompletionListener does not mask earlier exception
assert(e.getMessage.contains("exception in listener1"))
assert(e.getMessage.contains("exception in listener3"))
assert(e.getMessage.contains("exception in task"))
}
}
private object TaskContextSuite {

View file

@ -145,7 +145,7 @@ class PartiallySerializedBlockSuite
try {
TaskContext.setTaskContext(TaskContext.empty())
val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted()
TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted(None)
Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose()
Mockito.verifyNoMoreInteractions(memoryStore)
} finally {

View file

@ -192,7 +192,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
// Complete the task; then the 2nd block buffer should be exhausted
verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release()
taskContext.markTaskCompleted()
taskContext.markTaskCompleted(None)
verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release()
// The 3rd block should not be retained because the iterator is already in zombie state