Change the default broadcast implementation to a simple HTTP-based
broadcast. Fixes #139.
This commit is contained in:
parent
a96558caa3
commit
e75b1b5cb4
|
@ -76,6 +76,12 @@ object Utils {
|
|||
}
|
||||
} catch { case e: IOException => ; }
|
||||
}
|
||||
// Add a shutdown hook to delete the temp dir when the JVM exits
|
||||
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) {
|
||||
override def run() {
|
||||
Utils.deleteRecursively(dir)
|
||||
}
|
||||
})
|
||||
return dir
|
||||
}
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ object Broadcast extends Logging with Serializable {
|
|||
def initialize (isMaster__ : Boolean): Unit = synchronized {
|
||||
if (!initialized) {
|
||||
val broadcastFactoryClass = System.getProperty(
|
||||
"spark.broadcast.factory", "spark.broadcast.DfsBroadcastFactory")
|
||||
"spark.broadcast.factory", "spark.broadcast.HttpBroadcastFactory")
|
||||
|
||||
broadcastFactory =
|
||||
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
|
||||
|
@ -219,4 +219,4 @@ class SpeedTracker extends Serializable {
|
|||
}
|
||||
|
||||
override def toString = sourceToSpeedMap.toString
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,6 @@ package spark.broadcast
|
|||
* entire Spark job.
|
||||
*/
|
||||
trait BroadcastFactory {
|
||||
def initialize (isMaster: Boolean): Unit
|
||||
def newBroadcast[T] (value_ : T, isLocal: Boolean): Broadcast[T]
|
||||
}
|
||||
def initialize(isMaster: Boolean): Unit
|
||||
def newBroadcast[T](value_ : T, isLocal: Boolean): Broadcast[T]
|
||||
}
|
||||
|
|
110
core/src/main/scala/spark/broadcast/HttpBroadcast.scala
Normal file
110
core/src/main/scala/spark/broadcast/HttpBroadcast.scala
Normal file
|
@ -0,0 +1,110 @@
|
|||
package spark.broadcast
|
||||
|
||||
import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
|
||||
|
||||
import java.io._
|
||||
import java.net._
|
||||
import java.util.UUID
|
||||
|
||||
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
|
||||
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
|
||||
|
||||
import spark._
|
||||
|
||||
class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean)
|
||||
extends Broadcast[T] with Logging with Serializable {
|
||||
|
||||
def value = value_
|
||||
|
||||
HttpBroadcast.synchronized {
|
||||
HttpBroadcast.values.put(uuid, 0, value_)
|
||||
}
|
||||
|
||||
if (!isLocal) {
|
||||
HttpBroadcast.write(uuid, value_)
|
||||
}
|
||||
|
||||
// Called by JVM when deserializing an object
|
||||
private def readObject(in: ObjectInputStream): Unit = {
|
||||
in.defaultReadObject()
|
||||
HttpBroadcast.synchronized {
|
||||
val cachedVal = HttpBroadcast.values.get(uuid, 0)
|
||||
if (cachedVal != null) {
|
||||
value_ = cachedVal.asInstanceOf[T]
|
||||
} else {
|
||||
logInfo("Started reading broadcast variable " + uuid)
|
||||
val start = System.nanoTime
|
||||
value_ = HttpBroadcast.read(uuid).asInstanceOf[T]
|
||||
HttpBroadcast.values.put(uuid, 0, value_)
|
||||
val time = (System.nanoTime - start) / 1e9
|
||||
logInfo("Reading broadcast variable " + uuid + " took " + time + " s")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class HttpBroadcastFactory extends BroadcastFactory {
|
||||
def initialize(isMaster: Boolean): Unit = HttpBroadcast.initialize(isMaster)
|
||||
def newBroadcast[T](value_ : T, isLocal: Boolean) = new HttpBroadcast[T](value_, isLocal)
|
||||
}
|
||||
|
||||
private object HttpBroadcast extends Logging {
|
||||
val values = SparkEnv.get.cache.newKeySpace()
|
||||
|
||||
private var initialized = false
|
||||
|
||||
private var broadcastDir: File = null
|
||||
private var compress: Boolean = false
|
||||
private var bufferSize: Int = 65536
|
||||
private var serverUri: String = null
|
||||
private var server: HttpServer = null
|
||||
|
||||
def initialize(isMaster: Boolean): Unit = {
|
||||
synchronized {
|
||||
if (!initialized) {
|
||||
bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
|
||||
compress = System.getProperty("spark.compress", "false").toBoolean
|
||||
if (isMaster) {
|
||||
createServer()
|
||||
}
|
||||
serverUri = System.getProperty("spark.httpBroadcast.uri")
|
||||
initialized = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def createServer() {
|
||||
broadcastDir = Utils.createTempDir()
|
||||
server = new HttpServer(broadcastDir)
|
||||
server.start()
|
||||
serverUri = server.uri
|
||||
System.setProperty("spark.httpBroadcast.uri", serverUri)
|
||||
logInfo("Broadcast server started at " + serverUri)
|
||||
}
|
||||
|
||||
def write(uuid: UUID, value: Any) {
|
||||
val file = new File(broadcastDir, "broadcast-" + uuid)
|
||||
val out: OutputStream = if (compress) {
|
||||
new LZFOutputStream(new FileOutputStream(file)) // Does its own buffering
|
||||
} else {
|
||||
new FastBufferedOutputStream(new FileOutputStream(file), bufferSize)
|
||||
}
|
||||
val ser = SparkEnv.get.serializer.newInstance()
|
||||
val serOut = ser.outputStream(out)
|
||||
serOut.writeObject(value)
|
||||
serOut.close()
|
||||
}
|
||||
|
||||
def read(uuid: UUID): Any = {
|
||||
val url = serverUri + "/broadcast-" + uuid
|
||||
var in = if (compress) {
|
||||
new LZFInputStream(new URL(url).openStream()) // Does its own buffering
|
||||
} else {
|
||||
new FastBufferedInputStream(new URL(url).openStream(), bufferSize)
|
||||
}
|
||||
val ser = SparkEnv.get.serializer.newInstance()
|
||||
val serIn = ser.inputStream(in)
|
||||
val obj = serIn.readObject()
|
||||
serIn.close()
|
||||
}
|
||||
}
|
23
core/src/test/scala/spark/BroadcastSuite.scala
Normal file
23
core/src/test/scala/spark/BroadcastSuite.scala
Normal file
|
@ -0,0 +1,23 @@
|
|||
package spark
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
class BroadcastSuite extends FunSuite {
|
||||
test("basic broadcast") {
|
||||
val sc = new SparkContext("local", "test")
|
||||
val list = List(1, 2, 3, 4)
|
||||
val listBroadcast = sc.broadcast(list)
|
||||
val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
|
||||
assert(results.collect.toSet === Set((1, 10), (2, 10)))
|
||||
sc.stop()
|
||||
}
|
||||
|
||||
test("broadcast variables accessed in multiple threads") {
|
||||
val sc = new SparkContext("local[10]", "test")
|
||||
val list = List(1, 2, 3, 4)
|
||||
val listBroadcast = sc.broadcast(list)
|
||||
val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
|
||||
assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
|
||||
sc.stop()
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue