Merge pull request #57 from aarondav/bid

Refactor BlockId into an actual type

Converts all of our BlockId strings into actual BlockId types. Here are some advantages of doing this now:

+ Type safety
+  Code clarity - it's now obvious what the key of a shuffle or rdd block is, for instance. Additionally, appearing in tuple/map type signatures is a big readability bonus. A Seq[(String, BlockStatus)] is not very clear. Further, we can now use more Scala features, like matching on BlockId types.
+ Explicit usage - we can now formally tell where various BlockIds are being used (without doing string searches); this makes updating current BlockIds a much clearer process, and compiler-supported.
  (I'm looking at you, shuffle file consolidation.)
+ It will only get harder to make this change as time goes on.

Downside is, of course, that this is a very invasive change touching a lot of different files, which will inevitably lead to merge conflicts for many.
This commit is contained in:
Reynold Xin 2013-10-14 14:20:01 -07:00
commit 3b11f43e36
44 changed files with 544 additions and 385 deletions

View file

@ -21,6 +21,7 @@ import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundByteHandlerAdapter; import io.netty.channel.ChannelInboundByteHandlerAdapter;
import org.apache.spark.storage.BlockId;
abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
@ -33,7 +34,7 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
} }
public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
public abstract void handleError(String blockId); public abstract void handleError(BlockId blockId);
@Override @Override
public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) { public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {

View file

@ -24,6 +24,7 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundMessageHandlerAdapter; import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.channel.DefaultFileRegion; import io.netty.channel.DefaultFileRegion;
import org.apache.spark.storage.BlockId;
class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> { class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
@ -34,8 +35,9 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
} }
@Override @Override
public void messageReceived(ChannelHandlerContext ctx, String blockId) { public void messageReceived(ChannelHandlerContext ctx, String blockIdString) {
String path = pResolver.getAbsolutePath(blockId); BlockId blockId = BlockId.apply(blockIdString);
String path = pResolver.getAbsolutePath(blockId.name());
// if getFilePath returns null, close the channel // if getFilePath returns null, close the channel
if (path == null) { if (path == null) {
//ctx.close(); //ctx.close();

View file

@ -22,7 +22,7 @@ import scala.collection.mutable.HashMap
import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics} import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics}
import org.apache.spark.serializer.Serializer import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.BlockManagerId import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator import org.apache.spark.util.CompletionIterator
@ -45,12 +45,12 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
} }
val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map { val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
case (address, splits) => case (address, splits) =>
(address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2))) (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
} }
def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[T] = { def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
val blockId = blockPair._1 val blockId = blockPair._1
val blockOption = blockPair._2 val blockOption = blockPair._2
blockOption match { blockOption match {
@ -58,9 +58,8 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
block.asInstanceOf[Iterator[T]] block.asInstanceOf[Iterator[T]]
} }
case None => { case None => {
val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
blockId match { blockId match {
case regex(shufId, mapId, _) => case ShuffleBlockId(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1 val address = statuses(mapId.toInt)._1
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null) throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
case _ => case _ =>

View file

@ -18,7 +18,7 @@
package org.apache.spark package org.apache.spark
import scala.collection.mutable.{ArrayBuffer, HashSet} import scala.collection.mutable.{ArrayBuffer, HashSet}
import org.apache.spark.storage.{BlockManager, StorageLevel} import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, RDDBlockId}
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
@ -28,12 +28,12 @@ import org.apache.spark.rdd.RDD
private[spark] class CacheManager(blockManager: BlockManager) extends Logging { private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
/** Keys of RDD splits that are being computed/loaded. */ /** Keys of RDD splits that are being computed/loaded. */
private val loading = new HashSet[String]() private val loading = new HashSet[RDDBlockId]()
/** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */ /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel) def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel)
: Iterator[T] = { : Iterator[T] = {
val key = "rdd_%d_%d".format(rdd.id, split.index) val key = RDDBlockId(rdd.id, split.index)
logDebug("Looking for partition " + key) logDebug("Looking for partition " + key)
blockManager.get(key) match { blockManager.get(key) match {
case Some(values) => case Some(values) =>
@ -73,7 +73,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
if (context.runningLocally) { return computedValues } if (context.runningLocally) { return computedValues }
val elements = new ArrayBuffer[Any] val elements = new ArrayBuffer[Any]
elements ++= computedValues elements ++= computedValues
blockManager.put(key, elements, storageLevel, true) blockManager.put(key, elements, storageLevel, tellMaster = true)
return elements.iterator.asInstanceOf[Iterator[T]] return elements.iterator.asInstanceOf[Iterator[T]]
} finally { } finally {
loading.synchronized { loading.synchronized {

View file

@ -26,7 +26,7 @@ import scala.collection.mutable.{ListBuffer, Map, Set}
import scala.math import scala.math
import org.apache.spark._ import org.apache.spark._
import org.apache.spark.storage.{BlockManager, StorageLevel} import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
@ -36,7 +36,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
def value = value_ def value = value_
def blockId: String = BlockManager.toBroadcastId(id) def blockId = BroadcastBlockId(id)
MultiTracker.synchronized { MultiTracker.synchronized {
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)

View file

@ -25,16 +25,15 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import org.apache.spark.{HttpServer, Logging, SparkEnv} import org.apache.spark.{HttpServer, Logging, SparkEnv}
import org.apache.spark.io.CompressionCodec import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BlockManager, StorageLevel} import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{MetadataCleanerType, Utils, MetadataCleaner, TimeStampedHashSet} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable { extends Broadcast[T](id) with Logging with Serializable {
def value = value_ def value = value_
def blockId: String = BlockManager.toBroadcastId(id) def blockId = BroadcastBlockId(id)
HttpBroadcast.synchronized { HttpBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
@ -121,7 +120,7 @@ private object HttpBroadcast extends Logging {
} }
def write(id: Long, value: Any) { def write(id: Long, value: Any) {
val file = new File(broadcastDir, "broadcast-" + id) val file = new File(broadcastDir, BroadcastBlockId(id).name)
val out: OutputStream = { val out: OutputStream = {
if (compress) { if (compress) {
compressionCodec.compressedOutputStream(new FileOutputStream(file)) compressionCodec.compressedOutputStream(new FileOutputStream(file))
@ -137,7 +136,7 @@ private object HttpBroadcast extends Logging {
} }
def read[T](id: Long): T = { def read[T](id: Long): T = {
val url = serverUri + "/broadcast-" + id val url = serverUri + "/" + BroadcastBlockId(id).name
val in = { val in = {
if (compress) { if (compress) {
compressionCodec.compressedInputStream(new URL(url).openStream()) compressionCodec.compressedInputStream(new URL(url).openStream())

View file

@ -19,13 +19,11 @@ package org.apache.spark.broadcast
import java.io._ import java.io._
import java.net._ import java.net._
import java.util.{Comparator, Random, UUID}
import scala.collection.mutable.{ListBuffer, Map, Set} import scala.collection.mutable.{ListBuffer, Set}
import scala.math
import org.apache.spark._ import org.apache.spark._
import org.apache.spark.storage.{BlockManager, StorageLevel} import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
@ -33,7 +31,7 @@ extends Broadcast[T](id) with Logging with Serializable {
def value = value_ def value = value_
def blockId = BlockManager.toBroadcastId(id) def blockId = BroadcastBlockId(id)
MultiTracker.synchronized { MultiTracker.synchronized {
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)

View file

@ -27,7 +27,7 @@ import scala.collection.mutable.HashMap
import org.apache.spark.scheduler._ import org.apache.spark.scheduler._
import org.apache.spark._ import org.apache.spark._
import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
/** /**
@ -173,7 +173,7 @@ private[spark] class Executor(
val serializedResult = { val serializedResult = {
if (serializedDirectResult.limit >= akkaFrameSize - 1024) { if (serializedDirectResult.limit >= akkaFrameSize - 1024) {
logInfo("Storing result for " + taskId + " in local BlockManager") logInfo("Storing result for " + taskId + " in local BlockManager")
val blockId = "taskresult_" + taskId val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes( env.blockManager.putBytes(
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
ser.serialize(new IndirectTaskResult[Any](blockId)) ser.serialize(new IndirectTaskResult[Any](blockId))

View file

@ -20,17 +20,18 @@ package org.apache.spark.network.netty
import io.netty.buffer._ import io.netty.buffer._
import org.apache.spark.Logging import org.apache.spark.Logging
import org.apache.spark.storage.{TestBlockId, BlockId}
private[spark] class FileHeader ( private[spark] class FileHeader (
val fileLen: Int, val fileLen: Int,
val blockId: String) extends Logging { val blockId: BlockId) extends Logging {
lazy val buffer = { lazy val buffer = {
val buf = Unpooled.buffer() val buf = Unpooled.buffer()
buf.capacity(FileHeader.HEADER_SIZE) buf.capacity(FileHeader.HEADER_SIZE)
buf.writeInt(fileLen) buf.writeInt(fileLen)
buf.writeInt(blockId.length) buf.writeInt(blockId.name.length)
blockId.foreach((x: Char) => buf.writeByte(x)) blockId.name.foreach((x: Char) => buf.writeByte(x))
//padding the rest of header //padding the rest of header
if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
@ -57,18 +58,15 @@ private[spark] object FileHeader {
for (i <- 1 to idLength) { for (i <- 1 to idLength) {
idBuilder += buf.readByte().asInstanceOf[Char] idBuilder += buf.readByte().asInstanceOf[Char]
} }
val blockId = idBuilder.toString() val blockId = BlockId(idBuilder.toString())
new FileHeader(length, blockId) new FileHeader(length, blockId)
} }
def main (args:Array[String]) {
def main (args:Array[String]){ val header = new FileHeader(25, TestBlockId("my_block"))
val buf = header.buffer
val header = new FileHeader(25,"block_0"); val newHeader = FileHeader.create(buf)
val buf = header.buffer; System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen)
val newheader = FileHeader.create(buf);
System.out.println("id="+newheader.blockId+",size="+newheader.fileLen)
} }
} }

View file

@ -27,12 +27,13 @@ import org.apache.spark.Logging
import org.apache.spark.network.ConnectionManagerId import org.apache.spark.network.ConnectionManagerId
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import org.apache.spark.storage.BlockId
private[spark] class ShuffleCopier extends Logging { private[spark] class ShuffleCopier extends Logging {
def getBlock(host: String, port: Int, blockId: String, def getBlock(host: String, port: Int, blockId: BlockId,
resultCollectCallback: (String, Long, ByteBuf) => Unit) { resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
@ -41,7 +42,7 @@ private[spark] class ShuffleCopier extends Logging {
try { try {
fc.init() fc.init()
fc.connect(host, port) fc.connect(host, port)
fc.sendRequest(blockId) fc.sendRequest(blockId.name)
fc.waitForClose() fc.waitForClose()
fc.close() fc.close()
} catch { } catch {
@ -53,14 +54,14 @@ private[spark] class ShuffleCopier extends Logging {
} }
} }
def getBlock(cmId: ConnectionManagerId, blockId: String, def getBlock(cmId: ConnectionManagerId, blockId: BlockId,
resultCollectCallback: (String, Long, ByteBuf) => Unit) { resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
getBlock(cmId.host, cmId.port, blockId, resultCollectCallback) getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
} }
def getBlocks(cmId: ConnectionManagerId, def getBlocks(cmId: ConnectionManagerId,
blocks: Seq[(String, Long)], blocks: Seq[(BlockId, Long)],
resultCollectCallback: (String, Long, ByteBuf) => Unit) { resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
for ((blockId, size) <- blocks) { for ((blockId, size) <- blocks) {
getBlock(cmId, blockId, resultCollectCallback) getBlock(cmId, blockId, resultCollectCallback)
@ -71,7 +72,7 @@ private[spark] class ShuffleCopier extends Logging {
private[spark] object ShuffleCopier extends Logging { private[spark] object ShuffleCopier extends Logging {
private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit) private class ShuffleClientHandler(resultCollectCallBack: (BlockId, Long, ByteBuf) => Unit)
extends FileClientHandler with Logging { extends FileClientHandler with Logging {
override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
@ -79,14 +80,14 @@ private[spark] object ShuffleCopier extends Logging {
resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
} }
override def handleError(blockId: String) { override def handleError(blockId: BlockId) {
if (!isComplete) { if (!isComplete) {
resultCollectCallBack(blockId, -1, null) resultCollectCallBack(blockId, -1, null)
} }
} }
} }
def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { def echoResultCollectCallBack(blockId: BlockId, size: Long, content: ByteBuf) {
if (size != -1) { if (size != -1) {
logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
} }
@ -99,7 +100,7 @@ private[spark] object ShuffleCopier extends Logging {
} }
val host = args(0) val host = args(0)
val port = args(1).toInt val port = args(1).toInt
val file = args(2) val blockId = BlockId(args(2))
val threads = if (args.length > 3) args(3).toInt else 10 val threads = if (args.length > 3) args(3).toInt else 10
val copiers = Executors.newFixedThreadPool(80) val copiers = Executors.newFixedThreadPool(80)
@ -107,12 +108,12 @@ private[spark] object ShuffleCopier extends Logging {
Executors.callable(new Runnable() { Executors.callable(new Runnable() {
def run() { def run() {
val copier = new ShuffleCopier() val copier = new ShuffleCopier()
copier.getBlock(host, port, file, echoResultCollectCallBack) copier.getBlock(host, port, blockId, echoResultCollectCallBack)
} }
}) })
}).asJava }).asJava
copiers.invokeAll(tasks) copiers.invokeAll(tasks)
copiers.shutdown copiers.shutdown()
System.exit(0) System.exit(0)
} }
} }

View file

@ -21,7 +21,7 @@ import java.io.File
import org.apache.spark.Logging import org.apache.spark.Logging
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
import org.apache.spark.storage.ShuffleBlockManager import org.apache.spark.storage.BlockId
private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging { private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
@ -54,8 +54,9 @@ private[spark] object ShuffleSender {
val localDirs = args.drop(2).map(new File(_)) val localDirs = args.drop(2).map(new File(_))
val pResovler = new PathResolver { val pResovler = new PathResolver {
override def getAbsolutePath(blockId: String): String = { override def getAbsolutePath(blockIdString: String): String = {
if (!ShuffleBlockManager.isShuffle(blockId)) { val blockId = BlockId(blockIdString)
if (!blockId.isShuffle) {
throw new Exception("Block " + blockId + " is not a shuffle block") throw new Exception("Block " + blockId + " is not a shuffle block")
} }
// Figure out which local directory it hashes to, and which subdirectory in that // Figure out which local directory it hashes to, and which subdirectory in that
@ -63,7 +64,7 @@ private[spark] object ShuffleSender {
val dirId = hash % localDirs.length val dirId = hash % localDirs.length
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
val file = new File(subDir, blockId) val file = new File(subDir, blockId.name)
return file.getAbsolutePath return file.getAbsolutePath
} }
} }

View file

@ -18,14 +18,14 @@
package org.apache.spark.rdd package org.apache.spark.rdd
import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext} import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext}
import org.apache.spark.storage.BlockManager import org.apache.spark.storage.{BlockId, BlockManager}
private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition { private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends Partition {
val index = idx val index = idx
} }
private[spark] private[spark]
class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[BlockId])
extends RDD[T](sc, Nil) { extends RDD[T](sc, Nil) {
@transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) @transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get)

View file

@ -28,8 +28,8 @@ import org.apache.spark._
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.executor.TaskMetrics import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.storage.{BlockManager, BlockManagerMaster} import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId}
import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
/** /**
* The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of
@ -156,7 +156,7 @@ class DAGScheduler(
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = { private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
if (!cacheLocs.contains(rdd.id)) { if (!cacheLocs.contains(rdd.id)) {
val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray val blockIds = rdd.partitions.indices.map(index=> RDDBlockId(rdd.id, index)).toArray[BlockId]
val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster) val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster)
cacheLocs(rdd.id) = blockIds.map { id => cacheLocs(rdd.id) = blockIds.map { id =>
locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId)) locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId))

View file

@ -24,13 +24,14 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.{SparkEnv} import org.apache.spark.{SparkEnv}
import java.nio.ByteBuffer import java.nio.ByteBuffer
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
import org.apache.spark.storage.BlockId
// Task result. Also contains updates to accumulator variables. // Task result. Also contains updates to accumulator variables.
private[spark] sealed trait TaskResult[T] private[spark] sealed trait TaskResult[T]
/** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */ /** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
private[spark] private[spark]
case class IndirectTaskResult[T](val blockId: String) extends TaskResult[T] with Serializable case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable
/** A TaskResult that contains the task's return value and accumulator updates. */ /** A TaskResult that contains the task's return value and accumulator updates. */
private[spark] private[spark]

View file

@ -26,9 +26,8 @@ import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar} import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar}
import org.apache.spark.{SerializableWritable, Logging} import org.apache.spark.{SerializableWritable, Logging}
import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock, StorageLevel}
import org.apache.spark.broadcast.HttpBroadcast import org.apache.spark.broadcast.HttpBroadcast
import org.apache.spark.storage.{GetBlock,GotBlock, PutBlock, StorageLevel, TestBlockId}
/** /**
* A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]]. * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]].
@ -43,13 +42,14 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging
val kryo = instantiator.newKryo() val kryo = instantiator.newKryo()
val classLoader = Thread.currentThread.getContextClassLoader val classLoader = Thread.currentThread.getContextClassLoader
val blockId = TestBlockId("1")
// Register some commonly used classes // Register some commonly used classes
val toRegister: Seq[AnyRef] = Seq( val toRegister: Seq[AnyRef] = Seq(
ByteBuffer.allocate(1), ByteBuffer.allocate(1),
StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY,
PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY), PutBlock(blockId, ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
GotBlock("1", ByteBuffer.allocate(1)), GotBlock(blockId, ByteBuffer.allocate(1)),
GetBlock("1"), GetBlock(blockId),
1 to 10, 1 to 10,
1 until 10, 1 until 10,
1L to 10L, 1L to 10L,

View file

@ -18,5 +18,5 @@
package org.apache.spark.storage package org.apache.spark.storage
private[spark] private[spark]
case class BlockException(blockId: String, message: String) extends Exception(message) case class BlockException(blockId: BlockId, message: String) extends Exception(message)

View file

@ -47,7 +47,7 @@ import org.apache.spark.util.Utils
*/ */
private[storage] private[storage]
trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])] trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])]
with Logging with BlockFetchTracker { with Logging with BlockFetchTracker {
def initialize() def initialize()
} }
@ -57,20 +57,20 @@ private[storage]
object BlockFetcherIterator { object BlockFetcherIterator {
// A request to fetch one or more blocks, complete with their sizes // A request to fetch one or more blocks, complete with their sizes
class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) {
val size = blocks.map(_._2).sum val size = blocks.map(_._2).sum
} }
// A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
// the block (since we want all deserializaton to happen in the calling thread); can also // the block (since we want all deserializaton to happen in the calling thread); can also
// represent a fetch failure if size == -1. // represent a fetch failure if size == -1.
class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) {
def failed: Boolean = size == -1 def failed: Boolean = size == -1
} }
class BasicBlockFetcherIterator( class BasicBlockFetcherIterator(
private val blockManager: BlockManager, private val blockManager: BlockManager,
val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer) serializer: Serializer)
extends BlockFetcherIterator { extends BlockFetcherIterator {
@ -92,12 +92,12 @@ object BlockFetcherIterator {
// This represents the number of local blocks, also counting zero-sized blocks // This represents the number of local blocks, also counting zero-sized blocks
private var numLocal = 0 private var numLocal = 0
// BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
protected val localBlocksToFetch = new ArrayBuffer[String]() protected val localBlocksToFetch = new ArrayBuffer[BlockId]()
// This represents the number of remote blocks, also counting zero-sized blocks // This represents the number of remote blocks, also counting zero-sized blocks
private var numRemote = 0 private var numRemote = 0
// BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
protected val remoteBlocksToFetch = new HashSet[String]() protected val remoteBlocksToFetch = new HashSet[BlockId]()
// A queue to hold our results. // A queue to hold our results.
protected val results = new LinkedBlockingQueue[FetchResult] protected val results = new LinkedBlockingQueue[FetchResult]
@ -167,7 +167,7 @@ object BlockFetcherIterator {
logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
val iterator = blockInfos.iterator val iterator = blockInfos.iterator
var curRequestSize = 0L var curRequestSize = 0L
var curBlocks = new ArrayBuffer[(String, Long)] var curBlocks = new ArrayBuffer[(BlockId, Long)]
while (iterator.hasNext) { while (iterator.hasNext) {
val (blockId, size) = iterator.next() val (blockId, size) = iterator.next()
// Skip empty blocks // Skip empty blocks
@ -183,7 +183,7 @@ object BlockFetcherIterator {
// Add this FetchRequest // Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks) remoteRequests += new FetchRequest(address, curBlocks)
curRequestSize = 0 curRequestSize = 0
curBlocks = new ArrayBuffer[(String, Long)] curBlocks = new ArrayBuffer[(BlockId, Long)]
} }
} }
// Add in the final request // Add in the final request
@ -241,7 +241,7 @@ object BlockFetcherIterator {
override def hasNext: Boolean = resultsGotten < _numBlocksToFetch override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
override def next(): (String, Option[Iterator[Any]]) = { override def next(): (BlockId, Option[Iterator[Any]]) = {
resultsGotten += 1 resultsGotten += 1
val startFetchWait = System.currentTimeMillis() val startFetchWait = System.currentTimeMillis()
val result = results.take() val result = results.take()
@ -267,7 +267,7 @@ object BlockFetcherIterator {
class NettyBlockFetcherIterator( class NettyBlockFetcherIterator(
blockManager: BlockManager, blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer) serializer: Serializer)
extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) { extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) {
@ -303,7 +303,7 @@ object BlockFetcherIterator {
override protected def sendRequest(req: FetchRequest) { override protected def sendRequest(req: FetchRequest) {
def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) { def putResult(blockId: BlockId, blockSize: Long, blockData: ByteBuf) {
val fetchResult = new FetchResult(blockId, blockSize, val fetchResult = new FetchResult(blockId, blockSize,
() => dataDeserialize(blockId, blockData.nioBuffer, serializer)) () => dataDeserialize(blockId, blockData.nioBuffer, serializer))
results.put(fetchResult) results.put(fetchResult)
@ -337,7 +337,7 @@ object BlockFetcherIterator {
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
} }
override def next(): (String, Option[Iterator[Any]]) = { override def next(): (BlockId, Option[Iterator[Any]]) = {
resultsGotten += 1 resultsGotten += 1
val result = results.take() val result = results.take()
// If all the results has been retrieved, copiers will exit automatically // If all the results has been retrieved, copiers will exit automatically

View file

@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.storage
/**
* Identifies a particular Block of data, usually associated with a single file.
* A Block can be uniquely identified by its filename, but each type of Block has a different
* set of keys which produce its unique name.
*
* If your BlockId should be serializable, be sure to add it to the BlockId.fromString() method.
*/
private[spark] sealed abstract class BlockId {
/** A globally unique identifier for this Block. Can be used for ser/de. */
def name: String
// convenience methods
def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None
def isRDD = isInstanceOf[RDDBlockId]
def isShuffle = isInstanceOf[ShuffleBlockId]
def isBroadcast = isInstanceOf[BroadcastBlockId]
override def toString = name
override def hashCode = name.hashCode
override def equals(other: Any): Boolean = other match {
case o: BlockId => getClass == o.getClass && name.equals(o.name)
case _ => false
}
}
private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId {
def name = "rdd_" + rddId + "_" + splitIndex
}
private[spark]
case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
}
private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
def name = "broadcast_" + broadcastId
}
private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId {
def name = "taskresult_" + taskId
}
private[spark] case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId {
def name = "input-" + streamId + "-" + uniqueId
}
// Intended only for testing purposes
private[spark] case class TestBlockId(id: String) extends BlockId {
def name = "test_" + id
}
private[spark] object BlockId {
val RDD = "rdd_([0-9]+)_([0-9]+)".r
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
val BROADCAST = "broadcast_([0-9]+)".r
val TASKRESULT = "taskresult_([0-9]+)".r
val STREAM = "input-([0-9]+)-([0-9]+)".r
val TEST = "test_(.*)".r
/** Converts a BlockId "name" String back into a BlockId. */
def apply(id: String) = id match {
case RDD(rddId, splitIndex) =>
RDDBlockId(rddId.toInt, splitIndex.toInt)
case SHUFFLE(shuffleId, mapId, reduceId) =>
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
case BROADCAST(broadcastId) =>
BroadcastBlockId(broadcastId.toLong)
case TASKRESULT(taskId) =>
TaskResultBlockId(taskId.toLong)
case STREAM(streamId, uniqueId) =>
StreamBlockId(streamId.toInt, uniqueId.toLong)
case TEST(value) =>
TestBlockId(value)
case _ =>
throw new IllegalStateException("Unrecognized BlockId: " + id)
}
}

View file

@ -37,7 +37,6 @@ import org.apache.spark.util._
import sun.nio.ch.DirectBuffer import sun.nio.ch.DirectBuffer
private[spark] class BlockManager( private[spark] class BlockManager(
executorId: String, executorId: String,
actorSystem: ActorSystem, actorSystem: ActorSystem,
@ -103,7 +102,7 @@ private[spark] class BlockManager(
val shuffleBlockManager = new ShuffleBlockManager(this) val shuffleBlockManager = new ShuffleBlockManager(this)
private val blockInfo = new TimeStampedHashMap[String, BlockInfo] private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
private[storage] val diskStore: DiskStore = private[storage] val diskStore: DiskStore =
@ -249,7 +248,7 @@ private[spark] class BlockManager(
/** /**
* Get storage level of local block. If no info exists for the block, then returns null. * Get storage level of local block. If no info exists for the block, then returns null.
*/ */
def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
/** /**
* Tell the master about the current storage status of a block. This will send a block update * Tell the master about the current storage status of a block. This will send a block update
@ -259,7 +258,7 @@ private[spark] class BlockManager(
* droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid). * droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid).
* This ensures that update in master will compensate for the increase in memory on slave. * This ensures that update in master will compensate for the increase in memory on slave.
*/ */
def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) { def reportBlockStatus(blockId: BlockId, info: BlockInfo, droppedMemorySize: Long = 0L) {
val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize) val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize)
if (needReregister) { if (needReregister) {
logInfo("Got told to reregister updating block " + blockId) logInfo("Got told to reregister updating block " + blockId)
@ -274,7 +273,7 @@ private[spark] class BlockManager(
* which will be true if the block was successfully recorded and false if * which will be true if the block was successfully recorded and false if
* the slave needs to re-register. * the slave needs to re-register.
*/ */
private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = { private def tryToReportBlockStatus(blockId: BlockId, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = {
val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized { val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized {
info.level match { info.level match {
case null => case null =>
@ -299,7 +298,7 @@ private[spark] class BlockManager(
/** /**
* Get locations of an array of blocks. * Get locations of an array of blocks.
*/ */
def getLocationBlockIds(blockIds: Array[String]): Array[Seq[BlockManagerId]] = { def getLocationBlockIds(blockIds: Array[BlockId]): Array[Seq[BlockManagerId]] = {
val startTimeMs = System.currentTimeMillis val startTimeMs = System.currentTimeMillis
val locations = master.getLocations(blockIds).toArray val locations = master.getLocations(blockIds).toArray
logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
@ -311,7 +310,7 @@ private[spark] class BlockManager(
* shuffle blocks. It is safe to do so without a lock on block info since disk store * shuffle blocks. It is safe to do so without a lock on block info since disk store
* never deletes (recent) items. * never deletes (recent) items.
*/ */
def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { def getLocalFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
diskStore.getValues(blockId, serializer).orElse( diskStore.getValues(blockId, serializer).orElse(
sys.error("Block " + blockId + " not found on disk, though it should be")) sys.error("Block " + blockId + " not found on disk, though it should be"))
} }
@ -319,7 +318,7 @@ private[spark] class BlockManager(
/** /**
* Get block from local block manager. * Get block from local block manager.
*/ */
def getLocal(blockId: String): Option[Iterator[Any]] = { def getLocal(blockId: BlockId): Option[Iterator[Any]] = {
logDebug("Getting local block " + blockId) logDebug("Getting local block " + blockId)
val info = blockInfo.get(blockId).orNull val info = blockInfo.get(blockId).orNull
if (info != null) { if (info != null) {
@ -400,13 +399,13 @@ private[spark] class BlockManager(
/** /**
* Get block from the local block manager as serialized bytes. * Get block from the local block manager as serialized bytes.
*/ */
def getLocalBytes(blockId: String): Option[ByteBuffer] = { def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = {
// TODO: This whole thing is very similar to getLocal; we need to refactor it somehow // TODO: This whole thing is very similar to getLocal; we need to refactor it somehow
logDebug("Getting local block " + blockId + " as bytes") logDebug("Getting local block " + blockId + " as bytes")
// As an optimization for map output fetches, if the block is for a shuffle, return it // As an optimization for map output fetches, if the block is for a shuffle, return it
// without acquiring a lock; the disk store never deletes (recent) items so this should work // without acquiring a lock; the disk store never deletes (recent) items so this should work
if (ShuffleBlockManager.isShuffle(blockId)) { if (blockId.isShuffle) {
return diskStore.getBytes(blockId) match { return diskStore.getBytes(blockId) match {
case Some(bytes) => case Some(bytes) =>
Some(bytes) Some(bytes)
@ -473,7 +472,7 @@ private[spark] class BlockManager(
/** /**
* Get block from remote block managers. * Get block from remote block managers.
*/ */
def getRemote(blockId: String): Option[Iterator[Any]] = { def getRemote(blockId: BlockId): Option[Iterator[Any]] = {
if (blockId == null) { if (blockId == null) {
throw new IllegalArgumentException("Block Id is null") throw new IllegalArgumentException("Block Id is null")
} }
@ -498,7 +497,7 @@ private[spark] class BlockManager(
/** /**
* Get block from remote block managers as serialized bytes. * Get block from remote block managers as serialized bytes.
*/ */
def getRemoteBytes(blockId: String): Option[ByteBuffer] = { def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = {
// TODO: As with getLocalBytes, this is very similar to getRemote and perhaps should be // TODO: As with getLocalBytes, this is very similar to getRemote and perhaps should be
// refactored. // refactored.
if (blockId == null) { if (blockId == null) {
@ -523,7 +522,7 @@ private[spark] class BlockManager(
/** /**
* Get a block from the block manager (either local or remote). * Get a block from the block manager (either local or remote).
*/ */
def get(blockId: String): Option[Iterator[Any]] = { def get(blockId: BlockId): Option[Iterator[Any]] = {
val local = getLocal(blockId) val local = getLocal(blockId)
if (local.isDefined) { if (local.isDefined) {
logInfo("Found block %s locally".format(blockId)) logInfo("Found block %s locally".format(blockId))
@ -544,7 +543,7 @@ private[spark] class BlockManager(
* so that we can control the maxMegabytesInFlight for the fetch. * so that we can control the maxMegabytesInFlight for the fetch.
*/ */
def getMultiple( def getMultiple(
blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer) blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer)
: BlockFetcherIterator = { : BlockFetcherIterator = {
val iter = val iter =
@ -558,7 +557,7 @@ private[spark] class BlockManager(
iter iter
} }
def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) def put(blockId: BlockId, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
: Long = { : Long = {
val elements = new ArrayBuffer[Any] val elements = new ArrayBuffer[Any]
elements ++= values elements ++= values
@ -570,7 +569,7 @@ private[spark] class BlockManager(
* This is currently used for writing shuffle files out. Callers should handle error * This is currently used for writing shuffle files out. Callers should handle error
* cases. * cases.
*/ */
def getDiskBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) def getDiskBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = { : BlockObjectWriter = {
val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize) val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize)
writer.registerCloseEventHandler(() => { writer.registerCloseEventHandler(() => {
@ -584,7 +583,7 @@ private[spark] class BlockManager(
/** /**
* Put a new block of values to the block manager. Returns its (estimated) size in bytes. * Put a new block of values to the block manager. Returns its (estimated) size in bytes.
*/ */
def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel, def put(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
tellMaster: Boolean = true) : Long = { tellMaster: Boolean = true) : Long = {
if (blockId == null) { if (blockId == null) {
@ -704,7 +703,7 @@ private[spark] class BlockManager(
* Put a new block of serialized bytes to the block manager. * Put a new block of serialized bytes to the block manager.
*/ */
def putBytes( def putBytes(
blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) { blockId: BlockId, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) {
if (blockId == null) { if (blockId == null) {
throw new IllegalArgumentException("Block Id is null") throw new IllegalArgumentException("Block Id is null")
@ -805,7 +804,7 @@ private[spark] class BlockManager(
* Replicate block to another node. * Replicate block to another node.
*/ */
var cachedPeers: Seq[BlockManagerId] = null var cachedPeers: Seq[BlockManagerId] = null
private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel) {
val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
if (cachedPeers == null) { if (cachedPeers == null) {
cachedPeers = master.getPeers(blockManagerId, level.replication - 1) cachedPeers = master.getPeers(blockManagerId, level.replication - 1)
@ -828,14 +827,14 @@ private[spark] class BlockManager(
/** /**
* Read a block consisting of a single object. * Read a block consisting of a single object.
*/ */
def getSingle(blockId: String): Option[Any] = { def getSingle(blockId: BlockId): Option[Any] = {
get(blockId).map(_.next()) get(blockId).map(_.next())
} }
/** /**
* Write a block consisting of a single object. * Write a block consisting of a single object.
*/ */
def putSingle(blockId: String, value: Any, level: StorageLevel, tellMaster: Boolean = true) { def putSingle(blockId: BlockId, value: Any, level: StorageLevel, tellMaster: Boolean = true) {
put(blockId, Iterator(value), level, tellMaster) put(blockId, Iterator(value), level, tellMaster)
} }
@ -843,7 +842,7 @@ private[spark] class BlockManager(
* Drop a block from memory, possibly putting it on disk if applicable. Called when the memory * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory
* store reaches its limit and needs to free up space. * store reaches its limit and needs to free up space.
*/ */
def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) { def dropFromMemory(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer]) {
logInfo("Dropping block " + blockId + " from memory") logInfo("Dropping block " + blockId + " from memory")
val info = blockInfo.get(blockId).orNull val info = blockInfo.get(blockId).orNull
if (info != null) { if (info != null) {
@ -892,16 +891,15 @@ private[spark] class BlockManager(
// TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps
// from RDD.id to blocks. // from RDD.id to blocks.
logInfo("Removing RDD " + rddId) logInfo("Removing RDD " + rddId)
val rddPrefix = "rdd_" + rddId + "_" val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
val blocksToRemove = blockInfo.filter(_._1.startsWith(rddPrefix)).map(_._1) blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false))
blocksToRemove.foreach(blockId => removeBlock(blockId, false))
blocksToRemove.size blocksToRemove.size
} }
/** /**
* Remove a block from both memory and disk. * Remove a block from both memory and disk.
*/ */
def removeBlock(blockId: String, tellMaster: Boolean = true) { def removeBlock(blockId: BlockId, tellMaster: Boolean = true) {
logInfo("Removing block " + blockId) logInfo("Removing block " + blockId)
val info = blockInfo.get(blockId).orNull val info = blockInfo.get(blockId).orNull
if (info != null) info.synchronized { if (info != null) info.synchronized {
@ -928,7 +926,7 @@ private[spark] class BlockManager(
while (iterator.hasNext) { while (iterator.hasNext) {
val entry = iterator.next() val entry = iterator.next()
val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
if (time < cleanupTime && ! BlockManager.isBroadcastBlock(id) ) { if (time < cleanupTime && !id.isBroadcast) {
info.synchronized { info.synchronized {
val level = info.level val level = info.level
if (level.useMemory) { if (level.useMemory) {
@ -951,7 +949,7 @@ private[spark] class BlockManager(
while (iterator.hasNext) { while (iterator.hasNext) {
val entry = iterator.next() val entry = iterator.next()
val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
if (time < cleanupTime && BlockManager.isBroadcastBlock(id) ) { if (time < cleanupTime && id.isBroadcast) {
info.synchronized { info.synchronized {
val level = info.level val level = info.level
if (level.useMemory) { if (level.useMemory) {
@ -968,34 +966,29 @@ private[spark] class BlockManager(
} }
} }
def shouldCompress(blockId: String): Boolean = { def shouldCompress(blockId: BlockId): Boolean = blockId match {
if (ShuffleBlockManager.isShuffle(blockId)) { case ShuffleBlockId(_, _, _) => compressShuffle
compressShuffle case BroadcastBlockId(_) => compressBroadcast
} else if (BlockManager.isBroadcastBlock(blockId)) { case RDDBlockId(_, _) => compressRdds
compressBroadcast case _ => false
} else if (blockId.startsWith("rdd_")) {
compressRdds
} else {
false // Won't happen in a real cluster, but it can in tests
}
} }
/** /**
* Wrap an output stream for compression if block compression is enabled for its block type * Wrap an output stream for compression if block compression is enabled for its block type
*/ */
def wrapForCompression(blockId: String, s: OutputStream): OutputStream = { def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
} }
/** /**
* Wrap an input stream for compression if block compression is enabled for its block type * Wrap an input stream for compression if block compression is enabled for its block type
*/ */
def wrapForCompression(blockId: String, s: InputStream): InputStream = { def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
} }
def dataSerialize( def dataSerialize(
blockId: String, blockId: BlockId,
values: Iterator[Any], values: Iterator[Any],
serializer: Serializer = defaultSerializer): ByteBuffer = { serializer: Serializer = defaultSerializer): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096) val byteStream = new FastByteArrayOutputStream(4096)
@ -1010,7 +1003,7 @@ private[spark] class BlockManager(
* the iterator is reached. * the iterator is reached.
*/ */
def dataDeserialize( def dataDeserialize(
blockId: String, blockId: BlockId,
bytes: ByteBuffer, bytes: ByteBuffer,
serializer: Serializer = defaultSerializer): Iterator[Any] = { serializer: Serializer = defaultSerializer): Iterator[Any] = {
bytes.rewind() bytes.rewind()
@ -1065,10 +1058,10 @@ private[spark] object BlockManager extends Logging {
} }
def blockIdsToBlockManagers( def blockIdsToBlockManagers(
blockIds: Array[String], blockIds: Array[BlockId],
env: SparkEnv, env: SparkEnv,
blockManagerMaster: BlockManagerMaster = null) blockManagerMaster: BlockManagerMaster = null)
: Map[String, Seq[BlockManagerId]] = : Map[BlockId, Seq[BlockManagerId]] =
{ {
// env == null and blockManagerMaster != null is used in tests // env == null and blockManagerMaster != null is used in tests
assert (env != null || blockManagerMaster != null) assert (env != null || blockManagerMaster != null)
@ -1078,7 +1071,7 @@ private[spark] object BlockManager extends Logging {
blockManagerMaster.getLocations(blockIds) blockManagerMaster.getLocations(blockIds)
} }
val blockManagers = new HashMap[String, Seq[BlockManagerId]] val blockManagers = new HashMap[BlockId, Seq[BlockManagerId]]
for (i <- 0 until blockIds.length) { for (i <- 0 until blockIds.length) {
blockManagers(blockIds(i)) = blockLocations(i) blockManagers(blockIds(i)) = blockLocations(i)
} }
@ -1086,25 +1079,21 @@ private[spark] object BlockManager extends Logging {
} }
def blockIdsToExecutorIds( def blockIdsToExecutorIds(
blockIds: Array[String], blockIds: Array[BlockId],
env: SparkEnv, env: SparkEnv,
blockManagerMaster: BlockManagerMaster = null) blockManagerMaster: BlockManagerMaster = null)
: Map[String, Seq[String]] = : Map[BlockId, Seq[String]] =
{ {
blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId)) blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId))
} }
def blockIdsToHosts( def blockIdsToHosts(
blockIds: Array[String], blockIds: Array[BlockId],
env: SparkEnv, env: SparkEnv,
blockManagerMaster: BlockManagerMaster = null) blockManagerMaster: BlockManagerMaster = null)
: Map[String, Seq[String]] = : Map[BlockId, Seq[String]] =
{ {
blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host)) blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host))
} }
def isBroadcastBlock(blockId: String): Boolean = null != blockId && blockId.startsWith("broadcast_")
def toBroadcastId(id: Long): String = "broadcast_" + id
} }

View file

@ -60,7 +60,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
def updateBlockInfo( def updateBlockInfo(
blockManagerId: BlockManagerId, blockManagerId: BlockManagerId,
blockId: String, blockId: BlockId,
storageLevel: StorageLevel, storageLevel: StorageLevel,
memSize: Long, memSize: Long,
diskSize: Long): Boolean = { diskSize: Long): Boolean = {
@ -71,12 +71,12 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
} }
/** Get locations of the blockId from the driver */ /** Get locations of the blockId from the driver */
def getLocations(blockId: String): Seq[BlockManagerId] = { def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId)) askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId))
} }
/** Get locations of multiple blockIds from the driver */ /** Get locations of multiple blockIds from the driver */
def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
} }
@ -94,7 +94,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
* Remove a block from the slaves that have it. This can only be used to remove * Remove a block from the slaves that have it. This can only be used to remove
* blocks that the driver knows about. * blocks that the driver knows about.
*/ */
def removeBlock(blockId: String) { def removeBlock(blockId: BlockId) {
askDriverWithReply(RemoveBlock(blockId)) askDriverWithReply(RemoveBlock(blockId))
} }

View file

@ -48,7 +48,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId] private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId]
// Mapping from block id to the set of block managers that have the block. // Mapping from block id to the set of block managers that have the block.
private val blockLocations = new JHashMap[String, mutable.HashSet[BlockManagerId]] private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]]
val akkaTimeout = Duration.create( val akkaTimeout = Duration.create(
System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
@ -129,10 +129,9 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
// First remove the metadata for the given RDD, and then asynchronously remove the blocks // First remove the metadata for the given RDD, and then asynchronously remove the blocks
// from the slaves. // from the slaves.
val prefix = "rdd_" + rddId + "_"
// Find all blocks for the given RDD, remove the block from both blockLocations and // Find all blocks for the given RDD, remove the block from both blockLocations and
// the blockManagerInfo that is tracking the blocks. // the blockManagerInfo that is tracking the blocks.
val blocks = blockLocations.keySet().filter(_.startsWith(prefix)) val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
blocks.foreach { blockId => blocks.foreach { blockId =>
val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId)
bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId)))
@ -198,7 +197,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
// Remove a block from the slaves that have it. This can only be used to remove // Remove a block from the slaves that have it. This can only be used to remove
// blocks that the master knows about. // blocks that the master knows about.
private def removeBlockFromWorkers(blockId: String) { private def removeBlockFromWorkers(blockId: BlockId) {
val locations = blockLocations.get(blockId) val locations = blockLocations.get(blockId)
if (locations != null) { if (locations != null) {
locations.foreach { blockManagerId: BlockManagerId => locations.foreach { blockManagerId: BlockManagerId =>
@ -247,7 +246,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
private def updateBlockInfo( private def updateBlockInfo(
blockManagerId: BlockManagerId, blockManagerId: BlockManagerId,
blockId: String, blockId: BlockId,
storageLevel: StorageLevel, storageLevel: StorageLevel,
memSize: Long, memSize: Long,
diskSize: Long) { diskSize: Long) {
@ -292,11 +291,11 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
sender ! true sender ! true
} }
private def getLocations(blockId: String): Seq[BlockManagerId] = { private def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty
} }
private def getLocationsMultipleBlockIds(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
blockIds.map(blockId => getLocations(blockId)) blockIds.map(blockId => getLocations(blockId))
} }
@ -330,7 +329,7 @@ object BlockManagerMasterActor {
private var _remainingMem: Long = maxMem private var _remainingMem: Long = maxMem
// Mapping from block id to its status. // Mapping from block id to its status.
private val _blocks = new JHashMap[String, BlockStatus] private val _blocks = new JHashMap[BlockId, BlockStatus]
logInfo("Registering block manager %s with %s RAM".format( logInfo("Registering block manager %s with %s RAM".format(
blockManagerId.hostPort, Utils.bytesToString(maxMem))) blockManagerId.hostPort, Utils.bytesToString(maxMem)))
@ -339,7 +338,7 @@ object BlockManagerMasterActor {
_lastSeenMs = System.currentTimeMillis() _lastSeenMs = System.currentTimeMillis()
} }
def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, def updateBlockInfo(blockId: BlockId, storageLevel: StorageLevel, memSize: Long,
diskSize: Long) { diskSize: Long) {
updateLastSeenMs() updateLastSeenMs()
@ -383,7 +382,7 @@ object BlockManagerMasterActor {
} }
} }
def removeBlock(blockId: String) { def removeBlock(blockId: BlockId) {
if (_blocks.containsKey(blockId)) { if (_blocks.containsKey(blockId)) {
_remainingMem += _blocks.get(blockId).memSize _remainingMem += _blocks.get(blockId).memSize
_blocks.remove(blockId) _blocks.remove(blockId)
@ -394,7 +393,7 @@ object BlockManagerMasterActor {
def lastSeenMs: Long = _lastSeenMs def lastSeenMs: Long = _lastSeenMs
def blocks: JHashMap[String, BlockStatus] = _blocks def blocks: JHashMap[BlockId, BlockStatus] = _blocks
override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem

View file

@ -30,7 +30,7 @@ private[storage] object BlockManagerMessages {
// Remove a block from the slaves that have it. This can only be used to remove // Remove a block from the slaves that have it. This can only be used to remove
// blocks that the master knows about. // blocks that the master knows about.
case class RemoveBlock(blockId: String) extends ToBlockManagerSlave case class RemoveBlock(blockId: BlockId) extends ToBlockManagerSlave
// Remove all blocks belonging to a specific RDD. // Remove all blocks belonging to a specific RDD.
case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave
@ -51,7 +51,7 @@ private[storage] object BlockManagerMessages {
class UpdateBlockInfo( class UpdateBlockInfo(
var blockManagerId: BlockManagerId, var blockManagerId: BlockManagerId,
var blockId: String, var blockId: BlockId,
var storageLevel: StorageLevel, var storageLevel: StorageLevel,
var memSize: Long, var memSize: Long,
var diskSize: Long) var diskSize: Long)
@ -62,7 +62,7 @@ private[storage] object BlockManagerMessages {
override def writeExternal(out: ObjectOutput) { override def writeExternal(out: ObjectOutput) {
blockManagerId.writeExternal(out) blockManagerId.writeExternal(out)
out.writeUTF(blockId) out.writeUTF(blockId.name)
storageLevel.writeExternal(out) storageLevel.writeExternal(out)
out.writeLong(memSize) out.writeLong(memSize)
out.writeLong(diskSize) out.writeLong(diskSize)
@ -70,7 +70,7 @@ private[storage] object BlockManagerMessages {
override def readExternal(in: ObjectInput) { override def readExternal(in: ObjectInput) {
blockManagerId = BlockManagerId(in) blockManagerId = BlockManagerId(in)
blockId = in.readUTF() blockId = BlockId(in.readUTF())
storageLevel = StorageLevel(in) storageLevel = StorageLevel(in)
memSize = in.readLong() memSize = in.readLong()
diskSize = in.readLong() diskSize = in.readLong()
@ -79,7 +79,7 @@ private[storage] object BlockManagerMessages {
object UpdateBlockInfo { object UpdateBlockInfo {
def apply(blockManagerId: BlockManagerId, def apply(blockManagerId: BlockManagerId,
blockId: String, blockId: BlockId,
storageLevel: StorageLevel, storageLevel: StorageLevel,
memSize: Long, memSize: Long,
diskSize: Long): UpdateBlockInfo = { diskSize: Long): UpdateBlockInfo = {
@ -87,14 +87,14 @@ private[storage] object BlockManagerMessages {
} }
// For pattern-matching // For pattern-matching
def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, BlockId, StorageLevel, Long, Long)] = {
Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize))
} }
} }
case class GetLocations(blockId: String) extends ToBlockManagerMaster case class GetLocations(blockId: BlockId) extends ToBlockManagerMaster
case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster
case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster

View file

@ -77,7 +77,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
} }
} }
private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) { private def putBlock(id: BlockId, bytes: ByteBuffer, level: StorageLevel) {
val startTimeMs = System.currentTimeMillis() val startTimeMs = System.currentTimeMillis()
logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes) logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes)
blockManager.putBytes(id, bytes, level) blockManager.putBytes(id, bytes, level)
@ -85,7 +85,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
+ " with data size: " + bytes.limit) + " with data size: " + bytes.limit)
} }
private def getBlock(id: String): ByteBuffer = { private def getBlock(id: BlockId): ByteBuffer = {
val startTimeMs = System.currentTimeMillis() val startTimeMs = System.currentTimeMillis()
logDebug("GetBlock " + id + " started from " + startTimeMs) logDebug("GetBlock " + id + " started from " + startTimeMs)
val buffer = blockManager.getLocalBytes(id) match { val buffer = blockManager.getLocalBytes(id) match {

View file

@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.network._ import org.apache.spark.network._
private[spark] case class GetBlock(id: String) private[spark] case class GetBlock(id: BlockId)
private[spark] case class GotBlock(id: String, data: ByteBuffer) private[spark] case class GotBlock(id: BlockId, data: ByteBuffer)
private[spark] case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel) private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel)
private[spark] class BlockMessage() { private[spark] class BlockMessage() {
// Un-initialized: typ = 0 // Un-initialized: typ = 0
@ -34,7 +34,7 @@ private[spark] class BlockMessage() {
// GotBlock: typ = 2 // GotBlock: typ = 2
// PutBlock: typ = 3 // PutBlock: typ = 3
private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED
private var id: String = null private var id: BlockId = null
private var data: ByteBuffer = null private var data: ByteBuffer = null
private var level: StorageLevel = null private var level: StorageLevel = null
@ -74,7 +74,7 @@ private[spark] class BlockMessage() {
for (i <- 1 to idLength) { for (i <- 1 to idLength) {
idBuilder += buffer.getChar() idBuilder += buffer.getChar()
} }
id = idBuilder.toString() id = BlockId(idBuilder.toString)
if (typ == BlockMessage.TYPE_PUT_BLOCK) { if (typ == BlockMessage.TYPE_PUT_BLOCK) {
@ -109,28 +109,17 @@ private[spark] class BlockMessage() {
set(buffer) set(buffer)
} }
def getType: Int = { def getType: Int = typ
return typ def getId: BlockId = id
} def getData: ByteBuffer = data
def getLevel: StorageLevel = level
def getId: String = {
return id
}
def getData: ByteBuffer = {
return data
}
def getLevel: StorageLevel = {
return level
}
def toBufferMessage: BufferMessage = { def toBufferMessage: BufferMessage = {
val startTime = System.currentTimeMillis val startTime = System.currentTimeMillis
val buffers = new ArrayBuffer[ByteBuffer]() val buffers = new ArrayBuffer[ByteBuffer]()
var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2) var buffer = ByteBuffer.allocate(4 + 4 + id.name.length * 2)
buffer.putInt(typ).putInt(id.length()) buffer.putInt(typ).putInt(id.name.length)
id.foreach((x: Char) => buffer.putChar(x)) id.name.foreach((x: Char) => buffer.putChar(x))
buffer.flip() buffer.flip()
buffers += buffer buffers += buffer
@ -212,7 +201,8 @@ private[spark] object BlockMessage {
def main(args: Array[String]) { def main(args: Array[String]) {
val B = new BlockMessage() val B = new BlockMessage()
B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2)) val blockId = TestBlockId("ABC")
B.set(new PutBlock(blockId, ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2))
val bMsg = B.toBufferMessage val bMsg = B.toBufferMessage
val C = new BlockMessage() val C = new BlockMessage()
C.set(bMsg) C.set(bMsg)

View file

@ -116,9 +116,10 @@ private[spark] object BlockMessageArray {
if (i % 2 == 0) { if (i % 2 == 0) {
val buffer = ByteBuffer.allocate(100) val buffer = ByteBuffer.allocate(100)
buffer.clear buffer.clear
BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY_SER)) BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer,
StorageLevel.MEMORY_ONLY_SER))
} else { } else {
BlockMessage.fromGetBlock(GetBlock(i.toString)) BlockMessage.fromGetBlock(GetBlock(TestBlockId(i.toString)))
} }
} }
val blockMessageArray = new BlockMessageArray(blockMessages) val blockMessageArray = new BlockMessageArray(blockMessages)

View file

@ -25,7 +25,7 @@ package org.apache.spark.storage
* *
* This interface does not support concurrent writes. * This interface does not support concurrent writes.
*/ */
abstract class BlockObjectWriter(val blockId: String) { abstract class BlockObjectWriter(val blockId: BlockId) {
var closeEventHandler: () => Unit = _ var closeEventHandler: () => Unit = _

View file

@ -27,7 +27,7 @@ import org.apache.spark.Logging
*/ */
private[spark] private[spark]
abstract class BlockStore(val blockManager: BlockManager) extends Logging { abstract class BlockStore(val blockManager: BlockManager) extends Logging {
def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel)
/** /**
* Put in a block and, possibly, also return its content as either bytes or another Iterator. * Put in a block and, possibly, also return its content as either bytes or another Iterator.
@ -36,26 +36,26 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging {
* @return a PutResult that contains the size of the data, as well as the values put if * @return a PutResult that contains the size of the data, as well as the values put if
* returnValues is true (if not, the result's data field can be null) * returnValues is true (if not, the result's data field can be null)
*/ */
def putValues(blockId: String, values: ArrayBuffer[Any], level: StorageLevel, def putValues(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
returnValues: Boolean) : PutResult returnValues: Boolean) : PutResult
/** /**
* Return the size of a block in bytes. * Return the size of a block in bytes.
*/ */
def getSize(blockId: String): Long def getSize(blockId: BlockId): Long
def getBytes(blockId: String): Option[ByteBuffer] def getBytes(blockId: BlockId): Option[ByteBuffer]
def getValues(blockId: String): Option[Iterator[Any]] def getValues(blockId: BlockId): Option[Iterator[Any]]
/** /**
* Remove a block, if it exists. * Remove a block, if it exists.
* @param blockId the block to remove. * @param blockId the block to remove.
* @return True if the block was found and removed, False otherwise. * @return True if the block was found and removed, False otherwise.
*/ */
def remove(blockId: String): Boolean def remove(blockId: BlockId): Boolean
def contains(blockId: String): Boolean def contains(blockId: BlockId): Boolean
def clear() { } def clear() { }
} }

View file

@ -42,7 +42,7 @@ import org.apache.spark.util.Utils
private class DiskStore(blockManager: BlockManager, rootDirs: String) private class DiskStore(blockManager: BlockManager, rootDirs: String)
extends BlockStore(blockManager) with Logging { extends BlockStore(blockManager) with Logging {
class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int) class DiskBlockObjectWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
extends BlockObjectWriter(blockId) { extends BlockObjectWriter(blockId) {
private val f: File = createFile(blockId /*, allowAppendExisting */) private val f: File = createFile(blockId /*, allowAppendExisting */)
@ -124,16 +124,16 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
addShutdownHook() addShutdownHook()
def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) def getBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = { : BlockObjectWriter = {
new DiskBlockObjectWriter(blockId, serializer, bufferSize) new DiskBlockObjectWriter(blockId, serializer, bufferSize)
} }
override def getSize(blockId: String): Long = { override def getSize(blockId: BlockId): Long = {
getFile(blockId).length() getFile(blockId).length()
} }
override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
// So that we do not modify the input offsets ! // So that we do not modify the input offsets !
// duplicate does not copy buffer, so inexpensive // duplicate does not copy buffer, so inexpensive
val bytes = _bytes.duplicate() val bytes = _bytes.duplicate()
@ -163,7 +163,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
} }
override def putValues( override def putValues(
blockId: String, blockId: BlockId,
values: ArrayBuffer[Any], values: ArrayBuffer[Any],
level: StorageLevel, level: StorageLevel,
returnValues: Boolean) returnValues: Boolean)
@ -192,13 +192,13 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
} }
} }
override def getBytes(blockId: String): Option[ByteBuffer] = { override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
val file = getFile(blockId) val file = getFile(blockId)
val bytes = getFileBytes(file) val bytes = getFileBytes(file)
Some(bytes) Some(bytes)
} }
override def getValues(blockId: String): Option[Iterator[Any]] = { override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes)) getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
} }
@ -206,11 +206,11 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
* A version of getValues that allows a custom serializer. This is used as part of the * A version of getValues that allows a custom serializer. This is used as part of the
* shuffle short-circuit code. * shuffle short-circuit code.
*/ */
def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer)) getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer))
} }
override def remove(blockId: String): Boolean = { override def remove(blockId: BlockId): Boolean = {
val file = getFile(blockId) val file = getFile(blockId)
if (file.exists()) { if (file.exists()) {
file.delete() file.delete()
@ -219,11 +219,11 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
} }
} }
override def contains(blockId: String): Boolean = { override def contains(blockId: BlockId): Boolean = {
getFile(blockId).exists() getFile(blockId).exists()
} }
private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = { private def createFile(blockId: BlockId, allowAppendExisting: Boolean = false): File = {
val file = getFile(blockId) val file = getFile(blockId)
if (!allowAppendExisting && file.exists()) { if (!allowAppendExisting && file.exists()) {
// NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
@ -234,7 +234,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
file file
} }
private def getFile(blockId: String): File = { private def getFile(blockId: BlockId): File = {
logDebug("Getting file for block " + blockId) logDebug("Getting file for block " + blockId)
// Figure out which local directory it hashes to, and which subdirectory in that // Figure out which local directory it hashes to, and which subdirectory in that
@ -258,7 +258,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
} }
} }
new File(subDir, blockId) new File(subDir, blockId.name)
} }
private def createLocalDirs(): Array[File] = { private def createLocalDirs(): Array[File] = {
@ -307,7 +307,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
} }
} }
if (shuffleSender != null) { if (shuffleSender != null) {
shuffleSender.stop shuffleSender.stop()
} }
} }
}) })
@ -315,11 +315,10 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private[storage] def startShuffleBlockSender(port: Int): Int = { private[storage] def startShuffleBlockSender(port: Int): Int = {
val pResolver = new PathResolver { val pResolver = new PathResolver {
override def getAbsolutePath(blockId: String): String = { override def getAbsolutePath(blockIdString: String): String = {
if (!blockId.startsWith("shuffle_")) { val blockId = BlockId(blockIdString)
return null if (!blockId.isShuffle) null
} else DiskStore.this.getFile(blockId).getAbsolutePath
DiskStore.this.getFile(blockId).getAbsolutePath()
} }
} }
shuffleSender = new ShuffleSender(port, pResolver) shuffleSender = new ShuffleSender(port, pResolver)

View file

@ -32,7 +32,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
case class Entry(value: Any, size: Long, deserialized: Boolean) case class Entry(value: Any, size: Long, deserialized: Boolean)
private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true) private val entries = new LinkedHashMap[BlockId, Entry](32, 0.75f, true)
@volatile private var currentMemory = 0L @volatile private var currentMemory = 0L
// Object used to ensure that only one thread is putting blocks and if necessary, dropping // Object used to ensure that only one thread is putting blocks and if necessary, dropping
// blocks from the memory store. // blocks from the memory store.
@ -42,13 +42,13 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
def freeMemory: Long = maxMemory - currentMemory def freeMemory: Long = maxMemory - currentMemory
override def getSize(blockId: String): Long = { override def getSize(blockId: BlockId): Long = {
entries.synchronized { entries.synchronized {
entries.get(blockId).size entries.get(blockId).size
} }
} }
override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
// Work on a duplicate - since the original input might be used elsewhere. // Work on a duplicate - since the original input might be used elsewhere.
val bytes = _bytes.duplicate() val bytes = _bytes.duplicate()
bytes.rewind() bytes.rewind()
@ -64,7 +64,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
} }
override def putValues( override def putValues(
blockId: String, blockId: BlockId,
values: ArrayBuffer[Any], values: ArrayBuffer[Any],
level: StorageLevel, level: StorageLevel,
returnValues: Boolean) returnValues: Boolean)
@ -81,7 +81,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
} }
} }
override def getBytes(blockId: String): Option[ByteBuffer] = { override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
val entry = entries.synchronized { val entry = entries.synchronized {
entries.get(blockId) entries.get(blockId)
} }
@ -94,7 +94,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
} }
} }
override def getValues(blockId: String): Option[Iterator[Any]] = { override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
val entry = entries.synchronized { val entry = entries.synchronized {
entries.get(blockId) entries.get(blockId)
} }
@ -108,7 +108,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
} }
} }
override def remove(blockId: String): Boolean = { override def remove(blockId: BlockId): Boolean = {
entries.synchronized { entries.synchronized {
val entry = entries.remove(blockId) val entry = entries.remove(blockId)
if (entry != null) { if (entry != null) {
@ -131,14 +131,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
} }
/** /**
* Return the RDD ID that a given block ID is from, or null if it is not an RDD block. * Return the RDD ID that a given block ID is from, or None if it is not an RDD block.
*/ */
private def getRddId(blockId: String): String = { private def getRddId(blockId: BlockId): Option[Int] = {
if (blockId.startsWith("rdd_")) { blockId.asRDDId.map(_.rddId)
blockId.split('_')(1)
} else {
null
}
} }
/** /**
@ -151,7 +147,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
* blocks to free memory for one block, another thread may use up the freed space for * blocks to free memory for one block, another thread may use up the freed space for
* another block. * another block.
*/ */
private def tryToPut(blockId: String, value: Any, size: Long, deserialized: Boolean): Boolean = { private def tryToPut(blockId: BlockId, value: Any, size: Long, deserialized: Boolean): Boolean = {
// TODO: Its possible to optimize the locking by locking entries only when selecting blocks // TODO: Its possible to optimize the locking by locking entries only when selecting blocks
// to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has been // to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has been
// released, it must be ensured that those to-be-dropped blocks are not double counted for // released, it must be ensured that those to-be-dropped blocks are not double counted for
@ -195,7 +191,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
* Assumes that a lock is held by the caller to ensure only one thread is dropping blocks. * Assumes that a lock is held by the caller to ensure only one thread is dropping blocks.
* Otherwise, the freed space may fill up before the caller puts in their new value. * Otherwise, the freed space may fill up before the caller puts in their new value.
*/ */
private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = { private def ensureFreeSpace(blockIdToAdd: BlockId, space: Long): Boolean = {
logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format(
space, currentMemory, maxMemory)) space, currentMemory, maxMemory))
@ -207,7 +203,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
if (maxMemory - currentMemory < space) { if (maxMemory - currentMemory < space) {
val rddToAdd = getRddId(blockIdToAdd) val rddToAdd = getRddId(blockIdToAdd)
val selectedBlocks = new ArrayBuffer[String]() val selectedBlocks = new ArrayBuffer[BlockId]()
var selectedMemory = 0L var selectedMemory = 0L
// This is synchronized to ensure that the set of entries is not changed // This is synchronized to ensure that the set of entries is not changed
@ -218,7 +214,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) {
val pair = iterator.next() val pair = iterator.next()
val blockId = pair.getKey val blockId = pair.getKey
if (rddToAdd != null && rddToAdd == getRddId(blockId)) { if (rddToAdd != None && rddToAdd == getRddId(blockId)) {
logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " + logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " +
"block from the same RDD") "block from the same RDD")
return false return false
@ -252,7 +248,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
return true return true
} }
override def contains(blockId: String): Boolean = { override def contains(blockId: BlockId): Boolean = {
entries.synchronized { entries.containsKey(blockId) } entries.synchronized { entries.containsKey(blockId) }
} }
} }

View file

@ -30,7 +30,6 @@ trait ShuffleBlocks {
def releaseWriters(group: ShuffleWriterGroup) def releaseWriters(group: ShuffleWriterGroup)
} }
private[spark] private[spark]
class ShuffleBlockManager(blockManager: BlockManager) { class ShuffleBlockManager(blockManager: BlockManager) {
@ -40,7 +39,7 @@ class ShuffleBlockManager(blockManager: BlockManager) {
override def acquireWriters(mapId: Int): ShuffleWriterGroup = { override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
blockManager.getDiskBlockWriter(blockId, serializer, bufferSize) blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
} }
new ShuffleWriterGroup(mapId, writers) new ShuffleWriterGroup(mapId, writers)
@ -52,16 +51,3 @@ class ShuffleBlockManager(blockManager: BlockManager) {
} }
} }
} }
private[spark]
object ShuffleBlockManager {
// Returns the block id for a given shuffle block.
def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = {
"shuffle_" + shuffleId + "_" + groupId + "_" + bucketId
}
// Returns true if the block is a shuffle block.
def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_")
}

View file

@ -23,20 +23,24 @@ import org.apache.spark.util.Utils
private[spark] private[spark]
case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
blocks: Map[String, BlockStatus]) { blocks: Map[BlockId, BlockStatus]) {
def memUsed(blockPrefix: String = "") = { def memUsed() = blocks.values.map(_.memSize).reduceOption(_+_).getOrElse(0L)
blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize).
reduceOption(_+_).getOrElse(0l)
}
def diskUsed(blockPrefix: String = "") = { def memUsedByRDD(rddId: Int) =
blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.diskSize). rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_+_).getOrElse(0L)
reduceOption(_+_).getOrElse(0l)
} def diskUsed() = blocks.values.map(_.diskSize).reduceOption(_+_).getOrElse(0L)
def diskUsedByRDD(rddId: Int) =
rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_+_).getOrElse(0L)
def memRemaining : Long = maxMem - memUsed() def memRemaining : Long = maxMem - memUsed()
def rddBlocks = blocks.flatMap {
case (rdd: RDDBlockId, status) => Some(rdd, status)
case _ => None
}
} }
case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel,
@ -60,7 +64,7 @@ object StorageUtils {
/* Returns RDD-level information, compiled from a list of StorageStatus objects */ /* Returns RDD-level information, compiled from a list of StorageStatus objects */
def rddInfoFromStorageStatus(storageStatusList: Seq[StorageStatus], def rddInfoFromStorageStatus(storageStatusList: Seq[StorageStatus],
sc: SparkContext) : Array[RDDInfo] = { sc: SparkContext) : Array[RDDInfo] = {
rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc) rddInfoFromBlockStatusList(storageStatusList.flatMap(_.rddBlocks).toMap[RDDBlockId, BlockStatus], sc)
} }
/* Returns a map of blocks to their locations, compiled from a list of StorageStatus objects */ /* Returns a map of blocks to their locations, compiled from a list of StorageStatus objects */
@ -71,26 +75,21 @@ object StorageUtils {
} }
/* Given a list of BlockStatus objets, returns information for each RDD */ /* Given a list of BlockStatus objets, returns information for each RDD */
def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus], def rddInfoFromBlockStatusList(infos: Map[RDDBlockId, BlockStatus],
sc: SparkContext) : Array[RDDInfo] = { sc: SparkContext) : Array[RDDInfo] = {
// Group by rddId, ignore the partition name // Group by rddId, ignore the partition name
val groupedRddBlocks = infos.filterKeys(_.startsWith("rdd_")).groupBy { case(k, v) => val groupedRddBlocks = infos.groupBy { case(k, v) => k.rddId }.mapValues(_.values.toArray)
k.substring(0,k.lastIndexOf('_'))
}.mapValues(_.values.toArray)
// For each RDD, generate an RDDInfo object // For each RDD, generate an RDDInfo object
val rddInfos = groupedRddBlocks.map { case (rddKey, rddBlocks) => val rddInfos = groupedRddBlocks.map { case (rddId, rddBlocks) =>
// Add up memory and disk sizes // Add up memory and disk sizes
val memSize = rddBlocks.map(_.memSize).reduce(_ + _) val memSize = rddBlocks.map(_.memSize).reduce(_ + _)
val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _) val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _)
// Find the id of the RDD, e.g. rdd_1 => 1
val rddId = rddKey.split("_").last.toInt
// Get the friendly name and storage level for the RDD, if available // Get the friendly name and storage level for the RDD, if available
sc.persistentRdds.get(rddId).map { r => sc.persistentRdds.get(rddId).map { r =>
val rddName = Option(r.name).getOrElse(rddKey) val rddName = Option(r.name).getOrElse(rddId.toString)
val rddStorageLevel = r.getStorageLevel val rddStorageLevel = r.getStorageLevel
RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize) RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize)
} }
@ -101,16 +100,14 @@ object StorageUtils {
rddInfos rddInfos
} }
/* Removes all BlockStatus object that are not part of a block prefix */ /* Filters storage status by a given RDD id. */
def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus], def filterStorageStatusByRDD(storageStatusList: Array[StorageStatus], rddId: Int)
prefix: String) : Array[StorageStatus] = { : Array[StorageStatus] = {
storageStatusList.map { status => storageStatusList.map { status =>
val newBlocks = status.blocks.filterKeys(_.startsWith(prefix)) val newBlocks = status.rddBlocks.filterKeys(_.rddId == rddId).toMap[BlockId, BlockStatus]
//val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _) //val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _)
StorageStatus(status.blockManagerId, status.maxMem, newBlocks) StorageStatus(status.blockManagerId, status.maxMem, newBlocks)
} }
} }
} }

View file

@ -36,11 +36,11 @@ private[spark] object ThreadingTest {
val numBlocksPerProducer = 20000 val numBlocksPerProducer = 20000
private[spark] class ProducerThread(manager: BlockManager, id: Int) extends Thread { private[spark] class ProducerThread(manager: BlockManager, id: Int) extends Thread {
val queue = new ArrayBlockingQueue[(String, Seq[Int])](100) val queue = new ArrayBlockingQueue[(BlockId, Seq[Int])](100)
override def run() { override def run() {
for (i <- 1 to numBlocksPerProducer) { for (i <- 1 to numBlocksPerProducer) {
val blockId = "b-" + id + "-" + i val blockId = TestBlockId("b-" + id + "-" + i)
val blockSize = Random.nextInt(1000) val blockSize = Random.nextInt(1000)
val block = (1 to blockSize).map(_ => Random.nextInt()) val block = (1 to blockSize).map(_ => Random.nextInt())
val level = randomLevel() val level = randomLevel()
@ -64,7 +64,7 @@ private[spark] object ThreadingTest {
private[spark] class ConsumerThread( private[spark] class ConsumerThread(
manager: BlockManager, manager: BlockManager,
queue: ArrayBlockingQueue[(String, Seq[Int])] queue: ArrayBlockingQueue[(BlockId, Seq[Int])]
) extends Thread { ) extends Thread {
var numBlockConsumed = 0 var numBlockConsumed = 0

View file

@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.Node import scala.xml.Node
import org.apache.spark.storage.{StorageStatus, StorageUtils} import org.apache.spark.storage.{BlockId, StorageStatus, StorageUtils}
import org.apache.spark.storage.BlockManagerMasterActor.BlockStatus import org.apache.spark.storage.BlockManagerMasterActor.BlockStatus
import org.apache.spark.ui.UIUtils._ import org.apache.spark.ui.UIUtils._
import org.apache.spark.ui.Page._ import org.apache.spark.ui.Page._
@ -33,21 +33,20 @@ private[spark] class RDDPage(parent: BlockManagerUI) {
val sc = parent.sc val sc = parent.sc
def render(request: HttpServletRequest): Seq[Node] = { def render(request: HttpServletRequest): Seq[Node] = {
val id = request.getParameter("id") val id = request.getParameter("id").toInt
val prefix = "rdd_" + id.toString
val storageStatusList = sc.getExecutorStorageStatus val storageStatusList = sc.getExecutorStorageStatus
val filteredStorageStatusList = StorageUtils. val filteredStorageStatusList = StorageUtils.filterStorageStatusByRDD(storageStatusList, id)
filterStorageStatusByPrefix(storageStatusList, prefix)
val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head
val workerHeaders = Seq("Host", "Memory Usage", "Disk Usage") val workerHeaders = Seq("Host", "Memory Usage", "Disk Usage")
val workers = filteredStorageStatusList.map((prefix, _)) val workers = filteredStorageStatusList.map((id, _))
val workerTable = listingTable(workerHeaders, workerRow, workers) val workerTable = listingTable(workerHeaders, workerRow, workers)
val blockHeaders = Seq("Block Name", "Storage Level", "Size in Memory", "Size on Disk", val blockHeaders = Seq("Block Name", "Storage Level", "Size in Memory", "Size on Disk",
"Executors") "Executors")
val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1) val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.
sortWith(_._1.name < _._1.name)
val blockLocations = StorageUtils.blockLocationsFromStorageStatus(filteredStorageStatusList) val blockLocations = StorageUtils.blockLocationsFromStorageStatus(filteredStorageStatusList)
val blocks = blockStatuses.map { val blocks = blockStatuses.map {
case(id, status) => (id, status, blockLocations.get(id).getOrElse(Seq("UNKNOWN"))) case(id, status) => (id, status, blockLocations.get(id).getOrElse(Seq("UNKNOWN")))
@ -99,7 +98,7 @@ private[spark] class RDDPage(parent: BlockManagerUI) {
headerSparkPage(content, parent.sc, "RDD Storage Info for " + rddInfo.name, Storage) headerSparkPage(content, parent.sc, "RDD Storage Info for " + rddInfo.name, Storage)
} }
def blockRow(row: (String, BlockStatus, Seq[String])): Seq[Node] = { def blockRow(row: (BlockId, BlockStatus, Seq[String])): Seq[Node] = {
val (id, block, locations) = row val (id, block, locations) = row
<tr> <tr>
<td>{id}</td> <td>{id}</td>
@ -118,15 +117,15 @@ private[spark] class RDDPage(parent: BlockManagerUI) {
</tr> </tr>
} }
def workerRow(worker: (String, StorageStatus)): Seq[Node] = { def workerRow(worker: (Int, StorageStatus)): Seq[Node] = {
val (prefix, status) = worker val (rddId, status) = worker
<tr> <tr>
<td>{status.blockManagerId.host + ":" + status.blockManagerId.port}</td> <td>{status.blockManagerId.host + ":" + status.blockManagerId.port}</td>
<td> <td>
{Utils.bytesToString(status.memUsed(prefix))} {Utils.bytesToString(status.memUsedByRDD(rddId))}
({Utils.bytesToString(status.memRemaining)} Remaining) ({Utils.bytesToString(status.memRemaining)} Remaining)
</td> </td>
<td>{Utils.bytesToString(status.diskUsed(prefix))}</td> <td>{Utils.bytesToString(status.diskUsedByRDD(rddId))}</td>
</tr> </tr>
} }
} }

View file

@ -23,7 +23,7 @@ import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.mock.EasyMockSugar import org.scalatest.mock.EasyMockSugar
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{BlockManager, StorageLevel} import org.apache.spark.storage.{BlockManager, RDDBlockId, StorageLevel}
// TODO: Test the CacheManager's thread-safety aspects // TODO: Test the CacheManager's thread-safety aspects
class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar { class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar {
@ -52,9 +52,9 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
test("get uncached rdd") { test("get uncached rdd") {
expecting { expecting {
blockManager.get("rdd_0_0").andReturn(None) blockManager.get(RDDBlockId(0, 0)).andReturn(None)
blockManager.put("rdd_0_0", ArrayBuffer[Any](1, 2, 3, 4), StorageLevel.MEMORY_ONLY, true). blockManager.put(RDDBlockId(0, 0), ArrayBuffer[Any](1, 2, 3, 4), StorageLevel.MEMORY_ONLY,
andReturn(0) true).andReturn(0)
} }
whenExecuting(blockManager) { whenExecuting(blockManager) {
@ -66,7 +66,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
test("get cached rdd") { test("get cached rdd") {
expecting { expecting {
blockManager.get("rdd_0_0").andReturn(Some(ArrayBuffer(5, 6, 7).iterator)) blockManager.get(RDDBlockId(0, 0)).andReturn(Some(ArrayBuffer(5, 6, 7).iterator))
} }
whenExecuting(blockManager) { whenExecuting(blockManager) {
@ -79,7 +79,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
test("get uncached local rdd") { test("get uncached local rdd") {
expecting { expecting {
// Local computation should not persist the resulting value, so don't expect a put(). // Local computation should not persist the resulting value, so don't expect a put().
blockManager.get("rdd_0_0").andReturn(None) blockManager.get(RDDBlockId(0, 0)).andReturn(None)
} }
whenExecuting(blockManager) { whenExecuting(blockManager) {

View file

@ -21,7 +21,7 @@ import org.scalatest.FunSuite
import java.io.File import java.io.File
import org.apache.spark.rdd._ import org.apache.spark.rdd._
import org.apache.spark.SparkContext._ import org.apache.spark.SparkContext._
import storage.StorageLevel import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
@ -83,7 +83,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
} }
test("BlockRDD") { test("BlockRDD") {
val blockId = "id" val blockId = TestBlockId("id")
val blockManager = SparkEnv.get.blockManager val blockManager = SparkEnv.get.blockManager
blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
val blockRDD = new BlockRDD[String](sc, Array(blockId)) val blockRDD = new BlockRDD[String](sc, Array(blockId))
@ -191,7 +191,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
} }
test("CheckpointRDD with zero partitions") { test("CheckpointRDD with zero partitions") {
val rdd = new BlockRDD[Int](sc, Array[String]()) val rdd = new BlockRDD[Int](sc, Array[BlockId]())
assert(rdd.partitions.size === 0) assert(rdd.partitions.size === 0)
assert(rdd.isCheckpointed === false) assert(rdd.isCheckpointed === false)
rdd.checkpoint() rdd.checkpoint()

View file

@ -18,24 +18,14 @@
package org.apache.spark package org.apache.spark
import network.ConnectionManagerId import network.ConnectionManagerId
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Timeouts._ import org.scalatest.concurrent.Timeouts._
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers import org.scalatest.matchers.ShouldMatchers
import org.scalatest.prop.Checkers
import org.scalatest.time.{Span, Millis} import org.scalatest.time.{Span, Millis}
import org.scalacheck.Arbitrary._
import org.scalacheck.Gen
import org.scalacheck.Prop._
import org.eclipse.jetty.server.{Server, Request, Handler}
import com.google.common.io.Files
import scala.collection.mutable.ArrayBuffer
import SparkContext._ import SparkContext._
import storage.{GetBlock, BlockManagerWorker, StorageLevel} import org.apache.spark.storage.{BlockManagerWorker, GetBlock, RDDBlockId, StorageLevel}
import ui.JettyUtils
class NotSerializableClass class NotSerializableClass
@ -193,7 +183,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
// Get all the locations of the first partition and try to fetch the partitions // Get all the locations of the first partition and try to fetch the partitions
// from those locations. // from those locations.
val blockIds = data.partitions.indices.map(index => "rdd_%d_%d".format(data.id, index)).toArray val blockIds = data.partitions.indices.map(index => RDDBlockId(data.id, index)).toArray
val blockId = blockIds(0) val blockId = blockIds(0)
val blockManager = SparkEnv.get.blockManager val blockManager = SparkEnv.get.blockManager
blockManager.master.getLocations(blockId).foreach(id => { blockManager.master.getLocations(blockId).foreach(id => {

View file

@ -30,7 +30,7 @@ import org.apache.spark.Partition
import org.apache.spark.TaskContext import org.apache.spark.TaskContext
import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency} import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency}
import org.apache.spark.{FetchFailed, Success, TaskEndReason} import org.apache.spark.{FetchFailed, Success, TaskEndReason}
import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster} import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
@ -75,15 +75,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
// stub out BlockManagerMaster.getLocations to use our cacheLocations // stub out BlockManagerMaster.getLocations to use our cacheLocations
val blockManagerMaster = new BlockManagerMaster(null) { val blockManagerMaster = new BlockManagerMaster(null) {
override def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { override def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
blockIds.map { name => blockIds.map {
val pieces = name.split("_") _.asRDDId.map(id => (id.rddId -> id.splitIndex)).flatMap(key => cacheLocations.get(key)).
if (pieces(0) == "rdd") { getOrElse(Seq())
val key = pieces(1).toInt -> pieces(2).toInt
cacheLocations.getOrElse(key, Seq())
} else {
Seq()
}
}.toSeq }.toSeq
} }
override def removeExecutor(execId: String) { override def removeExecutor(execId: String) {

View file

@ -23,6 +23,7 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv} import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv}
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult} import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
import org.apache.spark.storage.TaskResultBlockId
/** /**
* Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter. * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter.
@ -85,7 +86,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
assert(result === 1.to(akkaFrameSize).toArray) assert(result === 1.to(akkaFrameSize).toArray)
val RESULT_BLOCK_ID = "taskresult_0" val RESULT_BLOCK_ID = TaskResultBlockId(0)
assert(sc.env.blockManager.master.getLocations(RESULT_BLOCK_ID).size === 0, assert(sc.env.blockManager.master.getLocations(RESULT_BLOCK_ID).size === 0,
"Expect result to be removed from the block manager.") "Expect result to be removed from the block manager.")
} }

View file

@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.storage
import org.scalatest.FunSuite
class BlockIdSuite extends FunSuite {
def assertSame(id1: BlockId, id2: BlockId) {
assert(id1.name === id2.name)
assert(id1.hashCode === id2.hashCode)
assert(id1 === id2)
}
def assertDifferent(id1: BlockId, id2: BlockId) {
assert(id1.name != id2.name)
assert(id1.hashCode != id2.hashCode)
assert(id1 != id2)
}
test("test-bad-deserialization") {
try {
// Try to deserialize an invalid block id.
BlockId("myblock")
fail()
} catch {
case e: IllegalStateException => // OK
case _ => fail()
}
}
test("rdd") {
val id = RDDBlockId(1, 2)
assertSame(id, RDDBlockId(1, 2))
assertDifferent(id, RDDBlockId(1, 1))
assert(id.name === "rdd_1_2")
assert(id.asRDDId.get.rddId === 1)
assert(id.asRDDId.get.splitIndex === 2)
assert(id.isRDD)
assertSame(id, BlockId(id.toString))
}
test("shuffle") {
val id = ShuffleBlockId(1, 2, 3)
assertSame(id, ShuffleBlockId(1, 2, 3))
assertDifferent(id, ShuffleBlockId(3, 2, 3))
assert(id.name === "shuffle_1_2_3")
assert(id.asRDDId === None)
assert(id.shuffleId === 1)
assert(id.mapId === 2)
assert(id.reduceId === 3)
assert(id.isShuffle)
assertSame(id, BlockId(id.toString))
}
test("broadcast") {
val id = BroadcastBlockId(42)
assertSame(id, BroadcastBlockId(42))
assertDifferent(id, BroadcastBlockId(123))
assert(id.name === "broadcast_42")
assert(id.asRDDId === None)
assert(id.broadcastId === 42)
assert(id.isBroadcast)
assertSame(id, BlockId(id.toString))
}
test("taskresult") {
val id = TaskResultBlockId(60)
assertSame(id, TaskResultBlockId(60))
assertDifferent(id, TaskResultBlockId(61))
assert(id.name === "taskresult_60")
assert(id.asRDDId === None)
assert(id.taskId === 60)
assert(!id.isRDD)
assertSame(id, BlockId(id.toString))
}
test("stream") {
val id = StreamBlockId(1, 100)
assertSame(id, StreamBlockId(1, 100))
assertDifferent(id, StreamBlockId(2, 101))
assert(id.name === "input-1-100")
assert(id.asRDDId === None)
assert(id.streamId === 1)
assert(id.uniqueId === 100)
assert(!id.isBroadcast)
assertSame(id, BlockId(id.toString))
}
test("test") {
val id = TestBlockId("abc")
assertSame(id, TestBlockId("abc"))
assertDifferent(id, TestBlockId("ab"))
assert(id.name === "test_abc")
assert(id.asRDDId === None)
assert(id.id === "abc")
assert(!id.isShuffle)
assertSame(id, BlockId(id.toString))
}
}

View file

@ -32,7 +32,6 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.util.{SizeEstimator, Utils, AkkaUtils, ByteBufferInputStream} import org.apache.spark.util.{SizeEstimator, Utils, AkkaUtils, ByteBufferInputStream}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester { class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester {
var store: BlockManager = null var store: BlockManager = null
var store2: BlockManager = null var store2: BlockManager = null
@ -46,6 +45,10 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
System.setProperty("spark.kryoserializer.buffer.mb", "1") System.setProperty("spark.kryoserializer.buffer.mb", "1")
val serializer = new KryoSerializer val serializer = new KryoSerializer
// Implicitly convert strings to BlockIds for test clarity.
implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value)
def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId)
before { before {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0)
this.actorSystem = actorSystem this.actorSystem = actorSystem
@ -229,31 +232,31 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
val a2 = new Array[Byte](400) val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400) val a3 = new Array[Byte](400)
// Putting a1, a2 and a3 in memory. // Putting a1, a2 and a3 in memory.
store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY)
store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY) store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY)
master.removeRdd(0, blocking = false) master.removeRdd(0, blocking = false)
eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
store.getSingle("rdd_0_0") should be (None) store.getSingle(rdd(0, 0)) should be (None)
master.getLocations("rdd_0_0") should have size 0 master.getLocations(rdd(0, 0)) should have size 0
} }
eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
store.getSingle("rdd_0_1") should be (None) store.getSingle(rdd(0, 1)) should be (None)
master.getLocations("rdd_0_1") should have size 0 master.getLocations(rdd(0, 1)) should have size 0
} }
eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
store.getSingle("nonrddblock") should not be (None) store.getSingle("nonrddblock") should not be (None)
master.getLocations("nonrddblock") should have size (1) master.getLocations("nonrddblock") should have size (1)
} }
store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY)
master.removeRdd(0, blocking = true) master.removeRdd(0, blocking = true)
store.getSingle("rdd_0_0") should be (None) store.getSingle(rdd(0, 0)) should be (None)
master.getLocations("rdd_0_0") should have size 0 master.getLocations(rdd(0, 0)) should have size 0
store.getSingle("rdd_0_1") should be (None) store.getSingle(rdd(0, 1)) should be (None)
master.getLocations("rdd_0_1") should have size 0 master.getLocations(rdd(0, 1)) should have size 0
} }
test("reregistration on heart beat") { test("reregistration on heart beat") {
@ -372,41 +375,41 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
val a1 = new Array[Byte](400) val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400) val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400) val a3 = new Array[Byte](400)
store.putSingle("rdd_0_1", a1, StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 1), a1, StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_0_2", a2, StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 2), a2, StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_0_3", a3, StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 3), a3, StorageLevel.MEMORY_ONLY)
// Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2 // Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2
// from the same RDD // from the same RDD
assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store")
assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store") assert(store.getSingle(rdd(0, 2)) != None, "rdd_0_2 was not in store")
assert(store.getSingle("rdd_0_1") != None, "rdd_0_1 was not in store") assert(store.getSingle(rdd(0, 1)) != None, "rdd_0_1 was not in store")
// Check that rdd_0_3 doesn't replace them even after further accesses // Check that rdd_0_3 doesn't replace them even after further accesses
assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store")
assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store")
assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store")
} }
test("in-memory LRU for partitions of multiple RDDs") { test("in-memory LRU for partitions of multiple RDDs") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200) store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
// At this point rdd_1_1 should've replaced rdd_0_1 // At this point rdd_1_1 should've replaced rdd_0_1
assert(store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was not in store") assert(store.memoryStore.contains(rdd(1, 1)), "rdd_1_1 was not in store")
assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store") assert(!store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was in store")
assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store") assert(store.memoryStore.contains(rdd(0, 2)), "rdd_0_2 was not in store")
// Do a get() on rdd_0_2 so that it is the most recently used item // Do a get() on rdd_0_2 so that it is the most recently used item
assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store") assert(store.getSingle(rdd(0, 2)) != None, "rdd_0_2 was not in store")
// Put in more partitions from RDD 0; they should replace rdd_1_1 // Put in more partitions from RDD 0; they should replace rdd_1_1
store.putSingle("rdd_0_3", new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 3), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle("rdd_0_4", new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 4), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
// Now rdd_1_1 should be dropped to add rdd_0_3, but then rdd_0_2 should *not* be dropped // Now rdd_1_1 should be dropped to add rdd_0_3, but then rdd_0_2 should *not* be dropped
// when we try to add rdd_0_4. // when we try to add rdd_0_4.
assert(!store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was in store") assert(!store.memoryStore.contains(rdd(1, 1)), "rdd_1_1 was in store")
assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store") assert(!store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was in store")
assert(!store.memoryStore.contains("rdd_0_4"), "rdd_0_4 was in store") assert(!store.memoryStore.contains(rdd(0, 4)), "rdd_0_4 was in store")
assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store") assert(store.memoryStore.contains(rdd(0, 2)), "rdd_0_2 was not in store")
assert(store.memoryStore.contains("rdd_0_3"), "rdd_0_3 was not in store") assert(store.memoryStore.contains(rdd(0, 3)), "rdd_0_3 was not in store")
} }
test("on-disk storage") { test("on-disk storage") {
@ -590,43 +593,46 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
try { try {
System.setProperty("spark.shuffle.compress", "true") System.setProperty("spark.shuffle.compress", "true")
store = new BlockManager("exec1", actorSystem, master, serializer, 2000) store = new BlockManager("exec1", actorSystem, master, serializer, 2000)
store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed") assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100,
"shuffle_0_0_0 was not compressed")
store.stop() store.stop()
store = null store = null
System.setProperty("spark.shuffle.compress", "false") System.setProperty("spark.shuffle.compress", "false")
store = new BlockManager("exec2", actorSystem, master, serializer, 2000) store = new BlockManager("exec2", actorSystem, master, serializer, 2000)
store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed") assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000,
"shuffle_0_0_0 was compressed")
store.stop() store.stop()
store = null store = null
System.setProperty("spark.broadcast.compress", "true") System.setProperty("spark.broadcast.compress", "true")
store = new BlockManager("exec3", actorSystem, master, serializer, 2000) store = new BlockManager("exec3", actorSystem, master, serializer, 2000)
store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed") assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100,
"broadcast_0 was not compressed")
store.stop() store.stop()
store = null store = null
System.setProperty("spark.broadcast.compress", "false") System.setProperty("spark.broadcast.compress", "false")
store = new BlockManager("exec4", actorSystem, master, serializer, 2000) store = new BlockManager("exec4", actorSystem, master, serializer, 2000)
store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed") assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed")
store.stop() store.stop()
store = null store = null
System.setProperty("spark.rdd.compress", "true") System.setProperty("spark.rdd.compress", "true")
store = new BlockManager("exec5", actorSystem, master, serializer, 2000) store = new BlockManager("exec5", actorSystem, master, serializer, 2000)
store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed") assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed")
store.stop() store.stop()
store = null store = null
System.setProperty("spark.rdd.compress", "false") System.setProperty("spark.rdd.compress", "false")
store = new BlockManager("exec6", actorSystem, master, serializer, 2000) store = new BlockManager("exec6", actorSystem, master, serializer, 2000)
store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed") assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed")
store.stop() store.stop()
store = null store = null

View file

@ -30,10 +30,11 @@ import akka.actor._
import akka.pattern.ask import akka.pattern.ask
import akka.util.duration._ import akka.util.duration._
import akka.dispatch._ import akka.dispatch._
import org.apache.spark.storage.BlockId
private[streaming] sealed trait NetworkInputTrackerMessage private[streaming] sealed trait NetworkInputTrackerMessage
private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage
private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[BlockId], metadata: Any) extends NetworkInputTrackerMessage
private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage
/** /**
@ -48,7 +49,7 @@ class NetworkInputTracker(
val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*) val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*)
val receiverExecutor = new ReceiverExecutor() val receiverExecutor = new ReceiverExecutor()
val receiverInfo = new HashMap[Int, ActorRef] val receiverInfo = new HashMap[Int, ActorRef]
val receivedBlockIds = new HashMap[Int, Queue[String]] val receivedBlockIds = new HashMap[Int, Queue[BlockId]]
val timeout = 5000.milliseconds val timeout = 5000.milliseconds
var currentTime: Time = null var currentTime: Time = null
@ -67,9 +68,9 @@ class NetworkInputTracker(
} }
/** Return all the blocks received from a receiver. */ /** Return all the blocks received from a receiver. */
def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized { def getBlockIds(receiverId: Int, time: Time): Array[BlockId] = synchronized {
val queue = receivedBlockIds.synchronized { val queue = receivedBlockIds.synchronized {
receivedBlockIds.getOrElse(receiverId, new Queue[String]()) receivedBlockIds.getOrElse(receiverId, new Queue[BlockId]())
} }
val result = queue.synchronized { val result = queue.synchronized {
queue.dequeueAll(x => true) queue.dequeueAll(x => true)
@ -92,7 +93,7 @@ class NetworkInputTracker(
case AddBlocks(streamId, blockIds, metadata) => { case AddBlocks(streamId, blockIds, metadata) => {
val tmp = receivedBlockIds.synchronized { val tmp = receivedBlockIds.synchronized {
if (!receivedBlockIds.contains(streamId)) { if (!receivedBlockIds.contains(streamId)) {
receivedBlockIds += ((streamId, new Queue[String])) receivedBlockIds += ((streamId, new Queue[BlockId]))
} }
receivedBlockIds(streamId) receivedBlockIds(streamId)
} }

View file

@ -31,7 +31,7 @@ import org.apache.spark.streaming.util.{RecurringTimer, SystemClock}
import org.apache.spark.streaming._ import org.apache.spark.streaming._
import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.rdd.{RDD, BlockRDD} import org.apache.spark.rdd.{RDD, BlockRDD}
import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId}
/** /**
* Abstract class for defining any InputDStream that has to start a receiver on worker * Abstract class for defining any InputDStream that has to start a receiver on worker
@ -69,7 +69,7 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming
val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime)
Some(new BlockRDD[T](ssc.sc, blockIds)) Some(new BlockRDD[T](ssc.sc, blockIds))
} else { } else {
Some(new BlockRDD[T](ssc.sc, Array[String]())) Some(new BlockRDD[T](ssc.sc, Array[BlockId]()))
} }
} }
} }
@ -77,7 +77,7 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming
private[streaming] sealed trait NetworkReceiverMessage private[streaming] sealed trait NetworkReceiverMessage
private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage
private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage private[streaming] case class ReportBlock(blockId: BlockId, metadata: Any) extends NetworkReceiverMessage
private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage
/** /**
@ -158,7 +158,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
/** /**
* Pushes a block (as an ArrayBuffer filled with data) into the block manager. * Pushes a block (as an ArrayBuffer filled with data) into the block manager.
*/ */
def pushBlock(blockId: String, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) { def pushBlock(blockId: BlockId, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) {
env.blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], level) env.blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], level)
actor ! ReportBlock(blockId, metadata) actor ! ReportBlock(blockId, metadata)
} }
@ -166,7 +166,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
/** /**
* Pushes a block (as bytes) into the block manager. * Pushes a block (as bytes) into the block manager.
*/ */
def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) { def pushBlock(blockId: BlockId, bytes: ByteBuffer, metadata: Any, level: StorageLevel) {
env.blockManager.putBytes(blockId, bytes, level) env.blockManager.putBytes(blockId, bytes, level)
actor ! ReportBlock(blockId, metadata) actor ! ReportBlock(blockId, metadata)
} }
@ -209,7 +209,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
class BlockGenerator(storageLevel: StorageLevel) class BlockGenerator(storageLevel: StorageLevel)
extends Serializable with Logging { extends Serializable with Logging {
case class Block(id: String, buffer: ArrayBuffer[T], metadata: Any = null) case class Block(id: BlockId, buffer: ArrayBuffer[T], metadata: Any = null)
val clock = new SystemClock() val clock = new SystemClock()
val blockInterval = System.getProperty("spark.streaming.blockInterval", "200").toLong val blockInterval = System.getProperty("spark.streaming.blockInterval", "200").toLong
@ -241,7 +241,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
val newBlockBuffer = currentBuffer val newBlockBuffer = currentBuffer
currentBuffer = new ArrayBuffer[T] currentBuffer = new ArrayBuffer[T]
if (newBlockBuffer.size > 0) { if (newBlockBuffer.size > 0) {
val blockId = "input-" + NetworkReceiver.this.streamId + "-" + (time - blockInterval) val blockId = StreamBlockId(NetworkReceiver.this.streamId, time - blockInterval)
val newBlock = new Block(blockId, newBlockBuffer) val newBlock = new Block(blockId, newBlockBuffer)
blocksForPushing.add(newBlock) blocksForPushing.add(newBlock)
} }

View file

@ -18,7 +18,7 @@
package org.apache.spark.streaming.dstream package org.apache.spark.streaming.dstream
import org.apache.spark.Logging import org.apache.spark.Logging
import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.StreamingContext
import java.net.InetSocketAddress import java.net.InetSocketAddress
@ -71,7 +71,7 @@ class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel)
var nextBlockNumber = 0 var nextBlockNumber = 0
while (true) { while (true) {
val buffer = queue.take() val buffer = queue.take()
val blockId = "input-" + streamId + "-" + nextBlockNumber val blockId = StreamBlockId(streamId, nextBlockNumber)
nextBlockNumber += 1 nextBlockNumber += 1
pushBlock(blockId, buffer, null, storageLevel) pushBlock(blockId, buffer, null, storageLevel)
} }

View file

@ -21,7 +21,7 @@ import akka.actor.{ Actor, PoisonPill, Props, SupervisorStrategy }
import akka.actor.{ actorRef2Scala, ActorRef } import akka.actor.{ actorRef2Scala, ActorRef }
import akka.actor.{ PossiblyHarmful, OneForOneStrategy } import akka.actor.{ PossiblyHarmful, OneForOneStrategy }
import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming.dstream.NetworkReceiver import org.apache.spark.streaming.dstream.NetworkReceiver
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
@ -159,7 +159,7 @@ private[streaming] class ActorReceiver[T: ClassManifest](
protected def pushBlock(iter: Iterator[T]) { protected def pushBlock(iter: Iterator[T]) {
val buffer = new ArrayBuffer[T] val buffer = new ArrayBuffer[T]
buffer ++= iter buffer ++= iter
pushBlock("block-" + streamId + "-" + System.nanoTime(), buffer, null, storageLevel) pushBlock(StreamBlockId(streamId, System.nanoTime()), buffer, null, storageLevel)
} }
protected def onStart() = { protected def onStart() = {