Change the default broadcast implementation to a simple HTTP-based

broadcast. Fixes #139.
This commit is contained in:
Matei Zaharia 2012-06-09 15:58:07 -07:00
parent a96558caa3
commit e75b1b5cb4
5 changed files with 144 additions and 5 deletions

View file

@ -76,6 +76,12 @@ object Utils {
}
} catch { case e: IOException => ; }
}
// Add a shutdown hook to delete the temp dir when the JVM exits
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) {
override def run() {
Utils.deleteRecursively(dir)
}
})
return dir
}

View file

@ -33,7 +33,7 @@ object Broadcast extends Logging with Serializable {
def initialize (isMaster__ : Boolean): Unit = synchronized {
if (!initialized) {
val broadcastFactoryClass = System.getProperty(
"spark.broadcast.factory", "spark.broadcast.DfsBroadcastFactory")
"spark.broadcast.factory", "spark.broadcast.HttpBroadcastFactory")
broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
@ -219,4 +219,4 @@ class SpeedTracker extends Serializable {
}
override def toString = sourceToSpeedMap.toString
}
}

View file

@ -7,6 +7,6 @@ package spark.broadcast
* entire Spark job.
*/
trait BroadcastFactory {
def initialize (isMaster: Boolean): Unit
def newBroadcast[T] (value_ : T, isLocal: Boolean): Broadcast[T]
}
def initialize(isMaster: Boolean): Unit
def newBroadcast[T](value_ : T, isLocal: Boolean): Broadcast[T]
}

View file

@ -0,0 +1,110 @@
package spark.broadcast
import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import java.io._
import java.net._
import java.util.UUID
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import spark._
class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging with Serializable {
def value = value_
HttpBroadcast.synchronized {
HttpBroadcast.values.put(uuid, 0, value_)
}
if (!isLocal) {
HttpBroadcast.write(uuid, value_)
}
// Called by JVM when deserializing an object
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
HttpBroadcast.synchronized {
val cachedVal = HttpBroadcast.values.get(uuid, 0)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
logInfo("Started reading broadcast variable " + uuid)
val start = System.nanoTime
value_ = HttpBroadcast.read(uuid).asInstanceOf[T]
HttpBroadcast.values.put(uuid, 0, value_)
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + uuid + " took " + time + " s")
}
}
}
}
class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isMaster: Boolean): Unit = HttpBroadcast.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) = new HttpBroadcast[T](value_, isLocal)
}
private object HttpBroadcast extends Logging {
val values = SparkEnv.get.cache.newKeySpace()
private var initialized = false
private var broadcastDir: File = null
private var compress: Boolean = false
private var bufferSize: Int = 65536
private var serverUri: String = null
private var server: HttpServer = null
def initialize(isMaster: Boolean): Unit = {
synchronized {
if (!initialized) {
bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
compress = System.getProperty("spark.compress", "false").toBoolean
if (isMaster) {
createServer()
}
serverUri = System.getProperty("spark.httpBroadcast.uri")
initialized = true
}
}
}
private def createServer() {
broadcastDir = Utils.createTempDir()
server = new HttpServer(broadcastDir)
server.start()
serverUri = server.uri
System.setProperty("spark.httpBroadcast.uri", serverUri)
logInfo("Broadcast server started at " + serverUri)
}
def write(uuid: UUID, value: Any) {
val file = new File(broadcastDir, "broadcast-" + uuid)
val out: OutputStream = if (compress) {
new LZFOutputStream(new FileOutputStream(file)) // Does its own buffering
} else {
new FastBufferedOutputStream(new FileOutputStream(file), bufferSize)
}
val ser = SparkEnv.get.serializer.newInstance()
val serOut = ser.outputStream(out)
serOut.writeObject(value)
serOut.close()
}
def read(uuid: UUID): Any = {
val url = serverUri + "/broadcast-" + uuid
var in = if (compress) {
new LZFInputStream(new URL(url).openStream()) // Does its own buffering
} else {
new FastBufferedInputStream(new URL(url).openStream(), bufferSize)
}
val ser = SparkEnv.get.serializer.newInstance()
val serIn = ser.inputStream(in)
val obj = serIn.readObject()
serIn.close()
}
}

View file

@ -0,0 +1,23 @@
package spark
import org.scalatest.FunSuite
class BroadcastSuite extends FunSuite {
test("basic broadcast") {
val sc = new SparkContext("local", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === Set((1, 10), (2, 10)))
sc.stop()
}
test("broadcast variables accessed in multiple threads") {
val sc = new SparkContext("local[10]", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
sc.stop()
}
}