From 22b8fcf632de8e00d25ab529eb347ed2288e20fc Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 30 Nov 2011 11:36:36 -0800 Subject: [PATCH] Added fold() and aggregate() operations that reuse an object to merge results into rather than requiring a new object allocation for each element merged. Fixes #95. --- core/src/main/scala/spark/RDD.scala | 33 ++++++++++++++++++++++-- core/src/test/scala/spark/RDDSuite.scala | 24 +++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 445d520bc2..eb31e90123 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -133,7 +133,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial val cleanF = sc.clean(f) val reducePartition: Iterator[T] => Option[T] = iter => { if (iter.hasNext) - Some(iter.reduceLeft(f)) + Some(iter.reduceLeft(cleanF)) else None } @@ -144,7 +144,36 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial if (results.size == 0) throw new UnsupportedOperationException("empty collection") else - return results.reduceLeft(f) + return results.reduceLeft(cleanF) + } + + /** + * Aggregate the elements of each partition, and then the results for all the + * partitions, using a given associative function and a neutral "zero value". + * The function op(t1, t2) is allowed to modify t1 and return it as its result + * value to avoid object allocation; however, it should not modify t2. + */ + def fold(zeroValue: T)(op: (T, T) => T): T = { + val cleanOp = sc.clean(op) + val results = sc.runJob(this, (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp)) + return results.fold(zeroValue)(cleanOp) + } + + /** + * Aggregate the elements of each partition, and then the results for all the + * partitions, using given combine functions and a neutral "zero value". This + * function can return a different result type, U, than the type of this RDD, T. + * Thus, we need one operation for merging a T into an U and one operation for + * merging two U's, as in scala.TraversableOnce. Both of these functions are + * allowed to modify and return their first argument instead of creating a new U + * to avoid memory allocation. + */ + def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = { + val cleanSeqOp = sc.clean(seqOp) + val cleanCombOp = sc.clean(combOp) + val results = sc.runJob(this, + (iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)) + return results.fold(zeroValue)(cleanCombOp) } def count(): Long = { diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 06d438d9e2..7199b634b7 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -1,5 +1,6 @@ package spark +import scala.collection.mutable.HashMap import org.scalatest.FunSuite import SparkContext._ @@ -9,6 +10,7 @@ class RDDSuite extends FunSuite { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(nums.collect().toList === List(1, 2, 3, 4)) assert(nums.reduce(_ + _) === 10) + assert(nums.fold(0)(_ + _) === 10) assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4")) assert(nums.filter(_ > 2).collect().toList === List(3, 4)) assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) @@ -18,4 +20,26 @@ class RDDSuite extends FunSuite { assert(partitionSums.collect().toList === List(3, 7)) sc.stop() } + + test("aggregate") { + val sc = new SparkContext("local", "test") + val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) + type StringMap = HashMap[String, Int] + val emptyMap = new StringMap { + override def default(key: String): Int = 0 + } + val mergeElement: (StringMap, (String, Int)) => StringMap = (map, pair) => { + map(pair._1) += pair._2 + map + } + val mergeMaps: (StringMap, StringMap) => StringMap = (map1, map2) => { + for ((key, value) <- map2) { + map1(key) += value + } + map1 + } + val result = pairs.aggregate(emptyMap)(mergeElement, mergeMaps) + assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) + sc.stop() + } }