[SPARK-16898][SQL] Adds argument type information for typed logical plan like MapElements, TypedFilter, and AppendColumn

## What changes were proposed in this pull request?

This PR adds argument type information for typed logical plan like MapElements, TypedFilter, and AppendColumn, so that we can use these info in customized optimizer rule.

## How was this patch tested?

Existing test.

Author: Sean Zhong <seanzhong@databricks.com>

Closes #14494 from clockfly/add_more_info_for_typed_operator.
This commit is contained in:
Sean Zhong 2016-08-09 08:36:50 +08:00 committed by Wenchen Fan
parent df10658831
commit bca43cd635
4 changed files with 31 additions and 11 deletions

View file

@ -214,7 +214,7 @@ object EliminateSerialization extends Rule[LogicalPlan] {
val objAttr = Alias(s.inputObjAttr, s.inputObjAttr.name)(exprId = d.outputObjAttr.exprId) val objAttr = Alias(s.inputObjAttr, s.inputObjAttr.name)(exprId = d.outputObjAttr.exprId)
Project(objAttr :: Nil, s.child) Project(objAttr :: Nil, s.child)
case a @ AppendColumns(_, _, _, s: SerializeFromObject) case a @ AppendColumns(_, _, _, _, _, s: SerializeFromObject)
if a.deserializer.dataType == s.inputObjAttr.dataType => if a.deserializer.dataType == s.inputObjAttr.dataType =>
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)
@ -223,7 +223,7 @@ object EliminateSerialization extends Rule[LogicalPlan] {
// deserialization in condition, and push it down through `SerializeFromObject`. // deserialization in condition, and push it down through `SerializeFromObject`.
// e.g. `ds.map(...).filter(...)` can be optimized by this rule to save extra deserialization, // e.g. `ds.map(...).filter(...)` can be optimized by this rule to save extra deserialization,
// but `ds.map(...).as[AnotherType].filter(...)` can not be optimized. // but `ds.map(...).as[AnotherType].filter(...)` can not be optimized.
case f @ TypedFilter(_, _, s: SerializeFromObject) case f @ TypedFilter(_, _, _, _, s: SerializeFromObject)
if f.deserializer.dataType == s.inputObjAttr.dataType => if f.deserializer.dataType == s.inputObjAttr.dataType =>
s.copy(child = f.withObjectProducerChild(s.child)) s.copy(child = f.withObjectProducerChild(s.child))
@ -1703,9 +1703,14 @@ case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[Logic
*/ */
object CombineTypedFilters extends Rule[LogicalPlan] { object CombineTypedFilters extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform { def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case t1 @ TypedFilter(_, _, t2 @ TypedFilter(_, _, child)) case t1 @ TypedFilter(_, _, _, _, t2 @ TypedFilter(_, _, _, _, child))
if t1.deserializer.dataType == t2.deserializer.dataType => if t1.deserializer.dataType == t2.deserializer.dataType =>
TypedFilter(combineFilterFunction(t2.func, t1.func), t1.deserializer, child) TypedFilter(
combineFilterFunction(t2.func, t1.func),
t1.argumentClass,
t1.argumentSchema,
t1.deserializer,
child)
} }
private def combineFilterFunction(func1: AnyRef, func2: AnyRef): Any => Boolean = { private def combineFilterFunction(func1: AnyRef, func2: AnyRef): Any => Boolean = {

View file

@ -155,6 +155,8 @@ object MapElements {
val deserialized = CatalystSerde.deserialize[T](child) val deserialized = CatalystSerde.deserialize[T](child)
val mapped = MapElements( val mapped = MapElements(
func, func,
implicitly[Encoder[T]].clsTag.runtimeClass,
implicitly[Encoder[T]].schema,
CatalystSerde.generateObjAttr[U], CatalystSerde.generateObjAttr[U],
deserialized) deserialized)
CatalystSerde.serialize[U](mapped) CatalystSerde.serialize[U](mapped)
@ -166,12 +168,19 @@ object MapElements {
*/ */
case class MapElements( case class MapElements(
func: AnyRef, func: AnyRef,
argumentClass: Class[_],
argumentSchema: StructType,
outputObjAttr: Attribute, outputObjAttr: Attribute,
child: LogicalPlan) extends ObjectConsumer with ObjectProducer child: LogicalPlan) extends ObjectConsumer with ObjectProducer
object TypedFilter { object TypedFilter {
def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = { def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = {
TypedFilter(func, UnresolvedDeserializer(encoderFor[T].deserializer), child) TypedFilter(
func,
implicitly[Encoder[T]].clsTag.runtimeClass,
implicitly[Encoder[T]].schema,
UnresolvedDeserializer(encoderFor[T].deserializer),
child)
} }
} }
@ -186,6 +195,8 @@ object TypedFilter {
*/ */
case class TypedFilter( case class TypedFilter(
func: AnyRef, func: AnyRef,
argumentClass: Class[_],
argumentSchema: StructType,
deserializer: Expression, deserializer: Expression,
child: LogicalPlan) extends UnaryNode { child: LogicalPlan) extends UnaryNode {
@ -213,6 +224,8 @@ object AppendColumns {
child: LogicalPlan): AppendColumns = { child: LogicalPlan): AppendColumns = {
new AppendColumns( new AppendColumns(
func.asInstanceOf[Any => Any], func.asInstanceOf[Any => Any],
implicitly[Encoder[T]].clsTag.runtimeClass,
implicitly[Encoder[T]].schema,
UnresolvedDeserializer(encoderFor[T].deserializer), UnresolvedDeserializer(encoderFor[T].deserializer),
encoderFor[U].namedExpressions, encoderFor[U].namedExpressions,
child) child)
@ -228,6 +241,8 @@ object AppendColumns {
*/ */
case class AppendColumns( case class AppendColumns(
func: Any => Any, func: Any => Any,
argumentClass: Class[_],
argumentSchema: StructType,
deserializer: Expression, deserializer: Expression,
serializer: Seq[NamedExpression], serializer: Seq[NamedExpression],
child: LogicalPlan) extends UnaryNode { child: LogicalPlan) extends UnaryNode {

View file

@ -356,9 +356,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) => case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) =>
execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping, execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping,
data, objAttr, planLater(child)) :: Nil data, objAttr, planLater(child)) :: Nil
case logical.MapElements(f, objAttr, child) => case logical.MapElements(f, _, _, objAttr, child) =>
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
case logical.AppendColumns(f, in, out, child) => case logical.AppendColumns(f, _, _, in, out, child) =>
execution.AppendColumnsExec(f, in, out, planLater(child)) :: Nil execution.AppendColumnsExec(f, in, out, planLater(child)) :: Nil
case logical.AppendColumnsWithObject(f, childSer, newSer, child) => case logical.AppendColumnsWithObject(f, childSer, newSer, child) =>
execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil

View file

@ -27,7 +27,7 @@ import org.apache.spark.sql.expressions.Aggregator
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] { class TypedSumDouble[IN](val f: IN => Double) extends Aggregator[IN, Double, Double] {
override def zero: Double = 0.0 override def zero: Double = 0.0
override def reduce(b: Double, a: IN): Double = b + f(a) override def reduce(b: Double, a: IN): Double = b + f(a)
override def merge(b1: Double, b2: Double): Double = b1 + b2 override def merge(b1: Double, b2: Double): Double = b1 + b2
@ -45,7 +45,7 @@ class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double]
} }
class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] { class TypedSumLong[IN](val f: IN => Long) extends Aggregator[IN, Long, Long] {
override def zero: Long = 0L override def zero: Long = 0L
override def reduce(b: Long, a: IN): Long = b + f(a) override def reduce(b: Long, a: IN): Long = b + f(a)
override def merge(b1: Long, b2: Long): Long = b1 + b2 override def merge(b1: Long, b2: Long): Long = b1 + b2
@ -63,7 +63,7 @@ class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] {
} }
class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] { class TypedCount[IN](val f: IN => Any) extends Aggregator[IN, Long, Long] {
override def zero: Long = 0 override def zero: Long = 0
override def reduce(b: Long, a: IN): Long = { override def reduce(b: Long, a: IN): Long = {
if (f(a) == null) b else b + 1 if (f(a) == null) b else b + 1
@ -82,7 +82,7 @@ class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] {
} }
class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), Double] { class TypedAverage[IN](val f: IN => Double) extends Aggregator[IN, (Double, Long), Double] {
override def zero: (Double, Long) = (0.0, 0L) override def zero: (Double, Long) = (0.0, 0L)
override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2) override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2)
override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2 override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2