From 9148b968cf34b898c36a7f9672382533ee54db2d Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Mon, 4 Mar 2013 15:48:47 -0800 Subject: [PATCH] mapWith, flatMapWith and filterWith --- core/src/main/scala/spark/RDD.scala | 59 ++++++++++++++++++++- core/src/test/scala/spark/RDDSuite.scala | 66 ++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 584efa8adf..de791598a4 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -364,6 +364,63 @@ abstract class RDD[T: ClassManifest]( preservesPartitioning: Boolean = false): RDD[U] = new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) + /** + * Maps f over this RDD where f takes an additional parameter of type A. This + * additional parameter is produced by a factory method T => A which is called + * on each invocation of f. This factory method is produced by the factoryBuilder, + * an instance of which is constructed in each partition from the partition index + * and a seed value of type B. + */ + def mapWith[A: ClassManifest, B: ClassManifest, U: ClassManifest]( + f:(A, T) => U, + factoryBuilder: (Int, B) => (T => A), + factorySeed: B, + preservesPartitioning: Boolean = false): RDD[U] = { + def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { + val factory = factoryBuilder(index, factorySeed) + iter.map(t => f(factory(t), t)) + } + new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) + } + + /** + * FlatMaps f over this RDD where f takes an additional parameter of type A. This + * additional parameter is produced by a factory method T => A which is called + * on each invocation of f. This factory method is produced by the factoryBuilder, + * an instance of which is constructed in each partition from the partition index + * and a seed value of type B. + */ + def flatMapWith[A: ClassManifest, B: ClassManifest, U: ClassManifest]( + f:(A, T) => Seq[U], + factoryBuilder: (Int, B) => (T => A), + factorySeed: B, + preservesPartitioning: Boolean = false): RDD[U] = { + def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { + val factory = factoryBuilder(index, factorySeed) + iter.flatMap(t => f(factory(t), t)) + } + new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) + } + + /** + * Filters this RDD with p, where p takes an additional parameter of type A. This + * additional parameter is produced by a factory method T => A which is called + * on each invocation of p. This factory method is produced by the factoryBuilder, + * an instance of which is constructed in each partition from the partition index + * and a seed value of type B. + */ + def filterWith[A: ClassManifest, B: ClassManifest]( + p:(A, T) => Boolean, + factoryBuilder: (Int, B) => (T => A), + factorySeed: B, + preservesPartitioning: Boolean = false): RDD[T] = { + def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { + val factory = factoryBuilder(index, factorySeed) + iter.filter(t => p(factory(t), t)) + } + new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) + } + /** * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, * second element in each RDD, etc. Assumes that the two RDDs have the *same number of @@ -404,7 +461,7 @@ abstract class RDD[T: ClassManifest]( /** * Return an RDD with the elements from `this` that are not in `other`. - * + * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * RDD will be <= us. */ diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 9739ba869b..ced8170300 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -178,4 +178,70 @@ class RDDSuite extends FunSuite with LocalSparkContext { assert(prunedData.size === 1) assert(prunedData(0) === 10) } + + test("mapWith") { + sc = new SparkContext("local", "test") + val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) + val randoms = ones.mapWith( + (random: Double, t: Int) => random * t, + (index: Int, seed: Int) => { + val prng = new java.util.Random(index + seed) + (_ => prng.nextDouble)}, + 42). + collect() + val prn42_3 = { + val prng42 = new java.util.Random(42) + prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() + } + val prn43_3 = { + val prng43 = new java.util.Random(43) + prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() + } + assert(randoms(2) === prn42_3) + assert(randoms(5) === prn43_3) + } + + test("flatMapWith") { + sc = new SparkContext("local", "test") + val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) + val randoms = ones.flatMapWith( + (random: Double, t: Int) => Seq(random * t, random * t * 10), + (index: Int, seed: Int) => { + val prng = new java.util.Random(index + seed) + (_ => prng.nextDouble)}, + 42). + collect() + val prn42_3 = { + val prng42 = new java.util.Random(42) + prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() + } + val prn43_3 = { + val prng43 = new java.util.Random(43) + prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() + } + assert(randoms(5) === prn42_3 * 10) + assert(randoms(11) === prn43_3 * 10) + } + + test("filterWith") { + import java.util.Random + sc = new SparkContext("local", "test") + val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) + val sample = ints.filterWith( + (random: Int, t: Int) => random == 0, + (index: Int, seed: Int) => { + val prng = new Random(index + seed) + (_ => prng.nextInt(3))}, + 42). + collect() + val checkSample = { + val prng42 = new Random(42) + val prng43 = new Random(43) + Array(1, 2, 3, 4, 5, 6).filter{i => + if (i < 4) 0 == prng42.nextInt(3) + else 0 == prng43.nextInt(3)} + } + assert(sample.size === checkSample.size) + for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) + } }