BlockManager.getMultiple returns a custom iterator, to enable tracking of shuffle performance
This commit is contained in:
parent
e319ac74c1
commit
b29f9cc978
|
@ -446,152 +446,8 @@ class BlockManager(
|
|||
* so that we can control the maxMegabytesInFlight for the fetch.
|
||||
*/
|
||||
def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])])
|
||||
: Iterator[(String, Option[Iterator[Any]])] = {
|
||||
|
||||
if (blocksByAddress == null) {
|
||||
throw new IllegalArgumentException("BlocksByAddress is null")
|
||||
}
|
||||
val totalBlocks = blocksByAddress.map(_._2.size).sum
|
||||
logDebug("Getting " + totalBlocks + " blocks")
|
||||
var startTime = System.currentTimeMillis
|
||||
val localBlockIds = new ArrayBuffer[String]()
|
||||
val remoteBlockIds = new HashSet[String]()
|
||||
|
||||
// A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
|
||||
// the block (since we want all deserializaton to happen in the calling thread); can also
|
||||
// represent a fetch failure if size == -1.
|
||||
class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
|
||||
def failed: Boolean = size == -1
|
||||
}
|
||||
|
||||
// A queue to hold our results.
|
||||
val results = new LinkedBlockingQueue[FetchResult]
|
||||
|
||||
// A request to fetch one or more blocks, complete with their sizes
|
||||
class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
|
||||
val size = blocks.map(_._2).sum
|
||||
}
|
||||
|
||||
// Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
|
||||
// the number of bytes in flight is limited to maxBytesInFlight
|
||||
val fetchRequests = new Queue[FetchRequest]
|
||||
|
||||
// Current bytes in flight from our requests
|
||||
var bytesInFlight = 0L
|
||||
|
||||
def sendRequest(req: FetchRequest) {
|
||||
logDebug("Sending request for %d blocks (%s) from %s".format(
|
||||
req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
|
||||
val cmId = new ConnectionManagerId(req.address.ip, req.address.port)
|
||||
val blockMessageArray = new BlockMessageArray(req.blocks.map {
|
||||
case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
|
||||
})
|
||||
bytesInFlight += req.size
|
||||
val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
|
||||
val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
|
||||
future.onSuccess {
|
||||
case Some(message) => {
|
||||
val bufferMessage = message.asInstanceOf[BufferMessage]
|
||||
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
|
||||
for (blockMessage <- blockMessageArray) {
|
||||
if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
|
||||
throw new SparkException(
|
||||
"Unexpected message " + blockMessage.getType + " received from " + cmId)
|
||||
}
|
||||
val blockId = blockMessage.getId
|
||||
results.put(new FetchResult(
|
||||
blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData)))
|
||||
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
|
||||
}
|
||||
}
|
||||
case None => {
|
||||
logError("Could not get block(s) from " + cmId)
|
||||
for ((blockId, size) <- req.blocks) {
|
||||
results.put(new FetchResult(blockId, -1, null))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
|
||||
// at most maxBytesInFlight in order to limit the amount of data in flight.
|
||||
val remoteRequests = new ArrayBuffer[FetchRequest]
|
||||
for ((address, blockInfos) <- blocksByAddress) {
|
||||
if (address == blockManagerId) {
|
||||
localBlockIds ++= blockInfos.map(_._1)
|
||||
} else {
|
||||
remoteBlockIds ++= blockInfos.map(_._1)
|
||||
// Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
|
||||
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
|
||||
// nodes, rather than blocking on reading output from one node.
|
||||
val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
|
||||
logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
|
||||
val iterator = blockInfos.iterator
|
||||
var curRequestSize = 0L
|
||||
var curBlocks = new ArrayBuffer[(String, Long)]
|
||||
while (iterator.hasNext) {
|
||||
val (blockId, size) = iterator.next()
|
||||
curBlocks += ((blockId, size))
|
||||
curRequestSize += size
|
||||
if (curRequestSize >= minRequestSize) {
|
||||
// Add this FetchRequest
|
||||
remoteRequests += new FetchRequest(address, curBlocks)
|
||||
curRequestSize = 0
|
||||
curBlocks = new ArrayBuffer[(String, Long)]
|
||||
}
|
||||
}
|
||||
// Add in the final request
|
||||
if (!curBlocks.isEmpty) {
|
||||
remoteRequests += new FetchRequest(address, curBlocks)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add the remote requests into our queue in a random order
|
||||
fetchRequests ++= Utils.randomize(remoteRequests)
|
||||
|
||||
// Send out initial requests for blocks, up to our maxBytesInFlight
|
||||
while (!fetchRequests.isEmpty &&
|
||||
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
|
||||
sendRequest(fetchRequests.dequeue())
|
||||
}
|
||||
|
||||
val numGets = remoteBlockIds.size - fetchRequests.size
|
||||
logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
|
||||
|
||||
// Get the local blocks while remote blocks are being fetched. Note that it's okay to do
|
||||
// these all at once because they will just memory-map some files, so they won't consume
|
||||
// any memory that might exceed our maxBytesInFlight
|
||||
startTime = System.currentTimeMillis
|
||||
for (id <- localBlockIds) {
|
||||
getLocal(id) match {
|
||||
case Some(iter) => {
|
||||
results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
|
||||
logDebug("Got local block " + id)
|
||||
}
|
||||
case None => {
|
||||
throw new BlockException(id, "Could not get block " + id + " from local machine")
|
||||
}
|
||||
}
|
||||
}
|
||||
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
|
||||
|
||||
// Return an iterator that will read fetched blocks off the queue as they arrive.
|
||||
return new Iterator[(String, Option[Iterator[Any]])] {
|
||||
var resultsGotten = 0
|
||||
|
||||
def hasNext: Boolean = resultsGotten < totalBlocks
|
||||
|
||||
def next(): (String, Option[Iterator[Any]]) = {
|
||||
resultsGotten += 1
|
||||
val result = results.take()
|
||||
bytesInFlight -= result.size
|
||||
if (!fetchRequests.isEmpty &&
|
||||
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
|
||||
sendRequest(fetchRequests.dequeue())
|
||||
}
|
||||
(result.blockId, if (result.failed) None else Some(result.deserialize()))
|
||||
}
|
||||
}
|
||||
: BlockFetcherIterator = {
|
||||
return new BlockFetcherIterator(this, blocksByAddress)
|
||||
}
|
||||
|
||||
def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
|
||||
|
@ -986,3 +842,165 @@ object BlockManager extends Logging {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BlockFetcherIterator(
|
||||
private val blockManager: BlockManager,
|
||||
val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]
|
||||
) extends Iterator[(String, Option[Iterator[Any]])] with Logging {
|
||||
|
||||
import blockManager._
|
||||
|
||||
private var remoteBytesRead = 0l
|
||||
|
||||
if (blocksByAddress == null) {
|
||||
throw new IllegalArgumentException("BlocksByAddress is null")
|
||||
}
|
||||
val totalBlocks = blocksByAddress.map(_._2.size).sum
|
||||
logDebug("Getting " + totalBlocks + " blocks")
|
||||
var startTime = System.currentTimeMillis
|
||||
val localBlockIds = new ArrayBuffer[String]()
|
||||
val remoteBlockIds = new HashSet[String]()
|
||||
|
||||
// A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
|
||||
// the block (since we want all deserializaton to happen in the calling thread); can also
|
||||
// represent a fetch failure if size == -1.
|
||||
class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
|
||||
def failed: Boolean = size == -1
|
||||
}
|
||||
|
||||
// A queue to hold our results.
|
||||
val results = new LinkedBlockingQueue[FetchResult]
|
||||
|
||||
// A request to fetch one or more blocks, complete with their sizes
|
||||
class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
|
||||
val size = blocks.map(_._2).sum
|
||||
}
|
||||
|
||||
// Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
|
||||
// the number of bytes in flight is limited to maxBytesInFlight
|
||||
val fetchRequests = new Queue[FetchRequest]
|
||||
|
||||
// Current bytes in flight from our requests
|
||||
var bytesInFlight = 0L
|
||||
|
||||
def sendRequest(req: FetchRequest) {
|
||||
logDebug("Sending request for %d blocks (%s) from %s".format(
|
||||
req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
|
||||
val cmId = new ConnectionManagerId(req.address.ip, req.address.port)
|
||||
val blockMessageArray = new BlockMessageArray(req.blocks.map {
|
||||
case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
|
||||
})
|
||||
bytesInFlight += req.size
|
||||
val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
|
||||
val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
|
||||
future.onSuccess {
|
||||
case Some(message) => {
|
||||
val bufferMessage = message.asInstanceOf[BufferMessage]
|
||||
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
|
||||
for (blockMessage <- blockMessageArray) {
|
||||
if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
|
||||
throw new SparkException(
|
||||
"Unexpected message " + blockMessage.getType + " received from " + cmId)
|
||||
}
|
||||
val blockId = blockMessage.getId
|
||||
results.put(new FetchResult(
|
||||
blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData)))
|
||||
remoteBytesRead += req.size
|
||||
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
|
||||
}
|
||||
}
|
||||
case None => {
|
||||
logError("Could not get block(s) from " + cmId)
|
||||
for ((blockId, size) <- req.blocks) {
|
||||
results.put(new FetchResult(blockId, -1, null))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
|
||||
// at most maxBytesInFlight in order to limit the amount of data in flight.
|
||||
val remoteRequests = new ArrayBuffer[FetchRequest]
|
||||
for ((address, blockInfos) <- blocksByAddress) {
|
||||
if (address == blockManagerId) {
|
||||
localBlockIds ++= blockInfos.map(_._1)
|
||||
} else {
|
||||
remoteBlockIds ++= blockInfos.map(_._1)
|
||||
// Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
|
||||
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
|
||||
// nodes, rather than blocking on reading output from one node.
|
||||
val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
|
||||
logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
|
||||
val iterator = blockInfos.iterator
|
||||
var curRequestSize = 0L
|
||||
var curBlocks = new ArrayBuffer[(String, Long)]
|
||||
while (iterator.hasNext) {
|
||||
val (blockId, size) = iterator.next()
|
||||
curBlocks += ((blockId, size))
|
||||
curRequestSize += size
|
||||
if (curRequestSize >= minRequestSize) {
|
||||
// Add this FetchRequest
|
||||
remoteRequests += new FetchRequest(address, curBlocks)
|
||||
curRequestSize = 0
|
||||
curBlocks = new ArrayBuffer[(String, Long)]
|
||||
}
|
||||
}
|
||||
// Add in the final request
|
||||
if (!curBlocks.isEmpty) {
|
||||
remoteRequests += new FetchRequest(address, curBlocks)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add the remote requests into our queue in a random order
|
||||
fetchRequests ++= Utils.randomize(remoteRequests)
|
||||
|
||||
// Send out initial requests for blocks, up to our maxBytesInFlight
|
||||
while (!fetchRequests.isEmpty &&
|
||||
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
|
||||
sendRequest(fetchRequests.dequeue())
|
||||
}
|
||||
|
||||
val numGets = remoteBlockIds.size - fetchRequests.size
|
||||
logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
|
||||
|
||||
// Get the local blocks while remote blocks are being fetched. Note that it's okay to do
|
||||
// these all at once because they will just memory-map some files, so they won't consume
|
||||
// any memory that might exceed our maxBytesInFlight
|
||||
startTime = System.currentTimeMillis
|
||||
for (id <- localBlockIds) {
|
||||
getLocal(id) match {
|
||||
case Some(iter) => {
|
||||
results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
|
||||
logDebug("Got local block " + id)
|
||||
}
|
||||
case None => {
|
||||
throw new BlockException(id, "Could not get block " + id + " from local machine")
|
||||
}
|
||||
}
|
||||
}
|
||||
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
|
||||
|
||||
//an iterator that will read fetched blocks off the queue as they arrive.
|
||||
var resultsGotten = 0
|
||||
|
||||
def hasNext: Boolean = resultsGotten < totalBlocks
|
||||
|
||||
def next(): (String, Option[Iterator[Any]]) = {
|
||||
resultsGotten += 1
|
||||
val result = results.take()
|
||||
bytesInFlight -= result.size
|
||||
if (!fetchRequests.isEmpty &&
|
||||
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
|
||||
sendRequest(fetchRequests.dequeue())
|
||||
}
|
||||
(result.blockId, if (result.failed) None else Some(result.deserialize()))
|
||||
}
|
||||
|
||||
|
||||
//methods to profile the block fetching
|
||||
|
||||
def numLocalBlocks = localBlockIds.size
|
||||
def numRemoteBlocks = remoteBlockIds.size
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue