[SPARK-28584][CORE] Fix thread safety issue in blacklist timer, tests

There's a small, probably very hard to hit thread-safety issue in the blacklist
abort timers in the task scheduler, where they access a non-thread-safe map without
locks.

In the tests, the code was also calling methods on the TaskSetManager without
holding the proper locks, which could cause threads to call non-thread-safe
TSM methods concurrently.

Closes #25317 from vanzin/SPARK-28584.

Authored-by: Marcelo Vanzin <vanzin@cloudera.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
Marcelo Vanzin 2019-08-01 10:37:47 -07:00 committed by Dongjoon Hyun
parent 8e1602a04f
commit 607fb87906
2 changed files with 24 additions and 20 deletions

View file

@ -552,7 +552,7 @@ private[spark] class TaskSchedulerImpl(
taskSet: TaskSetManager,
taskIndex: Int): TimerTask = {
new TimerTask() {
override def run() {
override def run(): Unit = TaskSchedulerImpl.this.synchronized {
if (unschedulableTaskSetToExpiryTime.contains(taskSet) &&
unschedulableTaskSetToExpiryTime(taskSet) <= clock.getTimeMillis()) {
logInfo("Cannot schedule any task because of complete blacklisting. " +

View file

@ -418,7 +418,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
val taskIndex = task.index
(0 until 4).foreach { attempt =>
assert(task.attemptNumber === attempt)
tsm.handleFailedTask(task.taskId, TaskState.FAILED, TaskResultLost)
failTask(task.taskId, TaskState.FAILED, TaskResultLost, tsm)
val nextAttempts =
taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("executor4", "host4", 1))).flatten
if (attempt < 3) {
@ -550,11 +550,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
// Fail the running task
val failedTask = firstTaskAttempts.find(_.executorId == "executor0").get
taskScheduler.statusUpdate(failedTask.taskId, TaskState.FAILED, ByteBuffer.allocate(0))
// we explicitly call the handleFailedTask method here to avoid adding a sleep in the test suite
// Reason being - handleFailedTask is run by an executor service and there is a momentary delay
// before it is launched and this fails the assertion check.
tsm.handleFailedTask(failedTask.taskId, TaskState.FAILED, UnknownReason)
failTask(failedTask.taskId, TaskState.FAILED, UnknownReason, tsm)
when(tsm.taskSetBlacklistHelperOpt.get.isExecutorBlacklistedForTask(
"executor0", failedTask.index)).thenReturn(true)
@ -586,11 +582,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
// Fail the running task
val failedTask = firstTaskAttempts.head
taskScheduler.statusUpdate(failedTask.taskId, TaskState.FAILED, ByteBuffer.allocate(0))
// we explicitly call the handleFailedTask method here to avoid adding a sleep in the test suite
// Reason being - handleFailedTask is run by an executor service and there is a momentary delay
// before it is launched and this fails the assertion check.
tsm.handleFailedTask(failedTask.taskId, TaskState.FAILED, UnknownReason)
failTask(failedTask.taskId, TaskState.FAILED, UnknownReason, tsm)
when(tsm.taskSetBlacklistHelperOpt.get.isExecutorBlacklistedForTask(
"executor0", failedTask.index)).thenReturn(true)
@ -632,8 +624,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
// Fail the running task
val failedTask = firstTaskAttempts.head
taskScheduler.statusUpdate(failedTask.taskId, TaskState.FAILED, ByteBuffer.allocate(0))
tsm.handleFailedTask(failedTask.taskId, TaskState.FAILED, UnknownReason)
failTask(failedTask.taskId, TaskState.FAILED, UnknownReason, tsm)
when(tsm.taskSetBlacklistHelperOpt.get.isExecutorBlacklistedForTask(
"executor0", failedTask.index)).thenReturn(true)
@ -647,8 +638,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
val tsm2 = stageToMockTaskSetManager(1)
val failedTask2 = secondTaskAttempts.head
taskScheduler.statusUpdate(failedTask2.taskId, TaskState.FAILED, ByteBuffer.allocate(0))
tsm2.handleFailedTask(failedTask2.taskId, TaskState.FAILED, UnknownReason)
failTask(failedTask2.taskId, TaskState.FAILED, UnknownReason, tsm2)
when(tsm2.taskSetBlacklistHelperOpt.get.isExecutorBlacklistedForTask(
"executor0", failedTask2.index)).thenReturn(true)
@ -696,8 +686,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
// Fail the running task
val failedTask = taskAttempts.head
taskScheduler.statusUpdate(failedTask.taskId, TaskState.FAILED, ByteBuffer.allocate(0))
tsm.handleFailedTask(failedTask.taskId, TaskState.FAILED, UnknownReason)
failTask(failedTask.taskId, TaskState.FAILED, UnknownReason, tsm)
when(tsm.taskSetBlacklistHelperOpt.get.isExecutorBlacklistedForTask(
"executor0", failedTask.index)).thenReturn(true)
@ -845,7 +834,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
// Fail one of the tasks, but leave the other running.
val failedTask = firstTaskAttempts.find(_.executorId == "executor0").get
taskScheduler.handleFailedTask(tsm, failedTask.taskId, TaskState.FAILED, TaskResultLost)
failTask(failedTask.taskId, TaskState.FAILED, TaskResultLost, tsm)
// At this point, our failed task could run on the other executor, so don't give up the task
// set yet.
assert(!failedTaskSet)
@ -905,7 +894,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
// fail all the tasks on the bad executor
firstTaskAttempts.foreach { taskAttempt =>
taskScheduler.handleFailedTask(tsm, taskAttempt.taskId, TaskState.FAILED, TaskResultLost)
failTask(taskAttempt.taskId, TaskState.FAILED, TaskResultLost, tsm)
}
// Here is the main check of this test -- we have the same offers again, and we schedule it
@ -1276,4 +1265,19 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
assert(ArrayBuffer("0") === taskDescriptions(0).resources.get(GPU).get.addresses)
assert(ArrayBuffer("1") === taskDescriptions(1).resources.get(GPU).get.addresses)
}
/**
* Used by tests to simulate a task failure. This calls the failure handler explicitly, to ensure
* that all the state is updated when this method returns. Otherwise, there's no way to know when
* that happens, since the operation is performed asynchronously by the TaskResultGetter.
*/
private def failTask(
tid: Long,
state: TaskState.TaskState,
reason: TaskFailedReason,
tsm: TaskSetManager): Unit = {
taskScheduler.statusUpdate(tid, state, ByteBuffer.allocate(0))
taskScheduler.handleFailedTask(tsm, tid, state, reason)
}
}