diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index cfd6dc8b2a..68ccab24db 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -76,6 +76,12 @@ object Utils { } } catch { case e: IOException => ; } } + // Add a shutdown hook to delete the temp dir when the JVM exits + Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) { + override def run() { + Utils.deleteRecursively(dir) + } + }) return dir } diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala index cdf05fe5de..06049749a9 100644 --- a/core/src/main/scala/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/spark/broadcast/Broadcast.scala @@ -33,7 +33,7 @@ object Broadcast extends Logging with Serializable { def initialize (isMaster__ : Boolean): Unit = synchronized { if (!initialized) { val broadcastFactoryClass = System.getProperty( - "spark.broadcast.factory", "spark.broadcast.DfsBroadcastFactory") + "spark.broadcast.factory", "spark.broadcast.HttpBroadcastFactory") broadcastFactory = Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] @@ -219,4 +219,4 @@ class SpeedTracker extends Serializable { } override def toString = sourceToSpeedMap.toString -} \ No newline at end of file +} diff --git a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala index 341746d18e..b18908f789 100644 --- a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala @@ -7,6 +7,6 @@ package spark.broadcast * entire Spark job. */ trait BroadcastFactory { - def initialize (isMaster: Boolean): Unit - def newBroadcast[T] (value_ : T, isLocal: Boolean): Broadcast[T] -} \ No newline at end of file + def initialize(isMaster: Boolean): Unit + def newBroadcast[T](value_ : T, isLocal: Boolean): Broadcast[T] +} diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala new file mode 100644 index 0000000000..4714816591 --- /dev/null +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -0,0 +1,110 @@ +package spark.broadcast + +import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} + +import java.io._ +import java.net._ +import java.util.UUID + +import it.unimi.dsi.fastutil.io.FastBufferedInputStream +import it.unimi.dsi.fastutil.io.FastBufferedOutputStream + +import spark._ + +class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean) +extends Broadcast[T] with Logging with Serializable { + + def value = value_ + + HttpBroadcast.synchronized { + HttpBroadcast.values.put(uuid, 0, value_) + } + + if (!isLocal) { + HttpBroadcast.write(uuid, value_) + } + + // Called by JVM when deserializing an object + private def readObject(in: ObjectInputStream): Unit = { + in.defaultReadObject() + HttpBroadcast.synchronized { + val cachedVal = HttpBroadcast.values.get(uuid, 0) + if (cachedVal != null) { + value_ = cachedVal.asInstanceOf[T] + } else { + logInfo("Started reading broadcast variable " + uuid) + val start = System.nanoTime + value_ = HttpBroadcast.read(uuid).asInstanceOf[T] + HttpBroadcast.values.put(uuid, 0, value_) + val time = (System.nanoTime - start) / 1e9 + logInfo("Reading broadcast variable " + uuid + " took " + time + " s") + } + } + } +} + +class HttpBroadcastFactory extends BroadcastFactory { + def initialize(isMaster: Boolean): Unit = HttpBroadcast.initialize(isMaster) + def newBroadcast[T](value_ : T, isLocal: Boolean) = new HttpBroadcast[T](value_, isLocal) +} + +private object HttpBroadcast extends Logging { + val values = SparkEnv.get.cache.newKeySpace() + + private var initialized = false + + private var broadcastDir: File = null + private var compress: Boolean = false + private var bufferSize: Int = 65536 + private var serverUri: String = null + private var server: HttpServer = null + + def initialize(isMaster: Boolean): Unit = { + synchronized { + if (!initialized) { + bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + compress = System.getProperty("spark.compress", "false").toBoolean + if (isMaster) { + createServer() + } + serverUri = System.getProperty("spark.httpBroadcast.uri") + initialized = true + } + } + } + + private def createServer() { + broadcastDir = Utils.createTempDir() + server = new HttpServer(broadcastDir) + server.start() + serverUri = server.uri + System.setProperty("spark.httpBroadcast.uri", serverUri) + logInfo("Broadcast server started at " + serverUri) + } + + def write(uuid: UUID, value: Any) { + val file = new File(broadcastDir, "broadcast-" + uuid) + val out: OutputStream = if (compress) { + new LZFOutputStream(new FileOutputStream(file)) // Does its own buffering + } else { + new FastBufferedOutputStream(new FileOutputStream(file), bufferSize) + } + val ser = SparkEnv.get.serializer.newInstance() + val serOut = ser.outputStream(out) + serOut.writeObject(value) + serOut.close() + } + + def read(uuid: UUID): Any = { + val url = serverUri + "/broadcast-" + uuid + var in = if (compress) { + new LZFInputStream(new URL(url).openStream()) // Does its own buffering + } else { + new FastBufferedInputStream(new URL(url).openStream(), bufferSize) + } + val ser = SparkEnv.get.serializer.newInstance() + val serIn = ser.inputStream(in) + val obj = serIn.readObject() + serIn.close() + } +} diff --git a/core/src/test/scala/spark/BroadcastSuite.scala b/core/src/test/scala/spark/BroadcastSuite.scala new file mode 100644 index 0000000000..750703de30 --- /dev/null +++ b/core/src/test/scala/spark/BroadcastSuite.scala @@ -0,0 +1,23 @@ +package spark + +import org.scalatest.FunSuite + +class BroadcastSuite extends FunSuite { + test("basic broadcast") { + val sc = new SparkContext("local", "test") + val list = List(1, 2, 3, 4) + val listBroadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === Set((1, 10), (2, 10))) + sc.stop() + } + + test("broadcast variables accessed in multiple threads") { + val sc = new SparkContext("local[10]", "test") + val list = List(1, 2, 3, 4) + val listBroadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + sc.stop() + } +}