[SPARK-35955][SQL] Check for overflow in Average in ANSI mode

### What changes were proposed in this pull request?

Fixes decimal overflow issues for decimal average in ANSI mode, so that overflows throw an exception rather than returning null.

### Why are the changes needed?

Query:

```
scala> import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions._

scala> spark.conf.set("spark.sql.ansi.enabled", true)

scala> val df = Seq(
     |  (BigDecimal("10000000000000000000"), 1),
     |  (BigDecimal("10000000000000000000"), 1),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
df: org.apache.spark.sql.DataFrame = [decNum: decimal(38,18), intNum: int]

scala> val df2 = df.withColumnRenamed("decNum", "decNum2").join(df, "intNum").agg(mean("decNum"))
df2: org.apache.spark.sql.DataFrame = [avg(decNum): decimal(38,22)]

scala> df2.show(40,false)
```

Before:
```
+-----------+
|avg(decNum)|
+-----------+
|null       |
+-----------+
```

After:
```
21/07/01 19:48:31 ERROR Executor: Exception in task 0.0 in stage 3.0 (TID 24)
java.lang.ArithmeticException: Overflow in sum of decimals.
	at org.apache.spark.sql.errors.QueryExecutionErrors$.overflowInSumOfDecimalError(QueryExecutionErrors.scala:162)
	at org.apache.spark.sql.errors.QueryExecutionErrors.overflowInSumOfDecimalError(QueryExecutionErrors.scala)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:349)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:898)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:898)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:131)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:499)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1462)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:502)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)
```

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Unit test

Closes #33177 from karenfeng/SPARK-35955.

Authored-by: Karen Feng <karen.feng@databricks.com>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
This commit is contained in:
Karen Feng 2021-07-02 12:41:24 +08:00 committed by Gengliang Wang
parent 47485a3c2d
commit 1fda011d71
2 changed files with 19 additions and 8 deletions

View file

@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreePattern.{AVERAGE, TreePattern} import org.apache.spark.sql.catalyst.trees.TreePattern.{AVERAGE, TreePattern}
import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
@ExpressionDescription( @ExpressionDescription(
@ -87,9 +88,11 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
// If all input are nulls, count will be 0 and we will get null after the division. // If all input are nulls, count will be 0 and we will get null after the division.
// We can't directly use `/` as it throws an exception under ansi mode. // We can't directly use `/` as it throws an exception under ansi mode.
override lazy val evaluateExpression = child.dataType match { override lazy val evaluateExpression = child.dataType match {
case _: DecimalType => case d: DecimalType =>
DecimalPrecision.decimalAndDecimal()( DecimalPrecision.decimalAndDecimal()(
Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType) Divide(
CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled),
count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
case _: YearMonthIntervalType => case _: YearMonthIntervalType =>
If(EqualTo(count, Literal(0L)), If(EqualTo(count, Literal(0L)),
Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count)) Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count))

View file

@ -235,7 +235,7 @@ class DataFrameSuite extends QueryTest
} }
} }
test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") { def checkAggResultsForDecimalOverflow(aggFn: Column => Column): Unit = {
Seq("true", "false").foreach { wholeStageEnabled => Seq("true", "false").foreach { wholeStageEnabled =>
withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) {
Seq(true, false).foreach { ansiEnabled => Seq(true, false).foreach { ansiEnabled =>
@ -256,22 +256,22 @@ class DataFrameSuite extends QueryTest
(BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
val df = df0.union(df1) val df = df0.union(df1)
val df2 = df.withColumnRenamed("decNum", "decNum2"). val df2 = df.withColumnRenamed("decNum", "decNum2").
join(df, "intNum").agg(sum("decNum")) join(df, "intNum").agg(aggFn($"decNum"))
val expectedAnswer = Row(null) val expectedAnswer = Row(null)
assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer) assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer)
val decStr = "1" + "0" * 19 val decStr = "1" + "0" * 19
val d1 = spark.range(0, 12, 1, 1) val d1 = spark.range(0, 12, 1, 1)
val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(aggFn($"d"))
assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer) assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer)
val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(aggFn($"d"))
assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer) assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer)
val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"), val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"),
lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd") lit(1).as("key")).groupBy("key").agg(aggFn($"d").alias("aggd")).select($"aggd")
assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer) assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer)
val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d")) val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d"))
@ -279,7 +279,7 @@ class DataFrameSuite extends QueryTest
val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")). val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")).
toDF("d") toDF("d")
assertDecimalSumOverflow( assertDecimalSumOverflow(
nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer) nullsDf.union(largeDecimals).agg(aggFn($"d")), ansiEnabled, expectedAnswer)
val df3 = Seq( val df3 = Seq(
(BigDecimal("10000000000000000000"), 1), (BigDecimal("10000000000000000000"), 1),
@ -306,6 +306,14 @@ class DataFrameSuite extends QueryTest
} }
} }
test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") {
checkAggResultsForDecimalOverflow(c => sum(c))
}
test("SPARK-35955: Aggregate avg should not return wrong results for decimal overflow") {
checkAggResultsForDecimalOverflow(c => avg(c))
}
test("Star Expansion - ds.explode should fail with a meaningful message if it takes a star") { test("Star Expansion - ds.explode should fail with a meaningful message if it takes a star") {
val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv")
val e = intercept[AnalysisException] { val e = intercept[AnalysisException] {