Skip fetching zero-sized blocks in OIO.
Also unify splitLocalRemoteBlocks for netty/nio and add a test case
This commit is contained in:
parent
6ed71390d9
commit
618c8cae1e
|
@ -124,6 +124,7 @@ object BlockFetcherIterator {
|
|||
protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
|
||||
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
|
||||
// at most maxBytesInFlight in order to limit the amount of data in flight.
|
||||
val originalTotalBlocks = _totalBlocks
|
||||
val remoteRequests = new ArrayBuffer[FetchRequest]
|
||||
for ((address, blockInfos) <- blocksByAddress) {
|
||||
if (address == blockManagerId) {
|
||||
|
@ -140,8 +141,15 @@ object BlockFetcherIterator {
|
|||
var curBlocks = new ArrayBuffer[(String, Long)]
|
||||
while (iterator.hasNext) {
|
||||
val (blockId, size) = iterator.next()
|
||||
curBlocks += ((blockId, size))
|
||||
curRequestSize += size
|
||||
// Skip empty blocks
|
||||
if (size > 0) {
|
||||
curBlocks += ((blockId, size))
|
||||
curRequestSize += size
|
||||
} else if (size == 0) {
|
||||
_totalBlocks -= 1
|
||||
} else {
|
||||
throw new BlockException(blockId, "Negative block size " + size)
|
||||
}
|
||||
if (curRequestSize >= minRequestSize) {
|
||||
// Add this FetchRequest
|
||||
remoteRequests += new FetchRequest(address, curBlocks)
|
||||
|
@ -155,6 +163,8 @@ object BlockFetcherIterator {
|
|||
}
|
||||
}
|
||||
}
|
||||
logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " +
|
||||
originalTotalBlocks + " blocks")
|
||||
remoteRequests
|
||||
}
|
||||
|
||||
|
@ -278,53 +288,6 @@ object BlockFetcherIterator {
|
|||
logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host )
|
||||
}
|
||||
|
||||
override protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
|
||||
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
|
||||
// at most maxBytesInFlight in order to limit the amount of data in flight.
|
||||
val originalTotalBlocks = _totalBlocks;
|
||||
val remoteRequests = new ArrayBuffer[FetchRequest]
|
||||
for ((address, blockInfos) <- blocksByAddress) {
|
||||
if (address == blockManagerId) {
|
||||
localBlockIds ++= blockInfos.map(_._1)
|
||||
} else {
|
||||
remoteBlockIds ++= blockInfos.map(_._1)
|
||||
// Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
|
||||
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
|
||||
// nodes, rather than blocking on reading output from one node.
|
||||
val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
|
||||
logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
|
||||
val iterator = blockInfos.iterator
|
||||
var curRequestSize = 0L
|
||||
var curBlocks = new ArrayBuffer[(String, Long)]
|
||||
while (iterator.hasNext) {
|
||||
val (blockId, size) = iterator.next()
|
||||
if (size > 0) {
|
||||
curBlocks += ((blockId, size))
|
||||
curRequestSize += size
|
||||
} else if (size == 0) {
|
||||
//here we changes the totalBlocks
|
||||
_totalBlocks -= 1
|
||||
} else {
|
||||
throw new BlockException(blockId, "Negative block size " + size)
|
||||
}
|
||||
if (curRequestSize >= minRequestSize) {
|
||||
// Add this FetchRequest
|
||||
remoteRequests += new FetchRequest(address, curBlocks)
|
||||
curRequestSize = 0
|
||||
curBlocks = new ArrayBuffer[(String, Long)]
|
||||
}
|
||||
}
|
||||
// Add in the final request
|
||||
if (!curBlocks.isEmpty) {
|
||||
remoteRequests += new FetchRequest(address, curBlocks)
|
||||
}
|
||||
}
|
||||
}
|
||||
logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " +
|
||||
originalTotalBlocks + " blocks")
|
||||
remoteRequests
|
||||
}
|
||||
|
||||
private var copiers: List[_ <: Thread] = null
|
||||
|
||||
override def initialize() {
|
||||
|
|
|
@ -317,6 +317,33 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
|
|||
val c = new ShuffledRDD(b, new HashPartitioner(3), classOf[spark.KryoSerializer].getName)
|
||||
assert(c.count === 10)
|
||||
}
|
||||
|
||||
test("zero sized blocks") {
|
||||
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
|
||||
sc = new SparkContext("local-cluster[2,1,512]", "test")
|
||||
|
||||
// 10 partitions from 4 keys
|
||||
val NUM_BLOCKS = 10
|
||||
val a = sc.parallelize(1 to 4, NUM_BLOCKS)
|
||||
val b = a.map(x => (x, x*2))
|
||||
|
||||
// NOTE: The default Java serializer doesn't create zero-sized blocks.
|
||||
// So, use Kryo
|
||||
val c = new ShuffledRDD(b, new HashPartitioner(10), classOf[spark.KryoSerializer].getName)
|
||||
|
||||
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
|
||||
assert(c.count === 4)
|
||||
|
||||
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
|
||||
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
|
||||
statuses.map(x => x._2)
|
||||
}
|
||||
val nonEmptyBlocks = blockSizes.filter(x => x > 0)
|
||||
|
||||
// We should have at most 4 non-zero sized partitions
|
||||
assert(nonEmptyBlocks.size <= 4)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
object ShuffleSuite {
|
||||
|
|
Loading…
Reference in a new issue