[SPARK-32917][SHUFFLE][CORE] Adds support for executors to push shuffle blocks after successful map task completion

### What changes were proposed in this pull request?
This is the shuffle writer side change where executors can push data to remote shuffle services. This is needed for push-based shuffle - SPIP [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602).
Summary of changes:
- This adds support for executors to push shuffle blocks after map tasks complete writing shuffle data.
- This also introduces a timeout specifically for creating connection to remote shuffle services.

### Why are the changes needed?
- These changes are needed for push-based shuffle. Refer to the SPIP in [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602).
- The main reason to create a separate connection creation timeout is because the existing `connectionTimeoutMs` is overloaded and is used for connection creation timeouts as well as connection idle timeout. The connection creation timeout should be much lower than the idle timeouts. The default for `connectionTimeoutMs` is 120s. This is quite high for just establishing the connections.  If a shuffle server node is bad then the connection creation will fail within few seconds. However, an overloaded shuffle server may take much longer to respond to a request and the channel can stay idle for a much longer time which is expected.  Another reason is that with push-based shuffle, an executor may be fetching shuffle data and pushing shuffle data (next stage) simultaneously. Both these tasks will share the same connections with the shuffle service. If there is a bad shuffle server node and the connection creation timeout is very high then both these tasks end up waiting a long time time eventually impacting the performance.

### Does this PR introduce _any_ user-facing change?
Yes. This PR introduces client-side configs for push-based shuffle. If push-based shuffle is turned-off then the users will not see any change.

### How was this patch tested?
Added unit tests.
The reference PR with the consolidated changes covering the complete implementation is also provided in [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602).
We have already verified the functionality and the improved performance as documented in the SPIP doc.

Lead-authored-by: Min Shen mshenlinkedin.com
Co-authored-by: Chandni Singh chsinghlinkedin.com
Co-authored-by: Ye Zhou yezhoulinkedin.com

Closes #30312 from otterc/SPARK-32917.

Lead-authored-by: Chandni Singh <singh.chandni@gmail.com>
Co-authored-by: Chandni Singh <chsingh@linkedin.com>
Co-authored-by: Min Shen <mshen@linked.in.com>
Co-authored-by: Ye Zhou <yezhou@linkedin.com>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
This commit is contained in:
Chandni Singh 2021-01-08 12:21:56 -06:00 committed by Mridul Muralidharan
parent 0781ed4f5b
commit d00f0695b7
12 changed files with 896 additions and 12 deletions

View file

@ -254,7 +254,7 @@ public class TransportClientFactory implements Closeable {
// Disable Nagle's Algorithm since we don't want packets to wait
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionCreationTimeoutMs())
.option(ChannelOption.ALLOCATOR, pooledAllocator);
if (conf.receiveBuf() > 0) {
@ -280,9 +280,10 @@ public class TransportClientFactory implements Closeable {
// Connect to the remote server
long preConnect = System.nanoTime();
ChannelFuture cf = bootstrap.connect(address);
if (!cf.await(conf.connectionTimeoutMs())) {
if (!cf.await(conf.connectionCreationTimeoutMs())) {
throw new IOException(
String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
String.format("Connecting to %s timed out (%s ms)",
address, conf.connectionCreationTimeoutMs()));
} else if (cf.cause() != null) {
throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
}

View file

@ -19,6 +19,7 @@ package org.apache.spark.network.util;
import java.util.Locale;
import java.util.Properties;
import java.util.concurrent.TimeUnit;
import com.google.common.primitives.Ints;
import io.netty.util.NettyRuntime;
@ -31,6 +32,7 @@ public class TransportConf {
private final String SPARK_NETWORK_IO_MODE_KEY;
private final String SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY;
private final String SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY;
private final String SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY;
private final String SPARK_NETWORK_IO_BACKLOG_KEY;
private final String SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY;
private final String SPARK_NETWORK_IO_SERVERTHREADS_KEY;
@ -54,6 +56,7 @@ public class TransportConf {
SPARK_NETWORK_IO_MODE_KEY = getConfKey("io.mode");
SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY = getConfKey("io.preferDirectBufs");
SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY = getConfKey("io.connectionTimeout");
SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY = getConfKey("io.connectionCreationTimeout");
SPARK_NETWORK_IO_BACKLOG_KEY = getConfKey("io.backLog");
SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY = getConfKey("io.numConnectionsPerPeer");
SPARK_NETWORK_IO_SERVERTHREADS_KEY = getConfKey("io.serverThreads");
@ -94,7 +97,7 @@ public class TransportConf {
return conf.getBoolean(SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY, true);
}
/** Connect timeout in milliseconds. Default 120 secs. */
/** Connection idle timeout in milliseconds. Default 120 secs. */
public int connectionTimeoutMs() {
long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec(
conf.get("spark.network.timeout", "120s"));
@ -103,6 +106,14 @@ public class TransportConf {
return (int) defaultTimeoutMs;
}
/** Connect creation timeout in milliseconds. Default 30 secs. */
public int connectionCreationTimeoutMs() {
long connectionTimeoutS = TimeUnit.MILLISECONDS.toSeconds(connectionTimeoutMs());
long defaultTimeoutMs = JavaUtils.timeStringAsSec(
conf.get(SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY, connectionTimeoutS + "s")) * 1000;
return (int) defaultTimeoutMs;
}
/** Number of concurrent connections between two nodes for fetching data. */
public int numConnectionsPerPeer() {
return conf.getInt(SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY, 1);

View file

@ -31,7 +31,6 @@ import scala.Product2;
import scala.Tuple2;
import scala.collection.Iterator;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Closeables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -178,8 +177,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
}
}
@VisibleForTesting
long[] getPartitionLengths() {
@Override
public long[] getPartitionLengths() {
return partitionLengths;
}

View file

@ -88,6 +88,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@Nullable private MapStatus mapStatus;
@Nullable private ShuffleExternalSorter sorter;
@Nullable private long[] partitionLengths;
private long peakMemoryUsedBytes = 0;
/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
@ -219,7 +220,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
serOutputStream = null;
final SpillInfo[] spills = sorter.closeAndGetSpills();
sorter = null;
final long[] partitionLengths;
try {
partitionLengths = mergeSpills(spills);
} finally {
@ -543,4 +543,9 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
channel.close();
}
}
@Override
public long[] getPartitionLengths() {
return partitionLengths;
}
}

View file

@ -47,7 +47,7 @@ import org.apache.spark.metrics.source.JVMCPUSource
import org.apache.spark.resource.ResourceInformation
import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.scheduler._
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.shuffle.{FetchFailedException, ShuffleBlockPusher}
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util._
import org.apache.spark.util.io.ChunkedByteBuffer
@ -325,6 +325,7 @@ private[spark] class Executor(
case NonFatal(e) =>
logWarning("Unable to stop heartbeater", e)
}
ShuffleBlockPusher.stop()
threadPool.shutdown()
// Notify plugins that executor is shutting down so they can terminate cleanly

View file

@ -2030,4 +2030,33 @@ package object config {
.version("3.1.0")
.doubleConf
.createWithDefault(5)
private[spark] val SHUFFLE_NUM_PUSH_THREADS =
ConfigBuilder("spark.shuffle.push.numPushThreads")
.doc("Specify the number of threads in the block pusher pool. These threads assist " +
"in creating connections and pushing blocks to remote shuffle services. By default, the " +
"threadpool size is equal to the number of spark executor cores.")
.version("3.2.0")
.intConf
.createOptional
private[spark] val SHUFFLE_MAX_BLOCK_SIZE_TO_PUSH =
ConfigBuilder("spark.shuffle.push.maxBlockSizeToPush")
.doc("The max size of an individual block to push to the remote shuffle services. Blocks " +
"larger than this threshold are not pushed to be merged remotely. These shuffle blocks " +
"will be fetched by the executors in the original manner.")
.version("3.2.0")
.bytesConf(ByteUnit.BYTE)
.createWithDefaultString("1m")
private[spark] val SHUFFLE_MAX_BLOCK_BATCH_SIZE_FOR_PUSH =
ConfigBuilder("spark.shuffle.push.maxBlockBatchSize")
.doc("The max size of a batch of shuffle blocks to be grouped into a single push request.")
.version("3.2.0")
.bytesConf(ByteUnit.BYTE)
// Default is 3m because it is greater than 2m which is the default value for
// TransportConf#memoryMapBytes. If this defaults to 2m as well it is very likely that each
// batch of block will be loaded in memory with memory mapping, which has higher overhead
// with small MB sized chunk of data.
.createWithDefaultString("3m")
}

View file

@ -0,0 +1,450 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle
import java.io.File
import java.net.ConnectException
import java.nio.ByteBuffer
import java.util.concurrent.ExecutorService
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
import com.google.common.base.Throwables
import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv}
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.network.shuffle.ErrorHandler.BlockPushErrorHandler
import org.apache.spark.network.util.TransportConf
import org.apache.spark.shuffle.ShuffleBlockPusher._
import org.apache.spark.storage.{BlockId, BlockManagerId, ShufflePushBlockId}
import org.apache.spark.util.{ThreadUtils, Utils}
/**
* Used for pushing shuffle blocks to remote shuffle services when push shuffle is enabled.
* When push shuffle is enabled, it is created after the shuffle writer finishes writing the shuffle
* file and initiates the block push process.
*
* @param conf spark configuration
*/
@Since("3.2.0")
private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
private[this] val maxBlockSizeToPush = conf.get(SHUFFLE_MAX_BLOCK_SIZE_TO_PUSH)
private[this] val maxBlockBatchSize = conf.get(SHUFFLE_MAX_BLOCK_BATCH_SIZE_FOR_PUSH)
private[this] val maxBytesInFlight =
conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024
private[this] val maxReqsInFlight = conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)
private[this] val maxBlocksInFlightPerAddress = conf.get(REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS)
private[this] var bytesInFlight = 0L
private[this] var reqsInFlight = 0
private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]()
private[this] val deferredPushRequests = new HashMap[BlockManagerId, Queue[PushRequest]]()
private[this] val pushRequests = new Queue[PushRequest]
private[this] val errorHandler = createErrorHandler()
// VisibleForTesting
private[shuffle] val unreachableBlockMgrs = new HashSet[BlockManagerId]()
// VisibleForTesting
private[shuffle] def createErrorHandler(): BlockPushErrorHandler = {
new BlockPushErrorHandler() {
// For a connection exception against a particular host, we will stop pushing any
// blocks to just that host and continue push blocks to other hosts. So, here push of
// all blocks will only stop when it is "Too Late". Also see updateStateAndCheckIfPushMore.
override def shouldRetryError(t: Throwable): Boolean = {
// If the block is too late, there is no need to retry it
!Throwables.getStackTraceAsString(t).contains(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)
}
}
}
/**
* Initiates the block push.
*
* @param dataFile mapper generated shuffle data file
* @param partitionLengths array of shuffle block size so we can tell shuffle block
* @param dep shuffle dependency to get shuffle ID and the location of remote shuffle
* services to push local shuffle blocks
* @param mapIndex map index of the shuffle map task
*/
private[shuffle] def initiateBlockPush(
dataFile: File,
partitionLengths: Array[Long],
dep: ShuffleDependency[_, _, _],
mapIndex: Int): Unit = {
val numPartitions = dep.partitioner.numPartitions
val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
val requests = prepareBlockPushRequests(numPartitions, mapIndex, dep.shuffleId, dataFile,
partitionLengths, dep.getMergerLocs, transportConf)
// Randomize the orders of the PushRequest, so different mappers pushing blocks at the same
// time won't be pushing the same ranges of shuffle partitions.
pushRequests ++= Utils.randomize(requests)
submitTask(() => {
pushUpToMax()
})
}
/**
* Triggers the push. It's a separate method for testing.
* VisibleForTesting
*/
protected def submitTask(task: Runnable): Unit = {
if (BLOCK_PUSHER_POOL != null) {
BLOCK_PUSHER_POOL.execute(task)
}
}
/**
* Since multiple block push threads could potentially be calling pushUpToMax for the same
* mapper, we synchronize access to this method so that only one thread can push blocks for
* a given mapper. This helps to simplify access to the shared states. The down side of this
* is that we could unnecessarily block other mappers' block pushes if all the threads
* are occupied by block pushes from the same mapper.
*
* This code is similar to ShuffleBlockFetcherIterator#fetchUpToMaxBytes in how it throttles
* the data transfer between shuffle client/server.
*/
private def pushUpToMax(): Unit = synchronized {
// Process any outstanding deferred push requests if possible.
if (deferredPushRequests.nonEmpty) {
for ((remoteAddress, defReqQueue) <- deferredPushRequests) {
while (isRemoteBlockPushable(defReqQueue) &&
!isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
val request = defReqQueue.dequeue()
logDebug(s"Processing deferred push request for $remoteAddress with "
+ s"${request.blocks.length} blocks")
sendRequest(request)
if (defReqQueue.isEmpty) {
deferredPushRequests -= remoteAddress
}
}
}
}
// Process any regular push requests if possible.
while (isRemoteBlockPushable(pushRequests)) {
val request = pushRequests.dequeue()
val remoteAddress = request.address
if (isRemoteAddressMaxedOut(remoteAddress, request)) {
logDebug(s"Deferring push request for $remoteAddress with ${request.blocks.size} blocks")
deferredPushRequests.getOrElseUpdate(remoteAddress, new Queue[PushRequest]())
.enqueue(request)
} else {
sendRequest(request)
}
}
def isRemoteBlockPushable(pushReqQueue: Queue[PushRequest]): Boolean = {
pushReqQueue.nonEmpty &&
(bytesInFlight == 0 ||
(reqsInFlight + 1 <= maxReqsInFlight &&
bytesInFlight + pushReqQueue.front.size <= maxBytesInFlight))
}
// Checks if sending a new push request will exceed the max no. of blocks being pushed to a
// given remote address.
def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: PushRequest): Boolean = {
(numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0)
+ request.blocks.size) > maxBlocksInFlightPerAddress
}
}
/**
* Push blocks to remote shuffle server. The callback listener will invoke #pushUpToMax again
* to trigger pushing the next batch of blocks once some block transfer is done in the current
* batch. This way, we decouple the map task from the block push process, since it is netty
* client thread instead of task execution thread which takes care of majority of the block
* pushes.
*/
private def sendRequest(request: PushRequest): Unit = {
bytesInFlight += request.size
reqsInFlight += 1
numBlocksInFlightPerAddress(request.address) = numBlocksInFlightPerAddress.getOrElseUpdate(
request.address, 0) + request.blocks.length
val sizeMap = request.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val address = request.address
val blockIds = request.blocks.map(_._1.toString)
val remainingBlocks = new HashSet[String]() ++= blockIds
val blockPushListener = new BlockFetchingListener {
// Initiating a connection and pushing blocks to a remote shuffle service is always handled by
// the block-push-threads. We should not initiate the connection creation in the
// blockPushListener callbacks which are invoked by the netty eventloop because:
// 1. TrasportClient.createConnection(...) blocks for connection to be established and it's
// recommended to avoid any blocking operations in the eventloop;
// 2. The actual connection creation is a task that gets added to the task queue of another
// eventloop which could have eventloops eventually blocking each other.
// Once the blockPushListener is notified of the block push success or failure, we
// just delegate it to block-push-threads.
def handleResult(result: PushResult): Unit = {
submitTask(() => {
if (updateStateAndCheckIfPushMore(
sizeMap(result.blockId), address, remainingBlocks, result)) {
pushUpToMax()
}
})
}
override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
logTrace(s"Push for block $blockId to $address successful.")
handleResult(PushResult(blockId, null))
}
override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
// check the message or it's cause to see it needs to be logged.
if (!errorHandler.shouldLogError(exception)) {
logTrace(s"Pushing block $blockId to $address failed.", exception)
} else {
logWarning(s"Pushing block $blockId to $address failed.", exception)
}
handleResult(PushResult(blockId, exception))
}
}
SparkEnv.get.blockManager.blockStoreClient.pushBlocks(
address.host, address.port, blockIds.toArray,
sliceReqBufferIntoBlockBuffers(request.reqBuffer, request.blocks.map(_._2)),
blockPushListener)
}
/**
* Given the ManagedBuffer representing all the continuous blocks inside the shuffle data file
* for a PushRequest and an array of individual block sizes, load the buffer from disk into
* memory and slice it into multiple smaller buffers representing each block.
*
* With nio ByteBuffer, the individual block buffers share data with the initial in memory
* buffer loaded from disk. Thus only one copy of the block data is kept in memory.
* @param reqBuffer A {{FileSegmentManagedBuffer}} representing all the continuous blocks in
* the shuffle data file for a PushRequest
* @param blockSizes Array of block sizes
* @return Array of in memory buffer for each individual block
*/
private def sliceReqBufferIntoBlockBuffers(
reqBuffer: ManagedBuffer,
blockSizes: Seq[Int]): Array[ManagedBuffer] = {
if (blockSizes.size == 1) {
Array(reqBuffer)
} else {
val inMemoryBuffer = reqBuffer.nioByteBuffer()
val blockOffsets = new Array[Int](blockSizes.size)
var offset = 0
for (index <- blockSizes.indices) {
blockOffsets(index) = offset
offset += blockSizes(index)
}
blockOffsets.zip(blockSizes).map {
case (offset, size) =>
new NioManagedBuffer(inMemoryBuffer.duplicate()
.position(offset)
.limit(offset + size).asInstanceOf[ByteBuffer].slice())
}.toArray
}
}
/**
* Updates the stats and based on the previous push result decides whether to push more blocks
* or stop.
*
* @param bytesPushed number of bytes pushed.
* @param address address of the remote service
* @param remainingBlocks remaining blocks
* @param pushResult result of the last push
* @return true if more blocks should be pushed; false otherwise.
*/
private def updateStateAndCheckIfPushMore(
bytesPushed: Long,
address: BlockManagerId,
remainingBlocks: HashSet[String],
pushResult: PushResult): Boolean = synchronized {
remainingBlocks -= pushResult.blockId
bytesInFlight -= bytesPushed
numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
if (remainingBlocks.isEmpty) {
reqsInFlight -= 1
}
if (pushResult.failure != null && pushResult.failure.getCause.isInstanceOf[ConnectException]) {
// Remove all the blocks for this address just once because removing from pushRequests
// is expensive. If there is a ConnectException for the first block, all the subsequent
// blocks to that address will fail, so should avoid removing multiple times.
if (!unreachableBlockMgrs.contains(address)) {
var removed = 0
unreachableBlockMgrs.add(address)
removed += pushRequests.dequeueAll(req => req.address == address).length
removed += deferredPushRequests.remove(address).map(_.length).getOrElse(0)
logWarning(s"Received a ConnectException from $address. " +
s"Dropping $removed push-requests and " +
s"not pushing any more blocks to this address.")
}
}
if (pushResult.failure != null && !errorHandler.shouldRetryError(pushResult.failure)) {
logDebug(s"Received after merge is finalized from $address. Not pushing any more blocks.")
return false
} else {
remainingBlocks.isEmpty && (pushRequests.nonEmpty || deferredPushRequests.nonEmpty)
}
}
/**
* Convert the shuffle data file of the current mapper into a list of PushRequest. Basically,
* continuous blocks in the shuffle file are grouped into a single request to allow more
* efficient read of the block data. Each mapper for a given shuffle will receive the same
* list of BlockManagerIds as the target location to push the blocks to. All mappers in the
* same shuffle will map shuffle partition ranges to individual target locations in a consistent
* manner to make sure each target location receives shuffle blocks belonging to the same set
* of partition ranges. 0-length blocks and blocks that are large enough will be skipped.
*
* @param numPartitions sumber of shuffle partitions in the shuffle file
* @param partitionId map index of the current mapper
* @param shuffleId shuffleId of current shuffle
* @param dataFile shuffle data file
* @param partitionLengths array of sizes of blocks in the shuffle data file
* @param mergerLocs target locations to push blocks to
* @param transportConf transportConf used to create FileSegmentManagedBuffer
* @return List of the PushRequest, randomly shuffled.
*
* VisibleForTesting
*/
private[shuffle] def prepareBlockPushRequests(
numPartitions: Int,
partitionId: Int,
shuffleId: Int,
dataFile: File,
partitionLengths: Array[Long],
mergerLocs: Seq[BlockManagerId],
transportConf: TransportConf): Seq[PushRequest] = {
var offset = 0L
var currentReqSize = 0
var currentReqOffset = 0L
var currentMergerId = 0
val numMergers = mergerLocs.length
val requests = new ArrayBuffer[PushRequest]
var blocks = new ArrayBuffer[(BlockId, Int)]
for (reduceId <- 0 until numPartitions) {
val blockSize = partitionLengths(reduceId)
logDebug(
s"Block ${ShufflePushBlockId(shuffleId, partitionId, reduceId)} is of size $blockSize")
// Skip 0-length blocks and blocks that are large enough
if (blockSize > 0) {
val mergerId = math.min(math.floor(reduceId * 1.0 / numPartitions * numMergers),
numMergers - 1).asInstanceOf[Int]
// Start a new PushRequest if the current request goes beyond the max batch size,
// or the number of blocks in the current request goes beyond the limit per destination,
// or the next block push location is for a different shuffle service, or the next block
// exceeds the max block size to push limit. This guarantees that each PushRequest
// represents continuous blocks in the shuffle file to be pushed to the same shuffle
// service, and does not go beyond existing limitations.
if (currentReqSize + blockSize <= maxBlockBatchSize
&& blocks.size < maxBlocksInFlightPerAddress
&& mergerId == currentMergerId && blockSize <= maxBlockSizeToPush) {
// Add current block to current batch
currentReqSize += blockSize.toInt
} else {
if (blocks.nonEmpty) {
// Convert the previous batch into a PushRequest
requests += PushRequest(mergerLocs(currentMergerId), blocks.toSeq,
createRequestBuffer(transportConf, dataFile, currentReqOffset, currentReqSize))
blocks = new ArrayBuffer[(BlockId, Int)]
}
// Start a new batch
currentReqSize = 0
// Set currentReqOffset to -1 so we are able to distinguish between the initial value
// of currentReqOffset and when we are about to start a new batch
currentReqOffset = -1
currentMergerId = mergerId
}
// Only push blocks under the size limit
if (blockSize <= maxBlockSizeToPush) {
val blockSizeInt = blockSize.toInt
blocks += ((ShufflePushBlockId(shuffleId, partitionId, reduceId), blockSizeInt))
// Only update currentReqOffset if the current block is the first in the request
if (currentReqOffset == -1) {
currentReqOffset = offset
}
if (currentReqSize == 0) {
currentReqSize += blockSizeInt
}
}
}
offset += blockSize
}
// Add in the final request
if (blocks.nonEmpty) {
requests += PushRequest(mergerLocs(currentMergerId), blocks.toSeq,
createRequestBuffer(transportConf, dataFile, currentReqOffset, currentReqSize))
}
requests.toSeq
}
// Visible for testing
protected def createRequestBuffer(
conf: TransportConf,
dataFile: File,
offset: Long,
length: Long): ManagedBuffer = {
new FileSegmentManagedBuffer(conf, dataFile, offset, length)
}
}
private[spark] object ShuffleBlockPusher {
/**
* A request to push blocks to a remote shuffle service
* @param address remote shuffle service location to push blocks to
* @param blocks list of block IDs and their sizes
* @param reqBuffer a chunk of data in the shuffle data file corresponding to the continuous
* blocks represented in this request
*/
private[spark] case class PushRequest(
address: BlockManagerId,
blocks: Seq[(BlockId, Int)],
reqBuffer: ManagedBuffer) {
val size = blocks.map(_._2).sum
}
/**
* Result of the block push.
* @param blockId blockId
* @param failure exception if the push was unsuccessful; null otherwise;
*/
private case class PushResult(blockId: String, failure: Throwable)
private val BLOCK_PUSHER_POOL: ExecutorService = {
val conf = SparkEnv.get.conf
if (Utils.isPushBasedShuffleEnabled(conf)) {
val numThreads = conf.get(SHUFFLE_NUM_PUSH_THREADS)
.getOrElse(conf.getInt(SparkLauncher.EXECUTOR_CORES, 1))
ThreadUtils.newDaemonFixedThreadPool(numThreads, "shuffle-block-push-thread")
} else {
null
}
}
/**
* Stop the shuffle pusher pool if it isn't null.
*/
private[spark] def stop(): Unit = {
if (BLOCK_PUSHER_POOL != null) {
BLOCK_PUSHER_POOL.shutdown()
}
}
}

View file

@ -21,6 +21,7 @@ import org.apache.spark.{Partition, ShuffleDependency, SparkEnv, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.util.Utils
/**
* The interface for customizing shuffle write process. The driver create a ShuffleWriteProcessor
@ -57,7 +58,23 @@ private[spark] class ShuffleWriteProcessor extends Serializable with Logging {
createMetricsReporter(context))
writer.write(
rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
writer.stop(success = true).get
val mapStatus = writer.stop(success = true)
if (mapStatus.isDefined) {
// Initiate shuffle push process if push based shuffle is enabled
// The map task only takes care of converting the shuffle data file into multiple
// block push requests. It delegates pushing the blocks to a different thread-pool -
// ShuffleBlockPusher.BLOCK_PUSHER_POOL.
if (Utils.isPushBasedShuffleEnabled(SparkEnv.get.conf) && dep.getMergerLocs.nonEmpty) {
manager.shuffleBlockResolver match {
case resolver: IndexShuffleBlockResolver =>
val dataFile = resolver.getDataFile(dep.shuffleId, mapId)
new ShuffleBlockPusher(SparkEnv.get.conf)
.initiateBlockPush(dataFile, writer.getPartitionLengths(), dep, partition.index)
case _ =>
}
}
}
mapStatus.get
} catch {
case e: Exception =>
try {

View file

@ -31,4 +31,7 @@ private[spark] abstract class ShuffleWriter[K, V] {
/** Close this writer, passing along whether the map completed */
def stop(success: Boolean): Option[MapStatus]
/** Get the lengths of each partition */
def getPartitionLengths(): Array[Long]
}

View file

@ -45,6 +45,8 @@ private[spark] class SortShuffleWriter[K, V, C](
private var mapStatus: MapStatus = null
private var partitionLengths: Array[Long] = _
private val writeMetrics = context.taskMetrics().shuffleWriteMetrics
/** Write a bunch of records to this task's output */
@ -67,7 +69,7 @@ private[spark] class SortShuffleWriter[K, V, C](
val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
dep.shuffleId, mapId, dep.partitioner.numPartitions)
sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
val partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths
partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
}
@ -93,6 +95,8 @@ private[spark] class SortShuffleWriter[K, V, C](
}
}
}
override def getPartitionLengths(): Array[Long] = partitionLengths
}
private[spark] object SortShuffleWriter {

View file

@ -20,7 +20,7 @@ package org.apache.spark.storage
import java.util.UUID
import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.annotation.{DeveloperApi, Since}
/**
* :: DeveloperApi ::
@ -81,6 +81,12 @@ case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, reduceId: Int) exten
override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index"
}
@Since("3.2.0")
@DeveloperApi
case class ShufflePushBlockId(shuffleId: Int, mapIndex: Int, reduceId: Int) extends BlockId {
override def name: String = "shufflePush_" + shuffleId + "_" + mapIndex + "_" + reduceId
}
@DeveloperApi
case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId {
override def name: String = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field)
@ -122,6 +128,7 @@ object BlockId {
val SHUFFLE_BATCH = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r
val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r
val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r
val SHUFFLE_PUSH = "shufflePush_([0-9]+)_([0-9]+)_([0-9]+)".r
val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
val TASKRESULT = "taskresult_([0-9]+)".r
val STREAM = "input-([0-9]+)-([0-9]+)".r
@ -140,6 +147,8 @@ object BlockId {
ShuffleDataBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt)
case SHUFFLE_INDEX(shuffleId, mapId, reduceId) =>
ShuffleIndexBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt)
case SHUFFLE_PUSH(shuffleId, mapIndex, reduceId) =>
ShufflePushBlockId(shuffleId.toInt, mapIndex.toInt, reduceId.toInt)
case BROADCAST(broadcastId, field) =>
BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_"))
case TASKRESULT(taskId) =>

View file

@ -0,0 +1,355 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle
import java.io.File
import java.net.ConnectException
import java.nio.ByteBuffer
import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.ArrayBuffer
import org.mockito.{Mock, MockitoAnnotations}
import org.mockito.Answers.RETURNS_SMART_NULLS
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.scalatest.BeforeAndAfterEach
import org.apache.spark._
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockStoreClient}
import org.apache.spark.network.shuffle.ErrorHandler.BlockPushErrorHandler
import org.apache.spark.network.util.TransportConf
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.shuffle.ShuffleBlockPusher.PushRequest
import org.apache.spark.storage._
class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
@Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _
@Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _
@Mock(answer = RETURNS_SMART_NULLS) private var shuffleClient: BlockStoreClient = _
private var conf: SparkConf = _
private var pushedBlocks = new ArrayBuffer[String]
override def beforeEach(): Unit = {
super.beforeEach()
conf = new SparkConf(loadDefaults = false)
MockitoAnnotations.initMocks(this)
when(dependency.partitioner).thenReturn(new HashPartitioner(8))
when(dependency.serializer).thenReturn(new JavaSerializer(conf))
when(dependency.getMergerLocs).thenReturn(Seq(BlockManagerId("test-client", "test-client", 1)))
conf.set("spark.shuffle.push.based.enabled", "true")
conf.set("spark.shuffle.service.enabled", "true")
// Set the env because the shuffler writer gets the shuffle client instance from the env.
val mockEnv = mock(classOf[SparkEnv])
when(mockEnv.conf).thenReturn(conf)
when(mockEnv.blockManager).thenReturn(blockManager)
SparkEnv.set(mockEnv)
when(blockManager.blockStoreClient).thenReturn(shuffleClient)
}
override def afterEach(): Unit = {
pushedBlocks.clear()
super.afterEach()
}
private def interceptPushedBlocksForSuccess(): Unit = {
when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
.thenAnswer((invocation: InvocationOnMock) => {
val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
pushedBlocks ++= blocks
val managedBuffers = invocation.getArguments()(3).asInstanceOf[Array[ManagedBuffer]]
val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
(blocks, managedBuffers).zipped.foreach((blockId, buffer) => {
blockFetchListener.onBlockFetchSuccess(blockId, buffer)
})
})
}
private def verifyPushRequests(
pushRequests: Seq[PushRequest],
expectedSizes: Seq[Int]): Unit = {
(pushRequests, expectedSizes).zipped.foreach((req, size) => {
assert(req.size == size)
})
}
test("A batch of blocks is limited by maxBlocksBatchSize") {
conf.set("spark.shuffle.push.maxBlockBatchSize", "1m")
conf.set("spark.shuffle.push.maxBlockSizeToPush", "2048k")
val blockPusher = new TestShuffleBlockPusher(conf)
val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
val largeBlockSize = 2 * 1024 * 1024
val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0,
mock(classOf[File]), Array(2, 2, 2, largeBlockSize, largeBlockSize), mergerLocs,
mock(classOf[TransportConf]))
assert(pushRequests.length == 3)
verifyPushRequests(pushRequests, Seq(6, largeBlockSize, largeBlockSize))
}
test("Large blocks are excluded in the preparation") {
conf.set("spark.shuffle.push.maxBlockSizeToPush", "1k")
val blockPusher = new TestShuffleBlockPusher(conf)
val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0,
mock(classOf[File]), Array(2, 2, 2, 1028, 1024), mergerLocs, mock(classOf[TransportConf]))
assert(pushRequests.length == 2)
verifyPushRequests(pushRequests, Seq(6, 1024))
}
test("Number of blocks in a push request are limited by maxBlocksInFlightPerAddress ") {
conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
val blockPusher = new TestShuffleBlockPusher(conf)
val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0,
mock(classOf[File]), Array(2, 2, 2, 2, 2), mergerLocs, mock(classOf[TransportConf]))
assert(pushRequests.length == 5)
verifyPushRequests(pushRequests, Seq(2, 2, 2, 2, 2))
}
test("Basic block push") {
interceptPushedBlocksForSuccess()
val blockPusher = new TestShuffleBlockPusher(conf)
blockPusher.initiateBlockPush(mock(classOf[File]),
Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0)
blockPusher.runPendingTasks()
verify(shuffleClient, times(1))
.pushBlocks(any(), any(), any(), any(), any())
assert(pushedBlocks.length == dependency.partitioner.numPartitions)
ShuffleBlockPusher.stop()
}
test("Large blocks are skipped for push") {
conf.set("spark.shuffle.push.maxBlockSizeToPush", "1k")
interceptPushedBlocksForSuccess()
val pusher = new TestShuffleBlockPusher(conf)
pusher.initiateBlockPush(
mock(classOf[File]), Array(2, 2, 2, 2, 2, 2, 2, 1100), dependency, 0)
pusher.runPendingTasks()
verify(shuffleClient, times(1))
.pushBlocks(any(), any(), any(), any(), any())
assert(pushedBlocks.length == dependency.partitioner.numPartitions - 1)
ShuffleBlockPusher.stop()
}
test("Number of blocks in flight per address are limited by maxBlocksInFlightPerAddress") {
conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
interceptPushedBlocksForSuccess()
val pusher = new TestShuffleBlockPusher(conf)
pusher.initiateBlockPush(
mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0)
pusher.runPendingTasks()
verify(shuffleClient, times(8))
.pushBlocks(any(), any(), any(), any(), any())
assert(pushedBlocks.length == dependency.partitioner.numPartitions)
ShuffleBlockPusher.stop()
}
test("Hit maxBlocksInFlightPerAddress limit so that the blocks are deferred") {
conf.set("spark.reducer.maxBlocksInFlightPerAddress", "2")
var blockPendingResponse : String = null
var listener : BlockFetchingListener = null
when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
.thenAnswer((invocation: InvocationOnMock) => {
val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
pushedBlocks ++= blocks
val managedBuffers = invocation.getArguments()(3).asInstanceOf[Array[ManagedBuffer]]
val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
// Expecting 2 blocks
assert(blocks.length == 2)
if (blockPendingResponse == null) {
blockPendingResponse = blocks(1)
listener = blockFetchListener
// Respond with success only for the first block which will cause all the rest of the
// blocks to be deferred
blockFetchListener.onBlockFetchSuccess(blocks(0), managedBuffers(0))
} else {
(blocks, managedBuffers).zipped.foreach((blockId, buffer) => {
blockFetchListener.onBlockFetchSuccess(blockId, buffer)
})
}
})
val pusher = new TestShuffleBlockPusher(conf)
pusher.initiateBlockPush(
mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0)
pusher.runPendingTasks()
verify(shuffleClient, times(1))
.pushBlocks(any(), any(), any(), any(), any())
assert(pushedBlocks.length == 2)
// this will trigger push of deferred blocks
listener.onBlockFetchSuccess(blockPendingResponse, mock(classOf[ManagedBuffer]))
pusher.runPendingTasks()
verify(shuffleClient, times(4))
.pushBlocks(any(), any(), any(), any(), any())
assert(pushedBlocks.length == 8)
ShuffleBlockPusher.stop()
}
test("Number of shuffle blocks grouped in a single push request is limited by " +
"maxBlockBatchSize") {
conf.set("spark.shuffle.push.maxBlockBatchSize", "1m")
interceptPushedBlocksForSuccess()
val pusher = new TestShuffleBlockPusher(conf)
pusher.initiateBlockPush(mock(classOf[File]),
Array.fill(dependency.partitioner.numPartitions) { 512 * 1024 }, dependency, 0)
pusher.runPendingTasks()
verify(shuffleClient, times(4))
.pushBlocks(any(), any(), any(), any(), any())
assert(pushedBlocks.length == dependency.partitioner.numPartitions)
ShuffleBlockPusher.stop()
}
test("Error retries") {
val pusher = new ShuffleBlockPusher(conf)
val errorHandler = pusher.createErrorHandler()
assert(
!errorHandler.shouldRetryError(new RuntimeException(
new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))))
assert(errorHandler.shouldRetryError(new RuntimeException(new ConnectException())))
assert(
errorHandler.shouldRetryError(new RuntimeException(new IllegalArgumentException(
BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))))
assert (errorHandler.shouldRetryError(new Throwable()))
}
test("Error logging") {
val pusher = new ShuffleBlockPusher(conf)
val errorHandler = pusher.createErrorHandler()
assert(
!errorHandler.shouldLogError(new RuntimeException(
new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))))
assert(!errorHandler.shouldLogError(new RuntimeException(
new IllegalArgumentException(
BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))))
assert(errorHandler.shouldLogError(new Throwable()))
}
test("Blocks are continued to push even when a block push fails with collision " +
"exception") {
conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
val pusher = new TestShuffleBlockPusher(conf)
var failBlock: Boolean = true
when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
.thenAnswer((invocation: InvocationOnMock) => {
val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
blocks.foreach(blockId => {
if (failBlock) {
failBlock = false
// Fail the first block with the collision exception.
blockFetchListener.onBlockFetchFailure(blockId, new RuntimeException(
new IllegalArgumentException(
BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX)))
} else {
pushedBlocks += blockId
blockFetchListener.onBlockFetchSuccess(blockId, mock(classOf[ManagedBuffer]))
}
})
})
pusher.initiateBlockPush(
mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0)
pusher.runPendingTasks()
verify(shuffleClient, times(8))
.pushBlocks(any(), any(), any(), any(), any())
assert(pushedBlocks.length == 7)
}
test("More blocks are not pushed when a block push fails with too late " +
"exception") {
conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
val pusher = new TestShuffleBlockPusher(conf)
var failBlock: Boolean = true
when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
.thenAnswer((invocation: InvocationOnMock) => {
val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
blocks.foreach(blockId => {
if (failBlock) {
failBlock = false
// Fail the first block with the too late exception.
blockFetchListener.onBlockFetchFailure(blockId, new RuntimeException(
new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)))
} else {
pushedBlocks += blockId
blockFetchListener.onBlockFetchSuccess(blockId, mock(classOf[ManagedBuffer]))
}
})
})
pusher.initiateBlockPush(
mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0)
pusher.runPendingTasks()
verify(shuffleClient, times(1))
.pushBlocks(any(), any(), any(), any(), any())
assert(pushedBlocks.isEmpty)
}
test("Connect exceptions remove all the push requests for that host") {
when(dependency.getMergerLocs).thenReturn(
Seq(BlockManagerId("client1", "client1", 1), BlockManagerId("client2", "client2", 2)))
conf.set("spark.reducer.maxBlocksInFlightPerAddress", "2")
when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
.thenAnswer((invocation: InvocationOnMock) => {
val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
pushedBlocks ++= blocks
val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
blocks.foreach(blockId => {
blockFetchListener.onBlockFetchFailure(
blockId, new RuntimeException(new ConnectException()))
})
})
val pusher = new TestShuffleBlockPusher(conf)
pusher.initiateBlockPush(
mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0)
pusher.runPendingTasks()
verify(shuffleClient, times(2))
.pushBlocks(any(), any(), any(), any(), any())
// 2 blocks for each merger locations
assert(pushedBlocks.length == 4)
assert(pusher.unreachableBlockMgrs.size == 2)
}
private class TestShuffleBlockPusher(conf: SparkConf) extends ShuffleBlockPusher(conf) {
private[this] val tasks = new LinkedBlockingQueue[Runnable]
override protected def submitTask(task: Runnable): Unit = {
tasks.add(task)
}
def runPendingTasks(): Unit = {
// This ensures that all the submitted tasks - updateStateAndCheckIfPushMore and pushUpToMax
// are run synchronously.
while (!tasks.isEmpty) {
tasks.take().run()
}
}
override protected def createRequestBuffer(
conf: TransportConf,
dataFile: File,
offset: Long,
length: Long): ManagedBuffer = {
val managedBuffer = mock(classOf[ManagedBuffer])
val byteBuffer = new Array[Byte](length.toInt)
when(managedBuffer.nioByteBuffer()).thenReturn(ByteBuffer.wrap(byteBuffer))
managedBuffer
}
}
}