From 55dea2d937a375d9929937ee66aa9bfed158b883 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Tue, 27 Apr 2021 10:52:12 +0300 Subject: [PATCH] [SPARK-34837][SQL][FOLLOWUP] Fix division by zero in the avg function over ANSI intervals ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/32229 support ANSI SQL intervals by the aggregate function `avg`. But have not treat that the input zero rows. so this will lead to: ``` Caused by: java.lang.ArithmeticException: / by zero at com.google.common.math.LongMath.divide(LongMath.java:367) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458) at org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1864) at org.apache.spark.rdd.RDD.$anonfun$count$1(RDD.scala:1253) at org.apache.spark.rdd.RDD.$anonfun$count$1$adapted(RDD.scala:1253) at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2248) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) at org.apache.spark.scheduler.Task.run(Task.scala:131) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:498) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1437) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:501) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) ``` ### Why are the changes needed? Fix a bug. ### Does this PR introduce _any_ user-facing change? No. Just new feature. ### How was this patch tested? new tests. Closes #32358 from beliefer/SPARK-34837-followup. Authored-by: gengjiaan Signed-off-by: Max Gekk --- .../sql/catalyst/expressions/aggregate/Average.scala | 8 ++++++-- .../apache/spark/sql/DataFrameAggregateSuite.scala | 11 +++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 4fc0256bce..8ae24e5135 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -87,8 +87,12 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case _: DecimalType => DecimalPrecision.decimalAndDecimal()( Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType) - case _: YearMonthIntervalType => DivideYMInterval(sum, count) - case _: DayTimeIntervalType => DivideDTInterval(sum, count) + case _: YearMonthIntervalType => + If(EqualTo(count, Literal(0L)), + Literal(null, YearMonthIntervalType), DivideYMInterval(sum, count)) + case _: DayTimeIntervalType => + If(EqualTo(count, Literal(0L)), + Literal(null, DayTimeIntervalType), DivideDTInterval(sum, count)) case _ => Divide(sum.cast(resultType), count.cast(resultType), failOnError = false) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index c53bcf045d..c6f6cbdbf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1135,7 +1135,7 @@ class DataFrameAggregateSuite extends QueryTest val sumDF2 = df.groupBy($"class").agg(sum($"year-month"), sum($"day-time")) checkAnswer(sumDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) :: Row(2, Period.ofMonths(1), Duration.ofDays(1)) :: - Row(3, Period.of(1, 6, 0), Duration.ofDays(-11)) ::Nil) + Row(3, Period.of(1, 6, 0), Duration.ofDays(-11)) :: Nil) assert(find(sumDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType, false), StructField("sum(year-month)", YearMonthIntervalType), @@ -1173,7 +1173,7 @@ class DataFrameAggregateSuite extends QueryTest val avgDF2 = df.groupBy($"class").agg(avg($"year-month"), avg($"day-time")) checkAnswer(avgDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) :: Row(2, Period.ofMonths(1), Duration.ofDays(1)) :: - Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) ::Nil) + Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) :: Nil) assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) assert(avgDF2.schema == StructType(Seq(StructField("class", IntegerType, false), StructField("avg(year-month)", YearMonthIntervalType), @@ -1188,6 +1188,13 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(df2.select(avg($"day-time")), Nil) } assert(error2.toString contains "java.lang.ArithmeticException: long overflow") + + val df3 = df.filter($"class" > 4) + val avgDF3 = df3.select(avg($"year-month"), avg($"day-time")) + checkAnswer(avgDF3, Row(null, null) :: Nil) + + val avgDF4 = df3.groupBy($"class").agg(avg($"year-month"), avg($"day-time")) + checkAnswer(avgDF4, Nil) } }