Updated according to Matei's code review comment.
This commit is contained in:
parent
dd7bef3147
commit
2bc895a829
|
@ -10,7 +10,7 @@ private[spark] abstract class ShuffleFetcher {
|
|||
* @return An iterator over the elements of the fetched shuffle outputs.
|
||||
*/
|
||||
def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
|
||||
serializer: Serializer = Serializer.default): Iterator[(K,V)]
|
||||
serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[(K,V)]
|
||||
|
||||
/** Stop the fetcher */
|
||||
def stop() {}
|
||||
|
|
|
@ -7,7 +7,7 @@ import spark.broadcast.BroadcastManager
|
|||
import spark.storage.BlockManager
|
||||
import spark.storage.BlockManagerMaster
|
||||
import spark.network.ConnectionManager
|
||||
import spark.serializer.Serializer
|
||||
import spark.serializer.{Serializer, SerializerManager}
|
||||
import spark.util.AkkaUtils
|
||||
|
||||
|
||||
|
@ -21,6 +21,7 @@ import spark.util.AkkaUtils
|
|||
class SparkEnv (
|
||||
val executorId: String,
|
||||
val actorSystem: ActorSystem,
|
||||
val serializerManager: SerializerManager,
|
||||
val serializer: Serializer,
|
||||
val closureSerializer: Serializer,
|
||||
val cacheManager: CacheManager,
|
||||
|
@ -92,10 +93,12 @@ object SparkEnv extends Logging {
|
|||
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
|
||||
}
|
||||
|
||||
val serializer = Serializer.setDefault(
|
||||
val serializerManager = new SerializerManager
|
||||
|
||||
val serializer = serializerManager.setDefault(
|
||||
System.getProperty("spark.serializer", "spark.JavaSerializer"))
|
||||
|
||||
val closureSerializer = Serializer.get(
|
||||
val closureSerializer = serializerManager.get(
|
||||
System.getProperty("spark.closure.serializer", "spark.JavaSerializer"))
|
||||
|
||||
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
|
||||
|
@ -155,6 +158,7 @@ object SparkEnv extends Logging {
|
|||
new SparkEnv(
|
||||
executorId,
|
||||
actorSystem,
|
||||
serializerManager,
|
||||
serializer,
|
||||
closureSerializer,
|
||||
cacheManager,
|
||||
|
|
|
@ -8,7 +8,6 @@ import scala.collection.mutable.ArrayBuffer
|
|||
|
||||
import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext}
|
||||
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
|
||||
import spark.serializer.Serializer
|
||||
|
||||
|
||||
private[spark] sealed trait CoGroupSplitDep extends Serializable
|
||||
|
@ -114,7 +113,7 @@ class CoGroupedRDD[K](
|
|||
}
|
||||
}
|
||||
|
||||
val ser = Serializer.get(serializerClass)
|
||||
val ser = SparkEnv.get.serializerManager.get(serializerClass)
|
||||
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
|
||||
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
|
||||
// Read them from the parent
|
||||
|
|
|
@ -2,7 +2,6 @@ package spark.rdd
|
|||
|
||||
import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext}
|
||||
import spark.SparkContext._
|
||||
import spark.serializer.Serializer
|
||||
|
||||
|
||||
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
|
||||
|
@ -32,7 +31,7 @@ class ShuffledRDD[K, V](
|
|||
|
||||
override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = {
|
||||
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
|
||||
SparkEnv.get.shuffleFetcher.fetch[K, V](
|
||||
shuffledId, split.index, context.taskMetrics, Serializer.get(serializerClass))
|
||||
SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics,
|
||||
SparkEnv.get.serializerManager.get(serializerClass))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ import spark.Partition
|
|||
import spark.SparkEnv
|
||||
import spark.ShuffleDependency
|
||||
import spark.OneToOneDependency
|
||||
import spark.serializer.Serializer
|
||||
|
||||
|
||||
/**
|
||||
* An optimized version of cogroup for set difference/subtraction.
|
||||
|
@ -68,7 +68,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
|
|||
|
||||
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
|
||||
val partition = p.asInstanceOf[CoGroupPartition]
|
||||
val serializer = Serializer.get(serializerClass)
|
||||
val serializer = SparkEnv.get.serializerManager.get(serializerClass)
|
||||
val map = new JHashMap[K, ArrayBuffer[V]]
|
||||
def getSeq(k: K): ArrayBuffer[V] = {
|
||||
val seq = map.get(k)
|
||||
|
|
|
@ -14,7 +14,6 @@ import com.ning.compress.lzf.LZFOutputStream
|
|||
|
||||
import spark._
|
||||
import spark.executor.ShuffleWriteMetrics
|
||||
import spark.serializer.Serializer
|
||||
import spark.storage._
|
||||
import spark.util.{TimeStampedHashMap, MetadataCleaner}
|
||||
|
||||
|
@ -139,12 +138,12 @@ private[spark] class ShuffleMapTask(
|
|||
metrics = Some(taskContext.taskMetrics)
|
||||
|
||||
val blockManager = SparkEnv.get.blockManager
|
||||
var shuffle: ShuffleBlockManager#Shuffle = null
|
||||
var shuffle: ShuffleBlocks = null
|
||||
var buckets: ShuffleWriterGroup = null
|
||||
|
||||
try {
|
||||
// Obtain all the block writers for shuffle blocks.
|
||||
val ser = Serializer.get(dep.serializerClass)
|
||||
val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
|
||||
shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser)
|
||||
buckets = shuffle.acquireWriters(partition)
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@ package spark.serializer
|
|||
|
||||
import java.io.{EOFException, InputStream, OutputStream}
|
||||
import java.nio.ByteBuffer
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
||||
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
|
||||
|
||||
|
@ -19,47 +18,6 @@ trait Serializer {
|
|||
}
|
||||
|
||||
|
||||
/**
|
||||
* A singleton object that can be used to fetch serializer objects based on the serializer
|
||||
* class name. If a previous instance of the serializer object has been created, the get
|
||||
* method returns that instead of creating a new one.
|
||||
*/
|
||||
object Serializer {
|
||||
|
||||
private val serializers = new ConcurrentHashMap[String, Serializer]
|
||||
private var _default: Serializer = _
|
||||
|
||||
def default = _default
|
||||
|
||||
def setDefault(clsName: String): Serializer = {
|
||||
_default = get(clsName)
|
||||
_default
|
||||
}
|
||||
|
||||
def get(clsName: String): Serializer = {
|
||||
if (clsName == null) {
|
||||
default
|
||||
} else {
|
||||
var serializer = serializers.get(clsName)
|
||||
if (serializer != null) {
|
||||
// If the serializer has been created previously, reuse that.
|
||||
serializer
|
||||
} else this.synchronized {
|
||||
// Otherwise, create a new one. But make sure no other thread has attempted
|
||||
// to create another new one at the same time.
|
||||
serializer = serializers.get(clsName)
|
||||
if (serializer == null) {
|
||||
val clsLoader = Thread.currentThread.getContextClassLoader
|
||||
serializer = Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer]
|
||||
serializers.put(clsName, serializer)
|
||||
}
|
||||
serializer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* An instance of a serializer, for use by one thread at a time.
|
||||
*/
|
||||
|
|
45
core/src/main/scala/spark/serializer/SerializerManager.scala
Normal file
45
core/src/main/scala/spark/serializer/SerializerManager.scala
Normal file
|
@ -0,0 +1,45 @@
|
|||
package spark.serializer
|
||||
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
||||
|
||||
/**
|
||||
* A service that returns a serializer object given the serializer's class name. If a previous
|
||||
* instance of the serializer object has been created, the get method returns that instead of
|
||||
* creating a new one.
|
||||
*/
|
||||
private[spark] class SerializerManager {
|
||||
|
||||
private val serializers = new ConcurrentHashMap[String, Serializer]
|
||||
private var _default: Serializer = _
|
||||
|
||||
def default = _default
|
||||
|
||||
def setDefault(clsName: String): Serializer = {
|
||||
_default = get(clsName)
|
||||
_default
|
||||
}
|
||||
|
||||
def get(clsName: String): Serializer = {
|
||||
if (clsName == null) {
|
||||
default
|
||||
} else {
|
||||
var serializer = serializers.get(clsName)
|
||||
if (serializer != null) {
|
||||
// If the serializer has been created previously, reuse that.
|
||||
serializer
|
||||
} else this.synchronized {
|
||||
// Otherwise, create a new one. But make sure no other thread has attempted
|
||||
// to create another new one at the same time.
|
||||
serializer = serializers.get(clsName)
|
||||
if (serializer == null) {
|
||||
val clsLoader = Thread.currentThread.getContextClassLoader
|
||||
serializer =
|
||||
Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer]
|
||||
serializers.put(clsName, serializer)
|
||||
}
|
||||
serializer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -2,6 +2,7 @@ package spark.storage
|
|||
|
||||
import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.channels.FileChannel
|
||||
import java.nio.channels.FileChannel.MapMode
|
||||
import java.util.{Random, Date}
|
||||
import java.text.SimpleDateFormat
|
||||
|
@ -26,14 +27,16 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
|||
|
||||
private val f: File = createFile(blockId /*, allowAppendExisting */)
|
||||
|
||||
private var repositionableStream: FastBufferedOutputStream = null
|
||||
// The file channel, used for repositioning / truncating the file.
|
||||
private var channel: FileChannel = null
|
||||
private var bs: OutputStream = null
|
||||
private var objOut: SerializationStream = null
|
||||
private var validLength = 0L
|
||||
private var lastValidPosition = 0L
|
||||
|
||||
override def open(): DiskBlockObjectWriter = {
|
||||
repositionableStream = new FastBufferedOutputStream(new FileOutputStream(f))
|
||||
bs = blockManager.wrapForCompression(blockId, repositionableStream)
|
||||
val fos = new FileOutputStream(f, true)
|
||||
channel = fos.getChannel()
|
||||
bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos))
|
||||
objOut = serializer.newInstance().serializeStream(bs)
|
||||
this
|
||||
}
|
||||
|
@ -41,9 +44,9 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
|||
override def close() {
|
||||
objOut.close()
|
||||
bs.close()
|
||||
objOut = null
|
||||
channel = null
|
||||
bs = null
|
||||
repositionableStream = null
|
||||
objOut = null
|
||||
// Invoke the close callback handler.
|
||||
super.close()
|
||||
}
|
||||
|
@ -54,25 +57,23 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
|||
// Return the number of bytes written for this commit.
|
||||
override def commit(): Long = {
|
||||
bs.flush()
|
||||
validLength = repositionableStream.position()
|
||||
validLength
|
||||
val prevPos = lastValidPosition
|
||||
lastValidPosition = channel.position()
|
||||
lastValidPosition - prevPos
|
||||
}
|
||||
|
||||
override def revertPartialWrites() {
|
||||
// Flush the outstanding writes and delete the file.
|
||||
objOut.close()
|
||||
bs.close()
|
||||
objOut = null
|
||||
bs = null
|
||||
repositionableStream = null
|
||||
f.delete()
|
||||
// Discard current writes. We do this by flushing the outstanding writes and
|
||||
// truncate the file to the last valid position.
|
||||
bs.flush()
|
||||
channel.truncate(lastValidPosition)
|
||||
}
|
||||
|
||||
override def write(value: Any) {
|
||||
objOut.writeObject(value)
|
||||
}
|
||||
|
||||
override def size(): Long = validLength
|
||||
override def size(): Long = lastValidPosition
|
||||
}
|
||||
|
||||
val MAX_DIR_CREATION_ATTEMPTS: Int = 10
|
||||
|
@ -86,7 +87,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
|||
|
||||
addShutdownHook()
|
||||
|
||||
def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int): BlockObjectWriter = {
|
||||
def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
|
||||
: BlockObjectWriter = {
|
||||
new DiskBlockObjectWriter(blockId, serializer, bufferSize)
|
||||
}
|
||||
|
||||
|
|
|
@ -7,27 +7,31 @@ private[spark]
|
|||
class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter])
|
||||
|
||||
|
||||
private[spark]
|
||||
trait ShuffleBlocks {
|
||||
def acquireWriters(mapId: Int): ShuffleWriterGroup
|
||||
def releaseWriters(group: ShuffleWriterGroup)
|
||||
}
|
||||
|
||||
|
||||
private[spark]
|
||||
class ShuffleBlockManager(blockManager: BlockManager) {
|
||||
|
||||
def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): Shuffle = {
|
||||
new Shuffle(shuffleId, numBuckets, serializer)
|
||||
}
|
||||
|
||||
class Shuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) {
|
||||
|
||||
// Get a group of writers for a map task.
|
||||
def acquireWriters(mapId: Int): ShuffleWriterGroup = {
|
||||
val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
|
||||
val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
|
||||
val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
|
||||
blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open()
|
||||
def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = {
|
||||
new ShuffleBlocks {
|
||||
// Get a group of writers for a map task.
|
||||
override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
|
||||
val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
|
||||
val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
|
||||
val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
|
||||
blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open()
|
||||
}
|
||||
new ShuffleWriterGroup(mapId, writers)
|
||||
}
|
||||
new ShuffleWriterGroup(mapId, writers)
|
||||
}
|
||||
|
||||
def releaseWriters(group: ShuffleWriterGroup) = {
|
||||
// Nothing really to release here.
|
||||
override def releaseWriters(group: ShuffleWriterGroup) = {
|
||||
// Nothing really to release here.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue