[SPARK-35955][SQL] Check for overflow in Average in ANSI mode
### What changes were proposed in this pull request? Fixes decimal overflow issues for decimal average in ANSI mode, so that overflows throw an exception rather than returning null. ### Why are the changes needed? Query: ``` scala> import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._ scala> spark.conf.set("spark.sql.ansi.enabled", true) scala> val df = Seq( | (BigDecimal("10000000000000000000"), 1), | (BigDecimal("10000000000000000000"), 1), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") df: org.apache.spark.sql.DataFrame = [decNum: decimal(38,18), intNum: int] scala> val df2 = df.withColumnRenamed("decNum", "decNum2").join(df, "intNum").agg(mean("decNum")) df2: org.apache.spark.sql.DataFrame = [avg(decNum): decimal(38,22)] scala> df2.show(40,false) ``` Before: ``` +-----------+ |avg(decNum)| +-----------+ |null | +-----------+ ``` After: ``` 21/07/01 19:48:31 ERROR Executor: Exception in task 0.0 in stage 3.0 (TID 24) java.lang.ArithmeticException: Overflow in sum of decimals. at org.apache.spark.sql.errors.QueryExecutionErrors$.overflowInSumOfDecimalError(QueryExecutionErrors.scala:162) at org.apache.spark.sql.errors.QueryExecutionErrors.overflowInSumOfDecimalError(QueryExecutionErrors.scala) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759) at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:349) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:898) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:898) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373) at org.apache.spark.rdd.RDD.iterator(RDD.scala:337) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) at org.apache.spark.scheduler.Task.run(Task.scala:131) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:499) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1462) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:502) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test Closes #33177 from karenfeng/SPARK-35955. Authored-by: Karen Feng <karen.feng@databricks.com> Signed-off-by: Gengliang Wang <gengliang@apache.org>
This commit is contained in:
parent
47485a3c2d
commit
1fda011d71
|
@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
|
|||
import org.apache.spark.sql.catalyst.trees.TreePattern.{AVERAGE, TreePattern}
|
||||
import org.apache.spark.sql.catalyst.trees.UnaryLike
|
||||
import org.apache.spark.sql.catalyst.util.TypeUtils
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
@ExpressionDescription(
|
||||
|
@ -87,9 +88,11 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
|
|||
// If all input are nulls, count will be 0 and we will get null after the division.
|
||||
// We can't directly use `/` as it throws an exception under ansi mode.
|
||||
override lazy val evaluateExpression = child.dataType match {
|
||||
case _: DecimalType =>
|
||||
case d: DecimalType =>
|
||||
DecimalPrecision.decimalAndDecimal()(
|
||||
Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
|
||||
Divide(
|
||||
CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled),
|
||||
count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
|
||||
case _: YearMonthIntervalType =>
|
||||
If(EqualTo(count, Literal(0L)),
|
||||
Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count))
|
||||
|
|
|
@ -235,7 +235,7 @@ class DataFrameSuite extends QueryTest
|
|||
}
|
||||
}
|
||||
|
||||
test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") {
|
||||
def checkAggResultsForDecimalOverflow(aggFn: Column => Column): Unit = {
|
||||
Seq("true", "false").foreach { wholeStageEnabled =>
|
||||
withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) {
|
||||
Seq(true, false).foreach { ansiEnabled =>
|
||||
|
@ -256,22 +256,22 @@ class DataFrameSuite extends QueryTest
|
|||
(BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
|
||||
val df = df0.union(df1)
|
||||
val df2 = df.withColumnRenamed("decNum", "decNum2").
|
||||
join(df, "intNum").agg(sum("decNum"))
|
||||
join(df, "intNum").agg(aggFn($"decNum"))
|
||||
|
||||
val expectedAnswer = Row(null)
|
||||
assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer)
|
||||
|
||||
val decStr = "1" + "0" * 19
|
||||
val d1 = spark.range(0, 12, 1, 1)
|
||||
val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d"))
|
||||
val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(aggFn($"d"))
|
||||
assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer)
|
||||
|
||||
val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
|
||||
val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d"))
|
||||
val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(aggFn($"d"))
|
||||
assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer)
|
||||
|
||||
val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"),
|
||||
lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd")
|
||||
lit(1).as("key")).groupBy("key").agg(aggFn($"d").alias("aggd")).select($"aggd")
|
||||
assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer)
|
||||
|
||||
val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d"))
|
||||
|
@ -279,7 +279,7 @@ class DataFrameSuite extends QueryTest
|
|||
val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")).
|
||||
toDF("d")
|
||||
assertDecimalSumOverflow(
|
||||
nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer)
|
||||
nullsDf.union(largeDecimals).agg(aggFn($"d")), ansiEnabled, expectedAnswer)
|
||||
|
||||
val df3 = Seq(
|
||||
(BigDecimal("10000000000000000000"), 1),
|
||||
|
@ -306,6 +306,14 @@ class DataFrameSuite extends QueryTest
|
|||
}
|
||||
}
|
||||
|
||||
test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") {
|
||||
checkAggResultsForDecimalOverflow(c => sum(c))
|
||||
}
|
||||
|
||||
test("SPARK-35955: Aggregate avg should not return wrong results for decimal overflow") {
|
||||
checkAggResultsForDecimalOverflow(c => avg(c))
|
||||
}
|
||||
|
||||
test("Star Expansion - ds.explode should fail with a meaningful message if it takes a star") {
|
||||
val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv")
|
||||
val e = intercept[AnalysisException] {
|
||||
|
|
Loading…
Reference in a new issue