[SPARK-11954][SQL] Encoder for JavaBeans

create java version of `constructorFor` and `extractorFor` in `JavaTypeInference`

Author: Wenchen Fan <wenchen@databricks.com>

This patch had conflicts when merged, resolved by
Committer: Michael Armbrust <michael@databricks.com>

Closes #9937 from cloud-fan/pojo.
This commit is contained in:
Wenchen Fan 2015-12-01 10:35:12 -08:00 committed by Michael Armbrust
parent 9df24624af
commit fd95eeaf49
9 changed files with 608 additions and 20 deletions

View file

@ -97,6 +97,24 @@ object Encoders {
*/
def STRING: Encoder[java.lang.String] = ExpressionEncoder()
/**
* Creates an encoder for Java Bean of type T.
*
* T must be publicly accessible.
*
* supported types for java bean field:
* - primitive types: boolean, int, double, etc.
* - boxed types: Boolean, Integer, Double, etc.
* - String
* - java.math.BigDecimal
* - time related: java.sql.Date, java.sql.Timestamp
* - collection types: only array and java.util.List currently, map support is in progress
* - nested java bean.
*
* @since 1.6.0
*/
def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass)
/**
* (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
* This encoder maps T into a single byte array (binary) field.

View file

@ -17,14 +17,20 @@
package org.apache.spark.sql.catalyst
import java.beans.Introspector
import java.beans.{PropertyDescriptor, Introspector}
import java.lang.{Iterable => JIterable}
import java.util.{Iterator => JIterator, Map => JMap}
import java.util.{Iterator => JIterator, Map => JMap, List => JList}
import scala.language.existentials
import com.google.common.reflect.TypeToken
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils}
import org.apache.spark.unsafe.types.UTF8String
/**
* Type-inference utilities for POJOs and Java collections.
@ -33,13 +39,14 @@ object JavaTypeInference {
private val iterableType = TypeToken.of(classOf[JIterable[_]])
private val mapType = TypeToken.of(classOf[JMap[_, _]])
private val listType = TypeToken.of(classOf[JList[_]])
private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType
private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType
private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType
private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType
/**
* Infers the corresponding SQL data type of a JavaClean class.
* Infers the corresponding SQL data type of a JavaBean class.
* @param beanClass Java type
* @return (SQL data type, nullable)
*/
@ -58,6 +65,8 @@ object JavaTypeInference {
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
case c: Class[_] if c == classOf[Array[Byte]] => (BinaryType, true)
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
@ -87,15 +96,14 @@ object JavaTypeInference {
(ArrayType(dataType, nullable), true)
case _ if mapType.isAssignableFrom(typeToken) =>
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]])
val keyType = elementType(mapSupertype.resolveType(keySetReturnType))
val valueType = elementType(mapSupertype.resolveType(valuesReturnType))
val (keyType, valueType) = mapKeyValueType(typeToken)
val (keyDataType, _) = inferDataType(keyType)
val (valueDataType, nullable) = inferDataType(valueType)
(MapType(keyDataType, valueDataType, nullable), true)
case _ =>
// TODO: we should only collect properties that have getter and setter. However, some tests
// pass in scala case class as java bean class which doesn't have getter and setter.
val beanInfo = Introspector.getBeanInfo(typeToken.getRawType)
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
val fields = properties.map { property =>
@ -107,11 +115,294 @@ object JavaTypeInference {
}
}
private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
val beanInfo = Introspector.getBeanInfo(beanClass)
beanInfo.getPropertyDescriptors
.filter(p => p.getReadMethod != null && p.getWriteMethod != null)
}
private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]]
val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]])
val iteratorType = iterableSupertype.resolveType(iteratorReturnType)
val itemType = iteratorType.resolveType(nextReturnType)
itemType
val iterableSuperType = typeToken2.getSupertype(classOf[JIterable[_]])
val iteratorType = iterableSuperType.resolveType(iteratorReturnType)
iteratorType.resolveType(nextReturnType)
}
private def mapKeyValueType(typeToken: TypeToken[_]): (TypeToken[_], TypeToken[_]) = {
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
val mapSuperType = typeToken2.getSupertype(classOf[JMap[_, _]])
val keyType = elementType(mapSuperType.resolveType(keySetReturnType))
val valueType = elementType(mapSuperType.resolveType(valuesReturnType))
keyType -> valueType
}
/**
* Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping
* to a native type, an ObjectType is returned.
*
* Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type
* system. As a result, ObjectType will be returned for things like boxed Integers.
*/
private def inferExternalType(cls: Class[_]): DataType = cls match {
case c if c == java.lang.Boolean.TYPE => BooleanType
case c if c == java.lang.Byte.TYPE => ByteType
case c if c == java.lang.Short.TYPE => ShortType
case c if c == java.lang.Integer.TYPE => IntegerType
case c if c == java.lang.Long.TYPE => LongType
case c if c == java.lang.Float.TYPE => FloatType
case c if c == java.lang.Double.TYPE => DoubleType
case c if c == classOf[Array[Byte]] => BinaryType
case _ => ObjectType(cls)
}
/**
* Returns an expression that can be used to construct an object of java bean `T` given an input
* row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
* of the same name as the constructor arguments. Nested classes will have their fields accessed
* using UnresolvedExtractValue.
*/
def constructorFor(beanClass: Class[_]): Expression = {
constructorFor(TypeToken.of(beanClass), None)
}
private def constructorFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = {
/** Returns the current path with a sub-field extracted. */
def addToPath(part: String): Expression = path
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
.getOrElse(UnresolvedAttribute(part))
/** Returns the current path or `BoundReference`. */
def getPath: Expression = path.getOrElse(BoundReference(0, inferDataType(typeToken)._1, true))
typeToken.getRawType match {
case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath
case c if c == classOf[java.lang.Short] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
case c if c == classOf[java.lang.Integer] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
case c if c == classOf[java.lang.Long] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
case c if c == classOf[java.lang.Double] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
case c if c == classOf[java.lang.Byte] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
case c if c == classOf[java.lang.Float] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
case c if c == classOf[java.lang.Boolean] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
case c if c == classOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils,
ObjectType(c),
"toJavaDate",
getPath :: Nil,
propagateNull = true)
case c if c == classOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils,
ObjectType(c),
"toJavaTimestamp",
getPath :: Nil,
propagateNull = true)
case c if c == classOf[java.lang.String] =>
Invoke(getPath, "toString", ObjectType(classOf[String]))
case c if c == classOf[java.math.BigDecimal] =>
Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
case c if c.isArray =>
val elementType = c.getComponentType
val primitiveMethod = elementType match {
case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray")
case c if c == java.lang.Byte.TYPE => Some("toByteArray")
case c if c == java.lang.Short.TYPE => Some("toShortArray")
case c if c == java.lang.Integer.TYPE => Some("toIntArray")
case c if c == java.lang.Long.TYPE => Some("toLongArray")
case c if c == java.lang.Float.TYPE => Some("toFloatArray")
case c if c == java.lang.Double.TYPE => Some("toDoubleArray")
case _ => None
}
primitiveMethod.map { method =>
Invoke(getPath, method, ObjectType(c))
}.getOrElse {
Invoke(
MapObjects(
p => constructorFor(typeToken.getComponentType, Some(p)),
getPath,
inferDataType(elementType)._1),
"array",
ObjectType(c))
}
case c if listType.isAssignableFrom(typeToken) =>
val et = elementType(typeToken)
val array =
Invoke(
MapObjects(
p => constructorFor(et, Some(p)),
getPath,
inferDataType(et)._1),
"array",
ObjectType(classOf[Array[Any]]))
StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil)
case _ if mapType.isAssignableFrom(typeToken) =>
val (keyType, valueType) = mapKeyValueType(typeToken)
val keyDataType = inferDataType(keyType)._1
val valueDataType = inferDataType(valueType)._1
val keyData =
Invoke(
MapObjects(
p => constructorFor(keyType, Some(p)),
Invoke(getPath, "keyArray", ArrayType(keyDataType)),
keyDataType),
"array",
ObjectType(classOf[Array[Any]]))
val valueData =
Invoke(
MapObjects(
p => constructorFor(valueType, Some(p)),
Invoke(getPath, "valueArray", ArrayType(valueDataType)),
valueDataType),
"array",
ObjectType(classOf[Array[Any]]))
StaticInvoke(
ArrayBasedMapData,
ObjectType(classOf[JMap[_, _]]),
"toJavaMap",
keyData :: valueData :: Nil)
case other =>
val properties = getJavaBeanProperties(other)
assert(properties.length > 0)
val setters = properties.map { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
p.getWriteMethod.getName -> constructorFor(fieldType, Some(addToPath(fieldName)))
}.toMap
val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other))
val result = InitializeJavaBean(newInstance, setters)
if (path.nonEmpty) {
expressions.If(
IsNull(getPath),
expressions.Literal.create(null, ObjectType(other)),
result
)
} else {
result
}
}
}
/**
* Returns expressions for extracting all the fields from the given type.
*/
def extractorsFor(beanClass: Class[_]): CreateNamedStruct = {
val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
extractorFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct]
}
private def extractorFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
val (dataType, nullable) = inferDataType(elementType)
if (ScalaReflection.isNativeType(dataType)) {
NewInstance(
classOf[GenericArrayData],
input :: Nil,
dataType = ArrayType(dataType, nullable))
} else {
MapObjects(extractorFor(_, elementType), input, ObjectType(elementType.getRawType))
}
}
if (!inputObject.dataType.isInstanceOf[ObjectType]) {
inputObject
} else {
typeToken.getRawType match {
case c if c == classOf[String] =>
StaticInvoke(
classOf[UTF8String],
StringType,
"fromString",
inputObject :: Nil)
case c if c == classOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils,
TimestampType,
"fromJavaTimestamp",
inputObject :: Nil)
case c if c == classOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils,
DateType,
"fromJavaDate",
inputObject :: Nil)
case c if c == classOf[java.math.BigDecimal] =>
StaticInvoke(
Decimal,
DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil)
case c if c == classOf[java.lang.Boolean] =>
Invoke(inputObject, "booleanValue", BooleanType)
case c if c == classOf[java.lang.Byte] =>
Invoke(inputObject, "byteValue", ByteType)
case c if c == classOf[java.lang.Short] =>
Invoke(inputObject, "shortValue", ShortType)
case c if c == classOf[java.lang.Integer] =>
Invoke(inputObject, "intValue", IntegerType)
case c if c == classOf[java.lang.Long] =>
Invoke(inputObject, "longValue", LongType)
case c if c == classOf[java.lang.Float] =>
Invoke(inputObject, "floatValue", FloatType)
case c if c == classOf[java.lang.Double] =>
Invoke(inputObject, "doubleValue", DoubleType)
case _ if typeToken.isArray =>
toCatalystArray(inputObject, typeToken.getComponentType)
case _ if listType.isAssignableFrom(typeToken) =>
toCatalystArray(inputObject, elementType(typeToken))
case _ if mapType.isAssignableFrom(typeToken) =>
// TODO: for java map, if we get the keys and values by `keySet` and `values`, we can
// not guarantee they have same iteration order(which is different from scala map).
// A possible solution is creating a new `MapObjects` that can iterate a map directly.
throw new UnsupportedOperationException("map type is not supported currently")
case other =>
val properties = getJavaBeanProperties(other)
if (properties.length > 0) {
CreateNamedStruct(properties.flatMap { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val fieldValue = Invoke(
inputObject,
p.getReadMethod.getName,
inferExternalType(fieldType.getRawType))
expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
})
} else {
throw new UnsupportedOperationException(s"no encoder found for ${other.getName}")
}
}
}
}
}

View file

@ -29,8 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.{JavaTypeInference, InternalRow, ScalaReflection}
import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
/**
@ -68,6 +67,22 @@ object ExpressionEncoder {
ClassTag[T](cls))
}
// TODO: improve error message for java bean encoder.
def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = {
val schema = JavaTypeInference.inferDataType(beanClass)._1
assert(schema.isInstanceOf[StructType])
val toRowExpression = JavaTypeInference.extractorsFor(beanClass)
val fromRowExpression = JavaTypeInference.constructorFor(beanClass)
new ExpressionEncoder[T](
schema.asInstanceOf[StructType],
flat = false,
toRowExpression.flatten,
fromRowExpression,
ClassTag[T](beanClass))
}
/**
* Given a set of N encoders, constructs a new encoder that produce objects as items in an
* N-tuple. Note that these encoders should be unresolved so that information about
@ -216,7 +231,7 @@ case class ExpressionEncoder[T](
*/
def assertUnresolved(): Unit = {
(fromRowExpression +: toRowExpressions).foreach(_.foreach {
case a: AttributeReference =>
case a: AttributeReference if a.name != "loopVar" =>
sys.error(s"Unresolved encoder expected, but $a was found.")
case _ =>
})

View file

@ -346,7 +346,8 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext
* as an ArrayType. This is similar to a typical map operation, but where the lambda function
* is expressed using catalyst expressions.
*
* The following collection ObjectTypes are currently supported: Seq, Array, ArrayData
* The following collection ObjectTypes are currently supported:
* Seq, Array, ArrayData, java.util.List
*
* @param function A function that returns an expression, given an attribute that can be used
* to access the current value. This is does as a lambda function so that
@ -386,6 +387,8 @@ case class MapObjects(
(".size()", (i: String) => s".apply($i)", false)
case ObjectType(cls) if cls.isArray =>
(".length", (i: String) => s"[$i]", false)
case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
(".size()", (i: String) => s".get($i)", false)
case ArrayType(t, _) =>
val (sqlType, primitiveElement) = t match {
case m: MapType => (m, false)
@ -596,3 +599,40 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
override def dataType: DataType = ObjectType(tag.runtimeClass)
}
/**
* Initialize a Java Bean instance by setting its field values via setters.
*/
case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Expression])
extends Expression {
override def nullable: Boolean = beanInstance.nullable
override def children: Seq[Expression] = beanInstance +: setters.values.toSeq
override def dataType: DataType = beanInstance.dataType
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val instanceGen = beanInstance.gen(ctx)
val initialize = setters.map {
case (setterMethod, fieldValue) =>
val fieldGen = fieldValue.gen(ctx)
s"""
${fieldGen.code}
${instanceGen.value}.$setterMethod(${fieldGen.value});
"""
}
ev.isNull = instanceGen.isNull
ev.value = instanceGen.value
s"""
${instanceGen.code}
if (!${instanceGen.isNull}) {
${initialize.mkString("\n")}
}
"""
}
}

View file

@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.trees
import scala.collection.Map
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.types.{StructType, DataType}
@ -191,6 +193,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
case nonChild: AnyRef => nonChild
case null => null
}
case m: Map[_, _] => m.mapValues {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
if (newChild fastEquals oldChild) {
oldChild
} else {
changed = true
newChild
}
case nonChild: AnyRef => nonChild
case null => null
}.view.force // `mapValues` is lazy and we need to force it to materialize
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
@ -262,7 +277,17 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
} else {
Some(arg)
}
case m: Map[_, _] => m
case m: Map[_, _] => m.mapValues {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
if (!(newChild fastEquals arg)) {
changed = true
newChild
} else {
arg
}
case other => other
}.view.force // `mapValues` is lazy and we need to force it to materialize
case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if containsChild(arg) =>

View file

@ -70,4 +70,9 @@ object ArrayBasedMapData {
def toScalaMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
keys.zip(values).toMap
}
def toJavaMap(keys: Array[Any], values: Array[Any]): java.util.Map[Any, Any] = {
import scala.collection.JavaConverters._
keys.zip(values).toMap.asJava
}
}

View file

@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.util
import scala.collection.JavaConverters._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.{DataType, Decimal}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@ -24,6 +26,7 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
class GenericArrayData(val array: Array[Any]) extends ArrayData {
def this(seq: Seq[Any]) = this(seq.toArray)
def this(list: java.util.List[Any]) = this(list.asScala)
// TODO: This is boxing. We should specialize.
def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq)

View file

@ -38,6 +38,13 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]])
override def output: Seq[Attribute] = Nil
}
case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable {
override def children: Seq[Expression] = map.values.toSeq
override def nullable: Boolean = true
override def dataType: NullType = NullType
override lazy val resolved = true
}
class TreeNodeSuite extends SparkFunSuite {
test("top node changed") {
val after = Literal(1) transform { case Literal(1, _) => Literal(2) }
@ -236,4 +243,22 @@ class TreeNodeSuite extends SparkFunSuite {
val expected = ComplexPlan(Seq(Seq(Literal("1")), Seq(Literal("2"))))
assert(expected === actual)
}
test("expressions inside a map") {
val expression = ExpressionInMap(Map("1" -> Literal(1), "2" -> Literal(2)))
{
val actual = expression.transform {
case Literal(i: Int, _) => Literal(i + 1)
}
val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3)))
assert(actual === expected)
}
{
val actual = expression.withNewChildren(Seq(Literal(2), Literal(3)))
val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3)))
assert(actual === expected)
}
}
}

View file

@ -31,14 +31,15 @@ import org.apache.spark.Accumulator;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.function.*;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.GroupedDataset;
import org.apache.spark.sql.*;
import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.catalyst.encoders.OuterScopes;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.functions.*;
import static org.apache.spark.sql.types.DataTypes.*;
public class JavaDatasetSuite implements Serializable {
private transient JavaSparkContext jsc;
@ -506,4 +507,169 @@ public class JavaDatasetSuite implements Serializable {
public void testKryoEncoderErrorMessageForPrivateClass() {
Encoders.kryo(PrivateClassTest.class);
}
public class SimpleJavaBean implements Serializable {
private boolean a;
private int b;
private byte[] c;
private String[] d;
private List<String> e;
private List<Long> f;
public boolean isA() {
return a;
}
public void setA(boolean a) {
this.a = a;
}
public int getB() {
return b;
}
public void setB(int b) {
this.b = b;
}
public byte[] getC() {
return c;
}
public void setC(byte[] c) {
this.c = c;
}
public String[] getD() {
return d;
}
public void setD(String[] d) {
this.d = d;
}
public List<String> getE() {
return e;
}
public void setE(List<String> e) {
this.e = e;
}
public List<Long> getF() {
return f;
}
public void setF(List<Long> f) {
this.f = f;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SimpleJavaBean that = (SimpleJavaBean) o;
if (a != that.a) return false;
if (b != that.b) return false;
if (!Arrays.equals(c, that.c)) return false;
if (!Arrays.equals(d, that.d)) return false;
if (!e.equals(that.e)) return false;
return f.equals(that.f);
}
@Override
public int hashCode() {
int result = (a ? 1 : 0);
result = 31 * result + b;
result = 31 * result + Arrays.hashCode(c);
result = 31 * result + Arrays.hashCode(d);
result = 31 * result + e.hashCode();
result = 31 * result + f.hashCode();
return result;
}
}
public class NestedJavaBean implements Serializable {
private SimpleJavaBean a;
public SimpleJavaBean getA() {
return a;
}
public void setA(SimpleJavaBean a) {
this.a = a;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
NestedJavaBean that = (NestedJavaBean) o;
return a.equals(that.a);
}
@Override
public int hashCode() {
return a.hashCode();
}
}
@Test
public void testJavaBeanEncoder() {
OuterScopes.addOuterScope(this);
SimpleJavaBean obj1 = new SimpleJavaBean();
obj1.setA(true);
obj1.setB(3);
obj1.setC(new byte[]{1, 2});
obj1.setD(new String[]{"hello", null});
obj1.setE(Arrays.asList("a", "b"));
obj1.setF(Arrays.asList(100L, null, 200L));
SimpleJavaBean obj2 = new SimpleJavaBean();
obj2.setA(false);
obj2.setB(30);
obj2.setC(new byte[]{3, 4});
obj2.setD(new String[]{null, "world"});
obj2.setE(Arrays.asList("x", "y"));
obj2.setF(Arrays.asList(300L, null, 400L));
List<SimpleJavaBean> data = Arrays.asList(obj1, obj2);
Dataset<SimpleJavaBean> ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class));
Assert.assertEquals(data, ds.collectAsList());
NestedJavaBean obj3 = new NestedJavaBean();
obj3.setA(obj1);
List<NestedJavaBean> data2 = Arrays.asList(obj3);
Dataset<NestedJavaBean> ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class));
Assert.assertEquals(data2, ds2.collectAsList());
Row row1 = new GenericRow(new Object[]{
true,
3,
new byte[]{1, 2},
new String[]{"hello", null},
Arrays.asList("a", "b"),
Arrays.asList(100L, null, 200L)});
Row row2 = new GenericRow(new Object[]{
false,
30,
new byte[]{3, 4},
new String[]{null, "world"},
Arrays.asList("x", "y"),
Arrays.asList(300L, null, 400L)});
StructType schema = new StructType()
.add("a", BooleanType, false)
.add("b", IntegerType, false)
.add("c", BinaryType)
.add("d", createArrayType(StringType))
.add("e", createArrayType(StringType))
.add("f", createArrayType(LongType));
Dataset<SimpleJavaBean> ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema)
.as(Encoders.bean(SimpleJavaBean.class));
Assert.assertEquals(data, ds3.collectAsList());
}
}