Merge pull request #353 from stephenh/tupleBy

Add RDD.tupleBy.
This commit is contained in:
Matei Zaharia 2013-01-08 22:24:03 -08:00
commit d0bae072ea
4 changed files with 28 additions and 0 deletions

View file

@ -517,6 +517,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
.saveAsSequenceFile(path)
}
/**
* Creates tuples of the elements in this RDD by applying `f`.
*/
def keyBy[K](f: T => K): RDD[(K, T)] = {
map(x => (f(x), x))
}
/** A private method for tests, to look at the contents of each partition */
private[spark] def collectPartitions(): Array[Array[T]] = {
sc.runJob(this, (iter: Iterator[T]) => iter.toArray)

View file

@ -298,4 +298,12 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Save this RDD as a SequenceFile of serialized objects.
*/
def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path)
/**
* Creates tuples of the elements in this RDD by applying `f`.
*/
def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = {
implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
JavaPairRDD.fromRDD(rdd.keyBy(f))
}
}

View file

@ -629,4 +629,16 @@ public class JavaAPISuite implements Serializable {
floatAccum.setValue(5.0f);
Assert.assertEquals((Float) 5.0f, floatAccum.value());
}
@Test
public void keyBy() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2));
List<Tuple2<String, Integer>> s = rdd.keyBy(new Function<Integer, String>() {
public String call(Integer t) throws Exception {
return t.toString();
}
}).collect();
Assert.assertEquals(new Tuple2<String, Integer>("1", 1), s.get(0));
Assert.assertEquals(new Tuple2<String, Integer>("2", 2), s.get(1));
}
}

View file

@ -36,6 +36,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4))
assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4)))
assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4"))
assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4)))
val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _)))
assert(partitionSums.collect().toList === List(3, 7))