diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 24c436a504..43408d43e5 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -254,7 +254,7 @@ public class TransportClientFactory implements Closeable { // Disable Nagle's Algorithm since we don't want packets to wait .option(ChannelOption.TCP_NODELAY, true) .option(ChannelOption.SO_KEEPALIVE, true) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionCreationTimeoutMs()) .option(ChannelOption.ALLOCATOR, pooledAllocator); if (conf.receiveBuf() > 0) { @@ -280,9 +280,10 @@ public class TransportClientFactory implements Closeable { // Connect to the remote server long preConnect = System.nanoTime(); ChannelFuture cf = bootstrap.connect(address); - if (!cf.await(conf.connectionTimeoutMs())) { + if (!cf.await(conf.connectionCreationTimeoutMs())) { throw new IOException( - String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); + String.format("Connecting to %s timed out (%s ms)", + address, conf.connectionCreationTimeoutMs())); } else if (cf.cause() != null) { throw new IOException(String.format("Failed to connect to %s", address), cf.cause()); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index d305dfa8e8..f051042a7a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -19,6 +19,7 @@ package org.apache.spark.network.util; import java.util.Locale; import java.util.Properties; +import java.util.concurrent.TimeUnit; import com.google.common.primitives.Ints; import io.netty.util.NettyRuntime; @@ -31,6 +32,7 @@ public class TransportConf { private final String SPARK_NETWORK_IO_MODE_KEY; private final String SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY; private final String SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY; + private final String SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY; private final String SPARK_NETWORK_IO_BACKLOG_KEY; private final String SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY; private final String SPARK_NETWORK_IO_SERVERTHREADS_KEY; @@ -54,6 +56,7 @@ public class TransportConf { SPARK_NETWORK_IO_MODE_KEY = getConfKey("io.mode"); SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY = getConfKey("io.preferDirectBufs"); SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY = getConfKey("io.connectionTimeout"); + SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY = getConfKey("io.connectionCreationTimeout"); SPARK_NETWORK_IO_BACKLOG_KEY = getConfKey("io.backLog"); SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY = getConfKey("io.numConnectionsPerPeer"); SPARK_NETWORK_IO_SERVERTHREADS_KEY = getConfKey("io.serverThreads"); @@ -94,7 +97,7 @@ public class TransportConf { return conf.getBoolean(SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY, true); } - /** Connect timeout in milliseconds. Default 120 secs. */ + /** Connection idle timeout in milliseconds. Default 120 secs. */ public int connectionTimeoutMs() { long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec( conf.get("spark.network.timeout", "120s")); @@ -103,6 +106,14 @@ public class TransportConf { return (int) defaultTimeoutMs; } + /** Connect creation timeout in milliseconds. Default 30 secs. */ + public int connectionCreationTimeoutMs() { + long connectionTimeoutS = TimeUnit.MILLISECONDS.toSeconds(connectionTimeoutMs()); + long defaultTimeoutMs = JavaUtils.timeStringAsSec( + conf.get(SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY, connectionTimeoutS + "s")) * 1000; + return (int) defaultTimeoutMs; + } + /** Number of concurrent connections between two nodes for fetching data. */ public int numConnectionsPerPeer() { return conf.getInt(SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY, 1); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 256789b8c7..3dbee1b13d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -31,7 +31,6 @@ import scala.Product2; import scala.Tuple2; import scala.collection.Iterator; -import com.google.common.annotations.VisibleForTesting; import com.google.common.io.Closeables; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -178,8 +177,8 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { } } - @VisibleForTesting - long[] getPartitionLengths() { + @Override + public long[] getPartitionLengths() { return partitionLengths; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 79e38a824f..e8f94ba8ff 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -88,6 +88,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; + @Nullable private long[] partitionLengths; private long peakMemoryUsedBytes = 0; /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ @@ -219,7 +220,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; - final long[] partitionLengths; try { partitionLengths = mergeSpills(spills); } finally { @@ -543,4 +543,9 @@ public class UnsafeShuffleWriter extends ShuffleWriter { channel.close(); } } + + @Override + public long[] getPartitionLengths() { + return partitionLengths; + } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c58009c166..3865c9c987 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -47,7 +47,7 @@ import org.apache.spark.metrics.source.JVMCPUSource import org.apache.spark.resource.ResourceInformation import org.apache.spark.rpc.RpcTimeout import org.apache.spark.scheduler._ -import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.{FetchFailedException, ShuffleBlockPusher} import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer @@ -325,6 +325,7 @@ private[spark] class Executor( case NonFatal(e) => logWarning("Unable to stop heartbeater", e) } + ShuffleBlockPusher.stop() threadPool.shutdown() // Notify plugins that executor is shutting down so they can terminate cleanly diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index adaf92d5a8..84c6647028 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -2030,4 +2030,33 @@ package object config { .version("3.1.0") .doubleConf .createWithDefault(5) + + private[spark] val SHUFFLE_NUM_PUSH_THREADS = + ConfigBuilder("spark.shuffle.push.numPushThreads") + .doc("Specify the number of threads in the block pusher pool. These threads assist " + + "in creating connections and pushing blocks to remote shuffle services. By default, the " + + "threadpool size is equal to the number of spark executor cores.") + .version("3.2.0") + .intConf + .createOptional + + private[spark] val SHUFFLE_MAX_BLOCK_SIZE_TO_PUSH = + ConfigBuilder("spark.shuffle.push.maxBlockSizeToPush") + .doc("The max size of an individual block to push to the remote shuffle services. Blocks " + + "larger than this threshold are not pushed to be merged remotely. These shuffle blocks " + + "will be fetched by the executors in the original manner.") + .version("3.2.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("1m") + + private[spark] val SHUFFLE_MAX_BLOCK_BATCH_SIZE_FOR_PUSH = + ConfigBuilder("spark.shuffle.push.maxBlockBatchSize") + .doc("The max size of a batch of shuffle blocks to be grouped into a single push request.") + .version("3.2.0") + .bytesConf(ByteUnit.BYTE) + // Default is 3m because it is greater than 2m which is the default value for + // TransportConf#memoryMapBytes. If this defaults to 2m as well it is very likely that each + // batch of block will be loaded in memory with memory mapping, which has higher overhead + // with small MB sized chunk of data. + .createWithDefaultString("3m") } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala new file mode 100644 index 0000000000..88d084ce1b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.File +import java.net.ConnectException +import java.nio.ByteBuffer +import java.util.concurrent.ExecutorService + +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} + +import com.google.common.base.Throwables + +import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv} +import org.apache.spark.annotation.Since +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.launcher.SparkLauncher +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.shuffle.ErrorHandler.BlockPushErrorHandler +import org.apache.spark.network.util.TransportConf +import org.apache.spark.shuffle.ShuffleBlockPusher._ +import org.apache.spark.storage.{BlockId, BlockManagerId, ShufflePushBlockId} +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * Used for pushing shuffle blocks to remote shuffle services when push shuffle is enabled. + * When push shuffle is enabled, it is created after the shuffle writer finishes writing the shuffle + * file and initiates the block push process. + * + * @param conf spark configuration + */ +@Since("3.2.0") +private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging { + private[this] val maxBlockSizeToPush = conf.get(SHUFFLE_MAX_BLOCK_SIZE_TO_PUSH) + private[this] val maxBlockBatchSize = conf.get(SHUFFLE_MAX_BLOCK_BATCH_SIZE_FOR_PUSH) + private[this] val maxBytesInFlight = + conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024 + private[this] val maxReqsInFlight = conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue) + private[this] val maxBlocksInFlightPerAddress = conf.get(REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS) + private[this] var bytesInFlight = 0L + private[this] var reqsInFlight = 0 + private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]() + private[this] val deferredPushRequests = new HashMap[BlockManagerId, Queue[PushRequest]]() + private[this] val pushRequests = new Queue[PushRequest] + private[this] val errorHandler = createErrorHandler() + // VisibleForTesting + private[shuffle] val unreachableBlockMgrs = new HashSet[BlockManagerId]() + + // VisibleForTesting + private[shuffle] def createErrorHandler(): BlockPushErrorHandler = { + new BlockPushErrorHandler() { + // For a connection exception against a particular host, we will stop pushing any + // blocks to just that host and continue push blocks to other hosts. So, here push of + // all blocks will only stop when it is "Too Late". Also see updateStateAndCheckIfPushMore. + override def shouldRetryError(t: Throwable): Boolean = { + // If the block is too late, there is no need to retry it + !Throwables.getStackTraceAsString(t).contains(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX) + } + } + } + + /** + * Initiates the block push. + * + * @param dataFile mapper generated shuffle data file + * @param partitionLengths array of shuffle block size so we can tell shuffle block + * @param dep shuffle dependency to get shuffle ID and the location of remote shuffle + * services to push local shuffle blocks + * @param mapIndex map index of the shuffle map task + */ + private[shuffle] def initiateBlockPush( + dataFile: File, + partitionLengths: Array[Long], + dep: ShuffleDependency[_, _, _], + mapIndex: Int): Unit = { + val numPartitions = dep.partitioner.numPartitions + val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") + val requests = prepareBlockPushRequests(numPartitions, mapIndex, dep.shuffleId, dataFile, + partitionLengths, dep.getMergerLocs, transportConf) + // Randomize the orders of the PushRequest, so different mappers pushing blocks at the same + // time won't be pushing the same ranges of shuffle partitions. + pushRequests ++= Utils.randomize(requests) + + submitTask(() => { + pushUpToMax() + }) + } + + /** + * Triggers the push. It's a separate method for testing. + * VisibleForTesting + */ + protected def submitTask(task: Runnable): Unit = { + if (BLOCK_PUSHER_POOL != null) { + BLOCK_PUSHER_POOL.execute(task) + } + } + + /** + * Since multiple block push threads could potentially be calling pushUpToMax for the same + * mapper, we synchronize access to this method so that only one thread can push blocks for + * a given mapper. This helps to simplify access to the shared states. The down side of this + * is that we could unnecessarily block other mappers' block pushes if all the threads + * are occupied by block pushes from the same mapper. + * + * This code is similar to ShuffleBlockFetcherIterator#fetchUpToMaxBytes in how it throttles + * the data transfer between shuffle client/server. + */ + private def pushUpToMax(): Unit = synchronized { + // Process any outstanding deferred push requests if possible. + if (deferredPushRequests.nonEmpty) { + for ((remoteAddress, defReqQueue) <- deferredPushRequests) { + while (isRemoteBlockPushable(defReqQueue) && + !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) { + val request = defReqQueue.dequeue() + logDebug(s"Processing deferred push request for $remoteAddress with " + + s"${request.blocks.length} blocks") + sendRequest(request) + if (defReqQueue.isEmpty) { + deferredPushRequests -= remoteAddress + } + } + } + } + + // Process any regular push requests if possible. + while (isRemoteBlockPushable(pushRequests)) { + val request = pushRequests.dequeue() + val remoteAddress = request.address + if (isRemoteAddressMaxedOut(remoteAddress, request)) { + logDebug(s"Deferring push request for $remoteAddress with ${request.blocks.size} blocks") + deferredPushRequests.getOrElseUpdate(remoteAddress, new Queue[PushRequest]()) + .enqueue(request) + } else { + sendRequest(request) + } + } + + def isRemoteBlockPushable(pushReqQueue: Queue[PushRequest]): Boolean = { + pushReqQueue.nonEmpty && + (bytesInFlight == 0 || + (reqsInFlight + 1 <= maxReqsInFlight && + bytesInFlight + pushReqQueue.front.size <= maxBytesInFlight)) + } + + // Checks if sending a new push request will exceed the max no. of blocks being pushed to a + // given remote address. + def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: PushRequest): Boolean = { + (numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + + request.blocks.size) > maxBlocksInFlightPerAddress + } + } + + /** + * Push blocks to remote shuffle server. The callback listener will invoke #pushUpToMax again + * to trigger pushing the next batch of blocks once some block transfer is done in the current + * batch. This way, we decouple the map task from the block push process, since it is netty + * client thread instead of task execution thread which takes care of majority of the block + * pushes. + */ + private def sendRequest(request: PushRequest): Unit = { + bytesInFlight += request.size + reqsInFlight += 1 + numBlocksInFlightPerAddress(request.address) = numBlocksInFlightPerAddress.getOrElseUpdate( + request.address, 0) + request.blocks.length + + val sizeMap = request.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap + val address = request.address + val blockIds = request.blocks.map(_._1.toString) + val remainingBlocks = new HashSet[String]() ++= blockIds + + val blockPushListener = new BlockFetchingListener { + // Initiating a connection and pushing blocks to a remote shuffle service is always handled by + // the block-push-threads. We should not initiate the connection creation in the + // blockPushListener callbacks which are invoked by the netty eventloop because: + // 1. TrasportClient.createConnection(...) blocks for connection to be established and it's + // recommended to avoid any blocking operations in the eventloop; + // 2. The actual connection creation is a task that gets added to the task queue of another + // eventloop which could have eventloops eventually blocking each other. + // Once the blockPushListener is notified of the block push success or failure, we + // just delegate it to block-push-threads. + def handleResult(result: PushResult): Unit = { + submitTask(() => { + if (updateStateAndCheckIfPushMore( + sizeMap(result.blockId), address, remainingBlocks, result)) { + pushUpToMax() + } + }) + } + + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + logTrace(s"Push for block $blockId to $address successful.") + handleResult(PushResult(blockId, null)) + } + + override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { + // check the message or it's cause to see it needs to be logged. + if (!errorHandler.shouldLogError(exception)) { + logTrace(s"Pushing block $blockId to $address failed.", exception) + } else { + logWarning(s"Pushing block $blockId to $address failed.", exception) + } + handleResult(PushResult(blockId, exception)) + } + } + SparkEnv.get.blockManager.blockStoreClient.pushBlocks( + address.host, address.port, blockIds.toArray, + sliceReqBufferIntoBlockBuffers(request.reqBuffer, request.blocks.map(_._2)), + blockPushListener) + } + + /** + * Given the ManagedBuffer representing all the continuous blocks inside the shuffle data file + * for a PushRequest and an array of individual block sizes, load the buffer from disk into + * memory and slice it into multiple smaller buffers representing each block. + * + * With nio ByteBuffer, the individual block buffers share data with the initial in memory + * buffer loaded from disk. Thus only one copy of the block data is kept in memory. + * @param reqBuffer A {{FileSegmentManagedBuffer}} representing all the continuous blocks in + * the shuffle data file for a PushRequest + * @param blockSizes Array of block sizes + * @return Array of in memory buffer for each individual block + */ + private def sliceReqBufferIntoBlockBuffers( + reqBuffer: ManagedBuffer, + blockSizes: Seq[Int]): Array[ManagedBuffer] = { + if (blockSizes.size == 1) { + Array(reqBuffer) + } else { + val inMemoryBuffer = reqBuffer.nioByteBuffer() + val blockOffsets = new Array[Int](blockSizes.size) + var offset = 0 + for (index <- blockSizes.indices) { + blockOffsets(index) = offset + offset += blockSizes(index) + } + blockOffsets.zip(blockSizes).map { + case (offset, size) => + new NioManagedBuffer(inMemoryBuffer.duplicate() + .position(offset) + .limit(offset + size).asInstanceOf[ByteBuffer].slice()) + }.toArray + } + } + + /** + * Updates the stats and based on the previous push result decides whether to push more blocks + * or stop. + * + * @param bytesPushed number of bytes pushed. + * @param address address of the remote service + * @param remainingBlocks remaining blocks + * @param pushResult result of the last push + * @return true if more blocks should be pushed; false otherwise. + */ + private def updateStateAndCheckIfPushMore( + bytesPushed: Long, + address: BlockManagerId, + remainingBlocks: HashSet[String], + pushResult: PushResult): Boolean = synchronized { + remainingBlocks -= pushResult.blockId + bytesInFlight -= bytesPushed + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 + if (remainingBlocks.isEmpty) { + reqsInFlight -= 1 + } + if (pushResult.failure != null && pushResult.failure.getCause.isInstanceOf[ConnectException]) { + // Remove all the blocks for this address just once because removing from pushRequests + // is expensive. If there is a ConnectException for the first block, all the subsequent + // blocks to that address will fail, so should avoid removing multiple times. + if (!unreachableBlockMgrs.contains(address)) { + var removed = 0 + unreachableBlockMgrs.add(address) + removed += pushRequests.dequeueAll(req => req.address == address).length + removed += deferredPushRequests.remove(address).map(_.length).getOrElse(0) + logWarning(s"Received a ConnectException from $address. " + + s"Dropping $removed push-requests and " + + s"not pushing any more blocks to this address.") + } + } + if (pushResult.failure != null && !errorHandler.shouldRetryError(pushResult.failure)) { + logDebug(s"Received after merge is finalized from $address. Not pushing any more blocks.") + return false + } else { + remainingBlocks.isEmpty && (pushRequests.nonEmpty || deferredPushRequests.nonEmpty) + } + } + + /** + * Convert the shuffle data file of the current mapper into a list of PushRequest. Basically, + * continuous blocks in the shuffle file are grouped into a single request to allow more + * efficient read of the block data. Each mapper for a given shuffle will receive the same + * list of BlockManagerIds as the target location to push the blocks to. All mappers in the + * same shuffle will map shuffle partition ranges to individual target locations in a consistent + * manner to make sure each target location receives shuffle blocks belonging to the same set + * of partition ranges. 0-length blocks and blocks that are large enough will be skipped. + * + * @param numPartitions sumber of shuffle partitions in the shuffle file + * @param partitionId map index of the current mapper + * @param shuffleId shuffleId of current shuffle + * @param dataFile shuffle data file + * @param partitionLengths array of sizes of blocks in the shuffle data file + * @param mergerLocs target locations to push blocks to + * @param transportConf transportConf used to create FileSegmentManagedBuffer + * @return List of the PushRequest, randomly shuffled. + * + * VisibleForTesting + */ + private[shuffle] def prepareBlockPushRequests( + numPartitions: Int, + partitionId: Int, + shuffleId: Int, + dataFile: File, + partitionLengths: Array[Long], + mergerLocs: Seq[BlockManagerId], + transportConf: TransportConf): Seq[PushRequest] = { + var offset = 0L + var currentReqSize = 0 + var currentReqOffset = 0L + var currentMergerId = 0 + val numMergers = mergerLocs.length + val requests = new ArrayBuffer[PushRequest] + var blocks = new ArrayBuffer[(BlockId, Int)] + for (reduceId <- 0 until numPartitions) { + val blockSize = partitionLengths(reduceId) + logDebug( + s"Block ${ShufflePushBlockId(shuffleId, partitionId, reduceId)} is of size $blockSize") + // Skip 0-length blocks and blocks that are large enough + if (blockSize > 0) { + val mergerId = math.min(math.floor(reduceId * 1.0 / numPartitions * numMergers), + numMergers - 1).asInstanceOf[Int] + // Start a new PushRequest if the current request goes beyond the max batch size, + // or the number of blocks in the current request goes beyond the limit per destination, + // or the next block push location is for a different shuffle service, or the next block + // exceeds the max block size to push limit. This guarantees that each PushRequest + // represents continuous blocks in the shuffle file to be pushed to the same shuffle + // service, and does not go beyond existing limitations. + if (currentReqSize + blockSize <= maxBlockBatchSize + && blocks.size < maxBlocksInFlightPerAddress + && mergerId == currentMergerId && blockSize <= maxBlockSizeToPush) { + // Add current block to current batch + currentReqSize += blockSize.toInt + } else { + if (blocks.nonEmpty) { + // Convert the previous batch into a PushRequest + requests += PushRequest(mergerLocs(currentMergerId), blocks.toSeq, + createRequestBuffer(transportConf, dataFile, currentReqOffset, currentReqSize)) + blocks = new ArrayBuffer[(BlockId, Int)] + } + // Start a new batch + currentReqSize = 0 + // Set currentReqOffset to -1 so we are able to distinguish between the initial value + // of currentReqOffset and when we are about to start a new batch + currentReqOffset = -1 + currentMergerId = mergerId + } + // Only push blocks under the size limit + if (blockSize <= maxBlockSizeToPush) { + val blockSizeInt = blockSize.toInt + blocks += ((ShufflePushBlockId(shuffleId, partitionId, reduceId), blockSizeInt)) + // Only update currentReqOffset if the current block is the first in the request + if (currentReqOffset == -1) { + currentReqOffset = offset + } + if (currentReqSize == 0) { + currentReqSize += blockSizeInt + } + } + } + offset += blockSize + } + // Add in the final request + if (blocks.nonEmpty) { + requests += PushRequest(mergerLocs(currentMergerId), blocks.toSeq, + createRequestBuffer(transportConf, dataFile, currentReqOffset, currentReqSize)) + } + requests.toSeq + } + + // Visible for testing + protected def createRequestBuffer( + conf: TransportConf, + dataFile: File, + offset: Long, + length: Long): ManagedBuffer = { + new FileSegmentManagedBuffer(conf, dataFile, offset, length) + } +} + +private[spark] object ShuffleBlockPusher { + + /** + * A request to push blocks to a remote shuffle service + * @param address remote shuffle service location to push blocks to + * @param blocks list of block IDs and their sizes + * @param reqBuffer a chunk of data in the shuffle data file corresponding to the continuous + * blocks represented in this request + */ + private[spark] case class PushRequest( + address: BlockManagerId, + blocks: Seq[(BlockId, Int)], + reqBuffer: ManagedBuffer) { + val size = blocks.map(_._2).sum + } + + /** + * Result of the block push. + * @param blockId blockId + * @param failure exception if the push was unsuccessful; null otherwise; + */ + private case class PushResult(blockId: String, failure: Throwable) + + private val BLOCK_PUSHER_POOL: ExecutorService = { + val conf = SparkEnv.get.conf + if (Utils.isPushBasedShuffleEnabled(conf)) { + val numThreads = conf.get(SHUFFLE_NUM_PUSH_THREADS) + .getOrElse(conf.getInt(SparkLauncher.EXECUTOR_CORES, 1)) + ThreadUtils.newDaemonFixedThreadPool(numThreads, "shuffle-block-push-thread") + } else { + null + } + } + + /** + * Stop the shuffle pusher pool if it isn't null. + */ + private[spark] def stop(): Unit = { + if (BLOCK_PUSHER_POOL != null) { + BLOCK_PUSHER_POOL.shutdown() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala index 1429144c6f..abff650b06 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala @@ -21,6 +21,7 @@ import org.apache.spark.{Partition, ShuffleDependency, SparkEnv, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.MapStatus +import org.apache.spark.util.Utils /** * The interface for customizing shuffle write process. The driver create a ShuffleWriteProcessor @@ -57,7 +58,23 @@ private[spark] class ShuffleWriteProcessor extends Serializable with Logging { createMetricsReporter(context)) writer.write( rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) - writer.stop(success = true).get + val mapStatus = writer.stop(success = true) + if (mapStatus.isDefined) { + // Initiate shuffle push process if push based shuffle is enabled + // The map task only takes care of converting the shuffle data file into multiple + // block push requests. It delegates pushing the blocks to a different thread-pool - + // ShuffleBlockPusher.BLOCK_PUSHER_POOL. + if (Utils.isPushBasedShuffleEnabled(SparkEnv.get.conf) && dep.getMergerLocs.nonEmpty) { + manager.shuffleBlockResolver match { + case resolver: IndexShuffleBlockResolver => + val dataFile = resolver.getDataFile(dep.shuffleId, mapId) + new ShuffleBlockPusher(SparkEnv.get.conf) + .initiateBlockPush(dataFile, writer.getPartitionLengths(), dep, partition.index) + case _ => + } + } + } + mapStatus.get } catch { case e: Exception => try { diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala index 4cc4ef5f18..a279b4c8f4 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -31,4 +31,7 @@ private[spark] abstract class ShuffleWriter[K, V] { /** Close this writer, passing along whether the map completed */ def stop(success: Boolean): Option[MapStatus] + + /** Get the lengths of each partition */ + def getPartitionLengths(): Array[Long] } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 83ebe3e129..af8d1e2fff 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -45,6 +45,8 @@ private[spark] class SortShuffleWriter[K, V, C]( private var mapStatus: MapStatus = null + private var partitionLengths: Array[Long] = _ + private val writeMetrics = context.taskMetrics().shuffleWriteMetrics /** Write a bunch of records to this task's output */ @@ -67,7 +69,7 @@ private[spark] class SortShuffleWriter[K, V, C]( val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( dep.shuffleId, mapId, dep.partitioner.numPartitions) sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) - val partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths + partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) } @@ -93,6 +95,8 @@ private[spark] class SortShuffleWriter[K, V, C]( } } } + + override def getPartitionLengths(): Array[Long] = partitionLengths } private[spark] object SortShuffleWriter { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 7b084e73c9..73bf809a08 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -20,7 +20,7 @@ package org.apache.spark.storage import java.util.UUID import org.apache.spark.SparkException -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} /** * :: DeveloperApi :: @@ -81,6 +81,12 @@ case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, reduceId: Int) exten override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" } +@Since("3.2.0") +@DeveloperApi +case class ShufflePushBlockId(shuffleId: Int, mapIndex: Int, reduceId: Int) extends BlockId { + override def name: String = "shufflePush_" + shuffleId + "_" + mapIndex + "_" + reduceId +} + @DeveloperApi case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId { override def name: String = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) @@ -122,6 +128,7 @@ object BlockId { val SHUFFLE_BATCH = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r + val SHUFFLE_PUSH = "shufflePush_([0-9]+)_([0-9]+)_([0-9]+)".r val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r @@ -140,6 +147,8 @@ object BlockId { ShuffleDataBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt) case SHUFFLE_INDEX(shuffleId, mapId, reduceId) => ShuffleIndexBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt) + case SHUFFLE_PUSH(shuffleId, mapIndex, reduceId) => + ShufflePushBlockId(shuffleId.toInt, mapIndex.toInt, reduceId.toInt) case BROADCAST(broadcastId, field) => BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_")) case TASKRESULT(taskId) => diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala new file mode 100644 index 0000000000..cc561e6106 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala @@ -0,0 +1,355 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.File +import java.net.ConnectException +import java.nio.ByteBuffer +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.mutable.ArrayBuffer + +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark._ +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockStoreClient} +import org.apache.spark.network.shuffle.ErrorHandler.BlockPushErrorHandler +import org.apache.spark.network.util.TransportConf +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.ShuffleBlockPusher.PushRequest +import org.apache.spark.storage._ + +class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach { + + @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ + @Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _ + @Mock(answer = RETURNS_SMART_NULLS) private var shuffleClient: BlockStoreClient = _ + + private var conf: SparkConf = _ + private var pushedBlocks = new ArrayBuffer[String] + + override def beforeEach(): Unit = { + super.beforeEach() + conf = new SparkConf(loadDefaults = false) + MockitoAnnotations.initMocks(this) + when(dependency.partitioner).thenReturn(new HashPartitioner(8)) + when(dependency.serializer).thenReturn(new JavaSerializer(conf)) + when(dependency.getMergerLocs).thenReturn(Seq(BlockManagerId("test-client", "test-client", 1))) + conf.set("spark.shuffle.push.based.enabled", "true") + conf.set("spark.shuffle.service.enabled", "true") + // Set the env because the shuffler writer gets the shuffle client instance from the env. + val mockEnv = mock(classOf[SparkEnv]) + when(mockEnv.conf).thenReturn(conf) + when(mockEnv.blockManager).thenReturn(blockManager) + SparkEnv.set(mockEnv) + when(blockManager.blockStoreClient).thenReturn(shuffleClient) + } + + override def afterEach(): Unit = { + pushedBlocks.clear() + super.afterEach() + } + + private def interceptPushedBlocksForSuccess(): Unit = { + when(shuffleClient.pushBlocks(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]] + pushedBlocks ++= blocks + val managedBuffers = invocation.getArguments()(3).asInstanceOf[Array[ManagedBuffer]] + val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + (blocks, managedBuffers).zipped.foreach((blockId, buffer) => { + blockFetchListener.onBlockFetchSuccess(blockId, buffer) + }) + }) + } + + private def verifyPushRequests( + pushRequests: Seq[PushRequest], + expectedSizes: Seq[Int]): Unit = { + (pushRequests, expectedSizes).zipped.foreach((req, size) => { + assert(req.size == size) + }) + } + + test("A batch of blocks is limited by maxBlocksBatchSize") { + conf.set("spark.shuffle.push.maxBlockBatchSize", "1m") + conf.set("spark.shuffle.push.maxBlockSizeToPush", "2048k") + val blockPusher = new TestShuffleBlockPusher(conf) + val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port)) + val largeBlockSize = 2 * 1024 * 1024 + val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, + mock(classOf[File]), Array(2, 2, 2, largeBlockSize, largeBlockSize), mergerLocs, + mock(classOf[TransportConf])) + assert(pushRequests.length == 3) + verifyPushRequests(pushRequests, Seq(6, largeBlockSize, largeBlockSize)) + } + + test("Large blocks are excluded in the preparation") { + conf.set("spark.shuffle.push.maxBlockSizeToPush", "1k") + val blockPusher = new TestShuffleBlockPusher(conf) + val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port)) + val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, + mock(classOf[File]), Array(2, 2, 2, 1028, 1024), mergerLocs, mock(classOf[TransportConf])) + assert(pushRequests.length == 2) + verifyPushRequests(pushRequests, Seq(6, 1024)) + } + + test("Number of blocks in a push request are limited by maxBlocksInFlightPerAddress ") { + conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1") + val blockPusher = new TestShuffleBlockPusher(conf) + val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port)) + val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, + mock(classOf[File]), Array(2, 2, 2, 2, 2), mergerLocs, mock(classOf[TransportConf])) + assert(pushRequests.length == 5) + verifyPushRequests(pushRequests, Seq(2, 2, 2, 2, 2)) + } + + test("Basic block push") { + interceptPushedBlocksForSuccess() + val blockPusher = new TestShuffleBlockPusher(conf) + blockPusher.initiateBlockPush(mock(classOf[File]), + Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0) + blockPusher.runPendingTasks() + verify(shuffleClient, times(1)) + .pushBlocks(any(), any(), any(), any(), any()) + assert(pushedBlocks.length == dependency.partitioner.numPartitions) + ShuffleBlockPusher.stop() + } + + test("Large blocks are skipped for push") { + conf.set("spark.shuffle.push.maxBlockSizeToPush", "1k") + interceptPushedBlocksForSuccess() + val pusher = new TestShuffleBlockPusher(conf) + pusher.initiateBlockPush( + mock(classOf[File]), Array(2, 2, 2, 2, 2, 2, 2, 1100), dependency, 0) + pusher.runPendingTasks() + verify(shuffleClient, times(1)) + .pushBlocks(any(), any(), any(), any(), any()) + assert(pushedBlocks.length == dependency.partitioner.numPartitions - 1) + ShuffleBlockPusher.stop() + } + + test("Number of blocks in flight per address are limited by maxBlocksInFlightPerAddress") { + conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1") + interceptPushedBlocksForSuccess() + val pusher = new TestShuffleBlockPusher(conf) + pusher.initiateBlockPush( + mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0) + pusher.runPendingTasks() + verify(shuffleClient, times(8)) + .pushBlocks(any(), any(), any(), any(), any()) + assert(pushedBlocks.length == dependency.partitioner.numPartitions) + ShuffleBlockPusher.stop() + } + + test("Hit maxBlocksInFlightPerAddress limit so that the blocks are deferred") { + conf.set("spark.reducer.maxBlocksInFlightPerAddress", "2") + var blockPendingResponse : String = null + var listener : BlockFetchingListener = null + when(shuffleClient.pushBlocks(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]] + pushedBlocks ++= blocks + val managedBuffers = invocation.getArguments()(3).asInstanceOf[Array[ManagedBuffer]] + val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + // Expecting 2 blocks + assert(blocks.length == 2) + if (blockPendingResponse == null) { + blockPendingResponse = blocks(1) + listener = blockFetchListener + // Respond with success only for the first block which will cause all the rest of the + // blocks to be deferred + blockFetchListener.onBlockFetchSuccess(blocks(0), managedBuffers(0)) + } else { + (blocks, managedBuffers).zipped.foreach((blockId, buffer) => { + blockFetchListener.onBlockFetchSuccess(blockId, buffer) + }) + } + }) + val pusher = new TestShuffleBlockPusher(conf) + pusher.initiateBlockPush( + mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0) + pusher.runPendingTasks() + verify(shuffleClient, times(1)) + .pushBlocks(any(), any(), any(), any(), any()) + assert(pushedBlocks.length == 2) + // this will trigger push of deferred blocks + listener.onBlockFetchSuccess(blockPendingResponse, mock(classOf[ManagedBuffer])) + pusher.runPendingTasks() + verify(shuffleClient, times(4)) + .pushBlocks(any(), any(), any(), any(), any()) + assert(pushedBlocks.length == 8) + ShuffleBlockPusher.stop() + } + + test("Number of shuffle blocks grouped in a single push request is limited by " + + "maxBlockBatchSize") { + conf.set("spark.shuffle.push.maxBlockBatchSize", "1m") + interceptPushedBlocksForSuccess() + val pusher = new TestShuffleBlockPusher(conf) + pusher.initiateBlockPush(mock(classOf[File]), + Array.fill(dependency.partitioner.numPartitions) { 512 * 1024 }, dependency, 0) + pusher.runPendingTasks() + verify(shuffleClient, times(4)) + .pushBlocks(any(), any(), any(), any(), any()) + assert(pushedBlocks.length == dependency.partitioner.numPartitions) + ShuffleBlockPusher.stop() + } + + test("Error retries") { + val pusher = new ShuffleBlockPusher(conf) + val errorHandler = pusher.createErrorHandler() + assert( + !errorHandler.shouldRetryError(new RuntimeException( + new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)))) + assert(errorHandler.shouldRetryError(new RuntimeException(new ConnectException()))) + assert( + errorHandler.shouldRetryError(new RuntimeException(new IllegalArgumentException( + BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX)))) + assert (errorHandler.shouldRetryError(new Throwable())) + } + + test("Error logging") { + val pusher = new ShuffleBlockPusher(conf) + val errorHandler = pusher.createErrorHandler() + assert( + !errorHandler.shouldLogError(new RuntimeException( + new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)))) + assert(!errorHandler.shouldLogError(new RuntimeException( + new IllegalArgumentException( + BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX)))) + assert(errorHandler.shouldLogError(new Throwable())) + } + + test("Blocks are continued to push even when a block push fails with collision " + + "exception") { + conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1") + val pusher = new TestShuffleBlockPusher(conf) + var failBlock: Boolean = true + when(shuffleClient.pushBlocks(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]] + val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + blocks.foreach(blockId => { + if (failBlock) { + failBlock = false + // Fail the first block with the collision exception. + blockFetchListener.onBlockFetchFailure(blockId, new RuntimeException( + new IllegalArgumentException( + BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))) + } else { + pushedBlocks += blockId + blockFetchListener.onBlockFetchSuccess(blockId, mock(classOf[ManagedBuffer])) + } + }) + }) + pusher.initiateBlockPush( + mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0) + pusher.runPendingTasks() + verify(shuffleClient, times(8)) + .pushBlocks(any(), any(), any(), any(), any()) + assert(pushedBlocks.length == 7) + } + + test("More blocks are not pushed when a block push fails with too late " + + "exception") { + conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1") + val pusher = new TestShuffleBlockPusher(conf) + var failBlock: Boolean = true + when(shuffleClient.pushBlocks(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]] + val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + blocks.foreach(blockId => { + if (failBlock) { + failBlock = false + // Fail the first block with the too late exception. + blockFetchListener.onBlockFetchFailure(blockId, new RuntimeException( + new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))) + } else { + pushedBlocks += blockId + blockFetchListener.onBlockFetchSuccess(blockId, mock(classOf[ManagedBuffer])) + } + }) + }) + pusher.initiateBlockPush( + mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0) + pusher.runPendingTasks() + verify(shuffleClient, times(1)) + .pushBlocks(any(), any(), any(), any(), any()) + assert(pushedBlocks.isEmpty) + } + + test("Connect exceptions remove all the push requests for that host") { + when(dependency.getMergerLocs).thenReturn( + Seq(BlockManagerId("client1", "client1", 1), BlockManagerId("client2", "client2", 2))) + conf.set("spark.reducer.maxBlocksInFlightPerAddress", "2") + when(shuffleClient.pushBlocks(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]] + pushedBlocks ++= blocks + val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + blocks.foreach(blockId => { + blockFetchListener.onBlockFetchFailure( + blockId, new RuntimeException(new ConnectException())) + }) + }) + val pusher = new TestShuffleBlockPusher(conf) + pusher.initiateBlockPush( + mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0) + pusher.runPendingTasks() + verify(shuffleClient, times(2)) + .pushBlocks(any(), any(), any(), any(), any()) + // 2 blocks for each merger locations + assert(pushedBlocks.length == 4) + assert(pusher.unreachableBlockMgrs.size == 2) + } + + private class TestShuffleBlockPusher(conf: SparkConf) extends ShuffleBlockPusher(conf) { + private[this] val tasks = new LinkedBlockingQueue[Runnable] + + override protected def submitTask(task: Runnable): Unit = { + tasks.add(task) + } + + def runPendingTasks(): Unit = { + // This ensures that all the submitted tasks - updateStateAndCheckIfPushMore and pushUpToMax + // are run synchronously. + while (!tasks.isEmpty) { + tasks.take().run() + } + } + + override protected def createRequestBuffer( + conf: TransportConf, + dataFile: File, + offset: Long, + length: Long): ManagedBuffer = { + val managedBuffer = mock(classOf[ManagedBuffer]) + val byteBuffer = new Array[Byte](length.toInt) + when(managedBuffer.nioByteBuffer()).thenReturn(ByteBuffer.wrap(byteBuffer)) + managedBuffer + } + } +}