[SPARK-19659] Fetch big blocks to disk when shuffle-read.

## What changes were proposed in this pull request?

Currently the whole block is fetched into memory(off heap by default) when shuffle-read. A block is defined by (shuffleId, mapId, reduceId). Thus it can be large when skew situations. If OOM happens during shuffle read, job will be killed and users will be notified to "Consider boosting spark.yarn.executor.memoryOverhead". Adjusting parameter and allocating more memory can resolve the OOM. However the approach is not perfectly suitable for production environment, especially for data warehouse.
Using Spark SQL as data engine in warehouse, users hope to have a unified parameter(e.g. memory) but less resource wasted(resource is allocated but not used). The hope is strong especially when migrating data engine to Spark from another one(e.g. Hive). Tuning the parameter for thousands of SQLs one by one is very time consuming.
It's not always easy to predict skew situations, when happen, it make sense to fetch remote blocks to disk for shuffle-read, rather than kill the job because of OOM.

In this pr, I propose to fetch big blocks to disk(which is also mentioned in SPARK-3019):

1. Track average size and also the outliers(which are larger than 2*avgSize) in MapStatus;
2. Request memory from `MemoryManager` before fetch blocks and release the memory to `MemoryManager` when `ManagedBuffer` is released.
3. Fetch remote blocks to disk when failing acquiring memory from `MemoryManager`, otherwise fetch to memory.

This is an improvement for memory control when shuffle blocks and help to avoid OOM in scenarios like below:
1. Single huge block;
2. Sizes of many blocks are underestimated in `MapStatus` and the actual footprint of blocks is much larger than the estimated.

## How was this patch tested?
Added unit test in `MapStatusSuite` and `ShuffleBlockFetcherIteratorSuite`.

Author: jinxing <jinxing6042@126.com>

Closes #16989 from jinxing64/SPARK-19659.
This commit is contained in:
jinxing 2017-05-25 16:11:30 +08:00 committed by Wenchen Fan
parent 731462a04f
commit 3f94e64aa8
17 changed files with 255 additions and 48 deletions

View file

@ -23,6 +23,8 @@ import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import scala.Tuple2;
import com.google.common.base.Preconditions;
import io.netty.channel.Channel;
import org.slf4j.Logger;
@ -94,6 +96,25 @@ public class OneForOneStreamManager extends StreamManager {
return nextChunk;
}
@Override
public ManagedBuffer openStream(String streamChunkId) {
Tuple2<Long, Integer> streamIdAndChunkId = parseStreamChunkId(streamChunkId);
return getChunk(streamIdAndChunkId._1, streamIdAndChunkId._2);
}
public static String genStreamChunkId(long streamId, int chunkId) {
return String.format("%d_%d", streamId, chunkId);
}
public static Tuple2<Long, Integer> parseStreamChunkId(String streamChunkId) {
String[] array = streamChunkId.split("_");
assert array.length == 2:
"Stream id and chunk index should be specified when open stream for fetching block.";
long streamId = Long.valueOf(array[0]);
int chunkIndex = Integer.valueOf(array[1]);
return new Tuple2<>(streamId, chunkIndex);
}
@Override
public void connectionTerminated(Channel channel) {
// Close all streams which have been associated with the channel.

View file

@ -17,6 +17,7 @@
package org.apache.spark.network.shuffle;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
@ -86,14 +87,16 @@ public class ExternalShuffleClient extends ShuffleClient {
int port,
String execId,
String[] blockIds,
BlockFetchingListener listener) {
BlockFetchingListener listener,
File[] shuffleFiles) {
checkInit();
logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
try {
RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
(blockIds1, listener1) -> {
TransportClient client = clientFactory.createClient(host, port);
new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1).start();
new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1, conf,
shuffleFiles).start();
};
int maxRetries = conf.maxIORetries();

View file

@ -17,19 +17,28 @@
package org.apache.spark.network.shuffle;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.StreamHandle;
import org.apache.spark.network.util.TransportConf;
/**
* Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and
@ -48,6 +57,8 @@ public class OneForOneBlockFetcher {
private final String[] blockIds;
private final BlockFetchingListener listener;
private final ChunkReceivedCallback chunkCallback;
private TransportConf transportConf = null;
private File[] shuffleFiles = null;
private StreamHandle streamHandle = null;
@ -56,12 +67,20 @@ public class OneForOneBlockFetcher {
String appId,
String execId,
String[] blockIds,
BlockFetchingListener listener) {
BlockFetchingListener listener,
TransportConf transportConf,
File[] shuffleFiles) {
this.client = client;
this.openMessage = new OpenBlocks(appId, execId, blockIds);
this.blockIds = blockIds;
this.listener = listener;
this.chunkCallback = new ChunkCallback();
this.transportConf = transportConf;
if (shuffleFiles != null) {
this.shuffleFiles = shuffleFiles;
assert this.shuffleFiles.length == blockIds.length:
"Number of shuffle files should equal to blocks";
}
}
/** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
@ -100,7 +119,12 @@ public class OneForOneBlockFetcher {
// Immediately request all chunks -- we expect that the total size of the request is
// reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
for (int i = 0; i < streamHandle.numChunks; i++) {
client.fetchChunk(streamHandle.streamId, i, chunkCallback);
if (shuffleFiles != null) {
client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i),
new DownloadCallback(shuffleFiles[i], i));
} else {
client.fetchChunk(streamHandle.streamId, i, chunkCallback);
}
}
} catch (Exception e) {
logger.error("Failed while starting block fetches after success", e);
@ -126,4 +150,38 @@ public class OneForOneBlockFetcher {
}
}
}
private class DownloadCallback implements StreamCallback {
private WritableByteChannel channel = null;
private File targetFile = null;
private int chunkIndex;
public DownloadCallback(File targetFile, int chunkIndex) throws IOException {
this.targetFile = targetFile;
this.channel = Channels.newChannel(new FileOutputStream(targetFile));
this.chunkIndex = chunkIndex;
}
@Override
public void onData(String streamId, ByteBuffer buf) throws IOException {
channel.write(buf);
}
@Override
public void onComplete(String streamId) throws IOException {
channel.close();
ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0,
targetFile.length());
listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);
}
@Override
public void onFailure(String streamId, Throwable cause) throws IOException {
channel.close();
// On receipt of a failure, fail every block from chunkIndex onwards.
String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length);
failRemainingBlocks(remainingBlockIds, cause);
}
}
}

View file

@ -18,6 +18,7 @@
package org.apache.spark.network.shuffle;
import java.io.Closeable;
import java.io.File;
/** Provides an interface for reading shuffle files, either from an Executor or external service. */
public abstract class ShuffleClient implements Closeable {
@ -40,5 +41,6 @@ public abstract class ShuffleClient implements Closeable {
int port,
String execId,
String[] blockIds,
BlockFetchingListener listener);
BlockFetchingListener listener,
File[] shuffleFiles);
}

View file

@ -204,7 +204,7 @@ public class SaslIntegrationSuite {
String[] blockIds = { "shuffle_2_3_4", "shuffle_6_7_8" };
OneForOneBlockFetcher fetcher =
new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener);
new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf, null);
fetcher.start();
blockFetchLatch.await();
checkSecurityException(exception.get());

View file

@ -158,7 +158,7 @@ public class ExternalShuffleIntegrationSuite {
}
}
}
});
}, null);
if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) {
fail("Timeout getting response from the server");

View file

@ -46,8 +46,13 @@ import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.StreamHandle;
import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;
public class OneForOneBlockFetcherSuite {
private static final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
@Test
public void testFetchOne() {
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
@ -126,7 +131,7 @@ public class OneForOneBlockFetcherSuite {
BlockFetchingListener listener = mock(BlockFetchingListener.class);
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
OneForOneBlockFetcher fetcher =
new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener);
new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener, conf, null);
// Respond to the "OpenBlocks" message with an appropriate ShuffleStreamHandle with streamId 123
doAnswer(invocationOnMock -> {

View file

@ -287,4 +287,10 @@ package object config {
.bytesConf(ByteUnit.BYTE)
.createWithDefault(100 * 1024 * 1024)
private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM =
ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem")
.doc("The blocks of a shuffle request will be fetched to disk when size of the request is " +
"above this threshold. This is to avoid a giant request takes too much memory.")
.bytesConf(ByteUnit.BYTE)
.createWithDefaultString("200m")
}

View file

@ -17,7 +17,7 @@
package org.apache.spark.network
import java.io.Closeable
import java.io.{Closeable, File}
import java.nio.ByteBuffer
import scala.concurrent.{Future, Promise}
@ -67,7 +67,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
port: Int,
execId: String,
blockIds: Array[String],
listener: BlockFetchingListener): Unit
listener: BlockFetchingListener,
shuffleFiles: Array[File]): Unit
/**
* Upload a single block to a remote node, available only after [[init]] is invoked.
@ -100,7 +101,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
ret.flip()
result.success(new NioManagedBuffer(ret))
}
})
}, shuffleFiles = null)
ThreadUtils.awaitResult(result.future, Duration.Inf)
}

View file

@ -17,6 +17,7 @@
package org.apache.spark.network.netty
import java.io.File
import java.nio.ByteBuffer
import scala.collection.JavaConverters._
@ -88,13 +89,15 @@ private[spark] class NettyBlockTransferService(
port: Int,
execId: String,
blockIds: Array[String],
listener: BlockFetchingListener): Unit = {
listener: BlockFetchingListener,
shuffleFiles: Array[File]): Unit = {
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try {
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
val client = clientFactory.createClient(host, port)
new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()
new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener,
transportConf, shuffleFiles).start()
}
}

View file

@ -18,7 +18,7 @@
package org.apache.spark.shuffle
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
import org.apache.spark.util.CompletionIterator
@ -51,6 +51,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
val serializerInstance = dep.serializer.newInstance()

View file

@ -17,7 +17,7 @@
package org.apache.spark.storage
import java.io.{InputStream, IOException}
import java.io.{File, InputStream, IOException}
import java.nio.ByteBuffer
import java.util.concurrent.LinkedBlockingQueue
import javax.annotation.concurrent.GuardedBy
@ -52,6 +52,7 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream
* @param streamWrapper A function to wrap the returned input stream.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
* @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
* @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory.
* @param detectCorrupt whether to detect any corruption in fetched blocks.
*/
private[spark]
@ -63,6 +64,7 @@ final class ShuffleBlockFetcherIterator(
streamWrapper: (BlockId, InputStream) => InputStream,
maxBytesInFlight: Long,
maxReqsInFlight: Int,
maxReqSizeShuffleToMem: Long,
detectCorrupt: Boolean)
extends Iterator[(BlockId, InputStream)] with Logging {
@ -129,6 +131,12 @@ final class ShuffleBlockFetcherIterator(
@GuardedBy("this")
private[this] var isZombie = false
/**
* A set to store the files used for shuffling remote huge blocks. Files in this set will be
* deleted when cleanup. This is a layer of defensiveness against disk file leaks.
*/
val shuffleFilesSet = mutable.HashSet[File]()
initialize()
// Decrements the buffer reference count.
@ -163,6 +171,11 @@ final class ShuffleBlockFetcherIterator(
case _ =>
}
}
shuffleFilesSet.foreach { file =>
if (!file.delete()) {
logInfo("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath());
}
}
}
private[this] def sendRequest(req: FetchRequest) {
@ -175,33 +188,45 @@ final class ShuffleBlockFetcherIterator(
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
val blockIds = req.blocks.map(_._1.toString)
val address = req.address
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
new BlockFetchingListener {
override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
// Only add the buffer to results queue if the iterator is not zombie,
// i.e. cleanup() has not been called yet.
ShuffleBlockFetcherIterator.this.synchronized {
if (!isZombie) {
// Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
remainingBlocks -= blockId
results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
remainingBlocks.isEmpty))
logDebug("remainingBlocks: " + remainingBlocks)
}
}
logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
results.put(new FailureFetchResult(BlockId(blockId), address, e))
val blockFetchingListener = new BlockFetchingListener {
override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
// Only add the buffer to results queue if the iterator is not zombie,
// i.e. cleanup() has not been called yet.
ShuffleBlockFetcherIterator.this.synchronized {
if (!isZombie) {
// Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
remainingBlocks -= blockId
results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
remainingBlocks.isEmpty))
logDebug("remainingBlocks: " + remainingBlocks)
}
}
logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
)
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
results.put(new FailureFetchResult(BlockId(blockId), address, e))
}
}
// Shuffle remote blocks to disk when the request is too large.
// TODO: Encryption and compression should be considered.
if (req.size > maxReqSizeShuffleToMem) {
val shuffleFiles = blockIds.map {
bId => blockManager.diskBlockManager.createTempLocalBlock()._2
}.toArray
shuffleFilesSet ++= shuffleFiles
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
blockFetchingListener, shuffleFiles)
} else {
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
blockFetchingListener, null)
}
}
private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {

View file

@ -19,7 +19,7 @@ package org.apache.spark
import scala.collection.mutable.ArrayBuffer
import org.mockito.Matchers.{any, isA}
import org.mockito.Matchers.any
import org.mockito.Mockito._
import org.apache.spark.broadcast.BroadcastManager

View file

@ -165,7 +165,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi
override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
promise.success(data.retain())
}
})
}, null)
ThreadUtils.awaitReady(promise.future, FiniteDuration(10, TimeUnit.SECONDS))
promise.future.value.get

View file

@ -17,6 +17,7 @@
package org.apache.spark.storage
import java.io.File
import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
@ -1290,7 +1291,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
port: Int,
execId: String,
blockIds: Array[String],
listener: BlockFetchingListener): Unit = {
listener: BlockFetchingListener,
shuffleFiles: Array[File]): Unit = {
listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1)))
}

View file

@ -18,6 +18,7 @@
package org.apache.spark.storage
import java.io.{File, InputStream, IOException}
import java.util.UUID
import java.util.concurrent.Semaphore
import scala.concurrent.ExecutionContext.Implicits.global
@ -44,7 +45,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
/** Creates a mock [[BlockTransferService]] that returns data from the given map. */
private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = {
val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]]
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
@ -106,6 +108,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
(_, in) => in,
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
true)
// 3 local blocks fetched in initialization
@ -134,7 +137,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
// 3 local blocks, and 2 remote blocks
// (but from the same block manager so one call to fetchBlocks)
verify(blockManager, times(3)).getBlockData(any())
verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any())
verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any())
}
test("release current unexhausted buffer in case the task completes early") {
@ -153,7 +156,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val sem = new Semaphore(0)
val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
@ -181,6 +185,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
(_, in) => in,
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
true)
verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release()
@ -218,7 +223,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val sem = new Semaphore(0)
val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
@ -246,6 +252,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
(_, in) => in,
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
true)
// Continue only after the mock calls onBlockFetchFailure
@ -281,7 +288,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100)
val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
@ -309,6 +317,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
(_, in) => new LimitedInputStream(in, 100),
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
true)
// Continue only after the mock calls onBlockFetchFailure
@ -318,7 +327,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val (id1, _) = iterator.next()
assert(id1 === ShuffleBlockId(0, 0, 0))
when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
@ -359,7 +369,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
@ -387,6 +398,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
(_, in) => new LimitedInputStream(in, 100),
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
false)
// Continue only after the mock calls onBlockFetchFailure
@ -401,4 +413,64 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
assert(id3 === ShuffleBlockId(0, 2, 0))
}
test("Blocks should be shuffled to disk when size of the request is above the" +
" threshold(maxReqSizeShuffleToMem).") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
doReturn(localBmId).when(blockManager).blockManagerId
val diskBlockManager = mock(classOf[DiskBlockManager])
doReturn{
var blockId = new TempLocalBlockId(UUID.randomUUID())
(blockId, new File(blockId.name))
}.when(diskBlockManager).createTempLocalBlock()
doReturn(diskBlockManager).when(blockManager).diskBlockManager
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val remoteBlocks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer())
val transfer = mock(classOf[BlockTransferService])
var shuffleFiles: Array[File] = null
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
shuffleFiles = invocation.getArguments()(5).asInstanceOf[Array[File]]
Future {
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0)))
}
}
})
val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq))
// Set maxReqSizeShuffleToMem to be 200.
val iterator1 = new ShuffleBlockFetcherIterator(
TaskContext.empty(),
transfer,
blockManager,
blocksByAddress1,
(_, in) => in,
Int.MaxValue,
Int.MaxValue,
200,
true)
assert(shuffleFiles === null)
val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq))
// Set maxReqSizeShuffleToMem to be 200.
val iterator2 = new ShuffleBlockFetcherIterator(
TaskContext.empty(),
transfer,
blockManager,
blocksByAddress2,
(_, in) => in,
Int.MaxValue,
Int.MaxValue,
200,
true)
assert(shuffleFiles != null)
}
}

View file

@ -519,6 +519,14 @@ Apart from these, the following properties are also available, and may be useful
By allowing it to limit the number of fetch requests, this scenario can be mitigated.
</td>
</tr>
<tr>
<td><code>spark.reducer.maxReqSizeShuffleToMem</code></td>
<td>200m</td>
<td>
The blocks of a shuffle request will be fetched to disk when size of the request is above
this threshold. This is to avoid a giant request takes too much memory.
</td>
</tr>
<tr>
<td><code>spark.shuffle.compress</code></td>
<td>true</td>