From a480dec6b26f759cf60eac2a9260484eeafc508d Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 30 Aug 2012 20:01:06 -0700 Subject: [PATCH] Deserialize multi-get results in the caller's thread. This fixes an issue with shared buffers in the KryoSerializer. --- .../spark/BlockStoreShuffleFetcher.scala | 41 ++++++++----------- .../scala/spark/storage/BlockManager.scala | 17 ++++---- 2 files changed, 27 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index 45a14c8290..0bbdb4e432 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -32,36 +32,29 @@ class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { (address, splits.map(i => "shuffleid_%d_%d_%d".format(shuffleId, i, reduceId))) } - try { - for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) { - blockOption match { - case Some(block) => { - val values = block - for(value <- values) { - val v = value.asInstanceOf[(K, V)] - func(v._1, v._2) - } - } - case None => { - throw new BlockException(blockId, "Did not get block " + blockId) + for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) { + blockOption match { + case Some(block) => { + val values = block + for(value <- values) { + val v = value.asInstanceOf[(K, V)] + func(v._1, v._2) } } - } - } catch { - // TODO: this is really ugly -- let's find a better way of throwing a FetchFailedException - case be: BlockException => { - val regex = "shuffleid_([0-9]*)_([0-9]*)_([0-9]]*)".r - be.blockId match { - case regex(sId, mId, rId) => { - val address = addresses(mId.toInt) - throw new FetchFailedException(address, sId.toInt, mId.toInt, rId.toInt, be) - } - case _ => { - throw be + case None => { + val regex = "shuffleid_([0-9]*)_([0-9]*)_([0-9]*)".r + blockId match { + case regex(shufId, mapId, reduceId) => + val addr = addresses(mapId.toInt) + throw new FetchFailedException(addr, shufId.toInt, mapId.toInt, reduceId.toInt, null) + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block") } } } } + logDebug("Fetching and merging outputs of shuffle %d, reduce %d took %d ms".format( shuffleId, reduceId, System.currentTimeMillis - startTime)) } diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 45f99717bc..e9197f7169 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -272,11 +272,15 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val totalBlocks = blocksByAddress.map(_._2.size).sum logDebug("Getting " + totalBlocks + " blocks") var startTime = System.currentTimeMillis - val results = new LinkedBlockingQueue[(String, Option[Iterator[Any]])] val localBlockIds = new ArrayBuffer[String]() val remoteBlockIds = new ArrayBuffer[String]() val remoteBlockIdsPerLocation = new HashMap[BlockManagerId, Seq[String]]() + // A queue to hold our results. Because we want all the deserializing the happen in the + // caller's thread, this will actually hold functions to produce the Iterator for each block. + // For local blocks we'll have an iterator already, while for remote ones we'll deserialize. + val results = new LinkedBlockingQueue[(String, Option[() => Iterator[Any]])] + // Split local and remote blocks for ((address, blockIds) <- blocksByAddress) { if (address == blockManagerId) { @@ -302,10 +306,8 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m throw new SparkException( "Unexpected message " + blockMessage.getType + " received from " + cmId) } - val buffer = blockMessage.getData val blockId = blockMessage.getId - val block = dataDeserialize(buffer) - results.put((blockId, Some(block))) + results.put((blockId, Some(() => dataDeserialize(blockMessage.getData)))) logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) }) } @@ -323,9 +325,9 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m // Get the local blocks while remote blocks are being fetched startTime = System.currentTimeMillis localBlockIds.foreach(id => { - get(id) match { + getLocal(id) match { case Some(block) => { - results.put((id, Some(block))) + results.put((id, Some(() => block))) logDebug("Got local block " + id) } case None => { @@ -343,7 +345,8 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m def next(): (String, Option[Iterator[Any]]) = { resultsGotten += 1 - results.take() + val (blockId, functionOption) = results.take() + (blockId, functionOption.map(_.apply())) } } }