[SPARK-20702][CORE] TaskContextImpl.markTaskCompleted should not hide the original error
## What changes were proposed in this pull request? This PR adds an `error` parameter to `TaskContextImpl.markTaskCompleted` to propagate the original error. It also fixes an issue that `TaskCompletionListenerException.getMessage` doesn't include `previousError`. ## How was this patch tested? New unit tests. Author: Shixiong Zhu <shixiong@databricks.com> Closes #17942 from zsxwing/SPARK-20702.
This commit is contained in:
parent
b526f70c16
commit
7d6ff39106
|
@ -110,10 +110,10 @@ private[spark] class TaskContextImpl(
|
|||
|
||||
/** Marks the task as completed and triggers the completion listeners. */
|
||||
@GuardedBy("this")
|
||||
private[spark] def markTaskCompleted(): Unit = synchronized {
|
||||
private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = synchronized {
|
||||
if (completed) return
|
||||
completed = true
|
||||
invokeListeners(onCompleteCallbacks, "TaskCompletionListener", None) {
|
||||
invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) {
|
||||
_.onTaskCompletion(this)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -115,26 +115,33 @@ private[spark] abstract class Task[T](
|
|||
case t: Throwable =>
|
||||
e.addSuppressed(t)
|
||||
}
|
||||
context.markTaskCompleted(Some(e))
|
||||
throw e
|
||||
} finally {
|
||||
// Call the task completion callbacks.
|
||||
context.markTaskCompleted()
|
||||
try {
|
||||
Utils.tryLogNonFatalError {
|
||||
// Release memory used by this thread for unrolling blocks
|
||||
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
|
||||
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP)
|
||||
// Notify any tasks waiting for execution memory to be freed to wake up and try to
|
||||
// acquire memory again. This makes impossible the scenario where a task sleeps forever
|
||||
// because there are no other tasks left to notify it. Since this is safe to do but may
|
||||
// not be strictly necessary, we should revisit whether we can remove this in the future.
|
||||
val memoryManager = SparkEnv.get.memoryManager
|
||||
memoryManager.synchronized { memoryManager.notifyAll() }
|
||||
}
|
||||
// Call the task completion callbacks. If "markTaskCompleted" is called twice, the second
|
||||
// one is no-op.
|
||||
context.markTaskCompleted(None)
|
||||
} finally {
|
||||
// Though we unset the ThreadLocal here, the context member variable itself is still queried
|
||||
// directly in the TaskRunner to check for FetchFailedExceptions.
|
||||
TaskContext.unset()
|
||||
try {
|
||||
Utils.tryLogNonFatalError {
|
||||
// Release memory used by this thread for unrolling blocks
|
||||
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
|
||||
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
|
||||
MemoryMode.OFF_HEAP)
|
||||
// Notify any tasks waiting for execution memory to be freed to wake up and try to
|
||||
// acquire memory again. This makes impossible the scenario where a task sleeps forever
|
||||
// because there are no other tasks left to notify it. Since this is safe to do but may
|
||||
// not be strictly necessary, we should revisit whether we can remove this in the
|
||||
// future.
|
||||
val memoryManager = SparkEnv.get.memoryManager
|
||||
memoryManager.synchronized { memoryManager.notifyAll() }
|
||||
}
|
||||
} finally {
|
||||
// Though we unset the ThreadLocal here, the context member variable itself is still
|
||||
// queried directly in the TaskRunner to check for FetchFailedExceptions.
|
||||
TaskContext.unset()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -55,14 +55,16 @@ class TaskCompletionListenerException(
|
|||
extends RuntimeException {
|
||||
|
||||
override def getMessage: String = {
|
||||
if (errorMessages.size == 1) {
|
||||
errorMessages.head
|
||||
} else {
|
||||
errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
|
||||
} +
|
||||
previousError.map { e =>
|
||||
val listenerErrorMessage =
|
||||
if (errorMessages.size == 1) {
|
||||
errorMessages.head
|
||||
} else {
|
||||
errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
|
||||
}
|
||||
val previousErrorMessage = previousError.map { e =>
|
||||
"\n\nPrevious exception in task: " + e.getMessage + "\n" +
|
||||
e.getStackTrace.mkString("\t", "\n\t", "")
|
||||
}.getOrElse("")
|
||||
listenerErrorMessage + previousErrorMessage
|
||||
}
|
||||
}
|
||||
|
|
|
@ -100,7 +100,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
|
|||
context.addTaskCompletionListener(_ => throw new Exception("blah"))
|
||||
|
||||
intercept[TaskCompletionListenerException] {
|
||||
context.markTaskCompleted()
|
||||
context.markTaskCompleted(None)
|
||||
}
|
||||
|
||||
verify(listener, times(1)).onTaskCompletion(any())
|
||||
|
@ -231,10 +231,10 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
|
|||
test("immediately call a completion listener if the context is completed") {
|
||||
var invocations = 0
|
||||
val context = TaskContext.empty()
|
||||
context.markTaskCompleted()
|
||||
context.markTaskCompleted(None)
|
||||
context.addTaskCompletionListener(_ => invocations += 1)
|
||||
assert(invocations == 1)
|
||||
context.markTaskCompleted()
|
||||
context.markTaskCompleted(None)
|
||||
assert(invocations == 1)
|
||||
}
|
||||
|
||||
|
@ -254,6 +254,36 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
|
|||
assert(lastError == error)
|
||||
assert(invocations == 1)
|
||||
}
|
||||
|
||||
test("TaskCompletionListenerException.getMessage should include previousError") {
|
||||
val listenerErrorMessage = "exception in listener"
|
||||
val taskErrorMessage = "exception in task"
|
||||
val e = new TaskCompletionListenerException(
|
||||
Seq(listenerErrorMessage),
|
||||
Some(new RuntimeException(taskErrorMessage)))
|
||||
assert(e.getMessage.contains(listenerErrorMessage) && e.getMessage.contains(taskErrorMessage))
|
||||
}
|
||||
|
||||
test("all TaskCompletionListeners should be called even if some fail or a task") {
|
||||
val context = TaskContext.empty()
|
||||
val listener = mock(classOf[TaskCompletionListener])
|
||||
context.addTaskCompletionListener(_ => throw new Exception("exception in listener1"))
|
||||
context.addTaskCompletionListener(listener)
|
||||
context.addTaskCompletionListener(_ => throw new Exception("exception in listener3"))
|
||||
|
||||
val e = intercept[TaskCompletionListenerException] {
|
||||
context.markTaskCompleted(Some(new Exception("exception in task")))
|
||||
}
|
||||
|
||||
// Make sure listener 2 was called.
|
||||
verify(listener, times(1)).onTaskCompletion(any())
|
||||
|
||||
// also need to check failure in TaskCompletionListener does not mask earlier exception
|
||||
assert(e.getMessage.contains("exception in listener1"))
|
||||
assert(e.getMessage.contains("exception in listener3"))
|
||||
assert(e.getMessage.contains("exception in task"))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private object TaskContextSuite {
|
||||
|
|
|
@ -145,7 +145,7 @@ class PartiallySerializedBlockSuite
|
|||
try {
|
||||
TaskContext.setTaskContext(TaskContext.empty())
|
||||
val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
|
||||
TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted()
|
||||
TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted(None)
|
||||
Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose()
|
||||
Mockito.verifyNoMoreInteractions(memoryStore)
|
||||
} finally {
|
||||
|
|
|
@ -192,7 +192,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
|
|||
|
||||
// Complete the task; then the 2nd block buffer should be exhausted
|
||||
verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release()
|
||||
taskContext.markTaskCompleted()
|
||||
taskContext.markTaskCompleted(None)
|
||||
verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release()
|
||||
|
||||
// The 3rd block should not be retained because the iterator is already in zombie state
|
||||
|
|
Loading…
Reference in a new issue