Merge branch 'mos-bt'
This merge keeps only the broadcast work in mos-bt because the structure of shuffle has changed with the new RDD design. We still need some kind of parallel shuffle but that will be added later. Conflicts: core/src/main/scala/spark/BitTorrentBroadcast.scala core/src/main/scala/spark/ChainedBroadcast.scala core/src/main/scala/spark/RDD.scala core/src/main/scala/spark/SparkContext.scala core/src/main/scala/spark/Utils.scala core/src/main/scala/spark/shuffle/BasicLocalFileShuffle.scala core/src/main/scala/spark/shuffle/DfsShuffle.scala
This commit is contained in:
commit
c4dd68ae21
File diff suppressed because it is too large
Load diff
|
@ -1,140 +0,0 @@
|
|||
package spark
|
||||
|
||||
import java.util.{BitSet, UUID}
|
||||
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
|
||||
|
||||
@serializable
|
||||
trait Broadcast[T] {
|
||||
val uuid = UUID.randomUUID
|
||||
|
||||
def value: T
|
||||
|
||||
// We cannot have an abstract readObject here due to some weird issues with
|
||||
// readObject having to be 'private' in sub-classes. Possibly a Scala bug!
|
||||
|
||||
override def toString = "spark.Broadcast(" + uuid + ")"
|
||||
}
|
||||
|
||||
trait BroadcastFactory {
|
||||
def initialize (isMaster: Boolean): Unit
|
||||
def newBroadcast[T] (value_ : T, isLocal: Boolean): Broadcast[T]
|
||||
}
|
||||
|
||||
private object Broadcast
|
||||
extends Logging {
|
||||
private var initialized = false
|
||||
private var broadcastFactory: BroadcastFactory = null
|
||||
|
||||
// Called by SparkContext or Executor before using Broadcast
|
||||
def initialize (isMaster: Boolean): Unit = synchronized {
|
||||
if (!initialized) {
|
||||
val broadcastFactoryClass = System.getProperty("spark.broadcast.factory",
|
||||
"spark.DfsBroadcastFactory")
|
||||
val booleanArgs = Array[AnyRef] (isMaster.asInstanceOf[AnyRef])
|
||||
|
||||
broadcastFactory =
|
||||
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
|
||||
|
||||
// Initialize appropriate BroadcastFactory and BroadcastObject
|
||||
broadcastFactory.initialize(isMaster)
|
||||
|
||||
initialized = true
|
||||
}
|
||||
}
|
||||
|
||||
def getBroadcastFactory: BroadcastFactory = {
|
||||
if (broadcastFactory == null) {
|
||||
throw new SparkException ("Broadcast.getBroadcastFactory called before initialize")
|
||||
}
|
||||
broadcastFactory
|
||||
}
|
||||
|
||||
// Returns a standard ThreadFactory except all threads are daemons
|
||||
private def newDaemonThreadFactory: ThreadFactory = {
|
||||
new ThreadFactory {
|
||||
def newThread(r: Runnable): Thread = {
|
||||
var t = Executors.defaultThreadFactory.newThread (r)
|
||||
t.setDaemon (true)
|
||||
return t
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wrapper over newCachedThreadPool
|
||||
def newDaemonCachedThreadPool: ThreadPoolExecutor = {
|
||||
var threadPool =
|
||||
Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
|
||||
|
||||
threadPool.setThreadFactory (newDaemonThreadFactory)
|
||||
|
||||
return threadPool
|
||||
}
|
||||
|
||||
// Wrapper over newFixedThreadPool
|
||||
def newDaemonFixedThreadPool (nThreads: Int): ThreadPoolExecutor = {
|
||||
var threadPool =
|
||||
Executors.newFixedThreadPool (nThreads).asInstanceOf[ThreadPoolExecutor]
|
||||
|
||||
threadPool.setThreadFactory (newDaemonThreadFactory)
|
||||
|
||||
return threadPool
|
||||
}
|
||||
}
|
||||
|
||||
@serializable
|
||||
case class SourceInfo (val hostAddress: String, val listenPort: Int,
|
||||
val totalBlocks: Int, val totalBytes: Int)
|
||||
extends Comparable[SourceInfo] with Logging {
|
||||
|
||||
var currentLeechers = 0
|
||||
var receptionFailed = false
|
||||
|
||||
var hasBlocks = 0
|
||||
var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
|
||||
|
||||
// Ascending sort based on leecher count
|
||||
def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)}
|
||||
|
||||
object SourceInfo {
|
||||
// Constants for special values of listenPort
|
||||
val TxNotStartedRetry = -1
|
||||
val TxOverGoToHDFS = 0
|
||||
// Other constants
|
||||
val StopBroadcast = -2
|
||||
val UnusedParam = 0
|
||||
}
|
||||
|
||||
@serializable
|
||||
case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { }
|
||||
|
||||
@serializable
|
||||
case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock],
|
||||
val totalBlocks: Int, val totalBytes: Int) {
|
||||
@transient var hasBlocks = 0
|
||||
}
|
||||
|
||||
@serializable
|
||||
class SpeedTracker {
|
||||
// Mapping 'source' to '(totalTime, numBlocks)'
|
||||
private var sourceToSpeedMap = Map[SourceInfo, (Long, Int)] ()
|
||||
|
||||
def addDataPoint (srcInfo: SourceInfo, timeInMillis: Long): Unit = {
|
||||
sourceToSpeedMap.synchronized {
|
||||
if (!sourceToSpeedMap.contains(srcInfo)) {
|
||||
sourceToSpeedMap += (srcInfo -> (timeInMillis, 1))
|
||||
} else {
|
||||
val tTnB = sourceToSpeedMap (srcInfo)
|
||||
sourceToSpeedMap += (srcInfo -> (tTnB._1 + timeInMillis, tTnB._2 + 1))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def getTimePerBlock (srcInfo: SourceInfo): Double = {
|
||||
sourceToSpeedMap.synchronized {
|
||||
val tTnB = sourceToSpeedMap (srcInfo)
|
||||
return tTnB._1 / tTnB._2
|
||||
}
|
||||
}
|
||||
|
||||
override def toString = sourceToSpeedMap.toString
|
||||
}
|
|
@ -1,873 +0,0 @@
|
|||
package spark
|
||||
|
||||
import java.io._
|
||||
import java.net._
|
||||
import java.util.{Comparator, PriorityQueue, Random, UUID}
|
||||
|
||||
import scala.collection.mutable.{Map, Set}
|
||||
|
||||
@serializable
|
||||
class ChainedBroadcast[T] (@transient var value_ : T, isLocal: Boolean)
|
||||
extends Broadcast[T] with Logging {
|
||||
|
||||
def value = value_
|
||||
|
||||
ChainedBroadcast.synchronized {
|
||||
ChainedBroadcast.values.put (uuid, value_)
|
||||
}
|
||||
|
||||
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
|
||||
@transient var totalBytes = -1
|
||||
@transient var totalBlocks = -1
|
||||
@transient var hasBlocks = 0
|
||||
|
||||
@transient var listenPortLock = new Object
|
||||
@transient var guidePortLock = new Object
|
||||
@transient var totalBlocksLock = new Object
|
||||
@transient var hasBlocksLock = new Object
|
||||
|
||||
@transient var pqOfSources = new PriorityQueue[SourceInfo]
|
||||
|
||||
@transient var serveMR: ServeMultipleRequests = null
|
||||
@transient var guideMR: GuideMultipleRequests = null
|
||||
|
||||
@transient var hostAddress = InetAddress.getLocalHost.getHostAddress
|
||||
@transient var listenPort = -1
|
||||
@transient var guidePort = -1
|
||||
|
||||
@transient var hasCopyInHDFS = false
|
||||
@transient var stopBroadcast = false
|
||||
|
||||
// Must call this after all the variables have been created/initialized
|
||||
if (!isLocal) {
|
||||
sendBroadcast
|
||||
}
|
||||
|
||||
def sendBroadcast (): Unit = {
|
||||
logInfo ("Local host address: " + hostAddress)
|
||||
|
||||
// Store a persistent copy in HDFS
|
||||
// TODO: Turned OFF for now
|
||||
// val out = new ObjectOutputStream (DfsBroadcast.openFileForWriting(uuid))
|
||||
// out.writeObject (value_)
|
||||
// out.close
|
||||
// TODO: Fix this at some point
|
||||
hasCopyInHDFS = true
|
||||
|
||||
// Create a variableInfo object and store it in valueInfos
|
||||
var variableInfo = blockifyObject (value_, ChainedBroadcast.BlockSize)
|
||||
|
||||
guideMR = new GuideMultipleRequests
|
||||
guideMR.setDaemon (true)
|
||||
guideMR.start
|
||||
logInfo ("GuideMultipleRequests started...")
|
||||
|
||||
serveMR = new ServeMultipleRequests
|
||||
serveMR.setDaemon (true)
|
||||
serveMR.start
|
||||
logInfo ("ServeMultipleRequests started...")
|
||||
|
||||
// Prepare the value being broadcasted
|
||||
// TODO: Refactoring and clean-up required here
|
||||
arrayOfBlocks = variableInfo.arrayOfBlocks
|
||||
totalBytes = variableInfo.totalBytes
|
||||
totalBlocks = variableInfo.totalBlocks
|
||||
hasBlocks = variableInfo.totalBlocks
|
||||
|
||||
while (listenPort == -1) {
|
||||
listenPortLock.synchronized {
|
||||
listenPortLock.wait
|
||||
}
|
||||
}
|
||||
|
||||
pqOfSources = new PriorityQueue[SourceInfo]
|
||||
val masterSource_0 =
|
||||
SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes)
|
||||
pqOfSources.add (masterSource_0)
|
||||
|
||||
// Register with the Tracker
|
||||
while (guidePort == -1) {
|
||||
guidePortLock.synchronized {
|
||||
guidePortLock.wait
|
||||
}
|
||||
}
|
||||
ChainedBroadcast.registerValue (uuid, guidePort)
|
||||
}
|
||||
|
||||
private def readObject (in: ObjectInputStream): Unit = {
|
||||
in.defaultReadObject
|
||||
ChainedBroadcast.synchronized {
|
||||
val cachedVal = ChainedBroadcast.values.get (uuid)
|
||||
if (cachedVal != null) {
|
||||
value_ = cachedVal.asInstanceOf[T]
|
||||
} else {
|
||||
// Initializing everything because Master will only send null/0 values
|
||||
initializeSlaveVariables
|
||||
|
||||
logInfo ("Local host address: " + hostAddress)
|
||||
|
||||
serveMR = new ServeMultipleRequests
|
||||
serveMR.setDaemon (true)
|
||||
serveMR.start
|
||||
logInfo ("ServeMultipleRequests started...")
|
||||
|
||||
val start = System.nanoTime
|
||||
|
||||
val receptionSucceeded = receiveBroadcast (uuid)
|
||||
// If does not succeed, then get from HDFS copy
|
||||
if (receptionSucceeded) {
|
||||
value_ = unBlockifyObject[T]
|
||||
ChainedBroadcast.values.put (uuid, value_)
|
||||
} else {
|
||||
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
|
||||
value_ = fileIn.readObject.asInstanceOf[T]
|
||||
ChainedBroadcast.values.put(uuid, value_)
|
||||
fileIn.close
|
||||
}
|
||||
|
||||
val time = (System.nanoTime - start) / 1e9
|
||||
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def initializeSlaveVariables: Unit = {
|
||||
arrayOfBlocks = null
|
||||
totalBytes = -1
|
||||
totalBlocks = -1
|
||||
hasBlocks = 0
|
||||
|
||||
listenPortLock = new Object
|
||||
totalBlocksLock = new Object
|
||||
hasBlocksLock = new Object
|
||||
|
||||
serveMR = null
|
||||
|
||||
hostAddress = InetAddress.getLocalHost.getHostAddress
|
||||
listenPort = -1
|
||||
|
||||
stopBroadcast = false
|
||||
}
|
||||
|
||||
private def blockifyObject (obj: T, blockSize: Int): VariableInfo = {
|
||||
val baos = new ByteArrayOutputStream
|
||||
val oos = new ObjectOutputStream (baos)
|
||||
oos.writeObject (obj)
|
||||
oos.close
|
||||
baos.close
|
||||
val byteArray = baos.toByteArray
|
||||
val bais = new ByteArrayInputStream (byteArray)
|
||||
|
||||
var blockNum = (byteArray.length / blockSize)
|
||||
if (byteArray.length % blockSize != 0)
|
||||
blockNum += 1
|
||||
|
||||
var retVal = new Array[BroadcastBlock] (blockNum)
|
||||
var blockID = 0
|
||||
|
||||
for (i <- 0 until (byteArray.length, blockSize)) {
|
||||
val thisBlockSize = math.min (blockSize, byteArray.length - i)
|
||||
var tempByteArray = new Array[Byte] (thisBlockSize)
|
||||
val hasRead = bais.read (tempByteArray, 0, thisBlockSize)
|
||||
|
||||
retVal (blockID) = new BroadcastBlock (blockID, tempByteArray)
|
||||
blockID += 1
|
||||
}
|
||||
bais.close
|
||||
|
||||
var variableInfo = VariableInfo (retVal, blockNum, byteArray.length)
|
||||
variableInfo.hasBlocks = blockNum
|
||||
|
||||
return variableInfo
|
||||
}
|
||||
|
||||
private def unBlockifyObject[A]: A = {
|
||||
var retByteArray = new Array[Byte] (totalBytes)
|
||||
for (i <- 0 until totalBlocks) {
|
||||
System.arraycopy (arrayOfBlocks(i).byteArray, 0, retByteArray,
|
||||
i * ChainedBroadcast.BlockSize, arrayOfBlocks(i).byteArray.length)
|
||||
}
|
||||
byteArrayToObject (retByteArray)
|
||||
}
|
||||
|
||||
private def byteArrayToObject[A] (bytes: Array[Byte]): A = {
|
||||
val in = new ObjectInputStream (new ByteArrayInputStream (bytes))
|
||||
val retVal = in.readObject.asInstanceOf[A]
|
||||
in.close
|
||||
return retVal
|
||||
}
|
||||
|
||||
def getMasterListenPort (variableUUID: UUID): Int = {
|
||||
var clientSocketToTracker: Socket = null
|
||||
var oosTracker: ObjectOutputStream = null
|
||||
var oisTracker: ObjectInputStream = null
|
||||
|
||||
var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
|
||||
|
||||
var retriesLeft = ChainedBroadcast.MaxRetryCount
|
||||
do {
|
||||
try {
|
||||
// Connect to the tracker to find out the guide
|
||||
val clientSocketToTracker =
|
||||
new Socket(ChainedBroadcast.MasterHostAddress, ChainedBroadcast.MasterTrackerPort)
|
||||
val oosTracker =
|
||||
new ObjectOutputStream (clientSocketToTracker.getOutputStream)
|
||||
oosTracker.flush
|
||||
val oisTracker =
|
||||
new ObjectInputStream (clientSocketToTracker.getInputStream)
|
||||
|
||||
// Send UUID and receive masterListenPort
|
||||
oosTracker.writeObject (uuid)
|
||||
oosTracker.flush
|
||||
masterListenPort = oisTracker.readObject.asInstanceOf[Int]
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo ("getMasterListenPort had a " + e)
|
||||
}
|
||||
} finally {
|
||||
if (oisTracker != null) {
|
||||
oisTracker.close
|
||||
}
|
||||
if (oosTracker != null) {
|
||||
oosTracker.close
|
||||
}
|
||||
if (clientSocketToTracker != null) {
|
||||
clientSocketToTracker.close
|
||||
}
|
||||
}
|
||||
retriesLeft -= 1
|
||||
|
||||
Thread.sleep (ChainedBroadcast.ranGen.nextInt (
|
||||
ChainedBroadcast.MaxKnockInterval - ChainedBroadcast.MinKnockInterval) +
|
||||
ChainedBroadcast.MinKnockInterval)
|
||||
|
||||
} while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
|
||||
|
||||
logInfo ("Got this guidePort from Tracker: " + masterListenPort)
|
||||
return masterListenPort
|
||||
}
|
||||
|
||||
def receiveBroadcast (variableUUID: UUID): Boolean = {
|
||||
val masterListenPort = getMasterListenPort (variableUUID)
|
||||
|
||||
if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
|
||||
masterListenPort == SourceInfo.TxNotStartedRetry) {
|
||||
// TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
|
||||
// to HDFS anyway when receiveBroadcast returns false
|
||||
return false
|
||||
}
|
||||
|
||||
// Wait until hostAddress and listenPort are created by the
|
||||
// ServeMultipleRequests thread
|
||||
while (listenPort == -1) {
|
||||
listenPortLock.synchronized {
|
||||
listenPortLock.wait
|
||||
}
|
||||
}
|
||||
|
||||
var clientSocketToMaster: Socket = null
|
||||
var oosMaster: ObjectOutputStream = null
|
||||
var oisMaster: ObjectInputStream = null
|
||||
|
||||
// Connect and receive broadcast from the specified source, retrying the
|
||||
// specified number of times in case of failures
|
||||
var retriesLeft = ChainedBroadcast.MaxRetryCount
|
||||
do {
|
||||
// Connect to Master and send this worker's Information
|
||||
clientSocketToMaster =
|
||||
new Socket(ChainedBroadcast.MasterHostAddress, masterListenPort)
|
||||
// TODO: Guiding object connection is reusable
|
||||
oosMaster =
|
||||
new ObjectOutputStream (clientSocketToMaster.getOutputStream)
|
||||
oosMaster.flush
|
||||
oisMaster =
|
||||
new ObjectInputStream (clientSocketToMaster.getInputStream)
|
||||
|
||||
logInfo ("Connected to Master's guiding object")
|
||||
|
||||
// Send local source information
|
||||
oosMaster.writeObject(SourceInfo (hostAddress, listenPort,
|
||||
SourceInfo.UnusedParam, SourceInfo.UnusedParam))
|
||||
oosMaster.flush
|
||||
|
||||
// Receive source information from Master
|
||||
var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
|
||||
totalBlocks = sourceInfo.totalBlocks
|
||||
arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks)
|
||||
totalBlocksLock.synchronized {
|
||||
totalBlocksLock.notifyAll
|
||||
}
|
||||
totalBytes = sourceInfo.totalBytes
|
||||
|
||||
logInfo ("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
|
||||
|
||||
val start = System.nanoTime
|
||||
val receptionSucceeded = receiveSingleTransmission (sourceInfo)
|
||||
val time = (System.nanoTime - start) / 1e9
|
||||
|
||||
// Updating some statistics in sourceInfo. Master will be using them later
|
||||
if (!receptionSucceeded) {
|
||||
sourceInfo.receptionFailed = true
|
||||
}
|
||||
|
||||
// Send back statistics to the Master
|
||||
oosMaster.writeObject (sourceInfo)
|
||||
|
||||
if (oisMaster != null) {
|
||||
oisMaster.close
|
||||
}
|
||||
if (oosMaster != null) {
|
||||
oosMaster.close
|
||||
}
|
||||
if (clientSocketToMaster != null) {
|
||||
clientSocketToMaster.close
|
||||
}
|
||||
|
||||
retriesLeft -= 1
|
||||
} while (retriesLeft > 0 && hasBlocks < totalBlocks)
|
||||
|
||||
return (hasBlocks == totalBlocks)
|
||||
}
|
||||
|
||||
// Tries to receive broadcast from the source and returns Boolean status.
|
||||
// This might be called multiple times to retry a defined number of times.
|
||||
private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
|
||||
var clientSocketToSource: Socket = null
|
||||
var oosSource: ObjectOutputStream = null
|
||||
var oisSource: ObjectInputStream = null
|
||||
|
||||
var receptionSucceeded = false
|
||||
try {
|
||||
// Connect to the source to get the object itself
|
||||
clientSocketToSource =
|
||||
new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
|
||||
oosSource =
|
||||
new ObjectOutputStream (clientSocketToSource.getOutputStream)
|
||||
oosSource.flush
|
||||
oisSource =
|
||||
new ObjectInputStream (clientSocketToSource.getInputStream)
|
||||
|
||||
logInfo ("Inside receiveSingleTransmission")
|
||||
logInfo ("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
|
||||
|
||||
// Send the range
|
||||
oosSource.writeObject((hasBlocks, totalBlocks))
|
||||
oosSource.flush
|
||||
|
||||
for (i <- hasBlocks until totalBlocks) {
|
||||
val recvStartTime = System.currentTimeMillis
|
||||
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
|
||||
val receptionTime = (System.currentTimeMillis - recvStartTime)
|
||||
|
||||
logInfo ("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
|
||||
|
||||
arrayOfBlocks(hasBlocks) = bcBlock
|
||||
hasBlocks += 1
|
||||
// Set to true if at least one block is received
|
||||
receptionSucceeded = true
|
||||
hasBlocksLock.synchronized {
|
||||
hasBlocksLock.notifyAll
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo ("receiveSingleTransmission had a " + e)
|
||||
}
|
||||
} finally {
|
||||
if (oisSource != null) {
|
||||
oisSource.close
|
||||
}
|
||||
if (oosSource != null) {
|
||||
oosSource.close
|
||||
}
|
||||
if (clientSocketToSource != null) {
|
||||
clientSocketToSource.close
|
||||
}
|
||||
}
|
||||
|
||||
return receptionSucceeded
|
||||
}
|
||||
|
||||
class GuideMultipleRequests
|
||||
extends Thread with Logging {
|
||||
// Keep track of sources that have completed reception
|
||||
private var setOfCompletedSources = Set[SourceInfo] ()
|
||||
|
||||
override def run: Unit = {
|
||||
var threadPool = Broadcast.newDaemonCachedThreadPool
|
||||
var serverSocket: ServerSocket = null
|
||||
|
||||
serverSocket = new ServerSocket (0)
|
||||
guidePort = serverSocket.getLocalPort
|
||||
logInfo ("GuideMultipleRequests => " + serverSocket + " " + guidePort)
|
||||
|
||||
guidePortLock.synchronized {
|
||||
guidePortLock.notifyAll
|
||||
}
|
||||
|
||||
try {
|
||||
// Don't stop until there is a copy in HDFS
|
||||
while (!stopBroadcast || !hasCopyInHDFS) {
|
||||
var clientSocket: Socket = null
|
||||
try {
|
||||
serverSocket.setSoTimeout (ChainedBroadcast.ServerSocketTimeout)
|
||||
clientSocket = serverSocket.accept
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo ("GuideMultipleRequests Timeout.")
|
||||
|
||||
// Stop broadcast if at least one worker has connected and
|
||||
// everyone connected so far are done. Comparing with
|
||||
// pqOfSources.size - 1, because it includes the Guide itself
|
||||
if (pqOfSources.size > 1 &&
|
||||
setOfCompletedSources.size == pqOfSources.size - 1) {
|
||||
stopBroadcast = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if (clientSocket != null) {
|
||||
logInfo ("Guide: Accepted new client connection: " + clientSocket)
|
||||
try {
|
||||
threadPool.execute (new GuideSingleRequest (clientSocket))
|
||||
} catch {
|
||||
// In failure, close the socket here; else, the thread will close it
|
||||
case ioe: IOException => clientSocket.close
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logInfo ("Sending stopBroadcast notifications...")
|
||||
sendStopBroadcastNotifications
|
||||
|
||||
ChainedBroadcast.unregisterValue (uuid)
|
||||
} finally {
|
||||
if (serverSocket != null) {
|
||||
logInfo ("GuideMultipleRequests now stopping...")
|
||||
serverSocket.close
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown the thread pool
|
||||
threadPool.shutdown
|
||||
}
|
||||
|
||||
private def sendStopBroadcastNotifications: Unit = {
|
||||
pqOfSources.synchronized {
|
||||
var pqIter = pqOfSources.iterator
|
||||
while (pqIter.hasNext) {
|
||||
var sourceInfo = pqIter.next
|
||||
|
||||
var guideSocketToSource: Socket = null
|
||||
var gosSource: ObjectOutputStream = null
|
||||
var gisSource: ObjectInputStream = null
|
||||
|
||||
try {
|
||||
// Connect to the source
|
||||
guideSocketToSource =
|
||||
new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
|
||||
gosSource =
|
||||
new ObjectOutputStream (guideSocketToSource.getOutputStream)
|
||||
gosSource.flush
|
||||
gisSource =
|
||||
new ObjectInputStream (guideSocketToSource.getInputStream)
|
||||
|
||||
// Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
|
||||
gosSource.writeObject ((SourceInfo.StopBroadcast,
|
||||
SourceInfo.StopBroadcast))
|
||||
gosSource.flush
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo ("sendStopBroadcastNotifications had a " + e)
|
||||
}
|
||||
} finally {
|
||||
if (gisSource != null) {
|
||||
gisSource.close
|
||||
}
|
||||
if (gosSource != null) {
|
||||
gosSource.close
|
||||
}
|
||||
if (guideSocketToSource != null) {
|
||||
guideSocketToSource.close
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class GuideSingleRequest (val clientSocket: Socket)
|
||||
extends Thread with Logging {
|
||||
private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
|
||||
oos.flush
|
||||
private val ois = new ObjectInputStream (clientSocket.getInputStream)
|
||||
|
||||
private var selectedSourceInfo: SourceInfo = null
|
||||
private var thisWorkerInfo:SourceInfo = null
|
||||
|
||||
override def run: Unit = {
|
||||
try {
|
||||
logInfo ("new GuideSingleRequest is running")
|
||||
// Connecting worker is sending in its hostAddress and listenPort it will
|
||||
// be listening to. Other fields are invalid (SourceInfo.UnusedParam)
|
||||
var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
|
||||
|
||||
pqOfSources.synchronized {
|
||||
// Select a suitable source and send it back to the worker
|
||||
selectedSourceInfo = selectSuitableSource (sourceInfo)
|
||||
logInfo ("Sending selectedSourceInfo: " + selectedSourceInfo)
|
||||
oos.writeObject (selectedSourceInfo)
|
||||
oos.flush
|
||||
|
||||
// Add this new (if it can finish) source to the PQ of sources
|
||||
thisWorkerInfo = SourceInfo (sourceInfo.hostAddress,
|
||||
sourceInfo.listenPort, totalBlocks, totalBytes)
|
||||
logInfo ("Adding possible new source to pqOfSources: " + thisWorkerInfo)
|
||||
pqOfSources.add (thisWorkerInfo)
|
||||
}
|
||||
|
||||
// Wait till the whole transfer is done. Then receive and update source
|
||||
// statistics in pqOfSources
|
||||
sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
|
||||
|
||||
pqOfSources.synchronized {
|
||||
// This should work since SourceInfo is a case class
|
||||
assert (pqOfSources.contains (selectedSourceInfo))
|
||||
|
||||
// Remove first
|
||||
pqOfSources.remove (selectedSourceInfo)
|
||||
// TODO: Removing a source based on just one failure notification!
|
||||
|
||||
// Update sourceInfo and put it back in, IF reception succeeded
|
||||
if (!sourceInfo.receptionFailed) {
|
||||
// Add thisWorkerInfo to sources that have completed reception
|
||||
setOfCompletedSources += thisWorkerInfo
|
||||
|
||||
selectedSourceInfo.currentLeechers -= 1
|
||||
|
||||
// Put it back
|
||||
pqOfSources.add (selectedSourceInfo)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// If something went wrong, e.g., the worker at the other end died etc.
|
||||
// then close everything up
|
||||
case e: Exception => {
|
||||
// Assuming that exception caused due to receiver worker failure.
|
||||
// Remove failed worker from pqOfSources and update leecherCount of
|
||||
// corresponding source worker
|
||||
pqOfSources.synchronized {
|
||||
if (selectedSourceInfo != null) {
|
||||
// Remove first
|
||||
pqOfSources.remove (selectedSourceInfo)
|
||||
// Update leecher count and put it back in
|
||||
selectedSourceInfo.currentLeechers -= 1
|
||||
pqOfSources.add (selectedSourceInfo)
|
||||
}
|
||||
|
||||
// Remove thisWorkerInfo
|
||||
if (pqOfSources != null) {
|
||||
pqOfSources.remove (thisWorkerInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
ois.close
|
||||
oos.close
|
||||
clientSocket.close
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Caller must have a synchronized block on pqOfSources
|
||||
// TODO: If a worker fails to get the broadcasted variable from a source and
|
||||
// comes back to Master, this function might choose the worker itself as a
|
||||
// source tp create a dependency cycle (this worker was put into pqOfSources
|
||||
// as a streming source when it first arrived). The length of this cycle can
|
||||
// be arbitrarily long.
|
||||
private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
|
||||
// Select one based on the ordering strategy (e.g., least leechers etc.)
|
||||
// take is a blocking call removing the element from PQ
|
||||
var selectedSource = pqOfSources.poll
|
||||
assert (selectedSource != null)
|
||||
// Update leecher count
|
||||
selectedSource.currentLeechers += 1
|
||||
// Add it back and then return
|
||||
pqOfSources.add (selectedSource)
|
||||
return selectedSource
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ServeMultipleRequests
|
||||
extends Thread with Logging {
|
||||
override def run: Unit = {
|
||||
var threadPool = Broadcast.newDaemonCachedThreadPool
|
||||
var serverSocket: ServerSocket = null
|
||||
|
||||
serverSocket = new ServerSocket (0)
|
||||
listenPort = serverSocket.getLocalPort
|
||||
logInfo ("ServeMultipleRequests started with " + serverSocket)
|
||||
|
||||
listenPortLock.synchronized {
|
||||
listenPortLock.notifyAll
|
||||
}
|
||||
|
||||
try {
|
||||
while (!stopBroadcast) {
|
||||
var clientSocket: Socket = null
|
||||
try {
|
||||
serverSocket.setSoTimeout (ChainedBroadcast.ServerSocketTimeout)
|
||||
clientSocket = serverSocket.accept
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo ("ServeMultipleRequests Timeout.")
|
||||
}
|
||||
}
|
||||
if (clientSocket != null) {
|
||||
logInfo ("Serve: Accepted new client connection: " + clientSocket)
|
||||
try {
|
||||
threadPool.execute (new ServeSingleRequest (clientSocket))
|
||||
} catch {
|
||||
// In failure, close socket here; else, the thread will close it
|
||||
case ioe: IOException => clientSocket.close
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
if (serverSocket != null) {
|
||||
logInfo ("ServeMultipleRequests now stopping...")
|
||||
serverSocket.close
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown the thread pool
|
||||
threadPool.shutdown
|
||||
}
|
||||
|
||||
class ServeSingleRequest (val clientSocket: Socket)
|
||||
extends Thread with Logging {
|
||||
private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
|
||||
oos.flush
|
||||
private val ois = new ObjectInputStream (clientSocket.getInputStream)
|
||||
|
||||
private var sendFrom = 0
|
||||
private var sendUntil = totalBlocks
|
||||
|
||||
override def run: Unit = {
|
||||
try {
|
||||
logInfo ("new ServeSingleRequest is running")
|
||||
|
||||
// Receive range to send
|
||||
var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
|
||||
sendFrom = rangeToSend._1
|
||||
sendUntil = rangeToSend._2
|
||||
|
||||
if (sendFrom == SourceInfo.StopBroadcast &&
|
||||
sendUntil == SourceInfo.StopBroadcast) {
|
||||
stopBroadcast = true
|
||||
} else {
|
||||
// Carry on
|
||||
sendObject
|
||||
}
|
||||
} catch {
|
||||
// If something went wrong, e.g., the worker at the other end died etc.
|
||||
// then close everything up
|
||||
case e: Exception => {
|
||||
logInfo ("ServeSingleRequest had a " + e)
|
||||
}
|
||||
} finally {
|
||||
logInfo ("ServeSingleRequest is closing streams and sockets")
|
||||
ois.close
|
||||
oos.close
|
||||
clientSocket.close
|
||||
}
|
||||
}
|
||||
|
||||
private def sendObject: Unit = {
|
||||
// Wait till receiving the SourceInfo from Master
|
||||
while (totalBlocks == -1) {
|
||||
totalBlocksLock.synchronized {
|
||||
totalBlocksLock.wait
|
||||
}
|
||||
}
|
||||
|
||||
for (i <- sendFrom until sendUntil) {
|
||||
while (i == hasBlocks) {
|
||||
hasBlocksLock.synchronized {
|
||||
hasBlocksLock.wait
|
||||
}
|
||||
}
|
||||
try {
|
||||
oos.writeObject (arrayOfBlocks(i))
|
||||
oos.flush
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo ("sendObject had a " + e)
|
||||
}
|
||||
}
|
||||
logInfo ("Sent block: " + i + " to " + clientSocket)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ChainedBroadcastFactory
|
||||
extends BroadcastFactory {
|
||||
def initialize (isMaster: Boolean) = ChainedBroadcast.initialize (isMaster)
|
||||
def newBroadcast[T] (value_ : T, isLocal: Boolean) =
|
||||
new ChainedBroadcast[T] (value_, isLocal)
|
||||
}
|
||||
|
||||
private object ChainedBroadcast
|
||||
extends Logging {
|
||||
val values = SparkEnv.get.cache.newKeySpace()
|
||||
|
||||
var valueToGuidePortMap = Map[UUID, Int] ()
|
||||
|
||||
// Random number generator
|
||||
var ranGen = new Random
|
||||
|
||||
private var initialized = false
|
||||
private var isMaster_ = false
|
||||
|
||||
private var MasterHostAddress_ = InetAddress.getLocalHost.getHostAddress
|
||||
private var MasterTrackerPort_ : Int = 22222
|
||||
private var BlockSize_ : Int = 512 * 1024
|
||||
private var MaxRetryCount_ : Int = 2
|
||||
|
||||
private var TrackerSocketTimeout_ : Int = 50000
|
||||
private var ServerSocketTimeout_ : Int = 10000
|
||||
|
||||
private var trackMV: TrackMultipleValues = null
|
||||
|
||||
private var MinKnockInterval_ = 500
|
||||
private var MaxKnockInterval_ = 999
|
||||
|
||||
def initialize (isMaster__ : Boolean): Unit = {
|
||||
synchronized {
|
||||
if (!initialized) {
|
||||
// Fix for issue #42
|
||||
MasterHostAddress_ =
|
||||
System.getProperty ("spark.broadcast.masterHostAddress", "")
|
||||
MasterTrackerPort_ =
|
||||
System.getProperty ("spark.broadcast.masterTrackerPort", "22222").toInt
|
||||
BlockSize_ =
|
||||
System.getProperty ("spark.broadcast.blockSize", "512").toInt * 1024
|
||||
MaxRetryCount_ =
|
||||
System.getProperty ("spark.broadcast.maxRetryCount", "2").toInt
|
||||
|
||||
TrackerSocketTimeout_ =
|
||||
System.getProperty ("spark.broadcast.trackerSocketTimeout", "50000").toInt
|
||||
ServerSocketTimeout_ =
|
||||
System.getProperty ("spark.broadcast.serverSocketTimeout", "10000").toInt
|
||||
|
||||
MinKnockInterval_ =
|
||||
System.getProperty ("spark.broadcast.minKnockInterval", "500").toInt
|
||||
MaxKnockInterval_ =
|
||||
System.getProperty ("spark.broadcast.maxKnockInterval", "999").toInt
|
||||
|
||||
isMaster_ = isMaster__
|
||||
|
||||
if (isMaster) {
|
||||
trackMV = new TrackMultipleValues
|
||||
trackMV.setDaemon (true)
|
||||
trackMV.start
|
||||
logInfo ("TrackMultipleValues started...")
|
||||
}
|
||||
|
||||
// Initialize DfsBroadcast to be used for broadcast variable persistence
|
||||
DfsBroadcast.initialize
|
||||
|
||||
initialized = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def MasterHostAddress = MasterHostAddress_
|
||||
def MasterTrackerPort = MasterTrackerPort_
|
||||
def BlockSize = BlockSize_
|
||||
def MaxRetryCount = MaxRetryCount_
|
||||
|
||||
def TrackerSocketTimeout = TrackerSocketTimeout_
|
||||
def ServerSocketTimeout = ServerSocketTimeout_
|
||||
|
||||
def isMaster = isMaster_
|
||||
|
||||
def MinKnockInterval = MinKnockInterval_
|
||||
def MaxKnockInterval = MaxKnockInterval_
|
||||
|
||||
def registerValue (uuid: UUID, guidePort: Int): Unit = {
|
||||
valueToGuidePortMap.synchronized {
|
||||
valueToGuidePortMap += (uuid -> guidePort)
|
||||
logInfo ("New value registered with the Tracker " + valueToGuidePortMap)
|
||||
}
|
||||
}
|
||||
|
||||
def unregisterValue (uuid: UUID): Unit = {
|
||||
valueToGuidePortMap.synchronized {
|
||||
valueToGuidePortMap (uuid) = SourceInfo.TxOverGoToHDFS
|
||||
logInfo ("Value unregistered from the Tracker " + valueToGuidePortMap)
|
||||
}
|
||||
}
|
||||
|
||||
class TrackMultipleValues
|
||||
extends Thread with Logging {
|
||||
override def run: Unit = {
|
||||
var threadPool = Broadcast.newDaemonCachedThreadPool
|
||||
var serverSocket: ServerSocket = null
|
||||
|
||||
serverSocket = new ServerSocket (ChainedBroadcast.MasterTrackerPort)
|
||||
logInfo ("TrackMultipleValues" + serverSocket)
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
var clientSocket: Socket = null
|
||||
try {
|
||||
serverSocket.setSoTimeout (TrackerSocketTimeout)
|
||||
clientSocket = serverSocket.accept
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo ("TrackMultipleValues Timeout. Stopping listening...")
|
||||
}
|
||||
}
|
||||
|
||||
if (clientSocket != null) {
|
||||
try {
|
||||
threadPool.execute (new Thread {
|
||||
override def run: Unit = {
|
||||
val oos = new ObjectOutputStream (clientSocket.getOutputStream)
|
||||
oos.flush
|
||||
val ois = new ObjectInputStream (clientSocket.getInputStream)
|
||||
try {
|
||||
val uuid = ois.readObject.asInstanceOf[UUID]
|
||||
var guidePort =
|
||||
if (valueToGuidePortMap.contains (uuid)) {
|
||||
valueToGuidePortMap (uuid)
|
||||
} else SourceInfo.TxNotStartedRetry
|
||||
logInfo ("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
|
||||
oos.writeObject (guidePort)
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo ("TrackMultipleValues had a " + e)
|
||||
}
|
||||
} finally {
|
||||
ois.close
|
||||
oos.close
|
||||
clientSocket.close
|
||||
}
|
||||
}
|
||||
})
|
||||
} catch {
|
||||
// In failure, close socket here; else, client thread will close
|
||||
case ioe: IOException => clientSocket.close
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
serverSocket.close
|
||||
}
|
||||
|
||||
// Shutdown the thread pool
|
||||
threadPool.shutdown
|
||||
}
|
||||
}
|
||||
}
|
|
@ -9,6 +9,8 @@ import scala.collection.mutable.ArrayBuffer
|
|||
import mesos.{ExecutorArgs, ExecutorDriver, MesosExecutorDriver}
|
||||
import mesos.{TaskDescription, TaskState, TaskStatus}
|
||||
|
||||
import spark.broadcast._
|
||||
|
||||
/**
|
||||
* The Mesos executor for Spark.
|
||||
*/
|
||||
|
|
|
@ -7,6 +7,7 @@ import java.util.concurrent.atomic.AtomicLong
|
|||
|
||||
import scala.collection.mutable.{ArrayBuffer, HashMap}
|
||||
|
||||
import spark._
|
||||
|
||||
object LocalFileShuffle extends Logging {
|
||||
private var initialized = false
|
||||
|
@ -29,9 +30,9 @@ object LocalFileShuffle extends Logging {
|
|||
while (!foundLocalDir && tries < 10) {
|
||||
tries += 1
|
||||
try {
|
||||
localDirUuid = UUID.randomUUID()
|
||||
localDirUuid = UUID.randomUUID
|
||||
localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
|
||||
if (!localDir.exists()) {
|
||||
if (!localDir.exists) {
|
||||
localDir.mkdirs()
|
||||
foundLocalDir = true
|
||||
}
|
||||
|
@ -47,6 +48,7 @@ object LocalFileShuffle extends Logging {
|
|||
shuffleDir = new File(localDir, "shuffle")
|
||||
shuffleDir.mkdirs()
|
||||
logInfo("Shuffle dir: " + shuffleDir)
|
||||
|
||||
val extServerPort = System.getProperty(
|
||||
"spark.localFileShuffle.external.server.port", "-1").toInt
|
||||
if (extServerPort != -1) {
|
||||
|
@ -65,6 +67,7 @@ object LocalFileShuffle extends Logging {
|
|||
serverUri = server.uri
|
||||
}
|
||||
initialized = true
|
||||
logInfo("Local URI: " + serverUri)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
package spark
|
||||
|
||||
/**
|
||||
* A trait for shuffle system. Given an input RDD and combiner functions
|
||||
* for PairRDDExtras.combineByKey(), returns an output RDD.
|
||||
*/
|
||||
@serializable
|
||||
trait Shuffle[K, V, C] {
|
||||
def compute(input: RDD[(K, V)],
|
||||
numOutputSplits: Int,
|
||||
createCombiner: V => C,
|
||||
mergeValue: (C, V) => C,
|
||||
mergeCombiners: (C, C) => C)
|
||||
: RDD[(K, C)]
|
||||
}
|
|
@ -8,6 +8,7 @@ import scala.collection.mutable.ArrayBuffer
|
|||
import org.apache.hadoop.mapred.InputFormat
|
||||
import org.apache.hadoop.mapred.SequenceFileInputFormat
|
||||
|
||||
import spark.broadcast._
|
||||
|
||||
class SparkContext(
|
||||
master: String,
|
||||
|
|
|
@ -3,6 +3,7 @@ package spark
|
|||
import java.io._
|
||||
import java.net.InetAddress
|
||||
import java.util.UUID
|
||||
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.util.Random
|
||||
|
@ -117,12 +118,43 @@ object Utils {
|
|||
/**
|
||||
* Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4)
|
||||
*/
|
||||
def localIpAddress(): String = {
|
||||
// Get local IP as an array of four bytes
|
||||
val bytes = InetAddress.getLocalHost().getAddress()
|
||||
// Convert the bytes to ints (keeping in mind that they may be negative)
|
||||
// and join them into a string
|
||||
return bytes.map(b => (b.toInt + 256) % 256).mkString(".")
|
||||
def localIpAddress(): String = InetAddress.getLocalHost.getHostAddress
|
||||
|
||||
/**
|
||||
* Returns a standard ThreadFactory except all threads are daemons
|
||||
*/
|
||||
private def newDaemonThreadFactory: ThreadFactory = {
|
||||
new ThreadFactory {
|
||||
def newThread(r: Runnable): Thread = {
|
||||
var t = Executors.defaultThreadFactory.newThread (r)
|
||||
t.setDaemon (true)
|
||||
return t
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrapper over newCachedThreadPool
|
||||
*/
|
||||
def newDaemonCachedThreadPool(): ThreadPoolExecutor = {
|
||||
var threadPool =
|
||||
Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
|
||||
|
||||
threadPool.setThreadFactory (newDaemonThreadFactory)
|
||||
|
||||
return threadPool
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrapper over newFixedThreadPool
|
||||
*/
|
||||
def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
|
||||
var threadPool =
|
||||
Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
|
||||
|
||||
threadPool.setThreadFactory(newDaemonThreadFactory)
|
||||
|
||||
return threadPool
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
1355
core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
Normal file
1355
core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
Normal file
File diff suppressed because it is too large
Load diff
228
core/src/main/scala/spark/broadcast/Broadcast.scala
Normal file
228
core/src/main/scala/spark/broadcast/Broadcast.scala
Normal file
|
@ -0,0 +1,228 @@
|
|||
package spark.broadcast
|
||||
|
||||
import java.io._
|
||||
import java.net._
|
||||
import java.util.{BitSet, UUID}
|
||||
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
|
||||
|
||||
import spark._
|
||||
|
||||
@serializable
|
||||
trait Broadcast[T] {
|
||||
val uuid = UUID.randomUUID
|
||||
|
||||
def value: T
|
||||
|
||||
// We cannot have an abstract readObject here due to some weird issues with
|
||||
// readObject having to be 'private' in sub-classes. Possibly a Scala bug!
|
||||
|
||||
override def toString = "spark.Broadcast(" + uuid + ")"
|
||||
}
|
||||
|
||||
object Broadcast
|
||||
extends Logging {
|
||||
// Messages
|
||||
val REGISTER_BROADCAST_TRACKER = 0
|
||||
val UNREGISTER_BROADCAST_TRACKER = 1
|
||||
val FIND_BROADCAST_TRACKER = 2
|
||||
val GET_UPDATED_SHARE = 3
|
||||
|
||||
private var initialized = false
|
||||
private var isMaster_ = false
|
||||
private var broadcastFactory: BroadcastFactory = null
|
||||
|
||||
// Called by SparkContext or Executor before using Broadcast
|
||||
def initialize (isMaster__ : Boolean): Unit = synchronized {
|
||||
if (!initialized) {
|
||||
val broadcastFactoryClass = System.getProperty(
|
||||
"spark.broadcast.factory", "spark.broadcast.DfsBroadcastFactory")
|
||||
|
||||
broadcastFactory =
|
||||
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
|
||||
|
||||
// Setup isMaster before using it
|
||||
isMaster_ = isMaster__
|
||||
|
||||
// Set masterHostAddress to the master's IP address for the slaves to read
|
||||
if (isMaster) {
|
||||
System.setProperty("spark.broadcast.masterHostAddress", Utils.localIpAddress)
|
||||
}
|
||||
|
||||
// Initialize appropriate BroadcastFactory and BroadcastObject
|
||||
broadcastFactory.initialize(isMaster)
|
||||
|
||||
initialized = true
|
||||
}
|
||||
}
|
||||
|
||||
def getBroadcastFactory: BroadcastFactory = {
|
||||
if (broadcastFactory == null) {
|
||||
throw new SparkException ("Broadcast.getBroadcastFactory called before initialize")
|
||||
}
|
||||
broadcastFactory
|
||||
}
|
||||
|
||||
// Load common broadcast-related config parameters
|
||||
private var MasterHostAddress_ = System.getProperty(
|
||||
"spark.broadcast.masterHostAddress", "")
|
||||
private var MasterTrackerPort_ = System.getProperty(
|
||||
"spark.broadcast.masterTrackerPort", "11111").toInt
|
||||
private var BlockSize_ = System.getProperty(
|
||||
"spark.broadcast.blockSize", "4096").toInt * 1024
|
||||
private var MaxRetryCount_ = System.getProperty(
|
||||
"spark.broadcast.maxRetryCount", "2").toInt
|
||||
|
||||
private var TrackerSocketTimeout_ = System.getProperty(
|
||||
"spark.broadcast.trackerSocketTimeout", "50000").toInt
|
||||
private var ServerSocketTimeout_ = System.getProperty(
|
||||
"spark.broadcast.serverSocketTimeout", "10000").toInt
|
||||
|
||||
private var MinKnockInterval_ = System.getProperty(
|
||||
"spark.broadcast.minKnockInterval", "500").toInt
|
||||
private var MaxKnockInterval_ = System.getProperty(
|
||||
"spark.broadcast.maxKnockInterval", "999").toInt
|
||||
|
||||
// Load ChainedBroadcast config params
|
||||
|
||||
// Load TreeBroadcast config params
|
||||
private var MaxDegree_ = System.getProperty("spark.broadcast.maxDegree", "2").toInt
|
||||
|
||||
// Load BitTorrentBroadcast config params
|
||||
private var MaxPeersInGuideResponse_ = System.getProperty(
|
||||
"spark.broadcast.maxPeersInGuideResponse", "4").toInt
|
||||
|
||||
private var MaxRxSlots_ = System.getProperty(
|
||||
"spark.broadcast.maxRxSlots", "4").toInt
|
||||
private var MaxTxSlots_ = System.getProperty(
|
||||
"spark.broadcast.maxTxSlots", "4").toInt
|
||||
|
||||
private var MaxChatTime_ = System.getProperty(
|
||||
"spark.broadcast.maxChatTime", "500").toInt
|
||||
private var MaxChatBlocks_ = System.getProperty(
|
||||
"spark.broadcast.maxChatBlocks", "1024").toInt
|
||||
|
||||
private var EndGameFraction_ = System.getProperty(
|
||||
"spark.broadcast.endGameFraction", "0.95").toDouble
|
||||
|
||||
def isMaster = isMaster_
|
||||
|
||||
// Common config params
|
||||
def MasterHostAddress = MasterHostAddress_
|
||||
def MasterTrackerPort = MasterTrackerPort_
|
||||
def BlockSize = BlockSize_
|
||||
def MaxRetryCount = MaxRetryCount_
|
||||
|
||||
def TrackerSocketTimeout = TrackerSocketTimeout_
|
||||
def ServerSocketTimeout = ServerSocketTimeout_
|
||||
|
||||
def MinKnockInterval = MinKnockInterval_
|
||||
def MaxKnockInterval = MaxKnockInterval_
|
||||
|
||||
// ChainedBroadcast configs
|
||||
|
||||
// TreeBroadcast configs
|
||||
def MaxDegree = MaxDegree_
|
||||
|
||||
// BitTorrentBroadcast configs
|
||||
def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
|
||||
|
||||
def MaxRxSlots = MaxRxSlots_
|
||||
def MaxTxSlots = MaxTxSlots_
|
||||
|
||||
def MaxChatTime = MaxChatTime_
|
||||
def MaxChatBlocks = MaxChatBlocks_
|
||||
|
||||
def EndGameFraction = EndGameFraction_
|
||||
|
||||
// Helper functions to convert an object to Array[BroadcastBlock]
|
||||
def blockifyObject[IN](obj: IN): VariableInfo = {
|
||||
val baos = new ByteArrayOutputStream
|
||||
val oos = new ObjectOutputStream(baos)
|
||||
oos.writeObject(obj)
|
||||
oos.close()
|
||||
baos.close()
|
||||
val byteArray = baos.toByteArray
|
||||
val bais = new ByteArrayInputStream(byteArray)
|
||||
|
||||
var blockNum = (byteArray.length / Broadcast.BlockSize)
|
||||
if (byteArray.length % Broadcast.BlockSize != 0)
|
||||
blockNum += 1
|
||||
|
||||
var retVal = new Array[BroadcastBlock](blockNum)
|
||||
var blockID = 0
|
||||
|
||||
for (i <- 0 until (byteArray.length, Broadcast.BlockSize)) {
|
||||
val thisBlockSize = math.min(Broadcast.BlockSize, byteArray.length - i)
|
||||
var tempByteArray = new Array[Byte](thisBlockSize)
|
||||
val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
|
||||
|
||||
retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
|
||||
blockID += 1
|
||||
}
|
||||
bais.close()
|
||||
|
||||
var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
|
||||
variableInfo.hasBlocks = blockNum
|
||||
|
||||
return variableInfo
|
||||
}
|
||||
|
||||
// Helper function to convert Array[BroadcastBlock] to object
|
||||
def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
|
||||
totalBytes: Int,
|
||||
totalBlocks: Int): OUT = {
|
||||
|
||||
var retByteArray = new Array[Byte](totalBytes)
|
||||
for (i <- 0 until totalBlocks) {
|
||||
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
|
||||
i * Broadcast.BlockSize, arrayOfBlocks(i).byteArray.length)
|
||||
}
|
||||
byteArrayToObject(retByteArray)
|
||||
}
|
||||
|
||||
private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
|
||||
val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
|
||||
override def resolveClass(desc: ObjectStreamClass) =
|
||||
Class.forName(desc.getName, false, currentThread.getContextClassLoader)
|
||||
}
|
||||
val retVal = in.readObject.asInstanceOf[OUT]
|
||||
in.close()
|
||||
return retVal
|
||||
}
|
||||
}
|
||||
|
||||
@serializable
|
||||
case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { }
|
||||
|
||||
@serializable
|
||||
case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock],
|
||||
val totalBlocks: Int,
|
||||
val totalBytes: Int) {
|
||||
@transient var hasBlocks = 0
|
||||
}
|
||||
|
||||
@serializable
|
||||
class SpeedTracker {
|
||||
// Mapping 'source' to '(totalTime, numBlocks)'
|
||||
private var sourceToSpeedMap = Map[SourceInfo, (Long, Int)] ()
|
||||
|
||||
def addDataPoint (srcInfo: SourceInfo, timeInMillis: Long): Unit = {
|
||||
sourceToSpeedMap.synchronized {
|
||||
if (!sourceToSpeedMap.contains(srcInfo)) {
|
||||
sourceToSpeedMap += (srcInfo -> (timeInMillis, 1))
|
||||
} else {
|
||||
val tTnB = sourceToSpeedMap (srcInfo)
|
||||
sourceToSpeedMap += (srcInfo -> (tTnB._1 + timeInMillis, tTnB._2 + 1))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def getTimePerBlock (srcInfo: SourceInfo): Double = {
|
||||
sourceToSpeedMap.synchronized {
|
||||
val tTnB = sourceToSpeedMap (srcInfo)
|
||||
return tTnB._1 / tTnB._2
|
||||
}
|
||||
}
|
||||
|
||||
override def toString = sourceToSpeedMap.toString
|
||||
}
|
12
core/src/main/scala/spark/broadcast/BroadcastFactory.scala
Normal file
12
core/src/main/scala/spark/broadcast/BroadcastFactory.scala
Normal file
|
@ -0,0 +1,12 @@
|
|||
package spark.broadcast
|
||||
|
||||
/**
|
||||
* An interface for all the broadcast implementations in Spark (to allow
|
||||
* multiple broadcast implementations). SparkContext uses a user-specified
|
||||
* BroadcastFactory implementation to instantiate a particular broadcast for the
|
||||
* entire Spark job.
|
||||
*/
|
||||
trait BroadcastFactory {
|
||||
def initialize (isMaster: Boolean): Unit
|
||||
def newBroadcast[T] (value_ : T, isLocal: Boolean): Broadcast[T]
|
||||
}
|
792
core/src/main/scala/spark/broadcast/ChainedBroadcast.scala
Normal file
792
core/src/main/scala/spark/broadcast/ChainedBroadcast.scala
Normal file
|
@ -0,0 +1,792 @@
|
|||
package spark.broadcast
|
||||
|
||||
import java.io._
|
||||
import java.net._
|
||||
import java.util.{Comparator, PriorityQueue, Random, UUID}
|
||||
|
||||
import scala.collection.mutable.{Map, Set}
|
||||
import scala.math
|
||||
|
||||
import spark._
|
||||
|
||||
@serializable
|
||||
class ChainedBroadcast[T](@transient var value_ : T, isLocal: Boolean)
|
||||
extends Broadcast[T] with Logging {
|
||||
|
||||
def value = value_
|
||||
|
||||
ChainedBroadcast.synchronized {
|
||||
ChainedBroadcast.values.put(uuid, value_)
|
||||
}
|
||||
|
||||
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
|
||||
@transient var totalBytes = -1
|
||||
@transient var totalBlocks = -1
|
||||
@transient var hasBlocks = 0
|
||||
// CHANGED: BlockSize in the Broadcast object is expected to change over time
|
||||
@transient var blockSize = Broadcast.BlockSize
|
||||
|
||||
@transient var listenPortLock = new Object
|
||||
@transient var guidePortLock = new Object
|
||||
@transient var totalBlocksLock = new Object
|
||||
@transient var hasBlocksLock = new Object
|
||||
|
||||
@transient var pqOfSources = new PriorityQueue[SourceInfo]
|
||||
|
||||
@transient var serveMR: ServeMultipleRequests = null
|
||||
@transient var guideMR: GuideMultipleRequests = null
|
||||
|
||||
@transient var hostAddress = Utils.localIpAddress
|
||||
@transient var listenPort = -1
|
||||
@transient var guidePort = -1
|
||||
|
||||
@transient var hasCopyInHDFS = false
|
||||
@transient var stopBroadcast = false
|
||||
|
||||
// Must call this after all the variables have been created/initialized
|
||||
if (!isLocal) {
|
||||
sendBroadcast
|
||||
}
|
||||
|
||||
def sendBroadcast(): Unit = {
|
||||
logInfo("Local host address: " + hostAddress)
|
||||
|
||||
// Store a persistent copy in HDFS
|
||||
// TODO: Turned OFF for now
|
||||
// val out = new ObjectOutputStream(DfsBroadcast.openFileForWriting(uuid))
|
||||
// out.writeObject(value_)
|
||||
// out.close()
|
||||
// TODO: Fix this at some point
|
||||
hasCopyInHDFS = true
|
||||
|
||||
// Create a variableInfo object and store it in valueInfos
|
||||
var variableInfo = Broadcast.blockifyObject(value_)
|
||||
|
||||
guideMR = new GuideMultipleRequests
|
||||
guideMR.setDaemon(true)
|
||||
guideMR.start()
|
||||
logInfo("GuideMultipleRequests started...")
|
||||
|
||||
serveMR = new ServeMultipleRequests
|
||||
serveMR.setDaemon(true)
|
||||
serveMR.start()
|
||||
logInfo("ServeMultipleRequests started...")
|
||||
|
||||
// Prepare the value being broadcasted
|
||||
// TODO: Refactoring and clean-up required here
|
||||
arrayOfBlocks = variableInfo.arrayOfBlocks
|
||||
totalBytes = variableInfo.totalBytes
|
||||
totalBlocks = variableInfo.totalBlocks
|
||||
hasBlocks = variableInfo.totalBlocks
|
||||
|
||||
while (listenPort == -1) {
|
||||
listenPortLock.synchronized {
|
||||
listenPortLock.wait
|
||||
}
|
||||
}
|
||||
|
||||
pqOfSources = new PriorityQueue[SourceInfo]
|
||||
val masterSource =
|
||||
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize)
|
||||
pqOfSources.add(masterSource)
|
||||
|
||||
// Register with the Tracker
|
||||
while (guidePort == -1) {
|
||||
guidePortLock.synchronized {
|
||||
guidePortLock.wait
|
||||
}
|
||||
}
|
||||
ChainedBroadcast.registerValue(uuid, guidePort)
|
||||
}
|
||||
|
||||
private def readObject(in: ObjectInputStream): Unit = {
|
||||
in.defaultReadObject
|
||||
ChainedBroadcast.synchronized {
|
||||
val cachedVal = ChainedBroadcast.values.get(uuid)
|
||||
if (cachedVal != null) {
|
||||
value_ = cachedVal.asInstanceOf[T]
|
||||
} else {
|
||||
// Initializing everything because Master will only send null/0 values
|
||||
initializeSlaveVariables
|
||||
|
||||
logInfo("Local host address: " + hostAddress)
|
||||
|
||||
serveMR = new ServeMultipleRequests
|
||||
serveMR.setDaemon(true)
|
||||
serveMR.start()
|
||||
logInfo("ServeMultipleRequests started...")
|
||||
|
||||
val start = System.nanoTime
|
||||
|
||||
val receptionSucceeded = receiveBroadcast(uuid)
|
||||
// If does not succeed, then get from HDFS copy
|
||||
if (receptionSucceeded) {
|
||||
value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
|
||||
ChainedBroadcast.values.put(uuid, value_)
|
||||
} else {
|
||||
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
|
||||
value_ = fileIn.readObject.asInstanceOf[T]
|
||||
ChainedBroadcast.values.put(uuid, value_)
|
||||
fileIn.close()
|
||||
}
|
||||
|
||||
val time =(System.nanoTime - start) / 1e9
|
||||
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def initializeSlaveVariables: Unit = {
|
||||
arrayOfBlocks = null
|
||||
totalBytes = -1
|
||||
totalBlocks = -1
|
||||
hasBlocks = 0
|
||||
blockSize = -1
|
||||
|
||||
listenPortLock = new Object
|
||||
totalBlocksLock = new Object
|
||||
hasBlocksLock = new Object
|
||||
|
||||
serveMR = null
|
||||
|
||||
hostAddress = Utils.localIpAddress
|
||||
listenPort = -1
|
||||
|
||||
stopBroadcast = false
|
||||
}
|
||||
|
||||
def getMasterListenPort(variableUUID: UUID): Int = {
|
||||
var clientSocketToTracker: Socket = null
|
||||
var oosTracker: ObjectOutputStream = null
|
||||
var oisTracker: ObjectInputStream = null
|
||||
|
||||
var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
|
||||
|
||||
var retriesLeft = Broadcast.MaxRetryCount
|
||||
do {
|
||||
try {
|
||||
// Connect to the tracker to find out the guide
|
||||
clientSocketToTracker =
|
||||
new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort)
|
||||
oosTracker =
|
||||
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
|
||||
oosTracker.flush()
|
||||
oisTracker =
|
||||
new ObjectInputStream(clientSocketToTracker.getInputStream)
|
||||
|
||||
// Send UUID and receive masterListenPort
|
||||
oosTracker.writeObject(uuid)
|
||||
oosTracker.flush()
|
||||
masterListenPort = oisTracker.readObject.asInstanceOf[Int]
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo("getMasterListenPort had a " + e)
|
||||
}
|
||||
} finally {
|
||||
if (oisTracker != null) {
|
||||
oisTracker.close()
|
||||
}
|
||||
if (oosTracker != null) {
|
||||
oosTracker.close()
|
||||
}
|
||||
if (clientSocketToTracker != null) {
|
||||
clientSocketToTracker.close()
|
||||
}
|
||||
}
|
||||
retriesLeft -= 1
|
||||
|
||||
Thread.sleep(ChainedBroadcast.ranGen.nextInt(
|
||||
Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) +
|
||||
Broadcast.MinKnockInterval)
|
||||
|
||||
} while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
|
||||
|
||||
logInfo("Got this guidePort from Tracker: " + masterListenPort)
|
||||
return masterListenPort
|
||||
}
|
||||
|
||||
def receiveBroadcast(variableUUID: UUID): Boolean = {
|
||||
val masterListenPort = getMasterListenPort(variableUUID)
|
||||
|
||||
if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
|
||||
masterListenPort == SourceInfo.TxNotStartedRetry) {
|
||||
// TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
|
||||
// to HDFS anyway when receiveBroadcast returns false
|
||||
return false
|
||||
}
|
||||
|
||||
// Wait until hostAddress and listenPort are created by the
|
||||
// ServeMultipleRequests thread
|
||||
while (listenPort == -1) {
|
||||
listenPortLock.synchronized {
|
||||
listenPortLock.wait
|
||||
}
|
||||
}
|
||||
|
||||
var clientSocketToMaster: Socket = null
|
||||
var oosMaster: ObjectOutputStream = null
|
||||
var oisMaster: ObjectInputStream = null
|
||||
|
||||
// Connect and receive broadcast from the specified source, retrying the
|
||||
// specified number of times in case of failures
|
||||
var retriesLeft = Broadcast.MaxRetryCount
|
||||
do {
|
||||
// Connect to Master and send this worker's Information
|
||||
clientSocketToMaster =
|
||||
new Socket(Broadcast.MasterHostAddress, masterListenPort)
|
||||
// TODO: Guiding object connection is reusable
|
||||
oosMaster =
|
||||
new ObjectOutputStream(clientSocketToMaster.getOutputStream)
|
||||
oosMaster.flush()
|
||||
oisMaster =
|
||||
new ObjectInputStream(clientSocketToMaster.getInputStream)
|
||||
|
||||
logInfo("Connected to Master's guiding object")
|
||||
|
||||
// Send local source information
|
||||
oosMaster.writeObject(SourceInfo(hostAddress, listenPort))
|
||||
oosMaster.flush()
|
||||
|
||||
// Receive source information from Master
|
||||
var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
|
||||
totalBlocks = sourceInfo.totalBlocks
|
||||
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
|
||||
totalBlocksLock.synchronized {
|
||||
totalBlocksLock.notifyAll
|
||||
}
|
||||
totalBytes = sourceInfo.totalBytes
|
||||
|
||||
logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
|
||||
|
||||
val start = System.nanoTime
|
||||
val receptionSucceeded = receiveSingleTransmission(sourceInfo)
|
||||
val time =(System.nanoTime - start) / 1e9
|
||||
|
||||
// Updating some statistics in sourceInfo. Master will be using them later
|
||||
if (!receptionSucceeded) {
|
||||
sourceInfo.receptionFailed = true
|
||||
}
|
||||
|
||||
// Send back statistics to the Master
|
||||
oosMaster.writeObject(sourceInfo)
|
||||
|
||||
if (oisMaster != null) {
|
||||
oisMaster.close()
|
||||
}
|
||||
if (oosMaster != null) {
|
||||
oosMaster.close()
|
||||
}
|
||||
if (clientSocketToMaster != null) {
|
||||
clientSocketToMaster.close()
|
||||
}
|
||||
|
||||
retriesLeft -= 1
|
||||
} while (retriesLeft > 0 && hasBlocks < totalBlocks)
|
||||
|
||||
return(hasBlocks == totalBlocks)
|
||||
}
|
||||
|
||||
// Tries to receive broadcast from the source and returns Boolean status.
|
||||
// This might be called multiple times to retry a defined number of times.
|
||||
private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
|
||||
var clientSocketToSource: Socket = null
|
||||
var oosSource: ObjectOutputStream = null
|
||||
var oisSource: ObjectInputStream = null
|
||||
|
||||
var receptionSucceeded = false
|
||||
try {
|
||||
// Connect to the source to get the object itself
|
||||
clientSocketToSource =
|
||||
new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
|
||||
oosSource =
|
||||
new ObjectOutputStream(clientSocketToSource.getOutputStream)
|
||||
oosSource.flush()
|
||||
oisSource =
|
||||
new ObjectInputStream(clientSocketToSource.getInputStream)
|
||||
|
||||
logInfo("Inside receiveSingleTransmission")
|
||||
logInfo("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
|
||||
|
||||
// Send the range
|
||||
oosSource.writeObject((hasBlocks, totalBlocks))
|
||||
oosSource.flush()
|
||||
|
||||
for (i <- hasBlocks until totalBlocks) {
|
||||
val recvStartTime = System.currentTimeMillis
|
||||
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
|
||||
val receptionTime =(System.currentTimeMillis - recvStartTime)
|
||||
|
||||
logInfo("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
|
||||
|
||||
arrayOfBlocks(hasBlocks) = bcBlock
|
||||
hasBlocks += 1
|
||||
// Set to true if at least one block is received
|
||||
receptionSucceeded = true
|
||||
hasBlocksLock.synchronized {
|
||||
hasBlocksLock.notifyAll
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo("receiveSingleTransmission had a " + e)
|
||||
}
|
||||
} finally {
|
||||
if (oisSource != null) {
|
||||
oisSource.close()
|
||||
}
|
||||
if (oosSource != null) {
|
||||
oosSource.close()
|
||||
}
|
||||
if (clientSocketToSource != null) {
|
||||
clientSocketToSource.close()
|
||||
}
|
||||
}
|
||||
|
||||
return receptionSucceeded
|
||||
}
|
||||
|
||||
class GuideMultipleRequests
|
||||
extends Thread with Logging {
|
||||
// Keep track of sources that have completed reception
|
||||
private var setOfCompletedSources = Set[SourceInfo]()
|
||||
|
||||
override def run: Unit = {
|
||||
var threadPool = Utils.newDaemonCachedThreadPool()
|
||||
var serverSocket: ServerSocket = null
|
||||
|
||||
serverSocket = new ServerSocket(0)
|
||||
guidePort = serverSocket.getLocalPort
|
||||
logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
|
||||
|
||||
guidePortLock.synchronized {
|
||||
guidePortLock.notifyAll
|
||||
}
|
||||
|
||||
try {
|
||||
// Don't stop until there is a copy in HDFS
|
||||
while (!stopBroadcast || !hasCopyInHDFS) {
|
||||
var clientSocket: Socket = null
|
||||
try {
|
||||
serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
|
||||
clientSocket = serverSocket.accept
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo("GuideMultipleRequests Timeout.")
|
||||
|
||||
// Stop broadcast if at least one worker has connected and
|
||||
// everyone connected so far are done. Comparing with
|
||||
// pqOfSources.size - 1, because it includes the Guide itself
|
||||
if (pqOfSources.size > 1 &&
|
||||
setOfCompletedSources.size == pqOfSources.size - 1) {
|
||||
stopBroadcast = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if (clientSocket != null) {
|
||||
logInfo("Guide: Accepted new client connection: " + clientSocket)
|
||||
try {
|
||||
threadPool.execute(new GuideSingleRequest(clientSocket))
|
||||
} catch {
|
||||
// In failure, close the socket here; else, the thread will close it
|
||||
case ioe: IOException => clientSocket.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logInfo("Sending stopBroadcast notifications...")
|
||||
sendStopBroadcastNotifications
|
||||
|
||||
ChainedBroadcast.unregisterValue(uuid)
|
||||
} finally {
|
||||
if (serverSocket != null) {
|
||||
logInfo("GuideMultipleRequests now stopping...")
|
||||
serverSocket.close()
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown the thread pool
|
||||
threadPool.shutdown()
|
||||
}
|
||||
|
||||
private def sendStopBroadcastNotifications: Unit = {
|
||||
pqOfSources.synchronized {
|
||||
var pqIter = pqOfSources.iterator
|
||||
while (pqIter.hasNext) {
|
||||
var sourceInfo = pqIter.next
|
||||
|
||||
var guideSocketToSource: Socket = null
|
||||
var gosSource: ObjectOutputStream = null
|
||||
var gisSource: ObjectInputStream = null
|
||||
|
||||
try {
|
||||
// Connect to the source
|
||||
guideSocketToSource =
|
||||
new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
|
||||
gosSource =
|
||||
new ObjectOutputStream(guideSocketToSource.getOutputStream)
|
||||
gosSource.flush()
|
||||
gisSource =
|
||||
new ObjectInputStream(guideSocketToSource.getInputStream)
|
||||
|
||||
// Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
|
||||
gosSource.writeObject((SourceInfo.StopBroadcast,
|
||||
SourceInfo.StopBroadcast))
|
||||
gosSource.flush()
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo("sendStopBroadcastNotifications had a " + e)
|
||||
}
|
||||
} finally {
|
||||
if (gisSource != null) {
|
||||
gisSource.close()
|
||||
}
|
||||
if (gosSource != null) {
|
||||
gosSource.close()
|
||||
}
|
||||
if (guideSocketToSource != null) {
|
||||
guideSocketToSource.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class GuideSingleRequest(val clientSocket: Socket)
|
||||
extends Thread with Logging {
|
||||
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
|
||||
oos.flush()
|
||||
private val ois = new ObjectInputStream(clientSocket.getInputStream)
|
||||
|
||||
private var selectedSourceInfo: SourceInfo = null
|
||||
private var thisWorkerInfo:SourceInfo = null
|
||||
|
||||
override def run: Unit = {
|
||||
try {
|
||||
logInfo("new GuideSingleRequest is running")
|
||||
// Connecting worker is sending in its hostAddress and listenPort it will
|
||||
// be listening to. Other fields are invalid(SourceInfo.UnusedParam)
|
||||
var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
|
||||
|
||||
pqOfSources.synchronized {
|
||||
// Select a suitable source and send it back to the worker
|
||||
selectedSourceInfo = selectSuitableSource(sourceInfo)
|
||||
logInfo("Sending selectedSourceInfo: " + selectedSourceInfo)
|
||||
oos.writeObject(selectedSourceInfo)
|
||||
oos.flush()
|
||||
|
||||
// Add this new(if it can finish) source to the PQ of sources
|
||||
thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
|
||||
sourceInfo.listenPort, totalBlocks, totalBytes, blockSize)
|
||||
logInfo("Adding possible new source to pqOfSources: " + thisWorkerInfo)
|
||||
pqOfSources.add(thisWorkerInfo)
|
||||
}
|
||||
|
||||
// Wait till the whole transfer is done. Then receive and update source
|
||||
// statistics in pqOfSources
|
||||
sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
|
||||
|
||||
pqOfSources.synchronized {
|
||||
// This should work since SourceInfo is a case class
|
||||
assert(pqOfSources.contains(selectedSourceInfo))
|
||||
|
||||
// Remove first
|
||||
pqOfSources.remove(selectedSourceInfo)
|
||||
// TODO: Removing a source based on just one failure notification!
|
||||
|
||||
// Update sourceInfo and put it back in, IF reception succeeded
|
||||
if (!sourceInfo.receptionFailed) {
|
||||
// Add thisWorkerInfo to sources that have completed reception
|
||||
setOfCompletedSources.synchronized {
|
||||
setOfCompletedSources += thisWorkerInfo
|
||||
}
|
||||
|
||||
selectedSourceInfo.currentLeechers -= 1
|
||||
|
||||
// Put it back
|
||||
pqOfSources.add(selectedSourceInfo)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// If something went wrong, e.g., the worker at the other end died etc.
|
||||
// then close everything up
|
||||
case e: Exception => {
|
||||
// Assuming that exception caused due to receiver worker failure.
|
||||
// Remove failed worker from pqOfSources and update leecherCount of
|
||||
// corresponding source worker
|
||||
pqOfSources.synchronized {
|
||||
if (selectedSourceInfo != null) {
|
||||
// Remove first
|
||||
pqOfSources.remove(selectedSourceInfo)
|
||||
// Update leecher count and put it back in
|
||||
selectedSourceInfo.currentLeechers -= 1
|
||||
pqOfSources.add(selectedSourceInfo)
|
||||
}
|
||||
|
||||
// Remove thisWorkerInfo
|
||||
if (pqOfSources != null) {
|
||||
pqOfSources.remove(thisWorkerInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
ois.close()
|
||||
oos.close()
|
||||
clientSocket.close()
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: Caller must have a synchronized block on pqOfSources
|
||||
// FIXME: If a worker fails to get the broadcasted variable from a source and
|
||||
// comes back to Master, this function might choose the worker itself as a
|
||||
// source tp create a dependency cycle(this worker was put into pqOfSources
|
||||
// as a streming source when it first arrived). The length of this cycle can
|
||||
// be arbitrarily long.
|
||||
private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
|
||||
// Select one based on the ordering strategy(e.g., least leechers etc.)
|
||||
// take is a blocking call removing the element from PQ
|
||||
var selectedSource = pqOfSources.poll
|
||||
assert(selectedSource != null)
|
||||
// Update leecher count
|
||||
selectedSource.currentLeechers += 1
|
||||
// Add it back and then return
|
||||
pqOfSources.add(selectedSource)
|
||||
return selectedSource
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ServeMultipleRequests
|
||||
extends Thread with Logging {
|
||||
override def run: Unit = {
|
||||
var threadPool = Utils.newDaemonCachedThreadPool()
|
||||
var serverSocket: ServerSocket = null
|
||||
|
||||
serverSocket = new ServerSocket(0)
|
||||
listenPort = serverSocket.getLocalPort
|
||||
logInfo("ServeMultipleRequests started with " + serverSocket)
|
||||
|
||||
listenPortLock.synchronized {
|
||||
listenPortLock.notifyAll
|
||||
}
|
||||
|
||||
try {
|
||||
while (!stopBroadcast) {
|
||||
var clientSocket: Socket = null
|
||||
try {
|
||||
serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
|
||||
clientSocket = serverSocket.accept
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo("ServeMultipleRequests Timeout.")
|
||||
}
|
||||
}
|
||||
if (clientSocket != null) {
|
||||
logInfo("Serve: Accepted new client connection: " + clientSocket)
|
||||
try {
|
||||
threadPool.execute(new ServeSingleRequest(clientSocket))
|
||||
} catch {
|
||||
// In failure, close socket here; else, the thread will close it
|
||||
case ioe: IOException => clientSocket.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
if (serverSocket != null) {
|
||||
logInfo("ServeMultipleRequests now stopping...")
|
||||
serverSocket.close()
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown the thread pool
|
||||
threadPool.shutdown()
|
||||
}
|
||||
|
||||
class ServeSingleRequest(val clientSocket: Socket)
|
||||
extends Thread with Logging {
|
||||
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
|
||||
oos.flush()
|
||||
private val ois = new ObjectInputStream(clientSocket.getInputStream)
|
||||
|
||||
private var sendFrom = 0
|
||||
private var sendUntil = totalBlocks
|
||||
|
||||
override def run: Unit = {
|
||||
try {
|
||||
logInfo("new ServeSingleRequest is running")
|
||||
|
||||
// Receive range to send
|
||||
var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
|
||||
sendFrom = rangeToSend._1
|
||||
sendUntil = rangeToSend._2
|
||||
|
||||
if (sendFrom == SourceInfo.StopBroadcast &&
|
||||
sendUntil == SourceInfo.StopBroadcast) {
|
||||
stopBroadcast = true
|
||||
} else {
|
||||
// Carry on
|
||||
sendObject
|
||||
}
|
||||
} catch {
|
||||
// If something went wrong, e.g., the worker at the other end died etc.
|
||||
// then close everything up
|
||||
case e: Exception => {
|
||||
logInfo("ServeSingleRequest had a " + e)
|
||||
}
|
||||
} finally {
|
||||
logInfo("ServeSingleRequest is closing streams and sockets")
|
||||
ois.close()
|
||||
oos.close()
|
||||
clientSocket.close()
|
||||
}
|
||||
}
|
||||
|
||||
private def sendObject: Unit = {
|
||||
// Wait till receiving the SourceInfo from Master
|
||||
while (totalBlocks == -1) {
|
||||
totalBlocksLock.synchronized {
|
||||
totalBlocksLock.wait
|
||||
}
|
||||
}
|
||||
|
||||
for (i <- sendFrom until sendUntil) {
|
||||
while (i == hasBlocks) {
|
||||
hasBlocksLock.synchronized {
|
||||
hasBlocksLock.wait
|
||||
}
|
||||
}
|
||||
try {
|
||||
oos.writeObject(arrayOfBlocks(i))
|
||||
oos.flush()
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo("sendObject had a " + e)
|
||||
}
|
||||
}
|
||||
logInfo("Sent block: " + i + " to " + clientSocket)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ChainedBroadcastFactory
|
||||
extends BroadcastFactory {
|
||||
def initialize(isMaster: Boolean) = ChainedBroadcast.initialize(isMaster)
|
||||
def newBroadcast[T](value_ : T, isLocal: Boolean) =
|
||||
new ChainedBroadcast[T](value_, isLocal)
|
||||
}
|
||||
|
||||
private object ChainedBroadcast
|
||||
extends Logging {
|
||||
val values = SparkEnv.get.cache.newKeySpace()
|
||||
|
||||
var valueToGuidePortMap = Map[UUID, Int]()
|
||||
|
||||
// Random number generator
|
||||
var ranGen = new Random
|
||||
|
||||
private var initialized = false
|
||||
private var isMaster_ = false
|
||||
|
||||
private var trackMV: TrackMultipleValues = null
|
||||
|
||||
def initialize(isMaster__ : Boolean): Unit = {
|
||||
synchronized {
|
||||
if (!initialized) {
|
||||
isMaster_ = isMaster__
|
||||
|
||||
if (isMaster) {
|
||||
trackMV = new TrackMultipleValues
|
||||
trackMV.setDaemon(true)
|
||||
trackMV.start()
|
||||
// TODO: Logging the following line makes the Spark framework ID not
|
||||
// getting logged, cause it calls logInfo before log4j is initialized
|
||||
logInfo("TrackMultipleValues started...")
|
||||
}
|
||||
|
||||
// Initialize DfsBroadcast to be used for broadcast variable persistence
|
||||
DfsBroadcast.initialize
|
||||
|
||||
initialized = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def isMaster = isMaster_
|
||||
|
||||
def registerValue(uuid: UUID, guidePort: Int): Unit = {
|
||||
valueToGuidePortMap.synchronized {
|
||||
valueToGuidePortMap +=(uuid -> guidePort)
|
||||
logInfo("New value registered with the Tracker " + valueToGuidePortMap)
|
||||
}
|
||||
}
|
||||
|
||||
def unregisterValue(uuid: UUID): Unit = {
|
||||
valueToGuidePortMap.synchronized {
|
||||
valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToHDFS
|
||||
logInfo("Value unregistered from the Tracker " + valueToGuidePortMap)
|
||||
}
|
||||
}
|
||||
|
||||
class TrackMultipleValues
|
||||
extends Thread with Logging {
|
||||
override def run: Unit = {
|
||||
var threadPool = Utils.newDaemonCachedThreadPool()
|
||||
var serverSocket: ServerSocket = null
|
||||
|
||||
serverSocket = new ServerSocket(Broadcast.MasterTrackerPort)
|
||||
logInfo("TrackMultipleValues" + serverSocket)
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
var clientSocket: Socket = null
|
||||
try {
|
||||
serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout)
|
||||
clientSocket = serverSocket.accept
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo("TrackMultipleValues Timeout. Stopping listening...")
|
||||
}
|
||||
}
|
||||
|
||||
if (clientSocket != null) {
|
||||
try {
|
||||
threadPool.execute(new Thread {
|
||||
override def run: Unit = {
|
||||
val oos = new ObjectOutputStream(clientSocket.getOutputStream)
|
||||
oos.flush()
|
||||
val ois = new ObjectInputStream(clientSocket.getInputStream)
|
||||
try {
|
||||
val uuid = ois.readObject.asInstanceOf[UUID]
|
||||
var guidePort =
|
||||
if (valueToGuidePortMap.contains(uuid)) {
|
||||
valueToGuidePortMap(uuid)
|
||||
} else SourceInfo.TxNotStartedRetry
|
||||
logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
|
||||
oos.writeObject(guidePort)
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo("TrackMultipleValues had a " + e)
|
||||
}
|
||||
} finally {
|
||||
ois.close()
|
||||
oos.close()
|
||||
clientSocket.close()
|
||||
}
|
||||
}
|
||||
})
|
||||
} catch {
|
||||
// In failure, close socket here; else, client thread will close
|
||||
case ioe: IOException => clientSocket.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
serverSocket.close()
|
||||
}
|
||||
|
||||
// Shutdown the thread pool
|
||||
threadPool.shutdown()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,4 +1,6 @@
|
|||
package spark
|
||||
package spark.broadcast
|
||||
|
||||
import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
|
||||
|
||||
import java.io._
|
||||
import java.net._
|
||||
|
@ -7,7 +9,7 @@ import java.util.UUID
|
|||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
|
||||
|
||||
import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
|
||||
import spark._
|
||||
|
||||
@serializable
|
||||
class DfsBroadcast[T](@transient var value_ : T, isLocal: Boolean)
|
41
core/src/main/scala/spark/broadcast/SourceInfo.scala
Normal file
41
core/src/main/scala/spark/broadcast/SourceInfo.scala
Normal file
|
@ -0,0 +1,41 @@
|
|||
package spark.broadcast
|
||||
|
||||
import java.util.BitSet
|
||||
|
||||
import spark._
|
||||
|
||||
/**
|
||||
* Used to keep and pass around information of peers involved in a broadcast
|
||||
*
|
||||
* CHANGED: Keep track of the blockSize for THIS broadcast variable.
|
||||
* Broadcast.BlockSize is expected to be updated across different broadcasts
|
||||
*/
|
||||
@serializable
|
||||
case class SourceInfo (val hostAddress: String,
|
||||
val listenPort: Int,
|
||||
val totalBlocks: Int = SourceInfo.UnusedParam,
|
||||
val totalBytes: Int = SourceInfo.UnusedParam,
|
||||
val blockSize: Int = Broadcast.BlockSize)
|
||||
extends Comparable[SourceInfo] with Logging {
|
||||
|
||||
var currentLeechers = 0
|
||||
var receptionFailed = false
|
||||
|
||||
var hasBlocks = 0
|
||||
var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
|
||||
|
||||
// Ascending sort based on leecher count
|
||||
def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper Object of SourceInfo for its constants
|
||||
*/
|
||||
object SourceInfo {
|
||||
// Constants for special values of listenPort
|
||||
val TxNotStartedRetry = -1
|
||||
val TxOverGoToHDFS = 0
|
||||
// Other constants
|
||||
val StopBroadcast = -2
|
||||
val UnusedParam = 0
|
||||
}
|
807
core/src/main/scala/spark/broadcast/TreeBroadcast.scala
Normal file
807
core/src/main/scala/spark/broadcast/TreeBroadcast.scala
Normal file
|
@ -0,0 +1,807 @@
|
|||
package spark.broadcast
|
||||
|
||||
import java.io._
|
||||
import java.net._
|
||||
import java.util.{Comparator, Random, UUID}
|
||||
|
||||
import scala.collection.mutable.{ListBuffer, Map, Set}
|
||||
import scala.math
|
||||
|
||||
import spark._
|
||||
|
||||
@serializable
|
||||
class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean)
|
||||
extends Broadcast[T] with Logging {
|
||||
|
||||
def value = value_
|
||||
|
||||
TreeBroadcast.synchronized {
|
||||
TreeBroadcast.values.put(uuid, value_)
|
||||
}
|
||||
|
||||
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
|
||||
@transient var totalBytes = -1
|
||||
@transient var totalBlocks = -1
|
||||
@transient var hasBlocks = 0
|
||||
// CHANGED: BlockSize in the Broadcast object is expected to change over time
|
||||
@transient var blockSize = Broadcast.BlockSize
|
||||
|
||||
@transient var listenPortLock = new Object
|
||||
@transient var guidePortLock = new Object
|
||||
@transient var totalBlocksLock = new Object
|
||||
@transient var hasBlocksLock = new Object
|
||||
|
||||
@transient var listOfSources = ListBuffer[SourceInfo]()
|
||||
|
||||
@transient var serveMR: ServeMultipleRequests = null
|
||||
@transient var guideMR: GuideMultipleRequests = null
|
||||
|
||||
@transient var hostAddress = Utils.localIpAddress
|
||||
@transient var listenPort = -1
|
||||
@transient var guidePort = -1
|
||||
|
||||
@transient var hasCopyInHDFS = false
|
||||
@transient var stopBroadcast = false
|
||||
|
||||
// Must call this after all the variables have been created/initialized
|
||||
if (!isLocal) {
|
||||
sendBroadcast
|
||||
}
|
||||
|
||||
def sendBroadcast(): Unit = {
|
||||
logInfo("Local host address: " + hostAddress)
|
||||
|
||||
// Store a persistent copy in HDFS
|
||||
// TODO: Turned OFF for now
|
||||
// val out = new ObjectOutputStream(DfsBroadcast.openFileForWriting(uuid))
|
||||
// out.writeObject(value_)
|
||||
// out.close()
|
||||
// TODO: Fix this at some point
|
||||
hasCopyInHDFS = true
|
||||
|
||||
// Create a variableInfo object and store it in valueInfos
|
||||
var variableInfo = Broadcast.blockifyObject(value_)
|
||||
|
||||
// Prepare the value being broadcasted
|
||||
// TODO: Refactoring and clean-up required here
|
||||
arrayOfBlocks = variableInfo.arrayOfBlocks
|
||||
totalBytes = variableInfo.totalBytes
|
||||
totalBlocks = variableInfo.totalBlocks
|
||||
hasBlocks = variableInfo.totalBlocks
|
||||
|
||||
guideMR = new GuideMultipleRequests
|
||||
guideMR.setDaemon(true)
|
||||
guideMR.start
|
||||
logInfo("GuideMultipleRequests started...")
|
||||
|
||||
// Must always come AFTER guideMR is created
|
||||
while (guidePort == -1) {
|
||||
guidePortLock.synchronized {
|
||||
guidePortLock.wait
|
||||
}
|
||||
}
|
||||
|
||||
serveMR = new ServeMultipleRequests
|
||||
serveMR.setDaemon(true)
|
||||
serveMR.start
|
||||
logInfo("ServeMultipleRequests started...")
|
||||
|
||||
// Must always come AFTER serveMR is created
|
||||
while (listenPort == -1) {
|
||||
listenPortLock.synchronized {
|
||||
listenPortLock.wait
|
||||
}
|
||||
}
|
||||
|
||||
// Must always come AFTER listenPort is created
|
||||
val masterSource =
|
||||
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize)
|
||||
listOfSources += masterSource
|
||||
|
||||
// Register with the Tracker
|
||||
TreeBroadcast.registerValue(uuid, guidePort)
|
||||
}
|
||||
|
||||
private def readObject(in: ObjectInputStream): Unit = {
|
||||
in.defaultReadObject
|
||||
TreeBroadcast.synchronized {
|
||||
val cachedVal = TreeBroadcast.values.get(uuid)
|
||||
if (cachedVal != null) {
|
||||
value_ = cachedVal.asInstanceOf[T]
|
||||
} else {
|
||||
// Initializing everything because Master will only send null/0 values
|
||||
initializeSlaveVariables
|
||||
|
||||
logInfo("Local host address: " + hostAddress)
|
||||
|
||||
serveMR = new ServeMultipleRequests
|
||||
serveMR.setDaemon(true)
|
||||
serveMR.start
|
||||
logInfo("ServeMultipleRequests started...")
|
||||
|
||||
val start = System.nanoTime
|
||||
|
||||
val receptionSucceeded = receiveBroadcast(uuid)
|
||||
// If does not succeed, then get from HDFS copy
|
||||
if (receptionSucceeded) {
|
||||
value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
|
||||
TreeBroadcast.values.put(uuid, value_)
|
||||
} else {
|
||||
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
|
||||
value_ = fileIn.readObject.asInstanceOf[T]
|
||||
TreeBroadcast.values.put(uuid, value_)
|
||||
fileIn.close()
|
||||
}
|
||||
|
||||
val time = (System.nanoTime - start) / 1e9
|
||||
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def initializeSlaveVariables: Unit = {
|
||||
arrayOfBlocks = null
|
||||
totalBytes = -1
|
||||
totalBlocks = -1
|
||||
hasBlocks = 0
|
||||
blockSize = -1
|
||||
|
||||
listenPortLock = new Object
|
||||
totalBlocksLock = new Object
|
||||
hasBlocksLock = new Object
|
||||
|
||||
serveMR = null
|
||||
|
||||
hostAddress = Utils.localIpAddress
|
||||
listenPort = -1
|
||||
|
||||
stopBroadcast = false
|
||||
}
|
||||
|
||||
def getMasterListenPort(variableUUID: UUID): Int = {
|
||||
var clientSocketToTracker: Socket = null
|
||||
var oosTracker: ObjectOutputStream = null
|
||||
var oisTracker: ObjectInputStream = null
|
||||
|
||||
var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
|
||||
|
||||
var retriesLeft = Broadcast.MaxRetryCount
|
||||
do {
|
||||
try {
|
||||
// Connect to the tracker to find out the guide
|
||||
clientSocketToTracker =
|
||||
new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort)
|
||||
oosTracker =
|
||||
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
|
||||
oosTracker.flush()
|
||||
oisTracker =
|
||||
new ObjectInputStream(clientSocketToTracker.getInputStream)
|
||||
|
||||
// Send UUID and receive masterListenPort
|
||||
oosTracker.writeObject(uuid)
|
||||
oosTracker.flush()
|
||||
masterListenPort = oisTracker.readObject.asInstanceOf[Int]
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo("getMasterListenPort had a " + e)
|
||||
}
|
||||
} finally {
|
||||
if (oisTracker != null) {
|
||||
oisTracker.close()
|
||||
}
|
||||
if (oosTracker != null) {
|
||||
oosTracker.close()
|
||||
}
|
||||
if (clientSocketToTracker != null) {
|
||||
clientSocketToTracker.close()
|
||||
}
|
||||
}
|
||||
retriesLeft -= 1
|
||||
|
||||
Thread.sleep(TreeBroadcast.ranGen.nextInt(
|
||||
Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) +
|
||||
Broadcast.MinKnockInterval)
|
||||
|
||||
} while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
|
||||
|
||||
logInfo("Got this guidePort from Tracker: " + masterListenPort)
|
||||
return masterListenPort
|
||||
}
|
||||
|
||||
def receiveBroadcast(variableUUID: UUID): Boolean = {
|
||||
val masterListenPort = getMasterListenPort(variableUUID)
|
||||
|
||||
if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
|
||||
masterListenPort == SourceInfo.TxNotStartedRetry) {
|
||||
// TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
|
||||
// to HDFS anyway when receiveBroadcast returns false
|
||||
return false
|
||||
}
|
||||
|
||||
// Wait until hostAddress and listenPort are created by the
|
||||
// ServeMultipleRequests thread
|
||||
while (listenPort == -1) {
|
||||
listenPortLock.synchronized {
|
||||
listenPortLock.wait
|
||||
}
|
||||
}
|
||||
|
||||
var clientSocketToMaster: Socket = null
|
||||
var oosMaster: ObjectOutputStream = null
|
||||
var oisMaster: ObjectInputStream = null
|
||||
|
||||
// Connect and receive broadcast from the specified source, retrying the
|
||||
// specified number of times in case of failures
|
||||
var retriesLeft = Broadcast.MaxRetryCount
|
||||
do {
|
||||
// Connect to Master and send this worker's Information
|
||||
clientSocketToMaster =
|
||||
new Socket(Broadcast.MasterHostAddress, masterListenPort)
|
||||
// TODO: Guiding object connection is reusable
|
||||
oosMaster =
|
||||
new ObjectOutputStream(clientSocketToMaster.getOutputStream)
|
||||
oosMaster.flush()
|
||||
oisMaster =
|
||||
new ObjectInputStream(clientSocketToMaster.getInputStream)
|
||||
|
||||
logInfo("Connected to Master's guiding object")
|
||||
|
||||
// Send local source information
|
||||
oosMaster.writeObject(SourceInfo(hostAddress, listenPort))
|
||||
oosMaster.flush()
|
||||
|
||||
// Receive source information from Master
|
||||
var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
|
||||
totalBlocks = sourceInfo.totalBlocks
|
||||
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
|
||||
totalBlocksLock.synchronized {
|
||||
totalBlocksLock.notifyAll
|
||||
}
|
||||
totalBytes = sourceInfo.totalBytes
|
||||
blockSize = sourceInfo.blockSize
|
||||
|
||||
logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
|
||||
|
||||
val start = System.nanoTime
|
||||
val receptionSucceeded = receiveSingleTransmission(sourceInfo)
|
||||
val time = (System.nanoTime - start) / 1e9
|
||||
|
||||
// Updating some statistics in sourceInfo. Master will be using them later
|
||||
if (!receptionSucceeded) {
|
||||
sourceInfo.receptionFailed = true
|
||||
}
|
||||
|
||||
// Send back statistics to the Master
|
||||
oosMaster.writeObject(sourceInfo)
|
||||
|
||||
if (oisMaster != null) {
|
||||
oisMaster.close()
|
||||
}
|
||||
if (oosMaster != null) {
|
||||
oosMaster.close()
|
||||
}
|
||||
if (clientSocketToMaster != null) {
|
||||
clientSocketToMaster.close()
|
||||
}
|
||||
|
||||
retriesLeft -= 1
|
||||
} while (retriesLeft > 0 && hasBlocks < totalBlocks)
|
||||
|
||||
return (hasBlocks == totalBlocks)
|
||||
}
|
||||
|
||||
// Tries to receive broadcast from the source and returns Boolean status.
|
||||
// This might be called multiple times to retry a defined number of times.
|
||||
private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
|
||||
var clientSocketToSource: Socket = null
|
||||
var oosSource: ObjectOutputStream = null
|
||||
var oisSource: ObjectInputStream = null
|
||||
|
||||
var receptionSucceeded = false
|
||||
try {
|
||||
// Connect to the source to get the object itself
|
||||
clientSocketToSource =
|
||||
new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
|
||||
oosSource =
|
||||
new ObjectOutputStream(clientSocketToSource.getOutputStream)
|
||||
oosSource.flush()
|
||||
oisSource =
|
||||
new ObjectInputStream(clientSocketToSource.getInputStream)
|
||||
|
||||
logInfo("Inside receiveSingleTransmission")
|
||||
logInfo("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
|
||||
|
||||
// Send the range
|
||||
oosSource.writeObject((hasBlocks, totalBlocks))
|
||||
oosSource.flush()
|
||||
|
||||
for (i <- hasBlocks until totalBlocks) {
|
||||
val recvStartTime = System.currentTimeMillis
|
||||
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
|
||||
val receptionTime = (System.currentTimeMillis - recvStartTime)
|
||||
|
||||
logInfo("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
|
||||
|
||||
arrayOfBlocks(hasBlocks) = bcBlock
|
||||
hasBlocks += 1
|
||||
// Set to true if at least one block is received
|
||||
receptionSucceeded = true
|
||||
hasBlocksLock.synchronized {
|
||||
hasBlocksLock.notifyAll
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo("receiveSingleTransmission had a " + e)
|
||||
}
|
||||
} finally {
|
||||
if (oisSource != null) {
|
||||
oisSource.close()
|
||||
}
|
||||
if (oosSource != null) {
|
||||
oosSource.close()
|
||||
}
|
||||
if (clientSocketToSource != null) {
|
||||
clientSocketToSource.close()
|
||||
}
|
||||
}
|
||||
|
||||
return receptionSucceeded
|
||||
}
|
||||
|
||||
class GuideMultipleRequests
|
||||
extends Thread with Logging {
|
||||
// Keep track of sources that have completed reception
|
||||
private var setOfCompletedSources = Set[SourceInfo]()
|
||||
|
||||
override def run: Unit = {
|
||||
var threadPool = Utils.newDaemonCachedThreadPool()
|
||||
var serverSocket: ServerSocket = null
|
||||
|
||||
serverSocket = new ServerSocket(0)
|
||||
guidePort = serverSocket.getLocalPort
|
||||
logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
|
||||
|
||||
guidePortLock.synchronized {
|
||||
guidePortLock.notifyAll
|
||||
}
|
||||
|
||||
try {
|
||||
// Don't stop until there is a copy in HDFS
|
||||
while (!stopBroadcast || !hasCopyInHDFS) {
|
||||
var clientSocket: Socket = null
|
||||
try {
|
||||
serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
|
||||
clientSocket = serverSocket.accept
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo("GuideMultipleRequests Timeout.")
|
||||
|
||||
// Stop broadcast if at least one worker has connected and
|
||||
// everyone connected so far are done. Comparing with
|
||||
// listOfSources.size - 1, because it includes the Guide itself
|
||||
if (listOfSources.size > 1 &&
|
||||
setOfCompletedSources.size == listOfSources.size - 1) {
|
||||
stopBroadcast = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if (clientSocket != null) {
|
||||
logInfo("Guide: Accepted new client connection: " + clientSocket)
|
||||
try {
|
||||
threadPool.execute(new GuideSingleRequest(clientSocket))
|
||||
} catch {
|
||||
// In failure, close() the socket here; else, the thread will close() it
|
||||
case ioe: IOException => clientSocket.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logInfo("Sending stopBroadcast notifications...")
|
||||
sendStopBroadcastNotifications
|
||||
|
||||
TreeBroadcast.unregisterValue(uuid)
|
||||
} finally {
|
||||
if (serverSocket != null) {
|
||||
logInfo("GuideMultipleRequests now stopping...")
|
||||
serverSocket.close()
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown the thread pool
|
||||
threadPool.shutdown
|
||||
}
|
||||
|
||||
private def sendStopBroadcastNotifications: Unit = {
|
||||
listOfSources.synchronized {
|
||||
var listIter = listOfSources.iterator
|
||||
while (listIter.hasNext) {
|
||||
var sourceInfo = listIter.next
|
||||
|
||||
var guideSocketToSource: Socket = null
|
||||
var gosSource: ObjectOutputStream = null
|
||||
var gisSource: ObjectInputStream = null
|
||||
|
||||
try {
|
||||
// Connect to the source
|
||||
guideSocketToSource =
|
||||
new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
|
||||
gosSource =
|
||||
new ObjectOutputStream(guideSocketToSource.getOutputStream)
|
||||
gosSource.flush()
|
||||
gisSource =
|
||||
new ObjectInputStream(guideSocketToSource.getInputStream)
|
||||
|
||||
// Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
|
||||
gosSource.writeObject((SourceInfo.StopBroadcast,
|
||||
SourceInfo.StopBroadcast))
|
||||
gosSource.flush()
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo("sendStopBroadcastNotifications had a " + e)
|
||||
}
|
||||
} finally {
|
||||
if (gisSource != null) {
|
||||
gisSource.close()
|
||||
}
|
||||
if (gosSource != null) {
|
||||
gosSource.close()
|
||||
}
|
||||
if (guideSocketToSource != null) {
|
||||
guideSocketToSource.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class GuideSingleRequest(val clientSocket: Socket)
|
||||
extends Thread with Logging {
|
||||
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
|
||||
oos.flush()
|
||||
private val ois = new ObjectInputStream(clientSocket.getInputStream)
|
||||
|
||||
private var selectedSourceInfo: SourceInfo = null
|
||||
private var thisWorkerInfo:SourceInfo = null
|
||||
|
||||
override def run: Unit = {
|
||||
try {
|
||||
logInfo("new GuideSingleRequest is running")
|
||||
// Connecting worker is sending in its hostAddress and listenPort it will
|
||||
// be listening to. Other fields are invalid (SourceInfo.UnusedParam)
|
||||
var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
|
||||
|
||||
listOfSources.synchronized {
|
||||
// Select a suitable source and send it back to the worker
|
||||
selectedSourceInfo = selectSuitableSource(sourceInfo)
|
||||
logInfo("Sending selectedSourceInfo: " + selectedSourceInfo)
|
||||
oos.writeObject(selectedSourceInfo)
|
||||
oos.flush()
|
||||
|
||||
// Add this new (if it can finish) source to the list of sources
|
||||
thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
|
||||
sourceInfo.listenPort, totalBlocks, totalBytes, blockSize)
|
||||
logInfo("Adding possible new source to listOfSources: " + thisWorkerInfo)
|
||||
listOfSources += thisWorkerInfo
|
||||
}
|
||||
|
||||
// Wait till the whole transfer is done. Then receive and update source
|
||||
// statistics in listOfSources
|
||||
sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
|
||||
|
||||
listOfSources.synchronized {
|
||||
// This should work since SourceInfo is a case class
|
||||
assert(listOfSources.contains(selectedSourceInfo))
|
||||
|
||||
// Remove first
|
||||
listOfSources = listOfSources - selectedSourceInfo
|
||||
// TODO: Removing a source based on just one failure notification!
|
||||
|
||||
// Update sourceInfo and put it back in, IF reception succeeded
|
||||
if (!sourceInfo.receptionFailed) {
|
||||
// Add thisWorkerInfo to sources that have completed reception
|
||||
setOfCompletedSources.synchronized {
|
||||
setOfCompletedSources += thisWorkerInfo
|
||||
}
|
||||
|
||||
selectedSourceInfo.currentLeechers -= 1
|
||||
|
||||
// Put it back
|
||||
listOfSources += selectedSourceInfo
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// If something went wrong, e.g., the worker at the other end died etc.
|
||||
// then close() everything up
|
||||
case e: Exception => {
|
||||
// Assuming that exception caused due to receiver worker failure.
|
||||
// Remove failed worker from listOfSources and update leecherCount of
|
||||
// corresponding source worker
|
||||
listOfSources.synchronized {
|
||||
if (selectedSourceInfo != null) {
|
||||
// Remove first
|
||||
listOfSources = listOfSources - selectedSourceInfo
|
||||
// Update leecher count and put it back in
|
||||
selectedSourceInfo.currentLeechers -= 1
|
||||
listOfSources += selectedSourceInfo
|
||||
}
|
||||
|
||||
// Remove thisWorkerInfo
|
||||
if (listOfSources != null) {
|
||||
listOfSources = listOfSources - thisWorkerInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
ois.close()
|
||||
oos.close()
|
||||
clientSocket.close()
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: Caller must have a synchronized block on listOfSources
|
||||
// FIXME: If a worker fails to get the broadcasted variable from a source
|
||||
// and comes back to the Master, this function might choose the worker
|
||||
// itself as a source to create a dependency cycle (this worker was put
|
||||
// into listOfSources as a streming source when it first arrived). The
|
||||
// length of this cycle can be arbitrarily long.
|
||||
private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
|
||||
// Select one with the most leechers. This will level-wise fill the tree
|
||||
|
||||
var maxLeechers = -1
|
||||
var selectedSource: SourceInfo = null
|
||||
|
||||
listOfSources.foreach { source =>
|
||||
if (source != skipSourceInfo &&
|
||||
source.currentLeechers < Broadcast.MaxDegree &&
|
||||
source.currentLeechers > maxLeechers) {
|
||||
selectedSource = source
|
||||
maxLeechers = source.currentLeechers
|
||||
}
|
||||
}
|
||||
|
||||
// Update leecher count
|
||||
selectedSource.currentLeechers += 1
|
||||
|
||||
return selectedSource
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ServeMultipleRequests
|
||||
extends Thread with Logging {
|
||||
override def run: Unit = {
|
||||
var threadPool = Utils.newDaemonCachedThreadPool()
|
||||
var serverSocket: ServerSocket = null
|
||||
|
||||
serverSocket = new ServerSocket(0)
|
||||
listenPort = serverSocket.getLocalPort
|
||||
logInfo("ServeMultipleRequests started with " + serverSocket)
|
||||
|
||||
listenPortLock.synchronized {
|
||||
listenPortLock.notifyAll
|
||||
}
|
||||
|
||||
try {
|
||||
while (!stopBroadcast) {
|
||||
var clientSocket: Socket = null
|
||||
try {
|
||||
serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
|
||||
clientSocket = serverSocket.accept
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo("ServeMultipleRequests Timeout.")
|
||||
}
|
||||
}
|
||||
if (clientSocket != null) {
|
||||
logInfo("Serve: Accepted new client connection: " + clientSocket)
|
||||
try {
|
||||
threadPool.execute(new ServeSingleRequest(clientSocket))
|
||||
} catch {
|
||||
// In failure, close() socket here; else, the thread will close() it
|
||||
case ioe: IOException => clientSocket.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
if (serverSocket != null) {
|
||||
logInfo("ServeMultipleRequests now stopping...")
|
||||
serverSocket.close()
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown the thread pool
|
||||
threadPool.shutdown
|
||||
}
|
||||
|
||||
class ServeSingleRequest(val clientSocket: Socket)
|
||||
extends Thread with Logging {
|
||||
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
|
||||
oos.flush()
|
||||
private val ois = new ObjectInputStream(clientSocket.getInputStream)
|
||||
|
||||
private var sendFrom = 0
|
||||
private var sendUntil = totalBlocks
|
||||
|
||||
override def run: Unit = {
|
||||
try {
|
||||
logInfo("new ServeSingleRequest is running")
|
||||
|
||||
// Receive range to send
|
||||
var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
|
||||
sendFrom = rangeToSend._1
|
||||
sendUntil = rangeToSend._2
|
||||
|
||||
if (sendFrom == SourceInfo.StopBroadcast &&
|
||||
sendUntil == SourceInfo.StopBroadcast) {
|
||||
stopBroadcast = true
|
||||
} else {
|
||||
// Carry on
|
||||
sendObject
|
||||
}
|
||||
} catch {
|
||||
// If something went wrong, e.g., the worker at the other end died etc.
|
||||
// then close() everything up
|
||||
case e: Exception => {
|
||||
logInfo("ServeSingleRequest had a " + e)
|
||||
}
|
||||
} finally {
|
||||
logInfo("ServeSingleRequest is closing streams and sockets")
|
||||
ois.close()
|
||||
oos.close()
|
||||
clientSocket.close()
|
||||
}
|
||||
}
|
||||
|
||||
private def sendObject: Unit = {
|
||||
// Wait till receiving the SourceInfo from Master
|
||||
while (totalBlocks == -1) {
|
||||
totalBlocksLock.synchronized {
|
||||
totalBlocksLock.wait
|
||||
}
|
||||
}
|
||||
|
||||
for (i <- sendFrom until sendUntil) {
|
||||
while (i == hasBlocks) {
|
||||
hasBlocksLock.synchronized {
|
||||
hasBlocksLock.wait
|
||||
}
|
||||
}
|
||||
try {
|
||||
oos.writeObject(arrayOfBlocks(i))
|
||||
oos.flush()
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo("sendObject had a " + e)
|
||||
}
|
||||
}
|
||||
logInfo("Sent block: " + i + " to " + clientSocket)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class TreeBroadcastFactory
|
||||
extends BroadcastFactory {
|
||||
def initialize(isMaster: Boolean) = TreeBroadcast.initialize(isMaster)
|
||||
def newBroadcast[T](value_ : T, isLocal: Boolean) =
|
||||
new TreeBroadcast[T](value_, isLocal)
|
||||
}
|
||||
|
||||
private object TreeBroadcast
|
||||
extends Logging {
|
||||
val values = SparkEnv.get.cache.newKeySpace()
|
||||
|
||||
var valueToGuidePortMap = Map[UUID, Int]()
|
||||
|
||||
// Random number generator
|
||||
var ranGen = new Random
|
||||
|
||||
private var initialized = false
|
||||
private var isMaster_ = false
|
||||
|
||||
private var trackMV: TrackMultipleValues = null
|
||||
|
||||
private var MaxDegree_ : Int = 2
|
||||
|
||||
def initialize(isMaster__ : Boolean): Unit = {
|
||||
synchronized {
|
||||
if (!initialized) {
|
||||
isMaster_ = isMaster__
|
||||
|
||||
if (isMaster) {
|
||||
trackMV = new TrackMultipleValues
|
||||
trackMV.setDaemon(true)
|
||||
trackMV.start
|
||||
// TODO: Logging the following line makes the Spark framework ID not
|
||||
// getting logged, cause it calls logInfo before log4j is initialized
|
||||
logInfo("TrackMultipleValues started...")
|
||||
}
|
||||
|
||||
// Initialize DfsBroadcast to be used for broadcast variable persistence
|
||||
DfsBroadcast.initialize
|
||||
|
||||
initialized = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def isMaster = isMaster_
|
||||
|
||||
def registerValue(uuid: UUID, guidePort: Int): Unit = {
|
||||
valueToGuidePortMap.synchronized {
|
||||
valueToGuidePortMap += (uuid -> guidePort)
|
||||
logInfo("New value registered with the Tracker " + valueToGuidePortMap)
|
||||
}
|
||||
}
|
||||
|
||||
def unregisterValue(uuid: UUID): Unit = {
|
||||
valueToGuidePortMap.synchronized {
|
||||
valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToHDFS
|
||||
logInfo("Value unregistered from the Tracker " + valueToGuidePortMap)
|
||||
}
|
||||
}
|
||||
|
||||
class TrackMultipleValues
|
||||
extends Thread with Logging {
|
||||
override def run: Unit = {
|
||||
var threadPool = Utils.newDaemonCachedThreadPool()
|
||||
var serverSocket: ServerSocket = null
|
||||
|
||||
serverSocket = new ServerSocket(Broadcast.MasterTrackerPort)
|
||||
logInfo("TrackMultipleValues" + serverSocket)
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
var clientSocket: Socket = null
|
||||
try {
|
||||
serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout)
|
||||
clientSocket = serverSocket.accept
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo("TrackMultipleValues Timeout. Stopping listening...")
|
||||
}
|
||||
}
|
||||
|
||||
if (clientSocket != null) {
|
||||
try {
|
||||
threadPool.execute(new Thread {
|
||||
override def run: Unit = {
|
||||
val oos = new ObjectOutputStream(clientSocket.getOutputStream)
|
||||
oos.flush()
|
||||
val ois = new ObjectInputStream(clientSocket.getInputStream)
|
||||
try {
|
||||
val uuid = ois.readObject.asInstanceOf[UUID]
|
||||
var guidePort =
|
||||
if (valueToGuidePortMap.contains(uuid)) {
|
||||
valueToGuidePortMap(uuid)
|
||||
} else SourceInfo.TxNotStartedRetry
|
||||
logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
|
||||
oos.writeObject(guidePort)
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logInfo("TrackMultipleValues had a " + e)
|
||||
}
|
||||
} finally {
|
||||
ois.close()
|
||||
oos.close()
|
||||
clientSocket.close()
|
||||
}
|
||||
}
|
||||
})
|
||||
} catch {
|
||||
// In failure, close() socket here; else, client thread will close()
|
||||
case ioe: IOException => clientSocket.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
serverSocket.close()
|
||||
}
|
||||
|
||||
// Shutdown the thread pool
|
||||
threadPool.shutdown
|
||||
}
|
||||
}
|
||||
}
|
|
@ -5,9 +5,10 @@ import spark.SparkContext
|
|||
object BroadcastTest {
|
||||
def main(args: Array[String]) {
|
||||
if (args.length == 0) {
|
||||
System.err.println("Usage: BroadcastTest <host> [<slices>]")
|
||||
System.err.println("Usage: BroadcastTest <host> [<slices>] [numElem]")
|
||||
System.exit(1)
|
||||
}
|
||||
|
||||
val spark = new SparkContext(args(0), "Broadcast Test")
|
||||
val slices = if (args.length > 1) args(1).toInt else 2
|
||||
val num = if (args.length > 2) args(2).toInt else 1000000
|
||||
|
@ -16,14 +17,8 @@ object BroadcastTest {
|
|||
for (i <- 0 until arr1.length)
|
||||
arr1(i) = i
|
||||
|
||||
// var arr2 = new Array[Int](num * 2)
|
||||
// for (i <- 0 until arr2.length)
|
||||
// arr2(i) = i
|
||||
|
||||
val barr1 = spark.broadcast(arr1)
|
||||
// val barr2 = spark.broadcast(arr2)
|
||||
spark.parallelize(1 to 10, slices).foreach {
|
||||
// i => println(barr1.value.size + barr2.value.size)
|
||||
i => println(barr1.value.size)
|
||||
}
|
||||
}
|
||||
|
|
37
examples/src/main/scala/spark/examples/GroupByTest.scala
Normal file
37
examples/src/main/scala/spark/examples/GroupByTest.scala
Normal file
|
@ -0,0 +1,37 @@
|
|||
package spark.examples
|
||||
|
||||
import spark.SparkContext
|
||||
import spark.SparkContext._
|
||||
import java.util.Random
|
||||
|
||||
object GroupByTest {
|
||||
def main(args: Array[String]) {
|
||||
if (args.length == 0) {
|
||||
System.err.println("Usage: GroupByTest <host> [numMappers] [numKVPairs] [KeySize] [numReducers]")
|
||||
System.exit(1)
|
||||
}
|
||||
|
||||
var numMappers = if (args.length > 1) args(1).toInt else 2
|
||||
var numKVPairs = if (args.length > 2) args(2).toInt else 1000
|
||||
var valSize = if (args.length > 3) args(3).toInt else 1000
|
||||
var numReducers = if (args.length > 4) args(4).toInt else numMappers
|
||||
|
||||
val sc = new SparkContext(args(0), "GroupBy Test")
|
||||
|
||||
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
|
||||
val ranGen = new Random
|
||||
var arr1 = new Array[(Int, Array[Byte])](numKVPairs)
|
||||
for (i <- 0 until numKVPairs) {
|
||||
val byteArr = new Array[Byte](valSize)
|
||||
ranGen.nextBytes(byteArr)
|
||||
arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr)
|
||||
}
|
||||
arr1
|
||||
}.cache
|
||||
// Enforce that everything has been calculated and in cache
|
||||
pairs1.count
|
||||
|
||||
println(pairs1.groupByKey(numReducers).count)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
package spark.examples
|
||||
|
||||
import spark.SparkContext
|
||||
|
||||
object MultiBroadcastTest {
|
||||
def main(args: Array[String]) {
|
||||
if (args.length == 0) {
|
||||
System.err.println("Usage: BroadcastTest <host> [<slices>] [numElem]")
|
||||
System.exit(1)
|
||||
}
|
||||
|
||||
val spark = new SparkContext(args(0), "Broadcast Test")
|
||||
val slices = if (args.length > 1) args(1).toInt else 2
|
||||
val num = if (args.length > 2) args(2).toInt else 1000000
|
||||
|
||||
var arr1 = new Array[Int](num)
|
||||
for (i <- 0 until arr1.length)
|
||||
arr1(i) = i
|
||||
|
||||
var arr2 = new Array[Int](num)
|
||||
for (i <- 0 until arr2.length)
|
||||
arr2(i) = i
|
||||
|
||||
val barr1 = spark.broadcast(arr1)
|
||||
val barr2 = spark.broadcast(arr2)
|
||||
spark.parallelize(1 to 10, slices).foreach {
|
||||
i => println(barr1.value.size + barr2.value.size)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
package spark.examples
|
||||
|
||||
import spark.SparkContext
|
||||
import spark.SparkContext._
|
||||
import java.util.Random
|
||||
|
||||
object SimpleSkewedGroupByTest {
|
||||
def main(args: Array[String]) {
|
||||
if (args.length == 0) {
|
||||
System.err.println("Usage: SimpleSkewedGroupByTest <host> " +
|
||||
"[numMappers] [numKVPairs] [valSize] [numReducers] [ratio]")
|
||||
System.exit(1)
|
||||
}
|
||||
|
||||
var numMappers = if (args.length > 1) args(1).toInt else 2
|
||||
var numKVPairs = if (args.length > 2) args(2).toInt else 1000
|
||||
var valSize = if (args.length > 3) args(3).toInt else 1000
|
||||
var numReducers = if (args.length > 4) args(4).toInt else numMappers
|
||||
var ratio = if (args.length > 5) args(5).toInt else 5.0
|
||||
|
||||
val sc = new SparkContext(args(0), "GroupBy Test")
|
||||
|
||||
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
|
||||
val ranGen = new Random
|
||||
var result = new Array[(Int, Array[Byte])](numKVPairs)
|
||||
for (i <- 0 until numKVPairs) {
|
||||
val byteArr = new Array[Byte](valSize)
|
||||
ranGen.nextBytes(byteArr)
|
||||
val offset = ranGen.nextInt(1000) * numReducers
|
||||
if (ranGen.nextDouble < ratio / (numReducers + ratio - 1)) {
|
||||
// give ratio times higher chance of generating key 0 (for reducer 0)
|
||||
result(i) = (offset, byteArr)
|
||||
} else {
|
||||
// generate a key for one of the other reducers
|
||||
val key = 1 + ranGen.nextInt(numReducers-1) + offset
|
||||
result(i) = (key, byteArr)
|
||||
}
|
||||
}
|
||||
result
|
||||
}.cache
|
||||
// Enforce that everything has been calculated and in cache
|
||||
pairs1.count
|
||||
|
||||
println("RESULT: " + pairs1.groupByKey(numReducers).count)
|
||||
// Print how many keys each reducer got (for debugging)
|
||||
//println("RESULT: " + pairs1.groupByKey(numReducers)
|
||||
// .map{case (k,v) => (k, v.size)}
|
||||
// .collectAsMap)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
package spark.examples
|
||||
|
||||
import spark.SparkContext
|
||||
import spark.SparkContext._
|
||||
import java.util.Random
|
||||
|
||||
object SkewedGroupByTest {
|
||||
def main(args: Array[String]) {
|
||||
if (args.length == 0) {
|
||||
System.err.println("Usage: GroupByTest <host> [numMappers] [numKVPairs] [KeySize] [numReducers]")
|
||||
System.exit(1)
|
||||
}
|
||||
|
||||
var numMappers = if (args.length > 1) args(1).toInt else 2
|
||||
var numKVPairs = if (args.length > 2) args(2).toInt else 1000
|
||||
var valSize = if (args.length > 3) args(3).toInt else 1000
|
||||
var numReducers = if (args.length > 4) args(4).toInt else numMappers
|
||||
|
||||
val sc = new SparkContext(args(0), "GroupBy Test")
|
||||
|
||||
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
|
||||
val ranGen = new Random
|
||||
|
||||
// map output sizes lineraly increase from the 1st to the last
|
||||
numKVPairs = (1. * (p + 1) / numMappers * numKVPairs).toInt
|
||||
|
||||
var arr1 = new Array[(Int, Array[Byte])](numKVPairs)
|
||||
for (i <- 0 until numKVPairs) {
|
||||
val byteArr = new Array[Byte](valSize)
|
||||
ranGen.nextBytes(byteArr)
|
||||
arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr)
|
||||
}
|
||||
arr1
|
||||
}.cache
|
||||
// Enforce that everything has been calculated and in cache
|
||||
pairs1.count
|
||||
|
||||
println(pairs1.groupByKey(numReducers).count)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in a new issue