[SPARK-35987][SQL] The ANSI flags of Sum and Avg should be kept after being copied

### What changes were proposed in this pull request?

Make the ANSI flag part of expressions `Sum` and `Average`'s parameter list, instead of fetching it from the sessional SQLConf.

### Why are the changes needed?

For Views, it is important to show consistent results even the ANSI configuration is different in the running session. This is why many expressions like 'Add'/'Divide' making the ANSI flag part of its case class parameter list.

We should make it consistent for the expressions `Sum` and `Average`

### Does this PR introduce _any_ user-facing change?

Yes, the `Sum` and `Average` inside a View always behaves the same, independent of the ANSI model SQL configuration in the current session.

### How was this patch tested?

Existing UT

Closes #33186 from gengliangwang/sumAndAvg.

Authored-by: Gengliang Wang <gengliang@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 51103cdcdd)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Gengliang Wang 2021-07-05 12:34:21 +08:00 committed by Wenchen Fan
parent 873f6b9d97
commit ac1c6aa45c
5 changed files with 33 additions and 15 deletions

View file

@ -421,8 +421,8 @@ abstract class TypeCoercionBase {
m.copy(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) }) m.copy(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })
// Hive lets you do aggregation of timestamps... for some reason // Hive lets you do aggregation of timestamps... for some reason
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType)) case Sum(e @ TimestampType(), _) => Sum(Cast(e, DoubleType))
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType)) case Average(e @ TimestampType(), _) => Average(Cast(e, DoubleType))
// Coalesce should return the first non-null value, which could be any column // Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and // from the list. So we need to make sure the return type is deterministic and
@ -1091,8 +1091,8 @@ object TypeCoercion extends TypeCoercionBase {
p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))
case Abs(e @ StringType(), failOnError) => Abs(Cast(e, DoubleType), failOnError) case Abs(e @ StringType(), failOnError) => Abs(Cast(e, DoubleType), failOnError)
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Sum(e @ StringType(), _) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType)) case Average(e @ StringType(), _) => Average(Cast(e, DoubleType))
case s @ StddevPop(e @ StringType(), _) => case s @ StddevPop(e @ StringType(), _) =>
s.withNewChildren(Seq(Cast(e, DoubleType))) s.withNewChildren(Seq(Cast(e, DoubleType)))
case s @ StddevSamp(e @ StringType(), _) => case s @ StddevSamp(e @ StringType(), _) =>

View file

@ -37,9 +37,15 @@ import org.apache.spark.sql.types._
""", """,
group = "agg_funcs", group = "agg_funcs",
since = "1.0.0") since = "1.0.0")
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes case class Average(
child: Expression,
failOnError: Boolean = SQLConf.get.ansiEnabled)
extends DeclarativeAggregate
with ImplicitCastInputTypes
with UnaryLike[Expression] { with UnaryLike[Expression] {
def this(child: Expression) = this(child, failOnError = SQLConf.get.ansiEnabled)
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg") override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")
override def inputTypes: Seq[AbstractDataType] = override def inputTypes: Seq[AbstractDataType] =
@ -91,7 +97,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
case d: DecimalType => case d: DecimalType =>
DecimalPrecision.decimalAndDecimal()( DecimalPrecision.decimalAndDecimal()(
Divide( Divide(
CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled), CheckOverflowInSum(sum, d, !failOnError),
count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType) count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
case _: YearMonthIntervalType => case _: YearMonthIntervalType =>
If(EqualTo(count, Literal(0L)), If(EqualTo(count, Literal(0L)),
@ -113,4 +119,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
override protected def withNewChildInternal(newChild: Expression): Average = override protected def withNewChildInternal(newChild: Expression): Average =
copy(child = newChild) copy(child = newChild)
// The flag `failOnError` won't be shown in the `toString` or `toAggString` methods
override def flatArguments: Iterator[Any] = Iterator(child)
} }

View file

@ -39,9 +39,15 @@ import org.apache.spark.sql.types._
""", """,
group = "agg_funcs", group = "agg_funcs",
since = "1.0.0") since = "1.0.0")
case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes case class Sum(
child: Expression,
failOnError: Boolean = SQLConf.get.ansiEnabled)
extends DeclarativeAggregate
with ImplicitCastInputTypes
with UnaryLike[Expression] { with UnaryLike[Expression] {
def this(child: Expression) = this(child, failOnError = SQLConf.get.ansiEnabled)
override def nullable: Boolean = true override def nullable: Boolean = true
// Return data type. // Return data type.
@ -151,9 +157,12 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
override lazy val evaluateExpression: Expression = resultType match { override lazy val evaluateExpression: Expression = resultType match {
case d: DecimalType => case d: DecimalType =>
If(isEmpty, Literal.create(null, resultType), If(isEmpty, Literal.create(null, resultType),
CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled)) CheckOverflowInSum(sum, d, !failOnError))
case _ => sum case _ => sum
} }
override protected def withNewChildInternal(newChild: Expression): Sum = copy(child = newChild) override protected def withNewChildInternal(newChild: Expression): Sum = copy(child = newChild)
// The flag `failOnError` won't be shown in the `toString` or `toAggString` methods
override def flatArguments: Iterator[Any] = Iterator(child)
} }

View file

@ -1709,11 +1709,11 @@ object DecimalAggregates extends Rule[LogicalPlan] {
case q: LogicalPlan => q.transformExpressionsDownWithPruning( case q: LogicalPlan => q.transformExpressionsDownWithPruning(
_.containsAnyPattern(SUM, AVERAGE), ruleId) { _.containsAnyPattern(SUM, AVERAGE), ruleId) {
case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _, _), _) => af match { case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _, _), _) => af match {
case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => case Sum(e @ DecimalType.Expression(prec, scale), _) if prec + 10 <= MAX_LONG_DIGITS =>
MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))), MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))),
prec + 10, scale) prec + 10, scale)
case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => case Average(e @ DecimalType.Expression(prec, scale), _) if prec + 4 <= MAX_DOUBLE_DIGITS =>
val newAggExpr = val newAggExpr =
we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e)))) we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e))))
Cast( Cast(
@ -1723,10 +1723,10 @@ object DecimalAggregates extends Rule[LogicalPlan] {
case _ => we case _ => we
} }
case ae @ AggregateExpression(af, _, _, _, _) => af match { case ae @ AggregateExpression(af, _, _, _, _) => af match {
case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => case Sum(e @ DecimalType.Expression(prec, scale), _) if prec + 10 <= MAX_LONG_DIGITS =>
MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)
case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => case Average(e @ DecimalType.Expression(prec, scale), _) if prec + 4 <= MAX_DOUBLE_DIGITS =>
val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
Cast( Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),

View file

@ -276,7 +276,7 @@ class RelationalGroupedDataset protected[sql](
*/ */
@scala.annotation.varargs @scala.annotation.varargs
def mean(colNames: String*): DataFrame = { def mean(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Average) aggregateNumericColumns(colNames : _*)(Average(_))
} }
/** /**
@ -300,7 +300,7 @@ class RelationalGroupedDataset protected[sql](
*/ */
@scala.annotation.varargs @scala.annotation.varargs
def avg(colNames: String*): DataFrame = { def avg(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Average) aggregateNumericColumns(colNames : _*)(Average(_))
} }
/** /**
@ -324,7 +324,7 @@ class RelationalGroupedDataset protected[sql](
*/ */
@scala.annotation.varargs @scala.annotation.varargs
def sum(colNames: String*): DataFrame = { def sum(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Sum) aggregateNumericColumns(colNames : _*)(Sum(_))
} }
/** /**