Add unit test and address rest of Reynold's comments

This commit is contained in:
Aaron Davidson 2013-10-12 22:28:31 -07:00
parent a395911138
commit d60352283c
10 changed files with 144 additions and 20 deletions

View file

@ -36,7 +36,7 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
@Override
public void messageReceived(ChannelHandlerContext ctx, String blockIdString) {
BlockId blockId = BlockId.fromString(blockIdString);
BlockId blockId = BlockId.apply(blockIdString);
String path = pResolver.getAbsolutePath(blockId.filename());
// if getFilePath returns null, close the channel
if (path == null) {

View file

@ -58,7 +58,7 @@ private[spark] object FileHeader {
for (i <- 1 to idLength) {
idBuilder += buf.readByte().asInstanceOf[Char]
}
val blockId = BlockId.fromString(idBuilder.toString())
val blockId = BlockId(idBuilder.toString())
new FileHeader(length, blockId)
}

View file

@ -100,7 +100,7 @@ private[spark] object ShuffleCopier extends Logging {
}
val host = args(0)
val port = args(1).toInt
val blockId = BlockId.fromString(args(2))
val blockId = BlockId(args(2))
val threads = if (args.length > 3) args(3).toInt else 10
val copiers = Executors.newFixedThreadPool(80)

View file

@ -55,7 +55,7 @@ private[spark] object ShuffleSender {
val pResovler = new PathResolver {
override def getAbsolutePath(blockIdString: String): String = {
val blockId = BlockId.fromString(blockIdString)
val blockId = BlockId(blockIdString)
if (!blockId.isShuffle) {
throw new Exception("Block " + blockId + " is not a shuffle block")
}

View file

@ -42,28 +42,29 @@ private[spark] abstract class BlockId {
}
}
case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId {
private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId {
def filename = "rdd_" + rddId + "_" + splitIndex
}
private[spark]
case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
def filename = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
}
case class BroadcastBlockId(broadcastId: Long) extends BlockId {
private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
def filename = "broadcast_" + broadcastId
}
case class TaskResultBlockId(taskId: Long) extends BlockId {
private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId {
def filename = "taskresult_" + taskId
}
case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId {
private[spark] case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId {
def filename = "input-" + streamId + "-" + uniqueId
}
// Intended only for testing purposes
case class TestBlockId(id: String) extends BlockId {
private[spark] case class TestBlockId(id: String) extends BlockId {
def filename = "test_" + id
}
@ -76,7 +77,7 @@ private[spark] object BlockId {
val StreamInput = "input-([0-9]+)-([0-9]+)".r
val Test = "test_(.*)".r
def fromString(id: String) = id match {
def apply(id: String) = id match {
case RDD(rddId, splitIndex) => RDDBlockId(rddId.toInt, splitIndex.toInt)
case Shuffle(shuffleId, mapId, reduceId) =>
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)

View file

@ -70,7 +70,7 @@ private[storage] object BlockManagerMessages {
override def readExternal(in: ObjectInput) {
blockManagerId = BlockManagerId(in)
blockId = BlockId.fromString(in.readUTF())
blockId = BlockId(in.readUTF())
storageLevel = StorageLevel(in)
memSize = in.readLong()
diskSize = in.readLong()

View file

@ -74,7 +74,7 @@ private[spark] class BlockMessage() {
for (i <- 1 to idLength) {
idBuilder += buffer.getChar()
}
id = BlockId.fromString(idBuilder.toString)
id = BlockId(idBuilder.toString)
if (typ == BlockMessage.TYPE_PUT_BLOCK) {
@ -117,7 +117,7 @@ private[spark] class BlockMessage() {
def toBufferMessage: BufferMessage = {
val startTime = System.currentTimeMillis
val buffers = new ArrayBuffer[ByteBuffer]()
var buffer = ByteBuffer.allocate(4 + 4 + id.filename.length * 2) // TODO: Why x2?
var buffer = ByteBuffer.allocate(4 + 4 + id.filename.length * 2)
buffer.putInt(typ).putInt(id.filename.length)
id.filename.foreach((x: Char) => buffer.putChar(x))
buffer.flip()
@ -201,8 +201,8 @@ private[spark] object BlockMessage {
def main(args: Array[String]) {
val B = new BlockMessage()
B.set(new PutBlock(
new TestBlockId("ABC"), ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2))
val blockId = TestBlockId("ABC")
B.set(new PutBlock(blockId, ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2))
val bMsg = B.toBufferMessage
val C = new BlockMessage()
C.set(bMsg)

View file

@ -316,7 +316,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private[storage] def startShuffleBlockSender(port: Int): Int = {
val pResolver = new PathResolver {
override def getAbsolutePath(blockIdString: String): String = {
val blockId = BlockId.fromString(blockIdString)
val blockId = BlockId(blockIdString)
if (!blockId.isShuffle) null
else DiskStore.this.getFile(blockId).getAbsolutePath
}

View file

@ -25,15 +25,15 @@ private[spark]
case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
blocks: Map[BlockId, BlockStatus]) {
def memUsed() = blocks.values.map(_.memSize).reduceOption(_+_).getOrElse(0l)
def memUsed() = blocks.values.map(_.memSize).reduceOption(_+_).getOrElse(0L)
def memUsedByRDD(rddId: Int) =
rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_+_).getOrElse(0l)
rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_+_).getOrElse(0L)
def diskUsed() = blocks.values.map(_.diskSize).reduceOption(_+_).getOrElse(0l)
def diskUsed() = blocks.values.map(_.diskSize).reduceOption(_+_).getOrElse(0L)
def diskUsedByRDD(rddId: Int) =
rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_+_).getOrElse(0l)
rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_+_).getOrElse(0L)
def memRemaining : Long = maxMem - memUsed()

View file

@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.storage
import org.scalatest.FunSuite
class BlockIdSuite extends FunSuite {
def assertSame(id1: BlockId, id2: BlockId) {
assert(id1.filename === id2.filename)
assert(id1.toString === id2.toString)
assert(id1.hashCode === id2.hashCode)
assert(id1 === id2)
}
def assertDifferent(id1: BlockId, id2: BlockId) {
assert(id1.filename != id2.filename)
assert(id1.toString != id2.toString)
assert(id1.hashCode != id2.hashCode)
assert(id1 != id2)
}
test("basic-functions") {
case class MyBlockId(filename: String) extends BlockId
val id = MyBlockId("a")
assertSame(id, MyBlockId("a"))
assertDifferent(id, MyBlockId("b"))
assert(id.asRDDId === None)
try {
// Try to deserialize an invalid block id.
BlockId("a")
fail()
} catch {
case e: IllegalStateException => // OK
case _ => fail()
}
}
test("rdd") {
val id = RDDBlockId(1, 2)
assertSame(id, RDDBlockId(1, 2))
assertDifferent(id, RDDBlockId(1, 1))
assert(id.toString === "rdd_1_2")
assert(id.asRDDId.get.rddId === 1)
assert(id.asRDDId.get.splitIndex === 2)
assert(id.isRDD)
assertSame(id, BlockId(id.toString))
}
test("shuffle") {
val id = ShuffleBlockId(1, 2, 3)
assertSame(id, ShuffleBlockId(1, 2, 3))
assertDifferent(id, ShuffleBlockId(3, 2, 3))
assert(id.toString === "shuffle_1_2_3")
assert(id.asRDDId === None)
assert(id.shuffleId === 1)
assert(id.mapId === 2)
assert(id.reduceId === 3)
assert(id.isShuffle)
assertSame(id, BlockId(id.toString))
}
test("broadcast") {
val id = BroadcastBlockId(42)
assertSame(id, BroadcastBlockId(42))
assertDifferent(id, BroadcastBlockId(123))
assert(id.toString === "broadcast_42")
assert(id.asRDDId === None)
assert(id.broadcastId === 42)
assert(id.isBroadcast)
assertSame(id, BlockId(id.toString))
}
test("taskresult") {
val id = TaskResultBlockId(60)
assertSame(id, TaskResultBlockId(60))
assertDifferent(id, TaskResultBlockId(61))
assert(id.toString === "taskresult_60")
assert(id.asRDDId === None)
assert(id.taskId === 60)
assert(!id.isRDD)
assertSame(id, BlockId(id.toString))
}
test("stream") {
val id = StreamBlockId(1, 100)
assertSame(id, StreamBlockId(1, 100))
assertDifferent(id, StreamBlockId(2, 101))
assert(id.toString === "input-1-100")
assert(id.asRDDId === None)
assert(id.streamId === 1)
assert(id.uniqueId === 100)
assert(!id.isBroadcast)
assertSame(id, BlockId(id.toString))
}
test("test") {
val id = TestBlockId("abc")
assertSame(id, TestBlockId("abc"))
assertDifferent(id, TestBlockId("ab"))
assert(id.toString === "test_abc")
assert(id.asRDDId === None)
assert(id.id === "abc")
assert(!id.isShuffle)
assertSame(id, BlockId(id.toString))
}
}