[SPARK-11954][SQL] Encoder for JavaBeans
create java version of `constructorFor` and `extractorFor` in `JavaTypeInference` Author: Wenchen Fan <wenchen@databricks.com> This patch had conflicts when merged, resolved by Committer: Michael Armbrust <michael@databricks.com> Closes #9937 from cloud-fan/pojo.
This commit is contained in:
parent
9df24624af
commit
fd95eeaf49
|
@ -97,6 +97,24 @@ object Encoders {
|
|||
*/
|
||||
def STRING: Encoder[java.lang.String] = ExpressionEncoder()
|
||||
|
||||
/**
|
||||
* Creates an encoder for Java Bean of type T.
|
||||
*
|
||||
* T must be publicly accessible.
|
||||
*
|
||||
* supported types for java bean field:
|
||||
* - primitive types: boolean, int, double, etc.
|
||||
* - boxed types: Boolean, Integer, Double, etc.
|
||||
* - String
|
||||
* - java.math.BigDecimal
|
||||
* - time related: java.sql.Date, java.sql.Timestamp
|
||||
* - collection types: only array and java.util.List currently, map support is in progress
|
||||
* - nested java bean.
|
||||
*
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass)
|
||||
|
||||
/**
|
||||
* (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
|
||||
* This encoder maps T into a single byte array (binary) field.
|
||||
|
|
|
@ -17,14 +17,20 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst
|
||||
|
||||
import java.beans.Introspector
|
||||
import java.beans.{PropertyDescriptor, Introspector}
|
||||
import java.lang.{Iterable => JIterable}
|
||||
import java.util.{Iterator => JIterator, Map => JMap}
|
||||
import java.util.{Iterator => JIterator, Map => JMap, List => JList}
|
||||
|
||||
import scala.language.existentials
|
||||
|
||||
import com.google.common.reflect.TypeToken
|
||||
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue}
|
||||
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils}
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
|
||||
/**
|
||||
* Type-inference utilities for POJOs and Java collections.
|
||||
|
@ -33,13 +39,14 @@ object JavaTypeInference {
|
|||
|
||||
private val iterableType = TypeToken.of(classOf[JIterable[_]])
|
||||
private val mapType = TypeToken.of(classOf[JMap[_, _]])
|
||||
private val listType = TypeToken.of(classOf[JList[_]])
|
||||
private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType
|
||||
private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType
|
||||
private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType
|
||||
private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType
|
||||
|
||||
/**
|
||||
* Infers the corresponding SQL data type of a JavaClean class.
|
||||
* Infers the corresponding SQL data type of a JavaBean class.
|
||||
* @param beanClass Java type
|
||||
* @return (SQL data type, nullable)
|
||||
*/
|
||||
|
@ -58,6 +65,8 @@ object JavaTypeInference {
|
|||
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
|
||||
|
||||
case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
|
||||
case c: Class[_] if c == classOf[Array[Byte]] => (BinaryType, true)
|
||||
|
||||
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
|
||||
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
|
||||
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
|
||||
|
@ -87,15 +96,14 @@ object JavaTypeInference {
|
|||
(ArrayType(dataType, nullable), true)
|
||||
|
||||
case _ if mapType.isAssignableFrom(typeToken) =>
|
||||
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
|
||||
val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]])
|
||||
val keyType = elementType(mapSupertype.resolveType(keySetReturnType))
|
||||
val valueType = elementType(mapSupertype.resolveType(valuesReturnType))
|
||||
val (keyType, valueType) = mapKeyValueType(typeToken)
|
||||
val (keyDataType, _) = inferDataType(keyType)
|
||||
val (valueDataType, nullable) = inferDataType(valueType)
|
||||
(MapType(keyDataType, valueDataType, nullable), true)
|
||||
|
||||
case _ =>
|
||||
// TODO: we should only collect properties that have getter and setter. However, some tests
|
||||
// pass in scala case class as java bean class which doesn't have getter and setter.
|
||||
val beanInfo = Introspector.getBeanInfo(typeToken.getRawType)
|
||||
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
|
||||
val fields = properties.map { property =>
|
||||
|
@ -107,11 +115,294 @@ object JavaTypeInference {
|
|||
}
|
||||
}
|
||||
|
||||
private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
|
||||
val beanInfo = Introspector.getBeanInfo(beanClass)
|
||||
beanInfo.getPropertyDescriptors
|
||||
.filter(p => p.getReadMethod != null && p.getWriteMethod != null)
|
||||
}
|
||||
|
||||
private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
|
||||
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]]
|
||||
val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]])
|
||||
val iteratorType = iterableSupertype.resolveType(iteratorReturnType)
|
||||
val itemType = iteratorType.resolveType(nextReturnType)
|
||||
itemType
|
||||
val iterableSuperType = typeToken2.getSupertype(classOf[JIterable[_]])
|
||||
val iteratorType = iterableSuperType.resolveType(iteratorReturnType)
|
||||
iteratorType.resolveType(nextReturnType)
|
||||
}
|
||||
|
||||
private def mapKeyValueType(typeToken: TypeToken[_]): (TypeToken[_], TypeToken[_]) = {
|
||||
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
|
||||
val mapSuperType = typeToken2.getSupertype(classOf[JMap[_, _]])
|
||||
val keyType = elementType(mapSuperType.resolveType(keySetReturnType))
|
||||
val valueType = elementType(mapSuperType.resolveType(valuesReturnType))
|
||||
keyType -> valueType
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping
|
||||
* to a native type, an ObjectType is returned.
|
||||
*
|
||||
* Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type
|
||||
* system. As a result, ObjectType will be returned for things like boxed Integers.
|
||||
*/
|
||||
private def inferExternalType(cls: Class[_]): DataType = cls match {
|
||||
case c if c == java.lang.Boolean.TYPE => BooleanType
|
||||
case c if c == java.lang.Byte.TYPE => ByteType
|
||||
case c if c == java.lang.Short.TYPE => ShortType
|
||||
case c if c == java.lang.Integer.TYPE => IntegerType
|
||||
case c if c == java.lang.Long.TYPE => LongType
|
||||
case c if c == java.lang.Float.TYPE => FloatType
|
||||
case c if c == java.lang.Double.TYPE => DoubleType
|
||||
case c if c == classOf[Array[Byte]] => BinaryType
|
||||
case _ => ObjectType(cls)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an expression that can be used to construct an object of java bean `T` given an input
|
||||
* row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
|
||||
* of the same name as the constructor arguments. Nested classes will have their fields accessed
|
||||
* using UnresolvedExtractValue.
|
||||
*/
|
||||
def constructorFor(beanClass: Class[_]): Expression = {
|
||||
constructorFor(TypeToken.of(beanClass), None)
|
||||
}
|
||||
|
||||
private def constructorFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = {
|
||||
/** Returns the current path with a sub-field extracted. */
|
||||
def addToPath(part: String): Expression = path
|
||||
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
|
||||
.getOrElse(UnresolvedAttribute(part))
|
||||
|
||||
/** Returns the current path or `BoundReference`. */
|
||||
def getPath: Expression = path.getOrElse(BoundReference(0, inferDataType(typeToken)._1, true))
|
||||
|
||||
typeToken.getRawType match {
|
||||
case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath
|
||||
|
||||
case c if c == classOf[java.lang.Short] =>
|
||||
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
|
||||
case c if c == classOf[java.lang.Integer] =>
|
||||
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
|
||||
case c if c == classOf[java.lang.Long] =>
|
||||
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
|
||||
case c if c == classOf[java.lang.Double] =>
|
||||
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
|
||||
case c if c == classOf[java.lang.Byte] =>
|
||||
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
|
||||
case c if c == classOf[java.lang.Float] =>
|
||||
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
|
||||
case c if c == classOf[java.lang.Boolean] =>
|
||||
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
|
||||
|
||||
case c if c == classOf[java.sql.Date] =>
|
||||
StaticInvoke(
|
||||
DateTimeUtils,
|
||||
ObjectType(c),
|
||||
"toJavaDate",
|
||||
getPath :: Nil,
|
||||
propagateNull = true)
|
||||
|
||||
case c if c == classOf[java.sql.Timestamp] =>
|
||||
StaticInvoke(
|
||||
DateTimeUtils,
|
||||
ObjectType(c),
|
||||
"toJavaTimestamp",
|
||||
getPath :: Nil,
|
||||
propagateNull = true)
|
||||
|
||||
case c if c == classOf[java.lang.String] =>
|
||||
Invoke(getPath, "toString", ObjectType(classOf[String]))
|
||||
|
||||
case c if c == classOf[java.math.BigDecimal] =>
|
||||
Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
|
||||
|
||||
case c if c.isArray =>
|
||||
val elementType = c.getComponentType
|
||||
val primitiveMethod = elementType match {
|
||||
case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray")
|
||||
case c if c == java.lang.Byte.TYPE => Some("toByteArray")
|
||||
case c if c == java.lang.Short.TYPE => Some("toShortArray")
|
||||
case c if c == java.lang.Integer.TYPE => Some("toIntArray")
|
||||
case c if c == java.lang.Long.TYPE => Some("toLongArray")
|
||||
case c if c == java.lang.Float.TYPE => Some("toFloatArray")
|
||||
case c if c == java.lang.Double.TYPE => Some("toDoubleArray")
|
||||
case _ => None
|
||||
}
|
||||
|
||||
primitiveMethod.map { method =>
|
||||
Invoke(getPath, method, ObjectType(c))
|
||||
}.getOrElse {
|
||||
Invoke(
|
||||
MapObjects(
|
||||
p => constructorFor(typeToken.getComponentType, Some(p)),
|
||||
getPath,
|
||||
inferDataType(elementType)._1),
|
||||
"array",
|
||||
ObjectType(c))
|
||||
}
|
||||
|
||||
case c if listType.isAssignableFrom(typeToken) =>
|
||||
val et = elementType(typeToken)
|
||||
val array =
|
||||
Invoke(
|
||||
MapObjects(
|
||||
p => constructorFor(et, Some(p)),
|
||||
getPath,
|
||||
inferDataType(et)._1),
|
||||
"array",
|
||||
ObjectType(classOf[Array[Any]]))
|
||||
|
||||
StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil)
|
||||
|
||||
case _ if mapType.isAssignableFrom(typeToken) =>
|
||||
val (keyType, valueType) = mapKeyValueType(typeToken)
|
||||
val keyDataType = inferDataType(keyType)._1
|
||||
val valueDataType = inferDataType(valueType)._1
|
||||
|
||||
val keyData =
|
||||
Invoke(
|
||||
MapObjects(
|
||||
p => constructorFor(keyType, Some(p)),
|
||||
Invoke(getPath, "keyArray", ArrayType(keyDataType)),
|
||||
keyDataType),
|
||||
"array",
|
||||
ObjectType(classOf[Array[Any]]))
|
||||
|
||||
val valueData =
|
||||
Invoke(
|
||||
MapObjects(
|
||||
p => constructorFor(valueType, Some(p)),
|
||||
Invoke(getPath, "valueArray", ArrayType(valueDataType)),
|
||||
valueDataType),
|
||||
"array",
|
||||
ObjectType(classOf[Array[Any]]))
|
||||
|
||||
StaticInvoke(
|
||||
ArrayBasedMapData,
|
||||
ObjectType(classOf[JMap[_, _]]),
|
||||
"toJavaMap",
|
||||
keyData :: valueData :: Nil)
|
||||
|
||||
case other =>
|
||||
val properties = getJavaBeanProperties(other)
|
||||
assert(properties.length > 0)
|
||||
|
||||
val setters = properties.map { p =>
|
||||
val fieldName = p.getName
|
||||
val fieldType = typeToken.method(p.getReadMethod).getReturnType
|
||||
p.getWriteMethod.getName -> constructorFor(fieldType, Some(addToPath(fieldName)))
|
||||
}.toMap
|
||||
|
||||
val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other))
|
||||
val result = InitializeJavaBean(newInstance, setters)
|
||||
|
||||
if (path.nonEmpty) {
|
||||
expressions.If(
|
||||
IsNull(getPath),
|
||||
expressions.Literal.create(null, ObjectType(other)),
|
||||
result
|
||||
)
|
||||
} else {
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns expressions for extracting all the fields from the given type.
|
||||
*/
|
||||
def extractorsFor(beanClass: Class[_]): CreateNamedStruct = {
|
||||
val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
|
||||
extractorFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct]
|
||||
}
|
||||
|
||||
private def extractorFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
|
||||
|
||||
def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
|
||||
val (dataType, nullable) = inferDataType(elementType)
|
||||
if (ScalaReflection.isNativeType(dataType)) {
|
||||
NewInstance(
|
||||
classOf[GenericArrayData],
|
||||
input :: Nil,
|
||||
dataType = ArrayType(dataType, nullable))
|
||||
} else {
|
||||
MapObjects(extractorFor(_, elementType), input, ObjectType(elementType.getRawType))
|
||||
}
|
||||
}
|
||||
|
||||
if (!inputObject.dataType.isInstanceOf[ObjectType]) {
|
||||
inputObject
|
||||
} else {
|
||||
typeToken.getRawType match {
|
||||
case c if c == classOf[String] =>
|
||||
StaticInvoke(
|
||||
classOf[UTF8String],
|
||||
StringType,
|
||||
"fromString",
|
||||
inputObject :: Nil)
|
||||
|
||||
case c if c == classOf[java.sql.Timestamp] =>
|
||||
StaticInvoke(
|
||||
DateTimeUtils,
|
||||
TimestampType,
|
||||
"fromJavaTimestamp",
|
||||
inputObject :: Nil)
|
||||
|
||||
case c if c == classOf[java.sql.Date] =>
|
||||
StaticInvoke(
|
||||
DateTimeUtils,
|
||||
DateType,
|
||||
"fromJavaDate",
|
||||
inputObject :: Nil)
|
||||
|
||||
case c if c == classOf[java.math.BigDecimal] =>
|
||||
StaticInvoke(
|
||||
Decimal,
|
||||
DecimalType.SYSTEM_DEFAULT,
|
||||
"apply",
|
||||
inputObject :: Nil)
|
||||
|
||||
case c if c == classOf[java.lang.Boolean] =>
|
||||
Invoke(inputObject, "booleanValue", BooleanType)
|
||||
case c if c == classOf[java.lang.Byte] =>
|
||||
Invoke(inputObject, "byteValue", ByteType)
|
||||
case c if c == classOf[java.lang.Short] =>
|
||||
Invoke(inputObject, "shortValue", ShortType)
|
||||
case c if c == classOf[java.lang.Integer] =>
|
||||
Invoke(inputObject, "intValue", IntegerType)
|
||||
case c if c == classOf[java.lang.Long] =>
|
||||
Invoke(inputObject, "longValue", LongType)
|
||||
case c if c == classOf[java.lang.Float] =>
|
||||
Invoke(inputObject, "floatValue", FloatType)
|
||||
case c if c == classOf[java.lang.Double] =>
|
||||
Invoke(inputObject, "doubleValue", DoubleType)
|
||||
|
||||
case _ if typeToken.isArray =>
|
||||
toCatalystArray(inputObject, typeToken.getComponentType)
|
||||
|
||||
case _ if listType.isAssignableFrom(typeToken) =>
|
||||
toCatalystArray(inputObject, elementType(typeToken))
|
||||
|
||||
case _ if mapType.isAssignableFrom(typeToken) =>
|
||||
// TODO: for java map, if we get the keys and values by `keySet` and `values`, we can
|
||||
// not guarantee they have same iteration order(which is different from scala map).
|
||||
// A possible solution is creating a new `MapObjects` that can iterate a map directly.
|
||||
throw new UnsupportedOperationException("map type is not supported currently")
|
||||
|
||||
case other =>
|
||||
val properties = getJavaBeanProperties(other)
|
||||
if (properties.length > 0) {
|
||||
CreateNamedStruct(properties.flatMap { p =>
|
||||
val fieldName = p.getName
|
||||
val fieldType = typeToken.method(p.getReadMethod).getReturnType
|
||||
val fieldValue = Invoke(
|
||||
inputObject,
|
||||
p.getReadMethod.getName,
|
||||
inferExternalType(fieldType.getRawType))
|
||||
expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
|
||||
})
|
||||
} else {
|
||||
throw new UnsupportedOperationException(s"no encoder found for ${other.getName}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,8 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
|
|||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
|
||||
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection
|
||||
import org.apache.spark.sql.catalyst.{JavaTypeInference, InternalRow, ScalaReflection}
|
||||
import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
|
||||
|
||||
/**
|
||||
|
@ -68,6 +67,22 @@ object ExpressionEncoder {
|
|||
ClassTag[T](cls))
|
||||
}
|
||||
|
||||
// TODO: improve error message for java bean encoder.
|
||||
def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = {
|
||||
val schema = JavaTypeInference.inferDataType(beanClass)._1
|
||||
assert(schema.isInstanceOf[StructType])
|
||||
|
||||
val toRowExpression = JavaTypeInference.extractorsFor(beanClass)
|
||||
val fromRowExpression = JavaTypeInference.constructorFor(beanClass)
|
||||
|
||||
new ExpressionEncoder[T](
|
||||
schema.asInstanceOf[StructType],
|
||||
flat = false,
|
||||
toRowExpression.flatten,
|
||||
fromRowExpression,
|
||||
ClassTag[T](beanClass))
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a set of N encoders, constructs a new encoder that produce objects as items in an
|
||||
* N-tuple. Note that these encoders should be unresolved so that information about
|
||||
|
@ -216,7 +231,7 @@ case class ExpressionEncoder[T](
|
|||
*/
|
||||
def assertUnresolved(): Unit = {
|
||||
(fromRowExpression +: toRowExpressions).foreach(_.foreach {
|
||||
case a: AttributeReference =>
|
||||
case a: AttributeReference if a.name != "loopVar" =>
|
||||
sys.error(s"Unresolved encoder expected, but $a was found.")
|
||||
case _ =>
|
||||
})
|
||||
|
|
|
@ -346,7 +346,8 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext
|
|||
* as an ArrayType. This is similar to a typical map operation, but where the lambda function
|
||||
* is expressed using catalyst expressions.
|
||||
*
|
||||
* The following collection ObjectTypes are currently supported: Seq, Array, ArrayData
|
||||
* The following collection ObjectTypes are currently supported:
|
||||
* Seq, Array, ArrayData, java.util.List
|
||||
*
|
||||
* @param function A function that returns an expression, given an attribute that can be used
|
||||
* to access the current value. This is does as a lambda function so that
|
||||
|
@ -386,6 +387,8 @@ case class MapObjects(
|
|||
(".size()", (i: String) => s".apply($i)", false)
|
||||
case ObjectType(cls) if cls.isArray =>
|
||||
(".length", (i: String) => s"[$i]", false)
|
||||
case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
|
||||
(".size()", (i: String) => s".get($i)", false)
|
||||
case ArrayType(t, _) =>
|
||||
val (sqlType, primitiveElement) = t match {
|
||||
case m: MapType => (m, false)
|
||||
|
@ -596,3 +599,40 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
|
|||
|
||||
override def dataType: DataType = ObjectType(tag.runtimeClass)
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize a Java Bean instance by setting its field values via setters.
|
||||
*/
|
||||
case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Expression])
|
||||
extends Expression {
|
||||
|
||||
override def nullable: Boolean = beanInstance.nullable
|
||||
override def children: Seq[Expression] = beanInstance +: setters.values.toSeq
|
||||
override def dataType: DataType = beanInstance.dataType
|
||||
|
||||
override def eval(input: InternalRow): Any =
|
||||
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||
val instanceGen = beanInstance.gen(ctx)
|
||||
|
||||
val initialize = setters.map {
|
||||
case (setterMethod, fieldValue) =>
|
||||
val fieldGen = fieldValue.gen(ctx)
|
||||
s"""
|
||||
${fieldGen.code}
|
||||
${instanceGen.value}.$setterMethod(${fieldGen.value});
|
||||
"""
|
||||
}
|
||||
|
||||
ev.isNull = instanceGen.isNull
|
||||
ev.value = instanceGen.value
|
||||
|
||||
s"""
|
||||
${instanceGen.code}
|
||||
if (!${instanceGen.isNull}) {
|
||||
${initialize.mkString("\n")}
|
||||
}
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.trees
|
||||
|
||||
import scala.collection.Map
|
||||
|
||||
import org.apache.spark.sql.catalyst.errors._
|
||||
import org.apache.spark.sql.types.{StructType, DataType}
|
||||
|
||||
|
@ -191,6 +193,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
|
|||
case nonChild: AnyRef => nonChild
|
||||
case null => null
|
||||
}
|
||||
case m: Map[_, _] => m.mapValues {
|
||||
case arg: TreeNode[_] if containsChild(arg) =>
|
||||
val newChild = remainingNewChildren.remove(0)
|
||||
val oldChild = remainingOldChildren.remove(0)
|
||||
if (newChild fastEquals oldChild) {
|
||||
oldChild
|
||||
} else {
|
||||
changed = true
|
||||
newChild
|
||||
}
|
||||
case nonChild: AnyRef => nonChild
|
||||
case null => null
|
||||
}.view.force // `mapValues` is lazy and we need to force it to materialize
|
||||
case arg: TreeNode[_] if containsChild(arg) =>
|
||||
val newChild = remainingNewChildren.remove(0)
|
||||
val oldChild = remainingOldChildren.remove(0)
|
||||
|
@ -262,7 +277,17 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
|
|||
} else {
|
||||
Some(arg)
|
||||
}
|
||||
case m: Map[_, _] => m
|
||||
case m: Map[_, _] => m.mapValues {
|
||||
case arg: TreeNode[_] if containsChild(arg) =>
|
||||
val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
|
||||
if (!(newChild fastEquals arg)) {
|
||||
changed = true
|
||||
newChild
|
||||
} else {
|
||||
arg
|
||||
}
|
||||
case other => other
|
||||
}.view.force // `mapValues` is lazy and we need to force it to materialize
|
||||
case d: DataType => d // Avoid unpacking Structs
|
||||
case args: Traversable[_] => args.map {
|
||||
case arg: TreeNode[_] if containsChild(arg) =>
|
||||
|
|
|
@ -70,4 +70,9 @@ object ArrayBasedMapData {
|
|||
def toScalaMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
|
||||
keys.zip(values).toMap
|
||||
}
|
||||
|
||||
def toJavaMap(keys: Array[Any], values: Array[Any]): java.util.Map[Any, Any] = {
|
||||
import scala.collection.JavaConverters._
|
||||
keys.zip(values).toMap.asJava
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.util
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.types.{DataType, Decimal}
|
||||
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
||||
|
@ -24,6 +26,7 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
|||
class GenericArrayData(val array: Array[Any]) extends ArrayData {
|
||||
|
||||
def this(seq: Seq[Any]) = this(seq.toArray)
|
||||
def this(list: java.util.List[Any]) = this(list.asScala)
|
||||
|
||||
// TODO: This is boxing. We should specialize.
|
||||
def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq)
|
||||
|
|
|
@ -38,6 +38,13 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]])
|
|||
override def output: Seq[Attribute] = Nil
|
||||
}
|
||||
|
||||
case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable {
|
||||
override def children: Seq[Expression] = map.values.toSeq
|
||||
override def nullable: Boolean = true
|
||||
override def dataType: NullType = NullType
|
||||
override lazy val resolved = true
|
||||
}
|
||||
|
||||
class TreeNodeSuite extends SparkFunSuite {
|
||||
test("top node changed") {
|
||||
val after = Literal(1) transform { case Literal(1, _) => Literal(2) }
|
||||
|
@ -236,4 +243,22 @@ class TreeNodeSuite extends SparkFunSuite {
|
|||
val expected = ComplexPlan(Seq(Seq(Literal("1")), Seq(Literal("2"))))
|
||||
assert(expected === actual)
|
||||
}
|
||||
|
||||
test("expressions inside a map") {
|
||||
val expression = ExpressionInMap(Map("1" -> Literal(1), "2" -> Literal(2)))
|
||||
|
||||
{
|
||||
val actual = expression.transform {
|
||||
case Literal(i: Int, _) => Literal(i + 1)
|
||||
}
|
||||
val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3)))
|
||||
assert(actual === expected)
|
||||
}
|
||||
|
||||
{
|
||||
val actual = expression.withNewChildren(Seq(Literal(2), Literal(3)))
|
||||
val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3)))
|
||||
assert(actual === expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,14 +31,15 @@ import org.apache.spark.Accumulator;
|
|||
import org.apache.spark.SparkContext;
|
||||
import org.apache.spark.api.java.function.*;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.sql.Encoder;
|
||||
import org.apache.spark.sql.Encoders;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.GroupedDataset;
|
||||
import org.apache.spark.sql.*;
|
||||
import org.apache.spark.sql.expressions.Aggregator;
|
||||
import org.apache.spark.sql.test.TestSQLContext;
|
||||
import org.apache.spark.sql.catalyst.encoders.OuterScopes;
|
||||
import org.apache.spark.sql.catalyst.expressions.GenericRow;
|
||||
import org.apache.spark.sql.types.StructType;
|
||||
|
||||
import static org.apache.spark.sql.functions.*;
|
||||
import static org.apache.spark.sql.types.DataTypes.*;
|
||||
|
||||
public class JavaDatasetSuite implements Serializable {
|
||||
private transient JavaSparkContext jsc;
|
||||
|
@ -506,4 +507,169 @@ public class JavaDatasetSuite implements Serializable {
|
|||
public void testKryoEncoderErrorMessageForPrivateClass() {
|
||||
Encoders.kryo(PrivateClassTest.class);
|
||||
}
|
||||
|
||||
public class SimpleJavaBean implements Serializable {
|
||||
private boolean a;
|
||||
private int b;
|
||||
private byte[] c;
|
||||
private String[] d;
|
||||
private List<String> e;
|
||||
private List<Long> f;
|
||||
|
||||
public boolean isA() {
|
||||
return a;
|
||||
}
|
||||
|
||||
public void setA(boolean a) {
|
||||
this.a = a;
|
||||
}
|
||||
|
||||
public int getB() {
|
||||
return b;
|
||||
}
|
||||
|
||||
public void setB(int b) {
|
||||
this.b = b;
|
||||
}
|
||||
|
||||
public byte[] getC() {
|
||||
return c;
|
||||
}
|
||||
|
||||
public void setC(byte[] c) {
|
||||
this.c = c;
|
||||
}
|
||||
|
||||
public String[] getD() {
|
||||
return d;
|
||||
}
|
||||
|
||||
public void setD(String[] d) {
|
||||
this.d = d;
|
||||
}
|
||||
|
||||
public List<String> getE() {
|
||||
return e;
|
||||
}
|
||||
|
||||
public void setE(List<String> e) {
|
||||
this.e = e;
|
||||
}
|
||||
|
||||
public List<Long> getF() {
|
||||
return f;
|
||||
}
|
||||
|
||||
public void setF(List<Long> f) {
|
||||
this.f = f;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
|
||||
SimpleJavaBean that = (SimpleJavaBean) o;
|
||||
|
||||
if (a != that.a) return false;
|
||||
if (b != that.b) return false;
|
||||
if (!Arrays.equals(c, that.c)) return false;
|
||||
if (!Arrays.equals(d, that.d)) return false;
|
||||
if (!e.equals(that.e)) return false;
|
||||
return f.equals(that.f);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = (a ? 1 : 0);
|
||||
result = 31 * result + b;
|
||||
result = 31 * result + Arrays.hashCode(c);
|
||||
result = 31 * result + Arrays.hashCode(d);
|
||||
result = 31 * result + e.hashCode();
|
||||
result = 31 * result + f.hashCode();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
public class NestedJavaBean implements Serializable {
|
||||
private SimpleJavaBean a;
|
||||
|
||||
public SimpleJavaBean getA() {
|
||||
return a;
|
||||
}
|
||||
|
||||
public void setA(SimpleJavaBean a) {
|
||||
this.a = a;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
|
||||
NestedJavaBean that = (NestedJavaBean) o;
|
||||
|
||||
return a.equals(that.a);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return a.hashCode();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testJavaBeanEncoder() {
|
||||
OuterScopes.addOuterScope(this);
|
||||
SimpleJavaBean obj1 = new SimpleJavaBean();
|
||||
obj1.setA(true);
|
||||
obj1.setB(3);
|
||||
obj1.setC(new byte[]{1, 2});
|
||||
obj1.setD(new String[]{"hello", null});
|
||||
obj1.setE(Arrays.asList("a", "b"));
|
||||
obj1.setF(Arrays.asList(100L, null, 200L));
|
||||
SimpleJavaBean obj2 = new SimpleJavaBean();
|
||||
obj2.setA(false);
|
||||
obj2.setB(30);
|
||||
obj2.setC(new byte[]{3, 4});
|
||||
obj2.setD(new String[]{null, "world"});
|
||||
obj2.setE(Arrays.asList("x", "y"));
|
||||
obj2.setF(Arrays.asList(300L, null, 400L));
|
||||
|
||||
List<SimpleJavaBean> data = Arrays.asList(obj1, obj2);
|
||||
Dataset<SimpleJavaBean> ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class));
|
||||
Assert.assertEquals(data, ds.collectAsList());
|
||||
|
||||
NestedJavaBean obj3 = new NestedJavaBean();
|
||||
obj3.setA(obj1);
|
||||
|
||||
List<NestedJavaBean> data2 = Arrays.asList(obj3);
|
||||
Dataset<NestedJavaBean> ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class));
|
||||
Assert.assertEquals(data2, ds2.collectAsList());
|
||||
|
||||
Row row1 = new GenericRow(new Object[]{
|
||||
true,
|
||||
3,
|
||||
new byte[]{1, 2},
|
||||
new String[]{"hello", null},
|
||||
Arrays.asList("a", "b"),
|
||||
Arrays.asList(100L, null, 200L)});
|
||||
Row row2 = new GenericRow(new Object[]{
|
||||
false,
|
||||
30,
|
||||
new byte[]{3, 4},
|
||||
new String[]{null, "world"},
|
||||
Arrays.asList("x", "y"),
|
||||
Arrays.asList(300L, null, 400L)});
|
||||
StructType schema = new StructType()
|
||||
.add("a", BooleanType, false)
|
||||
.add("b", IntegerType, false)
|
||||
.add("c", BinaryType)
|
||||
.add("d", createArrayType(StringType))
|
||||
.add("e", createArrayType(StringType))
|
||||
.add("f", createArrayType(LongType));
|
||||
Dataset<SimpleJavaBean> ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema)
|
||||
.as(Encoders.bean(SimpleJavaBean.class));
|
||||
Assert.assertEquals(data, ds3.collectAsList());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue