[SPARK-11564][SQL][FOLLOW-UP] clean up java tuple encoder
We need to support custom classes like java beans and combine them into tuple, and it's very hard to do it with the TypeTag-based approach. We should keep only the compose-based way to create tuple encoder. This PR also move `Encoder` to `org.apache.spark.sql` Author: Wenchen Fan <wenchen@databricks.com> Closes #9567 from cloud-fan/java.
This commit is contained in:
parent
9c57bc0efc
commit
ec2b807212
|
@ -15,14 +15,15 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.encoders
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
import org.apache.spark.util.Utils
|
||||
import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
|
||||
/**
|
||||
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
|
||||
*
|
||||
|
@ -38,9 +39,7 @@ trait Encoder[T] extends Serializable {
|
|||
def clsTag: ClassTag[T]
|
||||
}
|
||||
|
||||
object Encoder {
|
||||
import scala.reflect.runtime.universe._
|
||||
|
||||
object Encoders {
|
||||
def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true)
|
||||
def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true)
|
||||
def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true)
|
||||
|
@ -129,54 +128,4 @@ object Encoder {
|
|||
constructExpression,
|
||||
ClassTag.apply(cls))
|
||||
}
|
||||
|
||||
def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)]
|
||||
|
||||
private def getTypeTag[T](c: Class[T]): TypeTag[T] = {
|
||||
import scala.reflect.api
|
||||
|
||||
// val mirror = runtimeMirror(c.getClassLoader)
|
||||
val mirror = rootMirror
|
||||
val sym = mirror.staticClass(c.getName)
|
||||
val tpe = sym.selfType
|
||||
TypeTag(mirror, new api.TypeCreator {
|
||||
def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) =
|
||||
if (m eq mirror) tpe.asInstanceOf[U # Type]
|
||||
else throw new IllegalArgumentException(
|
||||
s"Type tag defined in $mirror cannot be migrated to other mirrors.")
|
||||
})
|
||||
}
|
||||
|
||||
def forTuple[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = {
|
||||
implicit val typeTag1 = getTypeTag(c1)
|
||||
implicit val typeTag2 = getTypeTag(c2)
|
||||
ExpressionEncoder[(T1, T2)]()
|
||||
}
|
||||
|
||||
def forTuple[T1, T2, T3](c1: Class[T1], c2: Class[T2], c3: Class[T3]): Encoder[(T1, T2, T3)] = {
|
||||
implicit val typeTag1 = getTypeTag(c1)
|
||||
implicit val typeTag2 = getTypeTag(c2)
|
||||
implicit val typeTag3 = getTypeTag(c3)
|
||||
ExpressionEncoder[(T1, T2, T3)]()
|
||||
}
|
||||
|
||||
def forTuple[T1, T2, T3, T4](
|
||||
c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4]): Encoder[(T1, T2, T3, T4)] = {
|
||||
implicit val typeTag1 = getTypeTag(c1)
|
||||
implicit val typeTag2 = getTypeTag(c2)
|
||||
implicit val typeTag3 = getTypeTag(c3)
|
||||
implicit val typeTag4 = getTypeTag(c4)
|
||||
ExpressionEncoder[(T1, T2, T3, T4)]()
|
||||
}
|
||||
|
||||
def forTuple[T1, T2, T3, T4, T5](
|
||||
c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4], c5: Class[T5])
|
||||
: Encoder[(T1, T2, T3, T4, T5)] = {
|
||||
implicit val typeTag1 = getTypeTag(c1)
|
||||
implicit val typeTag2 = getTypeTag(c2)
|
||||
implicit val typeTag3 = getTypeTag(c3)
|
||||
implicit val typeTag4 = getTypeTag(c4)
|
||||
implicit val typeTag5 = getTypeTag(c5)
|
||||
ExpressionEncoder[(T1, T2, T3, T4, T5)]()
|
||||
}
|
||||
}
|
|
@ -17,18 +17,18 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.encoders
|
||||
|
||||
import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
import scala.reflect.ClassTag
|
||||
import scala.reflect.runtime.universe.{typeTag, TypeTag}
|
||||
|
||||
import org.apache.spark.util.Utils
|
||||
import org.apache.spark.sql.Encoder
|
||||
import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection
|
||||
import org.apache.spark.sql.types.{StructField, DataType, ObjectType, StructType}
|
||||
import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
|
||||
|
||||
/**
|
||||
* A factory for constructing encoders that convert objects and primitves to and from the
|
||||
|
|
|
@ -17,10 +17,11 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst
|
||||
|
||||
import org.apache.spark.sql.Encoder
|
||||
|
||||
package object encoders {
|
||||
private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match {
|
||||
case e: ExpressionEncoder[A] => e
|
||||
case _ => sys.error(s"Only expression encoders are supported today")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.plans.logical
|
||||
|
||||
import org.apache.spark.sql.Encoder
|
||||
import org.apache.spark.sql.catalyst.encoders._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
|
||||
|
|
|
@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental
|
|||
import org.apache.spark.Logging
|
||||
import org.apache.spark.sql.functions.lit
|
||||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
|
||||
import org.apache.spark.sql.catalyst.encoders.encoderFor
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.util.DataTypeParser
|
||||
import org.apache.spark.sql.types._
|
||||
|
|
|
@ -23,7 +23,6 @@ import java.util.Properties
|
|||
import scala.language.implicitConversions
|
||||
import scala.reflect.ClassTag
|
||||
import scala.reflect.runtime.universe.TypeTag
|
||||
import scala.util.control.NonFatal
|
||||
|
||||
import com.fasterxml.jackson.core.JsonFactory
|
||||
import org.apache.commons.lang3.StringUtils
|
||||
|
@ -35,7 +34,6 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate._
|
||||
import org.apache.spark.sql.catalyst.encoders.Encoder
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
|
||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
|
||||
|
|
|
@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
|
|||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3, _}
|
||||
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute}
|
||||
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder}
|
||||
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute}
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
|
||||
|
|
|
@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD
|
|||
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
|
||||
import org.apache.spark.sql.SQLConf.SQLConfEntry
|
||||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
|
||||
import org.apache.spark.sql.catalyst.encoders.encoderFor
|
||||
import org.apache.spark.sql.catalyst.errors.DialectException
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
|
||||
|
|
|
@ -20,9 +20,10 @@ package org.apache.spark.sql.execution.aggregate
|
|||
import scala.language.existentials
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.sql.Encoder
|
||||
import org.apache.spark.sql.expressions.Aggregator
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
|
||||
import org.apache.spark.sql.catalyst.encoders.encoderFor
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
|
||||
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
|
|
|
@ -17,7 +17,8 @@
|
|||
|
||||
package org.apache.spark.sql.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
|
||||
import org.apache.spark.sql.Encoder
|
||||
import org.apache.spark.sql.catalyst.encoders.encoderFor
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
|
||||
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
|
||||
import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn}
|
||||
|
|
|
@ -26,7 +26,7 @@ import scala.util.Try
|
|||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
|
||||
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
|
||||
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, Encoder}
|
||||
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
|
||||
|
|
|
@ -30,8 +30,8 @@ import org.apache.spark.Accumulator;
|
|||
import org.apache.spark.SparkContext;
|
||||
import org.apache.spark.api.java.function.*;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.sql.catalyst.encoders.Encoder;
|
||||
import org.apache.spark.sql.catalyst.encoders.Encoder$;
|
||||
import org.apache.spark.sql.Encoder;
|
||||
import org.apache.spark.sql.Encoders;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.GroupedDataset;
|
||||
import org.apache.spark.sql.test.TestSQLContext;
|
||||
|
@ -41,7 +41,6 @@ import static org.apache.spark.sql.functions.*;
|
|||
public class JavaDatasetSuite implements Serializable {
|
||||
private transient JavaSparkContext jsc;
|
||||
private transient TestSQLContext context;
|
||||
private transient Encoder$ e = Encoder$.MODULE$;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
|
@ -66,7 +65,7 @@ public class JavaDatasetSuite implements Serializable {
|
|||
@Test
|
||||
public void testCollect() {
|
||||
List<String> data = Arrays.asList("hello", "world");
|
||||
Dataset<String> ds = context.createDataset(data, e.STRING());
|
||||
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
|
||||
List<String> collected = ds.collectAsList();
|
||||
Assert.assertEquals(Arrays.asList("hello", "world"), collected);
|
||||
}
|
||||
|
@ -74,7 +73,7 @@ public class JavaDatasetSuite implements Serializable {
|
|||
@Test
|
||||
public void testTake() {
|
||||
List<String> data = Arrays.asList("hello", "world");
|
||||
Dataset<String> ds = context.createDataset(data, e.STRING());
|
||||
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
|
||||
List<String> collected = ds.takeAsList(1);
|
||||
Assert.assertEquals(Arrays.asList("hello"), collected);
|
||||
}
|
||||
|
@ -82,7 +81,7 @@ public class JavaDatasetSuite implements Serializable {
|
|||
@Test
|
||||
public void testCommonOperation() {
|
||||
List<String> data = Arrays.asList("hello", "world");
|
||||
Dataset<String> ds = context.createDataset(data, e.STRING());
|
||||
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
|
||||
Assert.assertEquals("hello", ds.first());
|
||||
|
||||
Dataset<String> filtered = ds.filter(new FilterFunction<String>() {
|
||||
|
@ -99,7 +98,7 @@ public class JavaDatasetSuite implements Serializable {
|
|||
public Integer call(String v) throws Exception {
|
||||
return v.length();
|
||||
}
|
||||
}, e.INT());
|
||||
}, Encoders.INT());
|
||||
Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList());
|
||||
|
||||
Dataset<String> parMapped = ds.mapPartitions(new MapPartitionsFunction<String, String>() {
|
||||
|
@ -111,7 +110,7 @@ public class JavaDatasetSuite implements Serializable {
|
|||
}
|
||||
return ls;
|
||||
}
|
||||
}, e.STRING());
|
||||
}, Encoders.STRING());
|
||||
Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList());
|
||||
|
||||
Dataset<String> flatMapped = ds.flatMap(new FlatMapFunction<String, String>() {
|
||||
|
@ -123,7 +122,7 @@ public class JavaDatasetSuite implements Serializable {
|
|||
}
|
||||
return ls;
|
||||
}
|
||||
}, e.STRING());
|
||||
}, Encoders.STRING());
|
||||
Assert.assertEquals(
|
||||
Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"),
|
||||
flatMapped.collectAsList());
|
||||
|
@ -133,7 +132,7 @@ public class JavaDatasetSuite implements Serializable {
|
|||
public void testForeach() {
|
||||
final Accumulator<Integer> accum = jsc.accumulator(0);
|
||||
List<String> data = Arrays.asList("a", "b", "c");
|
||||
Dataset<String> ds = context.createDataset(data, e.STRING());
|
||||
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
|
||||
|
||||
ds.foreach(new ForeachFunction<String>() {
|
||||
@Override
|
||||
|
@ -147,7 +146,7 @@ public class JavaDatasetSuite implements Serializable {
|
|||
@Test
|
||||
public void testReduce() {
|
||||
List<Integer> data = Arrays.asList(1, 2, 3);
|
||||
Dataset<Integer> ds = context.createDataset(data, e.INT());
|
||||
Dataset<Integer> ds = context.createDataset(data, Encoders.INT());
|
||||
|
||||
int reduced = ds.reduce(new ReduceFunction<Integer>() {
|
||||
@Override
|
||||
|
@ -161,13 +160,13 @@ public class JavaDatasetSuite implements Serializable {
|
|||
@Test
|
||||
public void testGroupBy() {
|
||||
List<String> data = Arrays.asList("a", "foo", "bar");
|
||||
Dataset<String> ds = context.createDataset(data, e.STRING());
|
||||
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
|
||||
GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() {
|
||||
@Override
|
||||
public Integer call(String v) throws Exception {
|
||||
return v.length();
|
||||
}
|
||||
}, e.INT());
|
||||
}, Encoders.INT());
|
||||
|
||||
Dataset<String> mapped = grouped.map(new MapGroupFunction<Integer, String, String>() {
|
||||
@Override
|
||||
|
@ -178,7 +177,7 @@ public class JavaDatasetSuite implements Serializable {
|
|||
}
|
||||
return sb.toString();
|
||||
}
|
||||
}, e.STRING());
|
||||
}, Encoders.STRING());
|
||||
|
||||
Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
|
||||
|
||||
|
@ -193,27 +192,27 @@ public class JavaDatasetSuite implements Serializable {
|
|||
return Collections.singletonList(sb.toString());
|
||||
}
|
||||
},
|
||||
e.STRING());
|
||||
Encoders.STRING());
|
||||
|
||||
Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList());
|
||||
|
||||
List<Integer> data2 = Arrays.asList(2, 6, 10);
|
||||
Dataset<Integer> ds2 = context.createDataset(data2, e.INT());
|
||||
Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT());
|
||||
GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() {
|
||||
@Override
|
||||
public Integer call(Integer v) throws Exception {
|
||||
return v / 2;
|
||||
}
|
||||
}, e.INT());
|
||||
}, Encoders.INT());
|
||||
|
||||
Dataset<String> cogrouped = grouped.cogroup(
|
||||
grouped2,
|
||||
new CoGroupFunction<Integer, String, Integer, String>() {
|
||||
@Override
|
||||
public Iterable<String> call(
|
||||
Integer key,
|
||||
Iterator<String> left,
|
||||
Iterator<Integer> right) throws Exception {
|
||||
Integer key,
|
||||
Iterator<String> left,
|
||||
Iterator<Integer> right) throws Exception {
|
||||
StringBuilder sb = new StringBuilder(key.toString());
|
||||
while (left.hasNext()) {
|
||||
sb.append(left.next());
|
||||
|
@ -225,7 +224,7 @@ public class JavaDatasetSuite implements Serializable {
|
|||
return Collections.singletonList(sb.toString());
|
||||
}
|
||||
},
|
||||
e.STRING());
|
||||
Encoders.STRING());
|
||||
|
||||
Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList());
|
||||
}
|
||||
|
@ -233,8 +232,9 @@ public class JavaDatasetSuite implements Serializable {
|
|||
@Test
|
||||
public void testGroupByColumn() {
|
||||
List<String> data = Arrays.asList("a", "foo", "bar");
|
||||
Dataset<String> ds = context.createDataset(data, e.STRING());
|
||||
GroupedDataset<Integer, String> grouped = ds.groupBy(length(col("value"))).asKey(e.INT());
|
||||
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
|
||||
GroupedDataset<Integer, String> grouped =
|
||||
ds.groupBy(length(col("value"))).asKey(Encoders.INT());
|
||||
|
||||
Dataset<String> mapped = grouped.map(
|
||||
new MapGroupFunction<Integer, String, String>() {
|
||||
|
@ -247,7 +247,7 @@ public class JavaDatasetSuite implements Serializable {
|
|||
return sb.toString();
|
||||
}
|
||||
},
|
||||
e.STRING());
|
||||
Encoders.STRING());
|
||||
|
||||
Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
|
||||
}
|
||||
|
@ -255,11 +255,11 @@ public class JavaDatasetSuite implements Serializable {
|
|||
@Test
|
||||
public void testSelect() {
|
||||
List<Integer> data = Arrays.asList(2, 6);
|
||||
Dataset<Integer> ds = context.createDataset(data, e.INT());
|
||||
Dataset<Integer> ds = context.createDataset(data, Encoders.INT());
|
||||
|
||||
Dataset<Tuple2<Integer, String>> selected = ds.select(
|
||||
expr("value + 1"),
|
||||
col("value").cast("string")).as(e.tuple(e.INT(), e.STRING()));
|
||||
col("value").cast("string")).as(Encoders.tuple(Encoders.INT(), Encoders.STRING()));
|
||||
|
||||
Assert.assertEquals(
|
||||
Arrays.asList(tuple2(3, "2"), tuple2(7, "6")),
|
||||
|
@ -269,14 +269,14 @@ public class JavaDatasetSuite implements Serializable {
|
|||
@Test
|
||||
public void testSetOperation() {
|
||||
List<String> data = Arrays.asList("abc", "abc", "xyz");
|
||||
Dataset<String> ds = context.createDataset(data, e.STRING());
|
||||
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
|
||||
|
||||
Assert.assertEquals(
|
||||
Arrays.asList("abc", "xyz"),
|
||||
sort(ds.distinct().collectAsList().toArray(new String[0])));
|
||||
|
||||
List<String> data2 = Arrays.asList("xyz", "foo", "foo");
|
||||
Dataset<String> ds2 = context.createDataset(data2, e.STRING());
|
||||
Dataset<String> ds2 = context.createDataset(data2, Encoders.STRING());
|
||||
|
||||
Dataset<String> intersected = ds.intersect(ds2);
|
||||
Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList());
|
||||
|
@ -298,9 +298,9 @@ public class JavaDatasetSuite implements Serializable {
|
|||
@Test
|
||||
public void testJoin() {
|
||||
List<Integer> data = Arrays.asList(1, 2, 3);
|
||||
Dataset<Integer> ds = context.createDataset(data, e.INT()).as("a");
|
||||
Dataset<Integer> ds = context.createDataset(data, Encoders.INT()).as("a");
|
||||
List<Integer> data2 = Arrays.asList(2, 3, 4);
|
||||
Dataset<Integer> ds2 = context.createDataset(data2, e.INT()).as("b");
|
||||
Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()).as("b");
|
||||
|
||||
Dataset<Tuple2<Integer, Integer>> joined =
|
||||
ds.joinWith(ds2, col("a.value").equalTo(col("b.value")));
|
||||
|
@ -311,26 +311,28 @@ public class JavaDatasetSuite implements Serializable {
|
|||
|
||||
@Test
|
||||
public void testTupleEncoder() {
|
||||
Encoder<Tuple2<Integer, String>> encoder2 = e.tuple(e.INT(), e.STRING());
|
||||
Encoder<Tuple2<Integer, String>> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING());
|
||||
List<Tuple2<Integer, String>> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b"));
|
||||
Dataset<Tuple2<Integer, String>> ds2 = context.createDataset(data2, encoder2);
|
||||
Assert.assertEquals(data2, ds2.collectAsList());
|
||||
|
||||
Encoder<Tuple3<Integer, Long, String>> encoder3 = e.tuple(e.INT(), e.LONG(), e.STRING());
|
||||
Encoder<Tuple3<Integer, Long, String>> encoder3 =
|
||||
Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING());
|
||||
List<Tuple3<Integer, Long, String>> data3 =
|
||||
Arrays.asList(new Tuple3<Integer, Long, String>(1, 2L, "a"));
|
||||
Dataset<Tuple3<Integer, Long, String>> ds3 = context.createDataset(data3, encoder3);
|
||||
Assert.assertEquals(data3, ds3.collectAsList());
|
||||
|
||||
Encoder<Tuple4<Integer, String, Long, String>> encoder4 =
|
||||
e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING());
|
||||
Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING());
|
||||
List<Tuple4<Integer, String, Long, String>> data4 =
|
||||
Arrays.asList(new Tuple4<Integer, String, Long, String>(1, "b", 2L, "a"));
|
||||
Dataset<Tuple4<Integer, String, Long, String>> ds4 = context.createDataset(data4, encoder4);
|
||||
Assert.assertEquals(data4, ds4.collectAsList());
|
||||
|
||||
Encoder<Tuple5<Integer, String, Long, String, Boolean>> encoder5 =
|
||||
e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING(), e.BOOLEAN());
|
||||
Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING(),
|
||||
Encoders.BOOLEAN());
|
||||
List<Tuple5<Integer, String, Long, String, Boolean>> data5 =
|
||||
Arrays.asList(new Tuple5<Integer, String, Long, String, Boolean>(1, "b", 2L, "a", true));
|
||||
Dataset<Tuple5<Integer, String, Long, String, Boolean>> ds5 =
|
||||
|
@ -342,7 +344,7 @@ public class JavaDatasetSuite implements Serializable {
|
|||
public void testNestedTupleEncoder() {
|
||||
// test ((int, string), string)
|
||||
Encoder<Tuple2<Tuple2<Integer, String>, String>> encoder =
|
||||
e.tuple(e.tuple(e.INT(), e.STRING()), e.STRING());
|
||||
Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING());
|
||||
List<Tuple2<Tuple2<Integer, String>, String>> data =
|
||||
Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b"));
|
||||
Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = context.createDataset(data, encoder);
|
||||
|
@ -350,7 +352,8 @@ public class JavaDatasetSuite implements Serializable {
|
|||
|
||||
// test (int, (string, string, long))
|
||||
Encoder<Tuple2<Integer, Tuple3<String, String, Long>>> encoder2 =
|
||||
e.tuple(e.INT(), e.tuple(e.STRING(), e.STRING(), e.LONG()));
|
||||
Encoders.tuple(Encoders.INT(),
|
||||
Encoders.tuple(Encoders.STRING(), Encoders.STRING(), Encoders.LONG()));
|
||||
List<Tuple2<Integer, Tuple3<String, String, Long>>> data2 =
|
||||
Arrays.asList(tuple2(1, new Tuple3<String, String, Long>("a", "b", 3L)));
|
||||
Dataset<Tuple2<Integer, Tuple3<String, String, Long>>> ds2 =
|
||||
|
@ -359,7 +362,8 @@ public class JavaDatasetSuite implements Serializable {
|
|||
|
||||
// test (int, ((string, long), string))
|
||||
Encoder<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> encoder3 =
|
||||
e.tuple(e.INT(), e.tuple(e.tuple(e.STRING(), e.LONG()), e.STRING()));
|
||||
Encoders.tuple(Encoders.INT(),
|
||||
Encoders.tuple(Encoders.tuple(Encoders.STRING(), Encoders.LONG()), Encoders.STRING()));
|
||||
List<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> data3 =
|
||||
Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b")));
|
||||
Dataset<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> ds3 =
|
||||
|
|
|
@ -17,13 +17,11 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.catalyst.encoders.Encoder
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
import scala.language.postfixOps
|
||||
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.expressions.Aggregator
|
||||
|
||||
/** An `Aggregator` that adds up any numeric type returned by the given function. */
|
||||
|
|
|
@ -24,7 +24,6 @@ import scala.collection.JavaConverters._
|
|||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.util._
|
||||
import org.apache.spark.sql.columnar.InMemoryRelation
|
||||
import org.apache.spark.sql.catalyst.encoders.Encoder
|
||||
|
||||
abstract class QueryTest extends PlanTest {
|
||||
|
||||
|
|
Loading…
Reference in a new issue