TreeBroadcast is an extended version of ChainedBroadcast with customizable maxDegree per node. maxDegree = 1 is ChainedBroadcast.

package spark
import java.util.{Comparator, Random, UUID}
import scala.collection.mutable.{ListBuffer, Map, Set}
class TreeBroadcast[T] (@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging {
def value = value_
TreeBroadcast.synchronized {
TreeBroadcast.values.put (uuid, value_)
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@transient var totalBytes = -1
@transient var totalBlocks = -1
@transient var hasBlocks = 0
@transient var listenPortLock = new Object
@transient var guidePortLock = new Object
@transient var totalBlocksLock = new Object
@transient var hasBlocksLock = new Object
@transient var listOfSources = ListBuffer[SourceInfo]()
@transient var serveMR: ServeMultipleRequests = null
@transient var guideMR: GuideMultipleRequests = null
@transient var hostAddress = InetAddress.getLocalHost.getHostAddress
@transient var listenPort = -1
@transient var guidePort = -1
@transient var hasCopyInHDFS = false
@transient var stopBroadcast = false
// Must call this after all the variables have been created/initialized
if (!isLocal) {
def sendBroadcast (): Unit = {
logInfo ("Local host address: " + hostAddress)
// Store a persistent copy in HDFS
// TODO: Turned OFF for now
// val out = new ObjectOutputStream (DfsBroadcast.openFileForWriting(uuid))
// out.writeObject (value_)
// out.close()
// TODO: Fix this at some point
hasCopyInHDFS = true
// Create a variableInfo object and store it in valueInfos
var variableInfo = blockifyObject (value_, TreeBroadcast.BlockSize)
// Prepare the value being broadcasted
// TODO: Refactoring and clean-up required here
arrayOfBlocks = variableInfo.arrayOfBlocks
totalBytes = variableInfo.totalBytes
totalBlocks = variableInfo.totalBlocks
hasBlocks = variableInfo.totalBlocks
guideMR = new GuideMultipleRequests
guideMR.setDaemon (true)
logInfo ("GuideMultipleRequests started...")
// Must always come AFTER guideMR is created
while (guidePort == -1) {
guidePortLock.synchronized {
serveMR = new ServeMultipleRequests
serveMR.setDaemon (true)
logInfo ("ServeMultipleRequests started...")
// Must always come AFTER serveMR is created
while (listenPort == -1) {
listenPortLock.synchronized {
// Must always come AFTER listenPort is created
val masterSource =
SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes)
listOfSources = listOfSources + masterSource
// Register with the Tracker
TreeBroadcast.registerValue (uuid, guidePort)
private def readObject (in: ObjectInputStream): Unit = {
TreeBroadcast.synchronized {
val cachedVal = TreeBroadcast.values.get (uuid)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
// Initializing everything because Master will only send null/0 values
logInfo ("Local host address: " + hostAddress)
serveMR = new ServeMultipleRequests
serveMR.setDaemon (true)
logInfo ("ServeMultipleRequests started...")
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast (uuid)
// If does not succeed, then get from HDFS copy
if (receptionSucceeded) {
value_ = unBlockifyObject[T]
TreeBroadcast.values.put (uuid, value_)
} else {
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
TreeBroadcast.values.put(uuid, value_)
val time = (System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
private def initializeSlaveVariables: Unit = {
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
hasBlocks = 0
listenPortLock = new Object
totalBlocksLock = new Object
hasBlocksLock = new Object
serveMR = null
hostAddress = InetAddress.getLocalHost.getHostAddress
listenPort = -1
stopBroadcast = false
private def blockifyObject (obj: T, blockSize: Int): VariableInfo = {
val baos = new ByteArrayOutputStream
val oos = new ObjectOutputStream (baos)
oos.writeObject (obj)
val byteArray = baos.toByteArray
val bais = new ByteArrayInputStream (byteArray)
var blockNum = (byteArray.length / blockSize)
if (byteArray.length % blockSize != 0)
blockNum += 1
var retVal = new Array[BroadcastBlock] (blockNum)
var blockID = 0
for (i <- 0 until (byteArray.length, blockSize)) {
val thisBlockSize = Math.min (blockSize, byteArray.length - i)
var tempByteArray = new Array[Byte] (thisBlockSize)
val hasRead = (tempByteArray, 0, thisBlockSize)
retVal (blockID) = new BroadcastBlock (blockID, tempByteArray)
blockID += 1
var variableInfo = VariableInfo (retVal, blockNum, byteArray.length)
variableInfo.hasBlocks = blockNum
return variableInfo
private def unBlockifyObject[A]: A = {
var retByteArray = new Array[Byte] (totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy (arrayOfBlocks(i).byteArray, 0, retByteArray,
i * TreeBroadcast.BlockSize, arrayOfBlocks(i).byteArray.length)
byteArrayToObject (retByteArray)
private def byteArrayToObject[A] (bytes: Array[Byte]): A = {
val in = new ObjectInputStream (new ByteArrayInputStream (bytes))
val retVal = in.readObject.asInstanceOf[A]
return retVal
def getMasterListenPort (variableUUID: UUID): Int = {
var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
var retriesLeft = TreeBroadcast.MaxRetryCount
do {
try {
// Connect to the tracker to find out the guide
clientSocketToTracker =
new Socket(TreeBroadcast.MasterHostAddress, TreeBroadcast.MasterTrackerPort)
oosTracker =
new ObjectOutputStream (clientSocketToTracker.getOutputStream)
oisTracker =
new ObjectInputStream (clientSocketToTracker.getInputStream)
// Send UUID and receive masterListenPort
oosTracker.writeObject (uuid)
masterListenPort = oisTracker.readObject.asInstanceOf[Int]
} catch {
case e: Exception => {
logInfo ("getMasterListenPort had a " + e)
} finally {
if (oisTracker != null) {
if (oosTracker != null) {
if (clientSocketToTracker != null) {
retriesLeft -= 1
Thread.sleep (TreeBroadcast.ranGen.nextInt (
TreeBroadcast.MaxKnockInterval - TreeBroadcast.MinKnockInterval) +
} while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
logInfo ("Got this guidePort from Tracker: " + masterListenPort)
return masterListenPort
def receiveBroadcast (variableUUID: UUID): Boolean = {
val masterListenPort = getMasterListenPort (variableUUID)
if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
masterListenPort == SourceInfo.TxNotStartedRetry) {
// TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
// to HDFS anyway when receiveBroadcast returns false
return false
// Wait until hostAddress and listenPort are created by the
// ServeMultipleRequests thread
while (listenPort == -1) {
listenPortLock.synchronized {
var clientSocketToMaster: Socket = null
var oosMaster: ObjectOutputStream = null
var oisMaster: ObjectInputStream = null
// Connect and receive broadcast from the specified source, retrying the
// specified number of times in case of failures
var retriesLeft = TreeBroadcast.MaxRetryCount
do {
// Connect to Master and send this worker's Information
clientSocketToMaster =
new Socket(TreeBroadcast.MasterHostAddress, masterListenPort)
// TODO: Guiding object connection is reusable
oosMaster =
new ObjectOutputStream (clientSocketToMaster.getOutputStream)
oisMaster =
new ObjectInputStream (clientSocketToMaster.getInputStream)
logInfo ("Connected to Master's guiding object")
// Send local source information
oosMaster.writeObject(SourceInfo (hostAddress, listenPort,
SourceInfo.UnusedParam, SourceInfo.UnusedParam))
// Receive source information from Master
var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
totalBlocks = sourceInfo.totalBlocks
arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks)
totalBlocksLock.synchronized {
totalBytes = sourceInfo.totalBytes
logInfo ("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
val start = System.nanoTime
val receptionSucceeded = receiveSingleTransmission (sourceInfo)
val time = (System.nanoTime - start) / 1e9
// Updating some statistics in sourceInfo. Master will be using them later
if (!receptionSucceeded) {
sourceInfo.receptionFailed = true
// Send back statistics to the Master
oosMaster.writeObject (sourceInfo)
if (oisMaster != null) {
if (oosMaster != null) {
if (clientSocketToMaster != null) {
retriesLeft -= 1
} while (retriesLeft > 0 && hasBlocks < totalBlocks)
return (hasBlocks == totalBlocks)
// Tries to receive broadcast from the source and returns Boolean status.
// This might be called multiple times to retry a defined number of times.
private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
var clientSocketToSource: Socket = null
var oosSource: ObjectOutputStream = null
var oisSource: ObjectInputStream = null
var receptionSucceeded = false
try {
// Connect to the source to get the object itself
clientSocketToSource =
new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
oosSource =
new ObjectOutputStream (clientSocketToSource.getOutputStream)
oisSource =
new ObjectInputStream (clientSocketToSource.getInputStream)
logInfo ("Inside receiveSingleTransmission")
logInfo ("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
// Send the range
oosSource.writeObject((hasBlocks, totalBlocks))
for (i <- hasBlocks until totalBlocks) {
val recvStartTime = System.currentTimeMillis
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
val receptionTime = (System.currentTimeMillis - recvStartTime)
logInfo ("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
arrayOfBlocks(hasBlocks) = bcBlock
hasBlocks += 1
// Set to true if at least one block is received
receptionSucceeded = true
hasBlocksLock.synchronized {
} catch {
case e: Exception => {
logInfo ("receiveSingleTransmission had a " + e)
} finally {
if (oisSource != null) {
if (oosSource != null) {
if (clientSocketToSource != null) {
return receptionSucceeded
class GuideMultipleRequests
extends Thread with Logging {
// Keep track of sources that have completed reception
private var setOfCompletedSources = Set[SourceInfo] ()
override def run: Unit = {
var threadPool = Broadcast.newDaemonCachedThreadPool
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket (0)
guidePort = serverSocket.getLocalPort
logInfo ("GuideMultipleRequests => " + serverSocket + " " + guidePort)
guidePortLock.synchronized {
try {
// Don't stop until there is a copy in HDFS
while (!stopBroadcast || !hasCopyInHDFS) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout (TreeBroadcast.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo ("GuideMultipleRequests Timeout.")
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done. Comparing with
// listOfSources.size - 1, because it includes the Guide itself
if (listOfSources.size > 1 &&
setOfCompletedSources.size == listOfSources.size - 1) {
stopBroadcast = true
if (clientSocket != null) {
logInfo ("Guide: Accepted new client connection: " + clientSocket)
try {
threadPool.execute (new GuideSingleRequest (clientSocket))
} catch {
// In failure, close() the socket here; else, the thread will close() it
case ioe: IOException => clientSocket.close()
logInfo ("Sending stopBroadcast notifications...")
TreeBroadcast.unregisterValue (uuid)
} finally {
if (serverSocket != null) {
logInfo ("GuideMultipleRequests now stopping...")
// Shutdown the thread pool
private def sendStopBroadcastNotifications: Unit = {
listOfSources.synchronized {
var listIter = listOfSources.iterator
while (listIter.hasNext) {
var sourceInfo =
var guideSocketToSource: Socket = null
var gosSource: ObjectOutputStream = null
var gisSource: ObjectInputStream = null
try {
// Connect to the source
guideSocketToSource =
new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
gosSource =
new ObjectOutputStream (guideSocketToSource.getOutputStream)
gisSource =
new ObjectInputStream (guideSocketToSource.getInputStream)
// Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
gosSource.writeObject ((SourceInfo.StopBroadcast,
} catch {
case e: Exception => {
logInfo ("sendStopBroadcastNotifications had a " + e)
} finally {
if (gisSource != null) {
if (gosSource != null) {
if (guideSocketToSource != null) {
class GuideSingleRequest (val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
private val ois = new ObjectInputStream (clientSocket.getInputStream)
private var selectedSourceInfo: SourceInfo = null
private var thisWorkerInfo:SourceInfo = null
override def run: Unit = {
try {
logInfo ("new GuideSingleRequest is running")
// Connecting worker is sending in its hostAddress and listenPort it will
// be listening to. Other fields are invalid (SourceInfo.UnusedParam)
var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
listOfSources.synchronized {
// Select a suitable source and send it back to the worker
selectedSourceInfo = selectSuitableSource (sourceInfo)
logInfo ("Sending selectedSourceInfo: " + selectedSourceInfo)
oos.writeObject (selectedSourceInfo)
// Add this new (if it can finish) source to the list of sources
thisWorkerInfo = SourceInfo (sourceInfo.hostAddress,
sourceInfo.listenPort, totalBlocks, totalBytes)
logInfo ("Adding possible new source to listOfSources: " + thisWorkerInfo)
listOfSources = listOfSources + thisWorkerInfo
// Wait till the whole transfer is done. Then receive and update source
// statistics in listOfSources
sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
listOfSources.synchronized {
// This should work since SourceInfo is a case class
assert (listOfSources.contains (selectedSourceInfo))
// Remove first
listOfSources = listOfSources - selectedSourceInfo
// TODO: Removing a source based on just one failure notification!
// Update sourceInfo and put it back in, IF reception succeeded
if (!sourceInfo.receptionFailed) {
// Add thisWorkerInfo to sources that have completed reception
setOfCompletedSources += thisWorkerInfo
selectedSourceInfo.currentLeechers -= 1
// Put it back
listOfSources = listOfSources + selectedSourceInfo
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close() everything up
case e: Exception => {
// Assuming that exception caused due to receiver worker failure.
// Remove failed worker from listOfSources and update leecherCount of
// corresponding source worker
listOfSources.synchronized {
if (selectedSourceInfo != null) {
// Remove first
listOfSources = listOfSources - selectedSourceInfo
// Update leecher count and put it back in
selectedSourceInfo.currentLeechers -= 1
listOfSources = listOfSources + selectedSourceInfo
// Remove thisWorkerInfo
if (listOfSources != null) {
listOfSources = listOfSources - thisWorkerInfo
} finally {
// TODO: Caller must have a synchronized block on listOfSources
// TODO: If a worker fails to get the broadcasted variable from a source
// and comes back to the Master, this function might choose the worker
// itself as a source to create a dependency cycle (this worker was put
// into listOfSources as a streming source when it first arrived). The
// length of this cycle can be arbitrarily long.
private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
// Select one with the most leechers. This will level-wise fill the tree
var maxLeechers = -1
var selectedSource: SourceInfo = null
listOfSources.foreach { source =>
if (source != skipSourceInfo &&
source.currentLeechers < TreeBroadcast.MaxDegree &&
source.currentLeechers > maxLeechers) {
selectedSource = source
maxLeechers = source.currentLeechers
// Update leecher count
selectedSource.currentLeechers += 1
return selectedSource
class ServeMultipleRequests
extends Thread with Logging {
override def run: Unit = {
var threadPool = Broadcast.newDaemonCachedThreadPool
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket (0)
listenPort = serverSocket.getLocalPort
logInfo ("ServeMultipleRequests started with " + serverSocket)
listenPortLock.synchronized {
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout (TreeBroadcast.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo ("ServeMultipleRequests Timeout.")
if (clientSocket != null) {
logInfo ("Serve: Accepted new client connection: " + clientSocket)
try {
threadPool.execute (new ServeSingleRequest (clientSocket))
} catch {
// In failure, close() socket here; else, the thread will close() it
case ioe: IOException => clientSocket.close()
} finally {
if (serverSocket != null) {
logInfo ("ServeMultipleRequests now stopping...")
// Shutdown the thread pool
class ServeSingleRequest (val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
private val ois = new ObjectInputStream (clientSocket.getInputStream)
private var sendFrom = 0
private var sendUntil = totalBlocks
override def run: Unit = {
try {
logInfo ("new ServeSingleRequest is running")
// Receive range to send
var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
sendFrom = rangeToSend._1
sendUntil = rangeToSend._2
if (sendFrom == SourceInfo.StopBroadcast &&
sendUntil == SourceInfo.StopBroadcast) {
stopBroadcast = true
} else {
// Carry on
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close() everything up
case e: Exception => {
logInfo ("ServeSingleRequest had a " + e)
} finally {
logInfo ("ServeSingleRequest is closing streams and sockets")
private def sendObject: Unit = {
// Wait till receiving the SourceInfo from Master
while (totalBlocks == -1) {
totalBlocksLock.synchronized {
for (i <- sendFrom until sendUntil) {
while (i == hasBlocks) {
hasBlocksLock.synchronized {
try {
oos.writeObject (arrayOfBlocks(i))
} catch {
case e: Exception => {
logInfo ("sendObject had a " + e)
logInfo ("Sent block: " + i + " to " + clientSocket)
class TreeBroadcastFactory
extends BroadcastFactory {
def initialize (isMaster: Boolean) = TreeBroadcast.initialize (isMaster)
def newBroadcast[T] (value_ : T, isLocal: Boolean) =
new TreeBroadcast[T] (value_, isLocal)
private object TreeBroadcast
extends Logging {
val values = Cache.newKeySpace()
var valueToGuidePortMap = Map[UUID, Int] ()
// Random number generator
var ranGen = new Random
private var initialized = false
private var isMaster_ = false
private var MasterHostAddress_ = InetAddress.getLocalHost.getHostAddress
private var MasterTrackerPort_ : Int = 22222
private var BlockSize_ : Int = 512 * 1024
private var MaxDegree_ : Int = 1
private var MaxRetryCount_ : Int = 2
private var TrackerSocketTimeout_ : Int = 50000
private var ServerSocketTimeout_ : Int = 10000
private var trackMV: TrackMultipleValues = null
private var MinKnockInterval_ = 500
private var MaxKnockInterval_ = 999
def initialize (isMaster__ : Boolean): Unit = {
synchronized {
if (!initialized) {
MasterHostAddress_ =
System.getProperty ("spark.broadcast.masterHostAddress", "")
MasterTrackerPort_ = System.getProperty (
"spark.broadcast.masterTrackerPort", "22222").toInt
BlockSize_ = System.getProperty (
"spark.broadcast.blockSize", "512").toInt * 1024
MaxDegree_ = System.getProperty (
"spark.broadcast.maxDegree", "1").toInt
MaxRetryCount_ = System.getProperty (
"spark.broadcast.maxRetryCount", "2").toInt
TrackerSocketTimeout_ = System.getProperty (
"spark.broadcast.trackerSocketTimeout", "50000").toInt
ServerSocketTimeout_ = System.getProperty (
"spark.broadcast.serverSocketTimeout", "10000").toInt
MinKnockInterval_ = System.getProperty (
"spark.broadcast.minKnockInterval", "500").toInt
MaxKnockInterval_ = System.getProperty (
"spark.broadcast.maxKnockInterval", "999").toInt
isMaster_ = isMaster__
if (isMaster) {
trackMV = new TrackMultipleValues
trackMV.setDaemon (true)
logInfo ("TrackMultipleValues started...")
// Initialize DfsBroadcast to be used for broadcast variable persistence
initialized = true
def MasterHostAddress = MasterHostAddress_
def MasterTrackerPort = MasterTrackerPort_
def BlockSize = BlockSize_
def MaxDegree = MaxDegree_
def MaxRetryCount = MaxRetryCount_
def TrackerSocketTimeout = TrackerSocketTimeout_
def ServerSocketTimeout = ServerSocketTimeout_
def isMaster = isMaster_
def MinKnockInterval = MinKnockInterval_
def MaxKnockInterval = MaxKnockInterval_
def registerValue (uuid: UUID, guidePort: Int): Unit = {
valueToGuidePortMap.synchronized {
valueToGuidePortMap += (uuid -> guidePort)
logInfo ("New value registered with the Tracker " + valueToGuidePortMap)
def unregisterValue (uuid: UUID): Unit = {
valueToGuidePortMap.synchronized {
valueToGuidePortMap (uuid) = SourceInfo.TxOverGoToHDFS
logInfo ("Value unregistered from the Tracker " + valueToGuidePortMap)
class TrackMultipleValues
extends Thread with Logging {
override def run: Unit = {
var threadPool = Broadcast.newDaemonCachedThreadPool
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket (TreeBroadcast.MasterTrackerPort)
logInfo ("TrackMultipleValues" + serverSocket)
try {
while (true) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout (TrackerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo ("TrackMultipleValues Timeout. Stopping listening...")
if (clientSocket != null) {
try {
threadPool.execute (new Thread {
override def run: Unit = {
val oos = new ObjectOutputStream (clientSocket.getOutputStream)
val ois = new ObjectInputStream (clientSocket.getInputStream)
try {
val uuid = ois.readObject.asInstanceOf[UUID]
var guidePort =
if (valueToGuidePortMap.contains (uuid)) {
valueToGuidePortMap (uuid)
} else SourceInfo.TxNotStartedRetry
logInfo ("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
oos.writeObject (guidePort)
} catch {
case e: Exception => {
logInfo ("TrackMultipleValues had a " + e)
} finally {
} catch {
// In failure, close() socket here; else, client thread will close()
case ioe: IOException => clientSocket.close()
} finally {
// Shutdown the thread pool