From 1fda011d71d970996680fdef4a109805f9d3d385 Mon Sep 17 00:00:00 2001 From: Karen Feng Date: Fri, 2 Jul 2021 12:41:24 +0800 Subject: [PATCH] [SPARK-35955][SQL] Check for overflow in Average in ANSI mode ### What changes were proposed in this pull request? Fixes decimal overflow issues for decimal average in ANSI mode, so that overflows throw an exception rather than returning null. ### Why are the changes needed? Query: ``` scala> import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._ scala> spark.conf.set("spark.sql.ansi.enabled", true) scala> val df = Seq( | (BigDecimal("10000000000000000000"), 1), | (BigDecimal("10000000000000000000"), 1), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") df: org.apache.spark.sql.DataFrame = [decNum: decimal(38,18), intNum: int] scala> val df2 = df.withColumnRenamed("decNum", "decNum2").join(df, "intNum").agg(mean("decNum")) df2: org.apache.spark.sql.DataFrame = [avg(decNum): decimal(38,22)] scala> df2.show(40,false) ``` Before: ``` +-----------+ |avg(decNum)| +-----------+ |null | +-----------+ ``` After: ``` 21/07/01 19:48:31 ERROR Executor: Exception in task 0.0 in stage 3.0 (TID 24) java.lang.ArithmeticException: Overflow in sum of decimals. at org.apache.spark.sql.errors.QueryExecutionErrors$.overflowInSumOfDecimalError(QueryExecutionErrors.scala:162) at org.apache.spark.sql.errors.QueryExecutionErrors.overflowInSumOfDecimalError(QueryExecutionErrors.scala) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759) at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:349) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:898) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:898) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373) at org.apache.spark.rdd.RDD.iterator(RDD.scala:337) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) at org.apache.spark.scheduler.Task.run(Task.scala:131) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:499) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1462) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:502) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test Closes #33177 from karenfeng/SPARK-35955. Authored-by: Karen Feng Signed-off-by: Gengliang Wang --- .../expressions/aggregate/Average.scala | 7 +++++-- .../org/apache/spark/sql/DataFrameSuite.scala | 20 +++++++++++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 77a6cf7594..4fae6dfc0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreePattern.{AVERAGE, TreePattern} import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @ExpressionDescription( @@ -87,9 +88,11 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit // If all input are nulls, count will be 0 and we will get null after the division. // We can't directly use `/` as it throws an exception under ansi mode. override lazy val evaluateExpression = child.dataType match { - case _: DecimalType => + case d: DecimalType => DecimalPrecision.decimalAndDecimal()( - Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType) + Divide( + CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled), + count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType) case _: YearMonthIntervalType => If(EqualTo(count, Literal(0L)), Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 37916661f9..bfb8f19bc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -235,7 +235,7 @@ class DataFrameSuite extends QueryTest } } - test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") { + def checkAggResultsForDecimalOverflow(aggFn: Column => Column): Unit = { Seq("true", "false").foreach { wholeStageEnabled => withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { Seq(true, false).foreach { ansiEnabled => @@ -256,22 +256,22 @@ class DataFrameSuite extends QueryTest (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") val df = df0.union(df1) val df2 = df.withColumnRenamed("decNum", "decNum2"). - join(df, "intNum").agg(sum("decNum")) + join(df, "intNum").agg(aggFn($"decNum")) val expectedAnswer = Row(null) assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer) val decStr = "1" + "0" * 19 val d1 = spark.range(0, 12, 1, 1) - val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) + val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(aggFn($"d")) assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer) val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) - val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) + val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(aggFn($"d")) assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer) val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"), - lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd") + lit(1).as("key")).groupBy("key").agg(aggFn($"d").alias("aggd")).select($"aggd") assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer) val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d")) @@ -279,7 +279,7 @@ class DataFrameSuite extends QueryTest val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")). toDF("d") assertDecimalSumOverflow( - nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer) + nullsDf.union(largeDecimals).agg(aggFn($"d")), ansiEnabled, expectedAnswer) val df3 = Seq( (BigDecimal("10000000000000000000"), 1), @@ -306,6 +306,14 @@ class DataFrameSuite extends QueryTest } } + test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") { + checkAggResultsForDecimalOverflow(c => sum(c)) + } + + test("SPARK-35955: Aggregate avg should not return wrong results for decimal overflow") { + checkAggResultsForDecimalOverflow(c => avg(c)) + } + test("Star Expansion - ds.explode should fail with a meaningful message if it takes a star") { val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") val e = intercept[AnalysisException] {