[SPARK-34837][SQL] Support ANSI SQL intervals by the aggregate function avg

### What changes were proposed in this pull request?
Extend the `Average` expression to support `DayTimeIntervalType` and `YearMonthIntervalType` added by #31614.

Note: the expressions can throw the overflow exception independently from the SQL config `spark.sql.ansi.enabled`. In this way, the modified expressions always behave in the ANSI mode for the intervals.

### Why are the changes needed?
Extend `org.apache.spark.sql.catalyst.expressions.aggregate.Average` to support `DayTimeIntervalType` and `YearMonthIntervalType`.

### Does this PR introduce _any_ user-facing change?
'No'.
Should not since new types have not been released yet.

### How was this patch tested?
Jenkins test

Closes #32229 from beliefer/SPARK-34837.

Authored-by: gengjiaan <gengjiaan@360.cn>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
This commit is contained in:
gengjiaan 2021-04-19 15:56:56 +03:00 committed by Max Gekk
parent 70b606ffdd
commit 8dc455bba8
5 changed files with 60 additions and 9 deletions

View file

@ -40,10 +40,11 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType))
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function average")
TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "average")
override def nullable: Boolean = true
@ -53,11 +54,15 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
private lazy val resultType = child.dataType match {
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 4, s + 4)
case _: YearMonthIntervalType => YearMonthIntervalType
case _: DayTimeIntervalType => DayTimeIntervalType
case _ => DoubleType
}
private lazy val sumDataType = child.dataType match {
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
case _: YearMonthIntervalType => YearMonthIntervalType
case _: DayTimeIntervalType => DayTimeIntervalType
case _ => DoubleType
}
@ -82,6 +87,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
case _: DecimalType =>
DecimalPrecision.decimalAndDecimal(
Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
case _: YearMonthIntervalType => DivideYMInterval(sum, count)
case _: DayTimeIntervalType => DivideDTInterval(sum, count)
case _ =>
Divide(sum.cast(resultType), count.cast(resultType), failOnError = false)
}

View file

@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@ -48,12 +49,8 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType))
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case YearMonthIntervalType | DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
case other => TypeCheckResult.TypeCheckFailure(
s"function sum requires numeric or interval types, not ${other.catalogString}")
}
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "sum")
private lazy val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>

View file

@ -61,6 +61,14 @@ object TypeUtils {
}
}
def checkForAnsiIntervalOrNumericType(
dt: DataType, funcName: String): TypeCheckResult = dt match {
case YearMonthIntervalType | DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
case other => TypeCheckResult.TypeCheckFailure(
s"function $funcName requires numeric or interval types, not ${other.catalogString}")
}
def getNumeric(t: DataType, exactNumericRequired: Boolean = false): Numeric[Any] = {
if (exactNumericRequired) {
t.asInstanceOf[NumericType].exactNumeric.asInstanceOf[Numeric[Any]]

View file

@ -159,7 +159,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(Min(Symbol("mapField")), "min does not support ordering on type")
assertError(Max(Symbol("mapField")), "max does not support ordering on type")
assertError(Sum(Symbol("booleanField")), "function sum requires numeric or interval types")
assertError(Average(Symbol("booleanField")), "function average requires numeric type")
assertError(Average(Symbol("booleanField")),
"function average requires numeric or interval types")
}
test("check types for others") {

View file

@ -1151,6 +1151,44 @@ class DataFrameAggregateSuite extends QueryTest
}
assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
}
test("SPARK-34837: Support ANSI SQL intervals by the aggregate function `avg`") {
val df = Seq((1, Period.ofMonths(10), Duration.ofDays(10)),
(2, Period.ofMonths(1), Duration.ofDays(1)),
(2, null, null),
(3, Period.ofMonths(-3), Duration.ofDays(-6)),
(3, Period.ofMonths(21), Duration.ofDays(-5)))
.toDF("class", "year-month", "day-time")
val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)),
(Period.ofMonths(10), Duration.ofDays(10)))
.toDF("year-month", "day-time")
val avgDF = df.select(avg($"year-month"), avg($"day-time"))
checkAnswer(avgDF, Row(Period.ofMonths(7), Duration.ofDays(0)))
assert(find(avgDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
assert(avgDF.schema == StructType(Seq(StructField("avg(year-month)", YearMonthIntervalType),
StructField("avg(day-time)", DayTimeIntervalType))))
val avgDF2 = df.groupBy($"class").agg(avg($"year-month"), avg($"day-time"))
checkAnswer(avgDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) ::
Row(2, Period.ofMonths(1), Duration.ofDays(1)) ::
Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) ::Nil)
assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
assert(avgDF2.schema == StructType(Seq(StructField("class", IntegerType, false),
StructField("avg(year-month)", YearMonthIntervalType),
StructField("avg(day-time)", DayTimeIntervalType))))
val error = intercept[SparkException] {
checkAnswer(df2.select(avg($"year-month")), Nil)
}
assert(error.toString contains "java.lang.ArithmeticException: integer overflow")
val error2 = intercept[SparkException] {
checkAnswer(df2.select(avg($"day-time")), Nil)
}
assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
}
}
case class B(c: Option[Double])