Added fold() and aggregate() operations that reuse an object to
merge results into rather than requiring a new object allocation for each element merged. Fixes #95.
This commit is contained in:
parent
09dd58b3a7
commit
22b8fcf632
|
@ -133,7 +133,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
|
|||
val cleanF = sc.clean(f)
|
||||
val reducePartition: Iterator[T] => Option[T] = iter => {
|
||||
if (iter.hasNext)
|
||||
Some(iter.reduceLeft(f))
|
||||
Some(iter.reduceLeft(cleanF))
|
||||
else
|
||||
None
|
||||
}
|
||||
|
@ -144,7 +144,36 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
|
|||
if (results.size == 0)
|
||||
throw new UnsupportedOperationException("empty collection")
|
||||
else
|
||||
return results.reduceLeft(f)
|
||||
return results.reduceLeft(cleanF)
|
||||
}
|
||||
|
||||
/**
|
||||
* Aggregate the elements of each partition, and then the results for all the
|
||||
* partitions, using a given associative function and a neutral "zero value".
|
||||
* The function op(t1, t2) is allowed to modify t1 and return it as its result
|
||||
* value to avoid object allocation; however, it should not modify t2.
|
||||
*/
|
||||
def fold(zeroValue: T)(op: (T, T) => T): T = {
|
||||
val cleanOp = sc.clean(op)
|
||||
val results = sc.runJob(this, (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp))
|
||||
return results.fold(zeroValue)(cleanOp)
|
||||
}
|
||||
|
||||
/**
|
||||
* Aggregate the elements of each partition, and then the results for all the
|
||||
* partitions, using given combine functions and a neutral "zero value". This
|
||||
* function can return a different result type, U, than the type of this RDD, T.
|
||||
* Thus, we need one operation for merging a T into an U and one operation for
|
||||
* merging two U's, as in scala.TraversableOnce. Both of these functions are
|
||||
* allowed to modify and return their first argument instead of creating a new U
|
||||
* to avoid memory allocation.
|
||||
*/
|
||||
def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
|
||||
val cleanSeqOp = sc.clean(seqOp)
|
||||
val cleanCombOp = sc.clean(combOp)
|
||||
val results = sc.runJob(this,
|
||||
(iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp))
|
||||
return results.fold(zeroValue)(cleanCombOp)
|
||||
}
|
||||
|
||||
def count(): Long = {
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package spark
|
||||
|
||||
import scala.collection.mutable.HashMap
|
||||
import org.scalatest.FunSuite
|
||||
import SparkContext._
|
||||
|
||||
|
@ -9,6 +10,7 @@ class RDDSuite extends FunSuite {
|
|||
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
|
||||
assert(nums.collect().toList === List(1, 2, 3, 4))
|
||||
assert(nums.reduce(_ + _) === 10)
|
||||
assert(nums.fold(0)(_ + _) === 10)
|
||||
assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4"))
|
||||
assert(nums.filter(_ > 2).collect().toList === List(3, 4))
|
||||
assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4))
|
||||
|
@ -18,4 +20,26 @@ class RDDSuite extends FunSuite {
|
|||
assert(partitionSums.collect().toList === List(3, 7))
|
||||
sc.stop()
|
||||
}
|
||||
|
||||
test("aggregate") {
|
||||
val sc = new SparkContext("local", "test")
|
||||
val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3)))
|
||||
type StringMap = HashMap[String, Int]
|
||||
val emptyMap = new StringMap {
|
||||
override def default(key: String): Int = 0
|
||||
}
|
||||
val mergeElement: (StringMap, (String, Int)) => StringMap = (map, pair) => {
|
||||
map(pair._1) += pair._2
|
||||
map
|
||||
}
|
||||
val mergeMaps: (StringMap, StringMap) => StringMap = (map1, map2) => {
|
||||
for ((key, value) <- map2) {
|
||||
map1(key) += value
|
||||
}
|
||||
map1
|
||||
}
|
||||
val result = pairs.aggregate(emptyMap)(mergeElement, mergeMaps)
|
||||
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
|
||||
sc.stop()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue