Merge branch 'mos-bt'

This merge keeps only the broadcast work in mos-bt because the structure
of shuffle has changed with the new RDD design. We still need some kind
of parallel shuffle but that will be added later.

Conflicts:
	core/src/main/scala/spark/BitTorrentBroadcast.scala
	core/src/main/scala/spark/ChainedBroadcast.scala
	core/src/main/scala/spark/RDD.scala
	core/src/main/scala/spark/SparkContext.scala
	core/src/main/scala/spark/Utils.scala
	core/src/main/scala/spark/shuffle/BasicLocalFileShuffle.scala
	core/src/main/scala/spark/shuffle/DfsShuffle.scala
This commit is contained in:
Matei Zaharia 2011-06-26 18:22:12 -07:00
commit c4dd68ae21
20 changed files with 3446 additions and 2281 deletions

File diff suppressed because it is too large Load diff

View file

@ -1,140 +0,0 @@
package spark
import java.util.{BitSet, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
@serializable
trait Broadcast[T] {
val uuid = UUID.randomUUID
def value: T
// We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes. Possibly a Scala bug!
override def toString = "spark.Broadcast(" + uuid + ")"
}
trait BroadcastFactory {
def initialize (isMaster: Boolean): Unit
def newBroadcast[T] (value_ : T, isLocal: Boolean): Broadcast[T]
}
private object Broadcast
extends Logging {
private var initialized = false
private var broadcastFactory: BroadcastFactory = null
// Called by SparkContext or Executor before using Broadcast
def initialize (isMaster: Boolean): Unit = synchronized {
if (!initialized) {
val broadcastFactoryClass = System.getProperty("spark.broadcast.factory",
"spark.DfsBroadcastFactory")
val booleanArgs = Array[AnyRef] (isMaster.asInstanceOf[AnyRef])
broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isMaster)
initialized = true
}
}
def getBroadcastFactory: BroadcastFactory = {
if (broadcastFactory == null) {
throw new SparkException ("Broadcast.getBroadcastFactory called before initialize")
}
broadcastFactory
}
// Returns a standard ThreadFactory except all threads are daemons
private def newDaemonThreadFactory: ThreadFactory = {
new ThreadFactory {
def newThread(r: Runnable): Thread = {
var t = Executors.defaultThreadFactory.newThread (r)
t.setDaemon (true)
return t
}
}
}
// Wrapper over newCachedThreadPool
def newDaemonCachedThreadPool: ThreadPoolExecutor = {
var threadPool =
Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory (newDaemonThreadFactory)
return threadPool
}
// Wrapper over newFixedThreadPool
def newDaemonFixedThreadPool (nThreads: Int): ThreadPoolExecutor = {
var threadPool =
Executors.newFixedThreadPool (nThreads).asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory (newDaemonThreadFactory)
return threadPool
}
}
@serializable
case class SourceInfo (val hostAddress: String, val listenPort: Int,
val totalBlocks: Int, val totalBytes: Int)
extends Comparable[SourceInfo] with Logging {
var currentLeechers = 0
var receptionFailed = false
var hasBlocks = 0
var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
// Ascending sort based on leecher count
def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)}
object SourceInfo {
// Constants for special values of listenPort
val TxNotStartedRetry = -1
val TxOverGoToHDFS = 0
// Other constants
val StopBroadcast = -2
val UnusedParam = 0
}
@serializable
case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { }
@serializable
case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock],
val totalBlocks: Int, val totalBytes: Int) {
@transient var hasBlocks = 0
}
@serializable
class SpeedTracker {
// Mapping 'source' to '(totalTime, numBlocks)'
private var sourceToSpeedMap = Map[SourceInfo, (Long, Int)] ()
def addDataPoint (srcInfo: SourceInfo, timeInMillis: Long): Unit = {
sourceToSpeedMap.synchronized {
if (!sourceToSpeedMap.contains(srcInfo)) {
sourceToSpeedMap += (srcInfo -> (timeInMillis, 1))
} else {
val tTnB = sourceToSpeedMap (srcInfo)
sourceToSpeedMap += (srcInfo -> (tTnB._1 + timeInMillis, tTnB._2 + 1))
}
}
}
def getTimePerBlock (srcInfo: SourceInfo): Double = {
sourceToSpeedMap.synchronized {
val tTnB = sourceToSpeedMap (srcInfo)
return tTnB._1 / tTnB._2
}
}
override def toString = sourceToSpeedMap.toString
}

View file

@ -1,873 +0,0 @@
package spark
import java.io._
import java.net._
import java.util.{Comparator, PriorityQueue, Random, UUID}
import scala.collection.mutable.{Map, Set}
@serializable
class ChainedBroadcast[T] (@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging {
def value = value_
ChainedBroadcast.synchronized {
ChainedBroadcast.values.put (uuid, value_)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@transient var totalBytes = -1
@transient var totalBlocks = -1
@transient var hasBlocks = 0
@transient var listenPortLock = new Object
@transient var guidePortLock = new Object
@transient var totalBlocksLock = new Object
@transient var hasBlocksLock = new Object
@transient var pqOfSources = new PriorityQueue[SourceInfo]
@transient var serveMR: ServeMultipleRequests = null
@transient var guideMR: GuideMultipleRequests = null
@transient var hostAddress = InetAddress.getLocalHost.getHostAddress
@transient var listenPort = -1
@transient var guidePort = -1
@transient var hasCopyInHDFS = false
@transient var stopBroadcast = false
// Must call this after all the variables have been created/initialized
if (!isLocal) {
sendBroadcast
}
def sendBroadcast (): Unit = {
logInfo ("Local host address: " + hostAddress)
// Store a persistent copy in HDFS
// TODO: Turned OFF for now
// val out = new ObjectOutputStream (DfsBroadcast.openFileForWriting(uuid))
// out.writeObject (value_)
// out.close
// TODO: Fix this at some point
hasCopyInHDFS = true
// Create a variableInfo object and store it in valueInfos
var variableInfo = blockifyObject (value_, ChainedBroadcast.BlockSize)
guideMR = new GuideMultipleRequests
guideMR.setDaemon (true)
guideMR.start
logInfo ("GuideMultipleRequests started...")
serveMR = new ServeMultipleRequests
serveMR.setDaemon (true)
serveMR.start
logInfo ("ServeMultipleRequests started...")
// Prepare the value being broadcasted
// TODO: Refactoring and clean-up required here
arrayOfBlocks = variableInfo.arrayOfBlocks
totalBytes = variableInfo.totalBytes
totalBlocks = variableInfo.totalBlocks
hasBlocks = variableInfo.totalBlocks
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait
}
}
pqOfSources = new PriorityQueue[SourceInfo]
val masterSource_0 =
SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes)
pqOfSources.add (masterSource_0)
// Register with the Tracker
while (guidePort == -1) {
guidePortLock.synchronized {
guidePortLock.wait
}
}
ChainedBroadcast.registerValue (uuid, guidePort)
}
private def readObject (in: ObjectInputStream): Unit = {
in.defaultReadObject
ChainedBroadcast.synchronized {
val cachedVal = ChainedBroadcast.values.get (uuid)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
// Initializing everything because Master will only send null/0 values
initializeSlaveVariables
logInfo ("Local host address: " + hostAddress)
serveMR = new ServeMultipleRequests
serveMR.setDaemon (true)
serveMR.start
logInfo ("ServeMultipleRequests started...")
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast (uuid)
// If does not succeed, then get from HDFS copy
if (receptionSucceeded) {
value_ = unBlockifyObject[T]
ChainedBroadcast.values.put (uuid, value_)
} else {
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
ChainedBroadcast.values.put(uuid, value_)
fileIn.close
}
val time = (System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
}
}
}
private def initializeSlaveVariables: Unit = {
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
hasBlocks = 0
listenPortLock = new Object
totalBlocksLock = new Object
hasBlocksLock = new Object
serveMR = null
hostAddress = InetAddress.getLocalHost.getHostAddress
listenPort = -1
stopBroadcast = false
}
private def blockifyObject (obj: T, blockSize: Int): VariableInfo = {
val baos = new ByteArrayOutputStream
val oos = new ObjectOutputStream (baos)
oos.writeObject (obj)
oos.close
baos.close
val byteArray = baos.toByteArray
val bais = new ByteArrayInputStream (byteArray)
var blockNum = (byteArray.length / blockSize)
if (byteArray.length % blockSize != 0)
blockNum += 1
var retVal = new Array[BroadcastBlock] (blockNum)
var blockID = 0
for (i <- 0 until (byteArray.length, blockSize)) {
val thisBlockSize = math.min (blockSize, byteArray.length - i)
var tempByteArray = new Array[Byte] (thisBlockSize)
val hasRead = bais.read (tempByteArray, 0, thisBlockSize)
retVal (blockID) = new BroadcastBlock (blockID, tempByteArray)
blockID += 1
}
bais.close
var variableInfo = VariableInfo (retVal, blockNum, byteArray.length)
variableInfo.hasBlocks = blockNum
return variableInfo
}
private def unBlockifyObject[A]: A = {
var retByteArray = new Array[Byte] (totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy (arrayOfBlocks(i).byteArray, 0, retByteArray,
i * ChainedBroadcast.BlockSize, arrayOfBlocks(i).byteArray.length)
}
byteArrayToObject (retByteArray)
}
private def byteArrayToObject[A] (bytes: Array[Byte]): A = {
val in = new ObjectInputStream (new ByteArrayInputStream (bytes))
val retVal = in.readObject.asInstanceOf[A]
in.close
return retVal
}
def getMasterListenPort (variableUUID: UUID): Int = {
var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
var retriesLeft = ChainedBroadcast.MaxRetryCount
do {
try {
// Connect to the tracker to find out the guide
val clientSocketToTracker =
new Socket(ChainedBroadcast.MasterHostAddress, ChainedBroadcast.MasterTrackerPort)
val oosTracker =
new ObjectOutputStream (clientSocketToTracker.getOutputStream)
oosTracker.flush
val oisTracker =
new ObjectInputStream (clientSocketToTracker.getInputStream)
// Send UUID and receive masterListenPort
oosTracker.writeObject (uuid)
oosTracker.flush
masterListenPort = oisTracker.readObject.asInstanceOf[Int]
} catch {
case e: Exception => {
logInfo ("getMasterListenPort had a " + e)
}
} finally {
if (oisTracker != null) {
oisTracker.close
}
if (oosTracker != null) {
oosTracker.close
}
if (clientSocketToTracker != null) {
clientSocketToTracker.close
}
}
retriesLeft -= 1
Thread.sleep (ChainedBroadcast.ranGen.nextInt (
ChainedBroadcast.MaxKnockInterval - ChainedBroadcast.MinKnockInterval) +
ChainedBroadcast.MinKnockInterval)
} while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
logInfo ("Got this guidePort from Tracker: " + masterListenPort)
return masterListenPort
}
def receiveBroadcast (variableUUID: UUID): Boolean = {
val masterListenPort = getMasterListenPort (variableUUID)
if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
masterListenPort == SourceInfo.TxNotStartedRetry) {
// TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
// to HDFS anyway when receiveBroadcast returns false
return false
}
// Wait until hostAddress and listenPort are created by the
// ServeMultipleRequests thread
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait
}
}
var clientSocketToMaster: Socket = null
var oosMaster: ObjectOutputStream = null
var oisMaster: ObjectInputStream = null
// Connect and receive broadcast from the specified source, retrying the
// specified number of times in case of failures
var retriesLeft = ChainedBroadcast.MaxRetryCount
do {
// Connect to Master and send this worker's Information
clientSocketToMaster =
new Socket(ChainedBroadcast.MasterHostAddress, masterListenPort)
// TODO: Guiding object connection is reusable
oosMaster =
new ObjectOutputStream (clientSocketToMaster.getOutputStream)
oosMaster.flush
oisMaster =
new ObjectInputStream (clientSocketToMaster.getInputStream)
logInfo ("Connected to Master's guiding object")
// Send local source information
oosMaster.writeObject(SourceInfo (hostAddress, listenPort,
SourceInfo.UnusedParam, SourceInfo.UnusedParam))
oosMaster.flush
// Receive source information from Master
var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
totalBlocks = sourceInfo.totalBlocks
arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks)
totalBlocksLock.synchronized {
totalBlocksLock.notifyAll
}
totalBytes = sourceInfo.totalBytes
logInfo ("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
val start = System.nanoTime
val receptionSucceeded = receiveSingleTransmission (sourceInfo)
val time = (System.nanoTime - start) / 1e9
// Updating some statistics in sourceInfo. Master will be using them later
if (!receptionSucceeded) {
sourceInfo.receptionFailed = true
}
// Send back statistics to the Master
oosMaster.writeObject (sourceInfo)
if (oisMaster != null) {
oisMaster.close
}
if (oosMaster != null) {
oosMaster.close
}
if (clientSocketToMaster != null) {
clientSocketToMaster.close
}
retriesLeft -= 1
} while (retriesLeft > 0 && hasBlocks < totalBlocks)
return (hasBlocks == totalBlocks)
}
// Tries to receive broadcast from the source and returns Boolean status.
// This might be called multiple times to retry a defined number of times.
private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
var clientSocketToSource: Socket = null
var oosSource: ObjectOutputStream = null
var oisSource: ObjectInputStream = null
var receptionSucceeded = false
try {
// Connect to the source to get the object itself
clientSocketToSource =
new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
oosSource =
new ObjectOutputStream (clientSocketToSource.getOutputStream)
oosSource.flush
oisSource =
new ObjectInputStream (clientSocketToSource.getInputStream)
logInfo ("Inside receiveSingleTransmission")
logInfo ("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
// Send the range
oosSource.writeObject((hasBlocks, totalBlocks))
oosSource.flush
for (i <- hasBlocks until totalBlocks) {
val recvStartTime = System.currentTimeMillis
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
val receptionTime = (System.currentTimeMillis - recvStartTime)
logInfo ("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
arrayOfBlocks(hasBlocks) = bcBlock
hasBlocks += 1
// Set to true if at least one block is received
receptionSucceeded = true
hasBlocksLock.synchronized {
hasBlocksLock.notifyAll
}
}
} catch {
case e: Exception => {
logInfo ("receiveSingleTransmission had a " + e)
}
} finally {
if (oisSource != null) {
oisSource.close
}
if (oosSource != null) {
oosSource.close
}
if (clientSocketToSource != null) {
clientSocketToSource.close
}
}
return receptionSucceeded
}
class GuideMultipleRequests
extends Thread with Logging {
// Keep track of sources that have completed reception
private var setOfCompletedSources = Set[SourceInfo] ()
override def run: Unit = {
var threadPool = Broadcast.newDaemonCachedThreadPool
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket (0)
guidePort = serverSocket.getLocalPort
logInfo ("GuideMultipleRequests => " + serverSocket + " " + guidePort)
guidePortLock.synchronized {
guidePortLock.notifyAll
}
try {
// Don't stop until there is a copy in HDFS
while (!stopBroadcast || !hasCopyInHDFS) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout (ChainedBroadcast.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo ("GuideMultipleRequests Timeout.")
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done. Comparing with
// pqOfSources.size - 1, because it includes the Guide itself
if (pqOfSources.size > 1 &&
setOfCompletedSources.size == pqOfSources.size - 1) {
stopBroadcast = true
}
}
}
if (clientSocket != null) {
logInfo ("Guide: Accepted new client connection: " + clientSocket)
try {
threadPool.execute (new GuideSingleRequest (clientSocket))
} catch {
// In failure, close the socket here; else, the thread will close it
case ioe: IOException => clientSocket.close
}
}
}
logInfo ("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
ChainedBroadcast.unregisterValue (uuid)
} finally {
if (serverSocket != null) {
logInfo ("GuideMultipleRequests now stopping...")
serverSocket.close
}
}
// Shutdown the thread pool
threadPool.shutdown
}
private def sendStopBroadcastNotifications: Unit = {
pqOfSources.synchronized {
var pqIter = pqOfSources.iterator
while (pqIter.hasNext) {
var sourceInfo = pqIter.next
var guideSocketToSource: Socket = null
var gosSource: ObjectOutputStream = null
var gisSource: ObjectInputStream = null
try {
// Connect to the source
guideSocketToSource =
new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
gosSource =
new ObjectOutputStream (guideSocketToSource.getOutputStream)
gosSource.flush
gisSource =
new ObjectInputStream (guideSocketToSource.getInputStream)
// Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
gosSource.writeObject ((SourceInfo.StopBroadcast,
SourceInfo.StopBroadcast))
gosSource.flush
} catch {
case e: Exception => {
logInfo ("sendStopBroadcastNotifications had a " + e)
}
} finally {
if (gisSource != null) {
gisSource.close
}
if (gosSource != null) {
gosSource.close
}
if (guideSocketToSource != null) {
guideSocketToSource.close
}
}
}
}
}
class GuideSingleRequest (val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
oos.flush
private val ois = new ObjectInputStream (clientSocket.getInputStream)
private var selectedSourceInfo: SourceInfo = null
private var thisWorkerInfo:SourceInfo = null
override def run: Unit = {
try {
logInfo ("new GuideSingleRequest is running")
// Connecting worker is sending in its hostAddress and listenPort it will
// be listening to. Other fields are invalid (SourceInfo.UnusedParam)
var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
pqOfSources.synchronized {
// Select a suitable source and send it back to the worker
selectedSourceInfo = selectSuitableSource (sourceInfo)
logInfo ("Sending selectedSourceInfo: " + selectedSourceInfo)
oos.writeObject (selectedSourceInfo)
oos.flush
// Add this new (if it can finish) source to the PQ of sources
thisWorkerInfo = SourceInfo (sourceInfo.hostAddress,
sourceInfo.listenPort, totalBlocks, totalBytes)
logInfo ("Adding possible new source to pqOfSources: " + thisWorkerInfo)
pqOfSources.add (thisWorkerInfo)
}
// Wait till the whole transfer is done. Then receive and update source
// statistics in pqOfSources
sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
pqOfSources.synchronized {
// This should work since SourceInfo is a case class
assert (pqOfSources.contains (selectedSourceInfo))
// Remove first
pqOfSources.remove (selectedSourceInfo)
// TODO: Removing a source based on just one failure notification!
// Update sourceInfo and put it back in, IF reception succeeded
if (!sourceInfo.receptionFailed) {
// Add thisWorkerInfo to sources that have completed reception
setOfCompletedSources += thisWorkerInfo
selectedSourceInfo.currentLeechers -= 1
// Put it back
pqOfSources.add (selectedSourceInfo)
}
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close everything up
case e: Exception => {
// Assuming that exception caused due to receiver worker failure.
// Remove failed worker from pqOfSources and update leecherCount of
// corresponding source worker
pqOfSources.synchronized {
if (selectedSourceInfo != null) {
// Remove first
pqOfSources.remove (selectedSourceInfo)
// Update leecher count and put it back in
selectedSourceInfo.currentLeechers -= 1
pqOfSources.add (selectedSourceInfo)
}
// Remove thisWorkerInfo
if (pqOfSources != null) {
pqOfSources.remove (thisWorkerInfo)
}
}
}
} finally {
ois.close
oos.close
clientSocket.close
}
}
// TODO: Caller must have a synchronized block on pqOfSources
// TODO: If a worker fails to get the broadcasted variable from a source and
// comes back to Master, this function might choose the worker itself as a
// source tp create a dependency cycle (this worker was put into pqOfSources
// as a streming source when it first arrived). The length of this cycle can
// be arbitrarily long.
private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
// Select one based on the ordering strategy (e.g., least leechers etc.)
// take is a blocking call removing the element from PQ
var selectedSource = pqOfSources.poll
assert (selectedSource != null)
// Update leecher count
selectedSource.currentLeechers += 1
// Add it back and then return
pqOfSources.add (selectedSource)
return selectedSource
}
}
}
class ServeMultipleRequests
extends Thread with Logging {
override def run: Unit = {
var threadPool = Broadcast.newDaemonCachedThreadPool
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket (0)
listenPort = serverSocket.getLocalPort
logInfo ("ServeMultipleRequests started with " + serverSocket)
listenPortLock.synchronized {
listenPortLock.notifyAll
}
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout (ChainedBroadcast.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo ("ServeMultipleRequests Timeout.")
}
}
if (clientSocket != null) {
logInfo ("Serve: Accepted new client connection: " + clientSocket)
try {
threadPool.execute (new ServeSingleRequest (clientSocket))
} catch {
// In failure, close socket here; else, the thread will close it
case ioe: IOException => clientSocket.close
}
}
}
} finally {
if (serverSocket != null) {
logInfo ("ServeMultipleRequests now stopping...")
serverSocket.close
}
}
// Shutdown the thread pool
threadPool.shutdown
}
class ServeSingleRequest (val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
oos.flush
private val ois = new ObjectInputStream (clientSocket.getInputStream)
private var sendFrom = 0
private var sendUntil = totalBlocks
override def run: Unit = {
try {
logInfo ("new ServeSingleRequest is running")
// Receive range to send
var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
sendFrom = rangeToSend._1
sendUntil = rangeToSend._2
if (sendFrom == SourceInfo.StopBroadcast &&
sendUntil == SourceInfo.StopBroadcast) {
stopBroadcast = true
} else {
// Carry on
sendObject
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close everything up
case e: Exception => {
logInfo ("ServeSingleRequest had a " + e)
}
} finally {
logInfo ("ServeSingleRequest is closing streams and sockets")
ois.close
oos.close
clientSocket.close
}
}
private def sendObject: Unit = {
// Wait till receiving the SourceInfo from Master
while (totalBlocks == -1) {
totalBlocksLock.synchronized {
totalBlocksLock.wait
}
}
for (i <- sendFrom until sendUntil) {
while (i == hasBlocks) {
hasBlocksLock.synchronized {
hasBlocksLock.wait
}
}
try {
oos.writeObject (arrayOfBlocks(i))
oos.flush
} catch {
case e: Exception => {
logInfo ("sendObject had a " + e)
}
}
logInfo ("Sent block: " + i + " to " + clientSocket)
}
}
}
}
}
class ChainedBroadcastFactory
extends BroadcastFactory {
def initialize (isMaster: Boolean) = ChainedBroadcast.initialize (isMaster)
def newBroadcast[T] (value_ : T, isLocal: Boolean) =
new ChainedBroadcast[T] (value_, isLocal)
}
private object ChainedBroadcast
extends Logging {
val values = SparkEnv.get.cache.newKeySpace()
var valueToGuidePortMap = Map[UUID, Int] ()
// Random number generator
var ranGen = new Random
private var initialized = false
private var isMaster_ = false
private var MasterHostAddress_ = InetAddress.getLocalHost.getHostAddress
private var MasterTrackerPort_ : Int = 22222
private var BlockSize_ : Int = 512 * 1024
private var MaxRetryCount_ : Int = 2
private var TrackerSocketTimeout_ : Int = 50000
private var ServerSocketTimeout_ : Int = 10000
private var trackMV: TrackMultipleValues = null
private var MinKnockInterval_ = 500
private var MaxKnockInterval_ = 999
def initialize (isMaster__ : Boolean): Unit = {
synchronized {
if (!initialized) {
// Fix for issue #42
MasterHostAddress_ =
System.getProperty ("spark.broadcast.masterHostAddress", "")
MasterTrackerPort_ =
System.getProperty ("spark.broadcast.masterTrackerPort", "22222").toInt
BlockSize_ =
System.getProperty ("spark.broadcast.blockSize", "512").toInt * 1024
MaxRetryCount_ =
System.getProperty ("spark.broadcast.maxRetryCount", "2").toInt
TrackerSocketTimeout_ =
System.getProperty ("spark.broadcast.trackerSocketTimeout", "50000").toInt
ServerSocketTimeout_ =
System.getProperty ("spark.broadcast.serverSocketTimeout", "10000").toInt
MinKnockInterval_ =
System.getProperty ("spark.broadcast.minKnockInterval", "500").toInt
MaxKnockInterval_ =
System.getProperty ("spark.broadcast.maxKnockInterval", "999").toInt
isMaster_ = isMaster__
if (isMaster) {
trackMV = new TrackMultipleValues
trackMV.setDaemon (true)
trackMV.start
logInfo ("TrackMultipleValues started...")
}
// Initialize DfsBroadcast to be used for broadcast variable persistence
DfsBroadcast.initialize
initialized = true
}
}
}
def MasterHostAddress = MasterHostAddress_
def MasterTrackerPort = MasterTrackerPort_
def BlockSize = BlockSize_
def MaxRetryCount = MaxRetryCount_
def TrackerSocketTimeout = TrackerSocketTimeout_
def ServerSocketTimeout = ServerSocketTimeout_
def isMaster = isMaster_
def MinKnockInterval = MinKnockInterval_
def MaxKnockInterval = MaxKnockInterval_
def registerValue (uuid: UUID, guidePort: Int): Unit = {
valueToGuidePortMap.synchronized {
valueToGuidePortMap += (uuid -> guidePort)
logInfo ("New value registered with the Tracker " + valueToGuidePortMap)
}
}
def unregisterValue (uuid: UUID): Unit = {
valueToGuidePortMap.synchronized {
valueToGuidePortMap (uuid) = SourceInfo.TxOverGoToHDFS
logInfo ("Value unregistered from the Tracker " + valueToGuidePortMap)
}
}
class TrackMultipleValues
extends Thread with Logging {
override def run: Unit = {
var threadPool = Broadcast.newDaemonCachedThreadPool
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket (ChainedBroadcast.MasterTrackerPort)
logInfo ("TrackMultipleValues" + serverSocket)
try {
while (true) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout (TrackerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo ("TrackMultipleValues Timeout. Stopping listening...")
}
}
if (clientSocket != null) {
try {
threadPool.execute (new Thread {
override def run: Unit = {
val oos = new ObjectOutputStream (clientSocket.getOutputStream)
oos.flush
val ois = new ObjectInputStream (clientSocket.getInputStream)
try {
val uuid = ois.readObject.asInstanceOf[UUID]
var guidePort =
if (valueToGuidePortMap.contains (uuid)) {
valueToGuidePortMap (uuid)
} else SourceInfo.TxNotStartedRetry
logInfo ("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
oos.writeObject (guidePort)
} catch {
case e: Exception => {
logInfo ("TrackMultipleValues had a " + e)
}
} finally {
ois.close
oos.close
clientSocket.close
}
}
})
} catch {
// In failure, close socket here; else, client thread will close
case ioe: IOException => clientSocket.close
}
}
}
} finally {
serverSocket.close
}
// Shutdown the thread pool
threadPool.shutdown
}
}
}

View file

@ -9,6 +9,8 @@ import scala.collection.mutable.ArrayBuffer
import mesos.{ExecutorArgs, ExecutorDriver, MesosExecutorDriver}
import mesos.{TaskDescription, TaskState, TaskStatus}
import spark.broadcast._
/**
* The Mesos executor for Spark.
*/

View file

@ -7,6 +7,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.{ArrayBuffer, HashMap}
import spark._
object LocalFileShuffle extends Logging {
private var initialized = false
@ -29,9 +30,9 @@ object LocalFileShuffle extends Logging {
while (!foundLocalDir && tries < 10) {
tries += 1
try {
localDirUuid = UUID.randomUUID()
localDirUuid = UUID.randomUUID
localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
if (!localDir.exists()) {
if (!localDir.exists) {
localDir.mkdirs()
foundLocalDir = true
}
@ -47,6 +48,7 @@ object LocalFileShuffle extends Logging {
shuffleDir = new File(localDir, "shuffle")
shuffleDir.mkdirs()
logInfo("Shuffle dir: " + shuffleDir)
val extServerPort = System.getProperty(
"spark.localFileShuffle.external.server.port", "-1").toInt
if (extServerPort != -1) {
@ -65,6 +67,7 @@ object LocalFileShuffle extends Logging {
serverUri = server.uri
}
initialized = true
logInfo("Local URI: " + serverUri)
}
}

View file

@ -1,15 +0,0 @@
package spark
/**
* A trait for shuffle system. Given an input RDD and combiner functions
* for PairRDDExtras.combineByKey(), returns an output RDD.
*/
@serializable
trait Shuffle[K, V, C] {
def compute(input: RDD[(K, V)],
numOutputSplits: Int,
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C)
: RDD[(K, C)]
}

View file

@ -8,6 +8,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.SequenceFileInputFormat
import spark.broadcast._
class SparkContext(
master: String,

View file

@ -3,6 +3,7 @@ package spark
import java.io._
import java.net.InetAddress
import java.util.UUID
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
@ -117,12 +118,43 @@ object Utils {
/**
* Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4)
*/
def localIpAddress(): String = {
// Get local IP as an array of four bytes
val bytes = InetAddress.getLocalHost().getAddress()
// Convert the bytes to ints (keeping in mind that they may be negative)
// and join them into a string
return bytes.map(b => (b.toInt + 256) % 256).mkString(".")
def localIpAddress(): String = InetAddress.getLocalHost.getHostAddress
/**
* Returns a standard ThreadFactory except all threads are daemons
*/
private def newDaemonThreadFactory: ThreadFactory = {
new ThreadFactory {
def newThread(r: Runnable): Thread = {
var t = Executors.defaultThreadFactory.newThread (r)
t.setDaemon (true)
return t
}
}
}
/**
* Wrapper over newCachedThreadPool
*/
def newDaemonCachedThreadPool(): ThreadPoolExecutor = {
var threadPool =
Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory (newDaemonThreadFactory)
return threadPool
}
/**
* Wrapper over newFixedThreadPool
*/
def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
var threadPool =
Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory(newDaemonThreadFactory)
return threadPool
}
/**

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,228 @@
package spark.broadcast
import java.io._
import java.net._
import java.util.{BitSet, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import spark._
@serializable
trait Broadcast[T] {
val uuid = UUID.randomUUID
def value: T
// We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes. Possibly a Scala bug!
override def toString = "spark.Broadcast(" + uuid + ")"
}
object Broadcast
extends Logging {
// Messages
val REGISTER_BROADCAST_TRACKER = 0
val UNREGISTER_BROADCAST_TRACKER = 1
val FIND_BROADCAST_TRACKER = 2
val GET_UPDATED_SHARE = 3
private var initialized = false
private var isMaster_ = false
private var broadcastFactory: BroadcastFactory = null
// Called by SparkContext or Executor before using Broadcast
def initialize (isMaster__ : Boolean): Unit = synchronized {
if (!initialized) {
val broadcastFactoryClass = System.getProperty(
"spark.broadcast.factory", "spark.broadcast.DfsBroadcastFactory")
broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
// Setup isMaster before using it
isMaster_ = isMaster__
// Set masterHostAddress to the master's IP address for the slaves to read
if (isMaster) {
System.setProperty("spark.broadcast.masterHostAddress", Utils.localIpAddress)
}
// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isMaster)
initialized = true
}
}
def getBroadcastFactory: BroadcastFactory = {
if (broadcastFactory == null) {
throw new SparkException ("Broadcast.getBroadcastFactory called before initialize")
}
broadcastFactory
}
// Load common broadcast-related config parameters
private var MasterHostAddress_ = System.getProperty(
"spark.broadcast.masterHostAddress", "")
private var MasterTrackerPort_ = System.getProperty(
"spark.broadcast.masterTrackerPort", "11111").toInt
private var BlockSize_ = System.getProperty(
"spark.broadcast.blockSize", "4096").toInt * 1024
private var MaxRetryCount_ = System.getProperty(
"spark.broadcast.maxRetryCount", "2").toInt
private var TrackerSocketTimeout_ = System.getProperty(
"spark.broadcast.trackerSocketTimeout", "50000").toInt
private var ServerSocketTimeout_ = System.getProperty(
"spark.broadcast.serverSocketTimeout", "10000").toInt
private var MinKnockInterval_ = System.getProperty(
"spark.broadcast.minKnockInterval", "500").toInt
private var MaxKnockInterval_ = System.getProperty(
"spark.broadcast.maxKnockInterval", "999").toInt
// Load ChainedBroadcast config params
// Load TreeBroadcast config params
private var MaxDegree_ = System.getProperty("spark.broadcast.maxDegree", "2").toInt
// Load BitTorrentBroadcast config params
private var MaxPeersInGuideResponse_ = System.getProperty(
"spark.broadcast.maxPeersInGuideResponse", "4").toInt
private var MaxRxSlots_ = System.getProperty(
"spark.broadcast.maxRxSlots", "4").toInt
private var MaxTxSlots_ = System.getProperty(
"spark.broadcast.maxTxSlots", "4").toInt
private var MaxChatTime_ = System.getProperty(
"spark.broadcast.maxChatTime", "500").toInt
private var MaxChatBlocks_ = System.getProperty(
"spark.broadcast.maxChatBlocks", "1024").toInt
private var EndGameFraction_ = System.getProperty(
"spark.broadcast.endGameFraction", "0.95").toDouble
def isMaster = isMaster_
// Common config params
def MasterHostAddress = MasterHostAddress_
def MasterTrackerPort = MasterTrackerPort_
def BlockSize = BlockSize_
def MaxRetryCount = MaxRetryCount_
def TrackerSocketTimeout = TrackerSocketTimeout_
def ServerSocketTimeout = ServerSocketTimeout_
def MinKnockInterval = MinKnockInterval_
def MaxKnockInterval = MaxKnockInterval_
// ChainedBroadcast configs
// TreeBroadcast configs
def MaxDegree = MaxDegree_
// BitTorrentBroadcast configs
def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
def MaxRxSlots = MaxRxSlots_
def MaxTxSlots = MaxTxSlots_
def MaxChatTime = MaxChatTime_
def MaxChatBlocks = MaxChatBlocks_
def EndGameFraction = EndGameFraction_
// Helper functions to convert an object to Array[BroadcastBlock]
def blockifyObject[IN](obj: IN): VariableInfo = {
val baos = new ByteArrayOutputStream
val oos = new ObjectOutputStream(baos)
oos.writeObject(obj)
oos.close()
baos.close()
val byteArray = baos.toByteArray
val bais = new ByteArrayInputStream(byteArray)
var blockNum = (byteArray.length / Broadcast.BlockSize)
if (byteArray.length % Broadcast.BlockSize != 0)
blockNum += 1
var retVal = new Array[BroadcastBlock](blockNum)
var blockID = 0
for (i <- 0 until (byteArray.length, Broadcast.BlockSize)) {
val thisBlockSize = math.min(Broadcast.BlockSize, byteArray.length - i)
var tempByteArray = new Array[Byte](thisBlockSize)
val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
blockID += 1
}
bais.close()
var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
variableInfo.hasBlocks = blockNum
return variableInfo
}
// Helper function to convert Array[BroadcastBlock] to object
def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
totalBytes: Int,
totalBlocks: Int): OUT = {
var retByteArray = new Array[Byte](totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
i * Broadcast.BlockSize, arrayOfBlocks(i).byteArray.length)
}
byteArrayToObject(retByteArray)
}
private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, currentThread.getContextClassLoader)
}
val retVal = in.readObject.asInstanceOf[OUT]
in.close()
return retVal
}
}
@serializable
case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { }
@serializable
case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock],
val totalBlocks: Int,
val totalBytes: Int) {
@transient var hasBlocks = 0
}
@serializable
class SpeedTracker {
// Mapping 'source' to '(totalTime, numBlocks)'
private var sourceToSpeedMap = Map[SourceInfo, (Long, Int)] ()
def addDataPoint (srcInfo: SourceInfo, timeInMillis: Long): Unit = {
sourceToSpeedMap.synchronized {
if (!sourceToSpeedMap.contains(srcInfo)) {
sourceToSpeedMap += (srcInfo -> (timeInMillis, 1))
} else {
val tTnB = sourceToSpeedMap (srcInfo)
sourceToSpeedMap += (srcInfo -> (tTnB._1 + timeInMillis, tTnB._2 + 1))
}
}
}
def getTimePerBlock (srcInfo: SourceInfo): Double = {
sourceToSpeedMap.synchronized {
val tTnB = sourceToSpeedMap (srcInfo)
return tTnB._1 / tTnB._2
}
}
override def toString = sourceToSpeedMap.toString
}

View file

@ -0,0 +1,12 @@
package spark.broadcast
/**
* An interface for all the broadcast implementations in Spark (to allow
* multiple broadcast implementations). SparkContext uses a user-specified
* BroadcastFactory implementation to instantiate a particular broadcast for the
* entire Spark job.
*/
trait BroadcastFactory {
def initialize (isMaster: Boolean): Unit
def newBroadcast[T] (value_ : T, isLocal: Boolean): Broadcast[T]
}

View file

@ -0,0 +1,792 @@
package spark.broadcast
import java.io._
import java.net._
import java.util.{Comparator, PriorityQueue, Random, UUID}
import scala.collection.mutable.{Map, Set}
import scala.math
import spark._
@serializable
class ChainedBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging {
def value = value_
ChainedBroadcast.synchronized {
ChainedBroadcast.values.put(uuid, value_)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@transient var totalBytes = -1
@transient var totalBlocks = -1
@transient var hasBlocks = 0
// CHANGED: BlockSize in the Broadcast object is expected to change over time
@transient var blockSize = Broadcast.BlockSize
@transient var listenPortLock = new Object
@transient var guidePortLock = new Object
@transient var totalBlocksLock = new Object
@transient var hasBlocksLock = new Object
@transient var pqOfSources = new PriorityQueue[SourceInfo]
@transient var serveMR: ServeMultipleRequests = null
@transient var guideMR: GuideMultipleRequests = null
@transient var hostAddress = Utils.localIpAddress
@transient var listenPort = -1
@transient var guidePort = -1
@transient var hasCopyInHDFS = false
@transient var stopBroadcast = false
// Must call this after all the variables have been created/initialized
if (!isLocal) {
sendBroadcast
}
def sendBroadcast(): Unit = {
logInfo("Local host address: " + hostAddress)
// Store a persistent copy in HDFS
// TODO: Turned OFF for now
// val out = new ObjectOutputStream(DfsBroadcast.openFileForWriting(uuid))
// out.writeObject(value_)
// out.close()
// TODO: Fix this at some point
hasCopyInHDFS = true
// Create a variableInfo object and store it in valueInfos
var variableInfo = Broadcast.blockifyObject(value_)
guideMR = new GuideMultipleRequests
guideMR.setDaemon(true)
guideMR.start()
logInfo("GuideMultipleRequests started...")
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start()
logInfo("ServeMultipleRequests started...")
// Prepare the value being broadcasted
// TODO: Refactoring and clean-up required here
arrayOfBlocks = variableInfo.arrayOfBlocks
totalBytes = variableInfo.totalBytes
totalBlocks = variableInfo.totalBlocks
hasBlocks = variableInfo.totalBlocks
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait
}
}
pqOfSources = new PriorityQueue[SourceInfo]
val masterSource =
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize)
pqOfSources.add(masterSource)
// Register with the Tracker
while (guidePort == -1) {
guidePortLock.synchronized {
guidePortLock.wait
}
}
ChainedBroadcast.registerValue(uuid, guidePort)
}
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject
ChainedBroadcast.synchronized {
val cachedVal = ChainedBroadcast.values.get(uuid)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
// Initializing everything because Master will only send null/0 values
initializeSlaveVariables
logInfo("Local host address: " + hostAddress)
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start()
logInfo("ServeMultipleRequests started...")
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast(uuid)
// If does not succeed, then get from HDFS copy
if (receptionSucceeded) {
value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
ChainedBroadcast.values.put(uuid, value_)
} else {
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
ChainedBroadcast.values.put(uuid, value_)
fileIn.close()
}
val time =(System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
}
}
}
private def initializeSlaveVariables: Unit = {
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
hasBlocks = 0
blockSize = -1
listenPortLock = new Object
totalBlocksLock = new Object
hasBlocksLock = new Object
serveMR = null
hostAddress = Utils.localIpAddress
listenPort = -1
stopBroadcast = false
}
def getMasterListenPort(variableUUID: UUID): Int = {
var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
var retriesLeft = Broadcast.MaxRetryCount
do {
try {
// Connect to the tracker to find out the guide
clientSocketToTracker =
new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort)
oosTracker =
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
oosTracker.flush()
oisTracker =
new ObjectInputStream(clientSocketToTracker.getInputStream)
// Send UUID and receive masterListenPort
oosTracker.writeObject(uuid)
oosTracker.flush()
masterListenPort = oisTracker.readObject.asInstanceOf[Int]
} catch {
case e: Exception => {
logInfo("getMasterListenPort had a " + e)
}
} finally {
if (oisTracker != null) {
oisTracker.close()
}
if (oosTracker != null) {
oosTracker.close()
}
if (clientSocketToTracker != null) {
clientSocketToTracker.close()
}
}
retriesLeft -= 1
Thread.sleep(ChainedBroadcast.ranGen.nextInt(
Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) +
Broadcast.MinKnockInterval)
} while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
logInfo("Got this guidePort from Tracker: " + masterListenPort)
return masterListenPort
}
def receiveBroadcast(variableUUID: UUID): Boolean = {
val masterListenPort = getMasterListenPort(variableUUID)
if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
masterListenPort == SourceInfo.TxNotStartedRetry) {
// TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
// to HDFS anyway when receiveBroadcast returns false
return false
}
// Wait until hostAddress and listenPort are created by the
// ServeMultipleRequests thread
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait
}
}
var clientSocketToMaster: Socket = null
var oosMaster: ObjectOutputStream = null
var oisMaster: ObjectInputStream = null
// Connect and receive broadcast from the specified source, retrying the
// specified number of times in case of failures
var retriesLeft = Broadcast.MaxRetryCount
do {
// Connect to Master and send this worker's Information
clientSocketToMaster =
new Socket(Broadcast.MasterHostAddress, masterListenPort)
// TODO: Guiding object connection is reusable
oosMaster =
new ObjectOutputStream(clientSocketToMaster.getOutputStream)
oosMaster.flush()
oisMaster =
new ObjectInputStream(clientSocketToMaster.getInputStream)
logInfo("Connected to Master's guiding object")
// Send local source information
oosMaster.writeObject(SourceInfo(hostAddress, listenPort))
oosMaster.flush()
// Receive source information from Master
var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
totalBlocks = sourceInfo.totalBlocks
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
totalBlocksLock.synchronized {
totalBlocksLock.notifyAll
}
totalBytes = sourceInfo.totalBytes
logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
val start = System.nanoTime
val receptionSucceeded = receiveSingleTransmission(sourceInfo)
val time =(System.nanoTime - start) / 1e9
// Updating some statistics in sourceInfo. Master will be using them later
if (!receptionSucceeded) {
sourceInfo.receptionFailed = true
}
// Send back statistics to the Master
oosMaster.writeObject(sourceInfo)
if (oisMaster != null) {
oisMaster.close()
}
if (oosMaster != null) {
oosMaster.close()
}
if (clientSocketToMaster != null) {
clientSocketToMaster.close()
}
retriesLeft -= 1
} while (retriesLeft > 0 && hasBlocks < totalBlocks)
return(hasBlocks == totalBlocks)
}
// Tries to receive broadcast from the source and returns Boolean status.
// This might be called multiple times to retry a defined number of times.
private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
var clientSocketToSource: Socket = null
var oosSource: ObjectOutputStream = null
var oisSource: ObjectInputStream = null
var receptionSucceeded = false
try {
// Connect to the source to get the object itself
clientSocketToSource =
new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
oosSource =
new ObjectOutputStream(clientSocketToSource.getOutputStream)
oosSource.flush()
oisSource =
new ObjectInputStream(clientSocketToSource.getInputStream)
logInfo("Inside receiveSingleTransmission")
logInfo("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
// Send the range
oosSource.writeObject((hasBlocks, totalBlocks))
oosSource.flush()
for (i <- hasBlocks until totalBlocks) {
val recvStartTime = System.currentTimeMillis
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
val receptionTime =(System.currentTimeMillis - recvStartTime)
logInfo("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
arrayOfBlocks(hasBlocks) = bcBlock
hasBlocks += 1
// Set to true if at least one block is received
receptionSucceeded = true
hasBlocksLock.synchronized {
hasBlocksLock.notifyAll
}
}
} catch {
case e: Exception => {
logInfo("receiveSingleTransmission had a " + e)
}
} finally {
if (oisSource != null) {
oisSource.close()
}
if (oosSource != null) {
oosSource.close()
}
if (clientSocketToSource != null) {
clientSocketToSource.close()
}
}
return receptionSucceeded
}
class GuideMultipleRequests
extends Thread with Logging {
// Keep track of sources that have completed reception
private var setOfCompletedSources = Set[SourceInfo]()
override def run: Unit = {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(0)
guidePort = serverSocket.getLocalPort
logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
guidePortLock.synchronized {
guidePortLock.notifyAll
}
try {
// Don't stop until there is a copy in HDFS
while (!stopBroadcast || !hasCopyInHDFS) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo("GuideMultipleRequests Timeout.")
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done. Comparing with
// pqOfSources.size - 1, because it includes the Guide itself
if (pqOfSources.size > 1 &&
setOfCompletedSources.size == pqOfSources.size - 1) {
stopBroadcast = true
}
}
}
if (clientSocket != null) {
logInfo("Guide: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new GuideSingleRequest(clientSocket))
} catch {
// In failure, close the socket here; else, the thread will close it
case ioe: IOException => clientSocket.close()
}
}
}
logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
ChainedBroadcast.unregisterValue(uuid)
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
serverSocket.close()
}
}
// Shutdown the thread pool
threadPool.shutdown()
}
private def sendStopBroadcastNotifications: Unit = {
pqOfSources.synchronized {
var pqIter = pqOfSources.iterator
while (pqIter.hasNext) {
var sourceInfo = pqIter.next
var guideSocketToSource: Socket = null
var gosSource: ObjectOutputStream = null
var gisSource: ObjectInputStream = null
try {
// Connect to the source
guideSocketToSource =
new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
gosSource =
new ObjectOutputStream(guideSocketToSource.getOutputStream)
gosSource.flush()
gisSource =
new ObjectInputStream(guideSocketToSource.getInputStream)
// Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
gosSource.writeObject((SourceInfo.StopBroadcast,
SourceInfo.StopBroadcast))
gosSource.flush()
} catch {
case e: Exception => {
logInfo("sendStopBroadcastNotifications had a " + e)
}
} finally {
if (gisSource != null) {
gisSource.close()
}
if (gosSource != null) {
gosSource.close()
}
if (guideSocketToSource != null) {
guideSocketToSource.close()
}
}
}
}
}
class GuideSingleRequest(val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
private val ois = new ObjectInputStream(clientSocket.getInputStream)
private var selectedSourceInfo: SourceInfo = null
private var thisWorkerInfo:SourceInfo = null
override def run: Unit = {
try {
logInfo("new GuideSingleRequest is running")
// Connecting worker is sending in its hostAddress and listenPort it will
// be listening to. Other fields are invalid(SourceInfo.UnusedParam)
var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
pqOfSources.synchronized {
// Select a suitable source and send it back to the worker
selectedSourceInfo = selectSuitableSource(sourceInfo)
logInfo("Sending selectedSourceInfo: " + selectedSourceInfo)
oos.writeObject(selectedSourceInfo)
oos.flush()
// Add this new(if it can finish) source to the PQ of sources
thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
sourceInfo.listenPort, totalBlocks, totalBytes, blockSize)
logInfo("Adding possible new source to pqOfSources: " + thisWorkerInfo)
pqOfSources.add(thisWorkerInfo)
}
// Wait till the whole transfer is done. Then receive and update source
// statistics in pqOfSources
sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
pqOfSources.synchronized {
// This should work since SourceInfo is a case class
assert(pqOfSources.contains(selectedSourceInfo))
// Remove first
pqOfSources.remove(selectedSourceInfo)
// TODO: Removing a source based on just one failure notification!
// Update sourceInfo and put it back in, IF reception succeeded
if (!sourceInfo.receptionFailed) {
// Add thisWorkerInfo to sources that have completed reception
setOfCompletedSources.synchronized {
setOfCompletedSources += thisWorkerInfo
}
selectedSourceInfo.currentLeechers -= 1
// Put it back
pqOfSources.add(selectedSourceInfo)
}
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close everything up
case e: Exception => {
// Assuming that exception caused due to receiver worker failure.
// Remove failed worker from pqOfSources and update leecherCount of
// corresponding source worker
pqOfSources.synchronized {
if (selectedSourceInfo != null) {
// Remove first
pqOfSources.remove(selectedSourceInfo)
// Update leecher count and put it back in
selectedSourceInfo.currentLeechers -= 1
pqOfSources.add(selectedSourceInfo)
}
// Remove thisWorkerInfo
if (pqOfSources != null) {
pqOfSources.remove(thisWorkerInfo)
}
}
}
} finally {
ois.close()
oos.close()
clientSocket.close()
}
}
// FIXME: Caller must have a synchronized block on pqOfSources
// FIXME: If a worker fails to get the broadcasted variable from a source and
// comes back to Master, this function might choose the worker itself as a
// source tp create a dependency cycle(this worker was put into pqOfSources
// as a streming source when it first arrived). The length of this cycle can
// be arbitrarily long.
private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
// Select one based on the ordering strategy(e.g., least leechers etc.)
// take is a blocking call removing the element from PQ
var selectedSource = pqOfSources.poll
assert(selectedSource != null)
// Update leecher count
selectedSource.currentLeechers += 1
// Add it back and then return
pqOfSources.add(selectedSource)
return selectedSource
}
}
}
class ServeMultipleRequests
extends Thread with Logging {
override def run: Unit = {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(0)
listenPort = serverSocket.getLocalPort
logInfo("ServeMultipleRequests started with " + serverSocket)
listenPortLock.synchronized {
listenPortLock.notifyAll
}
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo("ServeMultipleRequests Timeout.")
}
}
if (clientSocket != null) {
logInfo("Serve: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new ServeSingleRequest(clientSocket))
} catch {
// In failure, close socket here; else, the thread will close it
case ioe: IOException => clientSocket.close()
}
}
}
} finally {
if (serverSocket != null) {
logInfo("ServeMultipleRequests now stopping...")
serverSocket.close()
}
}
// Shutdown the thread pool
threadPool.shutdown()
}
class ServeSingleRequest(val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
private val ois = new ObjectInputStream(clientSocket.getInputStream)
private var sendFrom = 0
private var sendUntil = totalBlocks
override def run: Unit = {
try {
logInfo("new ServeSingleRequest is running")
// Receive range to send
var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
sendFrom = rangeToSend._1
sendUntil = rangeToSend._2
if (sendFrom == SourceInfo.StopBroadcast &&
sendUntil == SourceInfo.StopBroadcast) {
stopBroadcast = true
} else {
// Carry on
sendObject
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close everything up
case e: Exception => {
logInfo("ServeSingleRequest had a " + e)
}
} finally {
logInfo("ServeSingleRequest is closing streams and sockets")
ois.close()
oos.close()
clientSocket.close()
}
}
private def sendObject: Unit = {
// Wait till receiving the SourceInfo from Master
while (totalBlocks == -1) {
totalBlocksLock.synchronized {
totalBlocksLock.wait
}
}
for (i <- sendFrom until sendUntil) {
while (i == hasBlocks) {
hasBlocksLock.synchronized {
hasBlocksLock.wait
}
}
try {
oos.writeObject(arrayOfBlocks(i))
oos.flush()
} catch {
case e: Exception => {
logInfo("sendObject had a " + e)
}
}
logInfo("Sent block: " + i + " to " + clientSocket)
}
}
}
}
}
class ChainedBroadcastFactory
extends BroadcastFactory {
def initialize(isMaster: Boolean) = ChainedBroadcast.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) =
new ChainedBroadcast[T](value_, isLocal)
}
private object ChainedBroadcast
extends Logging {
val values = SparkEnv.get.cache.newKeySpace()
var valueToGuidePortMap = Map[UUID, Int]()
// Random number generator
var ranGen = new Random
private var initialized = false
private var isMaster_ = false
private var trackMV: TrackMultipleValues = null
def initialize(isMaster__ : Boolean): Unit = {
synchronized {
if (!initialized) {
isMaster_ = isMaster__
if (isMaster) {
trackMV = new TrackMultipleValues
trackMV.setDaemon(true)
trackMV.start()
// TODO: Logging the following line makes the Spark framework ID not
// getting logged, cause it calls logInfo before log4j is initialized
logInfo("TrackMultipleValues started...")
}
// Initialize DfsBroadcast to be used for broadcast variable persistence
DfsBroadcast.initialize
initialized = true
}
}
}
def isMaster = isMaster_
def registerValue(uuid: UUID, guidePort: Int): Unit = {
valueToGuidePortMap.synchronized {
valueToGuidePortMap +=(uuid -> guidePort)
logInfo("New value registered with the Tracker " + valueToGuidePortMap)
}
}
def unregisterValue(uuid: UUID): Unit = {
valueToGuidePortMap.synchronized {
valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToHDFS
logInfo("Value unregistered from the Tracker " + valueToGuidePortMap)
}
}
class TrackMultipleValues
extends Thread with Logging {
override def run: Unit = {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(Broadcast.MasterTrackerPort)
logInfo("TrackMultipleValues" + serverSocket)
try {
while (true) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo("TrackMultipleValues Timeout. Stopping listening...")
}
}
if (clientSocket != null) {
try {
threadPool.execute(new Thread {
override def run: Unit = {
val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
val ois = new ObjectInputStream(clientSocket.getInputStream)
try {
val uuid = ois.readObject.asInstanceOf[UUID]
var guidePort =
if (valueToGuidePortMap.contains(uuid)) {
valueToGuidePortMap(uuid)
} else SourceInfo.TxNotStartedRetry
logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
oos.writeObject(guidePort)
} catch {
case e: Exception => {
logInfo("TrackMultipleValues had a " + e)
}
} finally {
ois.close()
oos.close()
clientSocket.close()
}
}
})
} catch {
// In failure, close socket here; else, client thread will close
case ioe: IOException => clientSocket.close()
}
}
}
} finally {
serverSocket.close()
}
// Shutdown the thread pool
threadPool.shutdown()
}
}
}

View file

@ -1,4 +1,6 @@
package spark
package spark.broadcast
import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import java.io._
import java.net._
@ -7,7 +9,7 @@ import java.util.UUID
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import spark._
@serializable
class DfsBroadcast[T](@transient var value_ : T, isLocal: Boolean)

View file

@ -0,0 +1,41 @@
package spark.broadcast
import java.util.BitSet
import spark._
/**
* Used to keep and pass around information of peers involved in a broadcast
*
* CHANGED: Keep track of the blockSize for THIS broadcast variable.
* Broadcast.BlockSize is expected to be updated across different broadcasts
*/
@serializable
case class SourceInfo (val hostAddress: String,
val listenPort: Int,
val totalBlocks: Int = SourceInfo.UnusedParam,
val totalBytes: Int = SourceInfo.UnusedParam,
val blockSize: Int = Broadcast.BlockSize)
extends Comparable[SourceInfo] with Logging {
var currentLeechers = 0
var receptionFailed = false
var hasBlocks = 0
var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
// Ascending sort based on leecher count
def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)
}
/**
* Helper Object of SourceInfo for its constants
*/
object SourceInfo {
// Constants for special values of listenPort
val TxNotStartedRetry = -1
val TxOverGoToHDFS = 0
// Other constants
val StopBroadcast = -2
val UnusedParam = 0
}

View file

@ -0,0 +1,807 @@
package spark.broadcast
import java.io._
import java.net._
import java.util.{Comparator, Random, UUID}
import scala.collection.mutable.{ListBuffer, Map, Set}
import scala.math
import spark._
@serializable
class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging {
def value = value_
TreeBroadcast.synchronized {
TreeBroadcast.values.put(uuid, value_)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@transient var totalBytes = -1
@transient var totalBlocks = -1
@transient var hasBlocks = 0
// CHANGED: BlockSize in the Broadcast object is expected to change over time
@transient var blockSize = Broadcast.BlockSize
@transient var listenPortLock = new Object
@transient var guidePortLock = new Object
@transient var totalBlocksLock = new Object
@transient var hasBlocksLock = new Object
@transient var listOfSources = ListBuffer[SourceInfo]()
@transient var serveMR: ServeMultipleRequests = null
@transient var guideMR: GuideMultipleRequests = null
@transient var hostAddress = Utils.localIpAddress
@transient var listenPort = -1
@transient var guidePort = -1
@transient var hasCopyInHDFS = false
@transient var stopBroadcast = false
// Must call this after all the variables have been created/initialized
if (!isLocal) {
sendBroadcast
}
def sendBroadcast(): Unit = {
logInfo("Local host address: " + hostAddress)
// Store a persistent copy in HDFS
// TODO: Turned OFF for now
// val out = new ObjectOutputStream(DfsBroadcast.openFileForWriting(uuid))
// out.writeObject(value_)
// out.close()
// TODO: Fix this at some point
hasCopyInHDFS = true
// Create a variableInfo object and store it in valueInfos
var variableInfo = Broadcast.blockifyObject(value_)
// Prepare the value being broadcasted
// TODO: Refactoring and clean-up required here
arrayOfBlocks = variableInfo.arrayOfBlocks
totalBytes = variableInfo.totalBytes
totalBlocks = variableInfo.totalBlocks
hasBlocks = variableInfo.totalBlocks
guideMR = new GuideMultipleRequests
guideMR.setDaemon(true)
guideMR.start
logInfo("GuideMultipleRequests started...")
// Must always come AFTER guideMR is created
while (guidePort == -1) {
guidePortLock.synchronized {
guidePortLock.wait
}
}
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start
logInfo("ServeMultipleRequests started...")
// Must always come AFTER serveMR is created
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait
}
}
// Must always come AFTER listenPort is created
val masterSource =
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize)
listOfSources += masterSource
// Register with the Tracker
TreeBroadcast.registerValue(uuid, guidePort)
}
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject
TreeBroadcast.synchronized {
val cachedVal = TreeBroadcast.values.get(uuid)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
// Initializing everything because Master will only send null/0 values
initializeSlaveVariables
logInfo("Local host address: " + hostAddress)
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start
logInfo("ServeMultipleRequests started...")
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast(uuid)
// If does not succeed, then get from HDFS copy
if (receptionSucceeded) {
value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
TreeBroadcast.values.put(uuid, value_)
} else {
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
TreeBroadcast.values.put(uuid, value_)
fileIn.close()
}
val time = (System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
}
}
}
private def initializeSlaveVariables: Unit = {
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
hasBlocks = 0
blockSize = -1
listenPortLock = new Object
totalBlocksLock = new Object
hasBlocksLock = new Object
serveMR = null
hostAddress = Utils.localIpAddress
listenPort = -1
stopBroadcast = false
}
def getMasterListenPort(variableUUID: UUID): Int = {
var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
var retriesLeft = Broadcast.MaxRetryCount
do {
try {
// Connect to the tracker to find out the guide
clientSocketToTracker =
new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort)
oosTracker =
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
oosTracker.flush()
oisTracker =
new ObjectInputStream(clientSocketToTracker.getInputStream)
// Send UUID and receive masterListenPort
oosTracker.writeObject(uuid)
oosTracker.flush()
masterListenPort = oisTracker.readObject.asInstanceOf[Int]
} catch {
case e: Exception => {
logInfo("getMasterListenPort had a " + e)
}
} finally {
if (oisTracker != null) {
oisTracker.close()
}
if (oosTracker != null) {
oosTracker.close()
}
if (clientSocketToTracker != null) {
clientSocketToTracker.close()
}
}
retriesLeft -= 1
Thread.sleep(TreeBroadcast.ranGen.nextInt(
Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) +
Broadcast.MinKnockInterval)
} while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
logInfo("Got this guidePort from Tracker: " + masterListenPort)
return masterListenPort
}
def receiveBroadcast(variableUUID: UUID): Boolean = {
val masterListenPort = getMasterListenPort(variableUUID)
if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
masterListenPort == SourceInfo.TxNotStartedRetry) {
// TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
// to HDFS anyway when receiveBroadcast returns false
return false
}
// Wait until hostAddress and listenPort are created by the
// ServeMultipleRequests thread
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait
}
}
var clientSocketToMaster: Socket = null
var oosMaster: ObjectOutputStream = null
var oisMaster: ObjectInputStream = null
// Connect and receive broadcast from the specified source, retrying the
// specified number of times in case of failures
var retriesLeft = Broadcast.MaxRetryCount
do {
// Connect to Master and send this worker's Information
clientSocketToMaster =
new Socket(Broadcast.MasterHostAddress, masterListenPort)
// TODO: Guiding object connection is reusable
oosMaster =
new ObjectOutputStream(clientSocketToMaster.getOutputStream)
oosMaster.flush()
oisMaster =
new ObjectInputStream(clientSocketToMaster.getInputStream)
logInfo("Connected to Master's guiding object")
// Send local source information
oosMaster.writeObject(SourceInfo(hostAddress, listenPort))
oosMaster.flush()
// Receive source information from Master
var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
totalBlocks = sourceInfo.totalBlocks
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
totalBlocksLock.synchronized {
totalBlocksLock.notifyAll
}
totalBytes = sourceInfo.totalBytes
blockSize = sourceInfo.blockSize
logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
val start = System.nanoTime
val receptionSucceeded = receiveSingleTransmission(sourceInfo)
val time = (System.nanoTime - start) / 1e9
// Updating some statistics in sourceInfo. Master will be using them later
if (!receptionSucceeded) {
sourceInfo.receptionFailed = true
}
// Send back statistics to the Master
oosMaster.writeObject(sourceInfo)
if (oisMaster != null) {
oisMaster.close()
}
if (oosMaster != null) {
oosMaster.close()
}
if (clientSocketToMaster != null) {
clientSocketToMaster.close()
}
retriesLeft -= 1
} while (retriesLeft > 0 && hasBlocks < totalBlocks)
return (hasBlocks == totalBlocks)
}
// Tries to receive broadcast from the source and returns Boolean status.
// This might be called multiple times to retry a defined number of times.
private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
var clientSocketToSource: Socket = null
var oosSource: ObjectOutputStream = null
var oisSource: ObjectInputStream = null
var receptionSucceeded = false
try {
// Connect to the source to get the object itself
clientSocketToSource =
new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
oosSource =
new ObjectOutputStream(clientSocketToSource.getOutputStream)
oosSource.flush()
oisSource =
new ObjectInputStream(clientSocketToSource.getInputStream)
logInfo("Inside receiveSingleTransmission")
logInfo("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
// Send the range
oosSource.writeObject((hasBlocks, totalBlocks))
oosSource.flush()
for (i <- hasBlocks until totalBlocks) {
val recvStartTime = System.currentTimeMillis
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
val receptionTime = (System.currentTimeMillis - recvStartTime)
logInfo("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
arrayOfBlocks(hasBlocks) = bcBlock
hasBlocks += 1
// Set to true if at least one block is received
receptionSucceeded = true
hasBlocksLock.synchronized {
hasBlocksLock.notifyAll
}
}
} catch {
case e: Exception => {
logInfo("receiveSingleTransmission had a " + e)
}
} finally {
if (oisSource != null) {
oisSource.close()
}
if (oosSource != null) {
oosSource.close()
}
if (clientSocketToSource != null) {
clientSocketToSource.close()
}
}
return receptionSucceeded
}
class GuideMultipleRequests
extends Thread with Logging {
// Keep track of sources that have completed reception
private var setOfCompletedSources = Set[SourceInfo]()
override def run: Unit = {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(0)
guidePort = serverSocket.getLocalPort
logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
guidePortLock.synchronized {
guidePortLock.notifyAll
}
try {
// Don't stop until there is a copy in HDFS
while (!stopBroadcast || !hasCopyInHDFS) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo("GuideMultipleRequests Timeout.")
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done. Comparing with
// listOfSources.size - 1, because it includes the Guide itself
if (listOfSources.size > 1 &&
setOfCompletedSources.size == listOfSources.size - 1) {
stopBroadcast = true
}
}
}
if (clientSocket != null) {
logInfo("Guide: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new GuideSingleRequest(clientSocket))
} catch {
// In failure, close() the socket here; else, the thread will close() it
case ioe: IOException => clientSocket.close()
}
}
}
logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
TreeBroadcast.unregisterValue(uuid)
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
serverSocket.close()
}
}
// Shutdown the thread pool
threadPool.shutdown
}
private def sendStopBroadcastNotifications: Unit = {
listOfSources.synchronized {
var listIter = listOfSources.iterator
while (listIter.hasNext) {
var sourceInfo = listIter.next
var guideSocketToSource: Socket = null
var gosSource: ObjectOutputStream = null
var gisSource: ObjectInputStream = null
try {
// Connect to the source
guideSocketToSource =
new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
gosSource =
new ObjectOutputStream(guideSocketToSource.getOutputStream)
gosSource.flush()
gisSource =
new ObjectInputStream(guideSocketToSource.getInputStream)
// Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
gosSource.writeObject((SourceInfo.StopBroadcast,
SourceInfo.StopBroadcast))
gosSource.flush()
} catch {
case e: Exception => {
logInfo("sendStopBroadcastNotifications had a " + e)
}
} finally {
if (gisSource != null) {
gisSource.close()
}
if (gosSource != null) {
gosSource.close()
}
if (guideSocketToSource != null) {
guideSocketToSource.close()
}
}
}
}
}
class GuideSingleRequest(val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
private val ois = new ObjectInputStream(clientSocket.getInputStream)
private var selectedSourceInfo: SourceInfo = null
private var thisWorkerInfo:SourceInfo = null
override def run: Unit = {
try {
logInfo("new GuideSingleRequest is running")
// Connecting worker is sending in its hostAddress and listenPort it will
// be listening to. Other fields are invalid (SourceInfo.UnusedParam)
var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
listOfSources.synchronized {
// Select a suitable source and send it back to the worker
selectedSourceInfo = selectSuitableSource(sourceInfo)
logInfo("Sending selectedSourceInfo: " + selectedSourceInfo)
oos.writeObject(selectedSourceInfo)
oos.flush()
// Add this new (if it can finish) source to the list of sources
thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
sourceInfo.listenPort, totalBlocks, totalBytes, blockSize)
logInfo("Adding possible new source to listOfSources: " + thisWorkerInfo)
listOfSources += thisWorkerInfo
}
// Wait till the whole transfer is done. Then receive and update source
// statistics in listOfSources
sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
listOfSources.synchronized {
// This should work since SourceInfo is a case class
assert(listOfSources.contains(selectedSourceInfo))
// Remove first
listOfSources = listOfSources - selectedSourceInfo
// TODO: Removing a source based on just one failure notification!
// Update sourceInfo and put it back in, IF reception succeeded
if (!sourceInfo.receptionFailed) {
// Add thisWorkerInfo to sources that have completed reception
setOfCompletedSources.synchronized {
setOfCompletedSources += thisWorkerInfo
}
selectedSourceInfo.currentLeechers -= 1
// Put it back
listOfSources += selectedSourceInfo
}
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close() everything up
case e: Exception => {
// Assuming that exception caused due to receiver worker failure.
// Remove failed worker from listOfSources and update leecherCount of
// corresponding source worker
listOfSources.synchronized {
if (selectedSourceInfo != null) {
// Remove first
listOfSources = listOfSources - selectedSourceInfo
// Update leecher count and put it back in
selectedSourceInfo.currentLeechers -= 1
listOfSources += selectedSourceInfo
}
// Remove thisWorkerInfo
if (listOfSources != null) {
listOfSources = listOfSources - thisWorkerInfo
}
}
}
} finally {
ois.close()
oos.close()
clientSocket.close()
}
}
// FIXME: Caller must have a synchronized block on listOfSources
// FIXME: If a worker fails to get the broadcasted variable from a source
// and comes back to the Master, this function might choose the worker
// itself as a source to create a dependency cycle (this worker was put
// into listOfSources as a streming source when it first arrived). The
// length of this cycle can be arbitrarily long.
private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
// Select one with the most leechers. This will level-wise fill the tree
var maxLeechers = -1
var selectedSource: SourceInfo = null
listOfSources.foreach { source =>
if (source != skipSourceInfo &&
source.currentLeechers < Broadcast.MaxDegree &&
source.currentLeechers > maxLeechers) {
selectedSource = source
maxLeechers = source.currentLeechers
}
}
// Update leecher count
selectedSource.currentLeechers += 1
return selectedSource
}
}
}
class ServeMultipleRequests
extends Thread with Logging {
override def run: Unit = {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(0)
listenPort = serverSocket.getLocalPort
logInfo("ServeMultipleRequests started with " + serverSocket)
listenPortLock.synchronized {
listenPortLock.notifyAll
}
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo("ServeMultipleRequests Timeout.")
}
}
if (clientSocket != null) {
logInfo("Serve: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new ServeSingleRequest(clientSocket))
} catch {
// In failure, close() socket here; else, the thread will close() it
case ioe: IOException => clientSocket.close()
}
}
}
} finally {
if (serverSocket != null) {
logInfo("ServeMultipleRequests now stopping...")
serverSocket.close()
}
}
// Shutdown the thread pool
threadPool.shutdown
}
class ServeSingleRequest(val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
private val ois = new ObjectInputStream(clientSocket.getInputStream)
private var sendFrom = 0
private var sendUntil = totalBlocks
override def run: Unit = {
try {
logInfo("new ServeSingleRequest is running")
// Receive range to send
var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
sendFrom = rangeToSend._1
sendUntil = rangeToSend._2
if (sendFrom == SourceInfo.StopBroadcast &&
sendUntil == SourceInfo.StopBroadcast) {
stopBroadcast = true
} else {
// Carry on
sendObject
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close() everything up
case e: Exception => {
logInfo("ServeSingleRequest had a " + e)
}
} finally {
logInfo("ServeSingleRequest is closing streams and sockets")
ois.close()
oos.close()
clientSocket.close()
}
}
private def sendObject: Unit = {
// Wait till receiving the SourceInfo from Master
while (totalBlocks == -1) {
totalBlocksLock.synchronized {
totalBlocksLock.wait
}
}
for (i <- sendFrom until sendUntil) {
while (i == hasBlocks) {
hasBlocksLock.synchronized {
hasBlocksLock.wait
}
}
try {
oos.writeObject(arrayOfBlocks(i))
oos.flush()
} catch {
case e: Exception => {
logInfo("sendObject had a " + e)
}
}
logInfo("Sent block: " + i + " to " + clientSocket)
}
}
}
}
}
class TreeBroadcastFactory
extends BroadcastFactory {
def initialize(isMaster: Boolean) = TreeBroadcast.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) =
new TreeBroadcast[T](value_, isLocal)
}
private object TreeBroadcast
extends Logging {
val values = SparkEnv.get.cache.newKeySpace()
var valueToGuidePortMap = Map[UUID, Int]()
// Random number generator
var ranGen = new Random
private var initialized = false
private var isMaster_ = false
private var trackMV: TrackMultipleValues = null
private var MaxDegree_ : Int = 2
def initialize(isMaster__ : Boolean): Unit = {
synchronized {
if (!initialized) {
isMaster_ = isMaster__
if (isMaster) {
trackMV = new TrackMultipleValues
trackMV.setDaemon(true)
trackMV.start
// TODO: Logging the following line makes the Spark framework ID not
// getting logged, cause it calls logInfo before log4j is initialized
logInfo("TrackMultipleValues started...")
}
// Initialize DfsBroadcast to be used for broadcast variable persistence
DfsBroadcast.initialize
initialized = true
}
}
}
def isMaster = isMaster_
def registerValue(uuid: UUID, guidePort: Int): Unit = {
valueToGuidePortMap.synchronized {
valueToGuidePortMap += (uuid -> guidePort)
logInfo("New value registered with the Tracker " + valueToGuidePortMap)
}
}
def unregisterValue(uuid: UUID): Unit = {
valueToGuidePortMap.synchronized {
valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToHDFS
logInfo("Value unregistered from the Tracker " + valueToGuidePortMap)
}
}
class TrackMultipleValues
extends Thread with Logging {
override def run: Unit = {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(Broadcast.MasterTrackerPort)
logInfo("TrackMultipleValues" + serverSocket)
try {
while (true) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo("TrackMultipleValues Timeout. Stopping listening...")
}
}
if (clientSocket != null) {
try {
threadPool.execute(new Thread {
override def run: Unit = {
val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
val ois = new ObjectInputStream(clientSocket.getInputStream)
try {
val uuid = ois.readObject.asInstanceOf[UUID]
var guidePort =
if (valueToGuidePortMap.contains(uuid)) {
valueToGuidePortMap(uuid)
} else SourceInfo.TxNotStartedRetry
logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
oos.writeObject(guidePort)
} catch {
case e: Exception => {
logInfo("TrackMultipleValues had a " + e)
}
} finally {
ois.close()
oos.close()
clientSocket.close()
}
}
})
} catch {
// In failure, close() socket here; else, client thread will close()
case ioe: IOException => clientSocket.close()
}
}
}
} finally {
serverSocket.close()
}
// Shutdown the thread pool
threadPool.shutdown
}
}
}

View file

@ -5,9 +5,10 @@ import spark.SparkContext
object BroadcastTest {
def main(args: Array[String]) {
if (args.length == 0) {
System.err.println("Usage: BroadcastTest <host> [<slices>]")
System.err.println("Usage: BroadcastTest <host> [<slices>] [numElem]")
System.exit(1)
}
val spark = new SparkContext(args(0), "Broadcast Test")
val slices = if (args.length > 1) args(1).toInt else 2
val num = if (args.length > 2) args(2).toInt else 1000000
@ -16,14 +17,8 @@ object BroadcastTest {
for (i <- 0 until arr1.length)
arr1(i) = i
// var arr2 = new Array[Int](num * 2)
// for (i <- 0 until arr2.length)
// arr2(i) = i
val barr1 = spark.broadcast(arr1)
// val barr2 = spark.broadcast(arr2)
spark.parallelize(1 to 10, slices).foreach {
// i => println(barr1.value.size + barr2.value.size)
i => println(barr1.value.size)
}
}

View file

@ -0,0 +1,37 @@
package spark.examples
import spark.SparkContext
import spark.SparkContext._
import java.util.Random
object GroupByTest {
def main(args: Array[String]) {
if (args.length == 0) {
System.err.println("Usage: GroupByTest <host> [numMappers] [numKVPairs] [KeySize] [numReducers]")
System.exit(1)
}
var numMappers = if (args.length > 1) args(1).toInt else 2
var numKVPairs = if (args.length > 2) args(2).toInt else 1000
var valSize = if (args.length > 3) args(3).toInt else 1000
var numReducers = if (args.length > 4) args(4).toInt else numMappers
val sc = new SparkContext(args(0), "GroupBy Test")
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
var arr1 = new Array[(Int, Array[Byte])](numKVPairs)
for (i <- 0 until numKVPairs) {
val byteArr = new Array[Byte](valSize)
ranGen.nextBytes(byteArr)
arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr)
}
arr1
}.cache
// Enforce that everything has been calculated and in cache
pairs1.count
println(pairs1.groupByKey(numReducers).count)
}
}

View file

@ -0,0 +1,30 @@
package spark.examples
import spark.SparkContext
object MultiBroadcastTest {
def main(args: Array[String]) {
if (args.length == 0) {
System.err.println("Usage: BroadcastTest <host> [<slices>] [numElem]")
System.exit(1)
}
val spark = new SparkContext(args(0), "Broadcast Test")
val slices = if (args.length > 1) args(1).toInt else 2
val num = if (args.length > 2) args(2).toInt else 1000000
var arr1 = new Array[Int](num)
for (i <- 0 until arr1.length)
arr1(i) = i
var arr2 = new Array[Int](num)
for (i <- 0 until arr2.length)
arr2(i) = i
val barr1 = spark.broadcast(arr1)
val barr2 = spark.broadcast(arr2)
spark.parallelize(1 to 10, slices).foreach {
i => println(barr1.value.size + barr2.value.size)
}
}
}

View file

@ -0,0 +1,51 @@
package spark.examples
import spark.SparkContext
import spark.SparkContext._
import java.util.Random
object SimpleSkewedGroupByTest {
def main(args: Array[String]) {
if (args.length == 0) {
System.err.println("Usage: SimpleSkewedGroupByTest <host> " +
"[numMappers] [numKVPairs] [valSize] [numReducers] [ratio]")
System.exit(1)
}
var numMappers = if (args.length > 1) args(1).toInt else 2
var numKVPairs = if (args.length > 2) args(2).toInt else 1000
var valSize = if (args.length > 3) args(3).toInt else 1000
var numReducers = if (args.length > 4) args(4).toInt else numMappers
var ratio = if (args.length > 5) args(5).toInt else 5.0
val sc = new SparkContext(args(0), "GroupBy Test")
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
var result = new Array[(Int, Array[Byte])](numKVPairs)
for (i <- 0 until numKVPairs) {
val byteArr = new Array[Byte](valSize)
ranGen.nextBytes(byteArr)
val offset = ranGen.nextInt(1000) * numReducers
if (ranGen.nextDouble < ratio / (numReducers + ratio - 1)) {
// give ratio times higher chance of generating key 0 (for reducer 0)
result(i) = (offset, byteArr)
} else {
// generate a key for one of the other reducers
val key = 1 + ranGen.nextInt(numReducers-1) + offset
result(i) = (key, byteArr)
}
}
result
}.cache
// Enforce that everything has been calculated and in cache
pairs1.count
println("RESULT: " + pairs1.groupByKey(numReducers).count)
// Print how many keys each reducer got (for debugging)
//println("RESULT: " + pairs1.groupByKey(numReducers)
// .map{case (k,v) => (k, v.size)}
// .collectAsMap)
}
}

View file

@ -0,0 +1,41 @@
package spark.examples
import spark.SparkContext
import spark.SparkContext._
import java.util.Random
object SkewedGroupByTest {
def main(args: Array[String]) {
if (args.length == 0) {
System.err.println("Usage: GroupByTest <host> [numMappers] [numKVPairs] [KeySize] [numReducers]")
System.exit(1)
}
var numMappers = if (args.length > 1) args(1).toInt else 2
var numKVPairs = if (args.length > 2) args(2).toInt else 1000
var valSize = if (args.length > 3) args(3).toInt else 1000
var numReducers = if (args.length > 4) args(4).toInt else numMappers
val sc = new SparkContext(args(0), "GroupBy Test")
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
// map output sizes lineraly increase from the 1st to the last
numKVPairs = (1. * (p + 1) / numMappers * numKVPairs).toInt
var arr1 = new Array[(Int, Array[Byte])](numKVPairs)
for (i <- 0 until numKVPairs) {
val byteArr = new Array[Byte](valSize)
ranGen.nextBytes(byteArr)
arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr)
}
arr1
}.cache
// Enforce that everything has been calculated and in cache
pairs1.count
println(pairs1.groupByKey(numReducers).count)
}
}