[SPARK-34837][SQL][FOLLOWUP] Fix division by zero in the avg function over ANSI intervals

### What changes were proposed in this pull request?
https://github.com/apache/spark/pull/32229 support ANSI SQL intervals by the aggregate function `avg`.
But have not treat that the input zero rows. so this will lead to:
```
Caused by: java.lang.ArithmeticException: / by zero
	at com.google.common.math.LongMath.divide(LongMath.java:367)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1864)
	at org.apache.spark.rdd.RDD.$anonfun$count$1(RDD.scala:1253)
	at org.apache.spark.rdd.RDD.$anonfun$count$1$adapted(RDD.scala:1253)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2248)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:131)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:498)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1437)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:501)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)
```

### Why are the changes needed?
Fix a bug.

### Does this PR introduce _any_ user-facing change?
No. Just new feature.

### How was this patch tested?
new tests.

Closes #32358 from beliefer/SPARK-34837-followup.

Authored-by: gengjiaan <gengjiaan@360.cn>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
This commit is contained in:
gengjiaan 2021-04-27 10:52:12 +03:00 committed by Max Gekk
parent 2d2f467831
commit 55dea2d937
2 changed files with 15 additions and 4 deletions

View file

@ -87,8 +87,12 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
case _: DecimalType =>
DecimalPrecision.decimalAndDecimal()(
Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
case _: YearMonthIntervalType => DivideYMInterval(sum, count)
case _: DayTimeIntervalType => DivideDTInterval(sum, count)
case _: YearMonthIntervalType =>
If(EqualTo(count, Literal(0L)),
Literal(null, YearMonthIntervalType), DivideYMInterval(sum, count))
case _: DayTimeIntervalType =>
If(EqualTo(count, Literal(0L)),
Literal(null, DayTimeIntervalType), DivideDTInterval(sum, count))
case _ =>
Divide(sum.cast(resultType), count.cast(resultType), failOnError = false)
}

View file

@ -1188,6 +1188,13 @@ class DataFrameAggregateSuite extends QueryTest
checkAnswer(df2.select(avg($"day-time")), Nil)
}
assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
val df3 = df.filter($"class" > 4)
val avgDF3 = df3.select(avg($"year-month"), avg($"day-time"))
checkAnswer(avgDF3, Row(null, null) :: Nil)
val avgDF4 = df3.groupBy($"class").agg(avg($"year-month"), avg($"day-time"))
checkAnswer(avgDF4, Nil)
}
}