Merge branch 'master' into scala-2.9

Conflicts:
	repl/src/main/scala/spark/repl/SparkInterpreterLoop.scala
This commit is contained in:
Matei Zaharia 2011-06-26 19:22:27 -07:00
commit bae8a97968
24 changed files with 3621 additions and 2330 deletions

48
README
View file

@ -1,48 +0,0 @@
ONLINE DOCUMENTATION
You can find the latest Spark documentation, including a programming guide,
on the project wiki at http://github.com/mesos/spark/wiki. This file only
contains basic setup instructions.
BUILDING
Spark requires Scala 2.8. This version has been tested with 2.8.1.final.
The project is built using Simple Build Tool (SBT), which is packaged with it.
To build Spark and its example programs, run sbt/sbt update compile.
To run Spark, you will need to have Scala's bin in your $PATH, or you
will need to set the SCALA_HOME environment variable to point to where
you've installed Scala. Scala must be accessible through one of these
methods on Mesos slave nodes as well as on the master.
To run one of the examples, use ./run <class> <params>. For example,
./run spark.examples.SparkLR will run the Logistic Regression example.
Each of the example programs prints usage help if no params are given.
All of the Spark samples take a <host> parameter that is the Mesos master
to connect to. This can be a Mesos URL, or "local" to run locally with one
thread, or "local[N]" to run locally with N threads.
CONFIGURATION
Spark can be configured through two files: conf/java-opts and conf/spark-env.sh.
In java-opts, you can add flags to be passed to the JVM when running Spark.
In spark-env.sh, you can set any environment variables you wish to be available
when running Spark programs, such as PATH, SCALA_HOME, etc. There are also
several Spark-specific variables you can set:
- SPARK_CLASSPATH: Extra entries to be added to the classpath, separated by ":".
- SPARK_MEM: Memory for Spark to use, in the format used by java's -Xmx option
(for example, 200m meams 200 MB, 1g means 1 GB, etc).
- SPARK_LIBRARY_PATH: Extra entries to add to java.library.path for locating
shared libraries.
- SPARK_JAVA_OPTS: Extra options to pass to JVM.
Note that spark-env.sh must be a shell script (it must be executable and start
with a #! header to specify the shell to use).

63
README.md Normal file
View file

@ -0,0 +1,63 @@
# Spark
Lightning-Fast Cluster Computing - <http://www.spark-project.org/>
## Online Documentation
You can find the latest Spark documentation, including a programming
guide, on the project wiki at <http://github.com/mesos/spark/wiki>. This
file only contains basic setup instructions.
## Building
Spark requires Scala 2.8. This version has been tested with 2.8.1.final.
Experimental support for Scala 2.9 is available in the `scala-2.9` branch.
The project is built using Simple Build Tool (SBT), which is packaged with it.
To build Spark and its example programs, run:
sbt/sbt update compile
To run Spark, you will need to have Scala's bin in your `PATH`, or you
will need to set the `SCALA_HOME` environment variable to point to where
you've installed Scala. Scala must be accessible through one of these
methods on Mesos slave nodes as well as on the master.
To run one of the examples, use `./run <class> <params>`. For example:
./run spark.examples.SparkLR local[2]
will run the Logistic Regression example locally on 2 CPUs.
Each of the example programs prints usage help if no params are given.
All of the Spark samples take a `<host>` parameter that is the Mesos master
to connect to. This can be a Mesos URL, or "local" to run locally with one
thread, or "local[N]" to run locally with N threads.
## Configuration
Spark can be configured through two files: `conf/java-opts` and
`conf/spark-env.sh`.
In `java-opts`, you can add flags to be passed to the JVM when running Spark.
In `spark-env.sh`, you can set any environment variables you wish to be available
when running Spark programs, such as `PATH`, `SCALA_HOME`, etc. There are also
several Spark-specific variables you can set:
- `SPARK_CLASSPATH`: Extra entries to be added to the classpath, separated by ":".
- `SPARK_MEM`: Memory for Spark to use, in the format used by java's `-Xmx`
option (for example, `-Xmx200m` means 200 MB, `-Xmx1g` means 1 GB, etc).
- `SPARK_LIBRARY_PATH`: Extra entries to add to `java.library.path` for locating
shared libraries.
- `SPARK_JAVA_OPTS`: Extra options to pass to JVM.
Note that `spark-env.sh` must be a shell script (it must be executable and start
with a `#!` header to specify the shell to use).

File diff suppressed because it is too large Load diff

View file

@ -1,140 +0,0 @@
package spark
import java.util.{BitSet, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
@serializable
trait Broadcast[T] {
val uuid = UUID.randomUUID
def value: T
// We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes. Possibly a Scala bug!
override def toString = "spark.Broadcast(" + uuid + ")"
}
trait BroadcastFactory {
def initialize (isMaster: Boolean): Unit
def newBroadcast[T] (value_ : T, isLocal: Boolean): Broadcast[T]
}
private object Broadcast
extends Logging {
private var initialized = false
private var broadcastFactory: BroadcastFactory = null
// Called by SparkContext or Executor before using Broadcast
def initialize (isMaster: Boolean): Unit = synchronized {
if (!initialized) {
val broadcastFactoryClass = System.getProperty("spark.broadcast.factory",
"spark.DfsBroadcastFactory")
val booleanArgs = Array[AnyRef] (isMaster.asInstanceOf[AnyRef])
broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isMaster)
initialized = true
}
}
def getBroadcastFactory: BroadcastFactory = {
if (broadcastFactory == null) {
throw new SparkException ("Broadcast.getBroadcastFactory called before initialize")
}
broadcastFactory
}
// Returns a standard ThreadFactory except all threads are daemons
private def newDaemonThreadFactory: ThreadFactory = {
new ThreadFactory {
def newThread(r: Runnable): Thread = {
var t = Executors.defaultThreadFactory.newThread (r)
t.setDaemon (true)
return t
}
}
}
// Wrapper over newCachedThreadPool
def newDaemonCachedThreadPool: ThreadPoolExecutor = {
var threadPool =
Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory (newDaemonThreadFactory)
return threadPool
}
// Wrapper over newFixedThreadPool
def newDaemonFixedThreadPool (nThreads: Int): ThreadPoolExecutor = {
var threadPool =
Executors.newFixedThreadPool (nThreads).asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory (newDaemonThreadFactory)
return threadPool
}
}
@serializable
case class SourceInfo (val hostAddress: String, val listenPort: Int,
val totalBlocks: Int, val totalBytes: Int)
extends Comparable[SourceInfo] with Logging {
var currentLeechers = 0
var receptionFailed = false
var hasBlocks = 0
var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
// Ascending sort based on leecher count
def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)}
object SourceInfo {
// Constants for special values of listenPort
val TxNotStartedRetry = -1
val TxOverGoToHDFS = 0
// Other constants
val StopBroadcast = -2
val UnusedParam = 0
}
@serializable
case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { }
@serializable
case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock],
val totalBlocks: Int, val totalBytes: Int) {
@transient var hasBlocks = 0
}
@serializable
class SpeedTracker {
// Mapping 'source' to '(totalTime, numBlocks)'
private var sourceToSpeedMap = Map[SourceInfo, (Long, Int)] ()
def addDataPoint (srcInfo: SourceInfo, timeInMillis: Long): Unit = {
sourceToSpeedMap.synchronized {
if (!sourceToSpeedMap.contains(srcInfo)) {
sourceToSpeedMap += (srcInfo -> (timeInMillis, 1))
} else {
val tTnB = sourceToSpeedMap (srcInfo)
sourceToSpeedMap += (srcInfo -> (tTnB._1 + timeInMillis, tTnB._2 + 1))
}
}
}
def getTimePerBlock (srcInfo: SourceInfo): Double = {
sourceToSpeedMap.synchronized {
val tTnB = sourceToSpeedMap (srcInfo)
return tTnB._1 / tTnB._2
}
}
override def toString = sourceToSpeedMap.toString
}

View file

@ -1,873 +0,0 @@
package spark
import java.io._
import java.net._
import java.util.{Comparator, PriorityQueue, Random, UUID}
import scala.collection.mutable.{Map, Set}
@serializable
class ChainedBroadcast[T] (@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging {
def value = value_
ChainedBroadcast.synchronized {
ChainedBroadcast.values.put (uuid, value_)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@transient var totalBytes = -1
@transient var totalBlocks = -1
@transient var hasBlocks = 0
@transient var listenPortLock = new Object
@transient var guidePortLock = new Object
@transient var totalBlocksLock = new Object
@transient var hasBlocksLock = new Object
@transient var pqOfSources = new PriorityQueue[SourceInfo]
@transient var serveMR: ServeMultipleRequests = null
@transient var guideMR: GuideMultipleRequests = null
@transient var hostAddress = InetAddress.getLocalHost.getHostAddress
@transient var listenPort = -1
@transient var guidePort = -1
@transient var hasCopyInHDFS = false
@transient var stopBroadcast = false
// Must call this after all the variables have been created/initialized
if (!isLocal) {
sendBroadcast
}
def sendBroadcast (): Unit = {
logInfo ("Local host address: " + hostAddress)
// Store a persistent copy in HDFS
// TODO: Turned OFF for now
// val out = new ObjectOutputStream (DfsBroadcast.openFileForWriting(uuid))
// out.writeObject (value_)
// out.close
// TODO: Fix this at some point
hasCopyInHDFS = true
// Create a variableInfo object and store it in valueInfos
var variableInfo = blockifyObject (value_, ChainedBroadcast.BlockSize)
guideMR = new GuideMultipleRequests
guideMR.setDaemon (true)
guideMR.start
logInfo ("GuideMultipleRequests started...")
serveMR = new ServeMultipleRequests
serveMR.setDaemon (true)
serveMR.start
logInfo ("ServeMultipleRequests started...")
// Prepare the value being broadcasted
// TODO: Refactoring and clean-up required here
arrayOfBlocks = variableInfo.arrayOfBlocks
totalBytes = variableInfo.totalBytes
totalBlocks = variableInfo.totalBlocks
hasBlocks = variableInfo.totalBlocks
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait
}
}
pqOfSources = new PriorityQueue[SourceInfo]
val masterSource_0 =
SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes)
pqOfSources.add (masterSource_0)
// Register with the Tracker
while (guidePort == -1) {
guidePortLock.synchronized {
guidePortLock.wait
}
}
ChainedBroadcast.registerValue (uuid, guidePort)
}
private def readObject (in: ObjectInputStream): Unit = {
in.defaultReadObject
ChainedBroadcast.synchronized {
val cachedVal = ChainedBroadcast.values.get (uuid)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
// Initializing everything because Master will only send null/0 values
initializeSlaveVariables
logInfo ("Local host address: " + hostAddress)
serveMR = new ServeMultipleRequests
serveMR.setDaemon (true)
serveMR.start
logInfo ("ServeMultipleRequests started...")
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast (uuid)
// If does not succeed, then get from HDFS copy
if (receptionSucceeded) {
value_ = unBlockifyObject[T]
ChainedBroadcast.values.put (uuid, value_)
} else {
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
ChainedBroadcast.values.put(uuid, value_)
fileIn.close
}
val time = (System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
}
}
}
private def initializeSlaveVariables: Unit = {
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
hasBlocks = 0
listenPortLock = new Object
totalBlocksLock = new Object
hasBlocksLock = new Object
serveMR = null
hostAddress = InetAddress.getLocalHost.getHostAddress
listenPort = -1
stopBroadcast = false
}
private def blockifyObject (obj: T, blockSize: Int): VariableInfo = {
val baos = new ByteArrayOutputStream
val oos = new ObjectOutputStream (baos)
oos.writeObject (obj)
oos.close
baos.close
val byteArray = baos.toByteArray
val bais = new ByteArrayInputStream (byteArray)
var blockNum = (byteArray.length / blockSize)
if (byteArray.length % blockSize != 0)
blockNum += 1
var retVal = new Array[BroadcastBlock] (blockNum)
var blockID = 0
for (i <- 0 until (byteArray.length, blockSize)) {
val thisBlockSize = math.min (blockSize, byteArray.length - i)
var tempByteArray = new Array[Byte] (thisBlockSize)
val hasRead = bais.read (tempByteArray, 0, thisBlockSize)
retVal (blockID) = new BroadcastBlock (blockID, tempByteArray)
blockID += 1
}
bais.close
var variableInfo = VariableInfo (retVal, blockNum, byteArray.length)
variableInfo.hasBlocks = blockNum
return variableInfo
}
private def unBlockifyObject[A]: A = {
var retByteArray = new Array[Byte] (totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy (arrayOfBlocks(i).byteArray, 0, retByteArray,
i * ChainedBroadcast.BlockSize, arrayOfBlocks(i).byteArray.length)
}
byteArrayToObject (retByteArray)
}
private def byteArrayToObject[A] (bytes: Array[Byte]): A = {
val in = new ObjectInputStream (new ByteArrayInputStream (bytes))
val retVal = in.readObject.asInstanceOf[A]
in.close
return retVal
}
def getMasterListenPort (variableUUID: UUID): Int = {
var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
var retriesLeft = ChainedBroadcast.MaxRetryCount
do {
try {
// Connect to the tracker to find out the guide
val clientSocketToTracker =
new Socket(ChainedBroadcast.MasterHostAddress, ChainedBroadcast.MasterTrackerPort)
val oosTracker =
new ObjectOutputStream (clientSocketToTracker.getOutputStream)
oosTracker.flush
val oisTracker =
new ObjectInputStream (clientSocketToTracker.getInputStream)
// Send UUID and receive masterListenPort
oosTracker.writeObject (uuid)
oosTracker.flush
masterListenPort = oisTracker.readObject.asInstanceOf[Int]
} catch {
case e: Exception => {
logInfo ("getMasterListenPort had a " + e)
}
} finally {
if (oisTracker != null) {
oisTracker.close
}
if (oosTracker != null) {
oosTracker.close
}
if (clientSocketToTracker != null) {
clientSocketToTracker.close
}
}
retriesLeft -= 1
Thread.sleep (ChainedBroadcast.ranGen.nextInt (
ChainedBroadcast.MaxKnockInterval - ChainedBroadcast.MinKnockInterval) +
ChainedBroadcast.MinKnockInterval)
} while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
logInfo ("Got this guidePort from Tracker: " + masterListenPort)
return masterListenPort
}
def receiveBroadcast (variableUUID: UUID): Boolean = {
val masterListenPort = getMasterListenPort (variableUUID)
if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
masterListenPort == SourceInfo.TxNotStartedRetry) {
// TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
// to HDFS anyway when receiveBroadcast returns false
return false
}
// Wait until hostAddress and listenPort are created by the
// ServeMultipleRequests thread
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait
}
}
var clientSocketToMaster: Socket = null
var oosMaster: ObjectOutputStream = null
var oisMaster: ObjectInputStream = null
// Connect and receive broadcast from the specified source, retrying the
// specified number of times in case of failures
var retriesLeft = ChainedBroadcast.MaxRetryCount
do {
// Connect to Master and send this worker's Information
clientSocketToMaster =
new Socket(ChainedBroadcast.MasterHostAddress, masterListenPort)
// TODO: Guiding object connection is reusable
oosMaster =
new ObjectOutputStream (clientSocketToMaster.getOutputStream)
oosMaster.flush
oisMaster =
new ObjectInputStream (clientSocketToMaster.getInputStream)
logInfo ("Connected to Master's guiding object")
// Send local source information
oosMaster.writeObject(SourceInfo (hostAddress, listenPort,
SourceInfo.UnusedParam, SourceInfo.UnusedParam))
oosMaster.flush
// Receive source information from Master
var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
totalBlocks = sourceInfo.totalBlocks
arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks)
totalBlocksLock.synchronized {
totalBlocksLock.notifyAll
}
totalBytes = sourceInfo.totalBytes
logInfo ("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
val start = System.nanoTime
val receptionSucceeded = receiveSingleTransmission (sourceInfo)
val time = (System.nanoTime - start) / 1e9
// Updating some statistics in sourceInfo. Master will be using them later
if (!receptionSucceeded) {
sourceInfo.receptionFailed = true
}
// Send back statistics to the Master
oosMaster.writeObject (sourceInfo)
if (oisMaster != null) {
oisMaster.close
}
if (oosMaster != null) {
oosMaster.close
}
if (clientSocketToMaster != null) {
clientSocketToMaster.close
}
retriesLeft -= 1
} while (retriesLeft > 0 && hasBlocks < totalBlocks)
return (hasBlocks == totalBlocks)
}
// Tries to receive broadcast from the source and returns Boolean status.
// This might be called multiple times to retry a defined number of times.
private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
var clientSocketToSource: Socket = null
var oosSource: ObjectOutputStream = null
var oisSource: ObjectInputStream = null
var receptionSucceeded = false
try {
// Connect to the source to get the object itself
clientSocketToSource =
new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
oosSource =
new ObjectOutputStream (clientSocketToSource.getOutputStream)
oosSource.flush
oisSource =
new ObjectInputStream (clientSocketToSource.getInputStream)
logInfo ("Inside receiveSingleTransmission")
logInfo ("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
// Send the range
oosSource.writeObject((hasBlocks, totalBlocks))
oosSource.flush
for (i <- hasBlocks until totalBlocks) {
val recvStartTime = System.currentTimeMillis
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
val receptionTime = (System.currentTimeMillis - recvStartTime)
logInfo ("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
arrayOfBlocks(hasBlocks) = bcBlock
hasBlocks += 1
// Set to true if at least one block is received
receptionSucceeded = true
hasBlocksLock.synchronized {
hasBlocksLock.notifyAll
}
}
} catch {
case e: Exception => {
logInfo ("receiveSingleTransmission had a " + e)
}
} finally {
if (oisSource != null) {
oisSource.close
}
if (oosSource != null) {
oosSource.close
}
if (clientSocketToSource != null) {
clientSocketToSource.close
}
}
return receptionSucceeded
}
class GuideMultipleRequests
extends Thread with Logging {
// Keep track of sources that have completed reception
private var setOfCompletedSources = Set[SourceInfo] ()
override def run: Unit = {
var threadPool = Broadcast.newDaemonCachedThreadPool
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket (0)
guidePort = serverSocket.getLocalPort
logInfo ("GuideMultipleRequests => " + serverSocket + " " + guidePort)
guidePortLock.synchronized {
guidePortLock.notifyAll
}
try {
// Don't stop until there is a copy in HDFS
while (!stopBroadcast || !hasCopyInHDFS) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout (ChainedBroadcast.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo ("GuideMultipleRequests Timeout.")
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done. Comparing with
// pqOfSources.size - 1, because it includes the Guide itself
if (pqOfSources.size > 1 &&
setOfCompletedSources.size == pqOfSources.size - 1) {
stopBroadcast = true
}
}
}
if (clientSocket != null) {
logInfo ("Guide: Accepted new client connection: " + clientSocket)
try {
threadPool.execute (new GuideSingleRequest (clientSocket))
} catch {
// In failure, close the socket here; else, the thread will close it
case ioe: IOException => clientSocket.close
}
}
}
logInfo ("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
ChainedBroadcast.unregisterValue (uuid)
} finally {
if (serverSocket != null) {
logInfo ("GuideMultipleRequests now stopping...")
serverSocket.close
}
}
// Shutdown the thread pool
threadPool.shutdown
}
private def sendStopBroadcastNotifications: Unit = {
pqOfSources.synchronized {
var pqIter = pqOfSources.iterator
while (pqIter.hasNext) {
var sourceInfo = pqIter.next
var guideSocketToSource: Socket = null
var gosSource: ObjectOutputStream = null
var gisSource: ObjectInputStream = null
try {
// Connect to the source
guideSocketToSource =
new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
gosSource =
new ObjectOutputStream (guideSocketToSource.getOutputStream)
gosSource.flush
gisSource =
new ObjectInputStream (guideSocketToSource.getInputStream)
// Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
gosSource.writeObject ((SourceInfo.StopBroadcast,
SourceInfo.StopBroadcast))
gosSource.flush
} catch {
case e: Exception => {
logInfo ("sendStopBroadcastNotifications had a " + e)
}
} finally {
if (gisSource != null) {
gisSource.close
}
if (gosSource != null) {
gosSource.close
}
if (guideSocketToSource != null) {
guideSocketToSource.close
}
}
}
}
}
class GuideSingleRequest (val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
oos.flush
private val ois = new ObjectInputStream (clientSocket.getInputStream)
private var selectedSourceInfo: SourceInfo = null
private var thisWorkerInfo:SourceInfo = null
override def run: Unit = {
try {
logInfo ("new GuideSingleRequest is running")
// Connecting worker is sending in its hostAddress and listenPort it will
// be listening to. Other fields are invalid (SourceInfo.UnusedParam)
var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
pqOfSources.synchronized {
// Select a suitable source and send it back to the worker
selectedSourceInfo = selectSuitableSource (sourceInfo)
logInfo ("Sending selectedSourceInfo: " + selectedSourceInfo)
oos.writeObject (selectedSourceInfo)
oos.flush
// Add this new (if it can finish) source to the PQ of sources
thisWorkerInfo = SourceInfo (sourceInfo.hostAddress,
sourceInfo.listenPort, totalBlocks, totalBytes)
logInfo ("Adding possible new source to pqOfSources: " + thisWorkerInfo)
pqOfSources.add (thisWorkerInfo)
}
// Wait till the whole transfer is done. Then receive and update source
// statistics in pqOfSources
sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
pqOfSources.synchronized {
// This should work since SourceInfo is a case class
assert (pqOfSources.contains (selectedSourceInfo))
// Remove first
pqOfSources.remove (selectedSourceInfo)
// TODO: Removing a source based on just one failure notification!
// Update sourceInfo and put it back in, IF reception succeeded
if (!sourceInfo.receptionFailed) {
// Add thisWorkerInfo to sources that have completed reception
setOfCompletedSources += thisWorkerInfo
selectedSourceInfo.currentLeechers -= 1
// Put it back
pqOfSources.add (selectedSourceInfo)
}
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close everything up
case e: Exception => {
// Assuming that exception caused due to receiver worker failure.
// Remove failed worker from pqOfSources and update leecherCount of
// corresponding source worker
pqOfSources.synchronized {
if (selectedSourceInfo != null) {
// Remove first
pqOfSources.remove (selectedSourceInfo)
// Update leecher count and put it back in
selectedSourceInfo.currentLeechers -= 1
pqOfSources.add (selectedSourceInfo)
}
// Remove thisWorkerInfo
if (pqOfSources != null) {
pqOfSources.remove (thisWorkerInfo)
}
}
}
} finally {
ois.close
oos.close
clientSocket.close
}
}
// TODO: Caller must have a synchronized block on pqOfSources
// TODO: If a worker fails to get the broadcasted variable from a source and
// comes back to Master, this function might choose the worker itself as a
// source tp create a dependency cycle (this worker was put into pqOfSources
// as a streming source when it first arrived). The length of this cycle can
// be arbitrarily long.
private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
// Select one based on the ordering strategy (e.g., least leechers etc.)
// take is a blocking call removing the element from PQ
var selectedSource = pqOfSources.poll
assert (selectedSource != null)
// Update leecher count
selectedSource.currentLeechers += 1
// Add it back and then return
pqOfSources.add (selectedSource)
return selectedSource
}
}
}
class ServeMultipleRequests
extends Thread with Logging {
override def run: Unit = {
var threadPool = Broadcast.newDaemonCachedThreadPool
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket (0)
listenPort = serverSocket.getLocalPort
logInfo ("ServeMultipleRequests started with " + serverSocket)
listenPortLock.synchronized {
listenPortLock.notifyAll
}
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout (ChainedBroadcast.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo ("ServeMultipleRequests Timeout.")
}
}
if (clientSocket != null) {
logInfo ("Serve: Accepted new client connection: " + clientSocket)
try {
threadPool.execute (new ServeSingleRequest (clientSocket))
} catch {
// In failure, close socket here; else, the thread will close it
case ioe: IOException => clientSocket.close
}
}
}
} finally {
if (serverSocket != null) {
logInfo ("ServeMultipleRequests now stopping...")
serverSocket.close
}
}
// Shutdown the thread pool
threadPool.shutdown
}
class ServeSingleRequest (val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
oos.flush
private val ois = new ObjectInputStream (clientSocket.getInputStream)
private var sendFrom = 0
private var sendUntil = totalBlocks
override def run: Unit = {
try {
logInfo ("new ServeSingleRequest is running")
// Receive range to send
var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
sendFrom = rangeToSend._1
sendUntil = rangeToSend._2
if (sendFrom == SourceInfo.StopBroadcast &&
sendUntil == SourceInfo.StopBroadcast) {
stopBroadcast = true
} else {
// Carry on
sendObject
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close everything up
case e: Exception => {
logInfo ("ServeSingleRequest had a " + e)
}
} finally {
logInfo ("ServeSingleRequest is closing streams and sockets")
ois.close
oos.close
clientSocket.close
}
}
private def sendObject: Unit = {
// Wait till receiving the SourceInfo from Master
while (totalBlocks == -1) {
totalBlocksLock.synchronized {
totalBlocksLock.wait
}
}
for (i <- sendFrom until sendUntil) {
while (i == hasBlocks) {
hasBlocksLock.synchronized {
hasBlocksLock.wait
}
}
try {
oos.writeObject (arrayOfBlocks(i))
oos.flush
} catch {
case e: Exception => {
logInfo ("sendObject had a " + e)
}
}
logInfo ("Sent block: " + i + " to " + clientSocket)
}
}
}
}
}
class ChainedBroadcastFactory
extends BroadcastFactory {
def initialize (isMaster: Boolean) = ChainedBroadcast.initialize (isMaster)
def newBroadcast[T] (value_ : T, isLocal: Boolean) =
new ChainedBroadcast[T] (value_, isLocal)
}
private object ChainedBroadcast
extends Logging {
val values = SparkEnv.get.cache.newKeySpace()
var valueToGuidePortMap = Map[UUID, Int] ()
// Random number generator
var ranGen = new Random
private var initialized = false
private var isMaster_ = false
private var MasterHostAddress_ = InetAddress.getLocalHost.getHostAddress
private var MasterTrackerPort_ : Int = 22222
private var BlockSize_ : Int = 512 * 1024
private var MaxRetryCount_ : Int = 2
private var TrackerSocketTimeout_ : Int = 50000
private var ServerSocketTimeout_ : Int = 10000
private var trackMV: TrackMultipleValues = null
private var MinKnockInterval_ = 500
private var MaxKnockInterval_ = 999
def initialize (isMaster__ : Boolean): Unit = {
synchronized {
if (!initialized) {
// Fix for issue #42
MasterHostAddress_ =
System.getProperty ("spark.broadcast.masterHostAddress", "")
MasterTrackerPort_ =
System.getProperty ("spark.broadcast.masterTrackerPort", "22222").toInt
BlockSize_ =
System.getProperty ("spark.broadcast.blockSize", "512").toInt * 1024
MaxRetryCount_ =
System.getProperty ("spark.broadcast.maxRetryCount", "2").toInt
TrackerSocketTimeout_ =
System.getProperty ("spark.broadcast.trackerSocketTimeout", "50000").toInt
ServerSocketTimeout_ =
System.getProperty ("spark.broadcast.serverSocketTimeout", "10000").toInt
MinKnockInterval_ =
System.getProperty ("spark.broadcast.minKnockInterval", "500").toInt
MaxKnockInterval_ =
System.getProperty ("spark.broadcast.maxKnockInterval", "999").toInt
isMaster_ = isMaster__
if (isMaster) {
trackMV = new TrackMultipleValues
trackMV.setDaemon (true)
trackMV.start
logInfo ("TrackMultipleValues started...")
}
// Initialize DfsBroadcast to be used for broadcast variable persistence
DfsBroadcast.initialize
initialized = true
}
}
}
def MasterHostAddress = MasterHostAddress_
def MasterTrackerPort = MasterTrackerPort_
def BlockSize = BlockSize_
def MaxRetryCount = MaxRetryCount_
def TrackerSocketTimeout = TrackerSocketTimeout_
def ServerSocketTimeout = ServerSocketTimeout_
def isMaster = isMaster_
def MinKnockInterval = MinKnockInterval_
def MaxKnockInterval = MaxKnockInterval_
def registerValue (uuid: UUID, guidePort: Int): Unit = {
valueToGuidePortMap.synchronized {
valueToGuidePortMap += (uuid -> guidePort)
logInfo ("New value registered with the Tracker " + valueToGuidePortMap)
}
}
def unregisterValue (uuid: UUID): Unit = {
valueToGuidePortMap.synchronized {
valueToGuidePortMap (uuid) = SourceInfo.TxOverGoToHDFS
logInfo ("Value unregistered from the Tracker " + valueToGuidePortMap)
}
}
class TrackMultipleValues
extends Thread with Logging {
override def run: Unit = {
var threadPool = Broadcast.newDaemonCachedThreadPool
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket (ChainedBroadcast.MasterTrackerPort)
logInfo ("TrackMultipleValues" + serverSocket)
try {
while (true) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout (TrackerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo ("TrackMultipleValues Timeout. Stopping listening...")
}
}
if (clientSocket != null) {
try {
threadPool.execute (new Thread {
override def run: Unit = {
val oos = new ObjectOutputStream (clientSocket.getOutputStream)
oos.flush
val ois = new ObjectInputStream (clientSocket.getInputStream)
try {
val uuid = ois.readObject.asInstanceOf[UUID]
var guidePort =
if (valueToGuidePortMap.contains (uuid)) {
valueToGuidePortMap (uuid)
} else SourceInfo.TxNotStartedRetry
logInfo ("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
oos.writeObject (guidePort)
} catch {
case e: Exception => {
logInfo ("TrackMultipleValues had a " + e)
}
} finally {
ois.close
oos.close
clientSocket.close
}
}
})
} catch {
// In failure, close socket here; else, client thread will close
case ioe: IOException => clientSocket.close
}
}
}
} finally {
serverSocket.close
}
// Shutdown the thread pool
threadPool.shutdown
}
}
}

View file

@ -9,6 +9,8 @@ import scala.collection.mutable.ArrayBuffer
import mesos.{ExecutorArgs, ExecutorDriver, MesosExecutorDriver}
import mesos.{TaskDescription, TaskState, TaskStatus}
import spark.broadcast._
/**
* The Mesos executor for Spark.
*/

View file

@ -7,6 +7,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.{ArrayBuffer, HashMap}
import spark._
object LocalFileShuffle extends Logging {
private var initialized = false
@ -29,9 +30,9 @@ object LocalFileShuffle extends Logging {
while (!foundLocalDir && tries < 10) {
tries += 1
try {
localDirUuid = UUID.randomUUID()
localDirUuid = UUID.randomUUID
localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
if (!localDir.exists()) {
if (!localDir.exists) {
localDir.mkdirs()
foundLocalDir = true
}
@ -47,6 +48,7 @@ object LocalFileShuffle extends Logging {
shuffleDir = new File(localDir, "shuffle")
shuffleDir.mkdirs()
logInfo("Shuffle dir: " + shuffleDir)
val extServerPort = System.getProperty(
"spark.localFileShuffle.external.server.port", "-1").toInt
if (extServerPort != -1) {
@ -65,6 +67,7 @@ object LocalFileShuffle extends Logging {
serverUri = server.uri
}
initialized = true
logInfo("Local URI: " + serverUri)
}
}

View file

@ -242,6 +242,44 @@ extends RDD[Array[T]](prev.context) {
}
}
def leftOuterJoin[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (V, Option[W]))] = {
val vs: RDD[(K, Either[V, W])] = self.map { case (k, v) => (k, Left(v)) }
val ws: RDD[(K, Either[V, W])] = other.map { case (k, w) => (k, Right(w)) }
(vs ++ ws).groupByKey(numSplits).flatMap {
case (k, seq) => {
val vbuf = new ArrayBuffer[V]
val wbuf = new ArrayBuffer[Option[W]]
seq.foreach(_ match {
case Left(v) => vbuf += v
case Right(w) => wbuf += Some(w)
})
if (wbuf.isEmpty) {
wbuf += None
}
for (v <- vbuf; w <- wbuf) yield (k, (v, w))
}
}
}
def rightOuterJoin[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (Option[V], W))] = {
val vs: RDD[(K, Either[V, W])] = self.map { case (k, v) => (k, Left(v)) }
val ws: RDD[(K, Either[V, W])] = other.map { case (k, w) => (k, Right(w)) }
(vs ++ ws).groupByKey(numSplits).flatMap {
case (k, seq) => {
val vbuf = new ArrayBuffer[Option[V]]
val wbuf = new ArrayBuffer[W]
seq.foreach(_ match {
case Left(v) => vbuf += Some(v)
case Right(w) => wbuf += w
})
if (vbuf.isEmpty) {
vbuf += None
}
for (v <- vbuf; w <- wbuf) yield (k, (v, w))
}
}
}
def combineByKey[C](createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C)
@ -261,6 +299,14 @@ extends RDD[Array[T]](prev.context) {
join(other, numCores)
}
def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = {
leftOuterJoin(other, numCores)
}
def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = {
rightOuterJoin(other, numCores)
}
def numCores = self.context.numCores
def collectAsMap(): Map[K, V] = HashMap(self.collect(): _*)
@ -301,6 +347,23 @@ extends RDD[Array[T]](prev.context) {
(k, (vs.asInstanceOf[Seq[V]], w1s.asInstanceOf[Seq[W1]], w2s.asInstanceOf[Seq[W2]]))
}
}
def lookup(key: K): Seq[V] = {
self.partitioner match {
case Some(p) =>
val index = p.getPartition(key)
def process(it: Iterator[(K, V)]): Seq[V] = {
val buf = new ArrayBuffer[V]
for ((k, v) <- it if k == key)
buf += v
buf
}
val res = self.context.runJob(self, process, Array(index))
res(0)
case None =>
throw new UnsupportedOperationException("lookup() called on an RDD without a partitioner")
}
}
}
class MappedValuesRDD[K, V, U](

View file

@ -1,15 +0,0 @@
package spark
/**
* A trait for shuffle system. Given an input RDD and combiner functions
* for PairRDDExtras.combineByKey(), returns an output RDD.
*/
@serializable
trait Shuffle[K, V, C] {
def compute(input: RDD[(K, V)],
numOutputSplits: Int,
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C)
: RDD[(K, C)]
}

View file

@ -8,6 +8,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.SequenceFileInputFormat
import spark.broadcast._
class SparkContext(
master: String,

View file

@ -3,6 +3,7 @@ package spark
import java.io._
import java.net.InetAddress
import java.util.UUID
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
@ -117,12 +118,43 @@ object Utils {
/**
* Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4)
*/
def localIpAddress(): String = {
// Get local IP as an array of four bytes
val bytes = InetAddress.getLocalHost().getAddress()
// Convert the bytes to ints (keeping in mind that they may be negative)
// and join them into a string
return bytes.map(b => (b.toInt + 256) % 256).mkString(".")
def localIpAddress(): String = InetAddress.getLocalHost.getHostAddress
/**
* Returns a standard ThreadFactory except all threads are daemons
*/
private def newDaemonThreadFactory: ThreadFactory = {
new ThreadFactory {
def newThread(r: Runnable): Thread = {
var t = Executors.defaultThreadFactory.newThread (r)
t.setDaemon (true)
return t
}
}
}
/**
* Wrapper over newCachedThreadPool
*/
def newDaemonCachedThreadPool(): ThreadPoolExecutor = {
var threadPool =
Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory (newDaemonThreadFactory)
return threadPool
}
/**
* Wrapper over newFixedThreadPool
*/
def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
var threadPool =
Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory(newDaemonThreadFactory)
return threadPool
}
/**

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,228 @@
package spark.broadcast
import java.io._
import java.net._
import java.util.{BitSet, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import spark._
@serializable
trait Broadcast[T] {
val uuid = UUID.randomUUID
def value: T
// We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes. Possibly a Scala bug!
override def toString = "spark.Broadcast(" + uuid + ")"
}
object Broadcast
extends Logging {
// Messages
val REGISTER_BROADCAST_TRACKER = 0
val UNREGISTER_BROADCAST_TRACKER = 1
val FIND_BROADCAST_TRACKER = 2
val GET_UPDATED_SHARE = 3
private var initialized = false
private var isMaster_ = false
private var broadcastFactory: BroadcastFactory = null
// Called by SparkContext or Executor before using Broadcast
def initialize (isMaster__ : Boolean): Unit = synchronized {
if (!initialized) {
val broadcastFactoryClass = System.getProperty(
"spark.broadcast.factory", "spark.broadcast.DfsBroadcastFactory")
broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
// Setup isMaster before using it
isMaster_ = isMaster__
// Set masterHostAddress to the master's IP address for the slaves to read
if (isMaster) {
System.setProperty("spark.broadcast.masterHostAddress", Utils.localIpAddress)
}
// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isMaster)
initialized = true
}
}
def getBroadcastFactory: BroadcastFactory = {
if (broadcastFactory == null) {
throw new SparkException ("Broadcast.getBroadcastFactory called before initialize")
}
broadcastFactory
}
// Load common broadcast-related config parameters
private var MasterHostAddress_ = System.getProperty(
"spark.broadcast.masterHostAddress", "")
private var MasterTrackerPort_ = System.getProperty(
"spark.broadcast.masterTrackerPort", "11111").toInt
private var BlockSize_ = System.getProperty(
"spark.broadcast.blockSize", "4096").toInt * 1024
private var MaxRetryCount_ = System.getProperty(
"spark.broadcast.maxRetryCount", "2").toInt
private var TrackerSocketTimeout_ = System.getProperty(
"spark.broadcast.trackerSocketTimeout", "50000").toInt
private var ServerSocketTimeout_ = System.getProperty(
"spark.broadcast.serverSocketTimeout", "10000").toInt
private var MinKnockInterval_ = System.getProperty(
"spark.broadcast.minKnockInterval", "500").toInt
private var MaxKnockInterval_ = System.getProperty(
"spark.broadcast.maxKnockInterval", "999").toInt
// Load ChainedBroadcast config params
// Load TreeBroadcast config params
private var MaxDegree_ = System.getProperty("spark.broadcast.maxDegree", "2").toInt
// Load BitTorrentBroadcast config params
private var MaxPeersInGuideResponse_ = System.getProperty(
"spark.broadcast.maxPeersInGuideResponse", "4").toInt
private var MaxRxSlots_ = System.getProperty(
"spark.broadcast.maxRxSlots", "4").toInt
private var MaxTxSlots_ = System.getProperty(
"spark.broadcast.maxTxSlots", "4").toInt
private var MaxChatTime_ = System.getProperty(
"spark.broadcast.maxChatTime", "500").toInt
private var MaxChatBlocks_ = System.getProperty(
"spark.broadcast.maxChatBlocks", "1024").toInt
private var EndGameFraction_ = System.getProperty(
"spark.broadcast.endGameFraction", "0.95").toDouble
def isMaster = isMaster_
// Common config params
def MasterHostAddress = MasterHostAddress_
def MasterTrackerPort = MasterTrackerPort_
def BlockSize = BlockSize_
def MaxRetryCount = MaxRetryCount_
def TrackerSocketTimeout = TrackerSocketTimeout_
def ServerSocketTimeout = ServerSocketTimeout_
def MinKnockInterval = MinKnockInterval_
def MaxKnockInterval = MaxKnockInterval_
// ChainedBroadcast configs
// TreeBroadcast configs
def MaxDegree = MaxDegree_
// BitTorrentBroadcast configs
def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
def MaxRxSlots = MaxRxSlots_
def MaxTxSlots = MaxTxSlots_
def MaxChatTime = MaxChatTime_
def MaxChatBlocks = MaxChatBlocks_
def EndGameFraction = EndGameFraction_
// Helper functions to convert an object to Array[BroadcastBlock]
def blockifyObject[IN](obj: IN): VariableInfo = {
val baos = new ByteArrayOutputStream
val oos = new ObjectOutputStream(baos)
oos.writeObject(obj)
oos.close()
baos.close()
val byteArray = baos.toByteArray
val bais = new ByteArrayInputStream(byteArray)
var blockNum = (byteArray.length / Broadcast.BlockSize)
if (byteArray.length % Broadcast.BlockSize != 0)
blockNum += 1
var retVal = new Array[BroadcastBlock](blockNum)
var blockID = 0
for (i <- 0 until (byteArray.length, Broadcast.BlockSize)) {
val thisBlockSize = math.min(Broadcast.BlockSize, byteArray.length - i)
var tempByteArray = new Array[Byte](thisBlockSize)
val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
blockID += 1
}
bais.close()
var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
variableInfo.hasBlocks = blockNum
return variableInfo
}
// Helper function to convert Array[BroadcastBlock] to object
def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
totalBytes: Int,
totalBlocks: Int): OUT = {
var retByteArray = new Array[Byte](totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
i * Broadcast.BlockSize, arrayOfBlocks(i).byteArray.length)
}
byteArrayToObject(retByteArray)
}
private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, currentThread.getContextClassLoader)
}
val retVal = in.readObject.asInstanceOf[OUT]
in.close()
return retVal
}
}
@serializable
case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { }
@serializable
case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock],
val totalBlocks: Int,
val totalBytes: Int) {
@transient var hasBlocks = 0
}
@serializable
class SpeedTracker {
// Mapping 'source' to '(totalTime, numBlocks)'
private var sourceToSpeedMap = Map[SourceInfo, (Long, Int)] ()
def addDataPoint (srcInfo: SourceInfo, timeInMillis: Long): Unit = {
sourceToSpeedMap.synchronized {
if (!sourceToSpeedMap.contains(srcInfo)) {
sourceToSpeedMap += (srcInfo -> (timeInMillis, 1))
} else {
val tTnB = sourceToSpeedMap (srcInfo)
sourceToSpeedMap += (srcInfo -> (tTnB._1 + timeInMillis, tTnB._2 + 1))
}
}
}
def getTimePerBlock (srcInfo: SourceInfo): Double = {
sourceToSpeedMap.synchronized {
val tTnB = sourceToSpeedMap (srcInfo)
return tTnB._1 / tTnB._2
}
}
override def toString = sourceToSpeedMap.toString
}

View file

@ -0,0 +1,12 @@
package spark.broadcast
/**
* An interface for all the broadcast implementations in Spark (to allow
* multiple broadcast implementations). SparkContext uses a user-specified
* BroadcastFactory implementation to instantiate a particular broadcast for the
* entire Spark job.
*/
trait BroadcastFactory {
def initialize (isMaster: Boolean): Unit
def newBroadcast[T] (value_ : T, isLocal: Boolean): Broadcast[T]
}

View file

@ -0,0 +1,792 @@
package spark.broadcast
import java.io._
import java.net._
import java.util.{Comparator, PriorityQueue, Random, UUID}
import scala.collection.mutable.{Map, Set}
import scala.math
import spark._
@serializable
class ChainedBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging {
def value = value_
ChainedBroadcast.synchronized {
ChainedBroadcast.values.put(uuid, value_)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@transient var totalBytes = -1
@transient var totalBlocks = -1
@transient var hasBlocks = 0
// CHANGED: BlockSize in the Broadcast object is expected to change over time
@transient var blockSize = Broadcast.BlockSize
@transient var listenPortLock = new Object
@transient var guidePortLock = new Object
@transient var totalBlocksLock = new Object
@transient var hasBlocksLock = new Object
@transient var pqOfSources = new PriorityQueue[SourceInfo]
@transient var serveMR: ServeMultipleRequests = null
@transient var guideMR: GuideMultipleRequests = null
@transient var hostAddress = Utils.localIpAddress
@transient var listenPort = -1
@transient var guidePort = -1
@transient var hasCopyInHDFS = false
@transient var stopBroadcast = false
// Must call this after all the variables have been created/initialized
if (!isLocal) {
sendBroadcast
}
def sendBroadcast(): Unit = {
logInfo("Local host address: " + hostAddress)
// Store a persistent copy in HDFS
// TODO: Turned OFF for now
// val out = new ObjectOutputStream(DfsBroadcast.openFileForWriting(uuid))
// out.writeObject(value_)
// out.close()
// TODO: Fix this at some point
hasCopyInHDFS = true
// Create a variableInfo object and store it in valueInfos
var variableInfo = Broadcast.blockifyObject(value_)
guideMR = new GuideMultipleRequests
guideMR.setDaemon(true)
guideMR.start()
logInfo("GuideMultipleRequests started...")
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start()
logInfo("ServeMultipleRequests started...")
// Prepare the value being broadcasted
// TODO: Refactoring and clean-up required here
arrayOfBlocks = variableInfo.arrayOfBlocks
totalBytes = variableInfo.totalBytes
totalBlocks = variableInfo.totalBlocks
hasBlocks = variableInfo.totalBlocks
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait
}
}
pqOfSources = new PriorityQueue[SourceInfo]
val masterSource =
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize)
pqOfSources.add(masterSource)
// Register with the Tracker
while (guidePort == -1) {
guidePortLock.synchronized {
guidePortLock.wait
}
}
ChainedBroadcast.registerValue(uuid, guidePort)
}
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject
ChainedBroadcast.synchronized {
val cachedVal = ChainedBroadcast.values.get(uuid)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
// Initializing everything because Master will only send null/0 values
initializeSlaveVariables
logInfo("Local host address: " + hostAddress)
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start()
logInfo("ServeMultipleRequests started...")
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast(uuid)
// If does not succeed, then get from HDFS copy
if (receptionSucceeded) {
value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
ChainedBroadcast.values.put(uuid, value_)
} else {
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
ChainedBroadcast.values.put(uuid, value_)
fileIn.close()
}
val time =(System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
}
}
}
private def initializeSlaveVariables: Unit = {
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
hasBlocks = 0
blockSize = -1
listenPortLock = new Object
totalBlocksLock = new Object
hasBlocksLock = new Object
serveMR = null
hostAddress = Utils.localIpAddress
listenPort = -1
stopBroadcast = false
}
def getMasterListenPort(variableUUID: UUID): Int = {
var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
var retriesLeft = Broadcast.MaxRetryCount
do {
try {
// Connect to the tracker to find out the guide
clientSocketToTracker =
new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort)
oosTracker =
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
oosTracker.flush()
oisTracker =
new ObjectInputStream(clientSocketToTracker.getInputStream)
// Send UUID and receive masterListenPort
oosTracker.writeObject(uuid)
oosTracker.flush()
masterListenPort = oisTracker.readObject.asInstanceOf[Int]
} catch {
case e: Exception => {
logInfo("getMasterListenPort had a " + e)
}
} finally {
if (oisTracker != null) {
oisTracker.close()
}
if (oosTracker != null) {
oosTracker.close()
}
if (clientSocketToTracker != null) {
clientSocketToTracker.close()
}
}
retriesLeft -= 1
Thread.sleep(ChainedBroadcast.ranGen.nextInt(
Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) +
Broadcast.MinKnockInterval)
} while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
logInfo("Got this guidePort from Tracker: " + masterListenPort)
return masterListenPort
}
def receiveBroadcast(variableUUID: UUID): Boolean = {
val masterListenPort = getMasterListenPort(variableUUID)
if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
masterListenPort == SourceInfo.TxNotStartedRetry) {
// TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
// to HDFS anyway when receiveBroadcast returns false
return false
}
// Wait until hostAddress and listenPort are created by the
// ServeMultipleRequests thread
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait
}
}
var clientSocketToMaster: Socket = null
var oosMaster: ObjectOutputStream = null
var oisMaster: ObjectInputStream = null
// Connect and receive broadcast from the specified source, retrying the
// specified number of times in case of failures
var retriesLeft = Broadcast.MaxRetryCount
do {
// Connect to Master and send this worker's Information
clientSocketToMaster =
new Socket(Broadcast.MasterHostAddress, masterListenPort)
// TODO: Guiding object connection is reusable
oosMaster =
new ObjectOutputStream(clientSocketToMaster.getOutputStream)
oosMaster.flush()
oisMaster =
new ObjectInputStream(clientSocketToMaster.getInputStream)
logInfo("Connected to Master's guiding object")
// Send local source information
oosMaster.writeObject(SourceInfo(hostAddress, listenPort))
oosMaster.flush()
// Receive source information from Master
var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
totalBlocks = sourceInfo.totalBlocks
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
totalBlocksLock.synchronized {
totalBlocksLock.notifyAll
}
totalBytes = sourceInfo.totalBytes
logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
val start = System.nanoTime
val receptionSucceeded = receiveSingleTransmission(sourceInfo)
val time =(System.nanoTime - start) / 1e9
// Updating some statistics in sourceInfo. Master will be using them later
if (!receptionSucceeded) {
sourceInfo.receptionFailed = true
}
// Send back statistics to the Master
oosMaster.writeObject(sourceInfo)
if (oisMaster != null) {
oisMaster.close()
}
if (oosMaster != null) {
oosMaster.close()
}
if (clientSocketToMaster != null) {
clientSocketToMaster.close()
}
retriesLeft -= 1
} while (retriesLeft > 0 && hasBlocks < totalBlocks)
return(hasBlocks == totalBlocks)
}
// Tries to receive broadcast from the source and returns Boolean status.
// This might be called multiple times to retry a defined number of times.
private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
var clientSocketToSource: Socket = null
var oosSource: ObjectOutputStream = null
var oisSource: ObjectInputStream = null
var receptionSucceeded = false
try {
// Connect to the source to get the object itself
clientSocketToSource =
new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
oosSource =
new ObjectOutputStream(clientSocketToSource.getOutputStream)
oosSource.flush()
oisSource =
new ObjectInputStream(clientSocketToSource.getInputStream)
logInfo("Inside receiveSingleTransmission")
logInfo("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
// Send the range
oosSource.writeObject((hasBlocks, totalBlocks))
oosSource.flush()
for (i <- hasBlocks until totalBlocks) {
val recvStartTime = System.currentTimeMillis
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
val receptionTime =(System.currentTimeMillis - recvStartTime)
logInfo("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
arrayOfBlocks(hasBlocks) = bcBlock
hasBlocks += 1
// Set to true if at least one block is received
receptionSucceeded = true
hasBlocksLock.synchronized {
hasBlocksLock.notifyAll
}
}
} catch {
case e: Exception => {
logInfo("receiveSingleTransmission had a " + e)
}
} finally {
if (oisSource != null) {
oisSource.close()
}
if (oosSource != null) {
oosSource.close()
}
if (clientSocketToSource != null) {
clientSocketToSource.close()
}
}
return receptionSucceeded
}
class GuideMultipleRequests
extends Thread with Logging {
// Keep track of sources that have completed reception
private var setOfCompletedSources = Set[SourceInfo]()
override def run: Unit = {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(0)
guidePort = serverSocket.getLocalPort
logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
guidePortLock.synchronized {
guidePortLock.notifyAll
}
try {
// Don't stop until there is a copy in HDFS
while (!stopBroadcast || !hasCopyInHDFS) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo("GuideMultipleRequests Timeout.")
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done. Comparing with
// pqOfSources.size - 1, because it includes the Guide itself
if (pqOfSources.size > 1 &&
setOfCompletedSources.size == pqOfSources.size - 1) {
stopBroadcast = true
}
}
}
if (clientSocket != null) {
logInfo("Guide: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new GuideSingleRequest(clientSocket))
} catch {
// In failure, close the socket here; else, the thread will close it
case ioe: IOException => clientSocket.close()
}
}
}
logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
ChainedBroadcast.unregisterValue(uuid)
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
serverSocket.close()
}
}
// Shutdown the thread pool
threadPool.shutdown()
}
private def sendStopBroadcastNotifications: Unit = {
pqOfSources.synchronized {
var pqIter = pqOfSources.iterator
while (pqIter.hasNext) {
var sourceInfo = pqIter.next
var guideSocketToSource: Socket = null
var gosSource: ObjectOutputStream = null
var gisSource: ObjectInputStream = null
try {
// Connect to the source
guideSocketToSource =
new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
gosSource =
new ObjectOutputStream(guideSocketToSource.getOutputStream)
gosSource.flush()
gisSource =
new ObjectInputStream(guideSocketToSource.getInputStream)
// Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
gosSource.writeObject((SourceInfo.StopBroadcast,
SourceInfo.StopBroadcast))
gosSource.flush()
} catch {
case e: Exception => {
logInfo("sendStopBroadcastNotifications had a " + e)
}
} finally {
if (gisSource != null) {
gisSource.close()
}
if (gosSource != null) {
gosSource.close()
}
if (guideSocketToSource != null) {
guideSocketToSource.close()
}
}
}
}
}
class GuideSingleRequest(val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
private val ois = new ObjectInputStream(clientSocket.getInputStream)
private var selectedSourceInfo: SourceInfo = null
private var thisWorkerInfo:SourceInfo = null
override def run: Unit = {
try {
logInfo("new GuideSingleRequest is running")
// Connecting worker is sending in its hostAddress and listenPort it will
// be listening to. Other fields are invalid(SourceInfo.UnusedParam)
var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
pqOfSources.synchronized {
// Select a suitable source and send it back to the worker
selectedSourceInfo = selectSuitableSource(sourceInfo)
logInfo("Sending selectedSourceInfo: " + selectedSourceInfo)
oos.writeObject(selectedSourceInfo)
oos.flush()
// Add this new(if it can finish) source to the PQ of sources
thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
sourceInfo.listenPort, totalBlocks, totalBytes, blockSize)
logInfo("Adding possible new source to pqOfSources: " + thisWorkerInfo)
pqOfSources.add(thisWorkerInfo)
}
// Wait till the whole transfer is done. Then receive and update source
// statistics in pqOfSources
sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
pqOfSources.synchronized {
// This should work since SourceInfo is a case class
assert(pqOfSources.contains(selectedSourceInfo))
// Remove first
pqOfSources.remove(selectedSourceInfo)
// TODO: Removing a source based on just one failure notification!
// Update sourceInfo and put it back in, IF reception succeeded
if (!sourceInfo.receptionFailed) {
// Add thisWorkerInfo to sources that have completed reception
setOfCompletedSources.synchronized {
setOfCompletedSources += thisWorkerInfo
}
selectedSourceInfo.currentLeechers -= 1
// Put it back
pqOfSources.add(selectedSourceInfo)
}
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close everything up
case e: Exception => {
// Assuming that exception caused due to receiver worker failure.
// Remove failed worker from pqOfSources and update leecherCount of
// corresponding source worker
pqOfSources.synchronized {
if (selectedSourceInfo != null) {
// Remove first
pqOfSources.remove(selectedSourceInfo)
// Update leecher count and put it back in
selectedSourceInfo.currentLeechers -= 1
pqOfSources.add(selectedSourceInfo)
}
// Remove thisWorkerInfo
if (pqOfSources != null) {
pqOfSources.remove(thisWorkerInfo)
}
}
}
} finally {
ois.close()
oos.close()
clientSocket.close()
}
}
// FIXME: Caller must have a synchronized block on pqOfSources
// FIXME: If a worker fails to get the broadcasted variable from a source and
// comes back to Master, this function might choose the worker itself as a
// source tp create a dependency cycle(this worker was put into pqOfSources
// as a streming source when it first arrived). The length of this cycle can
// be arbitrarily long.
private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
// Select one based on the ordering strategy(e.g., least leechers etc.)
// take is a blocking call removing the element from PQ
var selectedSource = pqOfSources.poll
assert(selectedSource != null)
// Update leecher count
selectedSource.currentLeechers += 1
// Add it back and then return
pqOfSources.add(selectedSource)
return selectedSource
}
}
}
class ServeMultipleRequests
extends Thread with Logging {
override def run: Unit = {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(0)
listenPort = serverSocket.getLocalPort
logInfo("ServeMultipleRequests started with " + serverSocket)
listenPortLock.synchronized {
listenPortLock.notifyAll
}
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo("ServeMultipleRequests Timeout.")
}
}
if (clientSocket != null) {
logInfo("Serve: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new ServeSingleRequest(clientSocket))
} catch {
// In failure, close socket here; else, the thread will close it
case ioe: IOException => clientSocket.close()
}
}
}
} finally {
if (serverSocket != null) {
logInfo("ServeMultipleRequests now stopping...")
serverSocket.close()
}
}
// Shutdown the thread pool
threadPool.shutdown()
}
class ServeSingleRequest(val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
private val ois = new ObjectInputStream(clientSocket.getInputStream)
private var sendFrom = 0
private var sendUntil = totalBlocks
override def run: Unit = {
try {
logInfo("new ServeSingleRequest is running")
// Receive range to send
var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
sendFrom = rangeToSend._1
sendUntil = rangeToSend._2
if (sendFrom == SourceInfo.StopBroadcast &&
sendUntil == SourceInfo.StopBroadcast) {
stopBroadcast = true
} else {
// Carry on
sendObject
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close everything up
case e: Exception => {
logInfo("ServeSingleRequest had a " + e)
}
} finally {
logInfo("ServeSingleRequest is closing streams and sockets")
ois.close()
oos.close()
clientSocket.close()
}
}
private def sendObject: Unit = {
// Wait till receiving the SourceInfo from Master
while (totalBlocks == -1) {
totalBlocksLock.synchronized {
totalBlocksLock.wait
}
}
for (i <- sendFrom until sendUntil) {
while (i == hasBlocks) {
hasBlocksLock.synchronized {
hasBlocksLock.wait
}
}
try {
oos.writeObject(arrayOfBlocks(i))
oos.flush()
} catch {
case e: Exception => {
logInfo("sendObject had a " + e)
}
}
logInfo("Sent block: " + i + " to " + clientSocket)
}
}
}
}
}
class ChainedBroadcastFactory
extends BroadcastFactory {
def initialize(isMaster: Boolean) = ChainedBroadcast.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) =
new ChainedBroadcast[T](value_, isLocal)
}
private object ChainedBroadcast
extends Logging {
val values = SparkEnv.get.cache.newKeySpace()
var valueToGuidePortMap = Map[UUID, Int]()
// Random number generator
var ranGen = new Random
private var initialized = false
private var isMaster_ = false
private var trackMV: TrackMultipleValues = null
def initialize(isMaster__ : Boolean): Unit = {
synchronized {
if (!initialized) {
isMaster_ = isMaster__
if (isMaster) {
trackMV = new TrackMultipleValues
trackMV.setDaemon(true)
trackMV.start()
// TODO: Logging the following line makes the Spark framework ID not
// getting logged, cause it calls logInfo before log4j is initialized
logInfo("TrackMultipleValues started...")
}
// Initialize DfsBroadcast to be used for broadcast variable persistence
DfsBroadcast.initialize
initialized = true
}
}
}
def isMaster = isMaster_
def registerValue(uuid: UUID, guidePort: Int): Unit = {
valueToGuidePortMap.synchronized {
valueToGuidePortMap +=(uuid -> guidePort)
logInfo("New value registered with the Tracker " + valueToGuidePortMap)
}
}
def unregisterValue(uuid: UUID): Unit = {
valueToGuidePortMap.synchronized {
valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToHDFS
logInfo("Value unregistered from the Tracker " + valueToGuidePortMap)
}
}
class TrackMultipleValues
extends Thread with Logging {
override def run: Unit = {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(Broadcast.MasterTrackerPort)
logInfo("TrackMultipleValues" + serverSocket)
try {
while (true) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo("TrackMultipleValues Timeout. Stopping listening...")
}
}
if (clientSocket != null) {
try {
threadPool.execute(new Thread {
override def run: Unit = {
val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
val ois = new ObjectInputStream(clientSocket.getInputStream)
try {
val uuid = ois.readObject.asInstanceOf[UUID]
var guidePort =
if (valueToGuidePortMap.contains(uuid)) {
valueToGuidePortMap(uuid)
} else SourceInfo.TxNotStartedRetry
logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
oos.writeObject(guidePort)
} catch {
case e: Exception => {
logInfo("TrackMultipleValues had a " + e)
}
} finally {
ois.close()
oos.close()
clientSocket.close()
}
}
})
} catch {
// In failure, close socket here; else, client thread will close
case ioe: IOException => clientSocket.close()
}
}
}
} finally {
serverSocket.close()
}
// Shutdown the thread pool
threadPool.shutdown()
}
}
}

View file

@ -1,4 +1,6 @@
package spark
package spark.broadcast
import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import java.io._
import java.net._
@ -7,7 +9,7 @@ import java.util.UUID
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import spark._
@serializable
class DfsBroadcast[T](@transient var value_ : T, isLocal: Boolean)

View file

@ -0,0 +1,41 @@
package spark.broadcast
import java.util.BitSet
import spark._
/**
* Used to keep and pass around information of peers involved in a broadcast
*
* CHANGED: Keep track of the blockSize for THIS broadcast variable.
* Broadcast.BlockSize is expected to be updated across different broadcasts
*/
@serializable
case class SourceInfo (val hostAddress: String,
val listenPort: Int,
val totalBlocks: Int = SourceInfo.UnusedParam,
val totalBytes: Int = SourceInfo.UnusedParam,
val blockSize: Int = Broadcast.BlockSize)
extends Comparable[SourceInfo] with Logging {
var currentLeechers = 0
var receptionFailed = false
var hasBlocks = 0
var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
// Ascending sort based on leecher count
def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)
}
/**
* Helper Object of SourceInfo for its constants
*/
object SourceInfo {
// Constants for special values of listenPort
val TxNotStartedRetry = -1
val TxOverGoToHDFS = 0
// Other constants
val StopBroadcast = -2
val UnusedParam = 0
}

View file

@ -0,0 +1,807 @@
package spark.broadcast
import java.io._
import java.net._
import java.util.{Comparator, Random, UUID}
import scala.collection.mutable.{ListBuffer, Map, Set}
import scala.math
import spark._
@serializable
class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging {
def value = value_
TreeBroadcast.synchronized {
TreeBroadcast.values.put(uuid, value_)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@transient var totalBytes = -1
@transient var totalBlocks = -1
@transient var hasBlocks = 0
// CHANGED: BlockSize in the Broadcast object is expected to change over time
@transient var blockSize = Broadcast.BlockSize
@transient var listenPortLock = new Object
@transient var guidePortLock = new Object
@transient var totalBlocksLock = new Object
@transient var hasBlocksLock = new Object
@transient var listOfSources = ListBuffer[SourceInfo]()
@transient var serveMR: ServeMultipleRequests = null
@transient var guideMR: GuideMultipleRequests = null
@transient var hostAddress = Utils.localIpAddress
@transient var listenPort = -1
@transient var guidePort = -1
@transient var hasCopyInHDFS = false
@transient var stopBroadcast = false
// Must call this after all the variables have been created/initialized
if (!isLocal) {
sendBroadcast
}
def sendBroadcast(): Unit = {
logInfo("Local host address: " + hostAddress)
// Store a persistent copy in HDFS
// TODO: Turned OFF for now
// val out = new ObjectOutputStream(DfsBroadcast.openFileForWriting(uuid))
// out.writeObject(value_)
// out.close()
// TODO: Fix this at some point
hasCopyInHDFS = true
// Create a variableInfo object and store it in valueInfos
var variableInfo = Broadcast.blockifyObject(value_)
// Prepare the value being broadcasted
// TODO: Refactoring and clean-up required here
arrayOfBlocks = variableInfo.arrayOfBlocks
totalBytes = variableInfo.totalBytes
totalBlocks = variableInfo.totalBlocks
hasBlocks = variableInfo.totalBlocks
guideMR = new GuideMultipleRequests
guideMR.setDaemon(true)
guideMR.start
logInfo("GuideMultipleRequests started...")
// Must always come AFTER guideMR is created
while (guidePort == -1) {
guidePortLock.synchronized {
guidePortLock.wait
}
}
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start
logInfo("ServeMultipleRequests started...")
// Must always come AFTER serveMR is created
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait
}
}
// Must always come AFTER listenPort is created
val masterSource =
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize)
listOfSources += masterSource
// Register with the Tracker
TreeBroadcast.registerValue(uuid, guidePort)
}
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject
TreeBroadcast.synchronized {
val cachedVal = TreeBroadcast.values.get(uuid)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
// Initializing everything because Master will only send null/0 values
initializeSlaveVariables
logInfo("Local host address: " + hostAddress)
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start
logInfo("ServeMultipleRequests started...")
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast(uuid)
// If does not succeed, then get from HDFS copy
if (receptionSucceeded) {
value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
TreeBroadcast.values.put(uuid, value_)
} else {
val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
value_ = fileIn.readObject.asInstanceOf[T]
TreeBroadcast.values.put(uuid, value_)
fileIn.close()
}
val time = (System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
}
}
}
private def initializeSlaveVariables: Unit = {
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
hasBlocks = 0
blockSize = -1
listenPortLock = new Object
totalBlocksLock = new Object
hasBlocksLock = new Object
serveMR = null
hostAddress = Utils.localIpAddress
listenPort = -1
stopBroadcast = false
}
def getMasterListenPort(variableUUID: UUID): Int = {
var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
var masterListenPort: Int = SourceInfo.TxOverGoToHDFS
var retriesLeft = Broadcast.MaxRetryCount
do {
try {
// Connect to the tracker to find out the guide
clientSocketToTracker =
new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort)
oosTracker =
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
oosTracker.flush()
oisTracker =
new ObjectInputStream(clientSocketToTracker.getInputStream)
// Send UUID and receive masterListenPort
oosTracker.writeObject(uuid)
oosTracker.flush()
masterListenPort = oisTracker.readObject.asInstanceOf[Int]
} catch {
case e: Exception => {
logInfo("getMasterListenPort had a " + e)
}
} finally {
if (oisTracker != null) {
oisTracker.close()
}
if (oosTracker != null) {
oosTracker.close()
}
if (clientSocketToTracker != null) {
clientSocketToTracker.close()
}
}
retriesLeft -= 1
Thread.sleep(TreeBroadcast.ranGen.nextInt(
Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) +
Broadcast.MinKnockInterval)
} while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry)
logInfo("Got this guidePort from Tracker: " + masterListenPort)
return masterListenPort
}
def receiveBroadcast(variableUUID: UUID): Boolean = {
val masterListenPort = getMasterListenPort(variableUUID)
if (masterListenPort == SourceInfo.TxOverGoToHDFS ||
masterListenPort == SourceInfo.TxNotStartedRetry) {
// TODO: SourceInfo.TxNotStartedRetry is not really in use because we go
// to HDFS anyway when receiveBroadcast returns false
return false
}
// Wait until hostAddress and listenPort are created by the
// ServeMultipleRequests thread
while (listenPort == -1) {
listenPortLock.synchronized {
listenPortLock.wait
}
}
var clientSocketToMaster: Socket = null
var oosMaster: ObjectOutputStream = null
var oisMaster: ObjectInputStream = null
// Connect and receive broadcast from the specified source, retrying the
// specified number of times in case of failures
var retriesLeft = Broadcast.MaxRetryCount
do {
// Connect to Master and send this worker's Information
clientSocketToMaster =
new Socket(Broadcast.MasterHostAddress, masterListenPort)
// TODO: Guiding object connection is reusable
oosMaster =
new ObjectOutputStream(clientSocketToMaster.getOutputStream)
oosMaster.flush()
oisMaster =
new ObjectInputStream(clientSocketToMaster.getInputStream)
logInfo("Connected to Master's guiding object")
// Send local source information
oosMaster.writeObject(SourceInfo(hostAddress, listenPort))
oosMaster.flush()
// Receive source information from Master
var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
totalBlocks = sourceInfo.totalBlocks
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
totalBlocksLock.synchronized {
totalBlocksLock.notifyAll
}
totalBytes = sourceInfo.totalBytes
blockSize = sourceInfo.blockSize
logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
val start = System.nanoTime
val receptionSucceeded = receiveSingleTransmission(sourceInfo)
val time = (System.nanoTime - start) / 1e9
// Updating some statistics in sourceInfo. Master will be using them later
if (!receptionSucceeded) {
sourceInfo.receptionFailed = true
}
// Send back statistics to the Master
oosMaster.writeObject(sourceInfo)
if (oisMaster != null) {
oisMaster.close()
}
if (oosMaster != null) {
oosMaster.close()
}
if (clientSocketToMaster != null) {
clientSocketToMaster.close()
}
retriesLeft -= 1
} while (retriesLeft > 0 && hasBlocks < totalBlocks)
return (hasBlocks == totalBlocks)
}
// Tries to receive broadcast from the source and returns Boolean status.
// This might be called multiple times to retry a defined number of times.
private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
var clientSocketToSource: Socket = null
var oosSource: ObjectOutputStream = null
var oisSource: ObjectInputStream = null
var receptionSucceeded = false
try {
// Connect to the source to get the object itself
clientSocketToSource =
new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
oosSource =
new ObjectOutputStream(clientSocketToSource.getOutputStream)
oosSource.flush()
oisSource =
new ObjectInputStream(clientSocketToSource.getInputStream)
logInfo("Inside receiveSingleTransmission")
logInfo("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
// Send the range
oosSource.writeObject((hasBlocks, totalBlocks))
oosSource.flush()
for (i <- hasBlocks until totalBlocks) {
val recvStartTime = System.currentTimeMillis
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
val receptionTime = (System.currentTimeMillis - recvStartTime)
logInfo("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
arrayOfBlocks(hasBlocks) = bcBlock
hasBlocks += 1
// Set to true if at least one block is received
receptionSucceeded = true
hasBlocksLock.synchronized {
hasBlocksLock.notifyAll
}
}
} catch {
case e: Exception => {
logInfo("receiveSingleTransmission had a " + e)
}
} finally {
if (oisSource != null) {
oisSource.close()
}
if (oosSource != null) {
oosSource.close()
}
if (clientSocketToSource != null) {
clientSocketToSource.close()
}
}
return receptionSucceeded
}
class GuideMultipleRequests
extends Thread with Logging {
// Keep track of sources that have completed reception
private var setOfCompletedSources = Set[SourceInfo]()
override def run: Unit = {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(0)
guidePort = serverSocket.getLocalPort
logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
guidePortLock.synchronized {
guidePortLock.notifyAll
}
try {
// Don't stop until there is a copy in HDFS
while (!stopBroadcast || !hasCopyInHDFS) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo("GuideMultipleRequests Timeout.")
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done. Comparing with
// listOfSources.size - 1, because it includes the Guide itself
if (listOfSources.size > 1 &&
setOfCompletedSources.size == listOfSources.size - 1) {
stopBroadcast = true
}
}
}
if (clientSocket != null) {
logInfo("Guide: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new GuideSingleRequest(clientSocket))
} catch {
// In failure, close() the socket here; else, the thread will close() it
case ioe: IOException => clientSocket.close()
}
}
}
logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
TreeBroadcast.unregisterValue(uuid)
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
serverSocket.close()
}
}
// Shutdown the thread pool
threadPool.shutdown
}
private def sendStopBroadcastNotifications: Unit = {
listOfSources.synchronized {
var listIter = listOfSources.iterator
while (listIter.hasNext) {
var sourceInfo = listIter.next
var guideSocketToSource: Socket = null
var gosSource: ObjectOutputStream = null
var gisSource: ObjectInputStream = null
try {
// Connect to the source
guideSocketToSource =
new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
gosSource =
new ObjectOutputStream(guideSocketToSource.getOutputStream)
gosSource.flush()
gisSource =
new ObjectInputStream(guideSocketToSource.getInputStream)
// Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2
gosSource.writeObject((SourceInfo.StopBroadcast,
SourceInfo.StopBroadcast))
gosSource.flush()
} catch {
case e: Exception => {
logInfo("sendStopBroadcastNotifications had a " + e)
}
} finally {
if (gisSource != null) {
gisSource.close()
}
if (gosSource != null) {
gosSource.close()
}
if (guideSocketToSource != null) {
guideSocketToSource.close()
}
}
}
}
}
class GuideSingleRequest(val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
private val ois = new ObjectInputStream(clientSocket.getInputStream)
private var selectedSourceInfo: SourceInfo = null
private var thisWorkerInfo:SourceInfo = null
override def run: Unit = {
try {
logInfo("new GuideSingleRequest is running")
// Connecting worker is sending in its hostAddress and listenPort it will
// be listening to. Other fields are invalid (SourceInfo.UnusedParam)
var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
listOfSources.synchronized {
// Select a suitable source and send it back to the worker
selectedSourceInfo = selectSuitableSource(sourceInfo)
logInfo("Sending selectedSourceInfo: " + selectedSourceInfo)
oos.writeObject(selectedSourceInfo)
oos.flush()
// Add this new (if it can finish) source to the list of sources
thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
sourceInfo.listenPort, totalBlocks, totalBytes, blockSize)
logInfo("Adding possible new source to listOfSources: " + thisWorkerInfo)
listOfSources += thisWorkerInfo
}
// Wait till the whole transfer is done. Then receive and update source
// statistics in listOfSources
sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
listOfSources.synchronized {
// This should work since SourceInfo is a case class
assert(listOfSources.contains(selectedSourceInfo))
// Remove first
listOfSources = listOfSources - selectedSourceInfo
// TODO: Removing a source based on just one failure notification!
// Update sourceInfo and put it back in, IF reception succeeded
if (!sourceInfo.receptionFailed) {
// Add thisWorkerInfo to sources that have completed reception
setOfCompletedSources.synchronized {
setOfCompletedSources += thisWorkerInfo
}
selectedSourceInfo.currentLeechers -= 1
// Put it back
listOfSources += selectedSourceInfo
}
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close() everything up
case e: Exception => {
// Assuming that exception caused due to receiver worker failure.
// Remove failed worker from listOfSources and update leecherCount of
// corresponding source worker
listOfSources.synchronized {
if (selectedSourceInfo != null) {
// Remove first
listOfSources = listOfSources - selectedSourceInfo
// Update leecher count and put it back in
selectedSourceInfo.currentLeechers -= 1
listOfSources += selectedSourceInfo
}
// Remove thisWorkerInfo
if (listOfSources != null) {
listOfSources = listOfSources - thisWorkerInfo
}
}
}
} finally {
ois.close()
oos.close()
clientSocket.close()
}
}
// FIXME: Caller must have a synchronized block on listOfSources
// FIXME: If a worker fails to get the broadcasted variable from a source
// and comes back to the Master, this function might choose the worker
// itself as a source to create a dependency cycle (this worker was put
// into listOfSources as a streming source when it first arrived). The
// length of this cycle can be arbitrarily long.
private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
// Select one with the most leechers. This will level-wise fill the tree
var maxLeechers = -1
var selectedSource: SourceInfo = null
listOfSources.foreach { source =>
if (source != skipSourceInfo &&
source.currentLeechers < Broadcast.MaxDegree &&
source.currentLeechers > maxLeechers) {
selectedSource = source
maxLeechers = source.currentLeechers
}
}
// Update leecher count
selectedSource.currentLeechers += 1
return selectedSource
}
}
}
class ServeMultipleRequests
extends Thread with Logging {
override def run: Unit = {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(0)
listenPort = serverSocket.getLocalPort
logInfo("ServeMultipleRequests started with " + serverSocket)
listenPortLock.synchronized {
listenPortLock.notifyAll
}
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo("ServeMultipleRequests Timeout.")
}
}
if (clientSocket != null) {
logInfo("Serve: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new ServeSingleRequest(clientSocket))
} catch {
// In failure, close() socket here; else, the thread will close() it
case ioe: IOException => clientSocket.close()
}
}
}
} finally {
if (serverSocket != null) {
logInfo("ServeMultipleRequests now stopping...")
serverSocket.close()
}
}
// Shutdown the thread pool
threadPool.shutdown
}
class ServeSingleRequest(val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
private val ois = new ObjectInputStream(clientSocket.getInputStream)
private var sendFrom = 0
private var sendUntil = totalBlocks
override def run: Unit = {
try {
logInfo("new ServeSingleRequest is running")
// Receive range to send
var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
sendFrom = rangeToSend._1
sendUntil = rangeToSend._2
if (sendFrom == SourceInfo.StopBroadcast &&
sendUntil == SourceInfo.StopBroadcast) {
stopBroadcast = true
} else {
// Carry on
sendObject
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close() everything up
case e: Exception => {
logInfo("ServeSingleRequest had a " + e)
}
} finally {
logInfo("ServeSingleRequest is closing streams and sockets")
ois.close()
oos.close()
clientSocket.close()
}
}
private def sendObject: Unit = {
// Wait till receiving the SourceInfo from Master
while (totalBlocks == -1) {
totalBlocksLock.synchronized {
totalBlocksLock.wait
}
}
for (i <- sendFrom until sendUntil) {
while (i == hasBlocks) {
hasBlocksLock.synchronized {
hasBlocksLock.wait
}
}
try {
oos.writeObject(arrayOfBlocks(i))
oos.flush()
} catch {
case e: Exception => {
logInfo("sendObject had a " + e)
}
}
logInfo("Sent block: " + i + " to " + clientSocket)
}
}
}
}
}
class TreeBroadcastFactory
extends BroadcastFactory {
def initialize(isMaster: Boolean) = TreeBroadcast.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) =
new TreeBroadcast[T](value_, isLocal)
}
private object TreeBroadcast
extends Logging {
val values = SparkEnv.get.cache.newKeySpace()
var valueToGuidePortMap = Map[UUID, Int]()
// Random number generator
var ranGen = new Random
private var initialized = false
private var isMaster_ = false
private var trackMV: TrackMultipleValues = null
private var MaxDegree_ : Int = 2
def initialize(isMaster__ : Boolean): Unit = {
synchronized {
if (!initialized) {
isMaster_ = isMaster__
if (isMaster) {
trackMV = new TrackMultipleValues
trackMV.setDaemon(true)
trackMV.start
// TODO: Logging the following line makes the Spark framework ID not
// getting logged, cause it calls logInfo before log4j is initialized
logInfo("TrackMultipleValues started...")
}
// Initialize DfsBroadcast to be used for broadcast variable persistence
DfsBroadcast.initialize
initialized = true
}
}
}
def isMaster = isMaster_
def registerValue(uuid: UUID, guidePort: Int): Unit = {
valueToGuidePortMap.synchronized {
valueToGuidePortMap += (uuid -> guidePort)
logInfo("New value registered with the Tracker " + valueToGuidePortMap)
}
}
def unregisterValue(uuid: UUID): Unit = {
valueToGuidePortMap.synchronized {
valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToHDFS
logInfo("Value unregistered from the Tracker " + valueToGuidePortMap)
}
}
class TrackMultipleValues
extends Thread with Logging {
override def run: Unit = {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(Broadcast.MasterTrackerPort)
logInfo("TrackMultipleValues" + serverSocket)
try {
while (true) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logInfo("TrackMultipleValues Timeout. Stopping listening...")
}
}
if (clientSocket != null) {
try {
threadPool.execute(new Thread {
override def run: Unit = {
val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
val ois = new ObjectInputStream(clientSocket.getInputStream)
try {
val uuid = ois.readObject.asInstanceOf[UUID]
var guidePort =
if (valueToGuidePortMap.contains(uuid)) {
valueToGuidePortMap(uuid)
} else SourceInfo.TxNotStartedRetry
logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort)
oos.writeObject(guidePort)
} catch {
case e: Exception => {
logInfo("TrackMultipleValues had a " + e)
}
} finally {
ois.close()
oos.close()
clientSocket.close()
}
}
})
} catch {
// In failure, close() socket here; else, client thread will close()
case ioe: IOException => clientSocket.close()
}
}
}
} finally {
serverSocket.close()
}
// Shutdown the thread pool
threadPool.shutdown
}
}
}

View file

@ -5,8 +5,8 @@ import org.scalatest.prop.Checkers
import org.scalacheck.Arbitrary._
import org.scalacheck.Gen
import org.scalacheck.Prop._
import SparkContext._
import scala.collection.mutable.ArrayBuffer
class ShuffleSuite extends FunSuite {
test("groupByKey") {
@ -115,6 +115,38 @@ class ShuffleSuite extends FunSuite {
sc.stop()
}
test("leftOuterJoin") {
val sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.leftOuterJoin(rdd2).collect()
assert(joined.size === 5)
assert(joined.toSet === Set(
(1, (1, Some('x'))),
(1, (2, Some('x'))),
(2, (1, Some('y'))),
(2, (1, Some('z'))),
(3, (1, None))
))
sc.stop()
}
test("rightOuterJoin") {
val sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.rightOuterJoin(rdd2).collect()
assert(joined.size === 5)
assert(joined.toSet === Set(
(1, (Some(1), 'x')),
(1, (Some(2), 'x')),
(2, (Some(1), 'y')),
(2, (Some(1), 'z')),
(4, (None, 'w'))
))
sc.stop()
}
test("join with no matches") {
val sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
@ -138,4 +170,20 @@ class ShuffleSuite extends FunSuite {
))
sc.stop()
}
test("groupWith") {
val sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.groupWith(rdd2).collect()
assert(joined.size === 4)
assert(joined.toSet === Set(
(1, (ArrayBuffer(1, 2), ArrayBuffer('x'))),
(2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))),
(3, (ArrayBuffer(1), ArrayBuffer())),
(4, (ArrayBuffer(), ArrayBuffer('w')))
))
sc.stop()
}
}

View file

@ -5,9 +5,10 @@ import spark.SparkContext
object BroadcastTest {
def main(args: Array[String]) {
if (args.length == 0) {
System.err.println("Usage: BroadcastTest <host> [<slices>]")
System.err.println("Usage: BroadcastTest <host> [<slices>] [numElem]")
System.exit(1)
}
val spark = new SparkContext(args(0), "Broadcast Test")
val slices = if (args.length > 1) args(1).toInt else 2
val num = if (args.length > 2) args(2).toInt else 1000000
@ -16,14 +17,8 @@ object BroadcastTest {
for (i <- 0 until arr1.length)
arr1(i) = i
// var arr2 = new Array[Int](num * 2)
// for (i <- 0 until arr2.length)
// arr2(i) = i
val barr1 = spark.broadcast(arr1)
// val barr2 = spark.broadcast(arr2)
spark.parallelize(1 to 10, slices).foreach {
// i => println(barr1.value.size + barr2.value.size)
i => println(barr1.value.size)
}
}

View file

@ -0,0 +1,37 @@
package spark.examples
import spark.SparkContext
import spark.SparkContext._
import java.util.Random
object GroupByTest {
def main(args: Array[String]) {
if (args.length == 0) {
System.err.println("Usage: GroupByTest <host> [numMappers] [numKVPairs] [KeySize] [numReducers]")
System.exit(1)
}
var numMappers = if (args.length > 1) args(1).toInt else 2
var numKVPairs = if (args.length > 2) args(2).toInt else 1000
var valSize = if (args.length > 3) args(3).toInt else 1000
var numReducers = if (args.length > 4) args(4).toInt else numMappers
val sc = new SparkContext(args(0), "GroupBy Test")
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
var arr1 = new Array[(Int, Array[Byte])](numKVPairs)
for (i <- 0 until numKVPairs) {
val byteArr = new Array[Byte](valSize)
ranGen.nextBytes(byteArr)
arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr)
}
arr1
}.cache
// Enforce that everything has been calculated and in cache
pairs1.count
println(pairs1.groupByKey(numReducers).count)
}
}

View file

@ -0,0 +1,30 @@
package spark.examples
import spark.SparkContext
object MultiBroadcastTest {
def main(args: Array[String]) {
if (args.length == 0) {
System.err.println("Usage: BroadcastTest <host> [<slices>] [numElem]")
System.exit(1)
}
val spark = new SparkContext(args(0), "Broadcast Test")
val slices = if (args.length > 1) args(1).toInt else 2
val num = if (args.length > 2) args(2).toInt else 1000000
var arr1 = new Array[Int](num)
for (i <- 0 until arr1.length)
arr1(i) = i
var arr2 = new Array[Int](num)
for (i <- 0 until arr2.length)
arr2(i) = i
val barr1 = spark.broadcast(arr1)
val barr2 = spark.broadcast(arr2)
spark.parallelize(1 to 10, slices).foreach {
i => println(barr1.value.size + barr2.value.size)
}
}
}

View file

@ -0,0 +1,51 @@
package spark.examples
import spark.SparkContext
import spark.SparkContext._
import java.util.Random
object SimpleSkewedGroupByTest {
def main(args: Array[String]) {
if (args.length == 0) {
System.err.println("Usage: SimpleSkewedGroupByTest <host> " +
"[numMappers] [numKVPairs] [valSize] [numReducers] [ratio]")
System.exit(1)
}
var numMappers = if (args.length > 1) args(1).toInt else 2
var numKVPairs = if (args.length > 2) args(2).toInt else 1000
var valSize = if (args.length > 3) args(3).toInt else 1000
var numReducers = if (args.length > 4) args(4).toInt else numMappers
var ratio = if (args.length > 5) args(5).toInt else 5.0
val sc = new SparkContext(args(0), "GroupBy Test")
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
var result = new Array[(Int, Array[Byte])](numKVPairs)
for (i <- 0 until numKVPairs) {
val byteArr = new Array[Byte](valSize)
ranGen.nextBytes(byteArr)
val offset = ranGen.nextInt(1000) * numReducers
if (ranGen.nextDouble < ratio / (numReducers + ratio - 1)) {
// give ratio times higher chance of generating key 0 (for reducer 0)
result(i) = (offset, byteArr)
} else {
// generate a key for one of the other reducers
val key = 1 + ranGen.nextInt(numReducers-1) + offset
result(i) = (key, byteArr)
}
}
result
}.cache
// Enforce that everything has been calculated and in cache
pairs1.count
println("RESULT: " + pairs1.groupByKey(numReducers).count)
// Print how many keys each reducer got (for debugging)
//println("RESULT: " + pairs1.groupByKey(numReducers)
// .map{case (k,v) => (k, v.size)}
// .collectAsMap)
}
}

View file

@ -0,0 +1,41 @@
package spark.examples
import spark.SparkContext
import spark.SparkContext._
import java.util.Random
object SkewedGroupByTest {
def main(args: Array[String]) {
if (args.length == 0) {
System.err.println("Usage: GroupByTest <host> [numMappers] [numKVPairs] [KeySize] [numReducers]")
System.exit(1)
}
var numMappers = if (args.length > 1) args(1).toInt else 2
var numKVPairs = if (args.length > 2) args(2).toInt else 1000
var valSize = if (args.length > 3) args(3).toInt else 1000
var numReducers = if (args.length > 4) args(4).toInt else numMappers
val sc = new SparkContext(args(0), "GroupBy Test")
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
// map output sizes lineraly increase from the 1st to the last
numKVPairs = (1. * (p + 1) / numMappers * numKVPairs).toInt
var arr1 = new Array[(Int, Array[Byte])](numKVPairs)
for (i <- 0 until numKVPairs) {
val byteArr = new Array[Byte](valSize)
ranGen.nextBytes(byteArr)
arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr)
}
arr1
}.cache
// Enforce that everything has been calculated and in cache
pairs1.count
println(pairs1.groupByKey(numReducers).count)
}
}