Modified StorageLevel and BlockManagerId to cache common objects and use cached object while deserializing.
This commit is contained in:
parent
d5e7aad039
commit
e463ae4920
|
@ -20,33 +20,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
|
||||||
import sun.nio.ch.DirectBuffer
|
import sun.nio.ch.DirectBuffer
|
||||||
|
|
||||||
|
|
||||||
private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable {
|
private[spark]
|
||||||
def this() = this(null, 0) // For deserialization only
|
|
||||||
|
|
||||||
def this(in: ObjectInput) = this(in.readUTF(), in.readInt())
|
|
||||||
|
|
||||||
override def writeExternal(out: ObjectOutput) {
|
|
||||||
out.writeUTF(ip)
|
|
||||||
out.writeInt(port)
|
|
||||||
}
|
|
||||||
|
|
||||||
override def readExternal(in: ObjectInput) {
|
|
||||||
ip = in.readUTF()
|
|
||||||
port = in.readInt()
|
|
||||||
}
|
|
||||||
|
|
||||||
override def toString = "BlockManagerId(" + ip + ", " + port + ")"
|
|
||||||
|
|
||||||
override def hashCode = ip.hashCode * 41 + port
|
|
||||||
|
|
||||||
override def equals(that: Any) = that match {
|
|
||||||
case id: BlockManagerId => port == id.port && ip == id.ip
|
|
||||||
case _ => false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private[spark]
|
|
||||||
case class BlockException(blockId: String, message: String, ex: Exception = null)
|
case class BlockException(blockId: String, message: String, ex: Exception = null)
|
||||||
extends Exception(message)
|
extends Exception(message)
|
||||||
|
|
||||||
|
|
48
core/src/main/scala/spark/storage/BlockManagerId.scala
Normal file
48
core/src/main/scala/spark/storage/BlockManagerId.scala
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
package spark.storage
|
||||||
|
|
||||||
|
import java.io.{IOException, ObjectOutput, ObjectInput, Externalizable}
|
||||||
|
import java.util.concurrent.ConcurrentHashMap
|
||||||
|
|
||||||
|
private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable {
|
||||||
|
def this() = this(null, 0) // For deserialization only
|
||||||
|
|
||||||
|
def this(in: ObjectInput) = this(in.readUTF(), in.readInt())
|
||||||
|
|
||||||
|
override def writeExternal(out: ObjectOutput) {
|
||||||
|
out.writeUTF(ip)
|
||||||
|
out.writeInt(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def readExternal(in: ObjectInput) {
|
||||||
|
ip = in.readUTF()
|
||||||
|
port = in.readInt()
|
||||||
|
}
|
||||||
|
|
||||||
|
@throws(classOf[IOException])
|
||||||
|
private def readResolve(): Object = {
|
||||||
|
BlockManagerId.getCachedBlockManagerId(this)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
override def toString = "BlockManagerId(" + ip + ", " + port + ")"
|
||||||
|
|
||||||
|
override def hashCode = ip.hashCode * 41 + port
|
||||||
|
|
||||||
|
override def equals(that: Any) = that match {
|
||||||
|
case id: BlockManagerId => port == id.port && ip == id.ip
|
||||||
|
case _ => false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
object BlockManagerId {
|
||||||
|
val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()
|
||||||
|
|
||||||
|
def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {
|
||||||
|
if (blockManagerIdCache.containsKey(id)) {
|
||||||
|
blockManagerIdCache.get(id)
|
||||||
|
} else {
|
||||||
|
blockManagerIdCache.put(id, id)
|
||||||
|
id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,6 +1,9 @@
|
||||||
package spark.storage
|
package spark.storage
|
||||||
|
|
||||||
import java.io.{Externalizable, ObjectInput, ObjectOutput}
|
import java.io.{IOException, Externalizable, ObjectInput, ObjectOutput}
|
||||||
|
import collection.mutable
|
||||||
|
import util.Random
|
||||||
|
import collection.mutable.ArrayBuffer
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory,
|
* Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory,
|
||||||
|
@ -17,7 +20,8 @@ class StorageLevel(
|
||||||
extends Externalizable {
|
extends Externalizable {
|
||||||
|
|
||||||
// TODO: Also add fields for caching priority, dataset ID, and flushing.
|
// TODO: Also add fields for caching priority, dataset ID, and flushing.
|
||||||
|
assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes")
|
||||||
|
|
||||||
def this(flags: Int, replication: Int) {
|
def this(flags: Int, replication: Int) {
|
||||||
this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication)
|
this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication)
|
||||||
}
|
}
|
||||||
|
@ -27,6 +31,10 @@ class StorageLevel(
|
||||||
override def clone(): StorageLevel = new StorageLevel(
|
override def clone(): StorageLevel = new StorageLevel(
|
||||||
this.useDisk, this.useMemory, this.deserialized, this.replication)
|
this.useDisk, this.useMemory, this.deserialized, this.replication)
|
||||||
|
|
||||||
|
override def hashCode(): Int = {
|
||||||
|
toInt * 41 + replication
|
||||||
|
}
|
||||||
|
|
||||||
override def equals(other: Any): Boolean = other match {
|
override def equals(other: Any): Boolean = other match {
|
||||||
case s: StorageLevel =>
|
case s: StorageLevel =>
|
||||||
s.useDisk == useDisk &&
|
s.useDisk == useDisk &&
|
||||||
|
@ -66,6 +74,11 @@ class StorageLevel(
|
||||||
replication = in.readByte()
|
replication = in.readByte()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@throws(classOf[IOException])
|
||||||
|
private def readResolve(): Object = {
|
||||||
|
StorageLevel.getCachedStorageLevel(this)
|
||||||
|
}
|
||||||
|
|
||||||
override def toString: String =
|
override def toString: String =
|
||||||
"StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication)
|
"StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication)
|
||||||
}
|
}
|
||||||
|
@ -82,4 +95,15 @@ object StorageLevel {
|
||||||
val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2)
|
val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2)
|
||||||
val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false)
|
val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false)
|
||||||
val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2)
|
val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2)
|
||||||
|
|
||||||
|
val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]()
|
||||||
|
|
||||||
|
def getCachedStorageLevel(level: StorageLevel): StorageLevel = {
|
||||||
|
if (storageLevelCache.containsKey(level)) {
|
||||||
|
storageLevelCache.get(level)
|
||||||
|
} else {
|
||||||
|
storageLevelCache.put(level, level)
|
||||||
|
level
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,6 +57,32 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("StorageLevel object caching") {
|
||||||
|
val level1 = new StorageLevel(false, false, false, 3)
|
||||||
|
val level2 = new StorageLevel(false, false, false, 3)
|
||||||
|
val bytes1 = spark.Utils.serialize(level1)
|
||||||
|
val level1_ = spark.Utils.deserialize[StorageLevel](bytes1)
|
||||||
|
val bytes2 = spark.Utils.serialize(level2)
|
||||||
|
val level2_ = spark.Utils.deserialize[StorageLevel](bytes2)
|
||||||
|
assert(level1_ === level1, "Deserialized level1 not same as original level1")
|
||||||
|
assert(level2_ === level2, "Deserialized level2 not same as original level1")
|
||||||
|
assert(level1_ === level2_, "Deserialized level1 not same as deserialized level2")
|
||||||
|
assert(level2_.eq(level1_), "Deserialized level2 not the same object as deserialized level1")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("BlockManagerId object caching") {
|
||||||
|
val id1 = new StorageLevel(false, false, false, 3)
|
||||||
|
val id2 = new StorageLevel(false, false, false, 3)
|
||||||
|
val bytes1 = spark.Utils.serialize(id1)
|
||||||
|
val id1_ = spark.Utils.deserialize[StorageLevel](bytes1)
|
||||||
|
val bytes2 = spark.Utils.serialize(id2)
|
||||||
|
val id2_ = spark.Utils.deserialize[StorageLevel](bytes2)
|
||||||
|
assert(id1_ === id1, "Deserialized id1 not same as original id1")
|
||||||
|
assert(id2_ === id2, "Deserialized id2 not same as original id1")
|
||||||
|
assert(id1_ === id2_, "Deserialized id1 not same as deserialized id2")
|
||||||
|
assert(id2_.eq(id1_), "Deserialized id2 not the same object as deserialized level1")
|
||||||
|
}
|
||||||
|
|
||||||
test("master + 1 manager interaction") {
|
test("master + 1 manager interaction") {
|
||||||
store = new BlockManager(master, serializer, 2000)
|
store = new BlockManager(master, serializer, 2000)
|
||||||
val a1 = new Array[Byte](400)
|
val a1 = new Array[Byte](400)
|
||||||
|
|
Loading…
Reference in a new issue