From 8dc455bba8254fca583f0a3d6acf7730862251a7 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Mon, 19 Apr 2021 15:56:56 +0300 Subject: [PATCH] [SPARK-34837][SQL] Support ANSI SQL intervals by the aggregate function `avg` ### What changes were proposed in this pull request? Extend the `Average` expression to support `DayTimeIntervalType` and `YearMonthIntervalType` added by #31614. Note: the expressions can throw the overflow exception independently from the SQL config `spark.sql.ansi.enabled`. In this way, the modified expressions always behave in the ANSI mode for the intervals. ### Why are the changes needed? Extend `org.apache.spark.sql.catalyst.expressions.aggregate.Average` to support `DayTimeIntervalType` and `YearMonthIntervalType`. ### Does this PR introduce _any_ user-facing change? 'No'. Should not since new types have not been released yet. ### How was this patch tested? Jenkins test Closes #32229 from beliefer/SPARK-34837. Authored-by: gengjiaan Signed-off-by: Max Gekk --- .../expressions/aggregate/Average.scala | 11 +++++- .../catalyst/expressions/aggregate/Sum.scala | 9 ++--- .../spark/sql/catalyst/util/TypeUtils.scala | 8 ++++ .../ExpressionTypeCheckingSuite.scala | 3 +- .../spark/sql/DataFrameAggregateSuite.scala | 38 +++++++++++++++++++ 5 files changed, 60 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 90e91ae418..b53f87c8b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -40,10 +40,11 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg") - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType)) override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function average") + TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "average") override def nullable: Boolean = true @@ -53,11 +54,15 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit private lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) + case _: YearMonthIntervalType => YearMonthIntervalType + case _: DayTimeIntervalType => DayTimeIntervalType case _ => DoubleType } private lazy val sumDataType = child.dataType match { case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) + case _: YearMonthIntervalType => YearMonthIntervalType + case _: DayTimeIntervalType => DayTimeIntervalType case _ => DoubleType } @@ -82,6 +87,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case _: DecimalType => DecimalPrecision.decimalAndDecimal( Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType) + case _: YearMonthIntervalType => DivideYMInterval(sum, count) + case _: DayTimeIntervalType => DivideDTInterval(sum, count) case _ => Divide(sum.cast(resultType), count.cast(resultType), failOnError = false) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 8ea687d78a..31150fc31b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -48,12 +49,8 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType)) - override def checkInputDataTypes(): TypeCheckResult = child.dataType match { - case YearMonthIntervalType | DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess - case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess - case other => TypeCheckResult.TypeCheckFailure( - s"function sum requires numeric or interval types, not ${other.catalogString}") - } + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "sum") private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 6212e8f48c..3fa4732651 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -61,6 +61,14 @@ object TypeUtils { } } + def checkForAnsiIntervalOrNumericType( + dt: DataType, funcName: String): TypeCheckResult = dt match { + case YearMonthIntervalType | DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess + case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess + case other => TypeCheckResult.TypeCheckFailure( + s"function $funcName requires numeric or interval types, not ${other.catalogString}") + } + def getNumeric(t: DataType, exactNumericRequired: Boolean = false): Numeric[Any] = { if (exactNumericRequired) { t.asInstanceOf[NumericType].exactNumeric.asInstanceOf[Numeric[Any]] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 1b9135eef6..a9b22ad21a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -159,7 +159,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Min(Symbol("mapField")), "min does not support ordering on type") assertError(Max(Symbol("mapField")), "max does not support ordering on type") assertError(Sum(Symbol("booleanField")), "function sum requires numeric or interval types") - assertError(Average(Symbol("booleanField")), "function average requires numeric type") + assertError(Average(Symbol("booleanField")), + "function average requires numeric or interval types") } test("check types for others") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 92d3dc6fb8..c53bcf045d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1151,6 +1151,44 @@ class DataFrameAggregateSuite extends QueryTest } assert(error2.toString contains "java.lang.ArithmeticException: long overflow") } + + test("SPARK-34837: Support ANSI SQL intervals by the aggregate function `avg`") { + val df = Seq((1, Period.ofMonths(10), Duration.ofDays(10)), + (2, Period.ofMonths(1), Duration.ofDays(1)), + (2, null, null), + (3, Period.ofMonths(-3), Duration.ofDays(-6)), + (3, Period.ofMonths(21), Duration.ofDays(-5))) + .toDF("class", "year-month", "day-time") + + val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)), + (Period.ofMonths(10), Duration.ofDays(10))) + .toDF("year-month", "day-time") + + val avgDF = df.select(avg($"year-month"), avg($"day-time")) + checkAnswer(avgDF, Row(Period.ofMonths(7), Duration.ofDays(0))) + assert(find(avgDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) + assert(avgDF.schema == StructType(Seq(StructField("avg(year-month)", YearMonthIntervalType), + StructField("avg(day-time)", DayTimeIntervalType)))) + + val avgDF2 = df.groupBy($"class").agg(avg($"year-month"), avg($"day-time")) + checkAnswer(avgDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) :: + Row(2, Period.ofMonths(1), Duration.ofDays(1)) :: + Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) ::Nil) + assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) + assert(avgDF2.schema == StructType(Seq(StructField("class", IntegerType, false), + StructField("avg(year-month)", YearMonthIntervalType), + StructField("avg(day-time)", DayTimeIntervalType)))) + + val error = intercept[SparkException] { + checkAnswer(df2.select(avg($"year-month")), Nil) + } + assert(error.toString contains "java.lang.ArithmeticException: integer overflow") + + val error2 = intercept[SparkException] { + checkAnswer(df2.select(avg($"day-time")), Nil) + } + assert(error2.toString contains "java.lang.ArithmeticException: long overflow") + } } case class B(c: Option[Double])