[SPARK-25746][SQL] Refactoring ExpressionEncoder to get rid of flat flag
## What changes were proposed in this pull request? This is inspired during implementing #21732. For now `ScalaReflection` needs to consider how `ExpressionEncoder` uses generated serializers and deserializers. And `ExpressionEncoder` has a weird `flat` flag. After discussion with cloud-fan, it seems to be better to refactor `ExpressionEncoder`. It should make SPARK-24762 easier to do. To summarize the proposed changes: 1. `serializerFor` and `deserializerFor` return expressions for serializing/deserializing an input expression for a given type. They are private and should not be called directly. 2. `serializerForType` and `deserializerForType` returns an expression for serializing/deserializing for an object of type T to/from Spark SQL representation. It assumes the input object/Spark SQL representation is located at ordinal 0 of a row. So in other words, `serializerForType` and `deserializerForType` return expressions for atomically serializing/deserializing JVM object to/from Spark SQL value. A serializer returned by `serializerForType` will serialize an object at `row(0)` to a corresponding Spark SQL representation, e.g. primitive type, array, map, struct. A deserializer returned by `deserializerForType` will deserialize an input field at `row(0)` to an object with given type. 3. The construction of `ExpressionEncoder` takes a pair of serializer and deserializer for type `T`. It uses them to create serializer and deserializer for T <-> row serialization. Now `ExpressionEncoder` dones't need to remember if serializer is flat or not. When we need to construct new `ExpressionEncoder` based on existing ones, we only need to change input location in the atomic serializer and deserializer. ## How was this patch tested? Existing tests. Closes #22749 from viirya/SPARK-24762-refactor. Authored-by: Liang-Chi Hsieh <viirya@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
ddd1b1e8ae
commit
cb5ea201df
|
@ -203,12 +203,10 @@ object Encoders {
|
|||
validatePublicClass[T]()
|
||||
|
||||
ExpressionEncoder[T](
|
||||
schema = new StructType().add("value", BinaryType),
|
||||
flat = true,
|
||||
serializer = Seq(
|
||||
objSerializer =
|
||||
EncodeUsingSerializer(
|
||||
BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)),
|
||||
deserializer =
|
||||
BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo),
|
||||
objDeserializer =
|
||||
DecodeUsingSerializer[T](
|
||||
Cast(GetColumnByOrdinal(0, BinaryType), BinaryType),
|
||||
classTag[T],
|
||||
|
|
|
@ -187,26 +187,23 @@ object JavaTypeInference {
|
|||
}
|
||||
|
||||
/**
|
||||
* Returns an expression that can be used to deserialize an internal row to an object of java bean
|
||||
* `T` with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
|
||||
* of the same name as the constructor arguments. Nested classes will have their fields accessed
|
||||
* using UnresolvedExtractValue.
|
||||
* Returns an expression that can be used to deserialize a Spark SQL representation to an object
|
||||
* of java bean `T` with a compatible schema. The Spark SQL representation is located at ordinal
|
||||
* 0 of a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed
|
||||
* using `UnresolvedExtractValue`.
|
||||
*/
|
||||
def deserializerFor(beanClass: Class[_]): Expression = {
|
||||
deserializerFor(TypeToken.of(beanClass), None)
|
||||
val typeToken = TypeToken.of(beanClass)
|
||||
deserializerFor(typeToken, GetColumnByOrdinal(0, inferDataType(typeToken)._1))
|
||||
}
|
||||
|
||||
private def deserializerFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = {
|
||||
private def deserializerFor(typeToken: TypeToken[_], path: Expression): Expression = {
|
||||
/** Returns the current path with a sub-field extracted. */
|
||||
def addToPath(part: String): Expression = path
|
||||
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
|
||||
.getOrElse(UnresolvedAttribute(part))
|
||||
|
||||
/** Returns the current path or `GetColumnByOrdinal`. */
|
||||
def getPath: Expression = path.getOrElse(GetColumnByOrdinal(0, inferDataType(typeToken)._1))
|
||||
def addToPath(part: String): Expression = UnresolvedExtractValue(path,
|
||||
expressions.Literal(part))
|
||||
|
||||
typeToken.getRawType match {
|
||||
case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath
|
||||
case c if !inferExternalType(c).isInstanceOf[ObjectType] => path
|
||||
|
||||
case c if c == classOf[java.lang.Short] ||
|
||||
c == classOf[java.lang.Integer] ||
|
||||
|
@ -219,7 +216,7 @@ object JavaTypeInference {
|
|||
c,
|
||||
ObjectType(c),
|
||||
"valueOf",
|
||||
getPath :: Nil,
|
||||
path :: Nil,
|
||||
returnNullable = false)
|
||||
|
||||
case c if c == classOf[java.sql.Date] =>
|
||||
|
@ -227,7 +224,7 @@ object JavaTypeInference {
|
|||
DateTimeUtils.getClass,
|
||||
ObjectType(c),
|
||||
"toJavaDate",
|
||||
getPath :: Nil,
|
||||
path :: Nil,
|
||||
returnNullable = false)
|
||||
|
||||
case c if c == classOf[java.sql.Timestamp] =>
|
||||
|
@ -235,14 +232,14 @@ object JavaTypeInference {
|
|||
DateTimeUtils.getClass,
|
||||
ObjectType(c),
|
||||
"toJavaTimestamp",
|
||||
getPath :: Nil,
|
||||
path :: Nil,
|
||||
returnNullable = false)
|
||||
|
||||
case c if c == classOf[java.lang.String] =>
|
||||
Invoke(getPath, "toString", ObjectType(classOf[String]))
|
||||
Invoke(path, "toString", ObjectType(classOf[String]))
|
||||
|
||||
case c if c == classOf[java.math.BigDecimal] =>
|
||||
Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
|
||||
Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
|
||||
|
||||
case c if c.isArray =>
|
||||
val elementType = c.getComponentType
|
||||
|
@ -258,12 +255,12 @@ object JavaTypeInference {
|
|||
}
|
||||
|
||||
primitiveMethod.map { method =>
|
||||
Invoke(getPath, method, ObjectType(c))
|
||||
Invoke(path, method, ObjectType(c))
|
||||
}.getOrElse {
|
||||
Invoke(
|
||||
MapObjects(
|
||||
p => deserializerFor(typeToken.getComponentType, Some(p)),
|
||||
getPath,
|
||||
p => deserializerFor(typeToken.getComponentType, p),
|
||||
path,
|
||||
inferDataType(elementType)._1),
|
||||
"array",
|
||||
ObjectType(c))
|
||||
|
@ -272,8 +269,8 @@ object JavaTypeInference {
|
|||
case c if listType.isAssignableFrom(typeToken) =>
|
||||
val et = elementType(typeToken)
|
||||
UnresolvedMapObjects(
|
||||
p => deserializerFor(et, Some(p)),
|
||||
getPath,
|
||||
p => deserializerFor(et, p),
|
||||
path,
|
||||
customCollectionCls = Some(c))
|
||||
|
||||
case _ if mapType.isAssignableFrom(typeToken) =>
|
||||
|
@ -282,16 +279,16 @@ object JavaTypeInference {
|
|||
val keyData =
|
||||
Invoke(
|
||||
UnresolvedMapObjects(
|
||||
p => deserializerFor(keyType, Some(p)),
|
||||
GetKeyArrayFromMap(getPath)),
|
||||
p => deserializerFor(keyType, p),
|
||||
GetKeyArrayFromMap(path)),
|
||||
"array",
|
||||
ObjectType(classOf[Array[Any]]))
|
||||
|
||||
val valueData =
|
||||
Invoke(
|
||||
UnresolvedMapObjects(
|
||||
p => deserializerFor(valueType, Some(p)),
|
||||
GetValueArrayFromMap(getPath)),
|
||||
p => deserializerFor(valueType, p),
|
||||
GetValueArrayFromMap(path)),
|
||||
"array",
|
||||
ObjectType(classOf[Array[Any]]))
|
||||
|
||||
|
@ -307,7 +304,7 @@ object JavaTypeInference {
|
|||
other,
|
||||
ObjectType(other),
|
||||
"valueOf",
|
||||
Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil,
|
||||
Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil,
|
||||
returnNullable = false)
|
||||
|
||||
case other =>
|
||||
|
@ -316,7 +313,7 @@ object JavaTypeInference {
|
|||
val fieldName = p.getName
|
||||
val fieldType = typeToken.method(p.getReadMethod).getReturnType
|
||||
val (_, nullable) = inferDataType(fieldType)
|
||||
val constructor = deserializerFor(fieldType, Some(addToPath(fieldName)))
|
||||
val constructor = deserializerFor(fieldType, addToPath(fieldName))
|
||||
val setter = if (nullable) {
|
||||
constructor
|
||||
} else {
|
||||
|
@ -328,28 +325,23 @@ object JavaTypeInference {
|
|||
val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false)
|
||||
val result = InitializeJavaBean(newInstance, setters)
|
||||
|
||||
if (path.nonEmpty) {
|
||||
expressions.If(
|
||||
IsNull(getPath),
|
||||
expressions.Literal.create(null, ObjectType(other)),
|
||||
result
|
||||
)
|
||||
} else {
|
||||
expressions.If(
|
||||
IsNull(path),
|
||||
expressions.Literal.create(null, ObjectType(other)),
|
||||
result
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an expression for serializing an object of the given type to an internal row.
|
||||
* Returns an expression for serializing an object of the given type to a Spark SQL
|
||||
* representation. The input object is located at ordinal 0 of a row, i.e.,
|
||||
* `BoundReference(0, _)`.
|
||||
*/
|
||||
def serializerFor(beanClass: Class[_]): CreateNamedStruct = {
|
||||
def serializerFor(beanClass: Class[_]): Expression = {
|
||||
val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
|
||||
val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean"))
|
||||
serializerFor(nullSafeInput, TypeToken.of(beanClass)) match {
|
||||
case expressions.If(_, _, s: CreateNamedStruct) => s
|
||||
case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
|
||||
}
|
||||
serializerFor(nullSafeInput, TypeToken.of(beanClass))
|
||||
}
|
||||
|
||||
private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
|
||||
|
|
|
@ -24,7 +24,7 @@ import scala.util.Properties
|
|||
import org.apache.commons.lang3.reflect.ConstructorUtils
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
|
||||
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.objects._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData}
|
||||
|
@ -129,21 +129,44 @@ object ScalaReflection extends ScalaReflection {
|
|||
}
|
||||
|
||||
/**
|
||||
* Returns an expression that can be used to deserialize an input row to an object of type `T`
|
||||
* with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
|
||||
* of the same name as the constructor arguments. Nested classes will have their fields accessed
|
||||
* using UnresolvedExtractValue.
|
||||
* When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
|
||||
* and lost the required data type, which may lead to runtime error if the real type doesn't
|
||||
* match the encoder's schema.
|
||||
* For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type
|
||||
* is [a: int, b: long], then we will hit runtime error and say that we can't construct class
|
||||
* `Data` with int and long, because we lost the information that `b` should be a string.
|
||||
*
|
||||
* When used on a primitive type, the constructor will instead default to extracting the value
|
||||
* from ordinal 0 (since there are no names to map to). The actual location can be moved by
|
||||
* calling resolve/bind with a new schema.
|
||||
* This method help us "remember" the required data type by adding a `UpCast`. Note that we
|
||||
* only need to do this for leaf nodes.
|
||||
*/
|
||||
def deserializerFor[T : TypeTag]: Expression = {
|
||||
val tpe = localTypeOf[T]
|
||||
private def upCastToExpectedType(expr: Expression, expected: DataType,
|
||||
walkedTypePath: Seq[String]): Expression = expected match {
|
||||
case _: StructType => expr
|
||||
case _: ArrayType => expr
|
||||
// TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and
|
||||
// it's not trivial to support by-name resolution for StructType inside MapType.
|
||||
case _ => UpCast(expr, expected, walkedTypePath)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an expression that can be used to deserialize a Spark SQL representation to an object
|
||||
* of type `T` with a compatible schema. The Spark SQL representation is located at ordinal 0 of
|
||||
* a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed using
|
||||
* `UnresolvedExtractValue`.
|
||||
*
|
||||
* The returned expression is used by `ExpressionEncoder`. The encoder will resolve and bind this
|
||||
* deserializer expression when using it.
|
||||
*/
|
||||
def deserializerForType(tpe: `Type`): Expression = {
|
||||
val clsName = getClassNameFromType(tpe)
|
||||
val walkedTypePath = s"""- root class: "$clsName"""" :: Nil
|
||||
val expr = deserializerFor(tpe, None, walkedTypePath)
|
||||
val Schema(_, nullable) = schemaFor(tpe)
|
||||
val Schema(dataType, nullable) = schemaFor(tpe)
|
||||
|
||||
// Assumes we are deserializing the first column of a row.
|
||||
val input = upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType,
|
||||
walkedTypePath)
|
||||
|
||||
val expr = deserializerFor(tpe, input, walkedTypePath)
|
||||
if (nullable) {
|
||||
expr
|
||||
} else {
|
||||
|
@ -151,16 +174,22 @@ object ScalaReflection extends ScalaReflection {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an expression that can be used to deserialize an input expression to an object of type
|
||||
* `T` with a compatible schema.
|
||||
*
|
||||
* @param tpe The `Type` of deserialized object.
|
||||
* @param path The expression which can be used to extract serialized value.
|
||||
* @param walkedTypePath The paths from top to bottom to access current field when deserializing.
|
||||
*/
|
||||
private def deserializerFor(
|
||||
tpe: `Type`,
|
||||
path: Option[Expression],
|
||||
path: Expression,
|
||||
walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects {
|
||||
|
||||
/** Returns the current path with a sub-field extracted. */
|
||||
def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = {
|
||||
val newPath = path
|
||||
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
|
||||
.getOrElse(UnresolvedAttribute.quoted(part))
|
||||
val newPath = UnresolvedExtractValue(path, expressions.Literal(part))
|
||||
upCastToExpectedType(newPath, dataType, walkedTypePath)
|
||||
}
|
||||
|
||||
|
@ -169,46 +198,12 @@ object ScalaReflection extends ScalaReflection {
|
|||
ordinal: Int,
|
||||
dataType: DataType,
|
||||
walkedTypePath: Seq[String]): Expression = {
|
||||
val newPath = path
|
||||
.map(p => GetStructField(p, ordinal))
|
||||
.getOrElse(GetColumnByOrdinal(ordinal, dataType))
|
||||
val newPath = GetStructField(path, ordinal)
|
||||
upCastToExpectedType(newPath, dataType, walkedTypePath)
|
||||
}
|
||||
|
||||
/** Returns the current path or `GetColumnByOrdinal`. */
|
||||
def getPath: Expression = {
|
||||
val dataType = schemaFor(tpe).dataType
|
||||
if (path.isDefined) {
|
||||
path.get
|
||||
} else {
|
||||
upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, walkedTypePath)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
|
||||
* and lost the required data type, which may lead to runtime error if the real type doesn't
|
||||
* match the encoder's schema.
|
||||
* For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type
|
||||
* is [a: int, b: long], then we will hit runtime error and say that we can't construct class
|
||||
* `Data` with int and long, because we lost the information that `b` should be a string.
|
||||
*
|
||||
* This method help us "remember" the required data type by adding a `UpCast`. Note that we
|
||||
* only need to do this for leaf nodes.
|
||||
*/
|
||||
def upCastToExpectedType(
|
||||
expr: Expression,
|
||||
expected: DataType,
|
||||
walkedTypePath: Seq[String]): Expression = expected match {
|
||||
case _: StructType => expr
|
||||
case _: ArrayType => expr
|
||||
// TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and
|
||||
// it's not trivial to support by-name resolution for StructType inside MapType.
|
||||
case _ => UpCast(expr, expected, walkedTypePath)
|
||||
}
|
||||
|
||||
tpe.dealias match {
|
||||
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
|
||||
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => path
|
||||
|
||||
case t if t <:< localTypeOf[Option[_]] =>
|
||||
val TypeRef(_, _, Seq(optType)) = t
|
||||
|
@ -219,44 +214,44 @@ object ScalaReflection extends ScalaReflection {
|
|||
case t if t <:< localTypeOf[java.lang.Integer] =>
|
||||
val boxedType = classOf[java.lang.Integer]
|
||||
val objectType = ObjectType(boxedType)
|
||||
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
|
||||
StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
|
||||
|
||||
case t if t <:< localTypeOf[java.lang.Long] =>
|
||||
val boxedType = classOf[java.lang.Long]
|
||||
val objectType = ObjectType(boxedType)
|
||||
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
|
||||
StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
|
||||
|
||||
case t if t <:< localTypeOf[java.lang.Double] =>
|
||||
val boxedType = classOf[java.lang.Double]
|
||||
val objectType = ObjectType(boxedType)
|
||||
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
|
||||
StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
|
||||
|
||||
case t if t <:< localTypeOf[java.lang.Float] =>
|
||||
val boxedType = classOf[java.lang.Float]
|
||||
val objectType = ObjectType(boxedType)
|
||||
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
|
||||
StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
|
||||
|
||||
case t if t <:< localTypeOf[java.lang.Short] =>
|
||||
val boxedType = classOf[java.lang.Short]
|
||||
val objectType = ObjectType(boxedType)
|
||||
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
|
||||
StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
|
||||
|
||||
case t if t <:< localTypeOf[java.lang.Byte] =>
|
||||
val boxedType = classOf[java.lang.Byte]
|
||||
val objectType = ObjectType(boxedType)
|
||||
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
|
||||
StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
|
||||
|
||||
case t if t <:< localTypeOf[java.lang.Boolean] =>
|
||||
val boxedType = classOf[java.lang.Boolean]
|
||||
val objectType = ObjectType(boxedType)
|
||||
StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false)
|
||||
StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false)
|
||||
|
||||
case t if t <:< localTypeOf[java.sql.Date] =>
|
||||
StaticInvoke(
|
||||
DateTimeUtils.getClass,
|
||||
ObjectType(classOf[java.sql.Date]),
|
||||
"toJavaDate",
|
||||
getPath :: Nil,
|
||||
path :: Nil,
|
||||
returnNullable = false)
|
||||
|
||||
case t if t <:< localTypeOf[java.sql.Timestamp] =>
|
||||
|
@ -264,25 +259,25 @@ object ScalaReflection extends ScalaReflection {
|
|||
DateTimeUtils.getClass,
|
||||
ObjectType(classOf[java.sql.Timestamp]),
|
||||
"toJavaTimestamp",
|
||||
getPath :: Nil,
|
||||
path :: Nil,
|
||||
returnNullable = false)
|
||||
|
||||
case t if t <:< localTypeOf[java.lang.String] =>
|
||||
Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false)
|
||||
Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false)
|
||||
|
||||
case t if t <:< localTypeOf[java.math.BigDecimal] =>
|
||||
Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
|
||||
Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
|
||||
returnNullable = false)
|
||||
|
||||
case t if t <:< localTypeOf[BigDecimal] =>
|
||||
Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false)
|
||||
Invoke(path, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false)
|
||||
|
||||
case t if t <:< localTypeOf[java.math.BigInteger] =>
|
||||
Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]),
|
||||
Invoke(path, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]),
|
||||
returnNullable = false)
|
||||
|
||||
case t if t <:< localTypeOf[scala.math.BigInt] =>
|
||||
Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]),
|
||||
Invoke(path, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]),
|
||||
returnNullable = false)
|
||||
|
||||
case t if t <:< localTypeOf[Array[_]] =>
|
||||
|
@ -294,7 +289,7 @@ object ScalaReflection extends ScalaReflection {
|
|||
val mapFunction: Expression => Expression = element => {
|
||||
// upcast the array element to the data type the encoder expected.
|
||||
val casted = upCastToExpectedType(element, dataType, newTypePath)
|
||||
val converter = deserializerFor(elementType, Some(casted), newTypePath)
|
||||
val converter = deserializerFor(elementType, casted, newTypePath)
|
||||
if (elementNullable) {
|
||||
converter
|
||||
} else {
|
||||
|
@ -302,7 +297,7 @@ object ScalaReflection extends ScalaReflection {
|
|||
}
|
||||
}
|
||||
|
||||
val arrayData = UnresolvedMapObjects(mapFunction, getPath)
|
||||
val arrayData = UnresolvedMapObjects(mapFunction, path)
|
||||
val arrayCls = arrayClassFor(elementType)
|
||||
|
||||
if (elementNullable) {
|
||||
|
@ -334,7 +329,7 @@ object ScalaReflection extends ScalaReflection {
|
|||
val mapFunction: Expression => Expression = element => {
|
||||
// upcast the array element to the data type the encoder expected.
|
||||
val casted = upCastToExpectedType(element, dataType, newTypePath)
|
||||
val converter = deserializerFor(elementType, Some(casted), newTypePath)
|
||||
val converter = deserializerFor(elementType, casted, newTypePath)
|
||||
if (elementNullable) {
|
||||
converter
|
||||
} else {
|
||||
|
@ -349,16 +344,16 @@ object ScalaReflection extends ScalaReflection {
|
|||
classOf[scala.collection.Set[_]]
|
||||
case _ => mirror.runtimeClass(t.typeSymbol.asClass)
|
||||
}
|
||||
UnresolvedMapObjects(mapFunction, getPath, Some(cls))
|
||||
UnresolvedMapObjects(mapFunction, path, Some(cls))
|
||||
|
||||
case t if t <:< localTypeOf[Map[_, _]] =>
|
||||
// TODO: add walked type path for map
|
||||
val TypeRef(_, _, Seq(keyType, valueType)) = t
|
||||
|
||||
CatalystToExternalMap(
|
||||
p => deserializerFor(keyType, Some(p), walkedTypePath),
|
||||
p => deserializerFor(valueType, Some(p), walkedTypePath),
|
||||
getPath,
|
||||
p => deserializerFor(keyType, p, walkedTypePath),
|
||||
p => deserializerFor(valueType, p, walkedTypePath),
|
||||
path,
|
||||
mirror.runtimeClass(t.typeSymbol.asClass)
|
||||
)
|
||||
|
||||
|
@ -368,7 +363,7 @@ object ScalaReflection extends ScalaReflection {
|
|||
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
|
||||
Nil,
|
||||
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
|
||||
Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
|
||||
Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
|
||||
|
||||
case t if UDTRegistration.exists(getClassNameFromType(t)) =>
|
||||
val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
|
||||
|
@ -377,7 +372,7 @@ object ScalaReflection extends ScalaReflection {
|
|||
udt.getClass,
|
||||
Nil,
|
||||
dataType = ObjectType(udt.getClass))
|
||||
Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
|
||||
Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
|
||||
|
||||
case t if definedByConstructorParams(t) =>
|
||||
val params = getConstructorParameters(t)
|
||||
|
@ -392,12 +387,12 @@ object ScalaReflection extends ScalaReflection {
|
|||
val constructor = if (cls.getName startsWith "scala.Tuple") {
|
||||
deserializerFor(
|
||||
fieldType,
|
||||
Some(addToPathOrdinal(i, dataType, newTypePath)),
|
||||
addToPathOrdinal(i, dataType, newTypePath),
|
||||
newTypePath)
|
||||
} else {
|
||||
deserializerFor(
|
||||
fieldType,
|
||||
Some(addToPath(fieldName, dataType, newTypePath)),
|
||||
addToPath(fieldName, dataType, newTypePath),
|
||||
newTypePath)
|
||||
}
|
||||
|
||||
|
@ -410,20 +405,17 @@ object ScalaReflection extends ScalaReflection {
|
|||
|
||||
val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
|
||||
|
||||
if (path.nonEmpty) {
|
||||
expressions.If(
|
||||
IsNull(getPath),
|
||||
expressions.Literal.create(null, ObjectType(cls)),
|
||||
newInstance
|
||||
)
|
||||
} else {
|
||||
expressions.If(
|
||||
IsNull(path),
|
||||
expressions.Literal.create(null, ObjectType(cls)),
|
||||
newInstance
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an expression for serializing an object of type T to an internal row.
|
||||
* Returns an expression for serializing an object of type T to Spark SQL representation. The
|
||||
* input object is located at ordinal 0 of a row, i.e., `BoundReference(0, _)`.
|
||||
*
|
||||
* If the given type is not supported, i.e. there is no encoder can be built for this type,
|
||||
* an [[UnsupportedOperationException]] will be thrown with detailed error message to explain
|
||||
|
@ -434,17 +426,21 @@ object ScalaReflection extends ScalaReflection {
|
|||
* * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"`
|
||||
* * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")`
|
||||
*/
|
||||
def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
|
||||
val tpe = localTypeOf[T]
|
||||
def serializerForType(tpe: `Type`): Expression = ScalaReflection.cleanUpReflectionObjects {
|
||||
val clsName = getClassNameFromType(tpe)
|
||||
val walkedTypePath = s"""- root class: "$clsName"""" :: Nil
|
||||
serializerFor(inputObject, tpe, walkedTypePath) match {
|
||||
case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s
|
||||
case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
|
||||
}
|
||||
|
||||
// The input object to `ExpressionEncoder` is located at first column of an row.
|
||||
val inputObject = BoundReference(0, dataTypeFor(tpe),
|
||||
nullable = !tpe.typeSymbol.asClass.isPrimitive)
|
||||
|
||||
serializerFor(inputObject, tpe, walkedTypePath)
|
||||
}
|
||||
|
||||
/** Helper for extracting internal fields from a case class. */
|
||||
/**
|
||||
* Returns an expression for serializing the value of an input expression into Spark SQL
|
||||
* internal representation.
|
||||
*/
|
||||
private def serializerFor(
|
||||
inputObject: Expression,
|
||||
tpe: `Type`,
|
||||
|
|
|
@ -25,10 +25,11 @@ import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaRefle
|
|||
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
|
||||
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance}
|
||||
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, Invoke, NewInstance}
|
||||
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation}
|
||||
import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType}
|
||||
import org.apache.spark.sql.types.{ObjectType, StringType, StructField, StructType}
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
|
@ -43,8 +44,8 @@ import org.apache.spark.util.Utils
|
|||
* to the name `value`.
|
||||
*/
|
||||
object ExpressionEncoder {
|
||||
|
||||
def apply[T : TypeTag](): ExpressionEncoder[T] = {
|
||||
// We convert the not-serializable TypeTag into StructType and ClassTag.
|
||||
val mirror = ScalaReflection.mirror
|
||||
val tpe = typeTag[T].in(mirror).tpe
|
||||
|
||||
|
@ -58,25 +59,11 @@ object ExpressionEncoder {
|
|||
}
|
||||
|
||||
val cls = mirror.runtimeClass(tpe)
|
||||
val flat = !ScalaReflection.definedByConstructorParams(tpe)
|
||||
|
||||
val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = !cls.isPrimitive)
|
||||
val nullSafeInput = if (flat) {
|
||||
inputObject
|
||||
} else {
|
||||
// For input object of Product type, we can't encode it to row if it's null, as Spark SQL
|
||||
// doesn't allow top-level row to be null, only its columns can be null.
|
||||
AssertNotNull(inputObject, Seq("top level Product input object"))
|
||||
}
|
||||
val serializer = ScalaReflection.serializerFor[T](nullSafeInput)
|
||||
val deserializer = ScalaReflection.deserializerFor[T]
|
||||
|
||||
val schema = serializer.dataType
|
||||
val serializer = ScalaReflection.serializerForType(tpe)
|
||||
val deserializer = ScalaReflection.deserializerForType(tpe)
|
||||
|
||||
new ExpressionEncoder[T](
|
||||
schema,
|
||||
flat,
|
||||
serializer.flatten,
|
||||
serializer,
|
||||
deserializer,
|
||||
ClassTag[T](cls))
|
||||
}
|
||||
|
@ -86,14 +73,12 @@ object ExpressionEncoder {
|
|||
val schema = JavaTypeInference.inferDataType(beanClass)._1
|
||||
assert(schema.isInstanceOf[StructType])
|
||||
|
||||
val serializer = JavaTypeInference.serializerFor(beanClass)
|
||||
val deserializer = JavaTypeInference.deserializerFor(beanClass)
|
||||
val objSerializer = JavaTypeInference.serializerFor(beanClass)
|
||||
val objDeserializer = JavaTypeInference.deserializerFor(beanClass)
|
||||
|
||||
new ExpressionEncoder[T](
|
||||
schema.asInstanceOf[StructType],
|
||||
flat = false,
|
||||
serializer.flatten,
|
||||
deserializer,
|
||||
objSerializer,
|
||||
objDeserializer,
|
||||
ClassTag[T](beanClass))
|
||||
}
|
||||
|
||||
|
@ -103,75 +88,59 @@ object ExpressionEncoder {
|
|||
* name/positional binding is preserved.
|
||||
*/
|
||||
def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
|
||||
// TODO: check if encoders length is more than 22 and throw exception for it.
|
||||
|
||||
encoders.foreach(_.assertUnresolved())
|
||||
|
||||
val schema = StructType(encoders.zipWithIndex.map {
|
||||
case (e, i) =>
|
||||
val (dataType, nullable) = if (e.flat) {
|
||||
e.schema.head.dataType -> e.schema.head.nullable
|
||||
} else {
|
||||
e.schema -> true
|
||||
}
|
||||
StructField(s"_${i + 1}", dataType, nullable)
|
||||
StructField(s"_${i + 1}", e.objSerializer.dataType, e.objSerializer.nullable)
|
||||
})
|
||||
|
||||
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
|
||||
|
||||
val serializer = encoders.zipWithIndex.map { case (enc, index) =>
|
||||
val originalInputObject = enc.serializer.head.collect { case b: BoundReference => b }.head
|
||||
val serializers = encoders.zipWithIndex.map { case (enc, index) =>
|
||||
val boundRefs = enc.objSerializer.collect { case b: BoundReference => b }.distinct
|
||||
assert(boundRefs.size == 1, "object serializer should have only one bound reference but " +
|
||||
s"there are ${boundRefs.size}")
|
||||
|
||||
val originalInputObject = boundRefs.head
|
||||
val newInputObject = Invoke(
|
||||
BoundReference(0, ObjectType(cls), nullable = true),
|
||||
s"_${index + 1}",
|
||||
originalInputObject.dataType)
|
||||
originalInputObject.dataType,
|
||||
returnNullable = originalInputObject.nullable)
|
||||
|
||||
val newSerializer = enc.serializer.map(_.transformUp {
|
||||
case b: BoundReference if b == originalInputObject => newInputObject
|
||||
})
|
||||
|
||||
val serializerExpr = if (enc.flat) {
|
||||
newSerializer.head
|
||||
} else {
|
||||
// For non-flat encoder, the input object is not top level anymore after being combined to
|
||||
// a tuple encoder, thus it can be null and we should wrap the `CreateStruct` with `If` and
|
||||
// null check to handle null case correctly.
|
||||
// e.g. for Encoder[(Int, String)], the serializer expressions will create 2 columns, and is
|
||||
// not able to handle the case when the input tuple is null. This is not a problem as there
|
||||
// is a check to make sure the input object won't be null. However, if this encoder is used
|
||||
// to create a bigger tuple encoder, the original input object becomes a filed of the new
|
||||
// input tuple and can be null. So instead of creating a struct directly here, we should add
|
||||
// a null/None check and return a null struct if the null/None check fails.
|
||||
val struct = CreateStruct(newSerializer)
|
||||
val nullCheck = Or(
|
||||
IsNull(newInputObject),
|
||||
Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil))
|
||||
If(nullCheck, Literal.create(null, struct.dataType), struct)
|
||||
val newSerializer = enc.objSerializer.transformUp {
|
||||
case b: BoundReference => newInputObject
|
||||
}
|
||||
Alias(serializerExpr, s"_${index + 1}")()
|
||||
|
||||
Alias(newSerializer, s"_${index + 1}")()
|
||||
}
|
||||
|
||||
val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) =>
|
||||
if (enc.flat) {
|
||||
enc.deserializer.transform {
|
||||
case g: GetColumnByOrdinal => g.copy(ordinal = index)
|
||||
}
|
||||
val getColumnsByOrdinals = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }
|
||||
.distinct
|
||||
assert(getColumnsByOrdinals.size == 1, "object deserializer should have only one " +
|
||||
s"`GetColumnByOrdinal`, but there are ${getColumnsByOrdinals.size}")
|
||||
|
||||
val input = GetStructField(GetColumnByOrdinal(0, schema), index)
|
||||
val newDeserializer = enc.objDeserializer.transformUp {
|
||||
case GetColumnByOrdinal(0, _) => input
|
||||
}
|
||||
if (schema(index).nullable) {
|
||||
If(IsNull(input), Literal.create(null, newDeserializer.dataType), newDeserializer)
|
||||
} else {
|
||||
val input = GetColumnByOrdinal(index, enc.schema)
|
||||
val deserialized = enc.deserializer.transformUp {
|
||||
case UnresolvedAttribute(nameParts) =>
|
||||
assert(nameParts.length == 1)
|
||||
UnresolvedExtractValue(input, Literal(nameParts.head))
|
||||
case GetColumnByOrdinal(ordinal, _) => GetStructField(input, ordinal)
|
||||
}
|
||||
If(IsNull(input), Literal.create(null, deserialized.dataType), deserialized)
|
||||
newDeserializer
|
||||
}
|
||||
}
|
||||
|
||||
val serializer = If(IsNull(BoundReference(0, ObjectType(cls), nullable = true)),
|
||||
Literal.create(null, schema), CreateStruct(serializers))
|
||||
val deserializer =
|
||||
NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false)
|
||||
|
||||
new ExpressionEncoder[Any](
|
||||
schema,
|
||||
flat = false,
|
||||
serializer,
|
||||
deserializer,
|
||||
ClassTag(cls))
|
||||
|
@ -212,21 +181,91 @@ object ExpressionEncoder {
|
|||
* A generic encoder for JVM objects that uses Catalyst Expressions for a `serializer`
|
||||
* and a `deserializer`.
|
||||
*
|
||||
* @param schema The schema after converting `T` to a Spark SQL row.
|
||||
* @param serializer A set of expressions, one for each top-level field that can be used to
|
||||
* extract the values from a raw object into an [[InternalRow]].
|
||||
* @param deserializer An expression that will construct an object given an [[InternalRow]].
|
||||
* @param objSerializer An expression that can be used to encode a raw object to corresponding
|
||||
* Spark SQL representation that can be a primitive column, array, map or a
|
||||
* struct. This represents how Spark SQL generally serializes an object of
|
||||
* type `T`.
|
||||
* @param objDeserializer An expression that will construct an object given a Spark SQL
|
||||
* representation. This represents how Spark SQL generally deserializes
|
||||
* a serialized value in Spark SQL representation back to an object of
|
||||
* type `T`.
|
||||
* @param clsTag A classtag for `T`.
|
||||
*/
|
||||
case class ExpressionEncoder[T](
|
||||
schema: StructType,
|
||||
flat: Boolean,
|
||||
serializer: Seq[Expression],
|
||||
deserializer: Expression,
|
||||
objSerializer: Expression,
|
||||
objDeserializer: Expression,
|
||||
clsTag: ClassTag[T])
|
||||
extends Encoder[T] {
|
||||
|
||||
if (flat) require(serializer.size == 1)
|
||||
/**
|
||||
* A sequence of expressions, one for each top-level field that can be used to
|
||||
* extract the values from a raw object into an [[InternalRow]]:
|
||||
* 1. If `serializer` encodes a raw object to a struct, strip the outer If-IsNull and get
|
||||
* the `CreateNamedStruct`.
|
||||
* 2. For other cases, wrap the single serializer with `CreateNamedStruct`.
|
||||
*/
|
||||
val serializer: Seq[NamedExpression] = {
|
||||
val clsName = Utils.getSimpleName(clsTag.runtimeClass)
|
||||
|
||||
if (isSerializedAsStruct) {
|
||||
val nullSafeSerializer = objSerializer.transformUp {
|
||||
case r: BoundReference =>
|
||||
// For input object of Product type, we can't encode it to row if it's null, as Spark SQL
|
||||
// doesn't allow top-level row to be null, only its columns can be null.
|
||||
AssertNotNull(r, Seq("top level Product or row object"))
|
||||
}
|
||||
nullSafeSerializer match {
|
||||
case If(_: IsNull, _, s: CreateNamedStruct) => s
|
||||
case s: CreateNamedStruct => s
|
||||
case _ =>
|
||||
throw new RuntimeException(s"class $clsName has unexpected serializer: $objSerializer")
|
||||
}
|
||||
} else {
|
||||
// For other input objects like primitive, array, map, etc., we construct a struct to wrap
|
||||
// the serializer which is a column of an row.
|
||||
CreateNamedStruct(Literal("value") :: objSerializer :: Nil)
|
||||
}
|
||||
}.flatten
|
||||
|
||||
/**
|
||||
* Returns an expression that can be used to deserialize an input row to an object of type `T`
|
||||
* with a compatible schema. Fields of the row will be extracted using `UnresolvedAttribute`.
|
||||
* of the same name as the constructor arguments.
|
||||
*
|
||||
* For complex objects that are encoded to structs, Fields of the struct will be extracted using
|
||||
* `GetColumnByOrdinal` with corresponding ordinal.
|
||||
*/
|
||||
val deserializer: Expression = {
|
||||
if (isSerializedAsStruct) {
|
||||
// We serialized this kind of objects to root-level row. The input of general deserializer
|
||||
// is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to
|
||||
// transform attributes accessors.
|
||||
objDeserializer.transform {
|
||||
case UnresolvedExtractValue(GetColumnByOrdinal(0, _),
|
||||
Literal(part: UTF8String, StringType)) =>
|
||||
UnresolvedAttribute.quoted(part.toString)
|
||||
case GetStructField(GetColumnByOrdinal(0, dt), ordinal, _) =>
|
||||
GetColumnByOrdinal(ordinal, dt)
|
||||
case If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance) => n
|
||||
case If(IsNull(GetColumnByOrdinal(0, _)), _, i: InitializeJavaBean) => i
|
||||
}
|
||||
} else {
|
||||
// For other input objects like primitive, array, map, etc., we deserialize the first column
|
||||
// of a row to the object.
|
||||
objDeserializer
|
||||
}
|
||||
}
|
||||
|
||||
// The schema after converting `T` to a Spark SQL row. This schema is dependent on the given
|
||||
// serialier.
|
||||
val schema: StructType = StructType(serializer.map { s =>
|
||||
StructField(s.name, s.dataType, s.nullable)
|
||||
})
|
||||
|
||||
/**
|
||||
* Returns true if the type `T` is serialized as a struct.
|
||||
*/
|
||||
def isSerializedAsStruct: Boolean = objSerializer.dataType.isInstanceOf[StructType]
|
||||
|
||||
// serializer expressions are used to encode an object to a row, while the object is usually an
|
||||
// intermediate value produced inside an operator, not from the output of the child operator. This
|
||||
|
@ -258,7 +297,7 @@ case class ExpressionEncoder[T](
|
|||
analyzer.checkAnalysis(analyzedPlan)
|
||||
val resolved = SimplifyCasts(analyzedPlan).asInstanceOf[DeserializeToObject].deserializer
|
||||
val bound = BindReferences.bindReference(resolved, attrs)
|
||||
copy(deserializer = bound)
|
||||
copy(objDeserializer = bound)
|
||||
}
|
||||
|
||||
@transient
|
||||
|
|
|
@ -58,12 +58,10 @@ object RowEncoder {
|
|||
def apply(schema: StructType): ExpressionEncoder[Row] = {
|
||||
val cls = classOf[Row]
|
||||
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
|
||||
val serializer = serializerFor(AssertNotNull(inputObject, Seq("top level row object")), schema)
|
||||
val deserializer = deserializerFor(schema)
|
||||
val serializer = serializerFor(inputObject, schema)
|
||||
val deserializer = deserializerFor(GetColumnByOrdinal(0, serializer.dataType), schema)
|
||||
new ExpressionEncoder[Row](
|
||||
schema,
|
||||
flat = false,
|
||||
serializer.asInstanceOf[CreateNamedStruct].flatten,
|
||||
serializer,
|
||||
deserializer,
|
||||
ClassTag(cls))
|
||||
}
|
||||
|
@ -237,13 +235,9 @@ object RowEncoder {
|
|||
case udt: UserDefinedType[_] => ObjectType(udt.userClass)
|
||||
}
|
||||
|
||||
private def deserializerFor(schema: StructType): Expression = {
|
||||
private def deserializerFor(input: Expression, schema: StructType): Expression = {
|
||||
val fields = schema.zipWithIndex.map { case (f, i) =>
|
||||
val dt = f.dataType match {
|
||||
case p: PythonUserDefinedType => p.sqlType
|
||||
case other => other
|
||||
}
|
||||
deserializerFor(GetColumnByOrdinal(i, dt))
|
||||
deserializerFor(GetStructField(input, i))
|
||||
}
|
||||
CreateExternalRow(fields, schema)
|
||||
}
|
||||
|
|
|
@ -19,12 +19,13 @@ package org.apache.spark.sql.catalyst
|
|||
|
||||
import java.sql.{Date, Timestamp}
|
||||
|
||||
import scala.reflect.runtime.universe.TypeTag
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
|
||||
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, SpecificInternalRow, UpCast}
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
|
||||
import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, If, SpecificInternalRow, UpCast}
|
||||
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
case class PrimitiveData(
|
||||
intField: Int,
|
||||
|
@ -112,6 +113,14 @@ object TestingUDT {
|
|||
class ScalaReflectionSuite extends SparkFunSuite {
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection._
|
||||
|
||||
// A helper method used to test `ScalaReflection.serializerForType`.
|
||||
private def serializerFor[T: TypeTag]: Expression =
|
||||
serializerForType(ScalaReflection.localTypeOf[T])
|
||||
|
||||
// A helper method used to test `ScalaReflection.deserializerForType`.
|
||||
private def deserializerFor[T: TypeTag]: Expression =
|
||||
deserializerForType(ScalaReflection.localTypeOf[T])
|
||||
|
||||
test("SQLUserDefinedType annotation on Scala structure") {
|
||||
val schema = schemaFor[TestingUDT.NestedStruct]
|
||||
assert(schema === Schema(
|
||||
|
@ -263,13 +272,9 @@ class ScalaReflectionSuite extends SparkFunSuite {
|
|||
|
||||
test("SPARK-15062: Get correct serializer for List[_]") {
|
||||
val list = List(1, 2, 3)
|
||||
val serializer = serializerFor[List[Int]](BoundReference(
|
||||
0, ObjectType(list.getClass), nullable = false))
|
||||
assert(serializer.children.size == 2)
|
||||
assert(serializer.children.head.isInstanceOf[Literal])
|
||||
assert(serializer.children.head.asInstanceOf[Literal].value === UTF8String.fromString("value"))
|
||||
assert(serializer.children.last.isInstanceOf[NewInstance])
|
||||
assert(serializer.children.last.asInstanceOf[NewInstance]
|
||||
val serializer = serializerFor[List[Int]]
|
||||
assert(serializer.isInstanceOf[NewInstance])
|
||||
assert(serializer.asInstanceOf[NewInstance]
|
||||
.cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData]))
|
||||
}
|
||||
|
||||
|
@ -280,59 +285,58 @@ class ScalaReflectionSuite extends SparkFunSuite {
|
|||
|
||||
test("serialize and deserialize arbitrary sequence types") {
|
||||
import scala.collection.immutable.Queue
|
||||
val queueSerializer = serializerFor[Queue[Int]](BoundReference(
|
||||
0, ObjectType(classOf[Queue[Int]]), nullable = false))
|
||||
assert(queueSerializer.dataType.head.dataType ==
|
||||
val queueSerializer = serializerFor[Queue[Int]]
|
||||
assert(queueSerializer.dataType ==
|
||||
ArrayType(IntegerType, containsNull = false))
|
||||
val queueDeserializer = deserializerFor[Queue[Int]]
|
||||
assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]]))
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference(
|
||||
0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false))
|
||||
assert(arrayBufferSerializer.dataType.head.dataType ==
|
||||
val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]]
|
||||
assert(arrayBufferSerializer.dataType ==
|
||||
ArrayType(IntegerType, containsNull = false))
|
||||
val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
|
||||
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))
|
||||
}
|
||||
|
||||
test("serialize and deserialize arbitrary map types") {
|
||||
val mapSerializer = serializerFor[Map[Int, Int]](BoundReference(
|
||||
0, ObjectType(classOf[Map[Int, Int]]), nullable = false))
|
||||
assert(mapSerializer.dataType.head.dataType ==
|
||||
val mapSerializer = serializerFor[Map[Int, Int]]
|
||||
assert(mapSerializer.dataType ==
|
||||
MapType(IntegerType, IntegerType, valueContainsNull = false))
|
||||
val mapDeserializer = deserializerFor[Map[Int, Int]]
|
||||
assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]]))
|
||||
|
||||
import scala.collection.immutable.HashMap
|
||||
val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference(
|
||||
0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false))
|
||||
assert(hashMapSerializer.dataType.head.dataType ==
|
||||
val hashMapSerializer = serializerFor[HashMap[Int, Int]]
|
||||
assert(hashMapSerializer.dataType ==
|
||||
MapType(IntegerType, IntegerType, valueContainsNull = false))
|
||||
val hashMapDeserializer = deserializerFor[HashMap[Int, Int]]
|
||||
assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]]))
|
||||
|
||||
import scala.collection.mutable.{LinkedHashMap => LHMap}
|
||||
val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference(
|
||||
0, ObjectType(classOf[LHMap[Long, String]]), nullable = false))
|
||||
assert(linkedHashMapSerializer.dataType.head.dataType ==
|
||||
val linkedHashMapSerializer = serializerFor[LHMap[Long, String]]
|
||||
assert(linkedHashMapSerializer.dataType ==
|
||||
MapType(LongType, StringType, valueContainsNull = true))
|
||||
val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]]
|
||||
assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]]))
|
||||
}
|
||||
|
||||
test("SPARK-22442: Generate correct field names for special characters") {
|
||||
val serializer = serializerFor[SpecialCharAsFieldData](BoundReference(
|
||||
0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false))
|
||||
val serializer = serializerFor[SpecialCharAsFieldData]
|
||||
.collect {
|
||||
case If(_, _, s: CreateNamedStruct) => s
|
||||
}.head
|
||||
val deserializer = deserializerFor[SpecialCharAsFieldData]
|
||||
assert(serializer.dataType(0).name == "field.1")
|
||||
assert(serializer.dataType(1).name == "field 2")
|
||||
|
||||
val argumentsFields = deserializer.asInstanceOf[NewInstance].arguments.flatMap { _.collect {
|
||||
case UpCast(u: UnresolvedAttribute, _, _) => u.nameParts
|
||||
val newInstance = deserializer.collect { case n: NewInstance => n }.head
|
||||
|
||||
val argumentsFields = newInstance.arguments.flatMap { _.collect {
|
||||
case UpCast(u: UnresolvedExtractValue, _, _) => u.extraction.toString
|
||||
}}
|
||||
assert(argumentsFields(0) == Seq("field.1"))
|
||||
assert(argumentsFields(1) == Seq("field 2"))
|
||||
assert(argumentsFields(0) == "field.1")
|
||||
assert(argumentsFields(1) == "field 2")
|
||||
}
|
||||
|
||||
test("SPARK-22472: add null check for top-level primitive values") {
|
||||
|
@ -351,8 +355,8 @@ class ScalaReflectionSuite extends SparkFunSuite {
|
|||
|
||||
test("SPARK-23835: add null check to non-nullable types in Tuples") {
|
||||
def numberOfCheckedArguments(deserializer: Expression): Int = {
|
||||
assert(deserializer.isInstanceOf[NewInstance])
|
||||
deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull])
|
||||
val newInstance = deserializer.collect { case n: NewInstance => n}.head
|
||||
newInstance.arguments.count(_.isInstanceOf[AssertNotNull])
|
||||
}
|
||||
assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2)
|
||||
assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1)
|
||||
|
|
|
@ -28,9 +28,9 @@ import org.apache.spark.sql.{Encoder, Encoders}
|
|||
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
|
||||
import org.apache.spark.sql.catalyst.analysis.AnalysisTest
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
|
||||
import org.apache.spark.sql.catalyst.expressions.AttributeReference
|
||||
import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
|
||||
import org.apache.spark.sql.catalyst.util.ArrayData
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
@ -348,7 +348,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
|
|||
|
||||
test("nullable of encoder serializer") {
|
||||
def checkNullable[T: Encoder](nullable: Boolean): Unit = {
|
||||
assert(encoderFor[T].serializer.forall(_.nullable === nullable))
|
||||
assert(encoderFor[T].objSerializer.nullable === nullable)
|
||||
}
|
||||
|
||||
// test for flat encoders
|
||||
|
|
|
@ -239,7 +239,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
|
|||
val encoder = RowEncoder(schema)
|
||||
val e = intercept[RuntimeException](encoder.toRow(null))
|
||||
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
|
||||
assert(e.getMessage.contains("top level row object"))
|
||||
assert(e.getMessage.contains("top level Product or row object"))
|
||||
}
|
||||
|
||||
test("RowEncoder should validate external type") {
|
||||
|
|
|
@ -1087,7 +1087,7 @@ class Dataset[T] private[sql](
|
|||
// Note that we do this before joining them, to enable the join operator to return null for one
|
||||
// side, in cases like outer-join.
|
||||
val left = {
|
||||
val combined = if (this.exprEnc.flat) {
|
||||
val combined = if (!this.exprEnc.isSerializedAsStruct) {
|
||||
assert(joined.left.output.length == 1)
|
||||
Alias(joined.left.output.head, "_1")()
|
||||
} else {
|
||||
|
@ -1097,7 +1097,7 @@ class Dataset[T] private[sql](
|
|||
}
|
||||
|
||||
val right = {
|
||||
val combined = if (other.exprEnc.flat) {
|
||||
val combined = if (!other.exprEnc.isSerializedAsStruct) {
|
||||
assert(joined.right.output.length == 1)
|
||||
Alias(joined.right.output.head, "_2")()
|
||||
} else {
|
||||
|
@ -1110,14 +1110,14 @@ class Dataset[T] private[sql](
|
|||
// combine the outputs of each join side.
|
||||
val conditionExpr = joined.condition.get transformUp {
|
||||
case a: Attribute if joined.left.outputSet.contains(a) =>
|
||||
if (this.exprEnc.flat) {
|
||||
if (!this.exprEnc.isSerializedAsStruct) {
|
||||
left.output.head
|
||||
} else {
|
||||
val index = joined.left.output.indexWhere(_.exprId == a.exprId)
|
||||
GetStructField(left.output.head, index)
|
||||
}
|
||||
case a: Attribute if joined.right.outputSet.contains(a) =>
|
||||
if (other.exprEnc.flat) {
|
||||
if (!other.exprEnc.isSerializedAsStruct) {
|
||||
right.output.head
|
||||
} else {
|
||||
val index = joined.right.output.indexWhere(_.exprId == a.exprId)
|
||||
|
@ -1390,7 +1390,7 @@ class Dataset[T] private[sql](
|
|||
implicit val encoder = c1.encoder
|
||||
val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan)
|
||||
|
||||
if (encoder.flat) {
|
||||
if (!encoder.isSerializedAsStruct) {
|
||||
new Dataset[U1](sparkSession, project, encoder)
|
||||
} else {
|
||||
// Flattens inner fields of U1
|
||||
|
|
|
@ -457,7 +457,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
|
|||
val encoders = columns.map(_.encoder)
|
||||
val namedColumns =
|
||||
columns.map(_.withInputType(vExprEnc, dataAttributes).named)
|
||||
val keyColumn = if (kExprEnc.flat) {
|
||||
val keyColumn = if (!kExprEnc.isSerializedAsStruct) {
|
||||
assert(groupingAttributes.length == 1)
|
||||
groupingAttributes.head
|
||||
} else {
|
||||
|
|
|
@ -38,18 +38,14 @@ object TypedAggregateExpression {
|
|||
val bufferSerializer = bufferEncoder.namedExpressions
|
||||
|
||||
val outputEncoder = encoderFor[OUT]
|
||||
val outputType = if (outputEncoder.flat) {
|
||||
outputEncoder.schema.head.dataType
|
||||
} else {
|
||||
outputEncoder.schema
|
||||
}
|
||||
val outputType = outputEncoder.objSerializer.dataType
|
||||
|
||||
// Checks if the buffer object is simple, i.e. the buffer encoder is flat and the serializer
|
||||
// expression is an alias of `BoundReference`, which means the buffer object doesn't need
|
||||
// serialization.
|
||||
val isSimpleBuffer = {
|
||||
bufferSerializer.head match {
|
||||
case Alias(_: BoundReference, _) if bufferEncoder.flat => true
|
||||
case Alias(_: BoundReference, _) if !bufferEncoder.isSerializedAsStruct => true
|
||||
case _ => false
|
||||
}
|
||||
}
|
||||
|
@ -71,7 +67,7 @@ object TypedAggregateExpression {
|
|||
outputEncoder.serializer,
|
||||
outputEncoder.deserializer.dataType,
|
||||
outputType,
|
||||
!outputEncoder.flat || outputEncoder.schema.head.nullable)
|
||||
outputEncoder.objSerializer.nullable)
|
||||
} else {
|
||||
ComplexTypedAggregateExpression(
|
||||
aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
|
||||
|
@ -82,7 +78,7 @@ object TypedAggregateExpression {
|
|||
bufferEncoder.resolveAndBind().deserializer,
|
||||
outputEncoder.serializer,
|
||||
outputType,
|
||||
!outputEncoder.flat || outputEncoder.schema.head.nullable)
|
||||
outputEncoder.objSerializer.nullable)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1065,7 +1065,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
|
|||
test("Dataset should throw RuntimeException if top-level product input object is null") {
|
||||
val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS())
|
||||
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
|
||||
assert(e.getMessage.contains("top level Product input object"))
|
||||
assert(e.getMessage.contains("top level Product or row object"))
|
||||
}
|
||||
|
||||
test("dropDuplicates") {
|
||||
|
|
Loading…
Reference in a new issue