From e463ae492068d2922e1d50c051a87f8010953dff Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 28 Nov 2012 14:05:01 -0800 Subject: [PATCH] Modified StorageLevel and BlockManagerId to cache common objects and use cached object while deserializing. --- .../scala/spark/storage/BlockManager.scala | 28 +---------- .../scala/spark/storage/BlockManagerId.scala | 48 +++++++++++++++++++ .../scala/spark/storage/StorageLevel.scala | 28 ++++++++++- .../spark/storage/BlockManagerSuite.scala | 26 ++++++++++ 4 files changed, 101 insertions(+), 29 deletions(-) create mode 100644 core/src/main/scala/spark/storage/BlockManagerId.scala diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 70d6d8369d..e4aa9247a3 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -20,33 +20,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import sun.nio.ch.DirectBuffer -private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { - def this() = this(null, 0) // For deserialization only - - def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) - - override def writeExternal(out: ObjectOutput) { - out.writeUTF(ip) - out.writeInt(port) - } - - override def readExternal(in: ObjectInput) { - ip = in.readUTF() - port = in.readInt() - } - - override def toString = "BlockManagerId(" + ip + ", " + port + ")" - - override def hashCode = ip.hashCode * 41 + port - - override def equals(that: Any) = that match { - case id: BlockManagerId => port == id.port && ip == id.ip - case _ => false - } -} - - -private[spark] +private[spark] case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message) diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala new file mode 100644 index 0000000000..4933cc6606 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -0,0 +1,48 @@ +package spark.storage + +import java.io.{IOException, ObjectOutput, ObjectInput, Externalizable} +import java.util.concurrent.ConcurrentHashMap + +private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { + def this() = this(null, 0) // For deserialization only + + def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) + + override def writeExternal(out: ObjectOutput) { + out.writeUTF(ip) + out.writeInt(port) + } + + override def readExternal(in: ObjectInput) { + ip = in.readUTF() + port = in.readInt() + } + + @throws(classOf[IOException]) + private def readResolve(): Object = { + BlockManagerId.getCachedBlockManagerId(this) + } + + + override def toString = "BlockManagerId(" + ip + ", " + port + ")" + + override def hashCode = ip.hashCode * 41 + port + + override def equals(that: Any) = that match { + case id: BlockManagerId => port == id.port && ip == id.ip + case _ => false + } +} + +object BlockManagerId { + val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() + + def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { + if (blockManagerIdCache.containsKey(id)) { + blockManagerIdCache.get(id) + } else { + blockManagerIdCache.put(id, id) + id + } + } +} \ No newline at end of file diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index c497f03e0c..eb88eb2759 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -1,6 +1,9 @@ package spark.storage -import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.io.{IOException, Externalizable, ObjectInput, ObjectOutput} +import collection.mutable +import util.Random +import collection.mutable.ArrayBuffer /** * Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, @@ -17,7 +20,8 @@ class StorageLevel( extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. - + assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") + def this(flags: Int, replication: Int) { this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } @@ -27,6 +31,10 @@ class StorageLevel( override def clone(): StorageLevel = new StorageLevel( this.useDisk, this.useMemory, this.deserialized, this.replication) + override def hashCode(): Int = { + toInt * 41 + replication + } + override def equals(other: Any): Boolean = other match { case s: StorageLevel => s.useDisk == useDisk && @@ -66,6 +74,11 @@ class StorageLevel( replication = in.readByte() } + @throws(classOf[IOException]) + private def readResolve(): Object = { + StorageLevel.getCachedStorageLevel(this) + } + override def toString: String = "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication) } @@ -82,4 +95,15 @@ object StorageLevel { val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2) val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false) val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2) + + val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() + + def getCachedStorageLevel(level: StorageLevel): StorageLevel = { + if (storageLevelCache.containsKey(level)) { + storageLevelCache.get(level) + } else { + storageLevelCache.put(level, level) + level + } + } } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 0e78228134..a2d5e39859 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -57,6 +57,32 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } } + test("StorageLevel object caching") { + val level1 = new StorageLevel(false, false, false, 3) + val level2 = new StorageLevel(false, false, false, 3) + val bytes1 = spark.Utils.serialize(level1) + val level1_ = spark.Utils.deserialize[StorageLevel](bytes1) + val bytes2 = spark.Utils.serialize(level2) + val level2_ = spark.Utils.deserialize[StorageLevel](bytes2) + assert(level1_ === level1, "Deserialized level1 not same as original level1") + assert(level2_ === level2, "Deserialized level2 not same as original level1") + assert(level1_ === level2_, "Deserialized level1 not same as deserialized level2") + assert(level2_.eq(level1_), "Deserialized level2 not the same object as deserialized level1") + } + + test("BlockManagerId object caching") { + val id1 = new StorageLevel(false, false, false, 3) + val id2 = new StorageLevel(false, false, false, 3) + val bytes1 = spark.Utils.serialize(id1) + val id1_ = spark.Utils.deserialize[StorageLevel](bytes1) + val bytes2 = spark.Utils.serialize(id2) + val id2_ = spark.Utils.deserialize[StorageLevel](bytes2) + assert(id1_ === id1, "Deserialized id1 not same as original id1") + assert(id2_ === id2, "Deserialized id2 not same as original id1") + assert(id1_ === id2_, "Deserialized id1 not same as deserialized id2") + assert(id2_.eq(id1_), "Deserialized id2 not the same object as deserialized level1") + } + test("master + 1 manager interaction") { store = new BlockManager(master, serializer, 2000) val a1 = new Array[Byte](400)