diff --git a/src/scala/spark/Broadcast.scala b/src/scala/spark/Broadcast.scala index 52e680d050..9456dead22 100644 --- a/src/scala/spark/Broadcast.scala +++ b/src/scala/spark/Broadcast.scala @@ -2,7 +2,7 @@ package spark import java.io._ import java.net._ -import java.util.{UUID, PriorityQueue, Comparator} +import java.util.{UUID, PriorityQueue, Comparator, BitSet} import com.google.common.collect.MapMaker @@ -45,6 +45,8 @@ class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean) @transient var hasBlocksLock = new Object @transient var pqOfSources = new PriorityQueue[SourceInfo] + + @transient var hasBlocksBitVector: BitSet = null @transient var serveMR: ServeMultipleRequests = null @transient var guideMR: GuideMultipleRequests = null @@ -89,7 +91,11 @@ class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean) arrayOfBlocks = variableInfo.arrayOfBlocks totalBytes = variableInfo.totalBytes totalBlocks = variableInfo.totalBlocks - hasBlocks = variableInfo.totalBlocks + hasBlocks = variableInfo.totalBlocks + + hasBlocksBitVector = new BitSet (totalBlocks) + hasBlocksBitVector.set (0, totalBlocks) + while (listenPort == -1) { listenPortLock.synchronized { @@ -98,9 +104,9 @@ class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean) } pqOfSources = new PriorityQueue[SourceInfo] - val masterSource_0 = + val masterSource = new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 0) - pqOfSources.add (masterSource_0) + pqOfSources.add (masterSource) // Register with the Tracker while (guidePort == -1) { @@ -151,9 +157,10 @@ class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean) private def initializeSlaveVariables = { arrayOfBlocks = null + hasBlocksBitVector = null totalBytes = -1 totalBlocks = -1 - hasBlocks = 0 + hasBlocks = 0 listenPortLock = new Object totalBlocksLock = new Object hasBlocksLock = new Object @@ -282,6 +289,7 @@ class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean) var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo] totalBlocks = sourceInfo.totalBlocks arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks) + hasBlocksBitVector = new BitSet (totalBlocks) totalBlocksLock.synchronized { totalBlocksLock.notifyAll } @@ -341,6 +349,7 @@ class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean) for (i <- hasBlocks until totalBlocks) { val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] arrayOfBlocks(hasBlocks) = bcBlock + hasBlocksBitVector.set (bcBlock.blockID) hasBlocks += 1 // Set to true if at least one block is received receptionSucceeded = true