[SPARK-35987][SQL] The ANSI flags of Sum and Avg should be kept after being copied
### What changes were proposed in this pull request?
Make the ANSI flag part of expressions `Sum` and `Average`'s parameter list, instead of fetching it from the sessional SQLConf.
### Why are the changes needed?
For Views, it is important to show consistent results even the ANSI configuration is different in the running session. This is why many expressions like 'Add'/'Divide' making the ANSI flag part of its case class parameter list.
We should make it consistent for the expressions `Sum` and `Average`
### Does this PR introduce _any_ user-facing change?
Yes, the `Sum` and `Average` inside a View always behaves the same, independent of the ANSI model SQL configuration in the current session.
### How was this patch tested?
Existing UT
Closes #33186 from gengliangwang/sumAndAvg.
Authored-by: Gengliang Wang <gengliang@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 51103cdcdd
)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
873f6b9d97
commit
ac1c6aa45c
|
@ -421,8 +421,8 @@ abstract class TypeCoercionBase {
|
|||
m.copy(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })
|
||||
|
||||
// Hive lets you do aggregation of timestamps... for some reason
|
||||
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
|
||||
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))
|
||||
case Sum(e @ TimestampType(), _) => Sum(Cast(e, DoubleType))
|
||||
case Average(e @ TimestampType(), _) => Average(Cast(e, DoubleType))
|
||||
|
||||
// Coalesce should return the first non-null value, which could be any column
|
||||
// from the list. So we need to make sure the return type is deterministic and
|
||||
|
@ -1091,8 +1091,8 @@ object TypeCoercion extends TypeCoercionBase {
|
|||
p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))
|
||||
|
||||
case Abs(e @ StringType(), failOnError) => Abs(Cast(e, DoubleType), failOnError)
|
||||
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
|
||||
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
|
||||
case Sum(e @ StringType(), _) => Sum(Cast(e, DoubleType))
|
||||
case Average(e @ StringType(), _) => Average(Cast(e, DoubleType))
|
||||
case s @ StddevPop(e @ StringType(), _) =>
|
||||
s.withNewChildren(Seq(Cast(e, DoubleType)))
|
||||
case s @ StddevSamp(e @ StringType(), _) =>
|
||||
|
|
|
@ -37,9 +37,15 @@ import org.apache.spark.sql.types._
|
|||
""",
|
||||
group = "agg_funcs",
|
||||
since = "1.0.0")
|
||||
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
|
||||
case class Average(
|
||||
child: Expression,
|
||||
failOnError: Boolean = SQLConf.get.ansiEnabled)
|
||||
extends DeclarativeAggregate
|
||||
with ImplicitCastInputTypes
|
||||
with UnaryLike[Expression] {
|
||||
|
||||
def this(child: Expression) = this(child, failOnError = SQLConf.get.ansiEnabled)
|
||||
|
||||
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")
|
||||
|
||||
override def inputTypes: Seq[AbstractDataType] =
|
||||
|
@ -91,7 +97,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
|
|||
case d: DecimalType =>
|
||||
DecimalPrecision.decimalAndDecimal()(
|
||||
Divide(
|
||||
CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled),
|
||||
CheckOverflowInSum(sum, d, !failOnError),
|
||||
count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
|
||||
case _: YearMonthIntervalType =>
|
||||
If(EqualTo(count, Literal(0L)),
|
||||
|
@ -113,4 +119,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
|
|||
|
||||
override protected def withNewChildInternal(newChild: Expression): Average =
|
||||
copy(child = newChild)
|
||||
|
||||
// The flag `failOnError` won't be shown in the `toString` or `toAggString` methods
|
||||
override def flatArguments: Iterator[Any] = Iterator(child)
|
||||
}
|
||||
|
|
|
@ -39,9 +39,15 @@ import org.apache.spark.sql.types._
|
|||
""",
|
||||
group = "agg_funcs",
|
||||
since = "1.0.0")
|
||||
case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
|
||||
case class Sum(
|
||||
child: Expression,
|
||||
failOnError: Boolean = SQLConf.get.ansiEnabled)
|
||||
extends DeclarativeAggregate
|
||||
with ImplicitCastInputTypes
|
||||
with UnaryLike[Expression] {
|
||||
|
||||
def this(child: Expression) = this(child, failOnError = SQLConf.get.ansiEnabled)
|
||||
|
||||
override def nullable: Boolean = true
|
||||
|
||||
// Return data type.
|
||||
|
@ -151,9 +157,12 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
|
|||
override lazy val evaluateExpression: Expression = resultType match {
|
||||
case d: DecimalType =>
|
||||
If(isEmpty, Literal.create(null, resultType),
|
||||
CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled))
|
||||
CheckOverflowInSum(sum, d, !failOnError))
|
||||
case _ => sum
|
||||
}
|
||||
|
||||
override protected def withNewChildInternal(newChild: Expression): Sum = copy(child = newChild)
|
||||
|
||||
// The flag `failOnError` won't be shown in the `toString` or `toAggString` methods
|
||||
override def flatArguments: Iterator[Any] = Iterator(child)
|
||||
}
|
||||
|
|
|
@ -1709,11 +1709,11 @@ object DecimalAggregates extends Rule[LogicalPlan] {
|
|||
case q: LogicalPlan => q.transformExpressionsDownWithPruning(
|
||||
_.containsAnyPattern(SUM, AVERAGE), ruleId) {
|
||||
case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _, _), _) => af match {
|
||||
case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
|
||||
case Sum(e @ DecimalType.Expression(prec, scale), _) if prec + 10 <= MAX_LONG_DIGITS =>
|
||||
MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))),
|
||||
prec + 10, scale)
|
||||
|
||||
case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
|
||||
case Average(e @ DecimalType.Expression(prec, scale), _) if prec + 4 <= MAX_DOUBLE_DIGITS =>
|
||||
val newAggExpr =
|
||||
we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e))))
|
||||
Cast(
|
||||
|
@ -1723,10 +1723,10 @@ object DecimalAggregates extends Rule[LogicalPlan] {
|
|||
case _ => we
|
||||
}
|
||||
case ae @ AggregateExpression(af, _, _, _, _) => af match {
|
||||
case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
|
||||
case Sum(e @ DecimalType.Expression(prec, scale), _) if prec + 10 <= MAX_LONG_DIGITS =>
|
||||
MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)
|
||||
|
||||
case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
|
||||
case Average(e @ DecimalType.Expression(prec, scale), _) if prec + 4 <= MAX_DOUBLE_DIGITS =>
|
||||
val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
|
||||
Cast(
|
||||
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
|
||||
|
|
|
@ -276,7 +276,7 @@ class RelationalGroupedDataset protected[sql](
|
|||
*/
|
||||
@scala.annotation.varargs
|
||||
def mean(colNames: String*): DataFrame = {
|
||||
aggregateNumericColumns(colNames : _*)(Average)
|
||||
aggregateNumericColumns(colNames : _*)(Average(_))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -300,7 +300,7 @@ class RelationalGroupedDataset protected[sql](
|
|||
*/
|
||||
@scala.annotation.varargs
|
||||
def avg(colNames: String*): DataFrame = {
|
||||
aggregateNumericColumns(colNames : _*)(Average)
|
||||
aggregateNumericColumns(colNames : _*)(Average(_))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -324,7 +324,7 @@ class RelationalGroupedDataset protected[sql](
|
|||
*/
|
||||
@scala.annotation.varargs
|
||||
def sum(colNames: String*): DataFrame = {
|
||||
aggregateNumericColumns(colNames : _*)(Sum)
|
||||
aggregateNumericColumns(colNames : _*)(Sum(_))
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
Loading…
Reference in a new issue