diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java index fbbe8ac0f1..e762bd2071 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java @@ -18,15 +18,33 @@ package org.apache.spark.network.shuffle; import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; import java.util.Collections; +import java.util.Map; +import java.util.concurrent.CompletableFuture; import com.codahale.metrics.MetricSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.GetLocalDirsForExecutors; +import org.apache.spark.network.shuffle.protocol.LocalDirsForExecutors; /** * Provides an interface for reading both shuffle files and RDD blocks, either from an Executor * or external service. */ public abstract class BlockStoreClient implements Closeable { + protected final Logger logger = LoggerFactory.getLogger(this.getClass()); + + protected volatile TransportClientFactory clientFactory; + protected String appId; /** * Fetch a sequence of blocks from a remote node asynchronously, @@ -61,4 +79,60 @@ public abstract class BlockStoreClient implements Closeable { // Return an empty MetricSet by default. return () -> Collections.emptyMap(); } + + protected void checkInit() { + assert appId != null : "Called before init()"; + } + + /** + * Request the local disk directories for executors which are located at the same host with + * the current BlockStoreClient(it can be ExternalBlockStoreClient or NettyBlockTransferService). + * + * @param host the host of BlockManager or ExternalShuffleService. It should be the same host + * with current BlockStoreClient. + * @param port the port of BlockManager or ExternalShuffleService. + * @param execIds a collection of executor Ids, which specifies the target executors that we + * want to get their local directories. There could be multiple executor Ids if + * BlockStoreClient is implemented by ExternalBlockStoreClient since the request + * handler, ExternalShuffleService, can serve multiple executors on the same node. + * Or, only one executor Id if BlockStoreClient is implemented by + * NettyBlockTransferService. + * @param hostLocalDirsCompletable a CompletableFuture which contains a map from executor Id + * to its local directories if the request handler replies + * successfully. Otherwise, it contains a specific error. + */ + public void getHostLocalDirs( + String host, + int port, + String[] execIds, + CompletableFuture> hostLocalDirsCompletable) { + checkInit(); + GetLocalDirsForExecutors getLocalDirsMessage = new GetLocalDirsForExecutors(appId, execIds); + try { + TransportClient client = clientFactory.createClient(host, port); + client.sendRpc(getLocalDirsMessage.toByteBuffer(), new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + try { + BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(response); + hostLocalDirsCompletable.complete( + ((LocalDirsForExecutors) msgObj).getLocalDirsByExec()); + } catch (Throwable t) { + logger.warn("Error while trying to get the host local dirs for " + + Arrays.toString(getLocalDirsMessage.execIds), t.getCause()); + hostLocalDirsCompletable.completeExceptionally(t); + } + } + + @Override + public void onFailure(Throwable t) { + logger.warn("Error while trying to get the host local dirs for " + + Arrays.toString(getLocalDirsMessage.execIds), t.getCause()); + hostLocalDirsCompletable.completeExceptionally(t); + } + }); + } catch (IOException | InterruptedException e) { + hostLocalDirsCompletable.completeExceptionally(e); + } + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java index 9fdb6322c9..76e23e7c69 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java @@ -21,7 +21,6 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.List; -import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; @@ -30,10 +29,7 @@ import com.google.common.collect.Lists; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; -import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.shuffle.protocol.*; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; import org.apache.spark.network.crypto.AuthClientBootstrap; @@ -47,16 +43,11 @@ import org.apache.spark.network.util.TransportConf; * (via BlockTransferService), which has the downside of losing the data if we lose the executors. */ public class ExternalBlockStoreClient extends BlockStoreClient { - private static final Logger logger = LoggerFactory.getLogger(ExternalBlockStoreClient.class); - private final TransportConf conf; private final boolean authEnabled; private final SecretKeyHolder secretKeyHolder; private final long registrationTimeoutMs; - protected volatile TransportClientFactory clientFactory; - protected String appId; - /** * Creates an external shuffle client, with SASL optionally enabled. If SASL is not enabled, * then secretKeyHolder may be null. @@ -72,10 +63,6 @@ public class ExternalBlockStoreClient extends BlockStoreClient { this.registrationTimeoutMs = registrationTimeoutMs; } - protected void checkInit() { - assert appId != null : "Called before init()"; - } - /** * Initializes the BlockStoreClient, specifying this Executor's appId. * Must be called before any other method on the BlockStoreClient. @@ -188,43 +175,6 @@ public class ExternalBlockStoreClient extends BlockStoreClient { return numRemovedBlocksFuture; } - public void getHostLocalDirs( - String host, - int port, - String[] execIds, - CompletableFuture> hostLocalDirsCompletable) { - checkInit(); - GetLocalDirsForExecutors getLocalDirsMessage = new GetLocalDirsForExecutors(appId, execIds); - try { - TransportClient client = clientFactory.createClient(host, port); - client.sendRpc(getLocalDirsMessage.toByteBuffer(), new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - try { - BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(response); - hostLocalDirsCompletable.complete( - ((LocalDirsForExecutors) msgObj).getLocalDirsByExec()); - } catch (Throwable t) { - logger.warn("Error trying to get the host local dirs for " + - Arrays.toString(getLocalDirsMessage.execIds) + " via external shuffle service", - t.getCause()); - hostLocalDirsCompletable.completeExceptionally(t); - } - } - - @Override - public void onFailure(Throwable t) { - logger.warn("Error trying to get the host local dirs for " + - Arrays.toString(getLocalDirsMessage.execIds) + " via external shuffle service", - t.getCause()); - hostLocalDirsCompletable.completeExceptionally(t); - } - }); - } catch (IOException | InterruptedException e) { - hostLocalDirsCompletable.completeExceptionally(e); - } - } - @Override public void close() { checkInit(); diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index b308115935..4b6770b319 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1415,10 +1415,9 @@ package object config { private[spark] val SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED = ConfigBuilder("spark.shuffle.readHostLocalDisk") - .doc(s"If enabled (and `${SHUFFLE_USE_OLD_FETCH_PROTOCOL.key}` is disabled and external " + - s"shuffle `${SHUFFLE_SERVICE_ENABLED.key}` is enabled), shuffle " + - "blocks requested from those block managers which are running on the same host are read " + - "from the disk directly instead of being fetched as remote blocks over the network.") + .doc(s"If enabled (and `${SHUFFLE_USE_OLD_FETCH_PROTOCOL.key}` is disabled, shuffle " + + "blocks requested from those block managers which are running on the same host are " + + "read from the disk directly instead of being fetched as remote blocks over the network.") .version("3.0.0") .booleanConf .createWithDefault(true) diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 0bd5774b63..62fbc16616 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -27,6 +27,11 @@ import org.apache.spark.storage.{BlockId, ShuffleBlockId, StorageLevel} private[spark] trait BlockDataManager { + /** + * Get the local directories that used by BlockManager to save the blocks to disk + */ + def getLocalDiskDirs: Array[String] + /** * Interface to get host-local shuffle block data. Throws an exception if the block cannot be * found or cannot be read successfully. diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 70a159f3ee..98129b62b5 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.ThreadUtils * BlockTransferService contains both client and server inside. */ private[spark] -abstract class BlockTransferService extends BlockStoreClient with Logging { +abstract class BlockTransferService extends BlockStoreClient { /** * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 62726f7e14..5f831dc666 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -29,7 +29,7 @@ import org.apache.spark.network.client.{RpcResponseCallback, StreamCallbackWithI import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} import org.apache.spark.network.shuffle.protocol._ import org.apache.spark.serializer.Serializer -import org.apache.spark.storage.{BlockId, ShuffleBlockBatchId, ShuffleBlockId, StorageLevel} +import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockBatchId, ShuffleBlockId, StorageLevel} /** * Serves requests to open blocks by simply registering one chunk per block requested. @@ -113,6 +113,26 @@ class NettyBlockRpcServer( s"when there is not sufficient space available to store the block.") responseContext.onFailure(exception) } + + case getLocalDirs: GetLocalDirsForExecutors => + val isIncorrectAppId = getLocalDirs.appId != appId + val execNum = getLocalDirs.execIds.length + if (isIncorrectAppId || execNum != 1) { + val errorMsg = "Invalid GetLocalDirsForExecutors request: " + + s"${if (isIncorrectAppId) s"incorrect application id: ${getLocalDirs.appId};"}" + + s"${if (execNum != 1) s"incorrect executor number: $execNum (expected 1);"}" + responseContext.onFailure(new IllegalStateException(errorMsg)) + } else { + val expectedExecId = blockManager.asInstanceOf[BlockManager].executorId + val actualExecId = getLocalDirs.execIds.head + if (actualExecId != expectedExecId) { + responseContext.onFailure(new IllegalStateException( + s"Invalid executor id: $actualExecId, expected $expectedExecId.")) + } else { + responseContext.onSuccess(new LocalDirsForExecutors( + Map(actualExecId -> blockManager.getLocalDiskDirs).asJava).toByteBuffer) + } + } } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 5d9cea068b..806fbf5279 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -19,7 +19,9 @@ package org.apache.spark.network.netty import java.io.IOException import java.nio.ByteBuffer +import java.util import java.util.{HashMap => JHashMap, Map => JMap} +import java.util.concurrent.CompletableFuture import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} @@ -33,11 +35,11 @@ import org.apache.spark.ExecutorDeadException import org.apache.spark.internal.config import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, OneForOneBlockFetcher, RetryingBlockFetcher} -import org.apache.spark.network.shuffle.protocol.{UploadBlock, UploadBlockStream} +import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, GetLocalDirsForExecutors, LocalDirsForExecutors, UploadBlock, UploadBlockStream} import org.apache.spark.network.util.JavaUtils import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.serializer.JavaSerializer @@ -65,8 +67,6 @@ private[spark] class NettyBlockTransferService( private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ - private[this] var clientFactory: TransportClientFactory = _ - private[this] var appId: String = _ override def init(blockDataManager: BlockDataManager): Unit = { val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager) @@ -80,7 +80,7 @@ private[spark] class NettyBlockTransferService( clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava) server = createServer(serverBootstrap.toList) appId = conf.getAppId - logInfo(s"Server created on ${hostName}:${server.getPort}") + logger.info(s"Server created on $hostName:${server.getPort}") } /** Creates and binds the TransportServer, possibly trying multiple ports. */ @@ -113,7 +113,9 @@ private[spark] class NettyBlockTransferService( blockIds: Array[String], listener: BlockFetchingListener, tempFileManager: DownloadFileManager): Unit = { - logTrace(s"Fetch blocks from $host:$port (executor id $execId)") + if (logger.isTraceEnabled) { + logger.trace(s"Fetch blocks from $host:$port (executor id $execId)") + } try { val maxRetries = transportConf.maxIORetries() val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { @@ -146,7 +148,7 @@ private[spark] class NettyBlockTransferService( } } catch { case e: Exception => - logError("Exception while beginning fetchBlocks", e) + logger.error("Exception while beginning fetchBlocks", e) blockIds.foreach(listener.onBlockFetchFailure(_, e)) } } @@ -174,12 +176,14 @@ private[spark] class NettyBlockTransferService( blockId.isShuffle) val callback = new RpcResponseCallback { override def onSuccess(response: ByteBuffer): Unit = { - logTrace(s"Successfully uploaded block $blockId${if (asStream) " as stream" else ""}") + if (logger.isTraceEnabled) { + logger.trace(s"Successfully uploaded block $blockId${if (asStream) " as stream" else ""}") + } result.success((): Unit) } override def onFailure(e: Throwable): Unit = { - logError(s"Error while uploading $blockId${if (asStream) " as stream" else ""}", e) + logger.error(s"Error while uploading $blockId${if (asStream) " as stream" else ""}", e) result.failure(e) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 5072340f33..ff0f38a247 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -120,9 +120,7 @@ private[spark] class ByteBufferBlockData( private[spark] class HostLocalDirManager( futureExecutionContext: ExecutionContext, cacheSize: Int, - externalBlockStoreClient: ExternalBlockStoreClient, - host: String, - externalShuffleServicePort: Int) extends Logging { + blockStoreClient: BlockStoreClient) extends Logging { private val executorIdToLocalDirsCache = CacheBuilder @@ -130,24 +128,25 @@ private[spark] class HostLocalDirManager( .maximumSize(cacheSize) .build[String, Array[String]]() - private[spark] def getCachedHostLocalDirs() - : scala.collection.Map[String, Array[String]] = executorIdToLocalDirsCache.synchronized { - import scala.collection.JavaConverters._ - return executorIdToLocalDirsCache.asMap().asScala - } + private[spark] def getCachedHostLocalDirs: Map[String, Array[String]] = + executorIdToLocalDirsCache.synchronized { + executorIdToLocalDirsCache.asMap().asScala.toMap + } private[spark] def getHostLocalDirs( + host: String, + port: Int, executorIds: Array[String])( - callback: Try[java.util.Map[String, Array[String]]] => Unit): Unit = { + callback: Try[Map[String, Array[String]]] => Unit): Unit = { val hostLocalDirsCompletable = new CompletableFuture[java.util.Map[String, Array[String]]] - externalBlockStoreClient.getHostLocalDirs( + blockStoreClient.getHostLocalDirs( host, - externalShuffleServicePort, + port, executorIds, hostLocalDirsCompletable) hostLocalDirsCompletable.whenComplete { (hostLocalDirs, throwable) => if (hostLocalDirs != null) { - callback(Success(hostLocalDirs)) + callback(Success(hostLocalDirs.asScala.toMap)) executorIdToLocalDirsCache.synchronized { executorIdToLocalDirsCache.putAll(hostLocalDirs) } @@ -165,7 +164,7 @@ private[spark] class HostLocalDirManager( * Note that [[initialize()]] must be called before the BlockManager is usable. */ private[spark] class BlockManager( - executorId: String, + val executorId: String, rpcEnv: RpcEnv, val master: BlockManagerMaster, val serializerManager: SerializerManager, @@ -212,7 +211,7 @@ private[spark] class BlockManager( private val maxOnHeapMemory = memoryManager.maxOnHeapStorageMemory private val maxOffHeapMemory = memoryManager.maxOffHeapStorageMemory - private val externalShuffleServicePort = StorageUtils.externalShuffleServicePort(conf) + private[spark] val externalShuffleServicePort = StorageUtils.externalShuffleServicePort(conf) var blockManagerId: BlockManagerId = _ @@ -265,6 +264,8 @@ private[spark] class BlockManager( shuffleManager.shuffleBlockResolver.asInstanceOf[MigratableResolver] } + override def getLocalDiskDirs: Array[String] = diskBlockManager.localDirsString + /** * Abstraction for storing blocks from bytes, whether they start in memory or on disk. * @@ -492,20 +493,17 @@ private[spark] class BlockManager( registerWithExternalShuffleServer() } - hostLocalDirManager = + hostLocalDirManager = { if (conf.get(config.SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED) && !conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) { - externalBlockStoreClient.map { blockStoreClient => - new HostLocalDirManager( - futureExecutionContext, - conf.get(config.STORAGE_LOCAL_DISK_BY_EXECUTORS_CACHE_SIZE), - blockStoreClient, - blockManagerId.host, - externalShuffleServicePort) - } + Some(new HostLocalDirManager( + futureExecutionContext, + conf.get(config.STORAGE_LOCAL_DISK_BY_EXECUTORS_CACHE_SIZE), + blockStoreClient)) } else { None } + } logInfo(s"Initialized BlockManager: $blockManagerId") } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index a2843da056..57b6a38ae6 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -210,13 +210,18 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(_, _, address, _, buf, _) => + case SuccessFetchResult(blockId, mapIndex, address, _, buf, _) => if (address != blockManager.blockManagerId) { - shuffleMetrics.incRemoteBytesRead(buf.size) - if (buf.isInstanceOf[FileSegmentManagedBuffer]) { - shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + if (hostLocalBlocks.contains(blockId -> mapIndex)) { + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + } else { + shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } + shuffleMetrics.incRemoteBlocksFetched(1) } - shuffleMetrics.incRemoteBlocksFetched(1) } buf.release() case _ => @@ -290,9 +295,6 @@ final class ShuffleBlockFetcherIterator( var hostLocalBlockBytes = 0L var remoteBlockBytes = 0L - val hostLocalDirReadingEnabled = - blockManager.hostLocalDirManager != null && blockManager.hostLocalDirManager.isDefined - for ((address, blockInfos) <- blocksByAddress) { if (address.executorId == blockManager.blockManagerId.executorId) { checkBlockSizes(blockInfos) @@ -301,7 +303,8 @@ final class ShuffleBlockFetcherIterator( numBlocksToFetch += mergedBlockInfos.size localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) localBlockBytes += mergedBlockInfos.map(_.size).sum - } else if (hostLocalDirReadingEnabled && address.host == blockManager.blockManagerId.host) { + } else if (blockManager.hostLocalDirManager.isDefined && + address.host == blockManager.blockManagerId.host) { checkBlockSizes(blockInfos) val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch) @@ -463,54 +466,72 @@ final class ShuffleBlockFetcherIterator( * track in-memory are the ManagedBuffer references themselves. */ private[this] def fetchHostLocalBlocks(hostLocalDirManager: HostLocalDirManager): Unit = { - val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs() - val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = - hostLocalBlocksByExecutor - .map { case (hostLocalBmId, bmInfos) => - (hostLocalBmId, bmInfos, cachedDirsByExec.get(hostLocalBmId.executorId)) - }.partition(_._3.isDefined) - val bmId = blockManager.blockManagerId - val immutableHostLocalBlocksWithoutDirs = - hostLocalBlocksWithMissingDirs.map { case (hostLocalBmId, bmInfos, _) => - hostLocalBmId -> bmInfos - }.toMap - if (immutableHostLocalBlocksWithoutDirs.nonEmpty) { - logDebug(s"Asynchronous fetching host-local blocks without cached executors' dir: " + - s"${immutableHostLocalBlocksWithoutDirs.mkString(", ")}") - val execIdsWithoutDirs = immutableHostLocalBlocksWithoutDirs.keys.map(_.executorId).toArray - hostLocalDirManager.getHostLocalDirs(execIdsWithoutDirs) { - case Success(dirs) => - immutableHostLocalBlocksWithoutDirs.foreach { case (hostLocalBmId, blockInfos) => - blockInfos.takeWhile { case (blockId, _, mapIndex) => - fetchHostLocalBlock( - blockId, - mapIndex, - dirs.get(hostLocalBmId.executorId), - hostLocalBmId) - } - } - logDebug(s"Got host-local blocks (without cached executors' dir) in " + - s"${Utils.getUsedTimeNs(startTimeNs)}") + val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs + val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = { + val (hasCache, noCache) = hostLocalBlocksByExecutor.partition { case (hostLocalBmId, _) => + cachedDirsByExec.contains(hostLocalBmId.executorId) + } + (hasCache.toMap, noCache.toMap) + } - case Failure(throwable) => - logError(s"Error occurred while fetching host local blocks", throwable) - val (hostLocalBmId, blockInfoSeq) = immutableHostLocalBlocksWithoutDirs.head - val (blockId, _, mapIndex) = blockInfoSeq.head - results.put(FailureFetchResult(blockId, mapIndex, hostLocalBmId, throwable)) + if (hostLocalBlocksWithMissingDirs.nonEmpty) { + logDebug(s"Asynchronous fetching host-local blocks without cached executors' dir: " + + s"${hostLocalBlocksWithMissingDirs.mkString(", ")}") + + // If the external shuffle service is enabled, we'll fetch the local directories for + // multiple executors from the external shuffle service, which located at the same host + // with the executors, in once. Otherwise, we'll fetch the local directories from those + // executors directly one by one. The fetch requests won't be too much since one host is + // almost impossible to have many executors at the same time practically. + val dirFetchRequests = if (blockManager.externalShuffleServiceEnabled) { + val host = blockManager.blockManagerId.host + val port = blockManager.externalShuffleServicePort + Seq((host, port, hostLocalBlocksWithMissingDirs.keys.toArray)) + } else { + hostLocalBlocksWithMissingDirs.keys.map(bmId => (bmId.host, bmId.port, Array(bmId))).toSeq + } + + dirFetchRequests.foreach { case (host, port, bmIds) => + hostLocalDirManager.getHostLocalDirs(host, port, bmIds.map(_.executorId)) { + case Success(dirsByExecId) => + fetchMultipleHostLocalBlocks( + hostLocalBlocksWithMissingDirs.filterKeys(bmIds.contains), + dirsByExecId, + cached = false) + + case Failure(throwable) => + logError("Error occurred while fetching host local blocks", throwable) + val bmId = bmIds.head + val blockInfoSeq = hostLocalBlocksWithMissingDirs(bmId) + val (blockId, _, mapIndex) = blockInfoSeq.head + results.put(FailureFetchResult(blockId, mapIndex, bmId, throwable)) + } } } + if (hostLocalBlocksWithCachedDirs.nonEmpty) { logDebug(s"Synchronous fetching host-local blocks with cached executors' dir: " + s"${hostLocalBlocksWithCachedDirs.mkString(", ")}") - hostLocalBlocksWithCachedDirs.foreach { case (_, blockInfos, localDirs) => - blockInfos.foreach { case (blockId, _, mapIndex) => - if (!fetchHostLocalBlock(blockId, mapIndex, localDirs.get, bmId)) { - return - } - } + fetchMultipleHostLocalBlocks(hostLocalBlocksWithCachedDirs, cachedDirsByExec, cached = true) + } + } + + private def fetchMultipleHostLocalBlocks( + bmIdToBlocks: Map[BlockManagerId, Seq[(BlockId, Long, Int)]], + localDirsByExecId: Map[String, Array[String]], + cached: Boolean): Unit = { + // We use `forall` because once there's a failed block fetch, `fetchHostLocalBlock` will put + // a `FailureFetchResult` immediately to the `results`. So there's no reason to fetch the + // remaining blocks. + val allFetchSucceeded = bmIdToBlocks.forall { case (bmId, blockInfos) => + blockInfos.forall { case (blockId, _, mapIndex) => + fetchHostLocalBlock(blockId, mapIndex, localDirsByExecId(bmId.executorId), bmId) } - logDebug(s"Got host-local blocks (with cached executors' dir) in " + - s"${Utils.getUsedTimeNs(startTimeNs)}") + } + if (allFetchSucceeded) { + logDebug(s"Got host-local blocks from ${bmIdToBlocks.keys.mkString(", ")} " + + s"(${if (cached) "with" else "without"} cached executors' dir) " + + s"in ${Utils.getUsedTimeNs(startTimeNs)}") } } diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 9026447e5a..48c1cc5906 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -100,50 +100,6 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi e.getMessage should include ("Fetch failure will not retry stage due to testing config") } - test("SPARK-27651: read host local shuffle blocks from disk and avoid network remote fetches") { - val confWithHostLocalRead = - conf.clone.set(config.SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED, true) - confWithHostLocalRead.set(config.STORAGE_LOCAL_DISK_BY_EXECUTORS_CACHE_SIZE, 5) - sc = new SparkContext("local-cluster[2,1,1024]", "test", confWithHostLocalRead) - sc.getConf.get(config.SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED) should equal(true) - sc.env.blockManager.externalShuffleServiceEnabled should equal(true) - sc.env.blockManager.hostLocalDirManager.isDefined should equal(true) - sc.env.blockManager.blockStoreClient.getClass should equal(classOf[ExternalBlockStoreClient]) - - // In a slow machine, one executor may register hundreds of milliseconds ahead of the other one. - // If we don't wait for all executors, it's possible that only one executor runs all jobs. Then - // all shuffle blocks will be in this executor, ShuffleBlockFetcherIterator will directly fetch - // local blocks from the local BlockManager and won't send requests to ExternalShuffleService. - // In this case, we won't receive FetchFailed. And it will make this test fail. - // Therefore, we should wait until all executors are up - TestUtils.waitUntilExecutorsUp(sc, 2, 60000) - - val rdd = sc.parallelize(0 until 1000, 10) - .map { i => (i, 1) } - .reduceByKey(_ + _) - - rdd.count() - rdd.count() - - val cachedExecutors = rdd.mapPartitions { _ => - SparkEnv.get.blockManager.hostLocalDirManager.map { localDirManager => - localDirManager.getCachedHostLocalDirs().keySet.iterator - }.getOrElse(Iterator.empty) - }.collect().toSet - - // both executors are caching the dirs of the other one - cachedExecutors should equal(sc.getExecutorIds().toSet) - - // Invalidate the registered executors, disallowing access to their shuffle blocks (without - // deleting the actual shuffle files, so we could access them without the shuffle service). - // As directories are already cached there is no request to external shuffle service. - rpcHandler.applicationRemoved(sc.conf.getAppId, false /* cleanupLocalDirs */) - - // Now Spark will not receive FetchFailed as host local blocks are read from the cached local - // disk directly - rdd.collect().map(_._2).sum should equal(1000) - } - test("SPARK-25888: using external shuffle service fetching disk persisted blocks") { val confWithRddFetchEnabled = conf.clone.set(config.SHUFFLE_SERVICE_FETCH_RDD_ENABLED, true) sc = new SparkContext("local-cluster[1,1,1024]", "test", confWithRddFetchEnabled) diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index baa878eb14..fa1a75d076 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -118,8 +118,8 @@ class NettyBlockTransferServiceSuite .thenAnswer(_ => {hitExecutorDeadException = true}) service0 = createService(port, driverEndpointRef) - val clientFactoryField = service0.getClass.getField( - "org$apache$spark$network$netty$NettyBlockTransferService$$clientFactory") + val clientFactoryField = service0.getClass + .getSuperclass.getSuperclass.getDeclaredField("clientFactory") clientFactoryField.setAccessible(true) clientFactoryField.set(service0, clientFactory) diff --git a/core/src/test/scala/org/apache/spark/shuffle/HostLocalShuffleReadingSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/HostLocalShuffleReadingSuite.scala new file mode 100644 index 0000000000..12c40f4462 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/HostLocalShuffleReadingSuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.scalatest.matchers.must.Matchers +import org.scalatest.matchers.should.Matchers._ + +import org.apache.spark._ +import org.apache.spark.internal.config._ +import org.apache.spark.network.TransportContext +import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} +import org.apache.spark.network.server.TransportServer +import org.apache.spark.network.shuffle.{ExternalBlockHandler, ExternalBlockStoreClient} +import org.apache.spark.util.Utils + +/** + * This's an end to end test suite used to test the host local shuffle reading. + */ +class HostLocalShuffleReadingSuite extends SparkFunSuite with Matchers with LocalSparkContext { + var rpcHandler: ExternalBlockHandler = _ + var server: TransportServer = _ + var transportContext: TransportContext = _ + + override def afterEach(): Unit = { + Option(rpcHandler).foreach { handler => + Utils.tryLogNonFatalError{ + server.close() + } + Utils.tryLogNonFatalError{ + handler.close() + } + Utils.tryLogNonFatalError{ + transportContext.close() + } + server = null + rpcHandler = null + transportContext = null + } + super.afterEach() + } + + Seq(true, false).foreach { isESSEnabled => /* ESS: external shuffle service */ + val conf = new SparkConf() + .set(SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED, true) + + val (essStatus, blockStoreClientClass) = if (isESSEnabled) { + // LocalSparkCluster will disable the ExternalShuffleService by default. Therefore, + // we have to manually setup an server which embedded with ExternalBlockHandler to + // mimic a ExternalShuffleService. Then, executors on the Worker can successfully + // find a ExternalShuffleService to connect. + val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 2) + rpcHandler = new ExternalBlockHandler(transportConf, null) + transportContext = new TransportContext(transportConf, rpcHandler) + server = transportContext.createServer() + conf.set(SHUFFLE_SERVICE_PORT, server.getPort) + + ("enabled (SPARK-27651)", classOf[ExternalBlockStoreClient]) + } else { + ("disabled (SPARK-32077)", classOf[NettyBlockTransferService]) + } + + test(s"host local shuffle reading with external shuffle service $essStatus") { + conf.set(SHUFFLE_SERVICE_ENABLED, isESSEnabled) + .set(STORAGE_LOCAL_DISK_BY_EXECUTORS_CACHE_SIZE, 5) + sc = new SparkContext("local-cluster[2,1,1024]", "test-host-local-shuffle-reading", conf) + // In a slow machine, one executor may register hundreds of milliseconds ahead of the other + // one. If we don't wait for all executors, it's possible that only one executor runs all + // jobs. Then all shuffle blocks will be in this executor, ShuffleBlockFetcherIterator will + // directly fetch local blocks from the local BlockManager and won't send requests to + // BlockStoreClient. In this case, we won't receive FetchFailed. And it will make this + // test fail. Therefore, we should wait until all executors are up + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) + + sc.getConf.get(SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED) should equal(true) + sc.env.blockManager.externalShuffleServiceEnabled should equal(isESSEnabled) + sc.env.blockManager.hostLocalDirManager.isDefined should equal(true) + sc.env.blockManager.blockStoreClient.getClass should equal(blockStoreClientClass) + + val rdd = sc.parallelize(0 until 1000, 10) + .map { i => + SparkEnv.get.blockManager.hostLocalDirManager.map { localDirManager => + // No shuffle fetch yet. So the cache must be empty + assert(localDirManager.getCachedHostLocalDirs.isEmpty) + } + (i, 1) + }.reduceByKey(_ + _) + + // raise a job and trigger the shuffle fetching during the job + assert(rdd.count() === 1000) + + val cachedExecutors = rdd.mapPartitions { _ => + SparkEnv.get.blockManager.hostLocalDirManager.map { localDirManager => + localDirManager.getCachedHostLocalDirs.keySet.iterator + }.getOrElse(Iterator.empty) + }.collect().toSet + + // both executors are caching the dirs of the other one + cachedExecutors should equal(sc.getExecutorIds().toSet) + + Option(rpcHandler).foreach { handler => + // Invalidate the registered executors, disallowing access to their shuffle blocks (without + // deleting the actual shuffle files, so we could access them without the shuffle service). + // As directories are already cached there is no request to external shuffle service. + handler.applicationRemoved(sc.conf.getAppId, false /* cleanupLocalDirs */) + } + + val (local, remote) = rdd.map { case (_, _) => + val shuffleReadMetrics = TaskContext.get().taskMetrics().shuffleReadMetrics + ((shuffleReadMetrics.localBytesRead, shuffleReadMetrics.localBlocksFetched), + (shuffleReadMetrics.remoteBytesRead, shuffleReadMetrics.remoteBlocksFetched)) + }.collect().unzip + // Spark should read the shuffle data locally from the cached directories on the same host, + // so there's no remote fetching at all. + val (localBytesRead, localBlocksFetched) = local.unzip + val (remoteBytesRead, remoteBlocksFetched) = remote.unzip + assert(localBytesRead.sum > 0 && localBlocksFetched.sum > 0) + assert(remoteBytesRead.sum === 0 && remoteBlocksFetched.sum === 0) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index bf1379ceb8..99c43b12d6 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -66,6 +66,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer } + private def createMockBlockManager(): BlockManager = { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-local-host", 1) + doReturn(localBmId).when(blockManager).blockManagerId + // By default, the mock BlockManager returns None for hostLocalDirManager. One could + // still use initHostLocalDirManager() to specify a custom hostLocalDirManager. + doReturn(None).when(blockManager).hostLocalDirManager + blockManager + } + private def initHostLocalDirManager( blockManager: BlockManager, hostLocalDirs: Map[String, Array[String]]): Unit = { @@ -73,9 +83,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val hostLocalDirManager = new HostLocalDirManager( futureExecutionContext = global, cacheSize = 1, - externalBlockStoreClient = mockExternalBlockStoreClient, - host = "localhost", - externalShuffleServicePort = 7337) + blockStoreClient = mockExternalBlockStoreClient) when(blockManager.hostLocalDirManager).thenReturn(Some(hostLocalDirManager)) when(mockExternalBlockStoreClient.getHostLocalDirs(any(), any(), any(), any())) @@ -116,9 +124,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("successful 3 local + 4 host local + 2 remote reads") { - val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-local-client", "test-local-host", 1) - doReturn(localBmId).when(blockManager).blockManagerId + val blockManager = createMockBlockManager() + val localBmId = blockManager.blockManagerId // Make sure blockManager.getBlockData would return the blocks val localBlocks = Map[BlockId, ManagedBuffer]( @@ -197,13 +204,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // 2 remote blocks are read from the same block manager verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) - assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs().size === 1) + assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size === 1) } test("error during accessing host local dirs for executors") { - val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-local-client", "test-local-host", 1) - doReturn(localBmId).when(blockManager).blockManagerId + val blockManager = createMockBlockManager() val hostLocalBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer()) @@ -218,9 +223,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val hostLocalDirManager = new HostLocalDirManager( futureExecutionContext = global, cacheSize = 1, - externalBlockStoreClient = mockExternalBlockStoreClient, - host = "localhost", - externalShuffleServicePort = 7337) + blockStoreClient = mockExternalBlockStoreClient) when(blockManager.hostLocalDirManager).thenReturn(Some(hostLocalDirManager)) when(mockExternalBlockStoreClient.getHostLocalDirs(any(), any(), any(), any())) @@ -256,10 +259,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("Hit maxBytesInFlight limitation before maxBlocksInFlightPerAddress") { - val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-local-host", 1) - doReturn(localBmId).when(blockManager).blockManagerId - + val blockManager = createMockBlockManager() val remoteBmId1 = BlockManagerId("test-remote-client-1", "test-remote-host1", 1) val remoteBmId2 = BlockManagerId("test-remote-client-2", "test-remote-host2", 2) val blockId1 = ShuffleBlockId(0, 1, 0) @@ -301,10 +301,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("Hit maxBlocksInFlightPerAddress limitation before maxBytesInFlight") { - val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-local-host", 1) - doReturn(localBmId).when(blockManager).blockManagerId - + val blockManager = createMockBlockManager() val remoteBmId = BlockManagerId("test-remote-client-1", "test-remote-host", 2) val blockId1 = ShuffleBlockId(0, 1, 0) val blockId2 = ShuffleBlockId(0, 2, 0) @@ -348,10 +345,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("fetch continuous blocks in batch successful 3 local + 4 host local + 2 remote reads") { - val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-local-host", 1) - doReturn(localBmId).when(blockManager).blockManagerId - + val blockManager = createMockBlockManager() + val localBmId = blockManager.blockManagerId // Make sure blockManager.getBlockData would return the merged block val localBlocks = Seq[BlockId]( ShuffleBlockId(0, 0, 0), @@ -431,14 +426,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verify(blockManager, times(1)) .getHostLocalShuffleData(any(), meq(Array("local-dir"))) - assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs().size === 1) + assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size === 1) } test("fetch continuous blocks in batch should respect maxBytesInFlight") { - val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-local-host", 1) - doReturn(localBmId).when(blockManager).blockManagerId - + val blockManager = createMockBlockManager() // Make sure remote blocks would return the merged block val remoteBmId1 = BlockManagerId("test-client-1", "test-client-1", 1) val remoteBmId2 = BlockManagerId("test-client-2", "test-client-2", 2) @@ -494,10 +486,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("fetch continuous blocks in batch should respect maxBlocksInFlightPerAddress") { - val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-local-host", 1) - doReturn(localBmId).when(blockManager).blockManagerId - + val blockManager = createMockBlockManager() // Make sure remote blocks would return the merged block val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 1) val remoteBlocks = Seq( @@ -549,10 +538,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("release current unexhausted buffer in case the task completes early") { - val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-client", 1) - doReturn(localBmId).when(blockManager).blockManagerId - + val blockManager = createMockBlockManager() // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( @@ -617,10 +603,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("fail all blocks if any of the remote request fails") { - val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-client", 1) - doReturn(localBmId).when(blockManager).blockManagerId - + val blockManager = createMockBlockManager() // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( @@ -707,10 +690,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("retry corrupt blocks") { - val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-client", 1) - doReturn(localBmId).when(blockManager).blockManagerId - + val blockManager = createMockBlockManager() // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( @@ -785,9 +765,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT test("big blocks are also checked for corruption") { val streamLength = 10000L - val blockManager = mock(classOf[BlockManager]) - val localBlockManagerId = BlockManagerId("local-client", "local-client", 1) - doReturn(localBlockManagerId).when(blockManager).blockManagerId + val blockManager = createMockBlockManager() // This stream will throw IOException when the first byte is read val corruptBuffer1 = mockCorruptBuffer(streamLength, 0) @@ -906,10 +884,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("retry corrupt blocks (disabled)") { - val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-client", 1) - doReturn(localBmId).when(blockManager).blockManagerId - + val blockManager = createMockBlockManager() // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( @@ -971,10 +946,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT test("Blocks should be shuffled to disk when size of the request is above the" + " threshold(maxReqSizeShuffleToMem).") { - val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-client", 1) - doReturn(localBmId).when(blockManager).blockManagerId - + val blockManager = createMockBlockManager() val diskBlockManager = mock(classOf[DiskBlockManager]) val tmpDir = Utils.createTempDir() doReturn{ @@ -1036,10 +1008,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("fail zero-size blocks") { - val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-client", 1) - doReturn(localBmId).when(blockManager).blockManagerId - + val blockManager = createMockBlockManager() // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer](