From ac1c6aa45cf1dbfed9d5a5573548f0eb04cc8af6 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 5 Jul 2021 12:34:21 +0800 Subject: [PATCH] [SPARK-35987][SQL] The ANSI flags of Sum and Avg should be kept after being copied ### What changes were proposed in this pull request? Make the ANSI flag part of expressions `Sum` and `Average`'s parameter list, instead of fetching it from the sessional SQLConf. ### Why are the changes needed? For Views, it is important to show consistent results even the ANSI configuration is different in the running session. This is why many expressions like 'Add'/'Divide' making the ANSI flag part of its case class parameter list. We should make it consistent for the expressions `Sum` and `Average` ### Does this PR introduce _any_ user-facing change? Yes, the `Sum` and `Average` inside a View always behaves the same, independent of the ANSI model SQL configuration in the current session. ### How was this patch tested? Existing UT Closes #33186 from gengliangwang/sumAndAvg. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan (cherry picked from commit 51103cdcddfb62e6a066e53576ddb170c6784f54) Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 8 ++++---- .../catalyst/expressions/aggregate/Average.scala | 13 +++++++++++-- .../sql/catalyst/expressions/aggregate/Sum.scala | 13 +++++++++++-- .../spark/sql/catalyst/optimizer/Optimizer.scala | 8 ++++---- .../apache/spark/sql/RelationalGroupedDataset.scala | 6 +++--- 5 files changed, 33 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index ad541f611e..da0489499f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -421,8 +421,8 @@ abstract class TypeCoercionBase { m.copy(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) }) // Hive lets you do aggregation of timestamps... for some reason - case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType)) - case Average(e @ TimestampType()) => Average(Cast(e, DoubleType)) + case Sum(e @ TimestampType(), _) => Sum(Cast(e, DoubleType)) + case Average(e @ TimestampType(), _) => Average(Cast(e, DoubleType)) // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and @@ -1091,8 +1091,8 @@ object TypeCoercion extends TypeCoercionBase { p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) case Abs(e @ StringType(), failOnError) => Abs(Cast(e, DoubleType), failOnError) - case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) - case Average(e @ StringType()) => Average(Cast(e, DoubleType)) + case Sum(e @ StringType(), _) => Sum(Cast(e, DoubleType)) + case Average(e @ StringType(), _) => Average(Cast(e, DoubleType)) case s @ StddevPop(e @ StringType(), _) => s.withNewChildren(Seq(Cast(e, DoubleType))) case s @ StddevSamp(e @ StringType(), _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 4fae6dfc0d..7ede3fc50a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -37,9 +37,15 @@ import org.apache.spark.sql.types._ """, group = "agg_funcs", since = "1.0.0") -case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes +case class Average( + child: Expression, + failOnError: Boolean = SQLConf.get.ansiEnabled) + extends DeclarativeAggregate + with ImplicitCastInputTypes with UnaryLike[Expression] { + def this(child: Expression) = this(child, failOnError = SQLConf.get.ansiEnabled) + override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg") override def inputTypes: Seq[AbstractDataType] = @@ -91,7 +97,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case d: DecimalType => DecimalPrecision.decimalAndDecimal()( Divide( - CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled), + CheckOverflowInSum(sum, d, !failOnError), count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType) case _: YearMonthIntervalType => If(EqualTo(count, Literal(0L)), @@ -113,4 +119,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit override protected def withNewChildInternal(newChild: Expression): Average = copy(child = newChild) + + // The flag `failOnError` won't be shown in the `toString` or `toAggString` methods + override def flatArguments: Iterator[Any] = Iterator(child) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 80dda69b7a..ec7479af96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -39,9 +39,15 @@ import org.apache.spark.sql.types._ """, group = "agg_funcs", since = "1.0.0") -case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes +case class Sum( + child: Expression, + failOnError: Boolean = SQLConf.get.ansiEnabled) + extends DeclarativeAggregate + with ImplicitCastInputTypes with UnaryLike[Expression] { + def this(child: Expression) = this(child, failOnError = SQLConf.get.ansiEnabled) + override def nullable: Boolean = true // Return data type. @@ -151,9 +157,12 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast override lazy val evaluateExpression: Expression = resultType match { case d: DecimalType => If(isEmpty, Literal.create(null, resultType), - CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled)) + CheckOverflowInSum(sum, d, !failOnError)) case _ => sum } override protected def withNewChildInternal(newChild: Expression): Sum = copy(child = newChild) + + // The flag `failOnError` won't be shown in the `toString` or `toAggString` methods + override def flatArguments: Iterator[Any] = Iterator(child) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 934682ccc5..fd3b7a1d3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1709,11 +1709,11 @@ object DecimalAggregates extends Rule[LogicalPlan] { case q: LogicalPlan => q.transformExpressionsDownWithPruning( _.containsAnyPattern(SUM, AVERAGE), ruleId) { case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _, _), _) => af match { - case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => + case Sum(e @ DecimalType.Expression(prec, scale), _) if prec + 10 <= MAX_LONG_DIGITS => MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))), prec + 10, scale) - case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + case Average(e @ DecimalType.Expression(prec, scale), _) if prec + 4 <= MAX_DOUBLE_DIGITS => val newAggExpr = we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e)))) Cast( @@ -1723,10 +1723,10 @@ object DecimalAggregates extends Rule[LogicalPlan] { case _ => we } case ae @ AggregateExpression(af, _, _, _, _) => af match { - case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => + case Sum(e @ DecimalType.Expression(prec, scale), _) if prec + 10 <= MAX_LONG_DIGITS => MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) - case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + case Average(e @ DecimalType.Expression(prec, scale), _) if prec + 4 <= MAX_DOUBLE_DIGITS => val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) Cast( Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index bd51837c17..96bb1b3027 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -276,7 +276,7 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def mean(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Average) + aggregateNumericColumns(colNames : _*)(Average(_)) } /** @@ -300,7 +300,7 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def avg(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Average) + aggregateNumericColumns(colNames : _*)(Average(_)) } /** @@ -324,7 +324,7 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def sum(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Sum) + aggregateNumericColumns(colNames : _*)(Sum(_)) } /**