[SPARK-34837][SQL][FOLLOWUP] Fix division by zero in the avg function over ANSI intervals
### What changes were proposed in this pull request? https://github.com/apache/spark/pull/32229 support ANSI SQL intervals by the aggregate function `avg`. But have not treat that the input zero rows. so this will lead to: ``` Caused by: java.lang.ArithmeticException: / by zero at com.google.common.math.LongMath.divide(LongMath.java:367) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458) at org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1864) at org.apache.spark.rdd.RDD.$anonfun$count$1(RDD.scala:1253) at org.apache.spark.rdd.RDD.$anonfun$count$1$adapted(RDD.scala:1253) at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2248) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) at org.apache.spark.scheduler.Task.run(Task.scala:131) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:498) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1437) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:501) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) ``` ### Why are the changes needed? Fix a bug. ### Does this PR introduce _any_ user-facing change? No. Just new feature. ### How was this patch tested? new tests. Closes #32358 from beliefer/SPARK-34837-followup. Authored-by: gengjiaan <gengjiaan@360.cn> Signed-off-by: Max Gekk <max.gekk@gmail.com>
This commit is contained in:
parent
2d2f467831
commit
55dea2d937
|
@ -87,8 +87,12 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
|
|||
case _: DecimalType =>
|
||||
DecimalPrecision.decimalAndDecimal()(
|
||||
Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
|
||||
case _: YearMonthIntervalType => DivideYMInterval(sum, count)
|
||||
case _: DayTimeIntervalType => DivideDTInterval(sum, count)
|
||||
case _: YearMonthIntervalType =>
|
||||
If(EqualTo(count, Literal(0L)),
|
||||
Literal(null, YearMonthIntervalType), DivideYMInterval(sum, count))
|
||||
case _: DayTimeIntervalType =>
|
||||
If(EqualTo(count, Literal(0L)),
|
||||
Literal(null, DayTimeIntervalType), DivideDTInterval(sum, count))
|
||||
case _ =>
|
||||
Divide(sum.cast(resultType), count.cast(resultType), failOnError = false)
|
||||
}
|
||||
|
|
|
@ -1188,6 +1188,13 @@ class DataFrameAggregateSuite extends QueryTest
|
|||
checkAnswer(df2.select(avg($"day-time")), Nil)
|
||||
}
|
||||
assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
|
||||
|
||||
val df3 = df.filter($"class" > 4)
|
||||
val avgDF3 = df3.select(avg($"year-month"), avg($"day-time"))
|
||||
checkAnswer(avgDF3, Row(null, null) :: Nil)
|
||||
|
||||
val avgDF4 = df3.groupBy($"class").agg(avg($"year-month"), avg($"day-time"))
|
||||
checkAnswer(avgDF4, Nil)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue