[SPARK-2897][SPARK-2920]TorrentBroadcast does use the serializer class specified in the spark option "spark.serializer"

Author: GuoQiang Li <witgo@qq.com>

Closes #1836 from witgo/SPARK-2897 and squashes the following commits:

23cdc5b [GuoQiang Li] review commit
ada4fba [GuoQiang Li] TorrentBroadcast does not support broadcast compression
fb91792 [GuoQiang Li] org.apache.spark.broadcast.TorrentBroadcast does use the serializer class specified in the spark option "spark.serializer"
This commit is contained in:
GuoQiang Li 2014-08-08 16:57:26 -07:00 committed by Reynold Xin
parent 74d6f62264
commit ec79063fad
2 changed files with 33 additions and 8 deletions

View file

@ -17,14 +17,15 @@
package org.apache.spark.broadcast
import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream}
import java.io.{ByteArrayOutputStream, ByteArrayInputStream, InputStream,
ObjectInputStream, ObjectOutputStream, OutputStream}
import scala.reflect.ClassTag
import scala.util.Random
import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.Utils
/**
* A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
@ -214,11 +215,15 @@ private[broadcast] object TorrentBroadcast extends Logging {
private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
private var initialized = false
private var conf: SparkConf = null
private var compress: Boolean = false
private var compressionCodec: CompressionCodec = null
def initialize(_isDriver: Boolean, conf: SparkConf) {
TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests
synchronized {
if (!initialized) {
compress = conf.getBoolean("spark.broadcast.compress", true)
compressionCodec = CompressionCodec.createCodec(conf)
initialized = true
}
}
@ -228,8 +233,13 @@ private[broadcast] object TorrentBroadcast extends Logging {
initialized = false
}
def blockifyObject[T](obj: T): TorrentInfo = {
val byteArray = Utils.serialize[T](obj)
def blockifyObject[T: ClassTag](obj: T): TorrentInfo = {
val bos = new ByteArrayOutputStream()
val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos
val ser = SparkEnv.get.serializer.newInstance()
val serOut = ser.serializeStream(out)
serOut.writeObject[T](obj).close()
val byteArray = bos.toByteArray
val bais = new ByteArrayInputStream(byteArray)
var blockNum = byteArray.length / BLOCK_SIZE
@ -255,7 +265,7 @@ private[broadcast] object TorrentBroadcast extends Logging {
info
}
def unBlockifyObject[T](
def unBlockifyObject[T: ClassTag](
arrayOfBlocks: Array[TorrentBlock],
totalBytes: Int,
totalBlocks: Int): T = {
@ -264,7 +274,16 @@ private[broadcast] object TorrentBroadcast extends Logging {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
}
Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader)
val in: InputStream = {
val arrIn = new ByteArrayInputStream(retByteArray)
if (compress) compressionCodec.compressedInputStream(arrIn) else arrIn
}
val ser = SparkEnv.get.serializer.newInstance()
val serIn = ser.deserializeStream(in)
val obj = serIn.readObject[T]()
serIn.close()
obj
}
/**

View file

@ -44,7 +44,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
test("Accessing HttpBroadcast variables in a local cluster") {
val numSlaves = 4
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf)
val conf = httpConf.clone
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
conf.set("spark.broadcast.compress", "true")
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf)
val list = List[Int](1, 2, 3, 4)
val broadcast = sc.broadcast(list)
val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))
@ -69,7 +72,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
test("Accessing TorrentBroadcast variables in a local cluster") {
val numSlaves = 4
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf)
val conf = torrentConf.clone
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
conf.set("spark.broadcast.compress", "true")
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf)
val list = List[Int](1, 2, 3, 4)
val broadcast = sc.broadcast(list)
val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))