From bacade6caf7527737dc6f02b1c2ca9114e02d8bc Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Jan 2013 22:55:26 -0800 Subject: [PATCH] Modified BlockManagerId API to ensure zero duplicate objects. Fixed BlockManagerId testcase in BlockManagerTestSuite. --- .../scala/spark/scheduler/MapStatus.scala | 2 +- .../scala/spark/storage/BlockManager.scala | 2 +- .../scala/spark/storage/BlockManagerId.scala | 33 +++++++++++++++---- .../spark/storage/BlockManagerMessages.scala | 3 +- .../scala/spark/MapOutputTrackerSuite.scala | 22 ++++++------- .../spark/storage/BlockManagerSuite.scala | 18 +++++----- 6 files changed, 51 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala index 4532d9497f..fae643f3a8 100644 --- a/core/src/main/scala/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/spark/scheduler/MapStatus.scala @@ -20,7 +20,7 @@ private[spark] class MapStatus(var address: BlockManagerId, var compressedSizes: } def readExternal(in: ObjectInput) { - address = new BlockManagerId(in) + address = BlockManagerId(in) compressedSizes = new Array[Byte](in.readInt()) in.readFully(compressedSizes) } diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 7a8ac10cdd..596a69c583 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -69,7 +69,7 @@ class BlockManager( implicit val futureExecContext = connectionManager.futureExecContext val connectionManagerId = connectionManager.id - val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port) + val blockManagerId = BlockManagerId(connectionManagerId.host, connectionManagerId.port) // TODO: This will be removed after cacheTracker is removed from the code base. var cacheTracker: CacheTracker = null diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala index 488679f049..26c98f2ac8 100644 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -3,20 +3,35 @@ package spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.ConcurrentHashMap +/** + * This class represent an unique identifier for a BlockManager. + * The first 2 constructors of this class is made private to ensure that + * BlockManagerId objects can be created only using the factory method in + * [[spark.storage.BlockManager$]]. This allows de-duplication of id objects. + * Also, constructor parameters are private to ensure that parameters cannot + * be modified from outside this class. + */ +private[spark] class BlockManagerId private ( + private var ip_ : String, + private var port_ : Int + ) extends Externalizable { + + private def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) -private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { def this() = this(null, 0) // For deserialization only - def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) + def ip = ip_ + + def port = port_ override def writeExternal(out: ObjectOutput) { - out.writeUTF(ip) - out.writeInt(port) + out.writeUTF(ip_) + out.writeInt(port_) } override def readExternal(in: ObjectInput) { - ip = in.readUTF() - port = in.readInt() + ip_ = in.readUTF() + port_ = in.readInt() } @throws(classOf[IOException]) @@ -35,6 +50,12 @@ private[spark] class BlockManagerId(var ip: String, var port: Int) extends Exter private[spark] object BlockManagerId { + def apply(ip: String, port: Int) = + getCachedBlockManagerId(new BlockManagerId(ip, port)) + + def apply(in: ObjectInput) = + getCachedBlockManagerId(new BlockManagerId(in)) + val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index d73a9b790f..7437fc63eb 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -54,8 +54,7 @@ class UpdateBlockInfo( } override def readExternal(in: ObjectInput) { - blockManagerId = new BlockManagerId() - blockManagerId.readExternal(in) + blockManagerId = BlockManagerId(in) blockId = in.readUTF() storageLevel = new StorageLevel() storageLevel.readExternal(in) diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index d3dd3a8fa4..095f415978 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -47,13 +47,13 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000), Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000), Array(compressedSize10000, compressedSize1000))) val statuses = tracker.getServerStatuses(10, 0) - assert(statuses.toSeq === Seq((new BlockManagerId("hostA", 1000), size1000), - (new BlockManagerId("hostB", 1000), size10000))) + assert(statuses.toSeq === Seq((BlockManagerId("hostA", 1000), size1000), + (BlockManagerId("hostB", 1000), size10000))) tracker.stop() } @@ -65,14 +65,14 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000), Array(compressedSize1000, compressedSize1000, compressedSize1000))) - tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000), Array(compressedSize10000, compressedSize1000, compressedSize1000))) // As if we had two simulatenous fetch failures - tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) - tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) // The remaining reduce task might try to grab the output dispite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the @@ -95,13 +95,13 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { val compressedSize1000 = MapOutputTracker.compressSize(1000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) masterTracker.registerMapOutput(10, 0, new MapStatus( - new BlockManagerId("hostA", 1000), Array(compressedSize1000))) + BlockManagerId("hostA", 1000), Array(compressedSize1000))) masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((new BlockManagerId("hostA", 1000), size1000))) + Seq((BlockManagerId("hostA", 1000), size1000))) - masterTracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) + masterTracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 8f86e3170e..a33d3324ba 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -82,16 +82,18 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("BlockManagerId object caching") { - val id1 = new StorageLevel(false, false, false, 3) - val id2 = new StorageLevel(false, false, false, 3) + val id1 = BlockManagerId("XXX", 1) + val id2 = BlockManagerId("XXX", 1) // this should return the same object as id1 + assert(id2 === id1, "id2 is not same as id1") + assert(id2.eq(id1), "id2 is not the same object as id1") val bytes1 = spark.Utils.serialize(id1) - val id1_ = spark.Utils.deserialize[StorageLevel](bytes1) + val id1_ = spark.Utils.deserialize[BlockManagerId](bytes1) val bytes2 = spark.Utils.serialize(id2) - val id2_ = spark.Utils.deserialize[StorageLevel](bytes2) - assert(id1_ === id1, "Deserialized id1 not same as original id1") - assert(id2_ === id2, "Deserialized id2 not same as original id1") - assert(id1_ === id2_, "Deserialized id1 not same as deserialized id2") - assert(id2_.eq(id1_), "Deserialized id2 not the same object as deserialized level1") + val id2_ = spark.Utils.deserialize[BlockManagerId](bytes2) + assert(id1_ === id1, "Deserialized id1 is not same as original id1") + assert(id1_.eq(id1), "Deserialized id1 is not the same object as original id1") + assert(id2_ === id2, "Deserialized id2 is not same as original id2") + assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1") } test("master + 1 manager interaction") {