[SPARK-22675][SQL] Refactoring PropagateTypes in TypeCoercion

## What changes were proposed in this pull request?
PropagateTypes are called twice in TypeCoercion. We do not need to call it twice. Instead, we should call it after each change on the types.

## How was this patch tested?
The existing tests

Author: gatorsmile <gatorsmile@gmail.com>

Closes #19874 from gatorsmile/deduplicatePropagateTypes.
This commit is contained in:
gatorsmile 2017-12-05 20:43:02 +08:00 committed by Wenchen Fan
parent a8af4da12c
commit 53e5251bb3
6 changed files with 71 additions and 58 deletions

View file

@ -58,7 +58,7 @@ import org.apache.spark.sql.types._
* - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE
*/
// scalastyle:on
object DecimalPrecision extends Rule[LogicalPlan] {
object DecimalPrecision extends TypeCoercionRule {
import scala.math.{max, min}
private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
@ -78,7 +78,7 @@ object DecimalPrecision extends Rule[LogicalPlan] {
PromotePrecision(Cast(e, dataType))
}
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// fix decimal precision for expressions
case q => q.transformExpressionsUp(
decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal))

View file

@ -22,6 +22,7 @@ import javax.annotation.Nullable
import scala.annotation.tailrec
import scala.collection.mutable
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
@ -45,8 +46,7 @@ import org.apache.spark.sql.types._
object TypeCoercion {
val typeCoercionRules =
PropagateTypes ::
InConversion ::
InConversion ::
WidenSetOperationTypes ::
PromoteStrings ::
DecimalPrecision ::
@ -56,7 +56,6 @@ object TypeCoercion {
IfCoercion ::
StackCoercion ::
Division ::
PropagateTypes ::
ImplicitTypeCasts ::
DateTimeOperations ::
WindowFrameCoercion ::
@ -220,38 +219,6 @@ object TypeCoercion {
private def haveSameType(exprs: Seq[Expression]): Boolean =
exprs.map(_.dataType).distinct.length == 1
/**
* Applies any changes to [[AttributeReference]] data types that are made by other rules to
* instances higher in the query tree.
*/
object PropagateTypes extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// No propagation required for leaf nodes.
case q: LogicalPlan if q.children.isEmpty => q
// Don't propagate types from unresolved children.
case q: LogicalPlan if !q.childrenResolved => q
case q: LogicalPlan =>
val inputMap = q.inputSet.toSeq.map(a => (a.exprId, a)).toMap
q transformExpressions {
case a: AttributeReference =>
inputMap.get(a.exprId) match {
// This can happen when an Attribute reference is born in a non-leaf node, for
// example due to a call to an external script like in the Transform operator.
// TODO: Perhaps those should actually be aliases?
case None => a
// Leave the same if the dataTypes match.
case Some(newType) if a.dataType == newType.dataType => a
case Some(newType) =>
logDebug(s"Promoting $a to $newType in ${q.simpleString}")
newType
}
}
}
}
/**
* Widens numeric types and converts strings to numbers when appropriate.
*
@ -345,7 +312,7 @@ object TypeCoercion {
/**
* Promotes strings that appear in arithmetic expressions.
*/
object PromoteStrings extends Rule[LogicalPlan] {
object PromoteStrings extends TypeCoercionRule {
private def castExpr(expr: Expression, targetType: DataType): Expression = {
(expr.dataType, targetType) match {
case (NullType, dt) => Literal.create(null, targetType)
@ -354,7 +321,7 @@ object TypeCoercion {
}
}
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@ -403,7 +370,7 @@ object TypeCoercion {
* operator type is found the original expression will be returned and an
* Analysis Exception will be raised at the type checking phase.
*/
object InConversion extends Rule[LogicalPlan] {
object InConversion extends TypeCoercionRule {
private def flattenExpr(expr: Expression): Seq[Expression] = {
expr match {
// Multi columns in IN clause is represented as a CreateNamedStruct.
@ -413,7 +380,7 @@ object TypeCoercion {
}
}
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@ -512,8 +479,8 @@ object TypeCoercion {
/**
* This ensure that the types for various functions are as expected.
*/
object FunctionArgumentConversion extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
object FunctionArgumentConversion extends TypeCoercionRule {
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@ -602,8 +569,8 @@ object TypeCoercion {
* Hive only performs integral division with the DIV operator. The arguments to / are always
* converted to fractional types.
*/
object Division extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
object Division extends TypeCoercionRule {
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who has not been resolved yet,
// as this is an extra rule which should be applied at last.
case e if !e.childrenResolved => e
@ -624,8 +591,8 @@ object TypeCoercion {
/**
* Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
object CaseWhenCoercion extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
object CaseWhenCoercion extends TypeCoercionRule {
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual =>
val maybeCommonType = findWiderCommonType(c.valueTypes)
maybeCommonType.map { commonType =>
@ -654,8 +621,8 @@ object TypeCoercion {
/**
* Coerces the type of different branches of If statement to a common type.
*/
object IfCoercion extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
object IfCoercion extends TypeCoercionRule {
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
case e if !e.childrenResolved => e
// Find tightest common type for If, if the true value and false value have different types.
case i @ If(pred, left, right) if left.dataType != right.dataType =>
@ -674,8 +641,8 @@ object TypeCoercion {
/**
* Coerces NullTypes in the Stack expression to the column types of the corresponding positions.
*/
object StackCoercion extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
object StackCoercion extends TypeCoercionRule {
override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case s @ Stack(children) if s.childrenResolved && s.hasFoldableNumRows =>
Stack(children.zipWithIndex.map {
// The first child is the number of rows for stack.
@ -711,8 +678,8 @@ object TypeCoercion {
/**
* Casts types according to the expected input types for [[Expression]]s.
*/
object ImplicitTypeCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
object ImplicitTypeCasts extends TypeCoercionRule {
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@ -828,8 +795,8 @@ object TypeCoercion {
/**
* Cast WindowFrame boundaries to the type they operate upon.
*/
object WindowFrameCoercion extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
object WindowFrameCoercion extends TypeCoercionRule {
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
case s @ WindowSpecDefinition(_, Seq(order), SpecifiedWindowFrame(RangeFrame, lower, upper))
if order.resolved =>
s.copy(frameSpecification = SpecifiedWindowFrame(
@ -850,3 +817,46 @@ object TypeCoercion {
}
}
}
trait TypeCoercionRule extends Rule[LogicalPlan] with Logging {
/**
* Applies any changes to [[AttributeReference]] data types that are made by the transform method
* to instances higher in the query tree.
*/
def apply(plan: LogicalPlan): LogicalPlan = {
val newPlan = coerceTypes(plan)
if (plan.fastEquals(newPlan)) {
plan
} else {
propagateTypes(newPlan)
}
}
protected def coerceTypes(plan: LogicalPlan): LogicalPlan
private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan transformUp {
// No propagation required for leaf nodes.
case q: LogicalPlan if q.children.isEmpty => q
// Don't propagate types from unresolved children.
case q: LogicalPlan if !q.childrenResolved => q
case q: LogicalPlan =>
val inputMap = q.inputSet.toSeq.map(a => (a.exprId, a)).toMap
q transformExpressions {
case a: AttributeReference =>
inputMap.get(a.exprId) match {
// This can happen when an Attribute reference is born in a non-leaf node, for
// example due to a call to an external script like in the Transform operator.
// TODO: Perhaps those should actually be aliases?
case None => a
// Leave the same if the dataTypes match.
case Some(newType) if a.dataType == newType.dataType => a
case Some(newType) =>
logDebug(
s"Promoting $a from ${a.dataType} to ${newType.dataType} in ${q.simpleString}")
newType
}
}
}
}

View file

@ -31,7 +31,7 @@ package org.apache.spark.sql.catalyst.expressions
* - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`.
* - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`.
*/
object Canonicalize extends {
object Canonicalize {
def execute(e: Expression): Expression = {
expressionReorder(ignoreNamesTypes(e))
}

View file

@ -300,7 +300,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext {
Locale.setDefault(originalLocale)
// For debugging dump some statistics about how much time was spent in various optimizer rules
logWarning(RuleExecutor.dumpTimeSpent())
logInfo(RuleExecutor.dumpTimeSpent())
} finally {
super.afterAll()
}

View file

@ -19,6 +19,7 @@ package org.apache.spark.sql
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.resourceToString
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
@ -39,6 +40,8 @@ class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfte
*/
protected override def afterAll(): Unit = {
try {
// For debugging dump some statistics about how much time was spent in various optimizer rules
logInfo(RuleExecutor.dumpTimeSpent())
spark.sessionState.catalog.reset()
} finally {
super.afterAll()

View file

@ -76,7 +76,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone)
// For debugging dump some statistics about how much time was spent in various optimizer rules
logWarning(RuleExecutor.dumpTimeSpent())
logInfo(RuleExecutor.dumpTimeSpent())
} finally {
super.afterAll()
}