Added fold() and aggregate() operations that reuse an object to

merge results into rather than requiring a new object allocation
for each element merged. Fixes #95.
This commit is contained in:
Matei Zaharia 2011-11-30 11:36:36 -08:00
parent 09dd58b3a7
commit 22b8fcf632
2 changed files with 55 additions and 2 deletions

View file

@ -133,7 +133,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
val cleanF = sc.clean(f)
val reducePartition: Iterator[T] => Option[T] = iter => {
if (iter.hasNext)
Some(iter.reduceLeft(f))
Some(iter.reduceLeft(cleanF))
else
None
}
@ -144,7 +144,36 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
if (results.size == 0)
throw new UnsupportedOperationException("empty collection")
else
return results.reduceLeft(f)
return results.reduceLeft(cleanF)
}
/**
* Aggregate the elements of each partition, and then the results for all the
* partitions, using a given associative function and a neutral "zero value".
* The function op(t1, t2) is allowed to modify t1 and return it as its result
* value to avoid object allocation; however, it should not modify t2.
*/
def fold(zeroValue: T)(op: (T, T) => T): T = {
val cleanOp = sc.clean(op)
val results = sc.runJob(this, (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp))
return results.fold(zeroValue)(cleanOp)
}
/**
* Aggregate the elements of each partition, and then the results for all the
* partitions, using given combine functions and a neutral "zero value". This
* function can return a different result type, U, than the type of this RDD, T.
* Thus, we need one operation for merging a T into an U and one operation for
* merging two U's, as in scala.TraversableOnce. Both of these functions are
* allowed to modify and return their first argument instead of creating a new U
* to avoid memory allocation.
*/
def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
val cleanSeqOp = sc.clean(seqOp)
val cleanCombOp = sc.clean(combOp)
val results = sc.runJob(this,
(iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp))
return results.fold(zeroValue)(cleanCombOp)
}
def count(): Long = {

View file

@ -1,5 +1,6 @@
package spark
import scala.collection.mutable.HashMap
import org.scalatest.FunSuite
import SparkContext._
@ -9,6 +10,7 @@ class RDDSuite extends FunSuite {
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
assert(nums.reduce(_ + _) === 10)
assert(nums.fold(0)(_ + _) === 10)
assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4"))
assert(nums.filter(_ > 2).collect().toList === List(3, 4))
assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4))
@ -18,4 +20,26 @@ class RDDSuite extends FunSuite {
assert(partitionSums.collect().toList === List(3, 7))
sc.stop()
}
test("aggregate") {
val sc = new SparkContext("local", "test")
val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3)))
type StringMap = HashMap[String, Int]
val emptyMap = new StringMap {
override def default(key: String): Int = 0
}
val mergeElement: (StringMap, (String, Int)) => StringMap = (map, pair) => {
map(pair._1) += pair._2
map
}
val mergeMaps: (StringMap, StringMap) => StringMap = (map1, map2) => {
for ((key, value) <- map2) {
map1(key) += value
}
map1
}
val result = pairs.aggregate(emptyMap)(mergeElement, mergeMaps)
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
sc.stop()
}
}