[SPARK-5904][SQL] DataFrame Java API test suites.

Added a new test suite to make sure Java DF programs can use varargs properly.
Also moved all suites into test.org.apache.spark package to make sure the suites also test for method visibility.

Author: Reynold Xin <rxin@databricks.com>

Closes #4751 from rxin/df-tests and squashes the following commits:

1e8b8e4 [Reynold Xin] Fixed imports and renamed JavaAPISuite.
a6ca53b [Reynold Xin] [SPARK-5904][SQL] DataFrame Java API test suites.
This commit is contained in:
Reynold Xin 2015-02-24 18:51:41 -08:00 committed by Michael Armbrust
parent f816e73902
commit 53a1ebf33b
7 changed files with 108 additions and 143 deletions

View file

@ -1,120 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.api.java;
import com.google.common.collect.ImmutableMap;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.types.DataTypes;
import static org.apache.spark.sql.functions.*;
/**
* This test doesn't actually run anything. It is here to check the API compatibility for Java.
*/
public class JavaDsl {
public static void testDataFrame(final DataFrame df) {
DataFrame df1 = df.select("colA");
df1 = df.select("colA", "colB");
df1 = df.select(col("colA"), col("colB"), lit("literal value").$plus(1));
df1 = df.filter(col("colA"));
java.util.Map<String, String> aggExprs = ImmutableMap.<String, String>builder()
.put("colA", "sum")
.put("colB", "avg")
.build();
df1 = df.agg(aggExprs);
df1 = df.groupBy("groupCol").agg(aggExprs);
df1 = df.join(df1, col("key1").$eq$eq$eq(col("key2")), "outer");
df.orderBy("colA");
df.orderBy("colA", "colB", "colC");
df.orderBy(col("colA").desc());
df.orderBy(col("colA").desc(), col("colB").asc());
df.sort("colA");
df.sort("colA", "colB", "colC");
df.sort(col("colA").desc());
df.sort(col("colA").desc(), col("colB").asc());
df.as("b");
df.limit(5);
df.unionAll(df1);
df.intersect(df1);
df.except(df1);
df.sample(true, 0.1, 234);
df.head();
df.head(5);
df.first();
df.count();
}
public static void testColumn(final Column c) {
c.asc();
c.desc();
c.endsWith("abcd");
c.startsWith("afgasdf");
c.like("asdf%");
c.rlike("wef%asdf");
c.as("newcol");
c.cast("int");
c.cast(DataTypes.IntegerType);
}
public static void testDsl() {
// Creating a column.
Column c = col("abcd");
Column c1 = column("abcd");
// Literals
Column l1 = lit(1);
Column l2 = lit(1.0);
Column l3 = lit("abcd");
// Functions
Column a = upper(c);
a = lower(c);
a = sqrt(c);
a = abs(c);
// Aggregates
a = min(c);
a = max(c);
a = sum(c);
a = sumDistinct(c);
a = countDistinct(c, a);
a = avg(c);
a = first(c);
a = last(c);
}
}

View file

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.sql.api.java;
package test.org.apache.spark.sql;
import java.io.Serializable;
import java.util.ArrayList;
@ -39,18 +39,18 @@ import org.apache.spark.sql.types.*;
// see http://stackoverflow.com/questions/758570/.
public class JavaApplySchemaSuite implements Serializable {
private transient JavaSparkContext javaCtx;
private transient SQLContext javaSqlCtx;
private transient SQLContext sqlContext;
@Before
public void setUp() {
javaSqlCtx = TestSQLContext$.MODULE$;
javaCtx = new JavaSparkContext(javaSqlCtx.sparkContext());
sqlContext = TestSQLContext$.MODULE$;
javaCtx = new JavaSparkContext(sqlContext.sparkContext());
}
@After
public void tearDown() {
javaCtx = null;
javaSqlCtx = null;
sqlContext = null;
}
public static class Person implements Serializable {
@ -98,9 +98,9 @@ public class JavaApplySchemaSuite implements Serializable {
fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
StructType schema = DataTypes.createStructType(fields);
DataFrame df = javaSqlCtx.applySchema(rowRDD, schema);
DataFrame df = sqlContext.applySchema(rowRDD, schema);
df.registerTempTable("people");
Row[] actual = javaSqlCtx.sql("SELECT * FROM people").collect();
Row[] actual = sqlContext.sql("SELECT * FROM people").collect();
List<Row> expected = new ArrayList<Row>(2);
expected.add(RowFactory.create("Michael", 29));
@ -109,8 +109,6 @@ public class JavaApplySchemaSuite implements Serializable {
Assert.assertEquals(expected, Arrays.asList(actual));
}
@Test
public void dataFrameRDDOperations() {
List<Person> personList = new ArrayList<Person>(2);
@ -135,9 +133,9 @@ public class JavaApplySchemaSuite implements Serializable {
fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
StructType schema = DataTypes.createStructType(fields);
DataFrame df = javaSqlCtx.applySchema(rowRDD, schema);
DataFrame df = sqlContext.applySchema(rowRDD, schema);
df.registerTempTable("people");
List<String> actual = javaSqlCtx.sql("SELECT * FROM people").toJavaRDD().map(new Function<Row, String>() {
List<String> actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function<Row, String>() {
public String call(Row row) {
return row.getString(0) + "_" + row.get(1).toString();
@ -189,18 +187,18 @@ public class JavaApplySchemaSuite implements Serializable {
null,
"this is another simple string."));
DataFrame df1 = javaSqlCtx.jsonRDD(jsonRDD);
DataFrame df1 = sqlContext.jsonRDD(jsonRDD);
StructType actualSchema1 = df1.schema();
Assert.assertEquals(expectedSchema, actualSchema1);
df1.registerTempTable("jsonTable1");
List<Row> actual1 = javaSqlCtx.sql("select * from jsonTable1").collectAsList();
List<Row> actual1 = sqlContext.sql("select * from jsonTable1").collectAsList();
Assert.assertEquals(expectedResult, actual1);
DataFrame df2 = javaSqlCtx.jsonRDD(jsonRDD, expectedSchema);
DataFrame df2 = sqlContext.jsonRDD(jsonRDD, expectedSchema);
StructType actualSchema2 = df2.schema();
Assert.assertEquals(expectedSchema, actualSchema2);
df2.registerTempTable("jsonTable2");
List<Row> actual2 = javaSqlCtx.sql("select * from jsonTable2").collectAsList();
List<Row> actual2 = sqlContext.sql("select * from jsonTable2").collectAsList();
Assert.assertEquals(expectedResult, actual2);
}
}

View file

@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package test.org.apache.spark.sql;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.sql.*;
import org.apache.spark.sql.test.TestSQLContext$;
import static org.apache.spark.sql.functions.*;
public class JavaDataFrameSuite {
private transient SQLContext context;
@Before
public void setUp() {
// Trigger static initializer of TestData
TestData$.MODULE$.testData();
context = TestSQLContext$.MODULE$;
}
@After
public void tearDown() {
context = null;
}
@Test
public void testExecution() {
DataFrame df = context.table("testData").filter("key = 1");
Assert.assertEquals(df.select("key").collect()[0].get(0), 1);
}
/**
* See SPARK-5904. Abstract vararg methods defined in Scala do not work in Java.
*/
@Test
public void testVarargMethods() {
DataFrame df = context.table("testData");
df.toDF("key1", "value1");
df.select("key", "value");
df.select(col("key"), col("value"));
df.selectExpr("key", "value + 1");
df.sort("key", "value");
df.sort(col("key"), col("value"));
df.orderBy("key", "value");
df.orderBy(col("key"), col("value"));
df.groupBy("key", "value").agg(col("key"), col("value"), sum("value"));
df.groupBy(col("key"), col("value")).agg(col("key"), col("value"), sum("value"));
df.agg(first("key"), sum("value"));
df.groupBy().avg("key");
df.groupBy().mean("key");
df.groupBy().max("key");
df.groupBy().min("key");
df.groupBy().sum("key");
// Varargs in column expressions
df.groupBy().agg(countDistinct("key", "value"));
df.groupBy().agg(countDistinct(col("key"), col("value")));
df.select(coalesce(col("key")));
}
}

View file

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.sql.api.java;
package test.org.apache.spark.sql;
import java.math.BigDecimal;
import java.sql.Date;

View file

@ -15,24 +15,26 @@
* limitations under the License.
*/
package org.apache.spark.sql.api.java;
package test.org.apache.spark.sql;
import java.io.Serializable;
import org.apache.spark.sql.test.TestSQLContext$;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.api.java.UDF2;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.types.DataTypes;
// The test suite itself is Serializable so that anonymous Function implementations can be
// serialized, as an alternative to converting these anonymous classes to static inner classes;
// see http://stackoverflow.com/questions/758570/.
public class JavaAPISuite implements Serializable {
public class JavaUDFSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SQLContext sqlContext;

View file

@ -14,7 +14,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.sources;
package test.org.apache.spark.sql.sources;
import java.io.File;
import java.io.IOException;

View file

@ -411,7 +411,7 @@ class DataFrameSuite extends QueryTest {
)
}
test("addColumn") {
test("withColumn") {
val df = testData.toDF().withColumn("newCol", col("key") + 1)
checkAnswer(
df,
@ -421,7 +421,7 @@ class DataFrameSuite extends QueryTest {
assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol"))
}
test("renameColumn") {
test("withColumnRenamed") {
val df = testData.toDF().withColumn("newCol", col("key") + 1)
.withColumnRenamed("value", "valueRenamed")
checkAnswer(