Added a fast and low-memory append-only map implementation for cogroup

and parallel reduce operations
This commit is contained in:
Matei Zaharia 2013-08-13 18:09:40 -07:00
parent e67d5b962a
commit b535db7d89
5 changed files with 411 additions and 29 deletions

View file

@ -21,8 +21,10 @@ import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions._
import spark.util.AppendOnlyMap
/** A set of functions used to aggregate data.
*
*
* @param createCombiner function to create the initial value of the aggregation.
* @param mergeValue function to merge a new value into the aggregation result.
* @param mergeCombiners function to merge outputs from multiple mergeValue function.
@ -33,27 +35,29 @@ case class Aggregator[K, V, C] (
mergeCombiners: (C, C) => C) {
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
for (kv <- iter) {
val oldC = combiners.get(kv._1)
if (oldC == null) {
combiners.put(kv._1, createCombiner(kv._2))
} else {
combiners.put(kv._1, mergeValue(oldC, kv._2))
}
val combiners = new AppendOnlyMap[K, C]
for ((k, v) <- iter) {
combiners.changeValue(k, (hadValue, oldValue) => {
if (hadValue) {
mergeValue(oldValue, v)
} else {
createCombiner(v)
}
})
}
combiners.iterator
}
def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
iter.foreach { case(k, c) =>
val oldC = combiners.get(k)
if (oldC == null) {
combiners.put(k, c)
} else {
combiners.put(k, mergeCombiners(oldC, c))
}
val combiners = new AppendOnlyMap[K, C]
for ((k, c) <- iter) {
combiners.changeValue(k, (hadValue, oldValue) => {
if (hadValue) {
mergeCombiners(oldValue, c)
} else {
c
}
})
}
combiners.iterator
}

View file

@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
import org.apache.spark.util.AppendOnlyMap
private[spark] sealed trait CoGroupSplitDep extends Serializable
@ -105,17 +106,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
val split = s.asInstanceOf[CoGroupPartition]
val numRdds = split.deps.size
// e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
val seq = map.get(k)
if (seq != null) {
seq
} else {
val seq = Array.fill(numRdds)(new ArrayBuffer[Any])
map.put(k, seq)
seq
}
map.changeValue(k, (hadValue, oldValue) => {
if (hadValue) oldValue else Array.fill(numRdds)(new ArrayBuffer[Any])
})
}
val ser = SparkEnv.get.serializerManager.get(serializerClass)
@ -134,7 +130,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
}
}
}
JavaConversions.mapAsScalaMap(map).iterator
map.iterator
}
override def clearDependencies() {

View file

@ -0,0 +1,241 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package spark.util
/**
* A simple open hash table optimized for the append-only use case, where keys
* are never removed, but the value for each key may be changed.
*
* This implementation uses quadratic probing with a power-of-2 hash table
* size, which is guaranteed to explore all spaces for each key (see
* http://en.wikipedia.org/wiki/Quadratic_probing).
*
* TODO: Cache the hash values of each key? java.util.HashMap does that.
*/
private[spark]
class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] with Serializable {
if (!isPowerOf2(initialCapacity)) {
throw new IllegalArgumentException("Initial capacity must be power of 2")
}
if (initialCapacity >= (1 << 30)) {
throw new IllegalArgumentException("Can't make capacity bigger than 2^29 elements")
}
private var capacity = initialCapacity
private var curSize = 0
// Holds keys and values in the same array for memory locality; specifically, the order of
// elements is key0, value0, key1, value1, key2, value2, etc.
private var data = new Array[AnyRef](2 * capacity)
// Treat the null key differently so we can use nulls in "data" to represent empty items.
private var haveNullValue = false
private var nullValue: V = null.asInstanceOf[V]
private val LOAD_FACTOR = 0.7
/** Get the value for a given key */
def apply(key: K): V = {
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
return nullValue
}
val mask = capacity - 1
var pos = rehash(k.hashCode) & mask
var i = 1
while (true) {
val curKey = data(2 * pos)
if (curKey.eq(k) || curKey.eq(null) || curKey == k) {
return data(2 * pos + 1).asInstanceOf[V]
} else {
val delta = i
pos = (pos + delta) & mask
i += 1
}
}
return null.asInstanceOf[V]
}
/** Set the value for a key */
def update(key: K, value: V) {
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
if (!haveNullValue) {
incrementSize()
}
nullValue = value
haveNullValue = true
return
}
val isNewEntry = putInto(data, k, value.asInstanceOf[AnyRef])
if (isNewEntry) {
incrementSize()
}
}
/**
* Set the value for key to updateFunc(hadValue, oldValue), where oldValue will be the old value
* for key, if any, or null otherwise. Returns the newly updated value.
*/
def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
if (!haveNullValue) {
incrementSize()
}
nullValue = updateFunc(haveNullValue, nullValue)
haveNullValue = true
return nullValue
}
val mask = capacity - 1
var pos = rehash(k.hashCode) & mask
var i = 1
while (true) {
val curKey = data(2 * pos)
if (curKey.eq(null)) {
val newValue = updateFunc(false, null.asInstanceOf[V])
data(2 * pos) = k
data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
incrementSize()
return newValue
} else if (curKey.eq(k) || curKey == k) {
val newValue = updateFunc(true, data(2*pos + 1).asInstanceOf[V])
data(2*pos + 1) = newValue.asInstanceOf[AnyRef]
return newValue
} else {
val delta = i
pos = (pos + delta) & mask
i += 1
}
}
null.asInstanceOf[V] // Never reached but needed to keep compiler happy
}
/** Iterator method from Iterable */
override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
var pos = -1
/** Get the next value we should return from next(), or null if we're finished iterating */
def nextValue(): (K, V) = {
if (pos == -1) { // Treat position -1 as looking at the null value
if (haveNullValue) {
return (null.asInstanceOf[K], nullValue)
}
pos += 1
}
while (pos < capacity) {
if (!data(2 * pos).eq(null)) {
return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
}
pos += 1
}
null
}
override def hasNext: Boolean = nextValue() != null
override def next(): (K, V) = {
val value = nextValue()
if (value == null) {
throw new NoSuchElementException("End of iterator")
}
pos += 1
value
}
}
override def size: Int = curSize
/** Increase table size by 1, rehashing if necessary */
private def incrementSize() {
curSize += 1
if (curSize > LOAD_FACTOR * capacity) {
growTable()
}
}
/**
* Re-hash a value to deal better with hash functions that don't differ
* in the lower bits, similar to java.util.HashMap
*/
private def rehash(h: Int): Int = {
val r = h ^ (h >>> 20) ^ (h >>> 12)
r ^ (r >>> 7) ^ (r >>> 4)
}
/**
* Put an entry into a table represented by data, returning true if
* this increases the size of the table or false otherwise. Assumes
* that "data" has at least one empty slot.
*/
private def putInto(data: Array[AnyRef], key: AnyRef, value: AnyRef): Boolean = {
val mask = (data.length / 2) - 1
var pos = rehash(key.hashCode) & mask
var i = 1
while (true) {
val curKey = data(2 * pos)
if (curKey.eq(null)) {
data(2 * pos) = key
data(2 * pos + 1) = value.asInstanceOf[AnyRef]
return true
} else if (curKey.eq(key) || curKey == key) {
data(2 * pos + 1) = value.asInstanceOf[AnyRef]
return false
} else {
val delta = i
pos = (pos + delta) & mask
i += 1
}
}
return false // Never reached but needed to keep compiler happy
}
/** Double the table's size and re-hash everything */
private def growTable() {
val newCapacity = capacity * 2
if (newCapacity >= (1 << 30)) {
// We can't make the table this big because we want an array of 2x
// that size for our data, but array sizes are at most Int.MaxValue
throw new Exception("Can't make capacity bigger than 2^29 elements")
}
val newData = new Array[AnyRef](2 * newCapacity)
var pos = 0
while (pos < capacity) {
if (!data(2 * pos).eq(null)) {
putInto(newData, data(2 * pos), data(2 * pos + 1))
}
pos += 1
}
data = newData
capacity = newCapacity
}
private def isPowerOf2(num: Int): Boolean = {
var n = num
while (n > 0) {
if (n == 1) {
return true
} else if (n % 2 == 1) {
return false
} else {
n /= 2
}
}
return false
}
}

View file

@ -44,7 +44,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)}
d.count
d.count()
Thread.sleep(1000)
listener.stageInfos.size should be (1)
@ -55,7 +55,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => w(k) -> (v1.size, v2.size)}
d4.setName("A Cogroup")
d4.collectAsMap
d4.collectAsMap()
Thread.sleep(1000)
listener.stageInfos.size should be (4)

View file

@ -0,0 +1,141 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package spark.util
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
class AppendOnlyMapSuite extends FunSuite {
test("initialization") {
val goodMap1 = new AppendOnlyMap[Int, Int](1)
assert(goodMap1.size === 0)
val goodMap2 = new AppendOnlyMap[Int, Int](256)
assert(goodMap2.size === 0)
intercept[IllegalArgumentException] {
new AppendOnlyMap[Int, Int](255) // Invalid map size: not power of 2
}
intercept[IllegalArgumentException] {
new AppendOnlyMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29
}
intercept[IllegalArgumentException] {
new AppendOnlyMap[Int, Int](-1) // Invalid map size: not power of 2
}
}
test("object keys and values") {
val map = new AppendOnlyMap[String, String]()
for (i <- 1 to 100) {
map("" + i) = "" + i
}
assert(map.size === 100)
for (i <- 1 to 100) {
assert(map("" + i) === "" + i)
}
assert(map("0") === null)
assert(map("101") === null)
assert(map(null) === null)
val set = new HashSet[(String, String)]
for ((k, v) <- map) { // Test the foreach method
set += ((k, v))
}
assert(set === (1 to 100).map(_.toString).map(x => (x, x)).toSet)
}
test("primitive keys and values") {
val map = new AppendOnlyMap[Int, Int]()
for (i <- 1 to 100) {
map(i) = i
}
assert(map.size === 100)
for (i <- 1 to 100) {
assert(map(i) === i)
}
assert(map(0) === null)
assert(map(101) === null)
val set = new HashSet[(Int, Int)]
for ((k, v) <- map) { // Test the foreach method
set += ((k, v))
}
assert(set === (1 to 100).map(x => (x, x)).toSet)
}
test("null keys") {
val map = new AppendOnlyMap[String, String]()
for (i <- 1 to 100) {
map("" + i) = "" + i
}
assert(map.size === 100)
assert(map(null) === null)
map(null) = "hello"
assert(map.size === 101)
assert(map(null) === "hello")
}
test("null values") {
val map = new AppendOnlyMap[String, String]()
for (i <- 1 to 100) {
map("" + i) = null
}
assert(map.size === 100)
assert(map("1") === null)
assert(map(null) === null)
assert(map.size === 100)
map(null) = null
assert(map.size === 101)
assert(map(null) === null)
}
test("changeValue") {
val map = new AppendOnlyMap[String, String]()
for (i <- 1 to 100) {
map("" + i) = "" + i
}
assert(map.size === 100)
for (i <- 1 to 100) {
val res = map.changeValue("" + i, (hadValue, oldValue) => {
assert(hadValue === true)
assert(oldValue === "" + i)
oldValue + "!"
})
assert(res === i + "!")
}
// Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a
// bug where changeValue would return the wrong result when the map grew on that insert
for (i <- 101 to 400) {
val res = map.changeValue("" + i, (hadValue, oldValue) => {
assert(hadValue === false)
i + "!"
})
assert(res === i + "!")
}
assert(map.size === 400)
assert(map(null) === null)
map.changeValue(null, (hadValue, oldValue) => {
assert(hadValue === false)
"null!"
})
assert(map.size === 401)
map.changeValue(null, (hadValue, oldValue) => {
assert(hadValue === true)
assert(oldValue === "null!")
"null!!"
})
assert(map.size === 401)
}
}