This commit makes three changes to the (PrimitiveKey)OpenHashMap

1) _keySet  --renamed--> keySet
  2) keySet and _values are made externally accessible
  3) added an update function which merges duplicate values
This commit is contained in:
Joseph E. Gonzalez 2013-10-31 18:09:42 -07:00
parent d74ad4ebc9
commit 4ad58e2b9a
2 changed files with 71 additions and 25 deletions

View file

@ -27,14 +27,21 @@ package org.apache.spark.util.hash
*/
private[spark]
class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V: ClassManifest](
initialCapacity: Int)
var keySet: OpenHashSet[K], var _values: Array[V])
extends Iterable[(K, V)]
with Serializable {
def this() = this(64)
/**
* Allocate an OpenHashMap with a fixed initial capacity
*/
def this(initialCapacity: Int = 64) =
this(new OpenHashSet[K](initialCapacity), new Array[V](initialCapacity))
/**
* Allocate an OpenHashMap with a fixed initial capacity
*/
def this(keySet: OpenHashSet[K]) = this(keySet, new Array[V](keySet.capacity))
protected var _keySet = new OpenHashSet[K](initialCapacity)
private var _values = new Array[V](_keySet.capacity)
@transient private var _oldValues: Array[V] = null
@ -42,14 +49,14 @@ class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V:
private var haveNullValue = false
private var nullValue: V = null.asInstanceOf[V]
override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size
override def size: Int = if (haveNullValue) keySet.size + 1 else keySet.size
/** Get the value for a given key */
def apply(k: K): V = {
if (k == null) {
nullValue
} else {
val pos = _keySet.getPos(k)
val pos = keySet.getPos(k)
if (pos < 0) {
null.asInstanceOf[V]
} else {
@ -64,9 +71,26 @@ class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V:
haveNullValue = true
nullValue = v
} else {
val pos = _keySet.fastAdd(k) & OpenHashSet.POSITION_MASK
val pos = keySet.fastAdd(k) & OpenHashSet.POSITION_MASK
_values(pos) = v
_keySet.rehashIfNeeded(k, grow, move)
keySet.rehashIfNeeded(k, grow, move)
_oldValues = null
}
}
/** Set the value for a key */
def update(k: K, v: V, mergeF: (V,V) => V) {
if (k == null) {
if(haveNullValue) {
nullValue = mergeF(nullValue, v)
} else {
haveNullValue = true
nullValue = v
}
} else {
val pos = keySet.fastAdd(k) & OpenHashSet.POSITION_MASK
_values(pos) = mergeF(_values(pos), v)
keySet.rehashIfNeeded(k, grow, move)
_oldValues = null
}
}
@ -87,11 +111,11 @@ class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V:
}
nullValue
} else {
val pos = _keySet.fastAdd(k)
val pos = keySet.fastAdd(k)
if ((pos & OpenHashSet.EXISTENCE_MASK) != 0) {
val newValue = defaultValue
_values(pos & OpenHashSet.POSITION_MASK) = newValue
_keySet.rehashIfNeeded(k, grow, move)
keySet.rehashIfNeeded(k, grow, move)
newValue
} else {
_values(pos) = mergeValue(_values(pos))
@ -113,9 +137,9 @@ class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V:
}
pos += 1
}
pos = _keySet.nextPos(pos)
pos = keySet.nextPos(pos)
if (pos >= 0) {
val ret = (_keySet.getValue(pos), _values(pos))
val ret = (keySet.getValue(pos), _values(pos))
pos += 1
ret
} else {
@ -146,3 +170,4 @@ class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V:
_values(newPos) = _oldValues(oldPos)
}
}

View file

@ -28,35 +28,56 @@ package org.apache.spark.util.hash
private[spark]
class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest,
@specialized(Long, Int, Double) V: ClassManifest](
initialCapacity: Int)
var keySet: OpenHashSet[K], var _values: Array[V])
extends Iterable[(K, V)]
with Serializable {
def this() = this(64)
/**
* Allocate an OpenHashMap with a fixed initial capacity
*/
def this(initialCapacity: Int = 64) =
this(new OpenHashSet[K](initialCapacity), new Array[V](initialCapacity))
/**
* Allocate an OpenHashMap with a fixed initial capacity
*/
def this(keySet: OpenHashSet[K]) = this(keySet, new Array[V](keySet.capacity))
require(classManifest[K] == classManifest[Long] || classManifest[K] == classManifest[Int])
protected var _keySet = new OpenHashSet[K](initialCapacity)
private var _values = new Array[V](_keySet.capacity)
private var _oldValues: Array[V] = null
override def size = _keySet.size
override def size = keySet.size
/** Get the value for a given key */
def apply(k: K): V = {
val pos = _keySet.getPos(k)
val pos = keySet.getPos(k)
_values(pos)
}
/** Set the value for a key */
def update(k: K, v: V) {
val pos = _keySet.fastAdd(k) & OpenHashSet.POSITION_MASK
val pos = keySet.fastAdd(k) & OpenHashSet.POSITION_MASK
_values(pos) = v
_keySet.rehashIfNeeded(k, grow, move)
keySet.rehashIfNeeded(k, grow, move)
_oldValues = null
}
/** Set the value for a key */
def update(k: K, v: V, mergeF: (V,V) => V) {
val pos = keySet.fastAdd(k)
val ind = pos & OpenHashSet.POSITION_MASK
if ((pos & OpenHashSet.EXISTENCE_MASK) != 0) { // if first add
_values(ind) = v
} else {
_values(ind) = mergeF(_values(ind), v)
}
keySet.rehashIfNeeded(k, grow, move)
_oldValues = null
}
/**
* If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
* set its value to mergeValue(oldValue).
@ -64,11 +85,11 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest,
* @return the newly updated value.
*/
def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
val pos = _keySet.fastAdd(k)
val pos = keySet.fastAdd(k)
if ((pos & OpenHashSet.EXISTENCE_MASK) != 0) {
val newValue = defaultValue
_values(pos & OpenHashSet.POSITION_MASK) = newValue
_keySet.rehashIfNeeded(k, grow, move)
keySet.rehashIfNeeded(k, grow, move)
newValue
} else {
_values(pos) = mergeValue(_values(pos))
@ -82,9 +103,9 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest,
/** Get the next value we should return from next(), or null if we're finished iterating */
def computeNextPair(): (K, V) = {
pos = _keySet.nextPos(pos)
pos = keySet.nextPos(pos)
if (pos >= 0) {
val ret = (_keySet.getValue(pos), _values(pos))
val ret = (keySet.getValue(pos), _values(pos))
pos += 1
ret
} else {