diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index ee367f9998..ad8e8b44d2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -23,6 +23,8 @@ import java.util.Random; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import scala.Tuple2; + import com.google.common.base.Preconditions; import io.netty.channel.Channel; import org.slf4j.Logger; @@ -94,6 +96,25 @@ public class OneForOneStreamManager extends StreamManager { return nextChunk; } + @Override + public ManagedBuffer openStream(String streamChunkId) { + Tuple2 streamIdAndChunkId = parseStreamChunkId(streamChunkId); + return getChunk(streamIdAndChunkId._1, streamIdAndChunkId._2); + } + + public static String genStreamChunkId(long streamId, int chunkId) { + return String.format("%d_%d", streamId, chunkId); + } + + public static Tuple2 parseStreamChunkId(String streamChunkId) { + String[] array = streamChunkId.split("_"); + assert array.length == 2: + "Stream id and chunk index should be specified when open stream for fetching block."; + long streamId = Long.valueOf(array[0]); + int chunkIndex = Integer.valueOf(array[1]); + return new Tuple2<>(streamId, chunkIndex); + } + @Override public void connectionTerminated(Channel channel) { // Close all streams which have been associated with the channel. diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 2c5827bf7d..269fa72dad 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; @@ -86,14 +87,16 @@ public class ExternalShuffleClient extends ShuffleClient { int port, String execId, String[] blockIds, - BlockFetchingListener listener) { + BlockFetchingListener listener, + File[] shuffleFiles) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1).start(); + new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1, conf, + shuffleFiles).start(); }; int maxRetries = conf.maxIORetries(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 35f69fe35c..5f42875925 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -17,19 +17,28 @@ package org.apache.spark.network.shuffle; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; import java.util.Arrays; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.util.TransportConf; /** * Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and @@ -48,6 +57,8 @@ public class OneForOneBlockFetcher { private final String[] blockIds; private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; + private TransportConf transportConf = null; + private File[] shuffleFiles = null; private StreamHandle streamHandle = null; @@ -56,12 +67,20 @@ public class OneForOneBlockFetcher { String appId, String execId, String[] blockIds, - BlockFetchingListener listener) { + BlockFetchingListener listener, + TransportConf transportConf, + File[] shuffleFiles) { this.client = client; this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); + this.transportConf = transportConf; + if (shuffleFiles != null) { + this.shuffleFiles = shuffleFiles; + assert this.shuffleFiles.length == blockIds.length: + "Number of shuffle files should equal to blocks"; + } } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ @@ -100,7 +119,12 @@ public class OneForOneBlockFetcher { // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { - client.fetchChunk(streamHandle.streamId, i, chunkCallback); + if (shuffleFiles != null) { + client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), + new DownloadCallback(shuffleFiles[i], i)); + } else { + client.fetchChunk(streamHandle.streamId, i, chunkCallback); + } } } catch (Exception e) { logger.error("Failed while starting block fetches after success", e); @@ -126,4 +150,38 @@ public class OneForOneBlockFetcher { } } } + + private class DownloadCallback implements StreamCallback { + + private WritableByteChannel channel = null; + private File targetFile = null; + private int chunkIndex; + + public DownloadCallback(File targetFile, int chunkIndex) throws IOException { + this.targetFile = targetFile; + this.channel = Channels.newChannel(new FileOutputStream(targetFile)); + this.chunkIndex = chunkIndex; + } + + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + channel.write(buf); + } + + @Override + public void onComplete(String streamId) throws IOException { + channel.close(); + ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, + targetFile.length()); + listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + channel.close(); + // On receipt of a failure, fail every block from chunkIndex onwards. + String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); + failRemainingBlocks(remainingBlockIds, cause); + } + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index f72ab40690..978ff5a2a8 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle; import java.io.Closeable; +import java.io.File; /** Provides an interface for reading shuffle files, either from an Executor or external service. */ public abstract class ShuffleClient implements Closeable { @@ -40,5 +41,6 @@ public abstract class ShuffleClient implements Closeable { int port, String execId, String[] blockIds, - BlockFetchingListener listener); + BlockFetchingListener listener, + File[] shuffleFiles); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index c0e170e5b9..0c054fc5db 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -204,7 +204,7 @@ public class SaslIntegrationSuite { String[] blockIds = { "shuffle_2_3_4", "shuffle_6_7_8" }; OneForOneBlockFetcher fetcher = - new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener); + new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf, null); fetcher.start(); blockFetchLatch.await(); checkSecurityException(exception.get()); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 7a33b68217..d1d8f5b4e1 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -158,7 +158,7 @@ public class ExternalShuffleIntegrationSuite { } } } - }); + }, null); if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 3e51fea3cf..61d82214e7 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -46,8 +46,13 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; public class OneForOneBlockFetcherSuite { + + private static final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + @Test public void testFetchOne() { LinkedHashMap blocks = Maps.newLinkedHashMap(); @@ -126,7 +131,7 @@ public class OneForOneBlockFetcherSuite { BlockFetchingListener listener = mock(BlockFetchingListener.class); String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); OneForOneBlockFetcher fetcher = - new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener); + new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener, conf, null); // Respond to the "OpenBlocks" message with an appropriate ShuffleStreamHandle with streamId 123 doAnswer(invocationOnMock -> { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index e193ed222e..f8139b706a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -287,4 +287,10 @@ package object config { .bytesConf(ByteUnit.BYTE) .createWithDefault(100 * 1024 * 1024) + private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM = + ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") + .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + + "above this threshold. This is to avoid a giant request takes too much memory.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("200m") } diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index cb9d389dd7..6860214c7f 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -17,7 +17,7 @@ package org.apache.spark.network -import java.io.Closeable +import java.io.{Closeable, File} import java.nio.ByteBuffer import scala.concurrent.{Future, Promise} @@ -67,7 +67,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo port: Int, execId: String, blockIds: Array[String], - listener: BlockFetchingListener): Unit + listener: BlockFetchingListener, + shuffleFiles: Array[File]): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -100,7 +101,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo ret.flip() result.success(new NioManagedBuffer(ret)) } - }) + }, shuffleFiles = null) ThreadUtils.awaitResult(result.future, Duration.Inf) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b75e91b660..b13a9c681e 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,6 +17,7 @@ package org.apache.spark.network.netty +import java.io.File import java.nio.ByteBuffer import scala.collection.JavaConverters._ @@ -88,13 +89,15 @@ private[spark] class NettyBlockTransferService( port: Int, execId: String, blockIds: Array[String], - listener: BlockFetchingListener): Unit = { + listener: BlockFetchingListener, + shuffleFiles: Array[File]): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) - new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start() + new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener, + transportConf, shuffleFiles).start() } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index ba3e0e395e..2fbac79a23 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -18,7 +18,7 @@ package org.apache.spark.shuffle import org.apache.spark._ -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator @@ -51,6 +51,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), + SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM), SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index f890611763..ee35060926 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{InputStream, IOException} +import java.io.{File, InputStream, IOException} import java.nio.ByteBuffer import java.util.concurrent.LinkedBlockingQueue import javax.annotation.concurrent.GuardedBy @@ -52,6 +52,7 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. + * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param detectCorrupt whether to detect any corruption in fetched blocks. */ private[spark] @@ -63,6 +64,7 @@ final class ShuffleBlockFetcherIterator( streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, + maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean) extends Iterator[(BlockId, InputStream)] with Logging { @@ -129,6 +131,12 @@ final class ShuffleBlockFetcherIterator( @GuardedBy("this") private[this] var isZombie = false + /** + * A set to store the files used for shuffling remote huge blocks. Files in this set will be + * deleted when cleanup. This is a layer of defensiveness against disk file leaks. + */ + val shuffleFilesSet = mutable.HashSet[File]() + initialize() // Decrements the buffer reference count. @@ -163,6 +171,11 @@ final class ShuffleBlockFetcherIterator( case _ => } } + shuffleFilesSet.foreach { file => + if (!file.delete()) { + logInfo("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath()); + } + } } private[this] def sendRequest(req: FetchRequest) { @@ -175,33 +188,45 @@ final class ShuffleBlockFetcherIterator( val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap val remainingBlocks = new HashSet[String]() ++= sizeMap.keys val blockIds = req.blocks.map(_._1.toString) - val address = req.address - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - new BlockFetchingListener { - override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { - // Only add the buffer to results queue if the iterator is not zombie, - // i.e. cleanup() has not been called yet. - ShuffleBlockFetcherIterator.this.synchronized { - if (!isZombie) { - // Increment the ref count because we need to pass this to a different thread. - // This needs to be released after use. - buf.retain() - remainingBlocks -= blockId - results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf, - remainingBlocks.isEmpty)) - logDebug("remainingBlocks: " + remainingBlocks) - } - } - logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) - } - override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { - logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - results.put(new FailureFetchResult(BlockId(blockId), address, e)) + val blockFetchingListener = new BlockFetchingListener { + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + ShuffleBlockFetcherIterator.this.synchronized { + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + remainingBlocks -= blockId + results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf, + remainingBlocks.isEmpty)) + logDebug("remainingBlocks: " + remainingBlocks) + } } + logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } - ) + + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { + logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + results.put(new FailureFetchResult(BlockId(blockId), address, e)) + } + } + + // Shuffle remote blocks to disk when the request is too large. + // TODO: Encryption and compression should be considered. + if (req.size > maxReqSizeShuffleToMem) { + val shuffleFiles = blockIds.map { + bId => blockManager.diskBlockManager.createTempLocalBlock()._2 + }.toArray + shuffleFilesSet ++= shuffleFiles + shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, + blockFetchingListener, shuffleFiles) + } else { + shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, + blockFetchingListener, null) + } } private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index bb24c6ce4d..71bedda5ac 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer -import org.mockito.Matchers.{any, isA} +import org.mockito.Matchers.any import org.mockito.Mockito._ import org.apache.spark.broadcast.BroadcastManager diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 792a1d7f57..474e30144f 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -165,7 +165,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { promise.success(data.retain()) } - }) + }, null) ThreadUtils.awaitReady(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) promise.future.value.get diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 1e7bcdb674..0d2912ba8c 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.storage +import java.io.File import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer @@ -1290,7 +1291,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE port: Int, execId: String, blockIds: Array[String], - listener: BlockFetchingListener): Unit = { + listener: BlockFetchingListener, + shuffleFiles: Array[File]): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 9900d1edc4..1f813a909f 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import java.io.{File, InputStream, IOException} +import java.util.UUID import java.util.concurrent.Semaphore import scala.concurrent.ExecutionContext.Implicits.global @@ -44,7 +45,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -106,6 +108,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, + Int.MaxValue, true) // 3 local blocks fetched in initialization @@ -134,7 +137,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) verify(blockManager, times(3)).getBlockData(any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) } test("release current unexhausted buffer in case the task completes early") { @@ -153,7 +156,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -181,6 +185,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, + Int.MaxValue, true) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() @@ -218,7 +223,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -246,6 +252,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, + Int.MaxValue, true) // Continue only after the mock calls onBlockFetchFailure @@ -281,7 +288,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -309,6 +317,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => new LimitedInputStream(in, 100), 48 * 1024 * 1024, Int.MaxValue, + Int.MaxValue, true) // Continue only after the mock calls onBlockFetchFailure @@ -318,7 +327,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -359,7 +369,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(corruptBuffer.createInputStream()).thenReturn(corruptStream) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -387,6 +398,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => new LimitedInputStream(in, 100), 48 * 1024 * 1024, Int.MaxValue, + Int.MaxValue, false) // Continue only after the mock calls onBlockFetchFailure @@ -401,4 +413,64 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(id3 === ShuffleBlockId(0, 2, 0)) } + test("Blocks should be shuffled to disk when size of the request is above the" + + " threshold(maxReqSizeShuffleToMem).") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + val diskBlockManager = mock(classOf[DiskBlockManager]) + doReturn{ + var blockId = new TempLocalBlockId(UUID.randomUUID()) + (blockId, new File(blockId.name)) + }.when(diskBlockManager).createTempLocalBlock() + doReturn(diskBlockManager).when(blockManager).diskBlockManager + + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) + val transfer = mock(classOf[BlockTransferService]) + var shuffleFiles: Array[File] = null + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + shuffleFiles = invocation.getArguments()(5).asInstanceOf[Array[File]] + Future { + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) + } + } + }) + + val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)) + // Set maxReqSizeShuffleToMem to be 200. + val iterator1 = new ShuffleBlockFetcherIterator( + TaskContext.empty(), + transfer, + blockManager, + blocksByAddress1, + (_, in) => in, + Int.MaxValue, + Int.MaxValue, + 200, + true) + assert(shuffleFiles === null) + + val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) + // Set maxReqSizeShuffleToMem to be 200. + val iterator2 = new ShuffleBlockFetcherIterator( + TaskContext.empty(), + transfer, + blockManager, + blocksByAddress2, + (_, in) => in, + Int.MaxValue, + Int.MaxValue, + 200, + true) + assert(shuffleFiles != null) + } } diff --git a/docs/configuration.md b/docs/configuration.md index a6b6d5dfa5..0771e36f80 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -519,6 +519,14 @@ Apart from these, the following properties are also available, and may be useful By allowing it to limit the number of fetch requests, this scenario can be mitigated. + + spark.reducer.maxReqSizeShuffleToMem + 200m + + The blocks of a shuffle request will be fetched to disk when size of the request is above + this threshold. This is to avoid a giant request takes too much memory. + + spark.shuffle.compress true