diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index e9197f7169..3aac8e50b4 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -2,16 +2,14 @@ package spark.storage import java.io._ import java.nio._ -import java.nio.channels.FileChannel.MapMode -import java.util.{HashMap => JHashMap} -import java.util.LinkedHashMap import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.LinkedBlockingQueue import java.util.Collections import akka.dispatch.{Await, Future} import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap +import scala.collection.mutable.{HashMap, HashSet} +import scala.collection.mutable.Queue import scala.collection.JavaConversions._ import it.unimi.dsi.fastutil.io._ @@ -273,28 +271,19 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m logDebug("Getting " + totalBlocks + " blocks") var startTime = System.currentTimeMillis val localBlockIds = new ArrayBuffer[String]() - val remoteBlockIds = new ArrayBuffer[String]() - val remoteBlockIdsPerLocation = new HashMap[BlockManagerId, Seq[String]]() + val remoteBlockIds = new HashSet[String]() // A queue to hold our results. Because we want all the deserializing the happen in the // caller's thread, this will actually hold functions to produce the Iterator for each block. // For local blocks we'll have an iterator already, while for remote ones we'll deserialize. val results = new LinkedBlockingQueue[(String, Option[() => Iterator[Any]])] - // Split local and remote blocks - for ((address, blockIds) <- blocksByAddress) { - if (address == blockManagerId) { - localBlockIds ++= blockIds - } else { - remoteBlockIds ++= blockIds - remoteBlockIdsPerLocation(address) = blockIds - } - } - - // Start getting remote blocks - for ((bmId, bIds) <- remoteBlockIdsPerLocation) { - val cmId = ConnectionManagerId(bmId.ip, bmId.port) - val blockMessages = bIds.map(bId => BlockMessage.fromGetBlock(GetBlock(bId))) + // Bound the number and memory usage of fetched remote blocks. + val parallelFetches = BlockManager.getNumParallelFetchesFromSystemProperties + val blocksToRequest = new Queue[(BlockManagerId, BlockMessage)] + + def sendRequest(bmId: BlockManagerId, blockMessages: Seq[BlockMessage]) { + val cmId = new ConnectionManagerId(bmId.ip, bmId.port) val blockMessageArray = new BlockMessageArray(blockMessages) val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) future.onSuccess { @@ -312,17 +301,43 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m }) } case None => { - logError("Could not get blocks from " + cmId) - for (blockId <- bIds) { - results.put((blockId, None)) + logError("Could not get block(s) from " + cmId) + for (blockMessage <- blockMessages) { + results.put((blockMessage.getId, None)) } } } } - logDebug("Started remote gets for " + remoteBlockIds.size + " blocks in " + + + // Split local and remote blocks. Remote blocks are further split into ones that will + // be requested initially and ones that will be added to a queue of blocks to request. + val initialRequestBlocks = new HashMap[BlockManagerId, ArrayBuffer[BlockMessage]]() + var initialRequests = 0 + for ((address, blockIds) <- blocksByAddress) { + if (address == blockManagerId) { + localBlockIds ++= blockIds + } else { + remoteBlockIds ++= blockIds + blockIds.foreach{blockId => + val blockMessage = BlockMessage.fromGetBlock(GetBlock(blockId)) + if (initialRequests < parallelFetches) { + initialRequestBlocks.getOrElseUpdate(address, new ArrayBuffer[BlockMessage]) + .append(blockMessage) + initialRequests += 1 + } else { + blocksToRequest.enqueue((address, blockMessage)) + } + } + } + } + + // Send out initial request(s) for 'parallelFetches' blocks. + for ((bmId, blockMessages) <- initialRequestBlocks) { sendRequest(bmId, blockMessages) } + + logDebug("Started remote gets for " + parallelFetches + " blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - // Get the local blocks while remote blocks are being fetched + // Get the local blocks while remote blocks are being fetched. startTime = System.currentTimeMillis localBlockIds.foreach(id => { getLocal(id) match { @@ -337,7 +352,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m }) logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - // Return an iterator that will read fetched blocks off the queue as they arrive + // Return an iterator that will read fetched blocks off the queue as they arrive. return new Iterator[(String, Option[Iterator[Any]])] { var resultsGotten = 0 @@ -346,6 +361,10 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m def next(): (String, Option[Iterator[Any]]) = { resultsGotten += 1 val (blockId, functionOption) = results.take() + if (remoteBlockIds.contains(blockId) && !blocksToRequest.isEmpty) { + val (bmId, blockMessage) = blocksToRequest.dequeue + sendRequest(bmId, Seq(blockMessage)) + } (blockId, functionOption.map(_.apply())) } } @@ -598,6 +617,11 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } object BlockManager { + + def getNumParallelFetchesFromSystemProperties(): Int = { + System.getProperty("spark.blockManager.parallelFetches", "8").toInt + } + def getMaxMemoryFromSystemProperties(): Long = { val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble (Runtime.getRuntime.maxMemory * memoryFraction).toLong