[SPARK-11754][SQL] consolidate ExpressionEncoder.tuple
and Encoders.tuple
These 2 are very similar, we can consolidate them into one. Also add tests for it and fix a bug. Author: Wenchen Fan <wenchen@databricks.com> Closes #9729 from cloud-fan/tuple.
This commit is contained in:
parent
24477d2705
commit
b1a9662623
|
@ -19,10 +19,8 @@ package org.apache.spark.sql
|
||||||
|
|
||||||
import scala.reflect.ClassTag
|
import scala.reflect.ClassTag
|
||||||
|
|
||||||
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
|
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
|
||||||
import org.apache.spark.sql.catalyst.expressions._
|
import org.apache.spark.sql.types.StructType
|
||||||
import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
|
|
||||||
import org.apache.spark.util.Utils
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
|
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
|
||||||
|
@ -49,83 +47,34 @@ object Encoders {
|
||||||
def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true)
|
def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true)
|
||||||
def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true)
|
def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true)
|
||||||
|
|
||||||
def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = {
|
def tuple[T1, T2](
|
||||||
tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]]))
|
e1: Encoder[T1],
|
||||||
.asInstanceOf[ExpressionEncoder[(T1, T2)]]
|
e2: Encoder[T2]): Encoder[(T1, T2)] = {
|
||||||
|
ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2))
|
||||||
}
|
}
|
||||||
|
|
||||||
def tuple[T1, T2, T3](
|
def tuple[T1, T2, T3](
|
||||||
enc1: Encoder[T1],
|
e1: Encoder[T1],
|
||||||
enc2: Encoder[T2],
|
e2: Encoder[T2],
|
||||||
enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
|
e3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
|
||||||
tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]]))
|
ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3))
|
||||||
.asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def tuple[T1, T2, T3, T4](
|
def tuple[T1, T2, T3, T4](
|
||||||
enc1: Encoder[T1],
|
e1: Encoder[T1],
|
||||||
enc2: Encoder[T2],
|
e2: Encoder[T2],
|
||||||
enc3: Encoder[T3],
|
e3: Encoder[T3],
|
||||||
enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
|
e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
|
||||||
tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]]))
|
ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4))
|
||||||
.asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def tuple[T1, T2, T3, T4, T5](
|
def tuple[T1, T2, T3, T4, T5](
|
||||||
enc1: Encoder[T1],
|
e1: Encoder[T1],
|
||||||
enc2: Encoder[T2],
|
e2: Encoder[T2],
|
||||||
enc3: Encoder[T3],
|
e3: Encoder[T3],
|
||||||
enc4: Encoder[T4],
|
e4: Encoder[T4],
|
||||||
enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
|
e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
|
||||||
tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]]))
|
ExpressionEncoder.tuple(
|
||||||
.asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
|
encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5))
|
||||||
}
|
|
||||||
|
|
||||||
private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
|
|
||||||
assert(encoders.length > 1)
|
|
||||||
// make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`.
|
|
||||||
assert(encoders.forall(_.fromRowExpression.find(_.isInstanceOf[Attribute]).isEmpty))
|
|
||||||
|
|
||||||
val schema = StructType(encoders.zipWithIndex.map {
|
|
||||||
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
|
|
||||||
})
|
|
||||||
|
|
||||||
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
|
|
||||||
|
|
||||||
val extractExpressions = encoders.map {
|
|
||||||
case e if e.flat => e.toRowExpressions.head
|
|
||||||
case other => CreateStruct(other.toRowExpressions)
|
|
||||||
}.zipWithIndex.map { case (expr, index) =>
|
|
||||||
expr.transformUp {
|
|
||||||
case BoundReference(0, t: ObjectType, _) =>
|
|
||||||
Invoke(
|
|
||||||
BoundReference(0, ObjectType(cls), nullable = true),
|
|
||||||
s"_${index + 1}",
|
|
||||||
t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
val constructExpressions = encoders.zipWithIndex.map { case (enc, index) =>
|
|
||||||
if (enc.flat) {
|
|
||||||
enc.fromRowExpression.transform {
|
|
||||||
case b: BoundReference => b.copy(ordinal = index)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
enc.fromRowExpression.transformUp {
|
|
||||||
case BoundReference(ordinal, dt, _) =>
|
|
||||||
GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
val constructExpression =
|
|
||||||
NewInstance(cls, constructExpressions, propagateNull = false, ObjectType(cls))
|
|
||||||
|
|
||||||
new ExpressionEncoder[Any](
|
|
||||||
schema,
|
|
||||||
flat = false,
|
|
||||||
extractExpressions,
|
|
||||||
constructExpression,
|
|
||||||
ClassTag(cls))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,47 +67,77 @@ object ExpressionEncoder {
|
||||||
def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
|
def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
|
||||||
encoders.foreach(_.assertUnresolved())
|
encoders.foreach(_.assertUnresolved())
|
||||||
|
|
||||||
val schema =
|
val schema = StructType(encoders.zipWithIndex.map {
|
||||||
StructType(
|
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
|
||||||
encoders.zipWithIndex.map {
|
})
|
||||||
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
|
|
||||||
})
|
|
||||||
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
|
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
|
||||||
|
|
||||||
// Rebind the encoders to the nested schema.
|
val toRowExpressions = encoders.map {
|
||||||
val newConstructExpressions = encoders.zipWithIndex.map {
|
case e if e.flat => e.toRowExpressions.head
|
||||||
case (e, i) if !e.flat => e.nested(i).fromRowExpression
|
case other => CreateStruct(other.toRowExpressions)
|
||||||
case (e, i) => e.shift(i).fromRowExpression
|
}.zipWithIndex.map { case (expr, index) =>
|
||||||
}
|
expr.transformUp {
|
||||||
|
case BoundReference(0, t, _) =>
|
||||||
val constructExpression =
|
Invoke(
|
||||||
NewInstance(cls, newConstructExpressions, false, ObjectType(cls))
|
BoundReference(0, ObjectType(cls), nullable = true),
|
||||||
|
s"_${index + 1}",
|
||||||
val input = BoundReference(0, ObjectType(cls), false)
|
t)
|
||||||
val extractExpressions = encoders.zipWithIndex.map {
|
|
||||||
case (e, i) if !e.flat => CreateStruct(e.toRowExpressions.map(_ transformUp {
|
|
||||||
case b: BoundReference =>
|
|
||||||
Invoke(input, s"_${i + 1}", b.dataType, Nil)
|
|
||||||
}))
|
|
||||||
case (e, i) => e.toRowExpressions.head transformUp {
|
|
||||||
case b: BoundReference =>
|
|
||||||
Invoke(input, s"_${i + 1}", b.dataType, Nil)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val fromRowExpressions = encoders.zipWithIndex.map { case (enc, index) =>
|
||||||
|
if (enc.flat) {
|
||||||
|
enc.fromRowExpression.transform {
|
||||||
|
case b: BoundReference => b.copy(ordinal = index)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
val input = BoundReference(index, enc.schema, nullable = true)
|
||||||
|
enc.fromRowExpression.transformUp {
|
||||||
|
case UnresolvedAttribute(nameParts) =>
|
||||||
|
assert(nameParts.length == 1)
|
||||||
|
UnresolvedExtractValue(input, Literal(nameParts.head))
|
||||||
|
case BoundReference(ordinal, dt, _) => GetInternalRowField(input, ordinal, dt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val fromRowExpression =
|
||||||
|
NewInstance(cls, fromRowExpressions, propagateNull = false, ObjectType(cls))
|
||||||
|
|
||||||
new ExpressionEncoder[Any](
|
new ExpressionEncoder[Any](
|
||||||
schema,
|
schema,
|
||||||
false,
|
flat = false,
|
||||||
extractExpressions,
|
toRowExpressions,
|
||||||
constructExpression,
|
fromRowExpression,
|
||||||
ClassTag.apply(cls))
|
ClassTag(cls))
|
||||||
}
|
}
|
||||||
|
|
||||||
/** A helper for producing encoders of Tuple2 from other encoders. */
|
|
||||||
def tuple[T1, T2](
|
def tuple[T1, T2](
|
||||||
e1: ExpressionEncoder[T1],
|
e1: ExpressionEncoder[T1],
|
||||||
e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
|
e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
|
||||||
tuple(e1 :: e2 :: Nil).asInstanceOf[ExpressionEncoder[(T1, T2)]]
|
tuple(Seq(e1, e2)).asInstanceOf[ExpressionEncoder[(T1, T2)]]
|
||||||
|
|
||||||
|
def tuple[T1, T2, T3](
|
||||||
|
e1: ExpressionEncoder[T1],
|
||||||
|
e2: ExpressionEncoder[T2],
|
||||||
|
e3: ExpressionEncoder[T3]): ExpressionEncoder[(T1, T2, T3)] =
|
||||||
|
tuple(Seq(e1, e2, e3)).asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
|
||||||
|
|
||||||
|
def tuple[T1, T2, T3, T4](
|
||||||
|
e1: ExpressionEncoder[T1],
|
||||||
|
e2: ExpressionEncoder[T2],
|
||||||
|
e3: ExpressionEncoder[T3],
|
||||||
|
e4: ExpressionEncoder[T4]): ExpressionEncoder[(T1, T2, T3, T4)] =
|
||||||
|
tuple(Seq(e1, e2, e3, e4)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]
|
||||||
|
|
||||||
|
def tuple[T1, T2, T3, T4, T5](
|
||||||
|
e1: ExpressionEncoder[T1],
|
||||||
|
e2: ExpressionEncoder[T2],
|
||||||
|
e3: ExpressionEncoder[T3],
|
||||||
|
e4: ExpressionEncoder[T4],
|
||||||
|
e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] =
|
||||||
|
tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -208,26 +238,6 @@ case class ExpressionEncoder[T](
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a copy of this encoder where the expressions used to create an object given an
|
|
||||||
* input row have been modified to pull the object out from a nested struct, instead of the
|
|
||||||
* top level fields.
|
|
||||||
*/
|
|
||||||
private def nested(i: Int): ExpressionEncoder[T] = {
|
|
||||||
// We don't always know our input type at this point since it might be unresolved.
|
|
||||||
// We fill in null and it will get unbound to the actual attribute at this position.
|
|
||||||
val input = BoundReference(i, NullType, nullable = true)
|
|
||||||
copy(fromRowExpression = fromRowExpression transformUp {
|
|
||||||
case u: Attribute =>
|
|
||||||
UnresolvedExtractValue(input, Literal(u.name))
|
|
||||||
case b: BoundReference =>
|
|
||||||
GetStructField(
|
|
||||||
input,
|
|
||||||
StructField(s"i[${b.ordinal}]", b.dataType),
|
|
||||||
b.ordinal)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
protected val attrs = toRowExpressions.flatMap(_.collect {
|
protected val attrs = toRowExpressions.flatMap(_.collect {
|
||||||
case _: UnresolvedAttribute => ""
|
case _: UnresolvedAttribute => ""
|
||||||
case a: Attribute => s"#${a.exprId}"
|
case a: Attribute => s"#${a.exprId}"
|
||||||
|
|
|
@ -117,6 +117,35 @@ class ProductEncoderSuite extends ExpressionEncoderSuite {
|
||||||
productTest(("Seq[Seq[(Int, Int)]]",
|
productTest(("Seq[Seq[(Int, Int)]]",
|
||||||
Seq(Seq((1, 2)))))
|
Seq(Seq((1, 2)))))
|
||||||
|
|
||||||
|
encodeDecodeTest(
|
||||||
|
1 -> 10L,
|
||||||
|
ExpressionEncoder.tuple(FlatEncoder[Int], FlatEncoder[Long]),
|
||||||
|
"tuple with 2 flat encoders")
|
||||||
|
|
||||||
|
encodeDecodeTest(
|
||||||
|
(PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)),
|
||||||
|
ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], ProductEncoder[(Int, Long)]),
|
||||||
|
"tuple with 2 product encoders")
|
||||||
|
|
||||||
|
encodeDecodeTest(
|
||||||
|
(PrimitiveData(1, 1, 1, 1, 1, 1, true), 3),
|
||||||
|
ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], FlatEncoder[Int]),
|
||||||
|
"tuple with flat encoder and product encoder")
|
||||||
|
|
||||||
|
encodeDecodeTest(
|
||||||
|
(3, PrimitiveData(1, 1, 1, 1, 1, 1, true)),
|
||||||
|
ExpressionEncoder.tuple(FlatEncoder[Int], ProductEncoder[PrimitiveData]),
|
||||||
|
"tuple with product encoder and flat encoder")
|
||||||
|
|
||||||
|
encodeDecodeTest(
|
||||||
|
(1, (10, 100L)),
|
||||||
|
{
|
||||||
|
val intEnc = FlatEncoder[Int]
|
||||||
|
val longEnc = FlatEncoder[Long]
|
||||||
|
ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
|
||||||
|
},
|
||||||
|
"nested tuple encoder")
|
||||||
|
|
||||||
private def productTest[T <: Product : TypeTag](input: T): Unit = {
|
private def productTest[T <: Product : TypeTag](input: T): Unit = {
|
||||||
encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName)
|
encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue