diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index d762f11125..975a6e4eeb 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -148,6 +148,8 @@ private[spark] class Executor( startDriverHeartbeater() + private[executor] def numRunningTasks: Int = runningTasks.size() + def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { val tr = new TaskRunner(context, taskDescription) runningTasks.put(taskDescription.taskId, tr) diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index f94baaa30d..b743ff5376 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -17,16 +17,21 @@ package org.apache.spark.executor +import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.nio.ByteBuffer import java.util.Properties -import java.util.concurrent.CountDownLatch +import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.collection.mutable.Map +import scala.concurrent.duration._ -import org.mockito.Matchers._ -import org.mockito.Mockito.{mock, when} +import org.mockito.ArgumentCaptor +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito.{inOrder, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.concurrent.Eventually +import org.scalatest.mock.MockitoSugar import org.apache.spark._ import org.apache.spark.TaskState.TaskState @@ -36,35 +41,15 @@ import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.{FakeTask, TaskDescription} import org.apache.spark.serializer.JavaSerializer -class ExecutorSuite extends SparkFunSuite { +class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually { test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") { // mock some objects to make Executor.launchTask() happy val conf = new SparkConf val serializer = new JavaSerializer(conf) - val mockEnv = mock(classOf[SparkEnv]) - val mockRpcEnv = mock(classOf[RpcEnv]) - val mockMetricsSystem = mock(classOf[MetricsSystem]) - val mockMemoryManager = mock(classOf[MemoryManager]) - when(mockEnv.conf).thenReturn(conf) - when(mockEnv.serializer).thenReturn(serializer) - when(mockEnv.rpcEnv).thenReturn(mockRpcEnv) - when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem) - when(mockEnv.memoryManager).thenReturn(mockMemoryManager) - when(mockEnv.closureSerializer).thenReturn(serializer) - val fakeTaskMetrics = serializer.newInstance().serialize(TaskMetrics.registered).array() - val serializedTask = serializer.newInstance().serialize( - new FakeTask(0, 0, Nil, fakeTaskMetrics)) - val taskDescription = new TaskDescription( - taskId = 0, - attemptNumber = 0, - executorId = "", - name = "", - index = 0, - addedFiles = Map[String, Long](), - addedJars = Map[String, Long](), - properties = new Properties, - serializedTask) + val env = createMockEnv(conf, serializer) + val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0)) + val taskDescription = createFakeTaskDescription(serializedTask) // we use latches to force the program to run in this order: // +-----------------------------+---------------------------------------+ @@ -86,7 +71,7 @@ class ExecutorSuite extends SparkFunSuite { val executorSuiteHelper = new ExecutorSuiteHelper - val mockExecutorBackend = mock(classOf[ExecutorBackend]) + val mockExecutorBackend = mock[ExecutorBackend] when(mockExecutorBackend.statusUpdate(any(), any(), any())) .thenAnswer(new Answer[Unit] { var firstTime = true @@ -102,8 +87,8 @@ class ExecutorSuite extends SparkFunSuite { val taskState = invocationOnMock.getArguments()(1).asInstanceOf[TaskState] executorSuiteHelper.taskState = taskState val taskEndReason = invocationOnMock.getArguments()(2).asInstanceOf[ByteBuffer] - executorSuiteHelper.testFailedReason - = serializer.newInstance().deserialize(taskEndReason) + executorSuiteHelper.testFailedReason = + serializer.newInstance().deserialize(taskEndReason) // let the main test thread check `taskState` and `testFailedReason` executorSuiteHelper.latch3.countDown() } @@ -112,16 +97,20 @@ class ExecutorSuite extends SparkFunSuite { var executor: Executor = null try { - executor = new Executor("id", "localhost", mockEnv, userClassPath = Nil, isLocal = true) + executor = new Executor("id", "localhost", env, userClassPath = Nil, isLocal = true) // the task will be launched in a dedicated worker thread executor.launchTask(mockExecutorBackend, taskDescription) - executorSuiteHelper.latch1.await() + if (!executorSuiteHelper.latch1.await(5, TimeUnit.SECONDS)) { + fail("executor did not send first status update in time") + } // we know the task will be started, but not yet deserialized, because of the latches we // use in mockExecutorBackend. executor.killAllTasks(true) executorSuiteHelper.latch2.countDown() - executorSuiteHelper.latch3.await() + if (!executorSuiteHelper.latch3.await(5, TimeUnit.SECONDS)) { + fail("executor did not send second status update in time") + } // `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED` assert(executorSuiteHelper.testFailedReason === TaskKilled) @@ -133,6 +122,79 @@ class ExecutorSuite extends SparkFunSuite { } } } + + test("Gracefully handle error in task deserialization") { + val conf = new SparkConf + val serializer = new JavaSerializer(conf) + val env = createMockEnv(conf, serializer) + val serializedTask = serializer.newInstance().serialize(new NonDeserializableTask) + val taskDescription = createFakeTaskDescription(serializedTask) + + val failReason = runTaskAndGetFailReason(taskDescription) + failReason match { + case ef: ExceptionFailure => + assert(ef.exception.isDefined) + assert(ef.exception.get.getMessage() === NonDeserializableTask.errorMsg) + case _ => + fail(s"unexpected failure type: $failReason") + } + } + + private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = { + val mockEnv = mock[SparkEnv] + val mockRpcEnv = mock[RpcEnv] + val mockMetricsSystem = mock[MetricsSystem] + val mockMemoryManager = mock[MemoryManager] + when(mockEnv.conf).thenReturn(conf) + when(mockEnv.serializer).thenReturn(serializer) + when(mockEnv.rpcEnv).thenReturn(mockRpcEnv) + when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem) + when(mockEnv.memoryManager).thenReturn(mockMemoryManager) + when(mockEnv.closureSerializer).thenReturn(serializer) + SparkEnv.set(mockEnv) + mockEnv + } + + private def createFakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = { + new TaskDescription( + taskId = 0, + attemptNumber = 0, + executorId = "", + name = "", + index = 0, + addedFiles = Map[String, Long](), + addedJars = Map[String, Long](), + properties = new Properties, + serializedTask) + } + + private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = { + val mockBackend = mock[ExecutorBackend] + var executor: Executor = null + try { + executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true) + // the task will be launched in a dedicated worker thread + executor.launchTask(mockBackend, taskDescription) + eventually(timeout(5 seconds), interval(10 milliseconds)) { + assert(executor.numRunningTasks === 0) + } + } finally { + if (executor != null) { + executor.stop() + } + } + val orderedMock = inOrder(mockBackend) + val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) + orderedMock.verify(mockBackend) + .statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture()) + orderedMock.verify(mockBackend) + .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture()) + // first statusUpdate for RUNNING has empty data + assert(statusCaptor.getAllValues().get(0).remaining() === 0) + // second update is more interesting + val failureData = statusCaptor.getAllValues.get(1) + SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData) + } } // Helps to test("SPARK-15963") @@ -145,3 +207,14 @@ private class ExecutorSuiteHelper { @volatile var taskState: TaskState = _ @volatile var testFailedReason: TaskFailedReason = _ } + +private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable { + def writeExternal(out: ObjectOutput): Unit = {} + def readExternal(in: ObjectInput): Unit = { + throw new RuntimeException(NonDeserializableTask.errorMsg) + } +} + +private object NonDeserializableTask { + val errorMsg = "failure in deserialization" +}