diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6591559426..0e2fd43983 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1672,9 +1672,9 @@ object CleanupAliases extends Rule[LogicalPlan] { // Operators that operate on objects should only have expressions from encoders, which should // never have extra aliases. - case o: ObjectOperator => o - case d: DeserializeToObject => d - case s: SerializeFromObject => s + case o: ObjectConsumer => o + case o: ObjectProducer => o + case a: AppendColumns => a case other => var stop = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 958966328b..085e95f542 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -245,6 +245,10 @@ package object dsl { def struct(attrs: AttributeReference*): AttributeReference = struct(StructType.fromAttributes(attrs)) + /** Creates a new AttributeReference of object type */ + def obj(cls: Class[_]): AttributeReference = + AttributeReference(s, ObjectType(cls), nullable = true)() + /** Create a function. */ def function(exprs: Expression*): UnresolvedFunction = UnresolvedFunction(s, exprs, isDistinct = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b806b725a8..0a5232b2d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -153,29 +153,16 @@ object SamplePushDown extends Rule[LogicalPlan] { * representation of data item. For example back to back map operations. */ object EliminateSerialization extends Rule[LogicalPlan] { - // TODO: find a more general way to do this optimization. def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case m @ MapPartitions(_, deserializer, _, child: ObjectOperator) - if !deserializer.isInstanceOf[Attribute] && - deserializer.dataType == child.outputObject.dataType => - val childWithoutSerialization = child.withObjectOutput - m.copy( - deserializer = childWithoutSerialization.output.head, - child = childWithoutSerialization) - - case m @ MapElements(_, deserializer, _, child: ObjectOperator) - if !deserializer.isInstanceOf[Attribute] && - deserializer.dataType == child.outputObject.dataType => - val childWithoutSerialization = child.withObjectOutput - m.copy( - deserializer = childWithoutSerialization.output.head, - child = childWithoutSerialization) - - case d @ DeserializeToObject(_, s: SerializeFromObject) + case d @ DeserializeToObject(_, _, s: SerializeFromObject) if d.outputObjectType == s.inputObjectType => // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) Project(objAttr :: Nil, s.child) + + case a @ AppendColumns(_, _, _, s: SerializeFromObject) + if a.deserializer.dataType == s.inputObjectType => + AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) } } @@ -366,9 +353,9 @@ object ColumnPruning extends Rule[LogicalPlan] { } a.copy(child = Expand(newProjects, newOutput, grandChild)) - // Prunes the unused columns from child of MapPartitions - case mp @ MapPartitions(_, _, _, child) if (child.outputSet -- mp.references).nonEmpty => - mp.copy(child = prunedChild(child, mp.references)) + // Prunes the unused columns from child of `DeserializeToObject` + case d @ DeserializeToObject(_, _, child) if (child.outputSet -- d.references).nonEmpty => + d.copy(child = prunedChild(child, d.references)) // Prunes the unused columns from child of Aggregate/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => @@ -1453,7 +1440,7 @@ object EmbedSerializerInFilter extends Rule[LogicalPlan] { s } else { val newCondition = condition transform { - case a: Attribute if a == d.output.head => d.deserializer.child + case a: Attribute if a == d.output.head => d.deserializer } Filter(newCondition, d.child) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 6df46189b6..4a1bdb0b8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -21,126 +21,111 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{DataType, ObjectType, StructType} +import org.apache.spark.sql.types.{DataType, StructType} object CatalystSerde { def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = { val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer) - DeserializeToObject(Alias(deserializer, "obj")(), child) + DeserializeToObject(deserializer, generateObjAttr[T], child) } def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = { SerializeFromObject(encoderFor[T].namedExpressions, child) } + + def generateObjAttr[T : Encoder]: Attribute = { + AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)() + } } /** - * Takes the input row from child and turns it into object using the given deserializer expression. - * The output of this operator is a single-field safe row containing the deserialized object. + * A trait for logical operators that produces domain objects as output. + * The output of this operator is a single-field safe row containing the produced object. */ -case class DeserializeToObject( - deserializer: Alias, - child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = deserializer.toAttribute :: Nil +trait ObjectProducer extends LogicalPlan { + // The attribute that reference to the single object field this operator outputs. + protected def outputObjAttr: Attribute - def outputObjectType: DataType = deserializer.dataType + override def output: Seq[Attribute] = outputObjAttr :: Nil + + override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) + + def outputObjectType: DataType = outputObjAttr.dataType } /** - * Takes the input object from child and turns in into unsafe row using the given serializer - * expression. The output of its child must be a single-field row containing the input object. + * A trait for logical operators that consumes domain objects as input. + * The output of its child must be a single-field row containing the input object. */ -case class SerializeFromObject( - serializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = serializer.map(_.toAttribute) +trait ObjectConsumer extends UnaryNode { + assert(child.output.length == 1) + + // This operator always need all columns of its child, even it doesn't reference to. + override def references: AttributeSet = child.outputSet def inputObjectType: DataType = child.output.head.dataType } /** - * A trait for logical operators that apply user defined functions to domain objects. + * Takes the input row from child and turns it into object using the given deserializer expression. */ -trait ObjectOperator extends LogicalPlan { +case class DeserializeToObject( + deserializer: Expression, + outputObjAttr: Attribute, + child: LogicalPlan) extends UnaryNode with ObjectProducer - /** The serializer that is used to produce the output of this operator. */ - def serializer: Seq[NamedExpression] +/** + * Takes the input object from child and turns it into unsafe row using the given serializer + * expression. + */ +case class SerializeFromObject( + serializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectConsumer { override def output: Seq[Attribute] = serializer.map(_.toAttribute) - - /** - * The object type that is produced by the user defined function. Note that the return type here - * is the same whether or not the operator is output serialized data. - */ - def outputObject: NamedExpression = - Alias(serializer.head.collect { case b: BoundReference => b }.head, "obj")() - - /** - * Returns a copy of this operator that will produce an object instead of an encoded row. - * Used in the optimizer when transforming plans to remove unneeded serialization. - */ - def withObjectOutput: LogicalPlan = if (output.head.dataType.isInstanceOf[ObjectType]) { - this - } else { - withNewSerializer(outputObject :: Nil) - } - - /** Returns a copy of this operator with a different serializer. */ - def withNewSerializer(newSerializer: Seq[NamedExpression]): LogicalPlan = makeCopy { - productIterator.map { - case c if c == serializer => newSerializer - case other: AnyRef => other - }.toArray - } } object MapPartitions { def apply[T : Encoder, U : Encoder]( func: Iterator[T] => Iterator[U], - child: LogicalPlan): MapPartitions = { - MapPartitions( + child: LogicalPlan): LogicalPlan = { + val deserialized = CatalystSerde.deserialize[T](child) + val mapped = MapPartitions( func.asInstanceOf[Iterator[Any] => Iterator[Any]], - UnresolvedDeserializer(encoderFor[T].deserializer), - encoderFor[U].namedExpressions, - child) + CatalystSerde.generateObjAttr[U], + deserialized) + CatalystSerde.serialize[U](mapped) } } /** * A relation produced by applying `func` to each partition of the `child`. - * - * @param deserializer used to extract the input to `func` from an input row. - * @param serializer use to serialize the output of `func`. */ case class MapPartitions( func: Iterator[Any] => Iterator[Any], - deserializer: Expression, - serializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectOperator + outputObjAttr: Attribute, + child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer object MapElements { def apply[T : Encoder, U : Encoder]( func: AnyRef, - child: LogicalPlan): MapElements = { - MapElements( + child: LogicalPlan): LogicalPlan = { + val deserialized = CatalystSerde.deserialize[T](child) + val mapped = MapElements( func, - UnresolvedDeserializer(encoderFor[T].deserializer), - encoderFor[U].namedExpressions, - child) + CatalystSerde.generateObjAttr[U], + deserialized) + CatalystSerde.serialize[U](mapped) } } /** * A relation produced by applying `func` to each element of the `child`. - * - * @param deserializer used to extract the input to `func` from an input row. - * @param serializer use to serialize the output of `func`. */ case class MapElements( func: AnyRef, - deserializer: Expression, - serializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectOperator + outputObjAttr: Attribute, + child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { @@ -156,7 +141,7 @@ object AppendColumns { } /** - * A relation produced by applying `func` to each partition of the `child`, concatenating the + * A relation produced by applying `func` to each element of the `child`, concatenating the * resulting columns at the end of the input row. * * @param deserializer used to extract the input to `func` from an input row. @@ -166,28 +151,41 @@ case class AppendColumns( func: Any => Any, deserializer: Expression, serializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectOperator { + child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output ++ newColumns def newColumns: Seq[Attribute] = serializer.map(_.toAttribute) } +/** + * An optimized version of [[AppendColumns]], that can be executed on deserialized object directly. + */ +case class AppendColumnsWithObject( + func: Any => Any, + childSerializer: Seq[NamedExpression], + newColumnsSerializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectConsumer { + + override def output: Seq[Attribute] = (childSerializer ++ newColumnsSerializer).map(_.toAttribute) +} + /** Factory for constructing new `MapGroups` nodes. */ object MapGroups { def apply[K : Encoder, T : Encoder, U : Encoder]( func: (K, Iterator[T]) => TraversableOnce[U], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], - child: LogicalPlan): MapGroups = { - new MapGroups( + child: LogicalPlan): LogicalPlan = { + val mapped = new MapGroups( func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]], UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes), - encoderFor[U].namedExpressions, groupingAttributes, dataAttributes, + CatalystSerde.generateObjAttr[U], child) + CatalystSerde.serialize[U](mapped) } } @@ -198,43 +196,43 @@ object MapGroups { * * @param keyDeserializer used to extract the key object for each group. * @param valueDeserializer used to extract the items in the iterator from an input row. - * @param serializer use to serialize the output of `func`. */ case class MapGroups( func: (Any, Iterator[Any]) => TraversableOnce[Any], keyDeserializer: Expression, valueDeserializer: Expression, - serializer: Seq[NamedExpression], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], - child: LogicalPlan) extends UnaryNode with ObjectOperator + outputObjAttr: Attribute, + child: LogicalPlan) extends UnaryNode with ObjectProducer /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { - def apply[Key : Encoder, Left : Encoder, Right : Encoder, Result : Encoder]( - func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + def apply[K : Encoder, L : Encoder, R : Encoder, OUT : Encoder]( + func: (K, Iterator[L], Iterator[R]) => TraversableOnce[OUT], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], leftAttr: Seq[Attribute], rightAttr: Seq[Attribute], left: LogicalPlan, - right: LogicalPlan): CoGroup = { + right: LogicalPlan): LogicalPlan = { require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup)) - CoGroup( + val cogrouped = CoGroup( func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to // resolve the `keyDeserializer` based on either of them, here we pick the left one. - UnresolvedDeserializer(encoderFor[Key].deserializer, leftGroup), - UnresolvedDeserializer(encoderFor[Left].deserializer, leftAttr), - UnresolvedDeserializer(encoderFor[Right].deserializer, rightAttr), - encoderFor[Result].namedExpressions, + UnresolvedDeserializer(encoderFor[K].deserializer, leftGroup), + UnresolvedDeserializer(encoderFor[L].deserializer, leftAttr), + UnresolvedDeserializer(encoderFor[R].deserializer, rightAttr), leftGroup, rightGroup, leftAttr, rightAttr, + CatalystSerde.generateObjAttr[OUT], left, right) + CatalystSerde.serialize[OUT](cogrouped) } } @@ -247,10 +245,10 @@ case class CoGroup( keyDeserializer: Expression, leftDeserializer: Expression, rightDeserializer: Expression, - serializer: Seq[NamedExpression], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], leftAttr: Seq[Attribute], rightAttr: Seq[Attribute], + outputObjAttr: Attribute, left: LogicalPlan, - right: LogicalPlan) extends BinaryNode with ObjectOperator + right: LogicalPlan) extends BinaryNode with ObjectProducer diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala index 9177737560..3c033ddc37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala @@ -22,8 +22,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.NewInstance -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, MapPartitions} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -37,40 +36,45 @@ class EliminateSerializationSuite extends PlanTest { } implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() - private val func = identity[Iterator[(Int, Int)]] _ - private val func2 = identity[Iterator[OtherTuple]] _ + implicit private def intEncoder = ExpressionEncoder[Int]() - def assertObjectCreations(count: Int, plan: LogicalPlan): Unit = { - val newInstances = plan.flatMap(_.expressions.collect { - case n: NewInstance => n - }) - - if (newInstances.size != count) { - fail( - s""" - |Wrong number of object creations in plan: ${newInstances.size} != $count - |$plan - """.stripMargin) - } + test("back to back serialization") { + val input = LocalRelation('obj.obj(classOf[(Int, Int)])) + val plan = input.serialize[(Int, Int)].deserialize[(Int, Int)].analyze + val optimized = Optimize.execute(plan) + val expected = input.select('obj.as("obj")).analyze + comparePlans(optimized, expected) } - test("back to back MapPartitions") { - val input = LocalRelation('_1.int, '_2.int) - val plan = - MapPartitions(func, - MapPartitions(func, input)) - - val optimized = Optimize.execute(plan.analyze) - assertObjectCreations(1, optimized) + test("back to back serialization with object change") { + val input = LocalRelation('obj.obj(classOf[OtherTuple])) + val plan = input.serialize[OtherTuple].deserialize[(Int, Int)].analyze + val optimized = Optimize.execute(plan) + comparePlans(optimized, plan) } - test("back to back with object change") { - val input = LocalRelation('_1.int, '_2.int) - val plan = - MapPartitions(func, - MapPartitions(func2, input)) + test("back to back serialization in AppendColumns") { + val input = LocalRelation('obj.obj(classOf[(Int, Int)])) + val func = (item: (Int, Int)) => item._1 + val plan = AppendColumns(func, input.serialize[(Int, Int)]).analyze - val optimized = Optimize.execute(plan.analyze) - assertObjectCreations(2, optimized) + val optimized = Optimize.execute(plan) + + val expected = AppendColumnsWithObject( + func.asInstanceOf[Any => Any], + productEncoder[(Int, Int)].namedExpressions, + intEncoder.namedExpressions, + input).analyze + + comparePlans(optimized, expected) + } + + test("back to back serialization in AppendColumns with object change") { + val input = LocalRelation('obj.obj(classOf[OtherTuple])) + val func = (item: (Int, Int)) => item._1 + val plan = AppendColumns(func, input.serialize[OtherTuple]).analyze + + val optimized = Optimize.execute(plan) + comparePlans(optimized, plan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1a09d70fb9..3c708cbf29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2251,16 +2251,16 @@ class Dataset[T] private[sql]( def unpersist(): this.type = unpersist(blocking = false) /** - * Represents the content of the [[Dataset]] as an [[RDD]] of [[Row]]s. Note that the RDD is - * memoized. Once called, it won't change even if you change any query planning related Spark SQL - * configurations (e.g. `spark.sql.shuffle.partitions`). + * Represents the content of the [[Dataset]] as an [[RDD]] of [[T]]. * * @group rdd * @since 1.6.0 */ lazy val rdd: RDD[T] = { - queryExecution.toRdd.mapPartitions { rows => - rows.map(boundTEncoder.fromRow) + val objectType = unresolvedTEncoder.deserializer.dataType + val deserialized = CatalystSerde.deserialize[T](logicalPlan) + sqlContext.executePlan(deserialized).toRdd.mapPartitions { rows => + rows.map(_.get(0, objectType).asInstanceOf[T]) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c15aaed365..a4b0fa59db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -346,21 +346,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw new IllegalStateException( "logical intersect operator should have been replaced by semi-join in the optimizer") - case logical.DeserializeToObject(deserializer, child) => - execution.DeserializeToObject(deserializer, planLater(child)) :: Nil + case logical.DeserializeToObject(deserializer, objAttr, child) => + execution.DeserializeToObject(deserializer, objAttr, planLater(child)) :: Nil case logical.SerializeFromObject(serializer, child) => execution.SerializeFromObject(serializer, planLater(child)) :: Nil - case logical.MapPartitions(f, in, out, child) => - execution.MapPartitions(f, in, out, planLater(child)) :: Nil - case logical.MapElements(f, in, out, child) => - execution.MapElements(f, in, out, planLater(child)) :: Nil + case logical.MapPartitions(f, objAttr, child) => + execution.MapPartitions(f, objAttr, planLater(child)) :: Nil + case logical.MapElements(f, objAttr, child) => + execution.MapElements(f, objAttr, planLater(child)) :: Nil case logical.AppendColumns(f, in, out, child) => execution.AppendColumns(f, in, out, planLater(child)) :: Nil - case logical.MapGroups(f, key, in, out, grouping, data, child) => - execution.MapGroups(f, key, in, out, grouping, data, planLater(child)) :: Nil - case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr, left, right) => + case logical.AppendColumnsWithObject(f, childSer, newSer, child) => + execution.AppendColumnsWithObject(f, childSer, newSer, planLater(child)) :: Nil + case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => + execution.MapGroups(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil + case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroup( - f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr, + f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, planLater(left), planLater(right)) :: Nil case logical.Repartition(numPartitions, shuffle, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 46eaede5e7..23b2eabd0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -473,6 +473,10 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { * Inserts a WholeStageCodegen on top of those that support codegen. */ private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = plan match { + // For operators that will output domain object, do not insert WholeStageCodegen for it as + // domain object can not be written into unsafe row. + case plan if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] => + plan.withNewChildren(plan.children.map(insertWholeStageCodegen)) case plan: CodegenSupport if supportCodegen(plan) => WholeStageCodegen(insertInputAdapter(plan)) case other => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index e7261fc512..7c8bc7fed8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -25,16 +25,19 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.types.ObjectType +import org.apache.spark.sql.types.{DataType, ObjectType} /** * Takes the input row from child and turns it into object using the given deserializer expression. * The output of this operator is a single-field safe row containing the deserialized object. */ case class DeserializeToObject( - deserializer: Alias, + deserializer: Expression, + outputObjAttr: Attribute, child: SparkPlan) extends UnaryNode with CodegenSupport { - override def output: Seq[Attribute] = deserializer.toAttribute :: Nil + + override def output: Seq[Attribute] = outputObjAttr :: Nil + override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) override def inputRDDs(): Seq[RDD[InternalRow]] = { child.asInstanceOf[CodegenSupport].inputRDDs() @@ -67,6 +70,7 @@ case class DeserializeToObject( case class SerializeFromObject( serializer: Seq[NamedExpression], child: SparkPlan) extends UnaryNode with CodegenSupport { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) override def inputRDDs(): Seq[RDD[InternalRow]] = { @@ -98,60 +102,71 @@ case class SerializeFromObject( * Helper functions for physical operators that work with user defined objects. */ trait ObjectOperator extends SparkPlan { - def generateToObject(objExpr: Expression, inputSchema: Seq[Attribute]): InternalRow => Any = { - val objectProjection = GenerateSafeProjection.generate(objExpr :: Nil, inputSchema) - (i: InternalRow) => objectProjection(i).get(0, objExpr.dataType) + def deserializeRowToObject( + deserializer: Expression, + inputSchema: Seq[Attribute]): InternalRow => Any = { + val proj = GenerateSafeProjection.generate(deserializer :: Nil, inputSchema) + (i: InternalRow) => proj(i).get(0, deserializer.dataType) } - def generateToRow(serializer: Seq[Expression]): Any => InternalRow = { - val outputProjection = if (serializer.head.dataType.isInstanceOf[ObjectType]) { - GenerateSafeProjection.generate(serializer) - } else { - GenerateUnsafeProjection.generate(serializer) + def serializeObjectToRow(serializer: Seq[Expression]): Any => UnsafeRow = { + val proj = GenerateUnsafeProjection.generate(serializer) + val objType = serializer.head.collect { case b: BoundReference => b.dataType }.head + val objRow = new SpecificMutableRow(objType :: Nil) + (o: Any) => { + objRow(0) = o + proj(objRow) } - val inputType = serializer.head.collect { case b: BoundReference => b.dataType }.head - val outputRow = new SpecificMutableRow(inputType :: Nil) + } + + def wrapObjectToRow(objType: DataType): Any => InternalRow = { + val outputRow = new SpecificMutableRow(objType :: Nil) (o: Any) => { outputRow(0) = o - outputProjection(outputRow) + outputRow } } + + def unwrapObjectFromRow(objType: DataType): InternalRow => Any = { + (i: InternalRow) => i.get(0, objType) + } } /** - * Applies the given function to each input row and encodes the result. + * Applies the given function to input object iterator. + * The output of its child must be a single-field row containing the input object. */ case class MapPartitions( func: Iterator[Any] => Iterator[Any], - deserializer: Expression, - serializer: Seq[NamedExpression], + outputObjAttr: Attribute, child: SparkPlan) extends UnaryNode with ObjectOperator { - override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override def output: Seq[Attribute] = outputObjAttr :: Nil + override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => - val getObject = generateToObject(deserializer, child.output) - val outputObject = generateToRow(serializer) + val getObject = unwrapObjectFromRow(child.output.head.dataType) + val outputObject = wrapObjectToRow(outputObjAttr.dataType) func(iter.map(getObject)).map(outputObject) } } } /** - * Applies the given function to each input row and encodes the result. + * Applies the given function to each input object. + * The output of its child must be a single-field row containing the input object. * - * Note that, each serializer expression needs the result object which is returned by the given - * function, as input. This operator uses some tricks to make sure we only calculate the result - * object once. We don't use [[Project]] directly as subexpression elimination doesn't work with - * whole stage codegen and it's confusing to show the un-common-subexpression-eliminated version of - * a project while explain. + * This operator is kind of a safe version of [[Project]], as it's output is custom object, we need + * to use safe row to contain it. */ case class MapElements( func: AnyRef, - deserializer: Expression, - serializer: Seq[NamedExpression], + outputObjAttr: Attribute, child: SparkPlan) extends UnaryNode with ObjectOperator with CodegenSupport { - override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override def output: Seq[Attribute] = outputObjAttr :: Nil + override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) override def inputRDDs(): Seq[RDD[InternalRow]] = { child.asInstanceOf[CodegenSupport].inputRDDs() @@ -167,23 +182,14 @@ case class MapElements( case _ => classOf[Any => Any] -> "apply" } val funcObj = Literal.create(func, ObjectType(funcClass)) - val resultObjType = serializer.head.collect { case b: BoundReference => b }.head.dataType - val callFunc = Invoke(funcObj, methodName, resultObjType, Seq(deserializer)) + val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output) val bound = ExpressionCanonicalizer.execute( BindReferences.bindReference(callFunc, child.output)) ctx.currentVars = input - val evaluated = bound.genCode(ctx) + val resultVars = bound.genCode(ctx) :: Nil - val resultObj = LambdaVariable(evaluated.value, evaluated.isNull, resultObjType) - val outputFields = serializer.map(_ transform { - case _: BoundReference => resultObj - }) - val resultVars = outputFields.map(_.genCode(ctx)) - s""" - ${evaluated.code} - ${consume(ctx, resultVars)} - """ + consume(ctx, resultVars) } override protected def doExecute(): RDD[InternalRow] = { @@ -191,9 +197,10 @@ case class MapElements( case m: MapFunction[_, _] => i => m.asInstanceOf[MapFunction[Any, Any]].call(i) case _ => func.asInstanceOf[Any => Any] } + child.execute().mapPartitionsInternal { iter => - val getObject = generateToObject(deserializer, child.output) - val outputObject = generateToRow(serializer) + val getObject = unwrapObjectFromRow(child.output.head.dataType) + val outputObject = wrapObjectToRow(outputObjAttr.dataType) iter.map(row => outputObject(callFunc(getObject(row)))) } } @@ -216,15 +223,43 @@ case class AppendColumns( override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => - val getObject = generateToObject(deserializer, child.output) + val getObject = deserializeRowToObject(deserializer, child.output) val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema) - val outputObject = generateToRow(serializer) + val outputObject = serializeObjectToRow(serializer) iter.map { row => val newColumns = outputObject(func(getObject(row))) + combiner.join(row.asInstanceOf[UnsafeRow], newColumns): InternalRow + } + } + } +} - // This operates on the assumption that we always serialize the result... - combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow]): InternalRow +/** + * An optimized version of [[AppendColumns]], that can be executed on deserialized object directly. + */ +case class AppendColumnsWithObject( + func: Any => Any, + inputSerializer: Seq[NamedExpression], + newColumnsSerializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with ObjectOperator { + + override def output: Seq[Attribute] = (inputSerializer ++ newColumnsSerializer).map(_.toAttribute) + + private def inputSchema = inputSerializer.map(_.toAttribute).toStructType + private def newColumnSchema = newColumnsSerializer.map(_.toAttribute).toStructType + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val getChildObject = unwrapObjectFromRow(child.output.head.dataType) + val outputChildObject = serializeObjectToRow(inputSerializer) + val outputNewColumnOjb = serializeObjectToRow(newColumnsSerializer) + val combiner = GenerateUnsafeRowJoiner.create(inputSchema, newColumnSchema) + + iter.map { row => + val childObj = getChildObject(row) + val newColumns = outputNewColumnOjb(func(childObj)) + combiner.join(outputChildObject(childObj), newColumns): InternalRow } } } @@ -232,19 +267,19 @@ case class AppendColumns( /** * Groups the input rows together and calls the function with each group and an iterator containing - * all elements in the group. The result of this function is encoded and flattened before - * being output. + * all elements in the group. The result of this function is flattened before being output. */ case class MapGroups( func: (Any, Iterator[Any]) => TraversableOnce[Any], keyDeserializer: Expression, valueDeserializer: Expression, - serializer: Seq[NamedExpression], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, child: SparkPlan) extends UnaryNode with ObjectOperator { - override def output: Seq[Attribute] = serializer.map(_.toAttribute) + override def output: Seq[Attribute] = outputObjAttr :: Nil + override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(groupingAttributes) :: Nil @@ -256,9 +291,9 @@ case class MapGroups( child.execute().mapPartitionsInternal { iter => val grouped = GroupedIterator(iter, groupingAttributes, child.output) - val getKey = generateToObject(keyDeserializer, groupingAttributes) - val getValue = generateToObject(valueDeserializer, dataAttributes) - val outputObject = generateToRow(serializer) + val getKey = deserializeRowToObject(keyDeserializer, groupingAttributes) + val getValue = deserializeRowToObject(valueDeserializer, dataAttributes) + val outputObject = wrapObjectToRow(outputObjAttr.dataType) grouped.flatMap { case (key, rowIter) => val result = func( @@ -273,22 +308,23 @@ case class MapGroups( /** * Co-groups the data from left and right children, and calls the function with each group and 2 * iterators containing all elements in the group from left and right side. - * The result of this function is encoded and flattened before being output. + * The result of this function is flattened before being output. */ case class CoGroup( func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any], keyDeserializer: Expression, leftDeserializer: Expression, rightDeserializer: Expression, - serializer: Seq[NamedExpression], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], leftAttr: Seq[Attribute], rightAttr: Seq[Attribute], + outputObjAttr: Attribute, left: SparkPlan, right: SparkPlan) extends BinaryNode with ObjectOperator { - override def output: Seq[Attribute] = serializer.map(_.toAttribute) + override def output: Seq[Attribute] = outputObjAttr :: Nil + override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil @@ -301,10 +337,10 @@ case class CoGroup( val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) - val getKey = generateToObject(keyDeserializer, leftGroup) - val getLeft = generateToObject(leftDeserializer, leftAttr) - val getRight = generateToObject(rightDeserializer, rightAttr) - val outputObject = generateToRow(serializer) + val getKey = deserializeRowToObject(keyDeserializer, leftGroup) + val getLeft = deserializeRowToObject(leftDeserializer, leftAttr) + val getRight = deserializeRowToObject(rightDeserializer, rightAttr) + val outputObject = wrapObjectToRow(outputObjAttr.dataType) new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { case (key, leftResult, rightResult) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 23a0ce215f..2dca792c83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -201,7 +201,9 @@ abstract class QueryTest extends PlanTest { val logicalPlan = df.queryExecution.analyzed // bypass some cases that we can't handle currently. logicalPlan.transform { - case _: ObjectOperator => return + case _: ObjectConsumer => return + case _: ObjectProducer => return + case _: AppendColumns => return case _: LogicalRelation => return case _: MemoryPlan => return }.transformAllExpressions { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 8efd9de29e..d7cf1dc6aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -79,7 +79,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegen] && - p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined) + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[SerializeFromObject]).isDefined) assert(ds.collect() === 0.until(10).map(_.toString).toArray) }