diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java index 968777fba7..2e15671a25 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import java.io.FileNotFoundException; import java.net.ConnectException; import com.google.common.base.Throwables; @@ -82,8 +83,12 @@ public interface ErrorHandler { @Override public boolean shouldRetryError(Throwable t) { - // If it is a connection time out or a connection closed exception, no need to retry. - if (t.getCause() != null && t.getCause() instanceof ConnectException) { + // If it is a connection time-out or a connection closed exception, no need to retry. + // If it is a FileNotFoundException originating from the client while pushing the shuffle + // blocks to the server, even then there is no need to retry. We will still log this exception + // once which helps with debugging. + if (t.getCause() != null && (t.getCause() instanceof ConnectException || + t.getCause() instanceof FileNotFoundException)) { return false; } // If the block is too late, there is no need to retry it diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala index 88d084ce1b..53687bbd27 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle -import java.io.File +import java.io.{File, FileNotFoundException} import java.net.ConnectException import java.nio.ByteBuffer import java.util.concurrent.ExecutorService @@ -71,6 +71,12 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging { // blocks to just that host and continue push blocks to other hosts. So, here push of // all blocks will only stop when it is "Too Late". Also see updateStateAndCheckIfPushMore. override def shouldRetryError(t: Throwable): Boolean = { + // If it is a FileNotFoundException originating from the client while pushing the shuffle + // blocks to the server, then we stop pushing all the blocks because this indicates the + // shuffle files are deleted and subsequent block push will also fail. + if (t.getCause != null && t.getCause.isInstanceOf[FileNotFoundException]) { + return false + } // If the block is too late, there is no need to retry it !Throwables.getStackTraceAsString(t).contains(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX) } @@ -100,10 +106,22 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging { pushRequests ++= Utils.randomize(requests) submitTask(() => { - pushUpToMax() + tryPushUpToMax() }) } + private[shuffle] def tryPushUpToMax(): Unit = { + try { + pushUpToMax() + } catch { + case e: FileNotFoundException => + logWarning("The shuffle files got deleted when this shuffle-block-push-thread " + + "was reading from them which could happen when the job finishes and the driver " + + "instructs the executor to cleanup the shuffle. In this case, push of the blocks " + + "belonging to this shuffle will stop.", e) + } + } + /** * Triggers the push. It's a separate method for testing. * VisibleForTesting @@ -201,7 +219,7 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging { submitTask(() => { if (updateStateAndCheckIfPushMore( sizeMap(result.blockId), address, remainingBlocks, result)) { - pushUpToMax() + tryPushUpToMax() } }) } @@ -297,7 +315,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging { } } if (pushResult.failure != null && !errorHandler.shouldRetryError(pushResult.failure)) { - logDebug(s"Received after merge is finalized from $address. Not pushing any more blocks.") + logDebug(s"Encountered an exception from $address which indicates that push needs to " + + s"stop.") return false } else { remainingBlocks.isEmpty && (pushRequests.nonEmpty || deferredPushRequests.nonEmpty) diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala index de6a2c9391..6a07fefad2 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle -import java.io.File +import java.io.{File, FileNotFoundException, IOException} import java.net.ConnectException import java.nio.ByteBuffer import java.util.concurrent.LinkedBlockingQueue @@ -324,8 +324,32 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach { assert(pusher.unreachableBlockMgrs.size == 2) } + test("SPARK-36255: FileNotFoundException stops the push") { + when(dependency.getMergerLocs).thenReturn( + Seq(BlockManagerId("client1", "client1", 1), BlockManagerId("client2", "client2", 2))) + conf.set("spark.reducer.maxReqsInFlight", "1") + val pusher = new TestShuffleBlockPusher(conf) + when(shuffleClient.pushBlocks(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val pushedBlocks = invocation.getArguments()(2).asInstanceOf[Array[String]] + val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + pushedBlocks.foreach(blockId => { + blockFetchListener.onBlockFetchFailure( + blockId, new IOException("Failed to send RPC", + new FileNotFoundException("file not found"))) + }) + }) + pusher.initiateBlockPush( + mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0) + pusher.runPendingTasks() + verify(shuffleClient, times(1)) + .pushBlocks(any(), any(), any(), any(), any()) + assert(pusher.tasks.isEmpty) + ShuffleBlockPusher.stop() + } + private class TestShuffleBlockPusher(conf: SparkConf) extends ShuffleBlockPusher(conf) { - private[this] val tasks = new LinkedBlockingQueue[Runnable] + val tasks = new LinkedBlockingQueue[Runnable] override protected def submitTask(task: Runnable): Unit = { tasks.add(task)