[SPARK-11564][SQL][FOLLOW-UP] clean up java tuple encoder

We need to support custom classes like java beans and combine them into tuple, and it's very hard to do it with the  TypeTag-based approach.
We should keep only the compose-based way to create tuple encoder.

This PR also move `Encoder` to `org.apache.spark.sql`

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9567 from cloud-fan/java.
This commit is contained in:
Wenchen Fan 2015-11-11 10:52:23 -08:00 committed by Michael Armbrust
parent 9c57bc0efc
commit ec2b807212
14 changed files with 65 additions and 113 deletions

View file

@ -15,14 +15,15 @@
* limitations under the License. * limitations under the License.
*/ */
package org.apache.spark.sql.catalyst.encoders package org.apache.spark.sql
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
import org.apache.spark.util.Utils
import scala.reflect.ClassTag import scala.reflect.ClassTag
import org.apache.spark.util.Utils
import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
import org.apache.spark.sql.catalyst.expressions._
/** /**
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
* *
@ -38,9 +39,7 @@ trait Encoder[T] extends Serializable {
def clsTag: ClassTag[T] def clsTag: ClassTag[T]
} }
object Encoder { object Encoders {
import scala.reflect.runtime.universe._
def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true)
def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true)
def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true)
@ -129,54 +128,4 @@ object Encoder {
constructExpression, constructExpression,
ClassTag.apply(cls)) ClassTag.apply(cls))
} }
def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)]
private def getTypeTag[T](c: Class[T]): TypeTag[T] = {
import scala.reflect.api
// val mirror = runtimeMirror(c.getClassLoader)
val mirror = rootMirror
val sym = mirror.staticClass(c.getName)
val tpe = sym.selfType
TypeTag(mirror, new api.TypeCreator {
def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) =
if (m eq mirror) tpe.asInstanceOf[U # Type]
else throw new IllegalArgumentException(
s"Type tag defined in $mirror cannot be migrated to other mirrors.")
})
}
def forTuple[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = {
implicit val typeTag1 = getTypeTag(c1)
implicit val typeTag2 = getTypeTag(c2)
ExpressionEncoder[(T1, T2)]()
}
def forTuple[T1, T2, T3](c1: Class[T1], c2: Class[T2], c3: Class[T3]): Encoder[(T1, T2, T3)] = {
implicit val typeTag1 = getTypeTag(c1)
implicit val typeTag2 = getTypeTag(c2)
implicit val typeTag3 = getTypeTag(c3)
ExpressionEncoder[(T1, T2, T3)]()
}
def forTuple[T1, T2, T3, T4](
c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4]): Encoder[(T1, T2, T3, T4)] = {
implicit val typeTag1 = getTypeTag(c1)
implicit val typeTag2 = getTypeTag(c2)
implicit val typeTag3 = getTypeTag(c3)
implicit val typeTag4 = getTypeTag(c4)
ExpressionEncoder[(T1, T2, T3, T4)]()
}
def forTuple[T1, T2, T3, T4, T5](
c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4], c5: Class[T5])
: Encoder[(T1, T2, T3, T4, T5)] = {
implicit val typeTag1 = getTypeTag(c1)
implicit val typeTag2 = getTypeTag(c2)
implicit val typeTag3 = getTypeTag(c3)
implicit val typeTag4 = getTypeTag(c4)
implicit val typeTag5 = getTypeTag(c5)
ExpressionEncoder[(T1, T2, T3, T4, T5)]()
}
} }

View file

@ -17,18 +17,18 @@
package org.apache.spark.sql.catalyst.encoders package org.apache.spark.sql.catalyst.encoders
import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.util.Utils
import scala.reflect.ClassTag import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.reflect.runtime.universe.{typeTag, TypeTag}
import org.apache.spark.util.Utils
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types.{StructField, DataType, ObjectType, StructType} import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
/** /**
* A factory for constructing encoders that convert objects and primitves to and from the * A factory for constructing encoders that convert objects and primitves to and from the

View file

@ -17,10 +17,11 @@
package org.apache.spark.sql.catalyst package org.apache.spark.sql.catalyst
import org.apache.spark.sql.Encoder
package object encoders { package object encoders {
private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match {
case e: ExpressionEncoder[A] => e case e: ExpressionEncoder[A] => e
case _ => sys.error(s"Only expression encoders are supported today") case _ => sys.error(s"Only expression encoders are supported today")
} }
} }

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.plans.logical package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression

View file

@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging import org.apache.spark.Logging
import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.catalyst.util.DataTypeParser
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._

View file

@ -23,7 +23,6 @@ import java.util.Properties
import scala.language.implicitConversions import scala.language.implicitConversions
import scala.reflect.ClassTag import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
import com.fasterxml.jackson.core.JsonFactory import com.fasterxml.jackson.core.JsonFactory
import org.apache.commons.lang3.StringUtils import org.apache.commons.lang3.StringUtils
@ -35,7 +34,6 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.encoders.Encoder
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}

View file

@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3, _} import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3, _}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute}
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression

View file

@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}

View file

@ -20,9 +20,10 @@ package org.apache.spark.sql.execution.aggregate
import scala.language.existentials import scala.language.existentials
import org.apache.spark.Logging import org.apache.spark.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._

View file

@ -17,7 +17,8 @@
package org.apache.spark.sql.expressions package org.apache.spark.sql.expressions
import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn}

View file

@ -26,7 +26,7 @@ import scala.util.Try
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, Encoder} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint

View file

@ -30,8 +30,8 @@ import org.apache.spark.Accumulator;
import org.apache.spark.SparkContext; import org.apache.spark.SparkContext;
import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.function.*;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.catalyst.encoders.Encoder; import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.catalyst.encoders.Encoder$; import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.GroupedDataset; import org.apache.spark.sql.GroupedDataset;
import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.test.TestSQLContext;
@ -41,7 +41,6 @@ import static org.apache.spark.sql.functions.*;
public class JavaDatasetSuite implements Serializable { public class JavaDatasetSuite implements Serializable {
private transient JavaSparkContext jsc; private transient JavaSparkContext jsc;
private transient TestSQLContext context; private transient TestSQLContext context;
private transient Encoder$ e = Encoder$.MODULE$;
@Before @Before
public void setUp() { public void setUp() {
@ -66,7 +65,7 @@ public class JavaDatasetSuite implements Serializable {
@Test @Test
public void testCollect() { public void testCollect() {
List<String> data = Arrays.asList("hello", "world"); List<String> data = Arrays.asList("hello", "world");
Dataset<String> ds = context.createDataset(data, e.STRING()); Dataset<String> ds = context.createDataset(data, Encoders.STRING());
List<String> collected = ds.collectAsList(); List<String> collected = ds.collectAsList();
Assert.assertEquals(Arrays.asList("hello", "world"), collected); Assert.assertEquals(Arrays.asList("hello", "world"), collected);
} }
@ -74,7 +73,7 @@ public class JavaDatasetSuite implements Serializable {
@Test @Test
public void testTake() { public void testTake() {
List<String> data = Arrays.asList("hello", "world"); List<String> data = Arrays.asList("hello", "world");
Dataset<String> ds = context.createDataset(data, e.STRING()); Dataset<String> ds = context.createDataset(data, Encoders.STRING());
List<String> collected = ds.takeAsList(1); List<String> collected = ds.takeAsList(1);
Assert.assertEquals(Arrays.asList("hello"), collected); Assert.assertEquals(Arrays.asList("hello"), collected);
} }
@ -82,7 +81,7 @@ public class JavaDatasetSuite implements Serializable {
@Test @Test
public void testCommonOperation() { public void testCommonOperation() {
List<String> data = Arrays.asList("hello", "world"); List<String> data = Arrays.asList("hello", "world");
Dataset<String> ds = context.createDataset(data, e.STRING()); Dataset<String> ds = context.createDataset(data, Encoders.STRING());
Assert.assertEquals("hello", ds.first()); Assert.assertEquals("hello", ds.first());
Dataset<String> filtered = ds.filter(new FilterFunction<String>() { Dataset<String> filtered = ds.filter(new FilterFunction<String>() {
@ -99,7 +98,7 @@ public class JavaDatasetSuite implements Serializable {
public Integer call(String v) throws Exception { public Integer call(String v) throws Exception {
return v.length(); return v.length();
} }
}, e.INT()); }, Encoders.INT());
Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList());
Dataset<String> parMapped = ds.mapPartitions(new MapPartitionsFunction<String, String>() { Dataset<String> parMapped = ds.mapPartitions(new MapPartitionsFunction<String, String>() {
@ -111,7 +110,7 @@ public class JavaDatasetSuite implements Serializable {
} }
return ls; return ls;
} }
}, e.STRING()); }, Encoders.STRING());
Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList()); Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList());
Dataset<String> flatMapped = ds.flatMap(new FlatMapFunction<String, String>() { Dataset<String> flatMapped = ds.flatMap(new FlatMapFunction<String, String>() {
@ -123,7 +122,7 @@ public class JavaDatasetSuite implements Serializable {
} }
return ls; return ls;
} }
}, e.STRING()); }, Encoders.STRING());
Assert.assertEquals( Assert.assertEquals(
Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"), Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"),
flatMapped.collectAsList()); flatMapped.collectAsList());
@ -133,7 +132,7 @@ public class JavaDatasetSuite implements Serializable {
public void testForeach() { public void testForeach() {
final Accumulator<Integer> accum = jsc.accumulator(0); final Accumulator<Integer> accum = jsc.accumulator(0);
List<String> data = Arrays.asList("a", "b", "c"); List<String> data = Arrays.asList("a", "b", "c");
Dataset<String> ds = context.createDataset(data, e.STRING()); Dataset<String> ds = context.createDataset(data, Encoders.STRING());
ds.foreach(new ForeachFunction<String>() { ds.foreach(new ForeachFunction<String>() {
@Override @Override
@ -147,7 +146,7 @@ public class JavaDatasetSuite implements Serializable {
@Test @Test
public void testReduce() { public void testReduce() {
List<Integer> data = Arrays.asList(1, 2, 3); List<Integer> data = Arrays.asList(1, 2, 3);
Dataset<Integer> ds = context.createDataset(data, e.INT()); Dataset<Integer> ds = context.createDataset(data, Encoders.INT());
int reduced = ds.reduce(new ReduceFunction<Integer>() { int reduced = ds.reduce(new ReduceFunction<Integer>() {
@Override @Override
@ -161,13 +160,13 @@ public class JavaDatasetSuite implements Serializable {
@Test @Test
public void testGroupBy() { public void testGroupBy() {
List<String> data = Arrays.asList("a", "foo", "bar"); List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = context.createDataset(data, e.STRING()); Dataset<String> ds = context.createDataset(data, Encoders.STRING());
GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() { GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() {
@Override @Override
public Integer call(String v) throws Exception { public Integer call(String v) throws Exception {
return v.length(); return v.length();
} }
}, e.INT()); }, Encoders.INT());
Dataset<String> mapped = grouped.map(new MapGroupFunction<Integer, String, String>() { Dataset<String> mapped = grouped.map(new MapGroupFunction<Integer, String, String>() {
@Override @Override
@ -178,7 +177,7 @@ public class JavaDatasetSuite implements Serializable {
} }
return sb.toString(); return sb.toString();
} }
}, e.STRING()); }, Encoders.STRING());
Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
@ -193,18 +192,18 @@ public class JavaDatasetSuite implements Serializable {
return Collections.singletonList(sb.toString()); return Collections.singletonList(sb.toString());
} }
}, },
e.STRING()); Encoders.STRING());
Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList()); Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList());
List<Integer> data2 = Arrays.asList(2, 6, 10); List<Integer> data2 = Arrays.asList(2, 6, 10);
Dataset<Integer> ds2 = context.createDataset(data2, e.INT()); Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT());
GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() { GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() {
@Override @Override
public Integer call(Integer v) throws Exception { public Integer call(Integer v) throws Exception {
return v / 2; return v / 2;
} }
}, e.INT()); }, Encoders.INT());
Dataset<String> cogrouped = grouped.cogroup( Dataset<String> cogrouped = grouped.cogroup(
grouped2, grouped2,
@ -225,7 +224,7 @@ public class JavaDatasetSuite implements Serializable {
return Collections.singletonList(sb.toString()); return Collections.singletonList(sb.toString());
} }
}, },
e.STRING()); Encoders.STRING());
Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList()); Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList());
} }
@ -233,8 +232,9 @@ public class JavaDatasetSuite implements Serializable {
@Test @Test
public void testGroupByColumn() { public void testGroupByColumn() {
List<String> data = Arrays.asList("a", "foo", "bar"); List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = context.createDataset(data, e.STRING()); Dataset<String> ds = context.createDataset(data, Encoders.STRING());
GroupedDataset<Integer, String> grouped = ds.groupBy(length(col("value"))).asKey(e.INT()); GroupedDataset<Integer, String> grouped =
ds.groupBy(length(col("value"))).asKey(Encoders.INT());
Dataset<String> mapped = grouped.map( Dataset<String> mapped = grouped.map(
new MapGroupFunction<Integer, String, String>() { new MapGroupFunction<Integer, String, String>() {
@ -247,7 +247,7 @@ public class JavaDatasetSuite implements Serializable {
return sb.toString(); return sb.toString();
} }
}, },
e.STRING()); Encoders.STRING());
Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
} }
@ -255,11 +255,11 @@ public class JavaDatasetSuite implements Serializable {
@Test @Test
public void testSelect() { public void testSelect() {
List<Integer> data = Arrays.asList(2, 6); List<Integer> data = Arrays.asList(2, 6);
Dataset<Integer> ds = context.createDataset(data, e.INT()); Dataset<Integer> ds = context.createDataset(data, Encoders.INT());
Dataset<Tuple2<Integer, String>> selected = ds.select( Dataset<Tuple2<Integer, String>> selected = ds.select(
expr("value + 1"), expr("value + 1"),
col("value").cast("string")).as(e.tuple(e.INT(), e.STRING())); col("value").cast("string")).as(Encoders.tuple(Encoders.INT(), Encoders.STRING()));
Assert.assertEquals( Assert.assertEquals(
Arrays.asList(tuple2(3, "2"), tuple2(7, "6")), Arrays.asList(tuple2(3, "2"), tuple2(7, "6")),
@ -269,14 +269,14 @@ public class JavaDatasetSuite implements Serializable {
@Test @Test
public void testSetOperation() { public void testSetOperation() {
List<String> data = Arrays.asList("abc", "abc", "xyz"); List<String> data = Arrays.asList("abc", "abc", "xyz");
Dataset<String> ds = context.createDataset(data, e.STRING()); Dataset<String> ds = context.createDataset(data, Encoders.STRING());
Assert.assertEquals( Assert.assertEquals(
Arrays.asList("abc", "xyz"), Arrays.asList("abc", "xyz"),
sort(ds.distinct().collectAsList().toArray(new String[0]))); sort(ds.distinct().collectAsList().toArray(new String[0])));
List<String> data2 = Arrays.asList("xyz", "foo", "foo"); List<String> data2 = Arrays.asList("xyz", "foo", "foo");
Dataset<String> ds2 = context.createDataset(data2, e.STRING()); Dataset<String> ds2 = context.createDataset(data2, Encoders.STRING());
Dataset<String> intersected = ds.intersect(ds2); Dataset<String> intersected = ds.intersect(ds2);
Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList());
@ -298,9 +298,9 @@ public class JavaDatasetSuite implements Serializable {
@Test @Test
public void testJoin() { public void testJoin() {
List<Integer> data = Arrays.asList(1, 2, 3); List<Integer> data = Arrays.asList(1, 2, 3);
Dataset<Integer> ds = context.createDataset(data, e.INT()).as("a"); Dataset<Integer> ds = context.createDataset(data, Encoders.INT()).as("a");
List<Integer> data2 = Arrays.asList(2, 3, 4); List<Integer> data2 = Arrays.asList(2, 3, 4);
Dataset<Integer> ds2 = context.createDataset(data2, e.INT()).as("b"); Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()).as("b");
Dataset<Tuple2<Integer, Integer>> joined = Dataset<Tuple2<Integer, Integer>> joined =
ds.joinWith(ds2, col("a.value").equalTo(col("b.value"))); ds.joinWith(ds2, col("a.value").equalTo(col("b.value")));
@ -311,26 +311,28 @@ public class JavaDatasetSuite implements Serializable {
@Test @Test
public void testTupleEncoder() { public void testTupleEncoder() {
Encoder<Tuple2<Integer, String>> encoder2 = e.tuple(e.INT(), e.STRING()); Encoder<Tuple2<Integer, String>> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING());
List<Tuple2<Integer, String>> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b")); List<Tuple2<Integer, String>> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b"));
Dataset<Tuple2<Integer, String>> ds2 = context.createDataset(data2, encoder2); Dataset<Tuple2<Integer, String>> ds2 = context.createDataset(data2, encoder2);
Assert.assertEquals(data2, ds2.collectAsList()); Assert.assertEquals(data2, ds2.collectAsList());
Encoder<Tuple3<Integer, Long, String>> encoder3 = e.tuple(e.INT(), e.LONG(), e.STRING()); Encoder<Tuple3<Integer, Long, String>> encoder3 =
Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING());
List<Tuple3<Integer, Long, String>> data3 = List<Tuple3<Integer, Long, String>> data3 =
Arrays.asList(new Tuple3<Integer, Long, String>(1, 2L, "a")); Arrays.asList(new Tuple3<Integer, Long, String>(1, 2L, "a"));
Dataset<Tuple3<Integer, Long, String>> ds3 = context.createDataset(data3, encoder3); Dataset<Tuple3<Integer, Long, String>> ds3 = context.createDataset(data3, encoder3);
Assert.assertEquals(data3, ds3.collectAsList()); Assert.assertEquals(data3, ds3.collectAsList());
Encoder<Tuple4<Integer, String, Long, String>> encoder4 = Encoder<Tuple4<Integer, String, Long, String>> encoder4 =
e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING()); Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING());
List<Tuple4<Integer, String, Long, String>> data4 = List<Tuple4<Integer, String, Long, String>> data4 =
Arrays.asList(new Tuple4<Integer, String, Long, String>(1, "b", 2L, "a")); Arrays.asList(new Tuple4<Integer, String, Long, String>(1, "b", 2L, "a"));
Dataset<Tuple4<Integer, String, Long, String>> ds4 = context.createDataset(data4, encoder4); Dataset<Tuple4<Integer, String, Long, String>> ds4 = context.createDataset(data4, encoder4);
Assert.assertEquals(data4, ds4.collectAsList()); Assert.assertEquals(data4, ds4.collectAsList());
Encoder<Tuple5<Integer, String, Long, String, Boolean>> encoder5 = Encoder<Tuple5<Integer, String, Long, String, Boolean>> encoder5 =
e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING(), e.BOOLEAN()); Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING(),
Encoders.BOOLEAN());
List<Tuple5<Integer, String, Long, String, Boolean>> data5 = List<Tuple5<Integer, String, Long, String, Boolean>> data5 =
Arrays.asList(new Tuple5<Integer, String, Long, String, Boolean>(1, "b", 2L, "a", true)); Arrays.asList(new Tuple5<Integer, String, Long, String, Boolean>(1, "b", 2L, "a", true));
Dataset<Tuple5<Integer, String, Long, String, Boolean>> ds5 = Dataset<Tuple5<Integer, String, Long, String, Boolean>> ds5 =
@ -342,7 +344,7 @@ public class JavaDatasetSuite implements Serializable {
public void testNestedTupleEncoder() { public void testNestedTupleEncoder() {
// test ((int, string), string) // test ((int, string), string)
Encoder<Tuple2<Tuple2<Integer, String>, String>> encoder = Encoder<Tuple2<Tuple2<Integer, String>, String>> encoder =
e.tuple(e.tuple(e.INT(), e.STRING()), e.STRING()); Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING());
List<Tuple2<Tuple2<Integer, String>, String>> data = List<Tuple2<Tuple2<Integer, String>, String>> data =
Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b")); Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b"));
Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = context.createDataset(data, encoder); Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = context.createDataset(data, encoder);
@ -350,7 +352,8 @@ public class JavaDatasetSuite implements Serializable {
// test (int, (string, string, long)) // test (int, (string, string, long))
Encoder<Tuple2<Integer, Tuple3<String, String, Long>>> encoder2 = Encoder<Tuple2<Integer, Tuple3<String, String, Long>>> encoder2 =
e.tuple(e.INT(), e.tuple(e.STRING(), e.STRING(), e.LONG())); Encoders.tuple(Encoders.INT(),
Encoders.tuple(Encoders.STRING(), Encoders.STRING(), Encoders.LONG()));
List<Tuple2<Integer, Tuple3<String, String, Long>>> data2 = List<Tuple2<Integer, Tuple3<String, String, Long>>> data2 =
Arrays.asList(tuple2(1, new Tuple3<String, String, Long>("a", "b", 3L))); Arrays.asList(tuple2(1, new Tuple3<String, String, Long>("a", "b", 3L)));
Dataset<Tuple2<Integer, Tuple3<String, String, Long>>> ds2 = Dataset<Tuple2<Integer, Tuple3<String, String, Long>>> ds2 =
@ -359,7 +362,8 @@ public class JavaDatasetSuite implements Serializable {
// test (int, ((string, long), string)) // test (int, ((string, long), string))
Encoder<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> encoder3 = Encoder<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> encoder3 =
e.tuple(e.INT(), e.tuple(e.tuple(e.STRING(), e.LONG()), e.STRING())); Encoders.tuple(Encoders.INT(),
Encoders.tuple(Encoders.tuple(Encoders.STRING(), Encoders.LONG()), Encoders.STRING()));
List<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> data3 = List<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> data3 =
Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b"))); Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b")));
Dataset<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> ds3 = Dataset<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> ds3 =

View file

@ -17,13 +17,11 @@
package org.apache.spark.sql package org.apache.spark.sql
import org.apache.spark.sql.catalyst.encoders.Encoder
import org.apache.spark.sql.functions._
import scala.language.postfixOps import scala.language.postfixOps
import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.Aggregator
/** An `Aggregator` that adds up any numeric type returned by the given function. */ /** An `Aggregator` that adds up any numeric type returned by the given function. */

View file

@ -24,7 +24,6 @@ import scala.collection.JavaConverters._
import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.sql.columnar.InMemoryRelation
import org.apache.spark.sql.catalyst.encoders.Encoder
abstract class QueryTest extends PlanTest { abstract class QueryTest extends PlanTest {