[SPARK-11754][SQL] consolidate ExpressionEncoder.tuple and Encoders.tuple

These 2 are very similar, we can consolidate them into one.

Also add tests for it and fix a bug.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9729 from cloud-fan/tuple.
This commit is contained in:
Wenchen Fan 2015-11-16 12:45:34 -08:00 committed by Michael Armbrust
parent 24477d2705
commit b1a9662623
3 changed files with 110 additions and 122 deletions

View file

@ -19,10 +19,8 @@ package org.apache.spark.sql
import scala.reflect.ClassTag import scala.reflect.ClassTag
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
import org.apache.spark.util.Utils
/** /**
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
@ -49,83 +47,34 @@ object Encoders {
def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true)
def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true)
def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = { def tuple[T1, T2](
tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]])) e1: Encoder[T1],
.asInstanceOf[ExpressionEncoder[(T1, T2)]] e2: Encoder[T2]): Encoder[(T1, T2)] = {
ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2))
} }
def tuple[T1, T2, T3]( def tuple[T1, T2, T3](
enc1: Encoder[T1], e1: Encoder[T1],
enc2: Encoder[T2], e2: Encoder[T2],
enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = { e3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]])) ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3))
.asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
} }
def tuple[T1, T2, T3, T4]( def tuple[T1, T2, T3, T4](
enc1: Encoder[T1], e1: Encoder[T1],
enc2: Encoder[T2], e2: Encoder[T2],
enc3: Encoder[T3], e3: Encoder[T3],
enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]])) ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4))
.asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]
} }
def tuple[T1, T2, T3, T4, T5]( def tuple[T1, T2, T3, T4, T5](
enc1: Encoder[T1], e1: Encoder[T1],
enc2: Encoder[T2], e2: Encoder[T2],
enc3: Encoder[T3], e3: Encoder[T3],
enc4: Encoder[T4], e4: Encoder[T4],
enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]])) ExpressionEncoder.tuple(
.asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5))
}
private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
assert(encoders.length > 1)
// make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`.
assert(encoders.forall(_.fromRowExpression.find(_.isInstanceOf[Attribute]).isEmpty))
val schema = StructType(encoders.zipWithIndex.map {
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
})
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
val extractExpressions = encoders.map {
case e if e.flat => e.toRowExpressions.head
case other => CreateStruct(other.toRowExpressions)
}.zipWithIndex.map { case (expr, index) =>
expr.transformUp {
case BoundReference(0, t: ObjectType, _) =>
Invoke(
BoundReference(0, ObjectType(cls), nullable = true),
s"_${index + 1}",
t)
}
}
val constructExpressions = encoders.zipWithIndex.map { case (enc, index) =>
if (enc.flat) {
enc.fromRowExpression.transform {
case b: BoundReference => b.copy(ordinal = index)
}
} else {
enc.fromRowExpression.transformUp {
case BoundReference(ordinal, dt, _) =>
GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt)
}
}
}
val constructExpression =
NewInstance(cls, constructExpressions, propagateNull = false, ObjectType(cls))
new ExpressionEncoder[Any](
schema,
flat = false,
extractExpressions,
constructExpression,
ClassTag(cls))
} }
} }

View file

@ -67,47 +67,77 @@ object ExpressionEncoder {
def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
encoders.foreach(_.assertUnresolved()) encoders.foreach(_.assertUnresolved())
val schema = val schema = StructType(encoders.zipWithIndex.map {
StructType(
encoders.zipWithIndex.map {
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
}) })
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
// Rebind the encoders to the nested schema. val toRowExpressions = encoders.map {
val newConstructExpressions = encoders.zipWithIndex.map { case e if e.flat => e.toRowExpressions.head
case (e, i) if !e.flat => e.nested(i).fromRowExpression case other => CreateStruct(other.toRowExpressions)
case (e, i) => e.shift(i).fromRowExpression }.zipWithIndex.map { case (expr, index) =>
expr.transformUp {
case BoundReference(0, t, _) =>
Invoke(
BoundReference(0, ObjectType(cls), nullable = true),
s"_${index + 1}",
t)
}
} }
val constructExpression = val fromRowExpressions = encoders.zipWithIndex.map { case (enc, index) =>
NewInstance(cls, newConstructExpressions, false, ObjectType(cls)) if (enc.flat) {
enc.fromRowExpression.transform {
case b: BoundReference => b.copy(ordinal = index)
}
} else {
val input = BoundReference(index, enc.schema, nullable = true)
enc.fromRowExpression.transformUp {
case UnresolvedAttribute(nameParts) =>
assert(nameParts.length == 1)
UnresolvedExtractValue(input, Literal(nameParts.head))
case BoundReference(ordinal, dt, _) => GetInternalRowField(input, ordinal, dt)
}
}
}
val input = BoundReference(0, ObjectType(cls), false) val fromRowExpression =
val extractExpressions = encoders.zipWithIndex.map { NewInstance(cls, fromRowExpressions, propagateNull = false, ObjectType(cls))
case (e, i) if !e.flat => CreateStruct(e.toRowExpressions.map(_ transformUp {
case b: BoundReference =>
Invoke(input, s"_${i + 1}", b.dataType, Nil)
}))
case (e, i) => e.toRowExpressions.head transformUp {
case b: BoundReference =>
Invoke(input, s"_${i + 1}", b.dataType, Nil)
}
}
new ExpressionEncoder[Any]( new ExpressionEncoder[Any](
schema, schema,
false, flat = false,
extractExpressions, toRowExpressions,
constructExpression, fromRowExpression,
ClassTag.apply(cls)) ClassTag(cls))
} }
/** A helper for producing encoders of Tuple2 from other encoders. */
def tuple[T1, T2]( def tuple[T1, T2](
e1: ExpressionEncoder[T1], e1: ExpressionEncoder[T1],
e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] = e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
tuple(e1 :: e2 :: Nil).asInstanceOf[ExpressionEncoder[(T1, T2)]] tuple(Seq(e1, e2)).asInstanceOf[ExpressionEncoder[(T1, T2)]]
def tuple[T1, T2, T3](
e1: ExpressionEncoder[T1],
e2: ExpressionEncoder[T2],
e3: ExpressionEncoder[T3]): ExpressionEncoder[(T1, T2, T3)] =
tuple(Seq(e1, e2, e3)).asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
def tuple[T1, T2, T3, T4](
e1: ExpressionEncoder[T1],
e2: ExpressionEncoder[T2],
e3: ExpressionEncoder[T3],
e4: ExpressionEncoder[T4]): ExpressionEncoder[(T1, T2, T3, T4)] =
tuple(Seq(e1, e2, e3, e4)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]
def tuple[T1, T2, T3, T4, T5](
e1: ExpressionEncoder[T1],
e2: ExpressionEncoder[T2],
e3: ExpressionEncoder[T3],
e4: ExpressionEncoder[T4],
e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] =
tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
} }
/** /**
@ -208,26 +238,6 @@ case class ExpressionEncoder[T](
}) })
} }
/**
* Returns a copy of this encoder where the expressions used to create an object given an
* input row have been modified to pull the object out from a nested struct, instead of the
* top level fields.
*/
private def nested(i: Int): ExpressionEncoder[T] = {
// We don't always know our input type at this point since it might be unresolved.
// We fill in null and it will get unbound to the actual attribute at this position.
val input = BoundReference(i, NullType, nullable = true)
copy(fromRowExpression = fromRowExpression transformUp {
case u: Attribute =>
UnresolvedExtractValue(input, Literal(u.name))
case b: BoundReference =>
GetStructField(
input,
StructField(s"i[${b.ordinal}]", b.dataType),
b.ordinal)
})
}
protected val attrs = toRowExpressions.flatMap(_.collect { protected val attrs = toRowExpressions.flatMap(_.collect {
case _: UnresolvedAttribute => "" case _: UnresolvedAttribute => ""
case a: Attribute => s"#${a.exprId}" case a: Attribute => s"#${a.exprId}"

View file

@ -117,6 +117,35 @@ class ProductEncoderSuite extends ExpressionEncoderSuite {
productTest(("Seq[Seq[(Int, Int)]]", productTest(("Seq[Seq[(Int, Int)]]",
Seq(Seq((1, 2))))) Seq(Seq((1, 2)))))
encodeDecodeTest(
1 -> 10L,
ExpressionEncoder.tuple(FlatEncoder[Int], FlatEncoder[Long]),
"tuple with 2 flat encoders")
encodeDecodeTest(
(PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)),
ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], ProductEncoder[(Int, Long)]),
"tuple with 2 product encoders")
encodeDecodeTest(
(PrimitiveData(1, 1, 1, 1, 1, 1, true), 3),
ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], FlatEncoder[Int]),
"tuple with flat encoder and product encoder")
encodeDecodeTest(
(3, PrimitiveData(1, 1, 1, 1, 1, 1, true)),
ExpressionEncoder.tuple(FlatEncoder[Int], ProductEncoder[PrimitiveData]),
"tuple with product encoder and flat encoder")
encodeDecodeTest(
(1, (10, 100L)),
{
val intEnc = FlatEncoder[Int]
val longEnc = FlatEncoder[Long]
ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
},
"nested tuple encoder")
private def productTest[T <: Product : TypeTag](input: T): Unit = { private def productTest[T <: Product : TypeTag](input: T): Unit = {
encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName) encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName)
} }