Add unit test and address rest of Reynold's comments

This commit is contained in:
Aaron Davidson 2013-10-12 22:28:31 -07:00
parent a395911138
commit d60352283c
10 changed files with 144 additions and 20 deletions

View file

@ -36,7 +36,7 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
@Override @Override
public void messageReceived(ChannelHandlerContext ctx, String blockIdString) { public void messageReceived(ChannelHandlerContext ctx, String blockIdString) {
BlockId blockId = BlockId.fromString(blockIdString); BlockId blockId = BlockId.apply(blockIdString);
String path = pResolver.getAbsolutePath(blockId.filename()); String path = pResolver.getAbsolutePath(blockId.filename());
// if getFilePath returns null, close the channel // if getFilePath returns null, close the channel
if (path == null) { if (path == null) {

View file

@ -58,7 +58,7 @@ private[spark] object FileHeader {
for (i <- 1 to idLength) { for (i <- 1 to idLength) {
idBuilder += buf.readByte().asInstanceOf[Char] idBuilder += buf.readByte().asInstanceOf[Char]
} }
val blockId = BlockId.fromString(idBuilder.toString()) val blockId = BlockId(idBuilder.toString())
new FileHeader(length, blockId) new FileHeader(length, blockId)
} }

View file

@ -100,7 +100,7 @@ private[spark] object ShuffleCopier extends Logging {
} }
val host = args(0) val host = args(0)
val port = args(1).toInt val port = args(1).toInt
val blockId = BlockId.fromString(args(2)) val blockId = BlockId(args(2))
val threads = if (args.length > 3) args(3).toInt else 10 val threads = if (args.length > 3) args(3).toInt else 10
val copiers = Executors.newFixedThreadPool(80) val copiers = Executors.newFixedThreadPool(80)

View file

@ -55,7 +55,7 @@ private[spark] object ShuffleSender {
val pResovler = new PathResolver { val pResovler = new PathResolver {
override def getAbsolutePath(blockIdString: String): String = { override def getAbsolutePath(blockIdString: String): String = {
val blockId = BlockId.fromString(blockIdString) val blockId = BlockId(blockIdString)
if (!blockId.isShuffle) { if (!blockId.isShuffle) {
throw new Exception("Block " + blockId + " is not a shuffle block") throw new Exception("Block " + blockId + " is not a shuffle block")
} }

View file

@ -42,28 +42,29 @@ private[spark] abstract class BlockId {
} }
} }
case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId {
def filename = "rdd_" + rddId + "_" + splitIndex def filename = "rdd_" + rddId + "_" + splitIndex
} }
private[spark]
case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
def filename = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId def filename = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
} }
case class BroadcastBlockId(broadcastId: Long) extends BlockId { private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
def filename = "broadcast_" + broadcastId def filename = "broadcast_" + broadcastId
} }
case class TaskResultBlockId(taskId: Long) extends BlockId { private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId {
def filename = "taskresult_" + taskId def filename = "taskresult_" + taskId
} }
case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId { private[spark] case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId {
def filename = "input-" + streamId + "-" + uniqueId def filename = "input-" + streamId + "-" + uniqueId
} }
// Intended only for testing purposes // Intended only for testing purposes
case class TestBlockId(id: String) extends BlockId { private[spark] case class TestBlockId(id: String) extends BlockId {
def filename = "test_" + id def filename = "test_" + id
} }
@ -76,7 +77,7 @@ private[spark] object BlockId {
val StreamInput = "input-([0-9]+)-([0-9]+)".r val StreamInput = "input-([0-9]+)-([0-9]+)".r
val Test = "test_(.*)".r val Test = "test_(.*)".r
def fromString(id: String) = id match { def apply(id: String) = id match {
case RDD(rddId, splitIndex) => RDDBlockId(rddId.toInt, splitIndex.toInt) case RDD(rddId, splitIndex) => RDDBlockId(rddId.toInt, splitIndex.toInt)
case Shuffle(shuffleId, mapId, reduceId) => case Shuffle(shuffleId, mapId, reduceId) =>
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)

View file

@ -70,7 +70,7 @@ private[storage] object BlockManagerMessages {
override def readExternal(in: ObjectInput) { override def readExternal(in: ObjectInput) {
blockManagerId = BlockManagerId(in) blockManagerId = BlockManagerId(in)
blockId = BlockId.fromString(in.readUTF()) blockId = BlockId(in.readUTF())
storageLevel = StorageLevel(in) storageLevel = StorageLevel(in)
memSize = in.readLong() memSize = in.readLong()
diskSize = in.readLong() diskSize = in.readLong()

View file

@ -74,7 +74,7 @@ private[spark] class BlockMessage() {
for (i <- 1 to idLength) { for (i <- 1 to idLength) {
idBuilder += buffer.getChar() idBuilder += buffer.getChar()
} }
id = BlockId.fromString(idBuilder.toString) id = BlockId(idBuilder.toString)
if (typ == BlockMessage.TYPE_PUT_BLOCK) { if (typ == BlockMessage.TYPE_PUT_BLOCK) {
@ -117,7 +117,7 @@ private[spark] class BlockMessage() {
def toBufferMessage: BufferMessage = { def toBufferMessage: BufferMessage = {
val startTime = System.currentTimeMillis val startTime = System.currentTimeMillis
val buffers = new ArrayBuffer[ByteBuffer]() val buffers = new ArrayBuffer[ByteBuffer]()
var buffer = ByteBuffer.allocate(4 + 4 + id.filename.length * 2) // TODO: Why x2? var buffer = ByteBuffer.allocate(4 + 4 + id.filename.length * 2)
buffer.putInt(typ).putInt(id.filename.length) buffer.putInt(typ).putInt(id.filename.length)
id.filename.foreach((x: Char) => buffer.putChar(x)) id.filename.foreach((x: Char) => buffer.putChar(x))
buffer.flip() buffer.flip()
@ -201,8 +201,8 @@ private[spark] object BlockMessage {
def main(args: Array[String]) { def main(args: Array[String]) {
val B = new BlockMessage() val B = new BlockMessage()
B.set(new PutBlock( val blockId = TestBlockId("ABC")
new TestBlockId("ABC"), ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2)) B.set(new PutBlock(blockId, ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2))
val bMsg = B.toBufferMessage val bMsg = B.toBufferMessage
val C = new BlockMessage() val C = new BlockMessage()
C.set(bMsg) C.set(bMsg)

View file

@ -316,7 +316,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private[storage] def startShuffleBlockSender(port: Int): Int = { private[storage] def startShuffleBlockSender(port: Int): Int = {
val pResolver = new PathResolver { val pResolver = new PathResolver {
override def getAbsolutePath(blockIdString: String): String = { override def getAbsolutePath(blockIdString: String): String = {
val blockId = BlockId.fromString(blockIdString) val blockId = BlockId(blockIdString)
if (!blockId.isShuffle) null if (!blockId.isShuffle) null
else DiskStore.this.getFile(blockId).getAbsolutePath else DiskStore.this.getFile(blockId).getAbsolutePath
} }

View file

@ -25,15 +25,15 @@ private[spark]
case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
blocks: Map[BlockId, BlockStatus]) { blocks: Map[BlockId, BlockStatus]) {
def memUsed() = blocks.values.map(_.memSize).reduceOption(_+_).getOrElse(0l) def memUsed() = blocks.values.map(_.memSize).reduceOption(_+_).getOrElse(0L)
def memUsedByRDD(rddId: Int) = def memUsedByRDD(rddId: Int) =
rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_+_).getOrElse(0l) rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_+_).getOrElse(0L)
def diskUsed() = blocks.values.map(_.diskSize).reduceOption(_+_).getOrElse(0l) def diskUsed() = blocks.values.map(_.diskSize).reduceOption(_+_).getOrElse(0L)
def diskUsedByRDD(rddId: Int) = def diskUsedByRDD(rddId: Int) =
rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_+_).getOrElse(0l) rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_+_).getOrElse(0L)
def memRemaining : Long = maxMem - memUsed() def memRemaining : Long = maxMem - memUsed()

View file

@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.storage
import org.scalatest.FunSuite
class BlockIdSuite extends FunSuite {
def assertSame(id1: BlockId, id2: BlockId) {
assert(id1.filename === id2.filename)
assert(id1.toString === id2.toString)
assert(id1.hashCode === id2.hashCode)
assert(id1 === id2)
}
def assertDifferent(id1: BlockId, id2: BlockId) {
assert(id1.filename != id2.filename)
assert(id1.toString != id2.toString)
assert(id1.hashCode != id2.hashCode)
assert(id1 != id2)
}
test("basic-functions") {
case class MyBlockId(filename: String) extends BlockId
val id = MyBlockId("a")
assertSame(id, MyBlockId("a"))
assertDifferent(id, MyBlockId("b"))
assert(id.asRDDId === None)
try {
// Try to deserialize an invalid block id.
BlockId("a")
fail()
} catch {
case e: IllegalStateException => // OK
case _ => fail()
}
}
test("rdd") {
val id = RDDBlockId(1, 2)
assertSame(id, RDDBlockId(1, 2))
assertDifferent(id, RDDBlockId(1, 1))
assert(id.toString === "rdd_1_2")
assert(id.asRDDId.get.rddId === 1)
assert(id.asRDDId.get.splitIndex === 2)
assert(id.isRDD)
assertSame(id, BlockId(id.toString))
}
test("shuffle") {
val id = ShuffleBlockId(1, 2, 3)
assertSame(id, ShuffleBlockId(1, 2, 3))
assertDifferent(id, ShuffleBlockId(3, 2, 3))
assert(id.toString === "shuffle_1_2_3")
assert(id.asRDDId === None)
assert(id.shuffleId === 1)
assert(id.mapId === 2)
assert(id.reduceId === 3)
assert(id.isShuffle)
assertSame(id, BlockId(id.toString))
}
test("broadcast") {
val id = BroadcastBlockId(42)
assertSame(id, BroadcastBlockId(42))
assertDifferent(id, BroadcastBlockId(123))
assert(id.toString === "broadcast_42")
assert(id.asRDDId === None)
assert(id.broadcastId === 42)
assert(id.isBroadcast)
assertSame(id, BlockId(id.toString))
}
test("taskresult") {
val id = TaskResultBlockId(60)
assertSame(id, TaskResultBlockId(60))
assertDifferent(id, TaskResultBlockId(61))
assert(id.toString === "taskresult_60")
assert(id.asRDDId === None)
assert(id.taskId === 60)
assert(!id.isRDD)
assertSame(id, BlockId(id.toString))
}
test("stream") {
val id = StreamBlockId(1, 100)
assertSame(id, StreamBlockId(1, 100))
assertDifferent(id, StreamBlockId(2, 101))
assert(id.toString === "input-1-100")
assert(id.asRDDId === None)
assert(id.streamId === 1)
assert(id.uniqueId === 100)
assert(!id.isBroadcast)
assertSame(id, BlockId(id.toString))
}
test("test") {
val id = TestBlockId("abc")
assertSame(id, TestBlockId("abc"))
assertDifferent(id, TestBlockId("ab"))
assert(id.toString === "test_abc")
assert(id.asRDDId === None)
assert(id.id === "abc")
assert(!id.isShuffle)
assertSame(id, BlockId(id.toString))
}
}