[SPARK-35987][SQL] The ANSI flags of Sum and Avg should be kept after being copied

### What changes were proposed in this pull request?

Make the ANSI flag part of expressions `Sum` and `Average`'s parameter list, instead of fetching it from the sessional SQLConf.

### Why are the changes needed?

For Views, it is important to show consistent results even the ANSI configuration is different in the running session. This is why many expressions like 'Add'/'Divide' making the ANSI flag part of its case class parameter list.

We should make it consistent for the expressions `Sum` and `Average`

### Does this PR introduce _any_ user-facing change?

Yes, the `Sum` and `Average` inside a View always behaves the same, independent of the ANSI model SQL configuration in the current session.

### How was this patch tested?

Existing UT

Closes #33186 from gengliangwang/sumAndAvg.

Authored-by: Gengliang Wang <gengliang@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 51103cdcdd)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Gengliang Wang 2021-07-05 12:34:21 +08:00 committed by Wenchen Fan
parent 873f6b9d97
commit ac1c6aa45c
5 changed files with 33 additions and 15 deletions

View file

@ -421,8 +421,8 @@ abstract class TypeCoercionBase {
m.copy(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })
// Hive lets you do aggregation of timestamps... for some reason
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))
case Sum(e @ TimestampType(), _) => Sum(Cast(e, DoubleType))
case Average(e @ TimestampType(), _) => Average(Cast(e, DoubleType))
// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
@ -1091,8 +1091,8 @@ object TypeCoercion extends TypeCoercionBase {
p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))
case Abs(e @ StringType(), failOnError) => Abs(Cast(e, DoubleType), failOnError)
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
case Sum(e @ StringType(), _) => Sum(Cast(e, DoubleType))
case Average(e @ StringType(), _) => Average(Cast(e, DoubleType))
case s @ StddevPop(e @ StringType(), _) =>
s.withNewChildren(Seq(Cast(e, DoubleType)))
case s @ StddevSamp(e @ StringType(), _) =>

View file

@ -37,9 +37,15 @@ import org.apache.spark.sql.types._
""",
group = "agg_funcs",
since = "1.0.0")
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
case class Average(
child: Expression,
failOnError: Boolean = SQLConf.get.ansiEnabled)
extends DeclarativeAggregate
with ImplicitCastInputTypes
with UnaryLike[Expression] {
def this(child: Expression) = this(child, failOnError = SQLConf.get.ansiEnabled)
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")
override def inputTypes: Seq[AbstractDataType] =
@ -91,7 +97,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
case d: DecimalType =>
DecimalPrecision.decimalAndDecimal()(
Divide(
CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled),
CheckOverflowInSum(sum, d, !failOnError),
count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
case _: YearMonthIntervalType =>
If(EqualTo(count, Literal(0L)),
@ -113,4 +119,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
override protected def withNewChildInternal(newChild: Expression): Average =
copy(child = newChild)
// The flag `failOnError` won't be shown in the `toString` or `toAggString` methods
override def flatArguments: Iterator[Any] = Iterator(child)
}

View file

@ -39,9 +39,15 @@ import org.apache.spark.sql.types._
""",
group = "agg_funcs",
since = "1.0.0")
case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
case class Sum(
child: Expression,
failOnError: Boolean = SQLConf.get.ansiEnabled)
extends DeclarativeAggregate
with ImplicitCastInputTypes
with UnaryLike[Expression] {
def this(child: Expression) = this(child, failOnError = SQLConf.get.ansiEnabled)
override def nullable: Boolean = true
// Return data type.
@ -151,9 +157,12 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
override lazy val evaluateExpression: Expression = resultType match {
case d: DecimalType =>
If(isEmpty, Literal.create(null, resultType),
CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled))
CheckOverflowInSum(sum, d, !failOnError))
case _ => sum
}
override protected def withNewChildInternal(newChild: Expression): Sum = copy(child = newChild)
// The flag `failOnError` won't be shown in the `toString` or `toAggString` methods
override def flatArguments: Iterator[Any] = Iterator(child)
}

View file

@ -1709,11 +1709,11 @@ object DecimalAggregates extends Rule[LogicalPlan] {
case q: LogicalPlan => q.transformExpressionsDownWithPruning(
_.containsAnyPattern(SUM, AVERAGE), ruleId) {
case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _, _), _) => af match {
case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
case Sum(e @ DecimalType.Expression(prec, scale), _) if prec + 10 <= MAX_LONG_DIGITS =>
MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))),
prec + 10, scale)
case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
case Average(e @ DecimalType.Expression(prec, scale), _) if prec + 4 <= MAX_DOUBLE_DIGITS =>
val newAggExpr =
we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e))))
Cast(
@ -1723,10 +1723,10 @@ object DecimalAggregates extends Rule[LogicalPlan] {
case _ => we
}
case ae @ AggregateExpression(af, _, _, _, _) => af match {
case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
case Sum(e @ DecimalType.Expression(prec, scale), _) if prec + 10 <= MAX_LONG_DIGITS =>
MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)
case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
case Average(e @ DecimalType.Expression(prec, scale), _) if prec + 4 <= MAX_DOUBLE_DIGITS =>
val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),

View file

@ -276,7 +276,7 @@ class RelationalGroupedDataset protected[sql](
*/
@scala.annotation.varargs
def mean(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Average)
aggregateNumericColumns(colNames : _*)(Average(_))
}
/**
@ -300,7 +300,7 @@ class RelationalGroupedDataset protected[sql](
*/
@scala.annotation.varargs
def avg(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Average)
aggregateNumericColumns(colNames : _*)(Average(_))
}
/**
@ -324,7 +324,7 @@ class RelationalGroupedDataset protected[sql](
*/
@scala.annotation.varargs
def sum(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Sum)
aggregateNumericColumns(colNames : _*)(Sum(_))
}
/**