Adding flatMap

This commit is contained in:
Patrick Wendell 2013-01-17 09:04:56 -08:00
parent d5570c7968
commit 61b877c688
2 changed files with 106 additions and 4 deletions

View file

@ -64,6 +64,27 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable
new JavaPairDStream(dstream.map(f)(cm))(f.keyType(), f.valueType())
}
/**
* Return a new DStream by applying a function to all elements of this DStream,
* and then flattening the results
*/
def flatMap[U](f: FlatMapFunction[T, U]): JavaDStream[U] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
new JavaDStream(dstream.flatMap(fn)(f.elementType()))(f.elementType())
}
/**
* Return a new DStream by applying a function to all elements of this DStream,
* and then flattening the results
*/
def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairDStream[K, V] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
new JavaPairDStream(dstream.flatMap(fn)(cm))(f.keyType(), f.valueType())
}
/**
* Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs
* of this DStream. Applying mapPartitions() to an RDD applies a function to each partition
@ -151,4 +172,12 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable
transformFunc.call(new JavaRDD[T](in), time).rdd
dstream.transform(scalaTransform(_, _))
}
/**
* Enable periodic checkpointing of RDDs of this DStream
* @param interval Time interval after which generated RDD will be checkpointed
*/
def checkpoint(interval: Duration) = {
dstream.checkpoint(interval)
}
}

View file

@ -11,10 +11,7 @@ import org.junit.Test;
import scala.Tuple2;
import spark.HashPartitioner;
import spark.api.java.JavaRDD;
import spark.api.java.function.FlatMapFunction;
import spark.api.java.function.Function;
import spark.api.java.function.Function2;
import spark.api.java.function.PairFunction;
import spark.api.java.function.*;
import spark.storage.StorageLevel;
import spark.streaming.api.java.JavaDStream;
import spark.streaming.api.java.JavaPairDStream;
@ -308,6 +305,82 @@ public class JavaAPISuite implements Serializable {
assertOrderInvariantEquals(expected, result);
}
@Test
public void testFlatMap() {
List<List<String>> inputData = Arrays.asList(
Arrays.asList("go", "giants"),
Arrays.asList("boo", "dodgers"),
Arrays.asList("athletics"));
List<List<String>> expected = Arrays.asList(
Arrays.asList("g","o","g","i","a","n","t","s"),
Arrays.asList("b", "o", "o", "d","o","d","g","e","r","s"),
Arrays.asList("a","t","h","l","e","t","i","c","s"));
JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1);
JavaDStream flatMapped = stream.flatMap(new FlatMapFunction<String, String>() {
@Override
public Iterable<String> call(String x) {
return Lists.newArrayList(x.split("(?!^)"));
}
});
JavaTestUtils.attachTestOutputStream(flatMapped);
List<List<String>> result = JavaTestUtils.runStreams(sc, 3, 3);
assertOrderInvariantEquals(expected, result);
}
@Test
public void testPairFlatMap() {
List<List<String>> inputData = Arrays.asList(
Arrays.asList("giants"),
Arrays.asList("dodgers"),
Arrays.asList("athletics"));
List<List<Tuple2<Integer, String>>> expected = Arrays.asList(
Arrays.asList(
new Tuple2<Integer, String>(6, "g"),
new Tuple2<Integer, String>(6, "i"),
new Tuple2<Integer, String>(6, "a"),
new Tuple2<Integer, String>(6, "n"),
new Tuple2<Integer, String>(6, "t"),
new Tuple2<Integer, String>(6, "s")),
Arrays.asList(
new Tuple2<Integer, String>(7, "d"),
new Tuple2<Integer, String>(7, "o"),
new Tuple2<Integer, String>(7, "d"),
new Tuple2<Integer, String>(7, "g"),
new Tuple2<Integer, String>(7, "e"),
new Tuple2<Integer, String>(7, "r"),
new Tuple2<Integer, String>(7, "s")),
Arrays.asList(
new Tuple2<Integer, String>(9, "a"),
new Tuple2<Integer, String>(9, "t"),
new Tuple2<Integer, String>(9, "h"),
new Tuple2<Integer, String>(9, "l"),
new Tuple2<Integer, String>(9, "e"),
new Tuple2<Integer, String>(9, "t"),
new Tuple2<Integer, String>(9, "i"),
new Tuple2<Integer, String>(9, "c"),
new Tuple2<Integer, String>(9, "s")));
JavaDStream stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1);
JavaPairDStream flatMapped = stream.flatMap(new PairFlatMapFunction<String, Integer, String>() {
@Override
public Iterable<Tuple2<Integer, String>> call(String in) throws Exception {
List<Tuple2<Integer, String>> out = Lists.newArrayList();
for (String letter: in.split("(?!^)")) {
out.add(new Tuple2<Integer, String>(in.length(), letter));
}
return out;
}
});
JavaTestUtils.attachTestOutputStream(flatMapped);
List<List<Tuple2<Integer, String>>> result = JavaTestUtils.runStreams(sc, 3, 3);
Assert.assertEquals(expected, result);
}
@Test
public void testUnion() {
List<List<Integer>> inputData1 = Arrays.asList(