Merge pull request #147 from mosharaf/dev

Broadcast refactoring/cleaning up
This commit is contained in:
Matei Zaharia 2012-08-23 19:38:28 -07:00
commit 7310a6f499
12 changed files with 652 additions and 1901 deletions

View file

@ -65,7 +65,7 @@ class SparkContext(
System.setProperty("spark.master.port", "0")
}
private val isLocal = (master == "local" || master.startsWith("local["))
private val isLocal = (master == "local" || master.startsWith("local[")) && !master.startsWith("localhost")
// Create the Spark execution environment (cache, map output tracker, etc)
val env = SparkEnv.createFromSystemProperties(
@ -74,7 +74,6 @@ class SparkContext(
true,
isLocal)
SparkEnv.set(env)
Broadcast.initialize(true)
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
@ -295,14 +294,14 @@ class SparkContext(
// Keep around a weak hash map of values to Cached versions?
def broadcast[T](value: T) = Broadcast.getBroadcastFactory.newBroadcast[T] (value, isLocal)
def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal)
// Stop the SparkContext
def stop() {
dagScheduler.stop()
dagScheduler = null
taskScheduler = null
// TODO: Broadcast.stop(), Cache.stop()?
// TODO: Cache.stop()?
env.stop()
SparkEnv.set(null)
ShuffleMapTask.clearCache()

View file

@ -2,6 +2,7 @@ package spark
import akka.actor.ActorSystem
import spark.broadcast.BroadcastManager
import spark.storage.BlockManager
import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager
@ -16,13 +17,14 @@ class SparkEnv (
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager
) {
/** No-parameter constructor for unit tests. */
def this() = {
this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null)
this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
}
def stop() {
@ -30,6 +32,7 @@ class SparkEnv (
cacheTracker.stop()
shuffleFetcher.stop()
shuffleManager.stop()
broadcastManager.stop()
blockManager.stop()
blockManager.master.stop()
actorSystem.shutdown()
@ -74,6 +77,8 @@ object SparkEnv {
val shuffleManager = new ShuffleManager()
val broadcastManager = new BroadcastManager(isMaster)
val closureSerializerClass =
System.getProperty("spark.closure.serializer", "spark.JavaSerializer")
val closureSerializer =
@ -119,6 +124,7 @@ object SparkEnv {
mapOutputTracker,
shuffleFetcher,
shuffleManager,
broadcastManager,
blockManager,
connectionManager)
}

View file

@ -2,21 +2,22 @@ package spark.broadcast
import java.io._
import java.net._
import java.util.{BitSet, Comparator, Random, Timer, TimerTask, UUID}
import java.util.{BitSet, Comparator, Timer, TimerTask, UUID}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ListBuffer, Map, Set}
import scala.math
import spark._
import spark.storage.StorageLevel
class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging with Serializable {
def value = value_
BitTorrentBroadcast.synchronized {
BitTorrentBroadcast.values.put(uuid, 0, value_)
MultiTracker.synchronized {
SparkEnv.get.blockManager.putSingle(uuid.toString, value_, StorageLevel.MEMORY_ONLY, false)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@ -25,8 +26,6 @@ extends Broadcast[T] with Logging with Serializable {
@transient var totalBytes = -1
@transient var totalBlocks = -1
@transient var hasBlocks = new AtomicInteger(0)
// CHANGED: BlockSize in the Broadcast object is expected to change over time
@transient var blockSize = Broadcast.BlockSize
// Used ONLY by Master to track how many unique blocks have been sent out
@transient var sentBlocks = new AtomicInteger(0)
@ -45,14 +44,10 @@ extends Broadcast[T] with Logging with Serializable {
// Used only in Workers
@transient var ttGuide: TalkToGuide = null
@transient var rxSpeeds = new SpeedTracker
@transient var txSpeeds = new SpeedTracker
@transient var hostAddress = Utils.localIpAddress
@transient var listenPort = -1
@transient var guidePort = -1
@transient var hasCopyInHDFS = false
@transient var stopBroadcast = false
// Must call this after all the variables have been created/initialized
@ -63,19 +58,10 @@ extends Broadcast[T] with Logging with Serializable {
def sendBroadcast() {
logInfo("Local host address: " + hostAddress)
// Store a persistent copy in HDFS
// TODO: Turned OFF for now. Related to persistence
// val out = new ObjectOutputStream(BroadcastCH.openFileForWriting(uuid))
// out.writeObject(value_)
// out.close()
// FIXME: Fix this at some point
hasCopyInHDFS = true
// Create a variableInfo object and store it in valueInfos
var variableInfo = Broadcast.blockifyObject(value_)
var variableInfo = MultiTracker.blockifyObject(value_)
// Prepare the value being broadcasted
// TODO: Refactoring and clean-up required here
arrayOfBlocks = variableInfo.arrayOfBlocks
totalBytes = variableInfo.totalBytes
totalBlocks = variableInfo.totalBlocks
@ -95,9 +81,7 @@ extends Broadcast[T] with Logging with Serializable {
// Must always come AFTER guideMR is created
while (guidePort == -1) {
guidePortLock.synchronized {
guidePortLock.wait()
}
guidePortLock.synchronized { guidePortLock.wait() }
}
serveMR = new ServeMultipleRequests
@ -107,14 +91,12 @@ extends Broadcast[T] with Logging with Serializable {
// Must always come AFTER serveMR is created
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait()
}
listenPortLock.synchronized { listenPortLock.wait() }
}
// Must always come AFTER listenPort is created
val masterSource =
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize)
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
hasBlocksBitVector.synchronized {
masterSource.hasBlocksBitVector = hasBlocksBitVector
}
@ -123,46 +105,42 @@ extends Broadcast[T] with Logging with Serializable {
listOfSources += masterSource
// Register with the Tracker
registerBroadcast(uuid,
SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes, blockSize))
MultiTracker.registerBroadcast(uuid,
SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
}
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
BitTorrentBroadcast.synchronized {
val cachedVal = BitTorrentBroadcast.values.get(uuid, 0)
MultiTracker.synchronized {
SparkEnv.get.blockManager.getSingle(uuid.toString) match {
case Some(x) => x.asInstanceOf[T]
case None => {
logInfo("Started reading broadcast variable " + uuid)
// Initializing everything because Master will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
// Only the first worker in a node can ever be inside this 'else'
initializeWorkerVariables
logInfo("Local host address: " + hostAddress)
logInfo("Local host address: " + hostAddress)
// Start local ServeMultipleRequests thread first
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start()
logInfo("ServeMultipleRequests started...")
// Start local ServeMultipleRequests thread first
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start()
logInfo("ServeMultipleRequests started...")
val start = System.nanoTime
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast(uuid)
if (receptionSucceeded) {
value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
SparkEnv.get.blockManager.putSingle(uuid.toString, value_, StorageLevel.MEMORY_ONLY, false)
} else {
logError("Reading Broadcasted variable " + uuid + " failed")
}
val receptionSucceeded = receiveBroadcast(uuid)
// If does not succeed, then get from HDFS copy
if (receptionSucceeded) {
value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
BitTorrentBroadcast.values.put(uuid, 0, value_)
} else {
// TODO: This part won't work, cause HDFS writing is turned OFF
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
BitTorrentBroadcast.values.put(uuid, 0, value_)
fileIn.close()
val time = (System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
}
val time = (System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
}
}
}
@ -175,7 +153,6 @@ extends Broadcast[T] with Logging with Serializable {
totalBytes = -1
totalBlocks = -1
hasBlocks = new AtomicInteger(0)
blockSize = -1
listenPortLock = new Object
totalBlocksLock = new Object
@ -183,9 +160,6 @@ extends Broadcast[T] with Logging with Serializable {
serveMR = null
ttGuide = null
rxSpeeds = new SpeedTracker
txSpeeds = new SpeedTracker
hostAddress = Utils.localIpAddress
listenPort = -1
@ -194,75 +168,19 @@ extends Broadcast[T] with Logging with Serializable {
stopBroadcast = false
}
private def registerBroadcast(uuid: UUID, gInfo: SourceInfo) {
val socket = new Socket(Broadcast.MasterHostAddress,
Broadcast.MasterTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
val oisST = new ObjectInputStream(socket.getInputStream)
// Send messageType/intention
oosST.writeObject(Broadcast.REGISTER_BROADCAST_TRACKER)
oosST.flush()
// Send UUID of this broadcast
oosST.writeObject(uuid)
oosST.flush()
// Send this tracker's information
oosST.writeObject(gInfo)
oosST.flush()
// Receive ACK and throw it away
oisST.readObject.asInstanceOf[Int]
// Shut stuff down
oisST.close()
oosST.close()
socket.close()
}
private def unregisterBroadcast(uuid: UUID) {
val socket = new Socket(Broadcast.MasterHostAddress,
Broadcast.MasterTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
val oisST = new ObjectInputStream(socket.getInputStream)
// Send messageType/intention
oosST.writeObject(Broadcast.UNREGISTER_BROADCAST_TRACKER)
oosST.flush()
// Send UUID of this broadcast
oosST.writeObject(uuid)
oosST.flush()
// Receive ACK and throw it away
oisST.readObject.asInstanceOf[Int]
// Shut stuff down
oisST.close()
oosST.close()
socket.close()
}
private def getLocalSourceInfo: SourceInfo = {
// Wait till hostName and listenPort are OK
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait()
}
listenPortLock.synchronized { listenPortLock.wait() }
}
// Wait till totalBlocks and totalBytes are OK
while (totalBlocks == -1) {
totalBlocksLock.synchronized {
totalBlocksLock.wait()
}
totalBlocksLock.synchronized { totalBlocksLock.wait() }
}
var localSourceInfo = SourceInfo(
hostAddress, listenPort, totalBlocks, totalBytes, blockSize)
hostAddress, listenPort, totalBlocks, totalBytes)
localSourceInfo.hasBlocks = hasBlocks.get
@ -274,7 +192,7 @@ extends Broadcast[T] with Logging with Serializable {
}
// Add new SourceInfo to the listOfSources. Update if it exists already.
// TODO: Optimizing just by OR-ing the BitVectors was BAD for performance
// Optimizing just by OR-ing the BitVectors was BAD for performance
private def addToListOfSources(newSourceInfo: SourceInfo) {
listOfSources.synchronized {
if (listOfSources.contains(newSourceInfo)) {
@ -297,9 +215,9 @@ extends Broadcast[T] with Logging with Serializable {
// Keep exchaning information until all blocks have been received
while (hasBlocks.get < totalBlocks) {
talkOnce
Thread.sleep(BitTorrentBroadcast.ranGen.nextInt(
Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) +
Broadcast.MinKnockInterval)
Thread.sleep(MultiTracker.ranGen.nextInt(
MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
MultiTracker.MinKnockInterval)
}
// Talk one more time to let the Guide know of reception completion
@ -324,7 +242,7 @@ extends Broadcast[T] with Logging with Serializable {
// Receive source information from Guide
var suitableSources =
oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]]
logInfo("Received suitableSources from Master " + suitableSources)
logDebug("Received suitableSources from Master " + suitableSources)
addToListOfSources(suitableSources)
@ -334,76 +252,17 @@ extends Broadcast[T] with Logging with Serializable {
}
}
def getGuideInfo(variableUUID: UUID): SourceInfo = {
var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxOverGoToHDFS)
var retriesLeft = Broadcast.MaxRetryCount
do {
try {
// Connect to the tracker to find out GuideInfo
clientSocketToTracker =
new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort)
oosTracker =
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
oosTracker.flush()
oisTracker =
new ObjectInputStream(clientSocketToTracker.getInputStream)
// Send messageType/intention
oosTracker.writeObject(Broadcast.FIND_BROADCAST_TRACKER)
oosTracker.flush()
// Send UUID and receive GuideInfo
oosTracker.writeObject(uuid)
oosTracker.flush()
gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
} catch {
case e: Exception => {
logInfo("getGuideInfo had a " + e)
}
} finally {
if (oisTracker != null) {
oisTracker.close()
}
if (oosTracker != null) {
oosTracker.close()
}
if (clientSocketToTracker != null) {
clientSocketToTracker.close()
}
}
Thread.sleep(BitTorrentBroadcast.ranGen.nextInt(
Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) +
Broadcast.MinKnockInterval)
retriesLeft -= 1
} while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry)
logInfo("Got this guidePort from Tracker: " + gInfo.listenPort)
return gInfo
}
def receiveBroadcast(variableUUID: UUID): Boolean = {
val gInfo = getGuideInfo(variableUUID)
val gInfo = MultiTracker.getGuideInfo(variableUUID)
if (gInfo.listenPort == SourceInfo.TxOverGoToHDFS ||
gInfo.listenPort == SourceInfo.TxNotStartedRetry) {
// TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
// to HDFS anyway when receiveBroadcast returns false
if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
return false
}
// Wait until hostAddress and listenPort are created by the
// ServeMultipleRequests thread
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait()
}
listenPortLock.synchronized { listenPortLock.wait() }
}
// Setup initial states of variables
@ -411,11 +270,8 @@ extends Broadcast[T] with Logging with Serializable {
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
hasBlocksBitVector = new BitSet(totalBlocks)
numCopiesSent = new Array[Int](totalBlocks)
totalBlocksLock.synchronized {
totalBlocksLock.notifyAll()
}
totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
totalBytes = gInfo.totalBytes
blockSize = gInfo.blockSize
// Start ttGuide to periodically talk to the Guide
var ttGuide = new TalkToGuide(gInfo)
@ -432,7 +288,7 @@ extends Broadcast[T] with Logging with Serializable {
// FIXME: Must fix this. This might never break if broadcast fails.
// We should be able to break and send false. Also need to kill threads
while (hasBlocks.get < totalBlocks) {
Thread.sleep(Broadcast.MaxKnockInterval)
Thread.sleep(MultiTracker.MaxKnockInterval)
}
return true
@ -446,36 +302,34 @@ extends Broadcast[T] with Logging with Serializable {
private var blocksInRequestBitVector = new BitSet(totalBlocks)
override def run() {
var threadPool = Utils.newDaemonFixedThreadPool(Broadcast.MaxRxSlots)
var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
while (hasBlocks.get < totalBlocks) {
var numThreadsToCreate =
math.min(listOfSources.size, Broadcast.MaxRxSlots) -
math.min(listOfSources.size, MultiTracker.MaxChatSlots) -
threadPool.getActiveCount
while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) {
var peerToTalkTo = pickPeerToTalkToRandom
if (peerToTalkTo != null)
logInfo("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector)
logDebug("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector)
else
logInfo("No peer chosen...")
logDebug("No peer chosen...")
if (peerToTalkTo != null) {
threadPool.execute(new TalkToPeer(peerToTalkTo))
// Add to peersNowTalking. Remove in the thread. We have to do this
// ASAP, otherwise pickPeerToTalkTo picks the same peer more than once
peersNowTalking.synchronized {
peersNowTalking += peerToTalkTo
}
peersNowTalking.synchronized { peersNowTalking += peerToTalkTo }
}
numThreadsToCreate = numThreadsToCreate - 1
}
// Sleep for a while before starting some more threads
Thread.sleep(Broadcast.MinKnockInterval)
Thread.sleep(MultiTracker.MinKnockInterval)
}
// Shutdown the thread pool
threadPool.shutdown()
@ -487,7 +341,7 @@ extends Broadcast[T] with Logging with Serializable {
var curPeer: SourceInfo = null
var curMax = 0
logInfo("Picking peers to talk to...")
logDebug("Picking peers to talk to...")
// Find peers that are not connected right now
var peersNotInUse = ListBuffer[SourceInfo]()
@ -512,11 +366,10 @@ extends Broadcast[T] with Logging with Serializable {
}
}
// TODO: Always pick randomly or randomly pick randomly?
// Now always picking randomly
// Always picking randomly
if (curPeer == null && peersNotInUse.size > 0) {
// Pick uniformly the i'th required peer
var i = BitTorrentBroadcast.ranGen.nextInt(peersNotInUse.size)
var i = MultiTracker.ranGen.nextInt(peersNotInUse.size)
var peerIter = peersNotInUse.iterator
curPeer = peerIter.next
@ -552,8 +405,8 @@ extends Broadcast[T] with Logging with Serializable {
}
}
// TODO: A block is rare if there are at most 2 copies of that block
// TODO: This CONSTANT could be a function of the neighborhood size
// A block is considered rare if there are at most 2 copies of that block
// This CONSTANT could be a function of the neighborhood size
var rareBlocksIndices = ListBuffer[Int]()
for (i <- 0 until totalBlocks) {
if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) {
@ -587,7 +440,7 @@ extends Broadcast[T] with Logging with Serializable {
// Sort the peers based on how many rare blocks they have
peersWithRareBlocks.sortBy(_._2)
var randomNumber = BitTorrentBroadcast.ranGen.nextDouble
var randomNumber = MultiTracker.ranGen.nextDouble
var tempSum = 0.0
var i = 0
@ -625,7 +478,7 @@ extends Broadcast[T] with Logging with Serializable {
}
var timeOutTimer = new Timer
timeOutTimer.schedule(timeOutTask, Broadcast.MaxKnockInterval)
timeOutTimer.schedule(timeOutTask, MultiTracker.MaxKnockInterval)
logInfo("TalkToPeer started... => " + peerToTalkTo)
@ -677,7 +530,7 @@ extends Broadcast[T] with Logging with Serializable {
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
val receptionTime = (System.currentTimeMillis - recvStartTime)
logInfo("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.")
logDebug("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.")
if (!hasBlocksBitVector.get(bcBlock.blockID)) {
arrayOfBlocks(bcBlock.blockID) = bcBlock
@ -688,8 +541,6 @@ extends Broadcast[T] with Logging with Serializable {
hasBlocks.getAndIncrement
}
rxSpeeds.addDataPoint(peerToTalkTo, receptionTime)
// Some block(may NOT be blockToAskFor) has arrived.
// In any case, blockToAskFor is not in request any more
blocksInRequestBitVector.synchronized {
@ -710,7 +561,7 @@ extends Broadcast[T] with Logging with Serializable {
// connection due to timeout
case eofe: java.io.EOFException => { }
case e: Exception => {
logInfo("TalktoPeer had a " + e)
logError("TalktoPeer had a " + e)
// FIXME: Remove 'newPeerToTalkTo' from listOfSources
// We probably should have the following in some form, but not
// really here. This exception can happen if the sender just breaks connection
@ -741,8 +592,8 @@ extends Broadcast[T] with Logging with Serializable {
}
// Include blocks already in transmission ONLY IF
// BitTorrentBroadcast.EndGameFraction has NOT been achieved
if ((1.0 * hasBlocks.get / totalBlocks) < Broadcast.EndGameFraction) {
// MultiTracker.EndGameFraction has NOT been achieved
if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) {
blocksInRequestBitVector.synchronized {
needBlocksBitVector.or(blocksInRequestBitVector)
}
@ -758,7 +609,7 @@ extends Broadcast[T] with Logging with Serializable {
return -1
} else {
// Pick uniformly the i'th required block
var i = BitTorrentBroadcast.ranGen.nextInt(needBlocksBitVector.cardinality)
var i = MultiTracker.ranGen.nextInt(needBlocksBitVector.cardinality)
var pickedBlockIndex = needBlocksBitVector.nextSetBit(0)
while (i > 0) {
@ -781,8 +632,8 @@ extends Broadcast[T] with Logging with Serializable {
}
// Include blocks already in transmission ONLY IF
// BitTorrentBroadcast.EndGameFraction has NOT been achieved
if ((1.0 * hasBlocks.get / totalBlocks) < Broadcast.EndGameFraction) {
// MultiTracker.EndGameFraction has NOT been achieved
if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) {
blocksInRequestBitVector.synchronized {
needBlocksBitVector.or(blocksInRequestBitVector)
}
@ -830,7 +681,7 @@ extends Broadcast[T] with Logging with Serializable {
return -1
} else {
// Pick uniformly the i'th index
var i = BitTorrentBroadcast.ranGen.nextInt(minBlocksIndices.size)
var i = MultiTracker.ranGen.nextInt(minBlocksIndices.size)
return minBlocksIndices(i)
}
}
@ -848,9 +699,7 @@ extends Broadcast[T] with Logging with Serializable {
}
// Delete from peersNowTalking
peersNowTalking.synchronized {
peersNowTalking = peersNowTalking - peerToTalkTo
}
peersNowTalking.synchronized { peersNowTalking -= peerToTalkTo }
}
}
}
@ -868,20 +717,18 @@ extends Broadcast[T] with Logging with Serializable {
guidePort = serverSocket.getLocalPort
logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
guidePortLock.synchronized {
guidePortLock.notifyAll()
}
guidePortLock.synchronized { guidePortLock.notifyAll() }
try {
// Don't stop until there is a copy in HDFS
while (!stopBroadcast || !hasCopyInHDFS) {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept()
} catch {
case e: Exception => {
logInfo("GuideMultipleRequests Timeout.")
logError("GuideMultipleRequests Timeout.")
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done. Comparing with
@ -893,7 +740,7 @@ extends Broadcast[T] with Logging with Serializable {
}
}
if (clientSocket != null) {
logInfo("Guide: Accepted new client connection:" + clientSocket)
logDebug("Guide: Accepted new client connection:" + clientSocket)
try {
threadPool.execute(new GuideSingleRequest(clientSocket))
} catch {
@ -911,7 +758,7 @@ extends Broadcast[T] with Logging with Serializable {
logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
unregisterBroadcast(uuid)
MultiTracker.unregisterBroadcast(uuid)
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
@ -930,13 +777,10 @@ extends Broadcast[T] with Logging with Serializable {
try {
// Connect to the source
guideSocketToSource =
new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
gosSource =
new ObjectOutputStream(guideSocketToSource.getOutputStream)
guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
gosSource.flush()
gisSource =
new ObjectInputStream(guideSocketToSource.getInputStream)
gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
// Throw away whatever comes in
gisSource.readObject.asInstanceOf[SourceInfo]
@ -946,7 +790,7 @@ extends Broadcast[T] with Logging with Serializable {
gosSource.flush()
} catch {
case e: Exception => {
logInfo("sendStopBroadcastNotifications had a " + e)
logError("sendStopBroadcastNotifications had a " + e)
}
} finally {
if (gisSource != null) {
@ -980,7 +824,7 @@ extends Broadcast[T] with Logging with Serializable {
// Select a suitable source and send it back to the worker
selectedSources = selectSuitableSources(sourceInfo)
logInfo("Sending selectedSources:" + selectedSources)
logDebug("Sending selectedSources:" + selectedSources)
oos.writeObject(selectedSources)
oos.flush()
@ -990,12 +834,11 @@ extends Broadcast[T] with Logging with Serializable {
case e: Exception => {
// Assuming exception caused by receiver failure: remove
if (listOfSources != null) {
listOfSources.synchronized {
listOfSources = listOfSources - sourceInfo
}
listOfSources.synchronized { listOfSources -= sourceInfo }
}
}
} finally {
logInfo("GuideSingleRequest is closing streams and sockets")
ois.close()
oos.close()
clientSocket.close()
@ -1009,24 +852,22 @@ extends Broadcast[T] with Logging with Serializable {
// If skipSourceInfo.hasBlocksBitVector has all bits set to 'true'
// then add skipSourceInfo to setOfCompletedSources. Return blank.
if (skipSourceInfo.hasBlocks == totalBlocks) {
setOfCompletedSources.synchronized {
setOfCompletedSources += skipSourceInfo
}
setOfCompletedSources.synchronized { setOfCompletedSources += skipSourceInfo }
return selectedSources
}
listOfSources.synchronized {
if (listOfSources.size <= Broadcast.MaxPeersInGuideResponse) {
if (listOfSources.size <= MultiTracker.MaxPeersInGuideResponse) {
selectedSources = listOfSources.clone
} else {
var picksLeft = Broadcast.MaxPeersInGuideResponse
var picksLeft = MultiTracker.MaxPeersInGuideResponse
var alreadyPicked = new BitSet(listOfSources.size)
while (picksLeft > 0) {
var i = -1
do {
i = BitTorrentBroadcast.ranGen.nextInt(listOfSources.size)
i = MultiTracker.ranGen.nextInt(listOfSources.size)
} while (alreadyPicked.get(i))
var peerIter = listOfSources.iterator
@ -1057,8 +898,8 @@ extends Broadcast[T] with Logging with Serializable {
class ServeMultipleRequests
extends Thread with Logging {
// Server at most Broadcast.MaxTxSlots peers
var threadPool = Utils.newDaemonFixedThreadPool(Broadcast.MaxTxSlots)
// Server at most MultiTracker.MaxChatSlots peers
var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
override def run() {
var serverSocket = new ServerSocket(0)
@ -1066,30 +907,26 @@ extends Broadcast[T] with Logging with Serializable {
logInfo("ServeMultipleRequests started with " + serverSocket)
listenPortLock.synchronized {
listenPortLock.notifyAll()
}
listenPortLock.synchronized { listenPortLock.notifyAll() }
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept()
} catch {
case e: Exception => {
logInfo("ServeMultipleRequests Timeout.")
logError("ServeMultipleRequests Timeout.")
}
}
if (clientSocket != null) {
logInfo("Serve: Accepted new client connection:" + clientSocket)
logDebug("Serve: Accepted new client connection:" + clientSocket)
try {
threadPool.execute(new ServeSingleRequest(clientSocket))
} catch {
// In failure, close socket here; else, the thread will close it
case ioe: IOException => {
clientSocket.close()
}
case ioe: IOException => clientSocket.close()
}
}
}
@ -1125,14 +962,13 @@ extends Broadcast[T] with Logging with Serializable {
if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) {
stopBroadcast = true
} else {
// Carry on
addToListOfSources(rxSourceInfo)
}
val startTime = System.currentTimeMillis
var curTime = startTime
var keepSending = true
var numBlocksToSend = Broadcast.MaxChatBlocks
var numBlocksToSend = MultiTracker.MaxChatBlocks
while (!stopBroadcast && keepSending && numBlocksToSend > 0) {
// Receive which block to send
@ -1140,7 +976,7 @@ extends Broadcast[T] with Logging with Serializable {
// If it is master AND at least one copy of each block has not been
// sent out already, MODIFY blockToSend
if (BitTorrentBroadcast.isMaster && sentBlocks.get < totalBlocks) {
if (MultiTracker.isMaster && sentBlocks.get < totalBlocks) {
blockToSend = sentBlocks.getAndIncrement
}
@ -1152,27 +988,21 @@ extends Broadcast[T] with Logging with Serializable {
// Receive latest SourceInfo from the receiver
rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
// logInfo("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector)
logDebug("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector)
addToListOfSources(rxSourceInfo)
curTime = System.currentTimeMillis
// Revoke sending only if there is anyone waiting in the queue
if (curTime - startTime >= Broadcast.MaxChatTime &&
if (curTime - startTime >= MultiTracker.MaxChatTime &&
threadPool.getQueue.size > 0) {
keepSending = false
}
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close everything up
// Exception can happen if the receiver stops receiving
case e: Exception => {
logInfo("ServeSingleRequest had a " + e)
}
case e: Exception => logError("ServeSingleRequest had a " + e)
} finally {
logInfo("ServeSingleRequest is closing streams and sockets")
ois.close()
// TODO: The following line causes a "java.net.SocketException: Socket closed"
oos.close()
clientSocket.close()
}
@ -1183,11 +1013,9 @@ extends Broadcast[T] with Logging with Serializable {
oos.writeObject(arrayOfBlocks(blockToSend))
oos.flush()
} catch {
case e: Exception => {
logInfo("sendBlock had a " + e)
}
case e: Exception => logError("sendBlock had a " + e)
}
logInfo("Sent block: " + blockToSend + " to " + clientSocket)
logDebug("Sent block: " + blockToSend + " to " + clientSocket)
}
}
}
@ -1195,161 +1023,7 @@ extends Broadcast[T] with Logging with Serializable {
class BitTorrentBroadcastFactory
extends BroadcastFactory {
def initialize(isMaster: Boolean) {
BitTorrentBroadcast.initialize(isMaster)
}
def newBroadcast[T](value_ : T, isLocal: Boolean) = {
new BitTorrentBroadcast[T](value_, isLocal)
}
}
private object BitTorrentBroadcast
extends Logging {
val values = SparkEnv.get.cache.newKeySpace()
var valueToGuideMap = Map[UUID, SourceInfo]()
// Random number generator
var ranGen = new Random
private var initialized = false
private var isMaster_ = false
private var trackMV: TrackMultipleValues = null
def initialize(isMaster__ : Boolean) {
synchronized {
if (!initialized) {
isMaster_ = isMaster__
if (isMaster) {
trackMV = new TrackMultipleValues
trackMV.setDaemon(true)
trackMV.start()
// TODO: Logging the following line makes the Spark framework ID not
// getting logged, cause it calls logInfo before log4j is initialized
logInfo("TrackMultipleValues started...")
}
// Initialize DfsBroadcast to be used for broadcast variable persistence
// TODO: Think about persistence
DfsBroadcast.initialize
initialized = true
}
}
}
def isMaster = isMaster_
class TrackMultipleValues
extends Thread with Logging {
override def run() {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(Broadcast.MasterTrackerPort)
logInfo("TrackMultipleValues" + serverSocket)
try {
while (true) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout)
clientSocket = serverSocket.accept()
} catch {
case e: Exception => {
logInfo("TrackMultipleValues Timeout. Stopping listening...")
}
}
if (clientSocket != null) {
try {
threadPool.execute(new Thread {
override def run() {
val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
val ois = new ObjectInputStream(clientSocket.getInputStream)
try {
// First, read message type
val messageType = ois.readObject.asInstanceOf[Int]
if (messageType == Broadcast.REGISTER_BROADCAST_TRACKER) {
// Receive UUID
val uuid = ois.readObject.asInstanceOf[UUID]
// Receive hostAddress and listenPort
val gInfo = ois.readObject.asInstanceOf[SourceInfo]
// Add to the map
valueToGuideMap.synchronized {
valueToGuideMap += (uuid -> gInfo)
}
logInfo ("New broadcast registered with TrackMultipleValues " + uuid + " " + valueToGuideMap)
// Send dummy ACK
oos.writeObject(-1)
oos.flush()
} else if (messageType == Broadcast.UNREGISTER_BROADCAST_TRACKER) {
// Receive UUID
val uuid = ois.readObject.asInstanceOf[UUID]
// Remove from the map
valueToGuideMap.synchronized {
valueToGuideMap(uuid) = SourceInfo("", SourceInfo.TxOverGoToHDFS)
logInfo("Value unregistered from the Tracker " + valueToGuideMap)
}
logInfo ("Broadcast unregistered from TrackMultipleValues " + uuid + " " + valueToGuideMap)
// Send dummy ACK
oos.writeObject(-1)
oos.flush()
} else if (messageType == Broadcast.FIND_BROADCAST_TRACKER) {
// Receive UUID
val uuid = ois.readObject.asInstanceOf[UUID]
var gInfo =
if (valueToGuideMap.contains(uuid)) valueToGuideMap(uuid)
else SourceInfo("", SourceInfo.TxNotStartedRetry)
logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort)
// Send reply back
oos.writeObject(gInfo)
oos.flush()
} else if (messageType == Broadcast.GET_UPDATED_SHARE) {
// TODO: Not implemented
} else {
throw new SparkException("Undefined messageType at TrackMultipleValues")
}
} catch {
case e: Exception => {
logInfo("TrackMultipleValues had a " + e)
}
} finally {
ois.close()
oos.close()
clientSocket.close()
}
}
})
} catch {
// In failure, close socket here; else, client thread will close
case ioe: IOException => {
clientSocket.close()
}
}
}
}
} finally {
serverSocket.close()
}
// Shutdown the thread pool
threadPool.shutdown()
}
}
def initialize(isMaster: Boolean) = MultiTracker.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) = new BitTorrentBroadcast[T](value_, isLocal)
def stop() = MultiTracker.stop
}

View file

@ -5,6 +5,8 @@ import java.net._
import java.util.{BitSet, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import scala.collection.mutable.Map
import spark._
trait Broadcast[T] extends Serializable {
@ -13,24 +15,20 @@ trait Broadcast[T] extends Serializable {
def value: T
// We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes. Possibly a Scala bug!
// readObject having to be 'private' in sub-classes.
override def toString = "spark.Broadcast(" + uuid + ")"
}
object Broadcast extends Logging with Serializable {
// Messages
val REGISTER_BROADCAST_TRACKER = 0
val UNREGISTER_BROADCAST_TRACKER = 1
val FIND_BROADCAST_TRACKER = 2
val GET_UPDATED_SHARE = 3
class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializable {
private var initialized = false
private var isMaster_ = false
private var broadcastFactory: BroadcastFactory = null
initialize()
// Called by SparkContext or Executor before using Broadcast
def initialize (isMaster__ : Boolean) {
private def initialize() {
synchronized {
if (!initialized) {
val broadcastFactoryClass = System.getProperty(
@ -39,14 +37,6 @@ object Broadcast extends Logging with Serializable {
broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
// Setup isMaster before using it
isMaster_ = isMaster__
// Set masterHostAddress to the master's IP address for the slaves to read
if (isMaster) {
System.setProperty("spark.broadcast.masterHostAddress", Utils.localIpAddress)
}
// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isMaster)
@ -55,170 +45,18 @@ object Broadcast extends Logging with Serializable {
}
}
def getBroadcastFactory: BroadcastFactory = {
def stop() {
broadcastFactory.stop()
}
private def getBroadcastFactory: BroadcastFactory = {
if (broadcastFactory == null) {
throw new SparkException ("Broadcast.getBroadcastFactory called before initialize")
}
broadcastFactory
}
// Load common broadcast-related config parameters
private var MasterHostAddress_ = System.getProperty(
"spark.broadcast.masterHostAddress", "")
private var MasterTrackerPort_ = System.getProperty(
"spark.broadcast.masterTrackerPort", "11111").toInt
private var BlockSize_ = System.getProperty(
"spark.broadcast.blockSize", "4096").toInt * 1024
private var MaxRetryCount_ = System.getProperty(
"spark.broadcast.maxRetryCount", "2").toInt
private var TrackerSocketTimeout_ = System.getProperty(
"spark.broadcast.trackerSocketTimeout", "50000").toInt
private var ServerSocketTimeout_ = System.getProperty(
"spark.broadcast.serverSocketTimeout", "10000").toInt
private var MinKnockInterval_ = System.getProperty(
"spark.broadcast.minKnockInterval", "500").toInt
private var MaxKnockInterval_ = System.getProperty(
"spark.broadcast.maxKnockInterval", "999").toInt
// Load ChainedBroadcast config params
// Load TreeBroadcast config params
private var MaxDegree_ = System.getProperty("spark.broadcast.maxDegree", "2").toInt
// Load BitTorrentBroadcast config params
private var MaxPeersInGuideResponse_ = System.getProperty(
"spark.broadcast.maxPeersInGuideResponse", "4").toInt
private var MaxRxSlots_ = System.getProperty("spark.broadcast.maxRxSlots", "4").toInt
private var MaxTxSlots_ = System.getProperty("spark.broadcast.maxTxSlots", "4").toInt
private var MaxChatTime_ = System.getProperty("spark.broadcast.maxChatTime", "500").toInt
private var MaxChatBlocks_ = System.getProperty("spark.broadcast.maxChatBlocks", "1024").toInt
private var EndGameFraction_ = System.getProperty(
"spark.broadcast.endGameFraction", "0.95").toDouble
def newBroadcast[T](value_ : T, isLocal: Boolean) = broadcastFactory.newBroadcast[T](value_, isLocal)
def isMaster = isMaster_
// Common config params
def MasterHostAddress = MasterHostAddress_
def MasterTrackerPort = MasterTrackerPort_
def BlockSize = BlockSize_
def MaxRetryCount = MaxRetryCount_
def TrackerSocketTimeout = TrackerSocketTimeout_
def ServerSocketTimeout = ServerSocketTimeout_
def MinKnockInterval = MinKnockInterval_
def MaxKnockInterval = MaxKnockInterval_
// ChainedBroadcast configs
// TreeBroadcast configs
def MaxDegree = MaxDegree_
// BitTorrentBroadcast configs
def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
def MaxRxSlots = MaxRxSlots_
def MaxTxSlots = MaxTxSlots_
def MaxChatTime = MaxChatTime_
def MaxChatBlocks = MaxChatBlocks_
def EndGameFraction = EndGameFraction_
// Helper functions to convert an object to Array[BroadcastBlock]
def blockifyObject[IN](obj: IN): VariableInfo = {
val baos = new ByteArrayOutputStream
val oos = new ObjectOutputStream(baos)
oos.writeObject(obj)
oos.close()
baos.close()
val byteArray = baos.toByteArray
val bais = new ByteArrayInputStream(byteArray)
var blockNum = (byteArray.length / Broadcast.BlockSize)
if (byteArray.length % Broadcast.BlockSize != 0)
blockNum += 1
var retVal = new Array[BroadcastBlock](blockNum)
var blockID = 0
for (i <- 0 until (byteArray.length, Broadcast.BlockSize)) {
val thisBlockSize = math.min(Broadcast.BlockSize, byteArray.length - i)
var tempByteArray = new Array[Byte](thisBlockSize)
val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
blockID += 1
}
bais.close()
var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
variableInfo.hasBlocks = blockNum
return variableInfo
}
// Helper function to convert Array[BroadcastBlock] to object
def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
totalBytes: Int,
totalBlocks: Int): OUT = {
var retByteArray = new Array[Byte](totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
i * Broadcast.BlockSize, arrayOfBlocks(i).byteArray.length)
}
byteArrayToObject(retByteArray)
}
private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
val in = new ObjectInputStream (new ByteArrayInputStream (bytes)) {
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
}
val retVal = in.readObject.asInstanceOf[OUT]
in.close()
return retVal
}
}
case class BroadcastBlock (blockID: Int, byteArray: Array[Byte]) extends Serializable
case class VariableInfo (@transient arrayOfBlocks : Array[BroadcastBlock],
totalBlocks: Int,
totalBytes: Int)
extends Serializable {
@transient
var hasBlocks = 0
}
class SpeedTracker extends Serializable {
// Mapping 'source' to '(totalTime, numBlocks)'
private var sourceToSpeedMap = Map[SourceInfo, (Long, Int)] ()
def addDataPoint (srcInfo: SourceInfo, timeInMillis: Long) {
sourceToSpeedMap.synchronized {
if (!sourceToSpeedMap.contains(srcInfo)) {
sourceToSpeedMap += (srcInfo -> (timeInMillis, 1))
} else {
val tTnB = sourceToSpeedMap (srcInfo)
sourceToSpeedMap += (srcInfo -> (tTnB._1 + timeInMillis, tTnB._2 + 1))
}
}
}
def getTimePerBlock (srcInfo: SourceInfo): Double = {
sourceToSpeedMap.synchronized {
val tTnB = sourceToSpeedMap (srcInfo)
return tTnB._1 / tTnB._2
}
}
override def toString = sourceToSpeedMap.toString
}

View file

@ -9,4 +9,5 @@ package spark.broadcast
trait BroadcastFactory {
def initialize(isMaster: Boolean): Unit
def newBroadcast[T](value_ : T, isLocal: Boolean): Broadcast[T]
def stop(): Unit
}

View file

@ -1,794 +0,0 @@
package spark.broadcast
import java.io._
import java.net._
import java.util.{Comparator, PriorityQueue, Random, UUID}
import scala.collection.mutable.{Map, Set}
import scala.math
import spark._
class ChainedBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging with Serializable {
def value = value_
ChainedBroadcast.synchronized {
ChainedBroadcast.values.put(uuid, 0, value_)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@transient var totalBytes = -1
@transient var totalBlocks = -1
@transient var hasBlocks = 0
// CHANGED: BlockSize in the Broadcast object is expected to change over time
@transient var blockSize = Broadcast.BlockSize
@transient var listenPortLock = new Object
@transient var guidePortLock = new Object
@transient var totalBlocksLock = new Object
@transient var hasBlocksLock = new Object
@transient var pqOfSources = new PriorityQueue[SourceInfo]
@transient var serveMR: ServeMultipleRequests = null
@transient var guideMR: GuideMultipleRequests = null
@transient var hostAddress = Utils.localIpAddress
@transient var listenPort = -1
@transient var guidePort = -1
@transient var hasCopyInHDFS = false
@transient var stopBroadcast = false
// Must call this after all the variables have been created/initialized
if (!isLocal) {
sendBroadcast
}
def sendBroadcast() {
logInfo("Local host address: " + hostAddress)
// Store a persistent copy in HDFS
// TODO: Turned OFF for now
// val out = new ObjectOutputStream(DfsBroadcast.openFileForWriting(uuid))
// out.writeObject(value_)
// out.close()
// TODO: Fix this at some point
hasCopyInHDFS = true
// Create a variableInfo object and store it in valueInfos
var variableInfo = Broadcast.blockifyObject(value_)
guideMR = new GuideMultipleRequests
guideMR.setDaemon(true)
guideMR.start()
logInfo("GuideMultipleRequests started...")
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start()
logInfo("ServeMultipleRequests started...")
// Prepare the value being broadcasted
// TODO: Refactoring and clean-up required here
arrayOfBlocks = variableInfo.arrayOfBlocks
totalBytes = variableInfo.totalBytes
totalBlocks = variableInfo.totalBlocks
hasBlocks = variableInfo.totalBlocks
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait()
}
}
pqOfSources = new PriorityQueue[SourceInfo]
val masterSource =
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize)
pqOfSources.add(masterSource)
// Register with the Tracker
while (guidePort == -1) {
guidePortLock.synchronized {
guidePortLock.wait()
}
}
ChainedBroadcast.registerValue(uuid, guidePort)
}
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
ChainedBroadcast.synchronized {
val cachedVal = ChainedBroadcast.values.get(uuid, 0)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
// Initializing everything because Master will only send null/0 values
initializeSlaveVariables
logInfo("Local host address: " + hostAddress)
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start()
logInfo("ServeMultipleRequests started...")
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast(uuid)
// If does not succeed, then get from HDFS copy
if (receptionSucceeded) {
value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
ChainedBroadcast.values.put(uuid, 0, value_)
} else {
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
ChainedBroadcast.values.put(uuid, 0, value_)
fileIn.close()
}
val time =(System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
}
}
}
private def initializeSlaveVariables() {
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
hasBlocks = 0
blockSize = -1
listenPortLock = new Object
totalBlocksLock = new Object
hasBlocksLock = new Object
serveMR = null
hostAddress = Utils.localIpAddress
listenPort = -1
stopBroadcast = false
}
def getMasterListenPort(variableUUID: UUID): Int = {
var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
var retriesLeft = Broadcast.MaxRetryCount
do {
try {
// Connect to the tracker to find out the guide
clientSocketToTracker =
new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort)
oosTracker =
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
oosTracker.flush()
oisTracker =
new ObjectInputStream(clientSocketToTracker.getInputStream)
// Send UUID and receive masterListenPort
oosTracker.writeObject(uuid)
oosTracker.flush()
masterListenPort = oisTracker.readObject.asInstanceOf[Int]
} catch {
case e: Exception => {
logInfo("getMasterListenPort had a " + e)
}
} finally {
if (oisTracker != null) {
oisTracker.close()
}
if (oosTracker != null) {
oosTracker.close()
}
if (clientSocketToTracker != null) {
clientSocketToTracker.close()
}
}
retriesLeft -= 1
Thread.sleep(ChainedBroadcast.ranGen.nextInt(
Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) +
Broadcast.MinKnockInterval)
} while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
logInfo("Got this guidePort from Tracker: " + masterListenPort)
return masterListenPort
}
def receiveBroadcast(variableUUID: UUID): Boolean = {
val masterListenPort = getMasterListenPort(variableUUID)
if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
masterListenPort == SourceInfo.TxNotStartedRetry) {
// TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
// to HDFS anyway when receiveBroadcast returns false
return false
}
// Wait until hostAddress and listenPort are created by the
// ServeMultipleRequests thread
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait()
}
}
var clientSocketToMaster: Socket = null
var oosMaster: ObjectOutputStream = null
var oisMaster: ObjectInputStream = null
// Connect and receive broadcast from the specified source, retrying the
// specified number of times in case of failures
var retriesLeft = Broadcast.MaxRetryCount
do {
// Connect to Master and send this worker's Information
clientSocketToMaster =
new Socket(Broadcast.MasterHostAddress, masterListenPort)
// TODO: Guiding object connection is reusable
oosMaster =
new ObjectOutputStream(clientSocketToMaster.getOutputStream)
oosMaster.flush()
oisMaster =
new ObjectInputStream(clientSocketToMaster.getInputStream)
logInfo("Connected to Master's guiding object")
// Send local source information
oosMaster.writeObject(SourceInfo(hostAddress, listenPort))
oosMaster.flush()
// Receive source information from Master
var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
totalBlocks = sourceInfo.totalBlocks
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
totalBlocksLock.synchronized {
totalBlocksLock.notifyAll()
}
totalBytes = sourceInfo.totalBytes
logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
val start = System.nanoTime
val receptionSucceeded = receiveSingleTransmission(sourceInfo)
val time =(System.nanoTime - start) / 1e9
// Updating some statistics in sourceInfo. Master will be using them later
if (!receptionSucceeded) {
sourceInfo.receptionFailed = true
}
// Send back statistics to the Master
oosMaster.writeObject(sourceInfo)
if (oisMaster != null) {
oisMaster.close()
}
if (oosMaster != null) {
oosMaster.close()
}
if (clientSocketToMaster != null) {
clientSocketToMaster.close()
}
retriesLeft -= 1
} while (retriesLeft > 0 && hasBlocks < totalBlocks)
return(hasBlocks == totalBlocks)
}
// Tries to receive broadcast from the source and returns Boolean status.
// This might be called multiple times to retry a defined number of times.
private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
var clientSocketToSource: Socket = null
var oosSource: ObjectOutputStream = null
var oisSource: ObjectInputStream = null
var receptionSucceeded = false
try {
// Connect to the source to get the object itself
clientSocketToSource =
new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
oosSource =
new ObjectOutputStream(clientSocketToSource.getOutputStream)
oosSource.flush()
oisSource =
new ObjectInputStream(clientSocketToSource.getInputStream)
logInfo("Inside receiveSingleTransmission")
logInfo("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
// Send the range
oosSource.writeObject((hasBlocks, totalBlocks))
oosSource.flush()
for (i <- hasBlocks until totalBlocks) {
val recvStartTime = System.currentTimeMillis
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
val receptionTime =(System.currentTimeMillis - recvStartTime)
logInfo("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
arrayOfBlocks(hasBlocks) = bcBlock
hasBlocks += 1
// Set to true if at least one block is received
receptionSucceeded = true
hasBlocksLock.synchronized {
hasBlocksLock.notifyAll()
}
}
} catch {
case e: Exception => {
logInfo("receiveSingleTransmission had a " + e)
}
} finally {
if (oisSource != null) {
oisSource.close()
}
if (oosSource != null) {
oosSource.close()
}
if (clientSocketToSource != null) {
clientSocketToSource.close()
}
}
return receptionSucceeded
}
class GuideMultipleRequests
extends Thread with Logging {
// Keep track of sources that have completed reception
private var setOfCompletedSources = Set[SourceInfo]()
override def run() {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(0)
guidePort = serverSocket.getLocalPort
logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
guidePortLock.synchronized {
guidePortLock.notifyAll()
}
try {
// Don't stop until there is a copy in HDFS
while (!stopBroadcast || !hasCopyInHDFS) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo("GuideMultipleRequests Timeout.")
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done. Comparing with
// pqOfSources.size - 1, because it includes the Guide itself
if (pqOfSources.size > 1 &&
setOfCompletedSources.size == pqOfSources.size - 1) {
stopBroadcast = true
}
}
}
if (clientSocket != null) {
logInfo("Guide: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new GuideSingleRequest(clientSocket))
} catch {
// In failure, close the socket here; else, the thread will close it
case ioe: IOException => clientSocket.close()
}
}
}
logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
ChainedBroadcast.unregisterValue(uuid)
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
serverSocket.close()
}
}
// Shutdown the thread pool
threadPool.shutdown()
}
private def sendStopBroadcastNotifications() {
pqOfSources.synchronized {
var pqIter = pqOfSources.iterator
while (pqIter.hasNext) {
var sourceInfo = pqIter.next
var guideSocketToSource: Socket = null
var gosSource: ObjectOutputStream = null
var gisSource: ObjectInputStream = null
try {
// Connect to the source
guideSocketToSource =
new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
gosSource =
new ObjectOutputStream(guideSocketToSource.getOutputStream)
gosSource.flush()
gisSource =
new ObjectInputStream(guideSocketToSource.getInputStream)
// Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
gosSource.writeObject((SourceInfo.StopBroadcast,
SourceInfo.StopBroadcast))
gosSource.flush()
} catch {
case e: Exception => {
logInfo("sendStopBroadcastNotifications had a " + e)
}
} finally {
if (gisSource != null) {
gisSource.close()
}
if (gosSource != null) {
gosSource.close()
}
if (guideSocketToSource != null) {
guideSocketToSource.close()
}
}
}
}
}
class GuideSingleRequest(val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
private val ois = new ObjectInputStream(clientSocket.getInputStream)
private var selectedSourceInfo: SourceInfo = null
private var thisWorkerInfo:SourceInfo = null
override def run() {
try {
logInfo("new GuideSingleRequest is running")
// Connecting worker is sending in its hostAddress and listenPort it will
// be listening to. Other fields are invalid(SourceInfo.UnusedParam)
var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
pqOfSources.synchronized {
// Select a suitable source and send it back to the worker
selectedSourceInfo = selectSuitableSource(sourceInfo)
logInfo("Sending selectedSourceInfo: " + selectedSourceInfo)
oos.writeObject(selectedSourceInfo)
oos.flush()
// Add this new(if it can finish) source to the PQ of sources
thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
sourceInfo.listenPort, totalBlocks, totalBytes, blockSize)
logInfo("Adding possible new source to pqOfSources: " + thisWorkerInfo)
pqOfSources.add(thisWorkerInfo)
}
// Wait till the whole transfer is done. Then receive and update source
// statistics in pqOfSources
sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
pqOfSources.synchronized {
// This should work since SourceInfo is a case class
assert(pqOfSources.contains(selectedSourceInfo))
// Remove first
pqOfSources.remove(selectedSourceInfo)
// TODO: Removing a source based on just one failure notification!
// Update sourceInfo and put it back in, IF reception succeeded
if (!sourceInfo.receptionFailed) {
// Add thisWorkerInfo to sources that have completed reception
setOfCompletedSources.synchronized {
setOfCompletedSources += thisWorkerInfo
}
selectedSourceInfo.currentLeechers -= 1
// Put it back
pqOfSources.add(selectedSourceInfo)
}
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close everything up
case e: Exception => {
// Assuming that exception caused due to receiver worker failure.
// Remove failed worker from pqOfSources and update leecherCount of
// corresponding source worker
pqOfSources.synchronized {
if (selectedSourceInfo != null) {
// Remove first
pqOfSources.remove(selectedSourceInfo)
// Update leecher count and put it back in
selectedSourceInfo.currentLeechers -= 1
pqOfSources.add(selectedSourceInfo)
}
// Remove thisWorkerInfo
if (pqOfSources != null) {
pqOfSources.remove(thisWorkerInfo)
}
}
}
} finally {
ois.close()
oos.close()
clientSocket.close()
}
}
// FIXME: Caller must have a synchronized block on pqOfSources
// FIXME: If a worker fails to get the broadcasted variable from a source and
// comes back to Master, this function might choose the worker itself as a
// source tp create a dependency cycle(this worker was put into pqOfSources
// as a streming source when it first arrived). The length of this cycle can
// be arbitrarily long.
private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
// Select one based on the ordering strategy(e.g., least leechers etc.)
// take is a blocking call removing the element from PQ
var selectedSource = pqOfSources.poll
assert(selectedSource != null)
// Update leecher count
selectedSource.currentLeechers += 1
// Add it back and then return
pqOfSources.add(selectedSource)
return selectedSource
}
}
}
class ServeMultipleRequests
extends Thread with Logging {
override def run() {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(0)
listenPort = serverSocket.getLocalPort
logInfo("ServeMultipleRequests started with " + serverSocket)
listenPortLock.synchronized {
listenPortLock.notifyAll()
}
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo("ServeMultipleRequests Timeout.")
}
}
if (clientSocket != null) {
logInfo("Serve: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new ServeSingleRequest(clientSocket))
} catch {
// In failure, close socket here; else, the thread will close it
case ioe: IOException => clientSocket.close()
}
}
}
} finally {
if (serverSocket != null) {
logInfo("ServeMultipleRequests now stopping...")
serverSocket.close()
}
}
// Shutdown the thread pool
threadPool.shutdown()
}
class ServeSingleRequest(val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
private val ois = new ObjectInputStream(clientSocket.getInputStream)
private var sendFrom = 0
private var sendUntil = totalBlocks
override def run() {
try {
logInfo("new ServeSingleRequest is running")
// Receive range to send
var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
sendFrom = rangeToSend._1
sendUntil = rangeToSend._2
if (sendFrom == SourceInfo.StopBroadcast &&
sendUntil == SourceInfo.StopBroadcast) {
stopBroadcast = true
} else {
// Carry on
sendObject
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close everything up
case e: Exception => {
logInfo("ServeSingleRequest had a " + e)
}
} finally {
logInfo("ServeSingleRequest is closing streams and sockets")
ois.close()
oos.close()
clientSocket.close()
}
}
private def sendObject() {
// Wait till receiving the SourceInfo from Master
while (totalBlocks == -1) {
totalBlocksLock.synchronized {
totalBlocksLock.wait()
}
}
for (i <- sendFrom until sendUntil) {
while (i == hasBlocks) {
hasBlocksLock.synchronized {
hasBlocksLock.wait()
}
}
try {
oos.writeObject(arrayOfBlocks(i))
oos.flush()
} catch {
case e: Exception => {
logInfo("sendObject had a " + e)
}
}
logInfo("Sent block: " + i + " to " + clientSocket)
}
}
}
}
}
class ChainedBroadcastFactory
extends BroadcastFactory {
def initialize(isMaster: Boolean) {
ChainedBroadcast.initialize(isMaster)
}
def newBroadcast[T](value_ : T, isLocal: Boolean) = {
new ChainedBroadcast[T](value_, isLocal)
}
}
private object ChainedBroadcast
extends Logging {
val values = SparkEnv.get.cache.newKeySpace()
var valueToGuidePortMap = Map[UUID, Int]()
// Random number generator
var ranGen = new Random
private var initialized = false
private var isMaster_ = false
private var trackMV: TrackMultipleValues = null
def initialize(isMaster__ : Boolean) {
synchronized {
if (!initialized) {
isMaster_ = isMaster__
if (isMaster) {
trackMV = new TrackMultipleValues
trackMV.setDaemon(true)
trackMV.start()
// TODO: Logging the following line makes the Spark framework ID not
// getting logged, cause it calls logInfo before log4j is initialized
logInfo("TrackMultipleValues started...")
}
// Initialize DfsBroadcast to be used for broadcast variable persistence
DfsBroadcast.initialize
initialized = true
}
}
}
def isMaster = isMaster_
def registerValue(uuid: UUID, guidePort: Int) {
valueToGuidePortMap.synchronized {
valueToGuidePortMap +=(uuid -> guidePort)
logInfo("New value registered with the Tracker " + valueToGuidePortMap)
}
}
def unregisterValue(uuid: UUID) {
valueToGuidePortMap.synchronized {
valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToHDFS
logInfo("Value unregistered from the Tracker " + valueToGuidePortMap)
}
}
class TrackMultipleValues
extends Thread with Logging {
override def run() {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(Broadcast.MasterTrackerPort)
logInfo("TrackMultipleValues" + serverSocket)
try {
while (true) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo("TrackMultipleValues Timeout. Stopping listening...")
}
}
if (clientSocket != null) {
try {
threadPool.execute(new Thread {
override def run() {
val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
val ois = new ObjectInputStream(clientSocket.getInputStream)
try {
val uuid = ois.readObject.asInstanceOf[UUID]
var guidePort =
if (valueToGuidePortMap.contains(uuid)) {
valueToGuidePortMap(uuid)
} else SourceInfo.TxNotStartedRetry
logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
oos.writeObject(guidePort)
} catch {
case e: Exception => {
logInfo("TrackMultipleValues had a " + e)
}
} finally {
ois.close()
oos.close()
clientSocket.close()
}
}
})
} catch {
// In failure, close socket here; else, client thread will close
case ioe: IOException => clientSocket.close()
}
}
}
} finally {
serverSocket.close()
}
// Shutdown the thread pool
threadPool.shutdown()
}
}
}

View file

@ -1,135 +0,0 @@
package spark.broadcast
import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import java.io._
import java.net._
import java.util.UUID
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
import spark._
class DfsBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging with Serializable {
def value = value_
DfsBroadcast.synchronized {
DfsBroadcast.values.put(uuid, 0, value_)
}
if (!isLocal) {
sendBroadcast
}
def sendBroadcast () {
val out = new ObjectOutputStream (DfsBroadcast.openFileForWriting(uuid))
out.writeObject (value_)
out.close()
}
// Called by JVM when deserializing an object
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
DfsBroadcast.synchronized {
val cachedVal = DfsBroadcast.values.get(uuid, 0)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
logInfo( "Started reading Broadcasted variable " + uuid)
val start = System.nanoTime
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
DfsBroadcast.values.put(uuid, 0, value_)
fileIn.close()
val time = (System.nanoTime - start) / 1e9
logInfo( "Reading Broadcasted variable " + uuid + " took " + time + " s")
}
}
}
}
class DfsBroadcastFactory
extends BroadcastFactory {
def initialize (isMaster: Boolean) {
DfsBroadcast.initialize
}
def newBroadcast[T] (value_ : T, isLocal: Boolean) =
new DfsBroadcast[T] (value_, isLocal)
}
private object DfsBroadcast
extends Logging {
val values = SparkEnv.get.cache.newKeySpace()
private var initialized = false
private var fileSystem: FileSystem = null
private var workDir: String = null
private var compress: Boolean = false
private var bufferSize: Int = 65536
def initialize() {
synchronized {
if (!initialized) {
bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val dfs = System.getProperty("spark.dfs", "file:///")
if (!dfs.startsWith("file://")) {
val conf = new Configuration()
conf.setInt("io.file.buffer.size", bufferSize)
val rep = System.getProperty("spark.dfs.replication", "3").toInt
conf.setInt("dfs.replication", rep)
fileSystem = FileSystem.get(new URI(dfs), conf)
}
workDir = System.getProperty("spark.dfs.workDir", "/tmp")
compress = System.getProperty("spark.compress", "false").toBoolean
initialized = true
}
}
}
private def getPath(uuid: UUID) = new Path(workDir + "/broadcast-" + uuid)
def openFileForReading(uuid: UUID): InputStream = {
val fileStream = if (fileSystem != null) {
fileSystem.open(getPath(uuid))
} else {
// Local filesystem
new FileInputStream(getPath(uuid).toString)
}
if (compress) {
// LZF stream does its own buffering
new LZFInputStream(fileStream)
} else if (fileSystem == null) {
new BufferedInputStream(fileStream, bufferSize)
} else {
// Hadoop streams do their own buffering
fileStream
}
}
def openFileForWriting(uuid: UUID): OutputStream = {
val fileStream = if (fileSystem != null) {
fileSystem.create(getPath(uuid))
} else {
// Local filesystem
new FileOutputStream(getPath(uuid).toString)
}
if (compress) {
// LZF stream does its own buffering
new LZFOutputStream(fileStream)
} else if (fileSystem == null) {
new BufferedOutputStream(fileStream, bufferSize)
} else {
// Hadoop streams do their own buffering
fileStream
}
}
}

View file

@ -10,14 +10,15 @@ import it.unimi.dsi.fastutil.io.FastBufferedInputStream
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import spark._
import spark.storage.StorageLevel
class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging with Serializable {
def value = value_
HttpBroadcast.synchronized {
HttpBroadcast.values.put(uuid, 0, value_)
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(uuid.toString, value_, StorageLevel.MEMORY_ONLY, false)
}
if (!isLocal) {
@ -28,31 +29,28 @@ extends Broadcast[T] with Logging with Serializable {
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
HttpBroadcast.synchronized {
val cachedVal = HttpBroadcast.values.get(uuid, 0)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
logInfo("Started reading broadcast variable " + uuid)
val start = System.nanoTime
value_ = HttpBroadcast.read[T](uuid)
HttpBroadcast.values.put(uuid, 0, value_)
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + uuid + " took " + time + " s")
SparkEnv.get.blockManager.getSingle(uuid.toString) match {
case Some(x) => value_ = x.asInstanceOf[T]
case None => {
logInfo("Started reading broadcast variable " + uuid)
val start = System.nanoTime
value_ = HttpBroadcast.read[T](uuid)
SparkEnv.get.blockManager.putSingle(uuid.toString, value_, StorageLevel.MEMORY_ONLY, false)
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + uuid + " took " + time + " s")
}
}
}
}
}
class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isMaster: Boolean) {
HttpBroadcast.initialize(isMaster)
}
def initialize(isMaster: Boolean) = HttpBroadcast.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) = new HttpBroadcast[T](value_, isLocal)
def stop() = HttpBroadcast.stop()
}
private object HttpBroadcast extends Logging {
val values = SparkEnv.get.cache.newKeySpace()
private var initialized = false
private var broadcastDir: File = null
@ -74,6 +72,12 @@ private object HttpBroadcast extends Logging {
}
}
}
def stop() {
if (server != null) {
server.stop()
}
}
private def createServer() {
broadcastDir = Utils.createTempDir()

View file

@ -0,0 +1,394 @@
package spark.broadcast
import java.io._
import java.net._
import java.util.{UUID, Random}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import scala.collection.mutable.Map
import spark._
private object MultiTracker
extends Logging {
// Tracker Messages
val REGISTER_BROADCAST_TRACKER = 0
val UNREGISTER_BROADCAST_TRACKER = 1
val FIND_BROADCAST_TRACKER = 2
// Map to keep track of guides of ongoing broadcasts
var valueToGuideMap = Map[UUID, SourceInfo]()
// Random number generator
var ranGen = new Random
private var initialized = false
private var isMaster_ = false
private var stopBroadcast = false
private var trackMV: TrackMultipleValues = null
def initialize(isMaster__ : Boolean) {
synchronized {
if (!initialized) {
isMaster_ = isMaster__
if (isMaster) {
trackMV = new TrackMultipleValues
trackMV.setDaemon(true)
trackMV.start()
// Set masterHostAddress to the master's IP address for the slaves to read
System.setProperty("spark.MultiTracker.MasterHostAddress", Utils.localIpAddress)
}
initialized = true
}
}
}
def stop() {
stopBroadcast = true
}
// Load common parameters
private var MasterHostAddress_ = System.getProperty(
"spark.MultiTracker.MasterHostAddress", "")
private var MasterTrackerPort_ = System.getProperty(
"spark.broadcast.masterTrackerPort", "11111").toInt
private var BlockSize_ = System.getProperty(
"spark.broadcast.blockSize", "4096").toInt * 1024
private var MaxRetryCount_ = System.getProperty(
"spark.broadcast.maxRetryCount", "2").toInt
private var TrackerSocketTimeout_ = System.getProperty(
"spark.broadcast.trackerSocketTimeout", "50000").toInt
private var ServerSocketTimeout_ = System.getProperty(
"spark.broadcast.serverSocketTimeout", "10000").toInt
private var MinKnockInterval_ = System.getProperty(
"spark.broadcast.minKnockInterval", "500").toInt
private var MaxKnockInterval_ = System.getProperty(
"spark.broadcast.maxKnockInterval", "999").toInt
// Load TreeBroadcast config params
private var MaxDegree_ = System.getProperty(
"spark.broadcast.maxDegree", "2").toInt
// Load BitTorrentBroadcast config params
private var MaxPeersInGuideResponse_ = System.getProperty(
"spark.broadcast.maxPeersInGuideResponse", "4").toInt
private var MaxChatSlots_ = System.getProperty(
"spark.broadcast.maxChatSlots", "4").toInt
private var MaxChatTime_ = System.getProperty(
"spark.broadcast.maxChatTime", "500").toInt
private var MaxChatBlocks_ = System.getProperty(
"spark.broadcast.maxChatBlocks", "1024").toInt
private var EndGameFraction_ = System.getProperty(
"spark.broadcast.endGameFraction", "0.95").toDouble
def isMaster = isMaster_
// Common config params
def MasterHostAddress = MasterHostAddress_
def MasterTrackerPort = MasterTrackerPort_
def BlockSize = BlockSize_
def MaxRetryCount = MaxRetryCount_
def TrackerSocketTimeout = TrackerSocketTimeout_
def ServerSocketTimeout = ServerSocketTimeout_
def MinKnockInterval = MinKnockInterval_
def MaxKnockInterval = MaxKnockInterval_
// TreeBroadcast configs
def MaxDegree = MaxDegree_
// BitTorrentBroadcast configs
def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
def MaxChatSlots = MaxChatSlots_
def MaxChatTime = MaxChatTime_
def MaxChatBlocks = MaxChatBlocks_
def EndGameFraction = EndGameFraction_
class TrackMultipleValues
extends Thread with Logging {
override def run() {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(MasterTrackerPort)
logInfo("TrackMultipleValues started at " + serverSocket)
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(TrackerSocketTimeout)
clientSocket = serverSocket.accept()
} catch {
case e: Exception => {
if (stopBroadcast) {
logInfo("Stopping TrackMultipleValues...")
}
}
}
if (clientSocket != null) {
try {
threadPool.execute(new Thread {
override def run() {
val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
val ois = new ObjectInputStream(clientSocket.getInputStream)
try {
// First, read message type
val messageType = ois.readObject.asInstanceOf[Int]
if (messageType == REGISTER_BROADCAST_TRACKER) {
// Receive UUID
val uuid = ois.readObject.asInstanceOf[UUID]
// Receive hostAddress and listenPort
val gInfo = ois.readObject.asInstanceOf[SourceInfo]
// Add to the map
valueToGuideMap.synchronized {
valueToGuideMap += (uuid -> gInfo)
}
logInfo ("New broadcast " + uuid + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
// Send dummy ACK
oos.writeObject(-1)
oos.flush()
} else if (messageType == UNREGISTER_BROADCAST_TRACKER) {
// Receive UUID
val uuid = ois.readObject.asInstanceOf[UUID]
// Remove from the map
valueToGuideMap.synchronized {
valueToGuideMap(uuid) = SourceInfo("", SourceInfo.TxOverGoToDefault)
}
logInfo ("Broadcast " + uuid + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
// Send dummy ACK
oos.writeObject(-1)
oos.flush()
} else if (messageType == FIND_BROADCAST_TRACKER) {
// Receive UUID
val uuid = ois.readObject.asInstanceOf[UUID]
var gInfo =
if (valueToGuideMap.contains(uuid)) valueToGuideMap(uuid)
else SourceInfo("", SourceInfo.TxNotStartedRetry)
logDebug("Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort)
// Send reply back
oos.writeObject(gInfo)
oos.flush()
} else {
throw new SparkException("Undefined messageType at TrackMultipleValues")
}
} catch {
case e: Exception => {
logError("TrackMultipleValues had a " + e)
}
} finally {
ois.close()
oos.close()
clientSocket.close()
}
}
})
} catch {
// In failure, close socket here; else, client thread will close
case ioe: IOException => clientSocket.close()
}
}
}
} finally {
serverSocket.close()
}
// Shutdown the thread pool
threadPool.shutdown()
}
}
def getGuideInfo(variableUUID: UUID): SourceInfo = {
var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxOverGoToDefault)
var retriesLeft = MultiTracker.MaxRetryCount
do {
try {
// Connect to the tracker to find out GuideInfo
clientSocketToTracker =
new Socket(MultiTracker.MasterHostAddress, MultiTracker.MasterTrackerPort)
oosTracker =
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
oosTracker.flush()
oisTracker =
new ObjectInputStream(clientSocketToTracker.getInputStream)
// Send messageType/intention
oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER)
oosTracker.flush()
// Send UUID and receive GuideInfo
oosTracker.writeObject(variableUUID)
oosTracker.flush()
gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
} catch {
case e: Exception => logError("getGuideInfo had a " + e)
} finally {
if (oisTracker != null) {
oisTracker.close()
}
if (oosTracker != null) {
oosTracker.close()
}
if (clientSocketToTracker != null) {
clientSocketToTracker.close()
}
}
Thread.sleep(MultiTracker.ranGen.nextInt(
MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
MultiTracker.MinKnockInterval)
retriesLeft -= 1
} while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry)
logDebug("Got this guidePort from Tracker: " + gInfo.listenPort)
return gInfo
}
def registerBroadcast(uuid: UUID, gInfo: SourceInfo) {
val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
val oisST = new ObjectInputStream(socket.getInputStream)
// Send messageType/intention
oosST.writeObject(REGISTER_BROADCAST_TRACKER)
oosST.flush()
// Send UUID of this broadcast
oosST.writeObject(uuid)
oosST.flush()
// Send this tracker's information
oosST.writeObject(gInfo)
oosST.flush()
// Receive ACK and throw it away
oisST.readObject.asInstanceOf[Int]
// Shut stuff down
oisST.close()
oosST.close()
socket.close()
}
def unregisterBroadcast(uuid: UUID) {
val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
val oisST = new ObjectInputStream(socket.getInputStream)
// Send messageType/intention
oosST.writeObject(UNREGISTER_BROADCAST_TRACKER)
oosST.flush()
// Send UUID of this broadcast
oosST.writeObject(uuid)
oosST.flush()
// Receive ACK and throw it away
oisST.readObject.asInstanceOf[Int]
// Shut stuff down
oisST.close()
oosST.close()
socket.close()
}
// Helper method to convert an object to Array[BroadcastBlock]
def blockifyObject[IN](obj: IN): VariableInfo = {
val baos = new ByteArrayOutputStream
val oos = new ObjectOutputStream(baos)
oos.writeObject(obj)
oos.close()
baos.close()
val byteArray = baos.toByteArray
val bais = new ByteArrayInputStream(byteArray)
var blockNum = (byteArray.length / BlockSize)
if (byteArray.length % BlockSize != 0)
blockNum += 1
var retVal = new Array[BroadcastBlock](blockNum)
var blockID = 0
for (i <- 0 until (byteArray.length, BlockSize)) {
val thisBlockSize = math.min(BlockSize, byteArray.length - i)
var tempByteArray = new Array[Byte](thisBlockSize)
val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
blockID += 1
}
bais.close()
var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
variableInfo.hasBlocks = blockNum
return variableInfo
}
// Helper method to convert Array[BroadcastBlock] to object
def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
totalBytes: Int,
totalBlocks: Int): OUT = {
var retByteArray = new Array[Byte](totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
i * BlockSize, arrayOfBlocks(i).byteArray.length)
}
byteArrayToObject(retByteArray)
}
private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
}
val retVal = in.readObject.asInstanceOf[OUT]
in.close()
return retVal
}
}
case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
extends Serializable
case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
totalBlocks: Int,
totalBytes: Int)
extends Serializable {
@transient var hasBlocks = 0
}

View file

@ -6,15 +6,11 @@ import spark._
/**
* Used to keep and pass around information of peers involved in a broadcast
*
* CHANGED: Keep track of the blockSize for THIS broadcast variable.
* Broadcast.BlockSize is expected to be updated across different broadcasts
*/
case class SourceInfo (hostAddress: String,
listenPort: Int,
totalBlocks: Int = SourceInfo.UnusedParam,
totalBytes: Int = SourceInfo.UnusedParam,
blockSize: Int = Broadcast.BlockSize)
totalBytes: Int = SourceInfo.UnusedParam)
extends Comparable[SourceInfo] with Logging {
var currentLeechers = 0
@ -33,8 +29,8 @@ extends Comparable[SourceInfo] with Logging {
object SourceInfo {
// Constants for special values of listenPort
val TxNotStartedRetry = -1
val TxOverGoToHDFS = 0
val TxOverGoToDefault = 0
// Other constants
val StopBroadcast = -2
val UnusedParam = 0
}
}

View file

@ -8,22 +8,21 @@ import scala.collection.mutable.{ListBuffer, Map, Set}
import scala.math
import spark._
import spark.storage.StorageLevel
class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging with Serializable {
def value = value_
TreeBroadcast.synchronized {
TreeBroadcast.values.put(uuid, 0, value_)
MultiTracker.synchronized {
SparkEnv.get.blockManager.putSingle(uuid.toString, value_, StorageLevel.MEMORY_ONLY, false)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@transient var totalBytes = -1
@transient var totalBlocks = -1
@transient var hasBlocks = 0
// CHANGED: BlockSize in the Broadcast object is expected to change over time
@transient var blockSize = Broadcast.BlockSize
@transient var listenPortLock = new Object
@transient var guidePortLock = new Object
@ -39,7 +38,6 @@ extends Broadcast[T] with Logging with Serializable {
@transient var listenPort = -1
@transient var guidePort = -1
@transient var hasCopyInHDFS = false
@transient var stopBroadcast = false
// Must call this after all the variables have been created/initialized
@ -50,19 +48,10 @@ extends Broadcast[T] with Logging with Serializable {
def sendBroadcast() {
logInfo("Local host address: " + hostAddress)
// Store a persistent copy in HDFS
// TODO: Turned OFF for now
// val out = new ObjectOutputStream(DfsBroadcast.openFileForWriting(uuid))
// out.writeObject(value_)
// out.close()
// TODO: Fix this at some point
hasCopyInHDFS = true
// Create a variableInfo object and store it in valueInfos
var variableInfo = Broadcast.blockifyObject(value_)
var variableInfo = MultiTracker.blockifyObject(value_)
// Prepare the value being broadcasted
// TODO: Refactoring and clean-up required here
arrayOfBlocks = variableInfo.arrayOfBlocks
totalBytes = variableInfo.totalBytes
totalBlocks = variableInfo.totalBlocks
@ -75,9 +64,7 @@ extends Broadcast[T] with Logging with Serializable {
// Must always come AFTER guideMR is created
while (guidePort == -1) {
guidePortLock.synchronized {
guidePortLock.wait()
}
guidePortLock.synchronized { guidePortLock.wait() }
}
serveMR = new ServeMultipleRequests
@ -87,63 +74,59 @@ extends Broadcast[T] with Logging with Serializable {
// Must always come AFTER serveMR is created
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait()
}
listenPortLock.synchronized { listenPortLock.wait() }
}
// Must always come AFTER listenPort is created
val masterSource =
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize)
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
listOfSources += masterSource
// Register with the Tracker
TreeBroadcast.registerValue(uuid, guidePort)
MultiTracker.registerBroadcast(uuid,
SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
}
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
TreeBroadcast.synchronized {
val cachedVal = TreeBroadcast.values.get(uuid, 0)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
// Initializing everything because Master will only send null/0 values
initializeSlaveVariables
MultiTracker.synchronized {
SparkEnv.get.blockManager.getSingle(uuid.toString) match {
case Some(x) => x.asInstanceOf[T]
case None => {
logInfo("Started reading broadcast variable " + uuid)
// Initializing everything because Master will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables
logInfo("Local host address: " + hostAddress)
logInfo("Local host address: " + hostAddress)
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start()
logInfo("ServeMultipleRequests started...")
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start()
logInfo("ServeMultipleRequests started...")
val start = System.nanoTime
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast(uuid)
// If does not succeed, then get from HDFS copy
if (receptionSucceeded) {
value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
TreeBroadcast.values.put(uuid, 0, value_)
} else {
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
TreeBroadcast.values.put(uuid, 0, value_)
fileIn.close()
val receptionSucceeded = receiveBroadcast(uuid)
if (receptionSucceeded) {
value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
SparkEnv.get.blockManager.putSingle(uuid.toString, value_, StorageLevel.MEMORY_ONLY, false)
} else {
logError("Reading Broadcasted variable " + uuid + " failed")
}
val time = (System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
}
val time = (System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
}
}
}
private def initializeSlaveVariables() {
private def initializeWorkerVariables() {
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
hasBlocks = 0
blockSize = -1
listenPortLock = new Object
totalBlocksLock = new Object
@ -157,72 +140,17 @@ extends Broadcast[T] with Logging with Serializable {
stopBroadcast = false
}
def getMasterListenPort(variableUUID: UUID): Int = {
var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
var retriesLeft = Broadcast.MaxRetryCount
do {
try {
// Connect to the tracker to find out the guide
clientSocketToTracker =
new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort)
oosTracker =
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
oosTracker.flush()
oisTracker =
new ObjectInputStream(clientSocketToTracker.getInputStream)
// Send UUID and receive masterListenPort
oosTracker.writeObject(uuid)
oosTracker.flush()
masterListenPort = oisTracker.readObject.asInstanceOf[Int]
} catch {
case e: Exception => {
logInfo("getMasterListenPort had a " + e)
}
} finally {
if (oisTracker != null) {
oisTracker.close()
}
if (oosTracker != null) {
oosTracker.close()
}
if (clientSocketToTracker != null) {
clientSocketToTracker.close()
}
}
retriesLeft -= 1
Thread.sleep(TreeBroadcast.ranGen.nextInt(
Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) +
Broadcast.MinKnockInterval)
} while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
logInfo("Got this guidePort from Tracker: " + masterListenPort)
return masterListenPort
}
def receiveBroadcast(variableUUID: UUID): Boolean = {
val masterListenPort = getMasterListenPort(variableUUID)
if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
masterListenPort == SourceInfo.TxNotStartedRetry) {
// TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
// to HDFS anyway when receiveBroadcast returns false
val gInfo = MultiTracker.getGuideInfo(variableUUID)
if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
return false
}
// Wait until hostAddress and listenPort are created by the
// ServeMultipleRequests thread
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait()
}
listenPortLock.synchronized { listenPortLock.wait() }
}
var clientSocketToMaster: Socket = null
@ -231,19 +159,15 @@ extends Broadcast[T] with Logging with Serializable {
// Connect and receive broadcast from the specified source, retrying the
// specified number of times in case of failures
var retriesLeft = Broadcast.MaxRetryCount
var retriesLeft = MultiTracker.MaxRetryCount
do {
// Connect to Master and send this worker's Information
clientSocketToMaster =
new Socket(Broadcast.MasterHostAddress, masterListenPort)
// TODO: Guiding object connection is reusable
oosMaster =
new ObjectOutputStream(clientSocketToMaster.getOutputStream)
clientSocketToMaster = new Socket(MultiTracker.MasterHostAddress, gInfo.listenPort)
oosMaster = new ObjectOutputStream(clientSocketToMaster.getOutputStream)
oosMaster.flush()
oisMaster =
new ObjectInputStream(clientSocketToMaster.getInputStream)
oisMaster = new ObjectInputStream(clientSocketToMaster.getInputStream)
logInfo("Connected to Master's guiding object")
logDebug("Connected to Master's guiding object")
// Send local source information
oosMaster.writeObject(SourceInfo(hostAddress, listenPort))
@ -253,13 +177,10 @@ extends Broadcast[T] with Logging with Serializable {
var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
totalBlocks = sourceInfo.totalBlocks
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
totalBlocksLock.synchronized {
totalBlocksLock.notifyAll()
}
totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
totalBytes = sourceInfo.totalBytes
blockSize = sourceInfo.blockSize
logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
logDebug("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
val start = System.nanoTime
val receptionSucceeded = receiveSingleTransmission(sourceInfo)
@ -289,8 +210,10 @@ extends Broadcast[T] with Logging with Serializable {
return (hasBlocks == totalBlocks)
}
// Tries to receive broadcast from the source and returns Boolean status.
// This might be called multiple times to retry a defined number of times.
/**
* Tries to receive broadcast from the source and returns Boolean status.
* This might be called multiple times to retry a defined number of times.
*/
private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
var clientSocketToSource: Socket = null
var oosSource: ObjectOutputStream = null
@ -299,16 +222,13 @@ extends Broadcast[T] with Logging with Serializable {
var receptionSucceeded = false
try {
// Connect to the source to get the object itself
clientSocketToSource =
new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
oosSource =
new ObjectOutputStream(clientSocketToSource.getOutputStream)
clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream)
oosSource.flush()
oisSource =
new ObjectInputStream(clientSocketToSource.getInputStream)
oisSource = new ObjectInputStream(clientSocketToSource.getInputStream)
logInfo("Inside receiveSingleTransmission")
logInfo("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
logDebug("Inside receiveSingleTransmission")
logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
// Send the range
oosSource.writeObject((hasBlocks, totalBlocks))
@ -319,20 +239,17 @@ extends Broadcast[T] with Logging with Serializable {
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
val receptionTime = (System.currentTimeMillis - recvStartTime)
logInfo("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
arrayOfBlocks(hasBlocks) = bcBlock
hasBlocks += 1
// Set to true if at least one block is received
receptionSucceeded = true
hasBlocksLock.synchronized {
hasBlocksLock.notifyAll()
}
hasBlocksLock.synchronized { hasBlocksLock.notifyAll() }
}
} catch {
case e: Exception => {
logInfo("receiveSingleTransmission had a " + e)
}
case e: Exception => logError("receiveSingleTransmission had a " + e)
} finally {
if (oisSource != null) {
oisSource.close()
@ -361,24 +278,22 @@ extends Broadcast[T] with Logging with Serializable {
guidePort = serverSocket.getLocalPort
logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
guidePortLock.synchronized {
guidePortLock.notifyAll()
}
guidePortLock.synchronized { guidePortLock.notifyAll() }
try {
// Don't stop until there is a copy in HDFS
while (!stopBroadcast || !hasCopyInHDFS) {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo("GuideMultipleRequests Timeout.")
logError("GuideMultipleRequests Timeout.")
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done. Comparing with
// listOfSources.size - 1, because it includes the Guide itself
// everyone connected so far are done.
// Comparing with listOfSources.size - 1, because the Guide itself
// is included
if (listOfSources.size > 1 &&
setOfCompletedSources.size == listOfSources.size - 1) {
stopBroadcast = true
@ -386,7 +301,7 @@ extends Broadcast[T] with Logging with Serializable {
}
}
if (clientSocket != null) {
logInfo("Guide: Accepted new client connection: " + clientSocket)
logDebug("Guide: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new GuideSingleRequest(clientSocket))
} catch {
@ -399,14 +314,13 @@ extends Broadcast[T] with Logging with Serializable {
logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
TreeBroadcast.unregisterValue(uuid)
MultiTracker.unregisterBroadcast(uuid)
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
serverSocket.close()
}
}
// Shutdown the thread pool
threadPool.shutdown()
}
@ -423,21 +337,17 @@ extends Broadcast[T] with Logging with Serializable {
try {
// Connect to the source
guideSocketToSource =
new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
gosSource =
new ObjectOutputStream(guideSocketToSource.getOutputStream)
guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
gosSource.flush()
gisSource =
new ObjectInputStream(guideSocketToSource.getInputStream)
gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
// Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
gosSource.writeObject((SourceInfo.StopBroadcast,
SourceInfo.StopBroadcast))
// Send stopBroadcast signal
gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast))
gosSource.flush()
} catch {
case e: Exception => {
logInfo("sendStopBroadcastNotifications had a " + e)
logError("sendStopBroadcastNotifications had a " + e)
}
} finally {
if (gisSource != null) {
@ -473,14 +383,14 @@ extends Broadcast[T] with Logging with Serializable {
listOfSources.synchronized {
// Select a suitable source and send it back to the worker
selectedSourceInfo = selectSuitableSource(sourceInfo)
logInfo("Sending selectedSourceInfo: " + selectedSourceInfo)
logDebug("Sending selectedSourceInfo: " + selectedSourceInfo)
oos.writeObject(selectedSourceInfo)
oos.flush()
// Add this new (if it can finish) source to the list of sources
thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
sourceInfo.listenPort, totalBlocks, totalBytes, blockSize)
logInfo("Adding possible new source to listOfSources: " + thisWorkerInfo)
sourceInfo.listenPort, totalBlocks, totalBytes)
logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo)
listOfSources += thisWorkerInfo
}
@ -492,9 +402,9 @@ extends Broadcast[T] with Logging with Serializable {
// This should work since SourceInfo is a case class
assert(listOfSources.contains(selectedSourceInfo))
// Remove first
// Remove first
// (Currently removing a source based on just one failure notification!)
listOfSources = listOfSources - selectedSourceInfo
// TODO: Removing a source based on just one failure notification!
// Update sourceInfo and put it back in, IF reception succeeded
if (!sourceInfo.receptionFailed) {
@ -503,17 +413,13 @@ extends Broadcast[T] with Logging with Serializable {
setOfCompletedSources += thisWorkerInfo
}
// Update leecher count and put it back in
selectedSourceInfo.currentLeechers -= 1
// Put it back
listOfSources += selectedSourceInfo
}
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close() everything up
case e: Exception => {
// Assuming that exception caused due to receiver worker failure.
// Remove failed worker from listOfSources and update leecherCount of
// corresponding source worker
listOfSources.synchronized {
@ -532,27 +438,23 @@ extends Broadcast[T] with Logging with Serializable {
}
}
} finally {
logInfo("GuideSingleRequest is closing streams and sockets")
ois.close()
oos.close()
clientSocket.close()
}
}
// FIXME: Caller must have a synchronized block on listOfSources
// FIXME: If a worker fails to get the broadcasted variable from a source
// and comes back to the Master, this function might choose the worker
// itself as a source to create a dependency cycle (this worker was put
// into listOfSources as a streming source when it first arrived). The
// length of this cycle can be arbitrarily long.
// Assuming the caller to have a synchronized block on listOfSources
// Select one with the most leechers. This will level-wise fill the tree
private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
// Select one with the most leechers. This will level-wise fill the tree
var maxLeechers = -1
var selectedSource: SourceInfo = null
listOfSources.foreach { source =>
if (source != skipSourceInfo &&
source.currentLeechers < Broadcast.MaxDegree &&
if ((source.hostAddress != skipSourceInfo.hostAddress ||
source.listenPort != skipSourceInfo.listenPort) &&
source.currentLeechers < MultiTracker.MaxDegree &&
source.currentLeechers > maxLeechers) {
selectedSource = source
maxLeechers = source.currentLeechers
@ -561,7 +463,6 @@ extends Broadcast[T] with Logging with Serializable {
// Update leecher count
selectedSource.currentLeechers += 1
return selectedSource
}
}
@ -569,35 +470,33 @@ extends Broadcast[T] with Logging with Serializable {
class ServeMultipleRequests
extends Thread with Logging {
override def run() {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(0)
var threadPool = Utils.newDaemonCachedThreadPool()
override def run() {
var serverSocket = new ServerSocket(0)
listenPort = serverSocket.getLocalPort
logInfo("ServeMultipleRequests started with " + serverSocket)
listenPortLock.synchronized {
listenPortLock.notifyAll()
}
listenPortLock.synchronized { listenPortLock.notifyAll() }
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo("ServeMultipleRequests Timeout.")
}
case e: Exception => logError("ServeMultipleRequests Timeout.")
}
if (clientSocket != null) {
logInfo("Serve: Accepted new client connection: " + clientSocket)
logDebug("Serve: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new ServeSingleRequest(clientSocket))
} catch {
// In failure, close() socket here; else, the thread will close() it
// In failure, close socket here; else, the thread will close it
case ioe: IOException => clientSocket.close()
}
}
@ -608,7 +507,6 @@ extends Broadcast[T] with Logging with Serializable {
serverSocket.close()
}
}
// Shutdown the thread pool
threadPool.shutdown()
}
@ -631,19 +529,14 @@ extends Broadcast[T] with Logging with Serializable {
sendFrom = rangeToSend._1
sendUntil = rangeToSend._2
if (sendFrom == SourceInfo.StopBroadcast &&
sendUntil == SourceInfo.StopBroadcast) {
// If not a valid range, stop broadcast
if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) {
stopBroadcast = true
} else {
// Carry on
sendObject
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close() everything up
case e: Exception => {
logInfo("ServeSingleRequest had a " + e)
}
case e: Exception => logError("ServeSingleRequest had a " + e)
} finally {
logInfo("ServeSingleRequest is closing streams and sockets")
ois.close()
@ -655,26 +548,20 @@ extends Broadcast[T] with Logging with Serializable {
private def sendObject() {
// Wait till receiving the SourceInfo from Master
while (totalBlocks == -1) {
totalBlocksLock.synchronized {
totalBlocksLock.wait()
}
totalBlocksLock.synchronized { totalBlocksLock.wait() }
}
for (i <- sendFrom until sendUntil) {
while (i == hasBlocks) {
hasBlocksLock.synchronized {
hasBlocksLock.wait()
}
hasBlocksLock.synchronized { hasBlocksLock.wait() }
}
try {
oos.writeObject(arrayOfBlocks(i))
oos.flush()
} catch {
case e: Exception => {
logInfo("sendObject had a " + e)
}
case e: Exception => logError("sendObject had a " + e)
}
logInfo("Sent block: " + i + " to " + clientSocket)
logDebug("Sent block: " + i + " to " + clientSocket)
}
}
}
@ -683,124 +570,7 @@ extends Broadcast[T] with Logging with Serializable {
class TreeBroadcastFactory
extends BroadcastFactory {
def initialize(isMaster: Boolean) = TreeBroadcast.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) =
new TreeBroadcast[T](value_, isLocal)
}
private object TreeBroadcast
extends Logging {
val values = SparkEnv.get.cache.newKeySpace()
var valueToGuidePortMap = Map[UUID, Int]()
// Random number generator
var ranGen = new Random
private var initialized = false
private var isMaster_ = false
private var trackMV: TrackMultipleValues = null
private var MaxDegree_ : Int = 2
def initialize(isMaster__ : Boolean) {
synchronized {
if (!initialized) {
isMaster_ = isMaster__
if (isMaster) {
trackMV = new TrackMultipleValues
trackMV.setDaemon(true)
trackMV.start()
// TODO: Logging the following line makes the Spark framework ID not
// getting logged, cause it calls logInfo before log4j is initialized
logInfo("TrackMultipleValues started...")
}
// Initialize DfsBroadcast to be used for broadcast variable persistence
DfsBroadcast.initialize
initialized = true
}
}
}
def isMaster = isMaster_
def registerValue(uuid: UUID, guidePort: Int) {
valueToGuidePortMap.synchronized {
valueToGuidePortMap += (uuid -> guidePort)
logInfo("New value registered with the Tracker " + valueToGuidePortMap)
}
}
def unregisterValue(uuid: UUID) {
valueToGuidePortMap.synchronized {
valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToHDFS
logInfo("Value unregistered from the Tracker " + valueToGuidePortMap)
}
}
class TrackMultipleValues
extends Thread with Logging {
override def run() {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(Broadcast.MasterTrackerPort)
logInfo("TrackMultipleValues" + serverSocket)
try {
while (true) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo("TrackMultipleValues Timeout. Stopping listening...")
}
}
if (clientSocket != null) {
try {
threadPool.execute(new Thread {
override def run() {
val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
val ois = new ObjectInputStream(clientSocket.getInputStream)
try {
val uuid = ois.readObject.asInstanceOf[UUID]
var guidePort =
if (valueToGuidePortMap.contains(uuid)) {
valueToGuidePortMap(uuid)
} else SourceInfo.TxNotStartedRetry
logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
oos.writeObject(guidePort)
} catch {
case e: Exception => {
logInfo("TrackMultipleValues had a " + e)
}
} finally {
ois.close()
oos.close()
clientSocket.close()
}
}
})
} catch {
// In failure, close() socket here; else, client thread will close()
case ioe: IOException => clientSocket.close()
}
}
}
} finally {
serverSocket.close()
}
// Shutdown the thread pool
threadPool.shutdown()
}
}
def initialize(isMaster: Boolean) = MultiTracker.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) = new TreeBroadcast[T](value_, isLocal)
def stop() = MultiTracker.stop
}

View file

@ -35,8 +35,6 @@ class Executor extends Logging {
// Initialize Spark environment (using system properties read above)
env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false)
SparkEnv.set(env)
// Old stuff that isn't yet using env
Broadcast.initialize(false)
// Create our ClassLoader (using spark properties) and set it on this thread
classLoader = createClassLoader()