diff --git a/core/src/main/scala/org/apache/spark/util/hash/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/hash/OpenHashMap.scala index a376d1015a..af282d5651 100644 --- a/core/src/main/scala/org/apache/spark/util/hash/OpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/hash/OpenHashMap.scala @@ -27,14 +27,21 @@ package org.apache.spark.util.hash */ private[spark] class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V: ClassManifest]( - initialCapacity: Int) + var keySet: OpenHashSet[K], var _values: Array[V]) extends Iterable[(K, V)] with Serializable { - def this() = this(64) + /** + * Allocate an OpenHashMap with a fixed initial capacity + */ + def this(initialCapacity: Int = 64) = + this(new OpenHashSet[K](initialCapacity), new Array[V](initialCapacity)) + + /** + * Allocate an OpenHashMap with a fixed initial capacity + */ + def this(keySet: OpenHashSet[K]) = this(keySet, new Array[V](keySet.capacity)) - protected var _keySet = new OpenHashSet[K](initialCapacity) - private var _values = new Array[V](_keySet.capacity) @transient private var _oldValues: Array[V] = null @@ -42,14 +49,14 @@ class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V: private var haveNullValue = false private var nullValue: V = null.asInstanceOf[V] - override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size + override def size: Int = if (haveNullValue) keySet.size + 1 else keySet.size /** Get the value for a given key */ def apply(k: K): V = { if (k == null) { nullValue } else { - val pos = _keySet.getPos(k) + val pos = keySet.getPos(k) if (pos < 0) { null.asInstanceOf[V] } else { @@ -64,9 +71,26 @@ class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V: haveNullValue = true nullValue = v } else { - val pos = _keySet.fastAdd(k) & OpenHashSet.POSITION_MASK + val pos = keySet.fastAdd(k) & OpenHashSet.POSITION_MASK _values(pos) = v - _keySet.rehashIfNeeded(k, grow, move) + keySet.rehashIfNeeded(k, grow, move) + _oldValues = null + } + } + + /** Set the value for a key */ + def update(k: K, v: V, mergeF: (V,V) => V) { + if (k == null) { + if(haveNullValue) { + nullValue = mergeF(nullValue, v) + } else { + haveNullValue = true + nullValue = v + } + } else { + val pos = keySet.fastAdd(k) & OpenHashSet.POSITION_MASK + _values(pos) = mergeF(_values(pos), v) + keySet.rehashIfNeeded(k, grow, move) _oldValues = null } } @@ -87,11 +111,11 @@ class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V: } nullValue } else { - val pos = _keySet.fastAdd(k) + val pos = keySet.fastAdd(k) if ((pos & OpenHashSet.EXISTENCE_MASK) != 0) { val newValue = defaultValue _values(pos & OpenHashSet.POSITION_MASK) = newValue - _keySet.rehashIfNeeded(k, grow, move) + keySet.rehashIfNeeded(k, grow, move) newValue } else { _values(pos) = mergeValue(_values(pos)) @@ -113,9 +137,9 @@ class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V: } pos += 1 } - pos = _keySet.nextPos(pos) + pos = keySet.nextPos(pos) if (pos >= 0) { - val ret = (_keySet.getValue(pos), _values(pos)) + val ret = (keySet.getValue(pos), _values(pos)) pos += 1 ret } else { @@ -146,3 +170,4 @@ class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V: _values(newPos) = _oldValues(oldPos) } } + diff --git a/core/src/main/scala/org/apache/spark/util/hash/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/hash/PrimitiveKeyOpenHashMap.scala index 14c1367207..cbfb2361b4 100644 --- a/core/src/main/scala/org/apache/spark/util/hash/PrimitiveKeyOpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/hash/PrimitiveKeyOpenHashMap.scala @@ -28,35 +28,56 @@ package org.apache.spark.util.hash private[spark] class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest, @specialized(Long, Int, Double) V: ClassManifest]( - initialCapacity: Int) + var keySet: OpenHashSet[K], var _values: Array[V]) extends Iterable[(K, V)] with Serializable { - def this() = this(64) + /** + * Allocate an OpenHashMap with a fixed initial capacity + */ + def this(initialCapacity: Int = 64) = + this(new OpenHashSet[K](initialCapacity), new Array[V](initialCapacity)) + + /** + * Allocate an OpenHashMap with a fixed initial capacity + */ + def this(keySet: OpenHashSet[K]) = this(keySet, new Array[V](keySet.capacity)) require(classManifest[K] == classManifest[Long] || classManifest[K] == classManifest[Int]) - protected var _keySet = new OpenHashSet[K](initialCapacity) - private var _values = new Array[V](_keySet.capacity) - private var _oldValues: Array[V] = null - override def size = _keySet.size + override def size = keySet.size /** Get the value for a given key */ def apply(k: K): V = { - val pos = _keySet.getPos(k) + val pos = keySet.getPos(k) _values(pos) } /** Set the value for a key */ def update(k: K, v: V) { - val pos = _keySet.fastAdd(k) & OpenHashSet.POSITION_MASK + val pos = keySet.fastAdd(k) & OpenHashSet.POSITION_MASK _values(pos) = v - _keySet.rehashIfNeeded(k, grow, move) + keySet.rehashIfNeeded(k, grow, move) _oldValues = null } + + /** Set the value for a key */ + def update(k: K, v: V, mergeF: (V,V) => V) { + val pos = keySet.fastAdd(k) + val ind = pos & OpenHashSet.POSITION_MASK + if ((pos & OpenHashSet.EXISTENCE_MASK) != 0) { // if first add + _values(ind) = v + } else { + _values(ind) = mergeF(_values(ind), v) + } + keySet.rehashIfNeeded(k, grow, move) + _oldValues = null + } + + /** * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise, * set its value to mergeValue(oldValue). @@ -64,11 +85,11 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest, * @return the newly updated value. */ def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = { - val pos = _keySet.fastAdd(k) + val pos = keySet.fastAdd(k) if ((pos & OpenHashSet.EXISTENCE_MASK) != 0) { val newValue = defaultValue _values(pos & OpenHashSet.POSITION_MASK) = newValue - _keySet.rehashIfNeeded(k, grow, move) + keySet.rehashIfNeeded(k, grow, move) newValue } else { _values(pos) = mergeValue(_values(pos)) @@ -82,9 +103,9 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest, /** Get the next value we should return from next(), or null if we're finished iterating */ def computeNextPair(): (K, V) = { - pos = _keySet.nextPos(pos) + pos = keySet.nextPos(pos) if (pos >= 0) { - val ret = (_keySet.getValue(pos), _values(pos)) + val ret = (keySet.getValue(pos), _values(pos)) pos += 1 ret } else {