Added a fast and low-memory append-only map implementation for cogroup
and parallel reduce operations
This commit is contained in:
parent
e67d5b962a
commit
b535db7d89
|
@ -21,8 +21,10 @@ import java.util.{HashMap => JHashMap}
|
|||
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
import spark.util.AppendOnlyMap
|
||||
|
||||
/** A set of functions used to aggregate data.
|
||||
*
|
||||
*
|
||||
* @param createCombiner function to create the initial value of the aggregation.
|
||||
* @param mergeValue function to merge a new value into the aggregation result.
|
||||
* @param mergeCombiners function to merge outputs from multiple mergeValue function.
|
||||
|
@ -33,27 +35,29 @@ case class Aggregator[K, V, C] (
|
|||
mergeCombiners: (C, C) => C) {
|
||||
|
||||
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = {
|
||||
val combiners = new JHashMap[K, C]
|
||||
for (kv <- iter) {
|
||||
val oldC = combiners.get(kv._1)
|
||||
if (oldC == null) {
|
||||
combiners.put(kv._1, createCombiner(kv._2))
|
||||
} else {
|
||||
combiners.put(kv._1, mergeValue(oldC, kv._2))
|
||||
}
|
||||
val combiners = new AppendOnlyMap[K, C]
|
||||
for ((k, v) <- iter) {
|
||||
combiners.changeValue(k, (hadValue, oldValue) => {
|
||||
if (hadValue) {
|
||||
mergeValue(oldValue, v)
|
||||
} else {
|
||||
createCombiner(v)
|
||||
}
|
||||
})
|
||||
}
|
||||
combiners.iterator
|
||||
}
|
||||
|
||||
def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = {
|
||||
val combiners = new JHashMap[K, C]
|
||||
iter.foreach { case(k, c) =>
|
||||
val oldC = combiners.get(k)
|
||||
if (oldC == null) {
|
||||
combiners.put(k, c)
|
||||
} else {
|
||||
combiners.put(k, mergeCombiners(oldC, c))
|
||||
}
|
||||
val combiners = new AppendOnlyMap[K, C]
|
||||
for ((k, c) <- iter) {
|
||||
combiners.changeValue(k, (hadValue, oldValue) => {
|
||||
if (hadValue) {
|
||||
mergeCombiners(oldValue, c)
|
||||
} else {
|
||||
c
|
||||
}
|
||||
})
|
||||
}
|
||||
combiners.iterator
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer
|
|||
|
||||
import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext}
|
||||
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
|
||||
import org.apache.spark.util.AppendOnlyMap
|
||||
|
||||
|
||||
private[spark] sealed trait CoGroupSplitDep extends Serializable
|
||||
|
@ -105,17 +106,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
|
|||
val split = s.asInstanceOf[CoGroupPartition]
|
||||
val numRdds = split.deps.size
|
||||
// e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
|
||||
val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
|
||||
val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]]
|
||||
|
||||
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
|
||||
val seq = map.get(k)
|
||||
if (seq != null) {
|
||||
seq
|
||||
} else {
|
||||
val seq = Array.fill(numRdds)(new ArrayBuffer[Any])
|
||||
map.put(k, seq)
|
||||
seq
|
||||
}
|
||||
map.changeValue(k, (hadValue, oldValue) => {
|
||||
if (hadValue) oldValue else Array.fill(numRdds)(new ArrayBuffer[Any])
|
||||
})
|
||||
}
|
||||
|
||||
val ser = SparkEnv.get.serializerManager.get(serializerClass)
|
||||
|
@ -134,7 +130,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
|
|||
}
|
||||
}
|
||||
}
|
||||
JavaConversions.mapAsScalaMap(map).iterator
|
||||
map.iterator
|
||||
}
|
||||
|
||||
override def clearDependencies() {
|
||||
|
|
241
core/src/main/scala/spark/util/AppendOnlyMap.scala
Normal file
241
core/src/main/scala/spark/util/AppendOnlyMap.scala
Normal file
|
@ -0,0 +1,241 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package spark.util
|
||||
|
||||
/**
|
||||
* A simple open hash table optimized for the append-only use case, where keys
|
||||
* are never removed, but the value for each key may be changed.
|
||||
*
|
||||
* This implementation uses quadratic probing with a power-of-2 hash table
|
||||
* size, which is guaranteed to explore all spaces for each key (see
|
||||
* http://en.wikipedia.org/wiki/Quadratic_probing).
|
||||
*
|
||||
* TODO: Cache the hash values of each key? java.util.HashMap does that.
|
||||
*/
|
||||
private[spark]
|
||||
class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] with Serializable {
|
||||
if (!isPowerOf2(initialCapacity)) {
|
||||
throw new IllegalArgumentException("Initial capacity must be power of 2")
|
||||
}
|
||||
if (initialCapacity >= (1 << 30)) {
|
||||
throw new IllegalArgumentException("Can't make capacity bigger than 2^29 elements")
|
||||
}
|
||||
|
||||
private var capacity = initialCapacity
|
||||
private var curSize = 0
|
||||
|
||||
// Holds keys and values in the same array for memory locality; specifically, the order of
|
||||
// elements is key0, value0, key1, value1, key2, value2, etc.
|
||||
private var data = new Array[AnyRef](2 * capacity)
|
||||
|
||||
// Treat the null key differently so we can use nulls in "data" to represent empty items.
|
||||
private var haveNullValue = false
|
||||
private var nullValue: V = null.asInstanceOf[V]
|
||||
|
||||
private val LOAD_FACTOR = 0.7
|
||||
|
||||
/** Get the value for a given key */
|
||||
def apply(key: K): V = {
|
||||
val k = key.asInstanceOf[AnyRef]
|
||||
if (k.eq(null)) {
|
||||
return nullValue
|
||||
}
|
||||
val mask = capacity - 1
|
||||
var pos = rehash(k.hashCode) & mask
|
||||
var i = 1
|
||||
while (true) {
|
||||
val curKey = data(2 * pos)
|
||||
if (curKey.eq(k) || curKey.eq(null) || curKey == k) {
|
||||
return data(2 * pos + 1).asInstanceOf[V]
|
||||
} else {
|
||||
val delta = i
|
||||
pos = (pos + delta) & mask
|
||||
i += 1
|
||||
}
|
||||
}
|
||||
return null.asInstanceOf[V]
|
||||
}
|
||||
|
||||
/** Set the value for a key */
|
||||
def update(key: K, value: V) {
|
||||
val k = key.asInstanceOf[AnyRef]
|
||||
if (k.eq(null)) {
|
||||
if (!haveNullValue) {
|
||||
incrementSize()
|
||||
}
|
||||
nullValue = value
|
||||
haveNullValue = true
|
||||
return
|
||||
}
|
||||
val isNewEntry = putInto(data, k, value.asInstanceOf[AnyRef])
|
||||
if (isNewEntry) {
|
||||
incrementSize()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the value for key to updateFunc(hadValue, oldValue), where oldValue will be the old value
|
||||
* for key, if any, or null otherwise. Returns the newly updated value.
|
||||
*/
|
||||
def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
|
||||
val k = key.asInstanceOf[AnyRef]
|
||||
if (k.eq(null)) {
|
||||
if (!haveNullValue) {
|
||||
incrementSize()
|
||||
}
|
||||
nullValue = updateFunc(haveNullValue, nullValue)
|
||||
haveNullValue = true
|
||||
return nullValue
|
||||
}
|
||||
val mask = capacity - 1
|
||||
var pos = rehash(k.hashCode) & mask
|
||||
var i = 1
|
||||
while (true) {
|
||||
val curKey = data(2 * pos)
|
||||
if (curKey.eq(null)) {
|
||||
val newValue = updateFunc(false, null.asInstanceOf[V])
|
||||
data(2 * pos) = k
|
||||
data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
|
||||
incrementSize()
|
||||
return newValue
|
||||
} else if (curKey.eq(k) || curKey == k) {
|
||||
val newValue = updateFunc(true, data(2*pos + 1).asInstanceOf[V])
|
||||
data(2*pos + 1) = newValue.asInstanceOf[AnyRef]
|
||||
return newValue
|
||||
} else {
|
||||
val delta = i
|
||||
pos = (pos + delta) & mask
|
||||
i += 1
|
||||
}
|
||||
}
|
||||
null.asInstanceOf[V] // Never reached but needed to keep compiler happy
|
||||
}
|
||||
|
||||
/** Iterator method from Iterable */
|
||||
override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
|
||||
var pos = -1
|
||||
|
||||
/** Get the next value we should return from next(), or null if we're finished iterating */
|
||||
def nextValue(): (K, V) = {
|
||||
if (pos == -1) { // Treat position -1 as looking at the null value
|
||||
if (haveNullValue) {
|
||||
return (null.asInstanceOf[K], nullValue)
|
||||
}
|
||||
pos += 1
|
||||
}
|
||||
while (pos < capacity) {
|
||||
if (!data(2 * pos).eq(null)) {
|
||||
return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
|
||||
}
|
||||
pos += 1
|
||||
}
|
||||
null
|
||||
}
|
||||
|
||||
override def hasNext: Boolean = nextValue() != null
|
||||
|
||||
override def next(): (K, V) = {
|
||||
val value = nextValue()
|
||||
if (value == null) {
|
||||
throw new NoSuchElementException("End of iterator")
|
||||
}
|
||||
pos += 1
|
||||
value
|
||||
}
|
||||
}
|
||||
|
||||
override def size: Int = curSize
|
||||
|
||||
/** Increase table size by 1, rehashing if necessary */
|
||||
private def incrementSize() {
|
||||
curSize += 1
|
||||
if (curSize > LOAD_FACTOR * capacity) {
|
||||
growTable()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Re-hash a value to deal better with hash functions that don't differ
|
||||
* in the lower bits, similar to java.util.HashMap
|
||||
*/
|
||||
private def rehash(h: Int): Int = {
|
||||
val r = h ^ (h >>> 20) ^ (h >>> 12)
|
||||
r ^ (r >>> 7) ^ (r >>> 4)
|
||||
}
|
||||
|
||||
/**
|
||||
* Put an entry into a table represented by data, returning true if
|
||||
* this increases the size of the table or false otherwise. Assumes
|
||||
* that "data" has at least one empty slot.
|
||||
*/
|
||||
private def putInto(data: Array[AnyRef], key: AnyRef, value: AnyRef): Boolean = {
|
||||
val mask = (data.length / 2) - 1
|
||||
var pos = rehash(key.hashCode) & mask
|
||||
var i = 1
|
||||
while (true) {
|
||||
val curKey = data(2 * pos)
|
||||
if (curKey.eq(null)) {
|
||||
data(2 * pos) = key
|
||||
data(2 * pos + 1) = value.asInstanceOf[AnyRef]
|
||||
return true
|
||||
} else if (curKey.eq(key) || curKey == key) {
|
||||
data(2 * pos + 1) = value.asInstanceOf[AnyRef]
|
||||
return false
|
||||
} else {
|
||||
val delta = i
|
||||
pos = (pos + delta) & mask
|
||||
i += 1
|
||||
}
|
||||
}
|
||||
return false // Never reached but needed to keep compiler happy
|
||||
}
|
||||
|
||||
/** Double the table's size and re-hash everything */
|
||||
private def growTable() {
|
||||
val newCapacity = capacity * 2
|
||||
if (newCapacity >= (1 << 30)) {
|
||||
// We can't make the table this big because we want an array of 2x
|
||||
// that size for our data, but array sizes are at most Int.MaxValue
|
||||
throw new Exception("Can't make capacity bigger than 2^29 elements")
|
||||
}
|
||||
val newData = new Array[AnyRef](2 * newCapacity)
|
||||
var pos = 0
|
||||
while (pos < capacity) {
|
||||
if (!data(2 * pos).eq(null)) {
|
||||
putInto(newData, data(2 * pos), data(2 * pos + 1))
|
||||
}
|
||||
pos += 1
|
||||
}
|
||||
data = newData
|
||||
capacity = newCapacity
|
||||
}
|
||||
|
||||
private def isPowerOf2(num: Int): Boolean = {
|
||||
var n = num
|
||||
while (n > 0) {
|
||||
if (n == 1) {
|
||||
return true
|
||||
} else if (n % 2 == 1) {
|
||||
return false
|
||||
} else {
|
||||
n /= 2
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
|
@ -44,7 +44,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
|
|||
}
|
||||
|
||||
val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)}
|
||||
d.count
|
||||
d.count()
|
||||
Thread.sleep(1000)
|
||||
listener.stageInfos.size should be (1)
|
||||
|
||||
|
@ -55,7 +55,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
|
|||
val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => w(k) -> (v1.size, v2.size)}
|
||||
d4.setName("A Cogroup")
|
||||
|
||||
d4.collectAsMap
|
||||
d4.collectAsMap()
|
||||
|
||||
Thread.sleep(1000)
|
||||
listener.stageInfos.size should be (4)
|
||||
|
|
141
core/src/test/scala/spark/util/AppendOnlyMapSuite.scala
Normal file
141
core/src/test/scala/spark/util/AppendOnlyMapSuite.scala
Normal file
|
@ -0,0 +1,141 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package spark.util
|
||||
|
||||
import scala.collection.mutable.HashSet
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
class AppendOnlyMapSuite extends FunSuite {
|
||||
test("initialization") {
|
||||
val goodMap1 = new AppendOnlyMap[Int, Int](1)
|
||||
assert(goodMap1.size === 0)
|
||||
val goodMap2 = new AppendOnlyMap[Int, Int](256)
|
||||
assert(goodMap2.size === 0)
|
||||
intercept[IllegalArgumentException] {
|
||||
new AppendOnlyMap[Int, Int](255) // Invalid map size: not power of 2
|
||||
}
|
||||
intercept[IllegalArgumentException] {
|
||||
new AppendOnlyMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29
|
||||
}
|
||||
intercept[IllegalArgumentException] {
|
||||
new AppendOnlyMap[Int, Int](-1) // Invalid map size: not power of 2
|
||||
}
|
||||
}
|
||||
|
||||
test("object keys and values") {
|
||||
val map = new AppendOnlyMap[String, String]()
|
||||
for (i <- 1 to 100) {
|
||||
map("" + i) = "" + i
|
||||
}
|
||||
assert(map.size === 100)
|
||||
for (i <- 1 to 100) {
|
||||
assert(map("" + i) === "" + i)
|
||||
}
|
||||
assert(map("0") === null)
|
||||
assert(map("101") === null)
|
||||
assert(map(null) === null)
|
||||
val set = new HashSet[(String, String)]
|
||||
for ((k, v) <- map) { // Test the foreach method
|
||||
set += ((k, v))
|
||||
}
|
||||
assert(set === (1 to 100).map(_.toString).map(x => (x, x)).toSet)
|
||||
}
|
||||
|
||||
test("primitive keys and values") {
|
||||
val map = new AppendOnlyMap[Int, Int]()
|
||||
for (i <- 1 to 100) {
|
||||
map(i) = i
|
||||
}
|
||||
assert(map.size === 100)
|
||||
for (i <- 1 to 100) {
|
||||
assert(map(i) === i)
|
||||
}
|
||||
assert(map(0) === null)
|
||||
assert(map(101) === null)
|
||||
val set = new HashSet[(Int, Int)]
|
||||
for ((k, v) <- map) { // Test the foreach method
|
||||
set += ((k, v))
|
||||
}
|
||||
assert(set === (1 to 100).map(x => (x, x)).toSet)
|
||||
}
|
||||
|
||||
test("null keys") {
|
||||
val map = new AppendOnlyMap[String, String]()
|
||||
for (i <- 1 to 100) {
|
||||
map("" + i) = "" + i
|
||||
}
|
||||
assert(map.size === 100)
|
||||
assert(map(null) === null)
|
||||
map(null) = "hello"
|
||||
assert(map.size === 101)
|
||||
assert(map(null) === "hello")
|
||||
}
|
||||
|
||||
test("null values") {
|
||||
val map = new AppendOnlyMap[String, String]()
|
||||
for (i <- 1 to 100) {
|
||||
map("" + i) = null
|
||||
}
|
||||
assert(map.size === 100)
|
||||
assert(map("1") === null)
|
||||
assert(map(null) === null)
|
||||
assert(map.size === 100)
|
||||
map(null) = null
|
||||
assert(map.size === 101)
|
||||
assert(map(null) === null)
|
||||
}
|
||||
|
||||
test("changeValue") {
|
||||
val map = new AppendOnlyMap[String, String]()
|
||||
for (i <- 1 to 100) {
|
||||
map("" + i) = "" + i
|
||||
}
|
||||
assert(map.size === 100)
|
||||
for (i <- 1 to 100) {
|
||||
val res = map.changeValue("" + i, (hadValue, oldValue) => {
|
||||
assert(hadValue === true)
|
||||
assert(oldValue === "" + i)
|
||||
oldValue + "!"
|
||||
})
|
||||
assert(res === i + "!")
|
||||
}
|
||||
// Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a
|
||||
// bug where changeValue would return the wrong result when the map grew on that insert
|
||||
for (i <- 101 to 400) {
|
||||
val res = map.changeValue("" + i, (hadValue, oldValue) => {
|
||||
assert(hadValue === false)
|
||||
i + "!"
|
||||
})
|
||||
assert(res === i + "!")
|
||||
}
|
||||
assert(map.size === 400)
|
||||
assert(map(null) === null)
|
||||
map.changeValue(null, (hadValue, oldValue) => {
|
||||
assert(hadValue === false)
|
||||
"null!"
|
||||
})
|
||||
assert(map.size === 401)
|
||||
map.changeValue(null, (hadValue, oldValue) => {
|
||||
assert(hadValue === true)
|
||||
assert(oldValue === "null!")
|
||||
"null!!"
|
||||
})
|
||||
assert(map.size === 401)
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue