[SPARK-25746][SQL] Refactoring ExpressionEncoder to get rid of flat flag

## What changes were proposed in this pull request?

This is inspired during implementing #21732. For now `ScalaReflection` needs to consider how `ExpressionEncoder` uses generated serializers and deserializers. And `ExpressionEncoder` has a weird `flat` flag. After discussion with cloud-fan, it seems to be better to refactor `ExpressionEncoder`. It should make SPARK-24762 easier to do.

To summarize the proposed changes:

1. `serializerFor` and `deserializerFor` return expressions for serializing/deserializing an input expression for a given type. They are private and should not be called directly.
2. `serializerForType` and `deserializerForType` returns an expression for serializing/deserializing for an object of type T to/from Spark SQL representation. It assumes the input object/Spark SQL representation is located at ordinal 0 of a row.

So in other words, `serializerForType` and `deserializerForType` return expressions for atomically serializing/deserializing JVM object to/from Spark SQL value.

A serializer returned by `serializerForType` will serialize an object at `row(0)` to a corresponding Spark SQL representation, e.g. primitive type, array, map, struct.

A deserializer returned by `deserializerForType` will deserialize an input field at `row(0)` to an object with given type.

3. The construction of `ExpressionEncoder` takes a pair of serializer and deserializer for type `T`. It uses them to create serializer and deserializer for T <-> row serialization. Now `ExpressionEncoder` dones't need to remember if serializer is flat or not. When we need to construct new `ExpressionEncoder` based on existing ones, we only need to change input location in the atomic serializer and deserializer.

## How was this patch tested?

Existing tests.

Closes #22749 from viirya/SPARK-24762-refactor.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Liang-Chi Hsieh 2018-10-25 19:27:45 +08:00 committed by Wenchen Fan
parent ddd1b1e8ae
commit cb5ea201df
12 changed files with 304 additions and 285 deletions

View file

@ -203,12 +203,10 @@ object Encoders {
validatePublicClass[T]()
ExpressionEncoder[T](
schema = new StructType().add("value", BinaryType),
flat = true,
serializer = Seq(
objSerializer =
EncodeUsingSerializer(
BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)),
deserializer =
BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo),
objDeserializer =
DecodeUsingSerializer[T](
Cast(GetColumnByOrdinal(0, BinaryType), BinaryType),
classTag[T],

View file

@ -187,26 +187,23 @@ object JavaTypeInference {
}
/**
* Returns an expression that can be used to deserialize an internal row to an object of java bean
* `T` with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
* of the same name as the constructor arguments. Nested classes will have their fields accessed
* using UnresolvedExtractValue.
* Returns an expression that can be used to deserialize a Spark SQL representation to an object
* of java bean `T` with a compatible schema. The Spark SQL representation is located at ordinal
* 0 of a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed
* using `UnresolvedExtractValue`.
*/
def deserializerFor(beanClass: Class[_]): Expression = {
deserializerFor(TypeToken.of(beanClass), None)
val typeToken = TypeToken.of(beanClass)
deserializerFor(typeToken, GetColumnByOrdinal(0, inferDataType(typeToken)._1))
}
private def deserializerFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = {
private def deserializerFor(typeToken: TypeToken[_], path: Expression): Expression = {
/** Returns the current path with a sub-field extracted. */
def addToPath(part: String): Expression = path
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
.getOrElse(UnresolvedAttribute(part))
/** Returns the current path or `GetColumnByOrdinal`. */
def getPath: Expression = path.getOrElse(GetColumnByOrdinal(0, inferDataType(typeToken)._1))
def addToPath(part: String): Expression = UnresolvedExtractValue(path,
expressions.Literal(part))
typeToken.getRawType match {
case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath
case c if !inferExternalType(c).isInstanceOf[ObjectType] => path
case c if c == classOf[java.lang.Short] ||
c == classOf[java.lang.Integer] ||
@ -219,7 +216,7 @@ object JavaTypeInference {
c,
ObjectType(c),
"valueOf",
getPath :: Nil,
path :: Nil,
returnNullable = false)
case c if c == classOf[java.sql.Date] =>
@ -227,7 +224,7 @@ object JavaTypeInference {
DateTimeUtils.getClass,
ObjectType(c),
"toJavaDate",
getPath :: Nil,
path :: Nil,
returnNullable = false)
case c if c == classOf[java.sql.Timestamp] =>
@ -235,14 +232,14 @@ object JavaTypeInference {
DateTimeUtils.getClass,
ObjectType(c),
"toJavaTimestamp",
getPath :: Nil,
path :: Nil,
returnNullable = false)
case c if c == classOf[java.lang.String] =>
Invoke(getPath, "toString", ObjectType(classOf[String]))
Invoke(path, "toString", ObjectType(classOf[String]))
case c if c == classOf[java.math.BigDecimal] =>
Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
case c if c.isArray =>
val elementType = c.getComponentType
@ -258,12 +255,12 @@ object JavaTypeInference {
}
primitiveMethod.map { method =>
Invoke(getPath, method, ObjectType(c))
Invoke(path, method, ObjectType(c))
}.getOrElse {
Invoke(
MapObjects(
p => deserializerFor(typeToken.getComponentType, Some(p)),
getPath,
p => deserializerFor(typeToken.getComponentType, p),
path,
inferDataType(elementType)._1),
"array",
ObjectType(c))
@ -272,8 +269,8 @@ object JavaTypeInference {
case c if listType.isAssignableFrom(typeToken) =>
val et = elementType(typeToken)
UnresolvedMapObjects(
p => deserializerFor(et, Some(p)),
getPath,
p => deserializerFor(et, p),
path,
customCollectionCls = Some(c))
case _ if mapType.isAssignableFrom(typeToken) =>
@ -282,16 +279,16 @@ object JavaTypeInference {
val keyData =
Invoke(
UnresolvedMapObjects(
p => deserializerFor(keyType, Some(p)),
GetKeyArrayFromMap(getPath)),
p => deserializerFor(keyType, p),
GetKeyArrayFromMap(path)),
"array",
ObjectType(classOf[Array[Any]]))
val valueData =
Invoke(
UnresolvedMapObjects(
p => deserializerFor(valueType, Some(p)),
GetValueArrayFromMap(getPath)),
p => deserializerFor(valueType, p),
GetValueArrayFromMap(path)),
"array",
ObjectType(classOf[Array[Any]]))
@ -307,7 +304,7 @@ object JavaTypeInference {
other,
ObjectType(other),
"valueOf",
Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil,
Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil,
returnNullable = false)
case other =>
@ -316,7 +313,7 @@ object JavaTypeInference {
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val (_, nullable) = inferDataType(fieldType)
val constructor = deserializerFor(fieldType, Some(addToPath(fieldName)))
val constructor = deserializerFor(fieldType, addToPath(fieldName))
val setter = if (nullable) {
constructor
} else {
@ -328,28 +325,23 @@ object JavaTypeInference {
val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false)
val result = InitializeJavaBean(newInstance, setters)
if (path.nonEmpty) {
expressions.If(
IsNull(getPath),
expressions.Literal.create(null, ObjectType(other)),
result
)
} else {
expressions.If(
IsNull(path),
expressions.Literal.create(null, ObjectType(other)),
result
}
)
}
}
/**
* Returns an expression for serializing an object of the given type to an internal row.
* Returns an expression for serializing an object of the given type to a Spark SQL
* representation. The input object is located at ordinal 0 of a row, i.e.,
* `BoundReference(0, _)`.
*/
def serializerFor(beanClass: Class[_]): CreateNamedStruct = {
def serializerFor(beanClass: Class[_]): Expression = {
val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean"))
serializerFor(nullSafeInput, TypeToken.of(beanClass)) match {
case expressions.If(_, _, s: CreateNamedStruct) => s
case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
}
serializerFor(nullSafeInput, TypeToken.of(beanClass))
}
private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {

View file

@ -24,7 +24,7 @@ import scala.util.Properties
import org.apache.commons.lang3.reflect.ConstructorUtils
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData}
@ -129,21 +129,44 @@ object ScalaReflection extends ScalaReflection {
}
/**
* Returns an expression that can be used to deserialize an input row to an object of type `T`
* with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
* of the same name as the constructor arguments. Nested classes will have their fields accessed
* using UnresolvedExtractValue.
* When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
* and lost the required data type, which may lead to runtime error if the real type doesn't
* match the encoder's schema.
* For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type
* is [a: int, b: long], then we will hit runtime error and say that we can't construct class
* `Data` with int and long, because we lost the information that `b` should be a string.
*
* When used on a primitive type, the constructor will instead default to extracting the value
* from ordinal 0 (since there are no names to map to). The actual location can be moved by
* calling resolve/bind with a new schema.
* This method help us "remember" the required data type by adding a `UpCast`. Note that we
* only need to do this for leaf nodes.
*/
def deserializerFor[T : TypeTag]: Expression = {
val tpe = localTypeOf[T]
private def upCastToExpectedType(expr: Expression, expected: DataType,
walkedTypePath: Seq[String]): Expression = expected match {
case _: StructType => expr
case _: ArrayType => expr
// TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and
// it's not trivial to support by-name resolution for StructType inside MapType.
case _ => UpCast(expr, expected, walkedTypePath)
}
/**
* Returns an expression that can be used to deserialize a Spark SQL representation to an object
* of type `T` with a compatible schema. The Spark SQL representation is located at ordinal 0 of
* a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed using
* `UnresolvedExtractValue`.
*
* The returned expression is used by `ExpressionEncoder`. The encoder will resolve and bind this
* deserializer expression when using it.
*/
def deserializerForType(tpe: `Type`): Expression = {
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "$clsName"""" :: Nil
val expr = deserializerFor(tpe, None, walkedTypePath)
val Schema(_, nullable) = schemaFor(tpe)
val Schema(dataType, nullable) = schemaFor(tpe)
// Assumes we are deserializing the first column of a row.
val input = upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType,
walkedTypePath)
val expr = deserializerFor(tpe, input, walkedTypePath)
if (nullable) {
expr
} else {
@ -151,16 +174,22 @@ object ScalaReflection extends ScalaReflection {
}
}
/**
* Returns an expression that can be used to deserialize an input expression to an object of type
* `T` with a compatible schema.
*
* @param tpe The `Type` of deserialized object.
* @param path The expression which can be used to extract serialized value.
* @param walkedTypePath The paths from top to bottom to access current field when deserializing.
*/
private def deserializerFor(
tpe: `Type`,
path: Option[Expression],
path: Expression,
walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects {
/** Returns the current path with a sub-field extracted. */
def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = {
val newPath = path
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
.getOrElse(UnresolvedAttribute.quoted(part))
val newPath = UnresolvedExtractValue(path, expressions.Literal(part))
upCastToExpectedType(newPath, dataType, walkedTypePath)
}
@ -169,46 +198,12 @@ object ScalaReflection extends ScalaReflection {
ordinal: Int,
dataType: DataType,
walkedTypePath: Seq[String]): Expression = {
val newPath = path
.map(p => GetStructField(p, ordinal))
.getOrElse(GetColumnByOrdinal(ordinal, dataType))
val newPath = GetStructField(path, ordinal)
upCastToExpectedType(newPath, dataType, walkedTypePath)
}
/** Returns the current path or `GetColumnByOrdinal`. */
def getPath: Expression = {
val dataType = schemaFor(tpe).dataType
if (path.isDefined) {
path.get
} else {
upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, walkedTypePath)
}
}
/**
* When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
* and lost the required data type, which may lead to runtime error if the real type doesn't
* match the encoder's schema.
* For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type
* is [a: int, b: long], then we will hit runtime error and say that we can't construct class
* `Data` with int and long, because we lost the information that `b` should be a string.
*
* This method help us "remember" the required data type by adding a `UpCast`. Note that we
* only need to do this for leaf nodes.
*/
def upCastToExpectedType(
expr: Expression,
expected: DataType,
walkedTypePath: Seq[String]): Expression = expected match {
case _: StructType => expr
case _: ArrayType => expr
// TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and
// it's not trivial to support by-name resolution for StructType inside MapType.
case _ => UpCast(expr, expected, walkedTypePath)
}
tpe.dealias match {
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => path
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
@ -219,44 +214,44 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[java.lang.Integer] =>
val boxedType = classOf[java.lang.Integer]
val objectType = ObjectType(boxedType)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
case t if t <:< localTypeOf[java.lang.Long] =>
val boxedType = classOf[java.lang.Long]
val objectType = ObjectType(boxedType)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
case t if t <:< localTypeOf[java.lang.Double] =>
val boxedType = classOf[java.lang.Double]
val objectType = ObjectType(boxedType)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
case t if t <:< localTypeOf[java.lang.Float] =>
val boxedType = classOf[java.lang.Float]
val objectType = ObjectType(boxedType)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
case t if t <:< localTypeOf[java.lang.Short] =>
val boxedType = classOf[java.lang.Short]
val objectType = ObjectType(boxedType)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
case t if t <:< localTypeOf[java.lang.Byte] =>
val boxedType = classOf[java.lang.Byte]
val objectType = ObjectType(boxedType)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
case t if t <:< localTypeOf[java.lang.Boolean] =>
val boxedType = classOf[java.lang.Boolean]
val objectType = ObjectType(boxedType)
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
case t if t <:< localTypeOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Date]),
"toJavaDate",
getPath :: Nil,
path :: Nil,
returnNullable = false)
case t if t <:< localTypeOf[java.sql.Timestamp] =>
@ -264,25 +259,25 @@ object ScalaReflection extends ScalaReflection {
DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Timestamp]),
"toJavaTimestamp",
getPath :: Nil,
path :: Nil,
returnNullable = false)
case t if t <:< localTypeOf[java.lang.String] =>
Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false)
Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false)
case t if t <:< localTypeOf[java.math.BigDecimal] =>
Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
returnNullable = false)
case t if t <:< localTypeOf[BigDecimal] =>
Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false)
Invoke(path, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false)
case t if t <:< localTypeOf[java.math.BigInteger] =>
Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]),
Invoke(path, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]),
returnNullable = false)
case t if t <:< localTypeOf[scala.math.BigInt] =>
Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]),
Invoke(path, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]),
returnNullable = false)
case t if t <:< localTypeOf[Array[_]] =>
@ -294,7 +289,7 @@ object ScalaReflection extends ScalaReflection {
val mapFunction: Expression => Expression = element => {
// upcast the array element to the data type the encoder expected.
val casted = upCastToExpectedType(element, dataType, newTypePath)
val converter = deserializerFor(elementType, Some(casted), newTypePath)
val converter = deserializerFor(elementType, casted, newTypePath)
if (elementNullable) {
converter
} else {
@ -302,7 +297,7 @@ object ScalaReflection extends ScalaReflection {
}
}
val arrayData = UnresolvedMapObjects(mapFunction, getPath)
val arrayData = UnresolvedMapObjects(mapFunction, path)
val arrayCls = arrayClassFor(elementType)
if (elementNullable) {
@ -334,7 +329,7 @@ object ScalaReflection extends ScalaReflection {
val mapFunction: Expression => Expression = element => {
// upcast the array element to the data type the encoder expected.
val casted = upCastToExpectedType(element, dataType, newTypePath)
val converter = deserializerFor(elementType, Some(casted), newTypePath)
val converter = deserializerFor(elementType, casted, newTypePath)
if (elementNullable) {
converter
} else {
@ -349,16 +344,16 @@ object ScalaReflection extends ScalaReflection {
classOf[scala.collection.Set[_]]
case _ => mirror.runtimeClass(t.typeSymbol.asClass)
}
UnresolvedMapObjects(mapFunction, getPath, Some(cls))
UnresolvedMapObjects(mapFunction, path, Some(cls))
case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
val TypeRef(_, _, Seq(keyType, valueType)) = t
CatalystToExternalMap(
p => deserializerFor(keyType, Some(p), walkedTypePath),
p => deserializerFor(valueType, Some(p), walkedTypePath),
getPath,
p => deserializerFor(keyType, p, walkedTypePath),
p => deserializerFor(valueType, p, walkedTypePath),
path,
mirror.runtimeClass(t.typeSymbol.asClass)
)
@ -368,7 +363,7 @@ object ScalaReflection extends ScalaReflection {
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
case t if UDTRegistration.exists(getClassNameFromType(t)) =>
val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
@ -377,7 +372,7 @@ object ScalaReflection extends ScalaReflection {
udt.getClass,
Nil,
dataType = ObjectType(udt.getClass))
Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
@ -392,12 +387,12 @@ object ScalaReflection extends ScalaReflection {
val constructor = if (cls.getName startsWith "scala.Tuple") {
deserializerFor(
fieldType,
Some(addToPathOrdinal(i, dataType, newTypePath)),
addToPathOrdinal(i, dataType, newTypePath),
newTypePath)
} else {
deserializerFor(
fieldType,
Some(addToPath(fieldName, dataType, newTypePath)),
addToPath(fieldName, dataType, newTypePath),
newTypePath)
}
@ -410,20 +405,17 @@ object ScalaReflection extends ScalaReflection {
val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
if (path.nonEmpty) {
expressions.If(
IsNull(getPath),
expressions.Literal.create(null, ObjectType(cls)),
newInstance
)
} else {
expressions.If(
IsNull(path),
expressions.Literal.create(null, ObjectType(cls)),
newInstance
}
)
}
}
/**
* Returns an expression for serializing an object of type T to an internal row.
* Returns an expression for serializing an object of type T to Spark SQL representation. The
* input object is located at ordinal 0 of a row, i.e., `BoundReference(0, _)`.
*
* If the given type is not supported, i.e. there is no encoder can be built for this type,
* an [[UnsupportedOperationException]] will be thrown with detailed error message to explain
@ -434,17 +426,21 @@ object ScalaReflection extends ScalaReflection {
* * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"`
* * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")`
*/
def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
val tpe = localTypeOf[T]
def serializerForType(tpe: `Type`): Expression = ScalaReflection.cleanUpReflectionObjects {
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "$clsName"""" :: Nil
serializerFor(inputObject, tpe, walkedTypePath) match {
case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s
case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
}
// The input object to `ExpressionEncoder` is located at first column of an row.
val inputObject = BoundReference(0, dataTypeFor(tpe),
nullable = !tpe.typeSymbol.asClass.isPrimitive)
serializerFor(inputObject, tpe, walkedTypePath)
}
/** Helper for extracting internal fields from a case class. */
/**
* Returns an expression for serializing the value of an input expression into Spark SQL
* internal representation.
*/
private def serializerFor(
inputObject: Expression,
tpe: `Type`,

View file

@ -25,10 +25,11 @@ import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaRefle
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, Invoke, NewInstance}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation}
import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType}
import org.apache.spark.sql.types.{ObjectType, StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
/**
@ -43,8 +44,8 @@ import org.apache.spark.util.Utils
* to the name `value`.
*/
object ExpressionEncoder {
def apply[T : TypeTag](): ExpressionEncoder[T] = {
// We convert the not-serializable TypeTag into StructType and ClassTag.
val mirror = ScalaReflection.mirror
val tpe = typeTag[T].in(mirror).tpe
@ -58,25 +59,11 @@ object ExpressionEncoder {
}
val cls = mirror.runtimeClass(tpe)
val flat = !ScalaReflection.definedByConstructorParams(tpe)
val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = !cls.isPrimitive)
val nullSafeInput = if (flat) {
inputObject
} else {
// For input object of Product type, we can't encode it to row if it's null, as Spark SQL
// doesn't allow top-level row to be null, only its columns can be null.
AssertNotNull(inputObject, Seq("top level Product input object"))
}
val serializer = ScalaReflection.serializerFor[T](nullSafeInput)
val deserializer = ScalaReflection.deserializerFor[T]
val schema = serializer.dataType
val serializer = ScalaReflection.serializerForType(tpe)
val deserializer = ScalaReflection.deserializerForType(tpe)
new ExpressionEncoder[T](
schema,
flat,
serializer.flatten,
serializer,
deserializer,
ClassTag[T](cls))
}
@ -86,14 +73,12 @@ object ExpressionEncoder {
val schema = JavaTypeInference.inferDataType(beanClass)._1
assert(schema.isInstanceOf[StructType])
val serializer = JavaTypeInference.serializerFor(beanClass)
val deserializer = JavaTypeInference.deserializerFor(beanClass)
val objSerializer = JavaTypeInference.serializerFor(beanClass)
val objDeserializer = JavaTypeInference.deserializerFor(beanClass)
new ExpressionEncoder[T](
schema.asInstanceOf[StructType],
flat = false,
serializer.flatten,
deserializer,
objSerializer,
objDeserializer,
ClassTag[T](beanClass))
}
@ -103,75 +88,59 @@ object ExpressionEncoder {
* name/positional binding is preserved.
*/
def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
// TODO: check if encoders length is more than 22 and throw exception for it.
encoders.foreach(_.assertUnresolved())
val schema = StructType(encoders.zipWithIndex.map {
case (e, i) =>
val (dataType, nullable) = if (e.flat) {
e.schema.head.dataType -> e.schema.head.nullable
} else {
e.schema -> true
}
StructField(s"_${i + 1}", dataType, nullable)
StructField(s"_${i + 1}", e.objSerializer.dataType, e.objSerializer.nullable)
})
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
val serializer = encoders.zipWithIndex.map { case (enc, index) =>
val originalInputObject = enc.serializer.head.collect { case b: BoundReference => b }.head
val serializers = encoders.zipWithIndex.map { case (enc, index) =>
val boundRefs = enc.objSerializer.collect { case b: BoundReference => b }.distinct
assert(boundRefs.size == 1, "object serializer should have only one bound reference but " +
s"there are ${boundRefs.size}")
val originalInputObject = boundRefs.head
val newInputObject = Invoke(
BoundReference(0, ObjectType(cls), nullable = true),
s"_${index + 1}",
originalInputObject.dataType)
originalInputObject.dataType,
returnNullable = originalInputObject.nullable)
val newSerializer = enc.serializer.map(_.transformUp {
case b: BoundReference if b == originalInputObject => newInputObject
})
val serializerExpr = if (enc.flat) {
newSerializer.head
} else {
// For non-flat encoder, the input object is not top level anymore after being combined to
// a tuple encoder, thus it can be null and we should wrap the `CreateStruct` with `If` and
// null check to handle null case correctly.
// e.g. for Encoder[(Int, String)], the serializer expressions will create 2 columns, and is
// not able to handle the case when the input tuple is null. This is not a problem as there
// is a check to make sure the input object won't be null. However, if this encoder is used
// to create a bigger tuple encoder, the original input object becomes a filed of the new
// input tuple and can be null. So instead of creating a struct directly here, we should add
// a null/None check and return a null struct if the null/None check fails.
val struct = CreateStruct(newSerializer)
val nullCheck = Or(
IsNull(newInputObject),
Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil))
If(nullCheck, Literal.create(null, struct.dataType), struct)
val newSerializer = enc.objSerializer.transformUp {
case b: BoundReference => newInputObject
}
Alias(serializerExpr, s"_${index + 1}")()
Alias(newSerializer, s"_${index + 1}")()
}
val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) =>
if (enc.flat) {
enc.deserializer.transform {
case g: GetColumnByOrdinal => g.copy(ordinal = index)
}
val getColumnsByOrdinals = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }
.distinct
assert(getColumnsByOrdinals.size == 1, "object deserializer should have only one " +
s"`GetColumnByOrdinal`, but there are ${getColumnsByOrdinals.size}")
val input = GetStructField(GetColumnByOrdinal(0, schema), index)
val newDeserializer = enc.objDeserializer.transformUp {
case GetColumnByOrdinal(0, _) => input
}
if (schema(index).nullable) {
If(IsNull(input), Literal.create(null, newDeserializer.dataType), newDeserializer)
} else {
val input = GetColumnByOrdinal(index, enc.schema)
val deserialized = enc.deserializer.transformUp {
case UnresolvedAttribute(nameParts) =>
assert(nameParts.length == 1)
UnresolvedExtractValue(input, Literal(nameParts.head))
case GetColumnByOrdinal(ordinal, _) => GetStructField(input, ordinal)
}
If(IsNull(input), Literal.create(null, deserialized.dataType), deserialized)
newDeserializer
}
}
val serializer = If(IsNull(BoundReference(0, ObjectType(cls), nullable = true)),
Literal.create(null, schema), CreateStruct(serializers))
val deserializer =
NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false)
new ExpressionEncoder[Any](
schema,
flat = false,
serializer,
deserializer,
ClassTag(cls))
@ -212,21 +181,91 @@ object ExpressionEncoder {
* A generic encoder for JVM objects that uses Catalyst Expressions for a `serializer`
* and a `deserializer`.
*
* @param schema The schema after converting `T` to a Spark SQL row.
* @param serializer A set of expressions, one for each top-level field that can be used to
* extract the values from a raw object into an [[InternalRow]].
* @param deserializer An expression that will construct an object given an [[InternalRow]].
* @param objSerializer An expression that can be used to encode a raw object to corresponding
* Spark SQL representation that can be a primitive column, array, map or a
* struct. This represents how Spark SQL generally serializes an object of
* type `T`.
* @param objDeserializer An expression that will construct an object given a Spark SQL
* representation. This represents how Spark SQL generally deserializes
* a serialized value in Spark SQL representation back to an object of
* type `T`.
* @param clsTag A classtag for `T`.
*/
case class ExpressionEncoder[T](
schema: StructType,
flat: Boolean,
serializer: Seq[Expression],
deserializer: Expression,
objSerializer: Expression,
objDeserializer: Expression,
clsTag: ClassTag[T])
extends Encoder[T] {
if (flat) require(serializer.size == 1)
/**
* A sequence of expressions, one for each top-level field that can be used to
* extract the values from a raw object into an [[InternalRow]]:
* 1. If `serializer` encodes a raw object to a struct, strip the outer If-IsNull and get
* the `CreateNamedStruct`.
* 2. For other cases, wrap the single serializer with `CreateNamedStruct`.
*/
val serializer: Seq[NamedExpression] = {
val clsName = Utils.getSimpleName(clsTag.runtimeClass)
if (isSerializedAsStruct) {
val nullSafeSerializer = objSerializer.transformUp {
case r: BoundReference =>
// For input object of Product type, we can't encode it to row if it's null, as Spark SQL
// doesn't allow top-level row to be null, only its columns can be null.
AssertNotNull(r, Seq("top level Product or row object"))
}
nullSafeSerializer match {
case If(_: IsNull, _, s: CreateNamedStruct) => s
case s: CreateNamedStruct => s
case _ =>
throw new RuntimeException(s"class $clsName has unexpected serializer: $objSerializer")
}
} else {
// For other input objects like primitive, array, map, etc., we construct a struct to wrap
// the serializer which is a column of an row.
CreateNamedStruct(Literal("value") :: objSerializer :: Nil)
}
}.flatten
/**
* Returns an expression that can be used to deserialize an input row to an object of type `T`
* with a compatible schema. Fields of the row will be extracted using `UnresolvedAttribute`.
* of the same name as the constructor arguments.
*
* For complex objects that are encoded to structs, Fields of the struct will be extracted using
* `GetColumnByOrdinal` with corresponding ordinal.
*/
val deserializer: Expression = {
if (isSerializedAsStruct) {
// We serialized this kind of objects to root-level row. The input of general deserializer
// is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to
// transform attributes accessors.
objDeserializer.transform {
case UnresolvedExtractValue(GetColumnByOrdinal(0, _),
Literal(part: UTF8String, StringType)) =>
UnresolvedAttribute.quoted(part.toString)
case GetStructField(GetColumnByOrdinal(0, dt), ordinal, _) =>
GetColumnByOrdinal(ordinal, dt)
case If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance) => n
case If(IsNull(GetColumnByOrdinal(0, _)), _, i: InitializeJavaBean) => i
}
} else {
// For other input objects like primitive, array, map, etc., we deserialize the first column
// of a row to the object.
objDeserializer
}
}
// The schema after converting `T` to a Spark SQL row. This schema is dependent on the given
// serialier.
val schema: StructType = StructType(serializer.map { s =>
StructField(s.name, s.dataType, s.nullable)
})
/**
* Returns true if the type `T` is serialized as a struct.
*/
def isSerializedAsStruct: Boolean = objSerializer.dataType.isInstanceOf[StructType]
// serializer expressions are used to encode an object to a row, while the object is usually an
// intermediate value produced inside an operator, not from the output of the child operator. This
@ -258,7 +297,7 @@ case class ExpressionEncoder[T](
analyzer.checkAnalysis(analyzedPlan)
val resolved = SimplifyCasts(analyzedPlan).asInstanceOf[DeserializeToObject].deserializer
val bound = BindReferences.bindReference(resolved, attrs)
copy(deserializer = bound)
copy(objDeserializer = bound)
}
@transient

View file

@ -58,12 +58,10 @@ object RowEncoder {
def apply(schema: StructType): ExpressionEncoder[Row] = {
val cls = classOf[Row]
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
val serializer = serializerFor(AssertNotNull(inputObject, Seq("top level row object")), schema)
val deserializer = deserializerFor(schema)
val serializer = serializerFor(inputObject, schema)
val deserializer = deserializerFor(GetColumnByOrdinal(0, serializer.dataType), schema)
new ExpressionEncoder[Row](
schema,
flat = false,
serializer.asInstanceOf[CreateNamedStruct].flatten,
serializer,
deserializer,
ClassTag(cls))
}
@ -237,13 +235,9 @@ object RowEncoder {
case udt: UserDefinedType[_] => ObjectType(udt.userClass)
}
private def deserializerFor(schema: StructType): Expression = {
private def deserializerFor(input: Expression, schema: StructType): Expression = {
val fields = schema.zipWithIndex.map { case (f, i) =>
val dt = f.dataType match {
case p: PythonUserDefinedType => p.sqlType
case other => other
}
deserializerFor(GetColumnByOrdinal(i, dt))
deserializerFor(GetStructField(input, i))
}
CreateExternalRow(fields, schema)
}

View file

@ -19,12 +19,13 @@ package org.apache.spark.sql.catalyst
import java.sql.{Date, Timestamp}
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, SpecificInternalRow, UpCast}
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, If, SpecificInternalRow, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
case class PrimitiveData(
intField: Int,
@ -112,6 +113,14 @@ object TestingUDT {
class ScalaReflectionSuite extends SparkFunSuite {
import org.apache.spark.sql.catalyst.ScalaReflection._
// A helper method used to test `ScalaReflection.serializerForType`.
private def serializerFor[T: TypeTag]: Expression =
serializerForType(ScalaReflection.localTypeOf[T])
// A helper method used to test `ScalaReflection.deserializerForType`.
private def deserializerFor[T: TypeTag]: Expression =
deserializerForType(ScalaReflection.localTypeOf[T])
test("SQLUserDefinedType annotation on Scala structure") {
val schema = schemaFor[TestingUDT.NestedStruct]
assert(schema === Schema(
@ -263,13 +272,9 @@ class ScalaReflectionSuite extends SparkFunSuite {
test("SPARK-15062: Get correct serializer for List[_]") {
val list = List(1, 2, 3)
val serializer = serializerFor[List[Int]](BoundReference(
0, ObjectType(list.getClass), nullable = false))
assert(serializer.children.size == 2)
assert(serializer.children.head.isInstanceOf[Literal])
assert(serializer.children.head.asInstanceOf[Literal].value === UTF8String.fromString("value"))
assert(serializer.children.last.isInstanceOf[NewInstance])
assert(serializer.children.last.asInstanceOf[NewInstance]
val serializer = serializerFor[List[Int]]
assert(serializer.isInstanceOf[NewInstance])
assert(serializer.asInstanceOf[NewInstance]
.cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData]))
}
@ -280,59 +285,58 @@ class ScalaReflectionSuite extends SparkFunSuite {
test("serialize and deserialize arbitrary sequence types") {
import scala.collection.immutable.Queue
val queueSerializer = serializerFor[Queue[Int]](BoundReference(
0, ObjectType(classOf[Queue[Int]]), nullable = false))
assert(queueSerializer.dataType.head.dataType ==
val queueSerializer = serializerFor[Queue[Int]]
assert(queueSerializer.dataType ==
ArrayType(IntegerType, containsNull = false))
val queueDeserializer = deserializerFor[Queue[Int]]
assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]]))
import scala.collection.mutable.ArrayBuffer
val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference(
0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false))
assert(arrayBufferSerializer.dataType.head.dataType ==
val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]]
assert(arrayBufferSerializer.dataType ==
ArrayType(IntegerType, containsNull = false))
val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))
}
test("serialize and deserialize arbitrary map types") {
val mapSerializer = serializerFor[Map[Int, Int]](BoundReference(
0, ObjectType(classOf[Map[Int, Int]]), nullable = false))
assert(mapSerializer.dataType.head.dataType ==
val mapSerializer = serializerFor[Map[Int, Int]]
assert(mapSerializer.dataType ==
MapType(IntegerType, IntegerType, valueContainsNull = false))
val mapDeserializer = deserializerFor[Map[Int, Int]]
assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]]))
import scala.collection.immutable.HashMap
val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference(
0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false))
assert(hashMapSerializer.dataType.head.dataType ==
val hashMapSerializer = serializerFor[HashMap[Int, Int]]
assert(hashMapSerializer.dataType ==
MapType(IntegerType, IntegerType, valueContainsNull = false))
val hashMapDeserializer = deserializerFor[HashMap[Int, Int]]
assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]]))
import scala.collection.mutable.{LinkedHashMap => LHMap}
val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference(
0, ObjectType(classOf[LHMap[Long, String]]), nullable = false))
assert(linkedHashMapSerializer.dataType.head.dataType ==
val linkedHashMapSerializer = serializerFor[LHMap[Long, String]]
assert(linkedHashMapSerializer.dataType ==
MapType(LongType, StringType, valueContainsNull = true))
val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]]
assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]]))
}
test("SPARK-22442: Generate correct field names for special characters") {
val serializer = serializerFor[SpecialCharAsFieldData](BoundReference(
0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false))
val serializer = serializerFor[SpecialCharAsFieldData]
.collect {
case If(_, _, s: CreateNamedStruct) => s
}.head
val deserializer = deserializerFor[SpecialCharAsFieldData]
assert(serializer.dataType(0).name == "field.1")
assert(serializer.dataType(1).name == "field 2")
val argumentsFields = deserializer.asInstanceOf[NewInstance].arguments.flatMap { _.collect {
case UpCast(u: UnresolvedAttribute, _, _) => u.nameParts
val newInstance = deserializer.collect { case n: NewInstance => n }.head
val argumentsFields = newInstance.arguments.flatMap { _.collect {
case UpCast(u: UnresolvedExtractValue, _, _) => u.extraction.toString
}}
assert(argumentsFields(0) == Seq("field.1"))
assert(argumentsFields(1) == Seq("field 2"))
assert(argumentsFields(0) == "field.1")
assert(argumentsFields(1) == "field 2")
}
test("SPARK-22472: add null check for top-level primitive values") {
@ -351,8 +355,8 @@ class ScalaReflectionSuite extends SparkFunSuite {
test("SPARK-23835: add null check to non-nullable types in Tuples") {
def numberOfCheckedArguments(deserializer: Expression): Int = {
assert(deserializer.isInstanceOf[NewInstance])
deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull])
val newInstance = deserializer.collect { case n: NewInstance => n}.head
newInstance.arguments.count(_.isInstanceOf[AssertNotNull])
}
assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2)
assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1)

View file

@ -28,9 +28,9 @@ import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
import org.apache.spark.sql.catalyst.analysis.AnalysisTest
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@ -348,7 +348,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
test("nullable of encoder serializer") {
def checkNullable[T: Encoder](nullable: Boolean): Unit = {
assert(encoderFor[T].serializer.forall(_.nullable === nullable))
assert(encoderFor[T].objSerializer.nullable === nullable)
}
// test for flat encoders

View file

@ -239,7 +239,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
val encoder = RowEncoder(schema)
val e = intercept[RuntimeException](encoder.toRow(null))
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
assert(e.getMessage.contains("top level row object"))
assert(e.getMessage.contains("top level Product or row object"))
}
test("RowEncoder should validate external type") {

View file

@ -1087,7 +1087,7 @@ class Dataset[T] private[sql](
// Note that we do this before joining them, to enable the join operator to return null for one
// side, in cases like outer-join.
val left = {
val combined = if (this.exprEnc.flat) {
val combined = if (!this.exprEnc.isSerializedAsStruct) {
assert(joined.left.output.length == 1)
Alias(joined.left.output.head, "_1")()
} else {
@ -1097,7 +1097,7 @@ class Dataset[T] private[sql](
}
val right = {
val combined = if (other.exprEnc.flat) {
val combined = if (!other.exprEnc.isSerializedAsStruct) {
assert(joined.right.output.length == 1)
Alias(joined.right.output.head, "_2")()
} else {
@ -1110,14 +1110,14 @@ class Dataset[T] private[sql](
// combine the outputs of each join side.
val conditionExpr = joined.condition.get transformUp {
case a: Attribute if joined.left.outputSet.contains(a) =>
if (this.exprEnc.flat) {
if (!this.exprEnc.isSerializedAsStruct) {
left.output.head
} else {
val index = joined.left.output.indexWhere(_.exprId == a.exprId)
GetStructField(left.output.head, index)
}
case a: Attribute if joined.right.outputSet.contains(a) =>
if (other.exprEnc.flat) {
if (!other.exprEnc.isSerializedAsStruct) {
right.output.head
} else {
val index = joined.right.output.indexWhere(_.exprId == a.exprId)
@ -1390,7 +1390,7 @@ class Dataset[T] private[sql](
implicit val encoder = c1.encoder
val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan)
if (encoder.flat) {
if (!encoder.isSerializedAsStruct) {
new Dataset[U1](sparkSession, project, encoder)
} else {
// Flattens inner fields of U1

View file

@ -457,7 +457,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
val encoders = columns.map(_.encoder)
val namedColumns =
columns.map(_.withInputType(vExprEnc, dataAttributes).named)
val keyColumn = if (kExprEnc.flat) {
val keyColumn = if (!kExprEnc.isSerializedAsStruct) {
assert(groupingAttributes.length == 1)
groupingAttributes.head
} else {

View file

@ -38,18 +38,14 @@ object TypedAggregateExpression {
val bufferSerializer = bufferEncoder.namedExpressions
val outputEncoder = encoderFor[OUT]
val outputType = if (outputEncoder.flat) {
outputEncoder.schema.head.dataType
} else {
outputEncoder.schema
}
val outputType = outputEncoder.objSerializer.dataType
// Checks if the buffer object is simple, i.e. the buffer encoder is flat and the serializer
// expression is an alias of `BoundReference`, which means the buffer object doesn't need
// serialization.
val isSimpleBuffer = {
bufferSerializer.head match {
case Alias(_: BoundReference, _) if bufferEncoder.flat => true
case Alias(_: BoundReference, _) if !bufferEncoder.isSerializedAsStruct => true
case _ => false
}
}
@ -71,7 +67,7 @@ object TypedAggregateExpression {
outputEncoder.serializer,
outputEncoder.deserializer.dataType,
outputType,
!outputEncoder.flat || outputEncoder.schema.head.nullable)
outputEncoder.objSerializer.nullable)
} else {
ComplexTypedAggregateExpression(
aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
@ -82,7 +78,7 @@ object TypedAggregateExpression {
bufferEncoder.resolveAndBind().deserializer,
outputEncoder.serializer,
outputType,
!outputEncoder.flat || outputEncoder.schema.head.nullable)
outputEncoder.objSerializer.nullable)
}
}
}

View file

@ -1065,7 +1065,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("Dataset should throw RuntimeException if top-level product input object is null") {
val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS())
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
assert(e.getMessage.contains("top level Product input object"))
assert(e.getMessage.contains("top level Product or row object"))
}
test("dropDuplicates") {