[SPARK-19597][CORE] test case for task deserialization errors

Adds a test case that ensures that Executors gracefully handle a task that fails to deserialize, by sending back a reasonable failure message.  This does not change any behavior (the prior behavior was already correct), it just adds a test case to prevent regression.

Author: Imran Rashid <irashid@cloudera.com>

Closes #16930 from squito/executor_task_deserialization.
This commit is contained in:
Imran Rashid 2017-02-24 13:03:37 -08:00 committed by Kay Ousterhout
parent 5cbd3b59ba
commit 5f74148bb4
2 changed files with 108 additions and 33 deletions

View file

@ -148,6 +148,8 @@ private[spark] class Executor(
startDriverHeartbeater()
private[executor] def numRunningTasks: Int = runningTasks.size()
def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
val tr = new TaskRunner(context, taskDescription)
runningTasks.put(taskDescription.taskId, tr)

View file

@ -17,16 +17,21 @@
package org.apache.spark.executor
import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.nio.ByteBuffer
import java.util.Properties
import java.util.concurrent.CountDownLatch
import java.util.concurrent.{CountDownLatch, TimeUnit}
import scala.collection.mutable.Map
import scala.concurrent.duration._
import org.mockito.Matchers._
import org.mockito.Mockito.{mock, when}
import org.mockito.ArgumentCaptor
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.{inOrder, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.concurrent.Eventually
import org.scalatest.mock.MockitoSugar
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
@ -36,35 +41,15 @@ import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.{FakeTask, TaskDescription}
import org.apache.spark.serializer.JavaSerializer
class ExecutorSuite extends SparkFunSuite {
class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually {
test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") {
// mock some objects to make Executor.launchTask() happy
val conf = new SparkConf
val serializer = new JavaSerializer(conf)
val mockEnv = mock(classOf[SparkEnv])
val mockRpcEnv = mock(classOf[RpcEnv])
val mockMetricsSystem = mock(classOf[MetricsSystem])
val mockMemoryManager = mock(classOf[MemoryManager])
when(mockEnv.conf).thenReturn(conf)
when(mockEnv.serializer).thenReturn(serializer)
when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem)
when(mockEnv.memoryManager).thenReturn(mockMemoryManager)
when(mockEnv.closureSerializer).thenReturn(serializer)
val fakeTaskMetrics = serializer.newInstance().serialize(TaskMetrics.registered).array()
val serializedTask = serializer.newInstance().serialize(
new FakeTask(0, 0, Nil, fakeTaskMetrics))
val taskDescription = new TaskDescription(
taskId = 0,
attemptNumber = 0,
executorId = "",
name = "",
index = 0,
addedFiles = Map[String, Long](),
addedJars = Map[String, Long](),
properties = new Properties,
serializedTask)
val env = createMockEnv(conf, serializer)
val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0))
val taskDescription = createFakeTaskDescription(serializedTask)
// we use latches to force the program to run in this order:
// +-----------------------------+---------------------------------------+
@ -86,7 +71,7 @@ class ExecutorSuite extends SparkFunSuite {
val executorSuiteHelper = new ExecutorSuiteHelper
val mockExecutorBackend = mock(classOf[ExecutorBackend])
val mockExecutorBackend = mock[ExecutorBackend]
when(mockExecutorBackend.statusUpdate(any(), any(), any()))
.thenAnswer(new Answer[Unit] {
var firstTime = true
@ -102,8 +87,8 @@ class ExecutorSuite extends SparkFunSuite {
val taskState = invocationOnMock.getArguments()(1).asInstanceOf[TaskState]
executorSuiteHelper.taskState = taskState
val taskEndReason = invocationOnMock.getArguments()(2).asInstanceOf[ByteBuffer]
executorSuiteHelper.testFailedReason
= serializer.newInstance().deserialize(taskEndReason)
executorSuiteHelper.testFailedReason =
serializer.newInstance().deserialize(taskEndReason)
// let the main test thread check `taskState` and `testFailedReason`
executorSuiteHelper.latch3.countDown()
}
@ -112,16 +97,20 @@ class ExecutorSuite extends SparkFunSuite {
var executor: Executor = null
try {
executor = new Executor("id", "localhost", mockEnv, userClassPath = Nil, isLocal = true)
executor = new Executor("id", "localhost", env, userClassPath = Nil, isLocal = true)
// the task will be launched in a dedicated worker thread
executor.launchTask(mockExecutorBackend, taskDescription)
executorSuiteHelper.latch1.await()
if (!executorSuiteHelper.latch1.await(5, TimeUnit.SECONDS)) {
fail("executor did not send first status update in time")
}
// we know the task will be started, but not yet deserialized, because of the latches we
// use in mockExecutorBackend.
executor.killAllTasks(true)
executorSuiteHelper.latch2.countDown()
executorSuiteHelper.latch3.await()
if (!executorSuiteHelper.latch3.await(5, TimeUnit.SECONDS)) {
fail("executor did not send second status update in time")
}
// `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED`
assert(executorSuiteHelper.testFailedReason === TaskKilled)
@ -133,6 +122,79 @@ class ExecutorSuite extends SparkFunSuite {
}
}
}
test("Gracefully handle error in task deserialization") {
val conf = new SparkConf
val serializer = new JavaSerializer(conf)
val env = createMockEnv(conf, serializer)
val serializedTask = serializer.newInstance().serialize(new NonDeserializableTask)
val taskDescription = createFakeTaskDescription(serializedTask)
val failReason = runTaskAndGetFailReason(taskDescription)
failReason match {
case ef: ExceptionFailure =>
assert(ef.exception.isDefined)
assert(ef.exception.get.getMessage() === NonDeserializableTask.errorMsg)
case _ =>
fail(s"unexpected failure type: $failReason")
}
}
private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = {
val mockEnv = mock[SparkEnv]
val mockRpcEnv = mock[RpcEnv]
val mockMetricsSystem = mock[MetricsSystem]
val mockMemoryManager = mock[MemoryManager]
when(mockEnv.conf).thenReturn(conf)
when(mockEnv.serializer).thenReturn(serializer)
when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem)
when(mockEnv.memoryManager).thenReturn(mockMemoryManager)
when(mockEnv.closureSerializer).thenReturn(serializer)
SparkEnv.set(mockEnv)
mockEnv
}
private def createFakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = {
new TaskDescription(
taskId = 0,
attemptNumber = 0,
executorId = "",
name = "",
index = 0,
addedFiles = Map[String, Long](),
addedJars = Map[String, Long](),
properties = new Properties,
serializedTask)
}
private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = {
val mockBackend = mock[ExecutorBackend]
var executor: Executor = null
try {
executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true)
// the task will be launched in a dedicated worker thread
executor.launchTask(mockBackend, taskDescription)
eventually(timeout(5 seconds), interval(10 milliseconds)) {
assert(executor.numRunningTasks === 0)
}
} finally {
if (executor != null) {
executor.stop()
}
}
val orderedMock = inOrder(mockBackend)
val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
orderedMock.verify(mockBackend)
.statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture())
orderedMock.verify(mockBackend)
.statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture())
// first statusUpdate for RUNNING has empty data
assert(statusCaptor.getAllValues().get(0).remaining() === 0)
// second update is more interesting
val failureData = statusCaptor.getAllValues.get(1)
SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData)
}
}
// Helps to test("SPARK-15963")
@ -145,3 +207,14 @@ private class ExecutorSuiteHelper {
@volatile var taskState: TaskState = _
@volatile var testFailedReason: TaskFailedReason = _
}
private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable {
def writeExternal(out: ObjectOutput): Unit = {}
def readExternal(in: ObjectInput): Unit = {
throw new RuntimeException(NonDeserializableTask.errorMsg)
}
}
private object NonDeserializableTask {
val errorMsg = "failure in deserialization"
}