Modified StorageLevel and BlockManagerId to cache common objects and use cached object while deserializing.

This commit is contained in:
Tathagata Das 2012-11-28 14:05:01 -08:00
parent d5e7aad039
commit e463ae4920
4 changed files with 101 additions and 29 deletions

View file

@ -20,33 +20,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import sun.nio.ch.DirectBuffer
private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable {
def this() = this(null, 0) // For deserialization only
def this(in: ObjectInput) = this(in.readUTF(), in.readInt())
override def writeExternal(out: ObjectOutput) {
out.writeUTF(ip)
out.writeInt(port)
}
override def readExternal(in: ObjectInput) {
ip = in.readUTF()
port = in.readInt()
}
override def toString = "BlockManagerId(" + ip + ", " + port + ")"
override def hashCode = ip.hashCode * 41 + port
override def equals(that: Any) = that match {
case id: BlockManagerId => port == id.port && ip == id.ip
case _ => false
}
}
private[spark]
private[spark]
case class BlockException(blockId: String, message: String, ex: Exception = null)
extends Exception(message)

View file

@ -0,0 +1,48 @@
package spark.storage
import java.io.{IOException, ObjectOutput, ObjectInput, Externalizable}
import java.util.concurrent.ConcurrentHashMap
private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable {
def this() = this(null, 0) // For deserialization only
def this(in: ObjectInput) = this(in.readUTF(), in.readInt())
override def writeExternal(out: ObjectOutput) {
out.writeUTF(ip)
out.writeInt(port)
}
override def readExternal(in: ObjectInput) {
ip = in.readUTF()
port = in.readInt()
}
@throws(classOf[IOException])
private def readResolve(): Object = {
BlockManagerId.getCachedBlockManagerId(this)
}
override def toString = "BlockManagerId(" + ip + ", " + port + ")"
override def hashCode = ip.hashCode * 41 + port
override def equals(that: Any) = that match {
case id: BlockManagerId => port == id.port && ip == id.ip
case _ => false
}
}
object BlockManagerId {
val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()
def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {
if (blockManagerIdCache.containsKey(id)) {
blockManagerIdCache.get(id)
} else {
blockManagerIdCache.put(id, id)
id
}
}
}

View file

@ -1,6 +1,9 @@
package spark.storage
import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.io.{IOException, Externalizable, ObjectInput, ObjectOutput}
import collection.mutable
import util.Random
import collection.mutable.ArrayBuffer
/**
* Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory,
@ -17,7 +20,8 @@ class StorageLevel(
extends Externalizable {
// TODO: Also add fields for caching priority, dataset ID, and flushing.
assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes")
def this(flags: Int, replication: Int) {
this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication)
}
@ -27,6 +31,10 @@ class StorageLevel(
override def clone(): StorageLevel = new StorageLevel(
this.useDisk, this.useMemory, this.deserialized, this.replication)
override def hashCode(): Int = {
toInt * 41 + replication
}
override def equals(other: Any): Boolean = other match {
case s: StorageLevel =>
s.useDisk == useDisk &&
@ -66,6 +74,11 @@ class StorageLevel(
replication = in.readByte()
}
@throws(classOf[IOException])
private def readResolve(): Object = {
StorageLevel.getCachedStorageLevel(this)
}
override def toString: String =
"StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication)
}
@ -82,4 +95,15 @@ object StorageLevel {
val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2)
val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false)
val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2)
val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]()
def getCachedStorageLevel(level: StorageLevel): StorageLevel = {
if (storageLevelCache.containsKey(level)) {
storageLevelCache.get(level)
} else {
storageLevelCache.put(level, level)
level
}
}
}

View file

@ -57,6 +57,32 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
}
test("StorageLevel object caching") {
val level1 = new StorageLevel(false, false, false, 3)
val level2 = new StorageLevel(false, false, false, 3)
val bytes1 = spark.Utils.serialize(level1)
val level1_ = spark.Utils.deserialize[StorageLevel](bytes1)
val bytes2 = spark.Utils.serialize(level2)
val level2_ = spark.Utils.deserialize[StorageLevel](bytes2)
assert(level1_ === level1, "Deserialized level1 not same as original level1")
assert(level2_ === level2, "Deserialized level2 not same as original level1")
assert(level1_ === level2_, "Deserialized level1 not same as deserialized level2")
assert(level2_.eq(level1_), "Deserialized level2 not the same object as deserialized level1")
}
test("BlockManagerId object caching") {
val id1 = new StorageLevel(false, false, false, 3)
val id2 = new StorageLevel(false, false, false, 3)
val bytes1 = spark.Utils.serialize(id1)
val id1_ = spark.Utils.deserialize[StorageLevel](bytes1)
val bytes2 = spark.Utils.serialize(id2)
val id2_ = spark.Utils.deserialize[StorageLevel](bytes2)
assert(id1_ === id1, "Deserialized id1 not same as original id1")
assert(id2_ === id2, "Deserialized id2 not same as original id1")
assert(id1_ === id2_, "Deserialized id1 not same as deserialized id2")
assert(id2_.eq(id1_), "Deserialized id2 not the same object as deserialized level1")
}
test("master + 1 manager interaction") {
store = new BlockManager(master, serializer, 2000)
val a1 = new Array[Byte](400)