mapWith, flatMapWith and filterWith
This commit is contained in:
parent
9f0dc829cb
commit
9148b968cf
|
@ -364,6 +364,63 @@ abstract class RDD[T: ClassManifest](
|
|||
preservesPartitioning: Boolean = false): RDD[U] =
|
||||
new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
|
||||
|
||||
/**
|
||||
* Maps f over this RDD where f takes an additional parameter of type A. This
|
||||
* additional parameter is produced by a factory method T => A which is called
|
||||
* on each invocation of f. This factory method is produced by the factoryBuilder,
|
||||
* an instance of which is constructed in each partition from the partition index
|
||||
* and a seed value of type B.
|
||||
*/
|
||||
def mapWith[A: ClassManifest, B: ClassManifest, U: ClassManifest](
|
||||
f:(A, T) => U,
|
||||
factoryBuilder: (Int, B) => (T => A),
|
||||
factorySeed: B,
|
||||
preservesPartitioning: Boolean = false): RDD[U] = {
|
||||
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
|
||||
val factory = factoryBuilder(index, factorySeed)
|
||||
iter.map(t => f(factory(t), t))
|
||||
}
|
||||
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
|
||||
}
|
||||
|
||||
/**
|
||||
* FlatMaps f over this RDD where f takes an additional parameter of type A. This
|
||||
* additional parameter is produced by a factory method T => A which is called
|
||||
* on each invocation of f. This factory method is produced by the factoryBuilder,
|
||||
* an instance of which is constructed in each partition from the partition index
|
||||
* and a seed value of type B.
|
||||
*/
|
||||
def flatMapWith[A: ClassManifest, B: ClassManifest, U: ClassManifest](
|
||||
f:(A, T) => Seq[U],
|
||||
factoryBuilder: (Int, B) => (T => A),
|
||||
factorySeed: B,
|
||||
preservesPartitioning: Boolean = false): RDD[U] = {
|
||||
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
|
||||
val factory = factoryBuilder(index, factorySeed)
|
||||
iter.flatMap(t => f(factory(t), t))
|
||||
}
|
||||
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters this RDD with p, where p takes an additional parameter of type A. This
|
||||
* additional parameter is produced by a factory method T => A which is called
|
||||
* on each invocation of p. This factory method is produced by the factoryBuilder,
|
||||
* an instance of which is constructed in each partition from the partition index
|
||||
* and a seed value of type B.
|
||||
*/
|
||||
def filterWith[A: ClassManifest, B: ClassManifest](
|
||||
p:(A, T) => Boolean,
|
||||
factoryBuilder: (Int, B) => (T => A),
|
||||
factorySeed: B,
|
||||
preservesPartitioning: Boolean = false): RDD[T] = {
|
||||
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
|
||||
val factory = factoryBuilder(index, factorySeed)
|
||||
iter.filter(t => p(factory(t), t))
|
||||
}
|
||||
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
|
||||
}
|
||||
|
||||
/**
|
||||
* Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
|
||||
* second element in each RDD, etc. Assumes that the two RDDs have the *same number of
|
||||
|
@ -404,7 +461,7 @@ abstract class RDD[T: ClassManifest](
|
|||
|
||||
/**
|
||||
* Return an RDD with the elements from `this` that are not in `other`.
|
||||
*
|
||||
*
|
||||
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
|
||||
* RDD will be <= us.
|
||||
*/
|
||||
|
|
|
@ -178,4 +178,70 @@ class RDDSuite extends FunSuite with LocalSparkContext {
|
|||
assert(prunedData.size === 1)
|
||||
assert(prunedData(0) === 10)
|
||||
}
|
||||
|
||||
test("mapWith") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
|
||||
val randoms = ones.mapWith(
|
||||
(random: Double, t: Int) => random * t,
|
||||
(index: Int, seed: Int) => {
|
||||
val prng = new java.util.Random(index + seed)
|
||||
(_ => prng.nextDouble)},
|
||||
42).
|
||||
collect()
|
||||
val prn42_3 = {
|
||||
val prng42 = new java.util.Random(42)
|
||||
prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble()
|
||||
}
|
||||
val prn43_3 = {
|
||||
val prng43 = new java.util.Random(43)
|
||||
prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble()
|
||||
}
|
||||
assert(randoms(2) === prn42_3)
|
||||
assert(randoms(5) === prn43_3)
|
||||
}
|
||||
|
||||
test("flatMapWith") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
|
||||
val randoms = ones.flatMapWith(
|
||||
(random: Double, t: Int) => Seq(random * t, random * t * 10),
|
||||
(index: Int, seed: Int) => {
|
||||
val prng = new java.util.Random(index + seed)
|
||||
(_ => prng.nextDouble)},
|
||||
42).
|
||||
collect()
|
||||
val prn42_3 = {
|
||||
val prng42 = new java.util.Random(42)
|
||||
prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble()
|
||||
}
|
||||
val prn43_3 = {
|
||||
val prng43 = new java.util.Random(43)
|
||||
prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble()
|
||||
}
|
||||
assert(randoms(5) === prn42_3 * 10)
|
||||
assert(randoms(11) === prn43_3 * 10)
|
||||
}
|
||||
|
||||
test("filterWith") {
|
||||
import java.util.Random
|
||||
sc = new SparkContext("local", "test")
|
||||
val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2)
|
||||
val sample = ints.filterWith(
|
||||
(random: Int, t: Int) => random == 0,
|
||||
(index: Int, seed: Int) => {
|
||||
val prng = new Random(index + seed)
|
||||
(_ => prng.nextInt(3))},
|
||||
42).
|
||||
collect()
|
||||
val checkSample = {
|
||||
val prng42 = new Random(42)
|
||||
val prng43 = new Random(43)
|
||||
Array(1, 2, 3, 4, 5, 6).filter{i =>
|
||||
if (i < 4) 0 == prng42.nextInt(3)
|
||||
else 0 == prng43.nextInt(3)}
|
||||
}
|
||||
assert(sample.size === checkSample.size)
|
||||
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue