[SPARK-11754][SQL] consolidate ExpressionEncoder.tuple
and Encoders.tuple
These 2 are very similar, we can consolidate them into one. Also add tests for it and fix a bug. Author: Wenchen Fan <wenchen@databricks.com> Closes #9729 from cloud-fan/tuple.
This commit is contained in:
parent
24477d2705
commit
b1a9662623
|
@ -19,10 +19,8 @@ package org.apache.spark.sql
|
|||
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
|
||||
import org.apache.spark.util.Utils
|
||||
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
/**
|
||||
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
|
||||
|
@ -49,83 +47,34 @@ object Encoders {
|
|||
def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true)
|
||||
def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true)
|
||||
|
||||
def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = {
|
||||
tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]]))
|
||||
.asInstanceOf[ExpressionEncoder[(T1, T2)]]
|
||||
def tuple[T1, T2](
|
||||
e1: Encoder[T1],
|
||||
e2: Encoder[T2]): Encoder[(T1, T2)] = {
|
||||
ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2))
|
||||
}
|
||||
|
||||
def tuple[T1, T2, T3](
|
||||
enc1: Encoder[T1],
|
||||
enc2: Encoder[T2],
|
||||
enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
|
||||
tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]]))
|
||||
.asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
|
||||
e1: Encoder[T1],
|
||||
e2: Encoder[T2],
|
||||
e3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
|
||||
ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3))
|
||||
}
|
||||
|
||||
def tuple[T1, T2, T3, T4](
|
||||
enc1: Encoder[T1],
|
||||
enc2: Encoder[T2],
|
||||
enc3: Encoder[T3],
|
||||
enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
|
||||
tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]]))
|
||||
.asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]
|
||||
e1: Encoder[T1],
|
||||
e2: Encoder[T2],
|
||||
e3: Encoder[T3],
|
||||
e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
|
||||
ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4))
|
||||
}
|
||||
|
||||
def tuple[T1, T2, T3, T4, T5](
|
||||
enc1: Encoder[T1],
|
||||
enc2: Encoder[T2],
|
||||
enc3: Encoder[T3],
|
||||
enc4: Encoder[T4],
|
||||
enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
|
||||
tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]]))
|
||||
.asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
|
||||
}
|
||||
|
||||
private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
|
||||
assert(encoders.length > 1)
|
||||
// make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`.
|
||||
assert(encoders.forall(_.fromRowExpression.find(_.isInstanceOf[Attribute]).isEmpty))
|
||||
|
||||
val schema = StructType(encoders.zipWithIndex.map {
|
||||
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
|
||||
})
|
||||
|
||||
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
|
||||
|
||||
val extractExpressions = encoders.map {
|
||||
case e if e.flat => e.toRowExpressions.head
|
||||
case other => CreateStruct(other.toRowExpressions)
|
||||
}.zipWithIndex.map { case (expr, index) =>
|
||||
expr.transformUp {
|
||||
case BoundReference(0, t: ObjectType, _) =>
|
||||
Invoke(
|
||||
BoundReference(0, ObjectType(cls), nullable = true),
|
||||
s"_${index + 1}",
|
||||
t)
|
||||
}
|
||||
}
|
||||
|
||||
val constructExpressions = encoders.zipWithIndex.map { case (enc, index) =>
|
||||
if (enc.flat) {
|
||||
enc.fromRowExpression.transform {
|
||||
case b: BoundReference => b.copy(ordinal = index)
|
||||
}
|
||||
} else {
|
||||
enc.fromRowExpression.transformUp {
|
||||
case BoundReference(ordinal, dt, _) =>
|
||||
GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val constructExpression =
|
||||
NewInstance(cls, constructExpressions, propagateNull = false, ObjectType(cls))
|
||||
|
||||
new ExpressionEncoder[Any](
|
||||
schema,
|
||||
flat = false,
|
||||
extractExpressions,
|
||||
constructExpression,
|
||||
ClassTag(cls))
|
||||
e1: Encoder[T1],
|
||||
e2: Encoder[T2],
|
||||
e3: Encoder[T3],
|
||||
e4: Encoder[T4],
|
||||
e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
|
||||
ExpressionEncoder.tuple(
|
||||
encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -67,47 +67,77 @@ object ExpressionEncoder {
|
|||
def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
|
||||
encoders.foreach(_.assertUnresolved())
|
||||
|
||||
val schema =
|
||||
StructType(
|
||||
encoders.zipWithIndex.map {
|
||||
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
|
||||
})
|
||||
val schema = StructType(encoders.zipWithIndex.map {
|
||||
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
|
||||
})
|
||||
|
||||
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
|
||||
|
||||
// Rebind the encoders to the nested schema.
|
||||
val newConstructExpressions = encoders.zipWithIndex.map {
|
||||
case (e, i) if !e.flat => e.nested(i).fromRowExpression
|
||||
case (e, i) => e.shift(i).fromRowExpression
|
||||
}
|
||||
|
||||
val constructExpression =
|
||||
NewInstance(cls, newConstructExpressions, false, ObjectType(cls))
|
||||
|
||||
val input = BoundReference(0, ObjectType(cls), false)
|
||||
val extractExpressions = encoders.zipWithIndex.map {
|
||||
case (e, i) if !e.flat => CreateStruct(e.toRowExpressions.map(_ transformUp {
|
||||
case b: BoundReference =>
|
||||
Invoke(input, s"_${i + 1}", b.dataType, Nil)
|
||||
}))
|
||||
case (e, i) => e.toRowExpressions.head transformUp {
|
||||
case b: BoundReference =>
|
||||
Invoke(input, s"_${i + 1}", b.dataType, Nil)
|
||||
val toRowExpressions = encoders.map {
|
||||
case e if e.flat => e.toRowExpressions.head
|
||||
case other => CreateStruct(other.toRowExpressions)
|
||||
}.zipWithIndex.map { case (expr, index) =>
|
||||
expr.transformUp {
|
||||
case BoundReference(0, t, _) =>
|
||||
Invoke(
|
||||
BoundReference(0, ObjectType(cls), nullable = true),
|
||||
s"_${index + 1}",
|
||||
t)
|
||||
}
|
||||
}
|
||||
|
||||
val fromRowExpressions = encoders.zipWithIndex.map { case (enc, index) =>
|
||||
if (enc.flat) {
|
||||
enc.fromRowExpression.transform {
|
||||
case b: BoundReference => b.copy(ordinal = index)
|
||||
}
|
||||
} else {
|
||||
val input = BoundReference(index, enc.schema, nullable = true)
|
||||
enc.fromRowExpression.transformUp {
|
||||
case UnresolvedAttribute(nameParts) =>
|
||||
assert(nameParts.length == 1)
|
||||
UnresolvedExtractValue(input, Literal(nameParts.head))
|
||||
case BoundReference(ordinal, dt, _) => GetInternalRowField(input, ordinal, dt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val fromRowExpression =
|
||||
NewInstance(cls, fromRowExpressions, propagateNull = false, ObjectType(cls))
|
||||
|
||||
new ExpressionEncoder[Any](
|
||||
schema,
|
||||
false,
|
||||
extractExpressions,
|
||||
constructExpression,
|
||||
ClassTag.apply(cls))
|
||||
flat = false,
|
||||
toRowExpressions,
|
||||
fromRowExpression,
|
||||
ClassTag(cls))
|
||||
}
|
||||
|
||||
/** A helper for producing encoders of Tuple2 from other encoders. */
|
||||
def tuple[T1, T2](
|
||||
e1: ExpressionEncoder[T1],
|
||||
e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
|
||||
tuple(e1 :: e2 :: Nil).asInstanceOf[ExpressionEncoder[(T1, T2)]]
|
||||
tuple(Seq(e1, e2)).asInstanceOf[ExpressionEncoder[(T1, T2)]]
|
||||
|
||||
def tuple[T1, T2, T3](
|
||||
e1: ExpressionEncoder[T1],
|
||||
e2: ExpressionEncoder[T2],
|
||||
e3: ExpressionEncoder[T3]): ExpressionEncoder[(T1, T2, T3)] =
|
||||
tuple(Seq(e1, e2, e3)).asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
|
||||
|
||||
def tuple[T1, T2, T3, T4](
|
||||
e1: ExpressionEncoder[T1],
|
||||
e2: ExpressionEncoder[T2],
|
||||
e3: ExpressionEncoder[T3],
|
||||
e4: ExpressionEncoder[T4]): ExpressionEncoder[(T1, T2, T3, T4)] =
|
||||
tuple(Seq(e1, e2, e3, e4)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]
|
||||
|
||||
def tuple[T1, T2, T3, T4, T5](
|
||||
e1: ExpressionEncoder[T1],
|
||||
e2: ExpressionEncoder[T2],
|
||||
e3: ExpressionEncoder[T3],
|
||||
e4: ExpressionEncoder[T4],
|
||||
e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] =
|
||||
tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -208,26 +238,6 @@ case class ExpressionEncoder[T](
|
|||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a copy of this encoder where the expressions used to create an object given an
|
||||
* input row have been modified to pull the object out from a nested struct, instead of the
|
||||
* top level fields.
|
||||
*/
|
||||
private def nested(i: Int): ExpressionEncoder[T] = {
|
||||
// We don't always know our input type at this point since it might be unresolved.
|
||||
// We fill in null and it will get unbound to the actual attribute at this position.
|
||||
val input = BoundReference(i, NullType, nullable = true)
|
||||
copy(fromRowExpression = fromRowExpression transformUp {
|
||||
case u: Attribute =>
|
||||
UnresolvedExtractValue(input, Literal(u.name))
|
||||
case b: BoundReference =>
|
||||
GetStructField(
|
||||
input,
|
||||
StructField(s"i[${b.ordinal}]", b.dataType),
|
||||
b.ordinal)
|
||||
})
|
||||
}
|
||||
|
||||
protected val attrs = toRowExpressions.flatMap(_.collect {
|
||||
case _: UnresolvedAttribute => ""
|
||||
case a: Attribute => s"#${a.exprId}"
|
||||
|
|
|
@ -117,6 +117,35 @@ class ProductEncoderSuite extends ExpressionEncoderSuite {
|
|||
productTest(("Seq[Seq[(Int, Int)]]",
|
||||
Seq(Seq((1, 2)))))
|
||||
|
||||
encodeDecodeTest(
|
||||
1 -> 10L,
|
||||
ExpressionEncoder.tuple(FlatEncoder[Int], FlatEncoder[Long]),
|
||||
"tuple with 2 flat encoders")
|
||||
|
||||
encodeDecodeTest(
|
||||
(PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)),
|
||||
ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], ProductEncoder[(Int, Long)]),
|
||||
"tuple with 2 product encoders")
|
||||
|
||||
encodeDecodeTest(
|
||||
(PrimitiveData(1, 1, 1, 1, 1, 1, true), 3),
|
||||
ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], FlatEncoder[Int]),
|
||||
"tuple with flat encoder and product encoder")
|
||||
|
||||
encodeDecodeTest(
|
||||
(3, PrimitiveData(1, 1, 1, 1, 1, 1, true)),
|
||||
ExpressionEncoder.tuple(FlatEncoder[Int], ProductEncoder[PrimitiveData]),
|
||||
"tuple with product encoder and flat encoder")
|
||||
|
||||
encodeDecodeTest(
|
||||
(1, (10, 100L)),
|
||||
{
|
||||
val intEnc = FlatEncoder[Int]
|
||||
val longEnc = FlatEncoder[Long]
|
||||
ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
|
||||
},
|
||||
"nested tuple encoder")
|
||||
|
||||
private def productTest[T <: Product : TypeTag](input: T): Unit = {
|
||||
encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue