BlockManager.getMultiple returns a custom iterator, to enable tracking of shuffle performance

This commit is contained in:
Imran Rashid 2013-02-05 13:40:30 -08:00
parent e319ac74c1
commit b29f9cc978

View file

@ -446,152 +446,8 @@ class BlockManager(
* so that we can control the maxMegabytesInFlight for the fetch.
*/
def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])])
: Iterator[(String, Option[Iterator[Any]])] = {
if (blocksByAddress == null) {
throw new IllegalArgumentException("BlocksByAddress is null")
}
val totalBlocks = blocksByAddress.map(_._2.size).sum
logDebug("Getting " + totalBlocks + " blocks")
var startTime = System.currentTimeMillis
val localBlockIds = new ArrayBuffer[String]()
val remoteBlockIds = new HashSet[String]()
// A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
// the block (since we want all deserializaton to happen in the calling thread); can also
// represent a fetch failure if size == -1.
class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
def failed: Boolean = size == -1
}
// A queue to hold our results.
val results = new LinkedBlockingQueue[FetchResult]
// A request to fetch one or more blocks, complete with their sizes
class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
val size = blocks.map(_._2).sum
}
// Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
// the number of bytes in flight is limited to maxBytesInFlight
val fetchRequests = new Queue[FetchRequest]
// Current bytes in flight from our requests
var bytesInFlight = 0L
def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
val cmId = new ConnectionManagerId(req.address.ip, req.address.port)
val blockMessageArray = new BlockMessageArray(req.blocks.map {
case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
})
bytesInFlight += req.size
val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
future.onSuccess {
case Some(message) => {
val bufferMessage = message.asInstanceOf[BufferMessage]
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
for (blockMessage <- blockMessageArray) {
if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
throw new SparkException(
"Unexpected message " + blockMessage.getType + " received from " + cmId)
}
val blockId = blockMessage.getId
results.put(new FetchResult(
blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData)))
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
}
case None => {
logError("Could not get block(s) from " + cmId)
for ((blockId, size) <- req.blocks) {
results.put(new FetchResult(blockId, -1, null))
}
}
}
}
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
val remoteRequests = new ArrayBuffer[FetchRequest]
for ((address, blockInfos) <- blocksByAddress) {
if (address == blockManagerId) {
localBlockIds ++= blockInfos.map(_._1)
} else {
remoteBlockIds ++= blockInfos.map(_._1)
// Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = new ArrayBuffer[(String, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
curBlocks += ((blockId, size))
curRequestSize += size
if (curRequestSize >= minRequestSize) {
// Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
curRequestSize = 0
curBlocks = new ArrayBuffer[(String, Long)]
}
}
// Add in the final request
if (!curBlocks.isEmpty) {
remoteRequests += new FetchRequest(address, curBlocks)
}
}
}
// Add the remote requests into our queue in a random order
fetchRequests ++= Utils.randomize(remoteRequests)
// Send out initial requests for blocks, up to our maxBytesInFlight
while (!fetchRequests.isEmpty &&
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
sendRequest(fetchRequests.dequeue())
}
val numGets = remoteBlockIds.size - fetchRequests.size
logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
// Get the local blocks while remote blocks are being fetched. Note that it's okay to do
// these all at once because they will just memory-map some files, so they won't consume
// any memory that might exceed our maxBytesInFlight
startTime = System.currentTimeMillis
for (id <- localBlockIds) {
getLocal(id) match {
case Some(iter) => {
results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
logDebug("Got local block " + id)
}
case None => {
throw new BlockException(id, "Could not get block " + id + " from local machine")
}
}
}
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
// Return an iterator that will read fetched blocks off the queue as they arrive.
return new Iterator[(String, Option[Iterator[Any]])] {
var resultsGotten = 0
def hasNext: Boolean = resultsGotten < totalBlocks
def next(): (String, Option[Iterator[Any]]) = {
resultsGotten += 1
val result = results.take()
bytesInFlight -= result.size
if (!fetchRequests.isEmpty &&
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
sendRequest(fetchRequests.dequeue())
}
(result.blockId, if (result.failed) None else Some(result.deserialize()))
}
}
: BlockFetcherIterator = {
return new BlockFetcherIterator(this, blocksByAddress)
}
def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
@ -986,3 +842,165 @@ object BlockManager extends Logging {
}
}
}
class BlockFetcherIterator(
private val blockManager: BlockManager,
val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]
) extends Iterator[(String, Option[Iterator[Any]])] with Logging {
import blockManager._
private var remoteBytesRead = 0l
if (blocksByAddress == null) {
throw new IllegalArgumentException("BlocksByAddress is null")
}
val totalBlocks = blocksByAddress.map(_._2.size).sum
logDebug("Getting " + totalBlocks + " blocks")
var startTime = System.currentTimeMillis
val localBlockIds = new ArrayBuffer[String]()
val remoteBlockIds = new HashSet[String]()
// A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
// the block (since we want all deserializaton to happen in the calling thread); can also
// represent a fetch failure if size == -1.
class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
def failed: Boolean = size == -1
}
// A queue to hold our results.
val results = new LinkedBlockingQueue[FetchResult]
// A request to fetch one or more blocks, complete with their sizes
class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
val size = blocks.map(_._2).sum
}
// Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
// the number of bytes in flight is limited to maxBytesInFlight
val fetchRequests = new Queue[FetchRequest]
// Current bytes in flight from our requests
var bytesInFlight = 0L
def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
val cmId = new ConnectionManagerId(req.address.ip, req.address.port)
val blockMessageArray = new BlockMessageArray(req.blocks.map {
case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
})
bytesInFlight += req.size
val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
future.onSuccess {
case Some(message) => {
val bufferMessage = message.asInstanceOf[BufferMessage]
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
for (blockMessage <- blockMessageArray) {
if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
throw new SparkException(
"Unexpected message " + blockMessage.getType + " received from " + cmId)
}
val blockId = blockMessage.getId
results.put(new FetchResult(
blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData)))
remoteBytesRead += req.size
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
}
case None => {
logError("Could not get block(s) from " + cmId)
for ((blockId, size) <- req.blocks) {
results.put(new FetchResult(blockId, -1, null))
}
}
}
}
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
val remoteRequests = new ArrayBuffer[FetchRequest]
for ((address, blockInfos) <- blocksByAddress) {
if (address == blockManagerId) {
localBlockIds ++= blockInfos.map(_._1)
} else {
remoteBlockIds ++= blockInfos.map(_._1)
// Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = new ArrayBuffer[(String, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
curBlocks += ((blockId, size))
curRequestSize += size
if (curRequestSize >= minRequestSize) {
// Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
curRequestSize = 0
curBlocks = new ArrayBuffer[(String, Long)]
}
}
// Add in the final request
if (!curBlocks.isEmpty) {
remoteRequests += new FetchRequest(address, curBlocks)
}
}
}
// Add the remote requests into our queue in a random order
fetchRequests ++= Utils.randomize(remoteRequests)
// Send out initial requests for blocks, up to our maxBytesInFlight
while (!fetchRequests.isEmpty &&
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
sendRequest(fetchRequests.dequeue())
}
val numGets = remoteBlockIds.size - fetchRequests.size
logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
// Get the local blocks while remote blocks are being fetched. Note that it's okay to do
// these all at once because they will just memory-map some files, so they won't consume
// any memory that might exceed our maxBytesInFlight
startTime = System.currentTimeMillis
for (id <- localBlockIds) {
getLocal(id) match {
case Some(iter) => {
results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
logDebug("Got local block " + id)
}
case None => {
throw new BlockException(id, "Could not get block " + id + " from local machine")
}
}
}
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
//an iterator that will read fetched blocks off the queue as they arrive.
var resultsGotten = 0
def hasNext: Boolean = resultsGotten < totalBlocks
def next(): (String, Option[Iterator[Any]]) = {
resultsGotten += 1
val result = results.take()
bytesInFlight -= result.size
if (!fetchRequests.isEmpty &&
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
sendRequest(fetchRequests.dequeue())
}
(result.blockId, if (result.failed) None else Some(result.deserialize()))
}
//methods to profile the block fetching
def numLocalBlocks = localBlockIds.size
def numRemoteBlocks = remoteBlockIds.size
}