[SPARK-19597][CORE] test case for task deserialization errors
Adds a test case that ensures that Executors gracefully handle a task that fails to deserialize, by sending back a reasonable failure message. This does not change any behavior (the prior behavior was already correct), it just adds a test case to prevent regression. Author: Imran Rashid <irashid@cloudera.com> Closes #16930 from squito/executor_task_deserialization.
This commit is contained in:
parent
5cbd3b59ba
commit
5f74148bb4
|
@ -148,6 +148,8 @@ private[spark] class Executor(
|
|||
|
||||
startDriverHeartbeater()
|
||||
|
||||
private[executor] def numRunningTasks: Int = runningTasks.size()
|
||||
|
||||
def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
|
||||
val tr = new TaskRunner(context, taskDescription)
|
||||
runningTasks.put(taskDescription.taskId, tr)
|
||||
|
|
|
@ -17,16 +17,21 @@
|
|||
|
||||
package org.apache.spark.executor
|
||||
|
||||
import java.io.{Externalizable, ObjectInput, ObjectOutput}
|
||||
import java.nio.ByteBuffer
|
||||
import java.util.Properties
|
||||
import java.util.concurrent.CountDownLatch
|
||||
import java.util.concurrent.{CountDownLatch, TimeUnit}
|
||||
|
||||
import scala.collection.mutable.Map
|
||||
import scala.concurrent.duration._
|
||||
|
||||
import org.mockito.Matchers._
|
||||
import org.mockito.Mockito.{mock, when}
|
||||
import org.mockito.ArgumentCaptor
|
||||
import org.mockito.Matchers.{any, eq => meq}
|
||||
import org.mockito.Mockito.{inOrder, when}
|
||||
import org.mockito.invocation.InvocationOnMock
|
||||
import org.mockito.stubbing.Answer
|
||||
import org.scalatest.concurrent.Eventually
|
||||
import org.scalatest.mock.MockitoSugar
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.TaskState.TaskState
|
||||
|
@ -36,35 +41,15 @@ import org.apache.spark.rpc.RpcEnv
|
|||
import org.apache.spark.scheduler.{FakeTask, TaskDescription}
|
||||
import org.apache.spark.serializer.JavaSerializer
|
||||
|
||||
class ExecutorSuite extends SparkFunSuite {
|
||||
class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually {
|
||||
|
||||
test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") {
|
||||
// mock some objects to make Executor.launchTask() happy
|
||||
val conf = new SparkConf
|
||||
val serializer = new JavaSerializer(conf)
|
||||
val mockEnv = mock(classOf[SparkEnv])
|
||||
val mockRpcEnv = mock(classOf[RpcEnv])
|
||||
val mockMetricsSystem = mock(classOf[MetricsSystem])
|
||||
val mockMemoryManager = mock(classOf[MemoryManager])
|
||||
when(mockEnv.conf).thenReturn(conf)
|
||||
when(mockEnv.serializer).thenReturn(serializer)
|
||||
when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
|
||||
when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem)
|
||||
when(mockEnv.memoryManager).thenReturn(mockMemoryManager)
|
||||
when(mockEnv.closureSerializer).thenReturn(serializer)
|
||||
val fakeTaskMetrics = serializer.newInstance().serialize(TaskMetrics.registered).array()
|
||||
val serializedTask = serializer.newInstance().serialize(
|
||||
new FakeTask(0, 0, Nil, fakeTaskMetrics))
|
||||
val taskDescription = new TaskDescription(
|
||||
taskId = 0,
|
||||
attemptNumber = 0,
|
||||
executorId = "",
|
||||
name = "",
|
||||
index = 0,
|
||||
addedFiles = Map[String, Long](),
|
||||
addedJars = Map[String, Long](),
|
||||
properties = new Properties,
|
||||
serializedTask)
|
||||
val env = createMockEnv(conf, serializer)
|
||||
val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0))
|
||||
val taskDescription = createFakeTaskDescription(serializedTask)
|
||||
|
||||
// we use latches to force the program to run in this order:
|
||||
// +-----------------------------+---------------------------------------+
|
||||
|
@ -86,7 +71,7 @@ class ExecutorSuite extends SparkFunSuite {
|
|||
|
||||
val executorSuiteHelper = new ExecutorSuiteHelper
|
||||
|
||||
val mockExecutorBackend = mock(classOf[ExecutorBackend])
|
||||
val mockExecutorBackend = mock[ExecutorBackend]
|
||||
when(mockExecutorBackend.statusUpdate(any(), any(), any()))
|
||||
.thenAnswer(new Answer[Unit] {
|
||||
var firstTime = true
|
||||
|
@ -102,8 +87,8 @@ class ExecutorSuite extends SparkFunSuite {
|
|||
val taskState = invocationOnMock.getArguments()(1).asInstanceOf[TaskState]
|
||||
executorSuiteHelper.taskState = taskState
|
||||
val taskEndReason = invocationOnMock.getArguments()(2).asInstanceOf[ByteBuffer]
|
||||
executorSuiteHelper.testFailedReason
|
||||
= serializer.newInstance().deserialize(taskEndReason)
|
||||
executorSuiteHelper.testFailedReason =
|
||||
serializer.newInstance().deserialize(taskEndReason)
|
||||
// let the main test thread check `taskState` and `testFailedReason`
|
||||
executorSuiteHelper.latch3.countDown()
|
||||
}
|
||||
|
@ -112,16 +97,20 @@ class ExecutorSuite extends SparkFunSuite {
|
|||
|
||||
var executor: Executor = null
|
||||
try {
|
||||
executor = new Executor("id", "localhost", mockEnv, userClassPath = Nil, isLocal = true)
|
||||
executor = new Executor("id", "localhost", env, userClassPath = Nil, isLocal = true)
|
||||
// the task will be launched in a dedicated worker thread
|
||||
executor.launchTask(mockExecutorBackend, taskDescription)
|
||||
|
||||
executorSuiteHelper.latch1.await()
|
||||
if (!executorSuiteHelper.latch1.await(5, TimeUnit.SECONDS)) {
|
||||
fail("executor did not send first status update in time")
|
||||
}
|
||||
// we know the task will be started, but not yet deserialized, because of the latches we
|
||||
// use in mockExecutorBackend.
|
||||
executor.killAllTasks(true)
|
||||
executorSuiteHelper.latch2.countDown()
|
||||
executorSuiteHelper.latch3.await()
|
||||
if (!executorSuiteHelper.latch3.await(5, TimeUnit.SECONDS)) {
|
||||
fail("executor did not send second status update in time")
|
||||
}
|
||||
|
||||
// `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED`
|
||||
assert(executorSuiteHelper.testFailedReason === TaskKilled)
|
||||
|
@ -133,6 +122,79 @@ class ExecutorSuite extends SparkFunSuite {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("Gracefully handle error in task deserialization") {
|
||||
val conf = new SparkConf
|
||||
val serializer = new JavaSerializer(conf)
|
||||
val env = createMockEnv(conf, serializer)
|
||||
val serializedTask = serializer.newInstance().serialize(new NonDeserializableTask)
|
||||
val taskDescription = createFakeTaskDescription(serializedTask)
|
||||
|
||||
val failReason = runTaskAndGetFailReason(taskDescription)
|
||||
failReason match {
|
||||
case ef: ExceptionFailure =>
|
||||
assert(ef.exception.isDefined)
|
||||
assert(ef.exception.get.getMessage() === NonDeserializableTask.errorMsg)
|
||||
case _ =>
|
||||
fail(s"unexpected failure type: $failReason")
|
||||
}
|
||||
}
|
||||
|
||||
private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = {
|
||||
val mockEnv = mock[SparkEnv]
|
||||
val mockRpcEnv = mock[RpcEnv]
|
||||
val mockMetricsSystem = mock[MetricsSystem]
|
||||
val mockMemoryManager = mock[MemoryManager]
|
||||
when(mockEnv.conf).thenReturn(conf)
|
||||
when(mockEnv.serializer).thenReturn(serializer)
|
||||
when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
|
||||
when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem)
|
||||
when(mockEnv.memoryManager).thenReturn(mockMemoryManager)
|
||||
when(mockEnv.closureSerializer).thenReturn(serializer)
|
||||
SparkEnv.set(mockEnv)
|
||||
mockEnv
|
||||
}
|
||||
|
||||
private def createFakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = {
|
||||
new TaskDescription(
|
||||
taskId = 0,
|
||||
attemptNumber = 0,
|
||||
executorId = "",
|
||||
name = "",
|
||||
index = 0,
|
||||
addedFiles = Map[String, Long](),
|
||||
addedJars = Map[String, Long](),
|
||||
properties = new Properties,
|
||||
serializedTask)
|
||||
}
|
||||
|
||||
private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = {
|
||||
val mockBackend = mock[ExecutorBackend]
|
||||
var executor: Executor = null
|
||||
try {
|
||||
executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true)
|
||||
// the task will be launched in a dedicated worker thread
|
||||
executor.launchTask(mockBackend, taskDescription)
|
||||
eventually(timeout(5 seconds), interval(10 milliseconds)) {
|
||||
assert(executor.numRunningTasks === 0)
|
||||
}
|
||||
} finally {
|
||||
if (executor != null) {
|
||||
executor.stop()
|
||||
}
|
||||
}
|
||||
val orderedMock = inOrder(mockBackend)
|
||||
val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
|
||||
orderedMock.verify(mockBackend)
|
||||
.statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture())
|
||||
orderedMock.verify(mockBackend)
|
||||
.statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture())
|
||||
// first statusUpdate for RUNNING has empty data
|
||||
assert(statusCaptor.getAllValues().get(0).remaining() === 0)
|
||||
// second update is more interesting
|
||||
val failureData = statusCaptor.getAllValues.get(1)
|
||||
SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData)
|
||||
}
|
||||
}
|
||||
|
||||
// Helps to test("SPARK-15963")
|
||||
|
@ -145,3 +207,14 @@ private class ExecutorSuiteHelper {
|
|||
@volatile var taskState: TaskState = _
|
||||
@volatile var testFailedReason: TaskFailedReason = _
|
||||
}
|
||||
|
||||
private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable {
|
||||
def writeExternal(out: ObjectOutput): Unit = {}
|
||||
def readExternal(in: ObjectInput): Unit = {
|
||||
throw new RuntimeException(NonDeserializableTask.errorMsg)
|
||||
}
|
||||
}
|
||||
|
||||
private object NonDeserializableTask {
|
||||
val errorMsg = "failure in deserialization"
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue