Updated according to Matei's code review comment.

This commit is contained in:
Reynold Xin 2013-05-03 01:02:16 -07:00
parent dd7bef3147
commit 2bc895a829
10 changed files with 99 additions and 89 deletions

View file

@ -10,7 +10,7 @@ private[spark] abstract class ShuffleFetcher {
* @return An iterator over the elements of the fetched shuffle outputs.
*/
def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
serializer: Serializer = Serializer.default): Iterator[(K,V)]
serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[(K,V)]
/** Stop the fetcher */
def stop() {}

View file

@ -7,7 +7,7 @@ import spark.broadcast.BroadcastManager
import spark.storage.BlockManager
import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager
import spark.serializer.Serializer
import spark.serializer.{Serializer, SerializerManager}
import spark.util.AkkaUtils
@ -21,6 +21,7 @@ import spark.util.AkkaUtils
class SparkEnv (
val executorId: String,
val actorSystem: ActorSystem,
val serializerManager: SerializerManager,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheManager: CacheManager,
@ -92,10 +93,12 @@ object SparkEnv extends Logging {
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
}
val serializer = Serializer.setDefault(
val serializerManager = new SerializerManager
val serializer = serializerManager.setDefault(
System.getProperty("spark.serializer", "spark.JavaSerializer"))
val closureSerializer = Serializer.get(
val closureSerializer = serializerManager.get(
System.getProperty("spark.closure.serializer", "spark.JavaSerializer"))
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
@ -155,6 +158,7 @@ object SparkEnv extends Logging {
new SparkEnv(
executorId,
actorSystem,
serializerManager,
serializer,
closureSerializer,
cacheManager,

View file

@ -8,7 +8,6 @@ import scala.collection.mutable.ArrayBuffer
import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext}
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
import spark.serializer.Serializer
private[spark] sealed trait CoGroupSplitDep extends Serializable
@ -114,7 +113,7 @@ class CoGroupedRDD[K](
}
}
val ser = Serializer.get(serializerClass)
val ser = SparkEnv.get.serializerManager.get(serializerClass)
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
// Read them from the parent

View file

@ -2,7 +2,6 @@ package spark.rdd
import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext}
import spark.SparkContext._
import spark.serializer.Serializer
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
@ -32,7 +31,7 @@ class ShuffledRDD[K, V](
override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
SparkEnv.get.shuffleFetcher.fetch[K, V](
shuffledId, split.index, context.taskMetrics, Serializer.get(serializerClass))
SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics,
SparkEnv.get.serializerManager.get(serializerClass))
}
}

View file

@ -11,7 +11,7 @@ import spark.Partition
import spark.SparkEnv
import spark.ShuffleDependency
import spark.OneToOneDependency
import spark.serializer.Serializer
/**
* An optimized version of cogroup for set difference/subtraction.
@ -68,7 +68,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
val serializer = Serializer.get(serializerClass)
val serializer = SparkEnv.get.serializerManager.get(serializerClass)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)

View file

@ -14,7 +14,6 @@ import com.ning.compress.lzf.LZFOutputStream
import spark._
import spark.executor.ShuffleWriteMetrics
import spark.serializer.Serializer
import spark.storage._
import spark.util.{TimeStampedHashMap, MetadataCleaner}
@ -139,12 +138,12 @@ private[spark] class ShuffleMapTask(
metrics = Some(taskContext.taskMetrics)
val blockManager = SparkEnv.get.blockManager
var shuffle: ShuffleBlockManager#Shuffle = null
var shuffle: ShuffleBlocks = null
var buckets: ShuffleWriterGroup = null
try {
// Obtain all the block writers for shuffle blocks.
val ser = Serializer.get(dep.serializerClass)
val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser)
buckets = shuffle.acquireWriters(partition)

View file

@ -2,7 +2,6 @@ package spark.serializer
import java.io.{EOFException, InputStream, OutputStream}
import java.nio.ByteBuffer
import java.util.concurrent.ConcurrentHashMap
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
@ -19,47 +18,6 @@ trait Serializer {
}
/**
* A singleton object that can be used to fetch serializer objects based on the serializer
* class name. If a previous instance of the serializer object has been created, the get
* method returns that instead of creating a new one.
*/
object Serializer {
private val serializers = new ConcurrentHashMap[String, Serializer]
private var _default: Serializer = _
def default = _default
def setDefault(clsName: String): Serializer = {
_default = get(clsName)
_default
}
def get(clsName: String): Serializer = {
if (clsName == null) {
default
} else {
var serializer = serializers.get(clsName)
if (serializer != null) {
// If the serializer has been created previously, reuse that.
serializer
} else this.synchronized {
// Otherwise, create a new one. But make sure no other thread has attempted
// to create another new one at the same time.
serializer = serializers.get(clsName)
if (serializer == null) {
val clsLoader = Thread.currentThread.getContextClassLoader
serializer = Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer]
serializers.put(clsName, serializer)
}
serializer
}
}
}
}
/**
* An instance of a serializer, for use by one thread at a time.
*/

View file

@ -0,0 +1,45 @@
package spark.serializer
import java.util.concurrent.ConcurrentHashMap
/**
* A service that returns a serializer object given the serializer's class name. If a previous
* instance of the serializer object has been created, the get method returns that instead of
* creating a new one.
*/
private[spark] class SerializerManager {
private val serializers = new ConcurrentHashMap[String, Serializer]
private var _default: Serializer = _
def default = _default
def setDefault(clsName: String): Serializer = {
_default = get(clsName)
_default
}
def get(clsName: String): Serializer = {
if (clsName == null) {
default
} else {
var serializer = serializers.get(clsName)
if (serializer != null) {
// If the serializer has been created previously, reuse that.
serializer
} else this.synchronized {
// Otherwise, create a new one. But make sure no other thread has attempted
// to create another new one at the same time.
serializer = serializers.get(clsName)
if (serializer == null) {
val clsLoader = Thread.currentThread.getContextClassLoader
serializer =
Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer]
serializers.put(clsName, serializer)
}
serializer
}
}
}
}

View file

@ -2,6 +2,7 @@ package spark.storage
import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
import java.nio.ByteBuffer
import java.nio.channels.FileChannel
import java.nio.channels.FileChannel.MapMode
import java.util.{Random, Date}
import java.text.SimpleDateFormat
@ -26,14 +27,16 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private val f: File = createFile(blockId /*, allowAppendExisting */)
private var repositionableStream: FastBufferedOutputStream = null
// The file channel, used for repositioning / truncating the file.
private var channel: FileChannel = null
private var bs: OutputStream = null
private var objOut: SerializationStream = null
private var validLength = 0L
private var lastValidPosition = 0L
override def open(): DiskBlockObjectWriter = {
repositionableStream = new FastBufferedOutputStream(new FileOutputStream(f))
bs = blockManager.wrapForCompression(blockId, repositionableStream)
val fos = new FileOutputStream(f, true)
channel = fos.getChannel()
bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos))
objOut = serializer.newInstance().serializeStream(bs)
this
}
@ -41,9 +44,9 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
override def close() {
objOut.close()
bs.close()
objOut = null
channel = null
bs = null
repositionableStream = null
objOut = null
// Invoke the close callback handler.
super.close()
}
@ -54,25 +57,23 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
// Return the number of bytes written for this commit.
override def commit(): Long = {
bs.flush()
validLength = repositionableStream.position()
validLength
val prevPos = lastValidPosition
lastValidPosition = channel.position()
lastValidPosition - prevPos
}
override def revertPartialWrites() {
// Flush the outstanding writes and delete the file.
objOut.close()
bs.close()
objOut = null
bs = null
repositionableStream = null
f.delete()
// Discard current writes. We do this by flushing the outstanding writes and
// truncate the file to the last valid position.
bs.flush()
channel.truncate(lastValidPosition)
}
override def write(value: Any) {
objOut.writeObject(value)
}
override def size(): Long = validLength
override def size(): Long = lastValidPosition
}
val MAX_DIR_CREATION_ATTEMPTS: Int = 10
@ -86,7 +87,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
addShutdownHook()
def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int): BlockObjectWriter = {
def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
new DiskBlockObjectWriter(blockId, serializer, bufferSize)
}

View file

@ -7,27 +7,31 @@ private[spark]
class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter])
private[spark]
trait ShuffleBlocks {
def acquireWriters(mapId: Int): ShuffleWriterGroup
def releaseWriters(group: ShuffleWriterGroup)
}
private[spark]
class ShuffleBlockManager(blockManager: BlockManager) {
def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): Shuffle = {
new Shuffle(shuffleId, numBuckets, serializer)
}
class Shuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) {
// Get a group of writers for a map task.
def acquireWriters(mapId: Int): ShuffleWriterGroup = {
val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open()
def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = {
new ShuffleBlocks {
// Get a group of writers for a map task.
override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open()
}
new ShuffleWriterGroup(mapId, writers)
}
new ShuffleWriterGroup(mapId, writers)
}
def releaseWriters(group: ShuffleWriterGroup) = {
// Nothing really to release here.
override def releaseWriters(group: ShuffleWriterGroup) = {
// Nothing really to release here.
}
}
}
}