From ca02a923327166ddacfccf26003543c15bc1c04c Mon Sep 17 00:00:00 2001 From: Mosharaf Chowdhury Date: Mon, 9 Jul 2012 21:35:39 -0700 Subject: [PATCH] Refactored TrackMultipleValues out. --- .../spark/broadcast/BitTorrentBroadcast.scala | 226 ++---------------- .../scala/spark/broadcast/Broadcast.scala | 182 +++++++++++++- .../scala/spark/broadcast/TreeBroadcast.scala | 130 +--------- 3 files changed, 210 insertions(+), 328 deletions(-) diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala index 5fca5a46d0..aab3a15587 100644 --- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala @@ -2,7 +2,7 @@ package spark.broadcast import java.io._ import java.net._ -import java.util.{BitSet, Comparator, Random, Timer, TimerTask, UUID} +import java.util.{BitSet, Comparator, Timer, TimerTask, UUID} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ListBuffer, Map, Set} @@ -15,8 +15,8 @@ extends Broadcast[T] with Logging with Serializable { def value = value_ - BitTorrentBroadcast.synchronized { - BitTorrentBroadcast.values.put(uuid, 0, value_) + Broadcast.synchronized { + Broadcast.values.put(uuid, 0, value_) } @transient var arrayOfBlocks: Array[BroadcastBlock] = null @@ -109,14 +109,14 @@ extends Broadcast[T] with Logging with Serializable { listOfSources += masterSource // Register with the Tracker - registerBroadcast(uuid, + Broadcast.registerBroadcast(uuid, SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) } private def readObject(in: ObjectInputStream) { in.defaultReadObject() - BitTorrentBroadcast.synchronized { - val cachedVal = BitTorrentBroadcast.values.get(uuid, 0) + Broadcast.synchronized { + val cachedVal = Broadcast.values.get(uuid, 0) if (cachedVal != null) { value_ = cachedVal.asInstanceOf[T] @@ -137,7 +137,7 @@ extends Broadcast[T] with Logging with Serializable { val receptionSucceeded = receiveBroadcast(uuid) if (receptionSucceeded) { value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - BitTorrentBroadcast.values.put(uuid, 0, value_) + Broadcast.values.put(uuid, 0, value_) } else { logError("Reading Broadcasted variable " + uuid + " failed") } @@ -171,58 +171,6 @@ extends Broadcast[T] with Logging with Serializable { stopBroadcast = false } - private def registerBroadcast(uuid: UUID, gInfo: SourceInfo) { - val socket = new Socket(Broadcast.MasterHostAddress, - Broadcast.MasterTrackerPort) - val oosST = new ObjectOutputStream(socket.getOutputStream) - oosST.flush() - val oisST = new ObjectInputStream(socket.getInputStream) - - // Send messageType/intention - oosST.writeObject(Broadcast.REGISTER_BROADCAST_TRACKER) - oosST.flush() - - // Send UUID of this broadcast - oosST.writeObject(uuid) - oosST.flush() - - // Send this tracker's information - oosST.writeObject(gInfo) - oosST.flush() - - // Receive ACK and throw it away - oisST.readObject.asInstanceOf[Int] - - // Shut stuff down - oisST.close() - oosST.close() - socket.close() - } - - private def unregisterBroadcast(uuid: UUID) { - val socket = new Socket(Broadcast.MasterHostAddress, - Broadcast.MasterTrackerPort) - val oosST = new ObjectOutputStream(socket.getOutputStream) - oosST.flush() - val oisST = new ObjectInputStream(socket.getInputStream) - - // Send messageType/intention - oosST.writeObject(Broadcast.UNREGISTER_BROADCAST_TRACKER) - oosST.flush() - - // Send UUID of this broadcast - oosST.writeObject(uuid) - oosST.flush() - - // Receive ACK and throw it away - oisST.readObject.asInstanceOf[Int] - - // Shut stuff down - oisST.close() - oosST.close() - socket.close() - } - private def getLocalSourceInfo: SourceInfo = { // Wait till hostName and listenPort are OK while (listenPort == -1) { @@ -274,7 +222,7 @@ extends Broadcast[T] with Logging with Serializable { // Keep exchaning information until all blocks have been received while (hasBlocks.get < totalBlocks) { talkOnce - Thread.sleep(BitTorrentBroadcast.ranGen.nextInt( + Thread.sleep(Broadcast.ranGen.nextInt( Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) + Broadcast.MinKnockInterval) } @@ -354,7 +302,7 @@ extends Broadcast[T] with Logging with Serializable { } } - Thread.sleep(BitTorrentBroadcast.ranGen.nextInt( + Thread.sleep(Broadcast.ranGen.nextInt( Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) + Broadcast.MinKnockInterval) @@ -492,7 +440,7 @@ extends Broadcast[T] with Logging with Serializable { // Now always picking randomly if (curPeer == null && peersNotInUse.size > 0) { // Pick uniformly the i'th required peer - var i = BitTorrentBroadcast.ranGen.nextInt(peersNotInUse.size) + var i = Broadcast.ranGen.nextInt(peersNotInUse.size) var peerIter = peersNotInUse.iterator curPeer = peerIter.next @@ -563,7 +511,7 @@ extends Broadcast[T] with Logging with Serializable { // Sort the peers based on how many rare blocks they have peersWithRareBlocks.sortBy(_._2) - var randomNumber = BitTorrentBroadcast.ranGen.nextDouble + var randomNumber = Broadcast.ranGen.nextDouble var tempSum = 0.0 var i = 0 @@ -732,7 +680,7 @@ extends Broadcast[T] with Logging with Serializable { return -1 } else { // Pick uniformly the i'th required block - var i = BitTorrentBroadcast.ranGen.nextInt(needBlocksBitVector.cardinality) + var i = Broadcast.ranGen.nextInt(needBlocksBitVector.cardinality) var pickedBlockIndex = needBlocksBitVector.nextSetBit(0) while (i > 0) { @@ -804,7 +752,7 @@ extends Broadcast[T] with Logging with Serializable { return -1 } else { // Pick uniformly the i'th index - var i = BitTorrentBroadcast.ranGen.nextInt(minBlocksIndices.size) + var i = Broadcast.ranGen.nextInt(minBlocksIndices.size) return minBlocksIndices(i) } } @@ -885,7 +833,7 @@ extends Broadcast[T] with Logging with Serializable { logInfo("Sending stopBroadcast notifications...") sendStopBroadcastNotifications - unregisterBroadcast(uuid) + Broadcast.unregisterBroadcast(uuid) } finally { if (serverSocket != null) { logInfo("GuideMultipleRequests now stopping...") @@ -1000,7 +948,7 @@ extends Broadcast[T] with Logging with Serializable { var i = -1 do { - i = BitTorrentBroadcast.ranGen.nextInt(listOfSources.size) + i = Broadcast.ranGen.nextInt(listOfSources.size) } while (alreadyPicked.get(i)) var peerIter = listOfSources.iterator @@ -1114,7 +1062,7 @@ extends Broadcast[T] with Logging with Serializable { // If it is master AND at least one copy of each block has not been // sent out already, MODIFY blockToSend - if (BitTorrentBroadcast.isMaster && sentBlocks.get < totalBlocks) { + if (Broadcast.isMaster && sentBlocks.get < totalBlocks) { blockToSend = sentBlocks.getAndIncrement } @@ -1170,150 +1118,10 @@ extends Broadcast[T] with Logging with Serializable { class BitTorrentBroadcastFactory extends BroadcastFactory { def initialize(isMaster: Boolean) { - BitTorrentBroadcast.initialize(isMaster) + // BitTorrentBroadcast.initialize(isMaster) } def newBroadcast[T](value_ : T, isLocal: Boolean) = { new BitTorrentBroadcast[T](value_, isLocal) } } - -private object BitTorrentBroadcast -extends Logging { - val values = SparkEnv.get.cache.newKeySpace() - - var valueToGuideMap = Map[UUID, SourceInfo]() - - // Random number generator - var ranGen = new Random - - private var initialized = false - private var isMaster_ = false - - private var trackMV: TrackMultipleValues = null - - def initialize(isMaster__ : Boolean) { - synchronized { - if (!initialized) { - isMaster_ = isMaster__ - if (isMaster) { - trackMV = new TrackMultipleValues - trackMV.setDaemon(true) - trackMV.start() - } - initialized = true - } - } - } - - def isMaster = isMaster_ - - class TrackMultipleValues - extends Thread with Logging { - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(Broadcast.MasterTrackerPort) - logInfo("TrackMultipleValues" + serverSocket) - - try { - while (true) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { - logInfo("TrackMultipleValues Timeout. Stopping listening...") - } - } - - if (clientSocket != null) { - try { - threadPool.execute(new Thread { - override def run() { - val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - val ois = new ObjectInputStream(clientSocket.getInputStream) - - try { - // First, read message type - val messageType = ois.readObject.asInstanceOf[Int] - - if (messageType == Broadcast.REGISTER_BROADCAST_TRACKER) { - // Receive UUID - val uuid = ois.readObject.asInstanceOf[UUID] - // Receive hostAddress and listenPort - val gInfo = ois.readObject.asInstanceOf[SourceInfo] - - // Add to the map - valueToGuideMap.synchronized { - valueToGuideMap += (uuid -> gInfo) - } - - logInfo ("New broadcast registered with TrackMultipleValues " + uuid + " " + valueToGuideMap) - - // Send dummy ACK - oos.writeObject(-1) - oos.flush() - } else if (messageType == Broadcast.UNREGISTER_BROADCAST_TRACKER) { - // Receive UUID - val uuid = ois.readObject.asInstanceOf[UUID] - - // Remove from the map - valueToGuideMap.synchronized { - valueToGuideMap(uuid) = SourceInfo("", SourceInfo.TxOverGoToDefault) - logInfo("Value unregistered from the Tracker " + valueToGuideMap) - } - - logInfo ("Broadcast unregistered from TrackMultipleValues " + uuid + " " + valueToGuideMap) - - // Send dummy ACK - oos.writeObject(-1) - oos.flush() - } else if (messageType == Broadcast.FIND_BROADCAST_TRACKER) { - // Receive UUID - val uuid = ois.readObject.asInstanceOf[UUID] - - var gInfo = - if (valueToGuideMap.contains(uuid)) valueToGuideMap(uuid) - else SourceInfo("", SourceInfo.TxNotStartedRetry) - - logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort) - - // Send reply back - oos.writeObject(gInfo) - oos.flush() - } else if (messageType == Broadcast.GET_UPDATED_SHARE) { - // TODO: Not implemented - } else { - throw new SparkException("Undefined messageType at TrackMultipleValues") - } - } catch { - case e: Exception => { - logInfo("TrackMultipleValues had a " + e) - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - }) - } catch { - // In failure, close socket here; else, client thread will close - case ioe: IOException => { - clientSocket.close() - } - } - } - } - } finally { - serverSocket.close() - } - // Shutdown the thread pool - threadPool.shutdown() - } - } -} diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala index 182c0851bc..72bfc35b74 100644 --- a/core/src/main/scala/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/spark/broadcast/Broadcast.scala @@ -2,9 +2,11 @@ package spark.broadcast import java.io._ import java.net._ -import java.util.{BitSet, UUID} +import java.util.{BitSet, UUID, Random} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} +import scala.collection.mutable.Map + import spark._ trait Broadcast[T] extends Serializable { @@ -30,6 +32,18 @@ object Broadcast extends Logging with Serializable { private var isMaster_ = false private var broadcastFactory: BroadcastFactory = null + // Cache of broadcasted objects + val values = SparkEnv.get.cache.newKeySpace() + + // Map to keep track of guides of ongoing broadcasts + var valueToGuideMap = Map[UUID, SourceInfo]() + + // Random number generator + var ranGen = new Random + + // Tracker object + private var trackMV: TrackMultipleValues = null + // Called by SparkContext or Executor before using Broadcast def initialize(isMaster__ : Boolean) { synchronized { @@ -46,6 +60,11 @@ object Broadcast extends Logging with Serializable { // Set masterHostAddress to the master's IP address for the slaves to read if (isMaster) { System.setProperty("spark.broadcast.masterHostAddress", Utils.localIpAddress) + + // Start the tracker + trackMV = new TrackMultipleValues + trackMV.setDaemon(true) + trackMV.start() } // Initialize appropriate BroadcastFactory and BroadcastObject @@ -127,6 +146,167 @@ object Broadcast extends Logging with Serializable { def EndGameFraction = EndGameFraction_ + class TrackMultipleValues + extends Thread with Logging { + override def run() { + var threadPool = Utils.newDaemonCachedThreadPool() + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket(Broadcast.MasterTrackerPort) + logInfo("TrackMultipleValues" + serverSocket) + + try { + while (true) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout) + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { + logInfo("TrackMultipleValues Timeout. Stopping listening...") + } + } + + if (clientSocket != null) { + try { + threadPool.execute(new Thread { + override def run() { + val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + val ois = new ObjectInputStream(clientSocket.getInputStream) + + try { + // First, read message type + val messageType = ois.readObject.asInstanceOf[Int] + + if (messageType == Broadcast.REGISTER_BROADCAST_TRACKER) { + // Receive UUID + val uuid = ois.readObject.asInstanceOf[UUID] + // Receive hostAddress and listenPort + val gInfo = ois.readObject.asInstanceOf[SourceInfo] + + // Add to the map + valueToGuideMap.synchronized { + valueToGuideMap += (uuid -> gInfo) + } + + logInfo ("New broadcast registered with TrackMultipleValues " + uuid + " " + valueToGuideMap) + + // Send dummy ACK + oos.writeObject(-1) + oos.flush() + } else if (messageType == Broadcast.UNREGISTER_BROADCAST_TRACKER) { + // Receive UUID + val uuid = ois.readObject.asInstanceOf[UUID] + + // Remove from the map + valueToGuideMap.synchronized { + valueToGuideMap(uuid) = SourceInfo("", SourceInfo.TxOverGoToDefault) + logInfo("Value unregistered from the Tracker " + valueToGuideMap) + } + + logInfo ("Broadcast unregistered from TrackMultipleValues " + uuid + " " + valueToGuideMap) + + // Send dummy ACK + oos.writeObject(-1) + oos.flush() + } else if (messageType == Broadcast.FIND_BROADCAST_TRACKER) { + // Receive UUID + val uuid = ois.readObject.asInstanceOf[UUID] + + var gInfo = + if (valueToGuideMap.contains(uuid)) valueToGuideMap(uuid) + else SourceInfo("", SourceInfo.TxNotStartedRetry) + + logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort) + + // Send reply back + oos.writeObject(gInfo) + oos.flush() + } else if (messageType == Broadcast.GET_UPDATED_SHARE) { + // TODO: Not implemented + } else { + throw new SparkException("Undefined messageType at TrackMultipleValues") + } + } catch { + case e: Exception => { + logInfo("TrackMultipleValues had a " + e) + } + } finally { + ois.close() + oos.close() + clientSocket.close() + } + } + }) + } catch { + // In failure, close socket here; else, client thread will close + case ioe: IOException => { + clientSocket.close() + } + } + } + } + } finally { + serverSocket.close() + } + // Shutdown the thread pool + threadPool.shutdown() + } + } + + def registerBroadcast(uuid: UUID, gInfo: SourceInfo) { + val socket = new Socket(Broadcast.MasterHostAddress, + Broadcast.MasterTrackerPort) + val oosST = new ObjectOutputStream(socket.getOutputStream) + oosST.flush() + val oisST = new ObjectInputStream(socket.getInputStream) + + // Send messageType/intention + oosST.writeObject(Broadcast.REGISTER_BROADCAST_TRACKER) + oosST.flush() + + // Send UUID of this broadcast + oosST.writeObject(uuid) + oosST.flush() + + // Send this tracker's information + oosST.writeObject(gInfo) + oosST.flush() + + // Receive ACK and throw it away + oisST.readObject.asInstanceOf[Int] + + // Shut stuff down + oisST.close() + oosST.close() + socket.close() + } + + def unregisterBroadcast(uuid: UUID) { + val socket = new Socket(Broadcast.MasterHostAddress, + Broadcast.MasterTrackerPort) + val oosST = new ObjectOutputStream(socket.getOutputStream) + oosST.flush() + val oisST = new ObjectInputStream(socket.getInputStream) + + // Send messageType/intention + oosST.writeObject(Broadcast.UNREGISTER_BROADCAST_TRACKER) + oosST.flush() + + // Send UUID of this broadcast + oosST.writeObject(uuid) + oosST.flush() + + // Receive ACK and throw it away + oisST.readObject.asInstanceOf[Int] + + // Shut stuff down + oisST.close() + oosST.close() + socket.close() + } + // Helper method to convert an object to Array[BroadcastBlock] def blockifyObject[IN](obj: IN): VariableInfo = { val baos = new ByteArrayOutputStream diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala index 4bb363a15e..758c3b0e01 100644 --- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala @@ -14,8 +14,8 @@ extends Broadcast[T] with Logging with Serializable { def value = value_ - TreeBroadcast.synchronized { - TreeBroadcast.values.put(uuid, 0, value_) + Broadcast.synchronized { + Broadcast.values.put(uuid, 0, value_) } @transient var arrayOfBlocks: Array[BroadcastBlock] = null @@ -87,13 +87,14 @@ extends Broadcast[T] with Logging with Serializable { listOfSources += masterSource // Register with the Tracker - TreeBroadcast.registerValue(uuid, guidePort) + Broadcast.registerBroadcast(uuid, + SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) } private def readObject(in: ObjectInputStream) { in.defaultReadObject() - TreeBroadcast.synchronized { - val cachedVal = TreeBroadcast.values.get(uuid, 0) + Broadcast.synchronized { + val cachedVal = Broadcast.values.get(uuid, 0) if (cachedVal != null) { value_ = cachedVal.asInstanceOf[T] } else { @@ -112,7 +113,7 @@ extends Broadcast[T] with Logging with Serializable { val receptionSucceeded = receiveBroadcast(uuid) if (receptionSucceeded) { value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - TreeBroadcast.values.put(uuid, 0, value_) + Broadcast.values.put(uuid, 0, value_) } else { logError("Reading Broadcasted variable " + uuid + " failed") } @@ -181,7 +182,7 @@ extends Broadcast[T] with Logging with Serializable { } retriesLeft -= 1 - Thread.sleep(TreeBroadcast.ranGen.nextInt( + Thread.sleep(Broadcast.ranGen.nextInt( Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) + Broadcast.MinKnockInterval) @@ -382,7 +383,7 @@ extends Broadcast[T] with Logging with Serializable { logInfo("Sending stopBroadcast notifications...") sendStopBroadcastNotifications - TreeBroadcast.unregisterValue(uuid) + Broadcast.unregisterBroadcast(uuid) } finally { if (serverSocket != null) { logInfo("GuideMultipleRequests now stopping...") @@ -666,116 +667,9 @@ extends Broadcast[T] with Logging with Serializable { class TreeBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) = TreeBroadcast.initialize(isMaster) + def initialize(isMaster: Boolean) { + // TreeBroadcast.initialize(isMaster) + } def newBroadcast[T](value_ : T, isLocal: Boolean) = new TreeBroadcast[T](value_, isLocal) } - -private object TreeBroadcast -extends Logging { - val values = SparkEnv.get.cache.newKeySpace() - - var valueToGuidePortMap = Map[UUID, Int]() - - // Random number generator - var ranGen = new Random - - private var initialized = false - private var isMaster_ = false - - private var trackMV: TrackMultipleValues = null - - private var MaxDegree_ : Int = 2 - - def initialize(isMaster__ : Boolean) { - synchronized { - if (!initialized) { - isMaster_ = isMaster__ - if (isMaster) { - trackMV = new TrackMultipleValues - trackMV.setDaemon(true) - trackMV.start() - } - initialized = true - } - } - } - - def isMaster = isMaster_ - - def registerValue(uuid: UUID, guidePort: Int) { - valueToGuidePortMap.synchronized { - valueToGuidePortMap += (uuid -> guidePort) - logInfo("New value registered with the Tracker " + valueToGuidePortMap) - } - } - - def unregisterValue(uuid: UUID) { - valueToGuidePortMap.synchronized { - valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToDefault - logInfo("Value unregistered from the Tracker " + valueToGuidePortMap) - } - } - - class TrackMultipleValues - extends Thread with Logging { - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(Broadcast.MasterTrackerPort) - logInfo("TrackMultipleValues" + serverSocket) - - try { - while (true) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo("TrackMultipleValues Timeout. Stopping listening...") - } - } - - if (clientSocket != null) { - try { - threadPool.execute(new Thread { - override def run() { - val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - val ois = new ObjectInputStream(clientSocket.getInputStream) - try { - val uuid = ois.readObject.asInstanceOf[UUID] - var guidePort = - if (valueToGuidePortMap.contains(uuid)) { - valueToGuidePortMap(uuid) - } else SourceInfo.TxNotStartedRetry - logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort) - oos.writeObject(guidePort) - } catch { - case e: Exception => { - logInfo("TrackMultipleValues had a " + e) - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - }) - } catch { - // In failure, close() socket here; else, client thread will close() - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - serverSocket.close() - } - - // Shutdown the thread pool - threadPool.shutdown() - } - } -}