[SPARK-11564][SQL] Dataset Java API audit

A few changes:

1. Removed fold, since it can be confusing for distributed collections.
2. Created specific interfaces for each Dataset function (e.g. MapFunction, ReduceFunction, MapPartitionsFunction)
3. Added more documentation and test cases.

The other thing I'm considering doing is to have a "collector" interface for FlatMapFunction and MapPartitionsFunction, similar to MapReduce's map function.

Author: Reynold Xin <rxin@databricks.com>

Closes #9531 from rxin/SPARK-11564.
This commit is contained in:
Reynold Xin 2015-11-08 20:57:09 -08:00
parent b2d195e137
commit 97b7080cf2
14 changed files with 317 additions and 98 deletions

View file

@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function;
import java.io.Serializable;
/**
* Base interface for a function used in Dataset's filter function.
*
* If the function returns true, the element is discarded in the returned Dataset.
*/
public interface FilterFunction<T> extends Serializable {
boolean call(T value) throws Exception;
}

View file

@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function;
import java.io.Serializable;
/**
* Base interface for a function used in Dataset's foreach function.
*
* Spark will invoke the call function on each element in the input Dataset.
*/
public interface ForeachFunction<T> extends Serializable {
void call(T t) throws Exception;
}

View file

@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function;
import java.io.Serializable;
import java.util.Iterator;
/**
* Base interface for a function used in Dataset's foreachPartition function.
*/
public interface ForeachPartitionFunction<T> extends Serializable {
void call(Iterator<T> t) throws Exception;
}

View file

@ -23,5 +23,5 @@ import java.io.Serializable;
* A zero-argument function that returns an R.
*/
public interface Function0<R> extends Serializable {
public R call() throws Exception;
R call() throws Exception;
}

View file

@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function;
import java.io.Serializable;
/**
* Base interface for a map function used in Dataset's map function.
*/
public interface MapFunction<T, U> extends Serializable {
U call(T value) throws Exception;
}

View file

@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function;
import java.io.Serializable;
import java.util.Iterator;
/**
* Base interface for function used in Dataset's mapPartitions.
*/
public interface MapPartitionsFunction<T, U> extends Serializable {
Iterable<U> call(Iterator<T> input) throws Exception;
}

View file

@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function;
import java.io.Serializable;
/**
* Base interface for function used in Dataset's reduce.
*/
public interface ReduceFunction<T> extends Serializable {
T call(T v1, T v2) throws Exception;
}

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.encoders
import scala.reflect.ClassTag
import org.apache.spark.util.Utils
import org.apache.spark.sql.types.{DataType, ObjectType, StructField, StructType}
import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
import org.apache.spark.sql.catalyst.expressions._
/**
@ -100,7 +100,7 @@ object Encoder {
expr.transformUp {
case BoundReference(0, t: ObjectType, _) =>
Invoke(
BoundReference(0, ObjectType(cls), true),
BoundReference(0, ObjectType(cls), nullable = true),
s"_${index + 1}",
t)
}
@ -114,13 +114,13 @@ object Encoder {
} else {
enc.constructExpression.transformUp {
case BoundReference(ordinal, dt, _) =>
GetInternalRowField(BoundReference(index, enc.schema, true), ordinal, dt)
GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt)
}
}
}
val constructExpression =
NewInstance(cls, constructExpressions, false, ObjectType(cls))
NewInstance(cls, constructExpressions, propagateNull = false, ObjectType(cls))
new ExpressionEncoder[Any](
schema,
@ -130,7 +130,6 @@ object Encoder {
ClassTag.apply(cls))
}
def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)]
private def getTypeTag[T](c: Class[T]): TypeTag[T] = {
@ -148,9 +147,36 @@ object Encoder {
})
}
def forTuple2[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = {
def forTuple[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = {
implicit val typeTag1 = getTypeTag(c1)
implicit val typeTag2 = getTypeTag(c2)
ExpressionEncoder[(T1, T2)]()
}
def forTuple[T1, T2, T3](c1: Class[T1], c2: Class[T2], c3: Class[T3]): Encoder[(T1, T2, T3)] = {
implicit val typeTag1 = getTypeTag(c1)
implicit val typeTag2 = getTypeTag(c2)
implicit val typeTag3 = getTypeTag(c3)
ExpressionEncoder[(T1, T2, T3)]()
}
def forTuple[T1, T2, T3, T4](
c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4]): Encoder[(T1, T2, T3, T4)] = {
implicit val typeTag1 = getTypeTag(c1)
implicit val typeTag2 = getTypeTag(c2)
implicit val typeTag3 = getTypeTag(c3)
implicit val typeTag4 = getTypeTag(c4)
ExpressionEncoder[(T1, T2, T3, T4)]()
}
def forTuple[T1, T2, T3, T4, T5](
c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4], c5: Class[T5])
: Encoder[(T1, T2, T3, T4, T5)] = {
implicit val typeTag1 = getTypeTag(c1)
implicit val typeTag2 = getTypeTag(c2)
implicit val typeTag3 = getTypeTag(c3)
implicit val typeTag4 = getTypeTag(c4)
implicit val typeTag5 = getTypeTag(c5)
ExpressionEncoder[(T1, T2, T3, T4, T5)]()
}
}

View file

@ -1478,18 +1478,54 @@ class DataFrame private[sql](
/**
* Returns the first `n` rows in the [[DataFrame]].
*
* Running take requires moving data into the application's driver process, and doing so on a
* very large dataset can crash the driver process with OutOfMemoryError.
*
* @group action
* @since 1.3.0
*/
def take(n: Int): Array[Row] = head(n)
/**
* Returns the first `n` rows in the [[DataFrame]] as a list.
*
* Running take requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
*
* @group action
* @since 1.6.0
*/
def takeAsList(n: Int): java.util.List[Row] = java.util.Arrays.asList(take(n) : _*)
/**
* Returns an array that contains all of [[Row]]s in this [[DataFrame]].
*
* Running take requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
*
* For Java API, use [[collectAsList]].
*
* @group action
* @since 1.3.0
*/
def collect(): Array[Row] = collect(needCallback = true)
/**
* Returns a Java list that contains all of [[Row]]s in this [[DataFrame]].
*
* Running collect requires moving all the data into the application's driver process, and
* doing so on a very large dataset can crash the driver process with OutOfMemoryError.
*
* @group action
* @since 1.3.0
*/
def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ =>
withNewExecutionId {
java.util.Arrays.asList(rdd.collect() : _*)
}
}
private def collect(needCallback: Boolean): Array[Row] = {
def execute(): Array[Row] = withNewExecutionId {
queryExecution.executedPlan.executeCollectPublic()
@ -1502,17 +1538,6 @@ class DataFrame private[sql](
}
}
/**
* Returns a Java list that contains all of [[Row]]s in this [[DataFrame]].
* @group action
* @since 1.3.0
*/
def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ =>
withNewExecutionId {
java.util.Arrays.asList(rdd.collect() : _*)
}
}
/**
* Returns the number of rows in the [[DataFrame]].
* @group action

View file

@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _}
import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
@ -75,7 +75,11 @@ class Dataset[T] private[sql](
private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
this(sqlContext, new QueryExecution(sqlContext, plan), encoder)
/** Returns the schema of the encoded form of the objects in this [[Dataset]]. */
/**
* Returns the schema of the encoded form of the objects in this [[Dataset]].
*
* @since 1.6.0
*/
def schema: StructType = encoder.schema
/* ************* *
@ -103,6 +107,7 @@ class Dataset[T] private[sql](
/**
* Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have
* the same name after two Datasets have been joined.
* @since 1.6.0
*/
def as(alias: String): Dataset[T] = withPlan(Subquery(alias, _))
@ -166,8 +171,7 @@ class Dataset[T] private[sql](
* Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
* @since 1.6.0
*/
def filter(func: JFunction[T, java.lang.Boolean]): Dataset[T] =
filter(t => func.call(t).booleanValue())
def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t))
/**
* (Scala-specific)
@ -181,7 +185,7 @@ class Dataset[T] private[sql](
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
* @since 1.6.0
*/
def map[U](func: JFunction[T, U], encoder: Encoder[U]): Dataset[U] =
def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
map(t => func.call(t))(encoder)
/**
@ -205,10 +209,8 @@ class Dataset[T] private[sql](
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
* @since 1.6.0
*/
def mapPartitions[U](
f: FlatMapFunction[java.util.Iterator[T], U],
encoder: Encoder[U]): Dataset[U] = {
val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator().asScala
def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator.asScala
mapPartitions(func)(encoder)
}
@ -248,7 +250,7 @@ class Dataset[T] private[sql](
* Runs `func` on each element of this Dataset.
* @since 1.6.0
*/
def foreach(func: VoidFunction[T]): Unit = foreach(func.call(_))
def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_))
/**
* (Scala-specific)
@ -262,7 +264,7 @@ class Dataset[T] private[sql](
* Runs `func` on each partition of this Dataset.
* @since 1.6.0
*/
def foreachPartition(func: VoidFunction[java.util.Iterator[T]]): Unit =
def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
foreachPartition(it => func.call(it.asJava))
/* ************* *
@ -271,7 +273,7 @@ class Dataset[T] private[sql](
/**
* (Scala-specific)
* Reduces the elements of this Dataset using the specified binary function. The given function
* Reduces the elements of this Dataset using the specified binary function. The given function
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
*/
@ -279,33 +281,11 @@ class Dataset[T] private[sql](
/**
* (Java-specific)
* Reduces the elements of this Dataset using the specified binary function. The given function
* Reduces the elements of this Dataset using the specified binary function. The given function
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
*/
def reduce(func: JFunction2[T, T, T]): T = reduce(func.call(_, _))
/**
* (Scala-specific)
* Aggregates the elements of each partition, and then the results for all the partitions, using a
* given associative and commutative function and a neutral "zero value".
*
* This behaves somewhat differently than the fold operations implemented for non-distributed
* collections in functional languages like Scala. This fold operation may be applied to
* partitions individually, and then those results will be folded into the final result.
* If op is not commutative, then the result may differ from that of a fold applied to a
* non-distributed collection.
* @since 1.6.0
*/
def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op)
/**
* (Java-specific)
* Aggregates the elements of each partition, and then the results for all the partitions, using a
* given associative and commutative function and a neutral "zero value".
* @since 1.6.0
*/
def fold(zeroValue: T, func: JFunction2[T, T, T]): T = fold(zeroValue)(func.call(_, _))
def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _))
/**
* (Scala-specific)
@ -351,7 +331,7 @@ class Dataset[T] private[sql](
* Returns a [[GroupedDataset]] where the data is grouped by the given key function.
* @since 1.6.0
*/
def groupBy[K](f: JFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
def groupBy[K](f: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
groupBy(f.call(_))(encoder)
/* ****************** *
@ -367,7 +347,7 @@ class Dataset[T] private[sql](
*/
// Copied from Dataframe to make sure we don't have invalid overloads.
@scala.annotation.varargs
def select(cols: Column*): DataFrame = toDF().select(cols: _*)
protected def select(cols: Column*): DataFrame = toDF().select(cols: _*)
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expression for each element.
@ -462,8 +442,7 @@ class Dataset[T] private[sql](
* and thus is not affected by a custom `equals` function defined on `T`.
* @since 1.6.0
*/
def intersect(other: Dataset[T]): Dataset[T] =
withPlan[T](other)(Intersect)
def intersect(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Intersect)
/**
* Returns a new [[Dataset]] that contains the elements of both this and the `other` [[Dataset]]
@ -473,8 +452,7 @@ class Dataset[T] private[sql](
* duplicate items. As such, it is analagous to `UNION ALL` in SQL.
* @since 1.6.0
*/
def union(other: Dataset[T]): Dataset[T] =
withPlan[T](other)(Union)
def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union)
/**
* Returns a new [[Dataset]] where any elements present in `other` have been removed.
@ -542,27 +520,47 @@ class Dataset[T] private[sql](
def first(): T = rdd.first()
/**
* Collects the elements to an Array.
* Returns an array that contains all the elements in this [[Dataset]].
*
* Running collect requires moving all the data into the application's driver process, and
* doing so on a very large dataset can crash the driver process with OutOfMemoryError.
*
* For Java API, use [[collectAsList]].
* @since 1.6.0
*/
def collect(): Array[T] = rdd.collect()
/**
* (Java-specific)
* Collects the elements to a Java list.
* Returns an array that contains all the elements in this [[Dataset]].
*
* Due to the incompatibility problem between Scala and Java, the return type of [[collect()]] at
* Java side is `java.lang.Object`, which is not easy to use. Java user can use this method
* instead and keep the generic type for result.
* Running collect requires moving all the data into the application's driver process, and
* doing so on a very large dataset can crash the driver process with OutOfMemoryError.
*
* For Java API, use [[collectAsList]].
* @since 1.6.0
*/
def collectAsList(): java.util.List[T] = rdd.collect().toSeq.asJava
/**
* Returns the first `num` elements of this [[Dataset]] as an array.
*
* Running take requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
*
* @since 1.6.0
*/
def collectAsList(): java.util.List[T] =
rdd.collect().toSeq.asJava
/** Returns the first `num` elements of this [[Dataset]] as an Array. */
def take(num: Int): Array[T] = rdd.take(num)
/**
* Returns the first `num` elements of this [[Dataset]] as an array.
*
* Running take requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
*
* @since 1.6.0
*/
def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)
/* ******************** *
* Internal Functions *
* ******************** */

View file

@ -65,6 +65,13 @@ public class JavaDataFrameSuite {
Assert.assertEquals(1, df.select("key").collect()[0].get(0));
}
@Test
public void testCollectAndTake() {
DataFrame df = context.table("testData").filter("key = 1 or key = 2 or key = 3");
Assert.assertEquals(3, df.select("key").collectAsList().size());
Assert.assertEquals(2, df.select("key").takeAsList(2).size());
}
/**
* See SPARK-5904. Abstract vararg methods defined in Scala do not work in Java.
*/

View file

@ -68,8 +68,16 @@ public class JavaDatasetSuite implements Serializable {
public void testCollect() {
List<String> data = Arrays.asList("hello", "world");
Dataset<String> ds = context.createDataset(data, e.STRING());
String[] collected = (String[]) ds.collect();
Assert.assertEquals(Arrays.asList("hello", "world"), Arrays.asList(collected));
List<String> collected = ds.collectAsList();
Assert.assertEquals(Arrays.asList("hello", "world"), collected);
}
@Test
public void testTake() {
List<String> data = Arrays.asList("hello", "world");
Dataset<String> ds = context.createDataset(data, e.STRING());
List<String> collected = ds.takeAsList(1);
Assert.assertEquals(Arrays.asList("hello"), collected);
}
@Test
@ -78,16 +86,16 @@ public class JavaDatasetSuite implements Serializable {
Dataset<String> ds = context.createDataset(data, e.STRING());
Assert.assertEquals("hello", ds.first());
Dataset<String> filtered = ds.filter(new Function<String, Boolean>() {
Dataset<String> filtered = ds.filter(new FilterFunction<String>() {
@Override
public Boolean call(String v) throws Exception {
public boolean call(String v) throws Exception {
return v.startsWith("h");
}
});
Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList());
Dataset<Integer> mapped = ds.map(new Function<String, Integer>() {
Dataset<Integer> mapped = ds.map(new MapFunction<String, Integer>() {
@Override
public Integer call(String v) throws Exception {
return v.length();
@ -95,7 +103,7 @@ public class JavaDatasetSuite implements Serializable {
}, e.INT());
Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList());
Dataset<String> parMapped = ds.mapPartitions(new FlatMapFunction<Iterator<String>, String>() {
Dataset<String> parMapped = ds.mapPartitions(new MapPartitionsFunction<String, String>() {
@Override
public Iterable<String> call(Iterator<String> it) throws Exception {
List<String> ls = new LinkedList<String>();
@ -128,7 +136,7 @@ public class JavaDatasetSuite implements Serializable {
List<String> data = Arrays.asList("a", "b", "c");
Dataset<String> ds = context.createDataset(data, e.STRING());
ds.foreach(new VoidFunction<String>() {
ds.foreach(new ForeachFunction<String>() {
@Override
public void call(String s) throws Exception {
accum.add(1);
@ -142,28 +150,20 @@ public class JavaDatasetSuite implements Serializable {
List<Integer> data = Arrays.asList(1, 2, 3);
Dataset<Integer> ds = context.createDataset(data, e.INT());
int reduced = ds.reduce(new Function2<Integer, Integer, Integer>() {
int reduced = ds.reduce(new ReduceFunction<Integer>() {
@Override
public Integer call(Integer v1, Integer v2) throws Exception {
return v1 + v2;
}
});
Assert.assertEquals(6, reduced);
int folded = ds.fold(1, new Function2<Integer, Integer, Integer>() {
@Override
public Integer call(Integer v1, Integer v2) throws Exception {
return v1 * v2;
}
});
Assert.assertEquals(6, folded);
}
@Test
public void testGroupBy() {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = context.createDataset(data, e.STRING());
GroupedDataset<Integer, String> grouped = ds.groupBy(new Function<String, Integer>() {
GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() {
@Override
public Integer call(String v) throws Exception {
return v.length();
@ -187,7 +187,7 @@ public class JavaDatasetSuite implements Serializable {
List<Integer> data2 = Arrays.asList(2, 6, 10);
Dataset<Integer> ds2 = context.createDataset(data2, e.INT());
GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new Function<Integer, Integer>() {
GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() {
@Override
public Integer call(Integer v) throws Exception {
return v / 2;

View file

@ -75,11 +75,6 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
assert(ds.reduce(_ + _) == 6)
}
test("fold") {
val ds = Seq(1, 2, 3).toDS()
assert(ds.fold(0)(_ + _) == 6)
}
test("groupBy function, keys") {
val ds = Seq(1, 2, 3, 4, 5).toDS()
val grouped = ds.groupBy(_ % 2)

View file

@ -61,6 +61,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(ds.collect() === Array(ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)))
}
test("as case class - take") {
val ds = Seq((1, "a"), (2, "b"), (3, "c")).toDF("b", "a").as[ClassData]
assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2)))
}
test("map") {
val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
checkAnswer(
@ -137,11 +142,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == ("sum", 6))
}
test("fold") {
val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6))
}
test("joinWith, flat schema") {
val ds1 = Seq(1, 2, 3).toDS().as("a")
val ds2 = Seq(1, 2).toDS().as("b")