[SPARK-34837][SQL] Support ANSI SQL intervals by the aggregate function avg
### What changes were proposed in this pull request? Extend the `Average` expression to support `DayTimeIntervalType` and `YearMonthIntervalType` added by #31614. Note: the expressions can throw the overflow exception independently from the SQL config `spark.sql.ansi.enabled`. In this way, the modified expressions always behave in the ANSI mode for the intervals. ### Why are the changes needed? Extend `org.apache.spark.sql.catalyst.expressions.aggregate.Average` to support `DayTimeIntervalType` and `YearMonthIntervalType`. ### Does this PR introduce _any_ user-facing change? 'No'. Should not since new types have not been released yet. ### How was this patch tested? Jenkins test Closes #32229 from beliefer/SPARK-34837. Authored-by: gengjiaan <gengjiaan@360.cn> Signed-off-by: Max Gekk <max.gekk@gmail.com>
This commit is contained in:
parent
70b606ffdd
commit
8dc455bba8
|
@ -40,10 +40,11 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
|
|||
|
||||
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")
|
||||
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
|
||||
override def inputTypes: Seq[AbstractDataType] =
|
||||
Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType))
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult =
|
||||
TypeUtils.checkForNumericExpr(child.dataType, "function average")
|
||||
TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "average")
|
||||
|
||||
override def nullable: Boolean = true
|
||||
|
||||
|
@ -53,11 +54,15 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
|
|||
private lazy val resultType = child.dataType match {
|
||||
case DecimalType.Fixed(p, s) =>
|
||||
DecimalType.bounded(p + 4, s + 4)
|
||||
case _: YearMonthIntervalType => YearMonthIntervalType
|
||||
case _: DayTimeIntervalType => DayTimeIntervalType
|
||||
case _ => DoubleType
|
||||
}
|
||||
|
||||
private lazy val sumDataType = child.dataType match {
|
||||
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
|
||||
case _: YearMonthIntervalType => YearMonthIntervalType
|
||||
case _: DayTimeIntervalType => DayTimeIntervalType
|
||||
case _ => DoubleType
|
||||
}
|
||||
|
||||
|
@ -82,6 +87,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
|
|||
case _: DecimalType =>
|
||||
DecimalPrecision.decimalAndDecimal(
|
||||
Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
|
||||
case _: YearMonthIntervalType => DivideYMInterval(sum, count)
|
||||
case _: DayTimeIntervalType => DivideDTInterval(sum, count)
|
||||
case _ =>
|
||||
Divide(sum.cast(resultType), count.cast(resultType), failOnError = false)
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
|||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.trees.UnaryLike
|
||||
import org.apache.spark.sql.catalyst.util.TypeUtils
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -48,12 +49,8 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
|
|||
override def inputTypes: Seq[AbstractDataType] =
|
||||
Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType))
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
|
||||
case YearMonthIntervalType | DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess
|
||||
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
|
||||
case other => TypeCheckResult.TypeCheckFailure(
|
||||
s"function sum requires numeric or interval types, not ${other.catalogString}")
|
||||
}
|
||||
override def checkInputDataTypes(): TypeCheckResult =
|
||||
TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "sum")
|
||||
|
||||
private lazy val resultType = child.dataType match {
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
|
|
|
@ -61,6 +61,14 @@ object TypeUtils {
|
|||
}
|
||||
}
|
||||
|
||||
def checkForAnsiIntervalOrNumericType(
|
||||
dt: DataType, funcName: String): TypeCheckResult = dt match {
|
||||
case YearMonthIntervalType | DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess
|
||||
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
|
||||
case other => TypeCheckResult.TypeCheckFailure(
|
||||
s"function $funcName requires numeric or interval types, not ${other.catalogString}")
|
||||
}
|
||||
|
||||
def getNumeric(t: DataType, exactNumericRequired: Boolean = false): Numeric[Any] = {
|
||||
if (exactNumericRequired) {
|
||||
t.asInstanceOf[NumericType].exactNumeric.asInstanceOf[Numeric[Any]]
|
||||
|
|
|
@ -159,7 +159,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
|
|||
assertError(Min(Symbol("mapField")), "min does not support ordering on type")
|
||||
assertError(Max(Symbol("mapField")), "max does not support ordering on type")
|
||||
assertError(Sum(Symbol("booleanField")), "function sum requires numeric or interval types")
|
||||
assertError(Average(Symbol("booleanField")), "function average requires numeric type")
|
||||
assertError(Average(Symbol("booleanField")),
|
||||
"function average requires numeric or interval types")
|
||||
}
|
||||
|
||||
test("check types for others") {
|
||||
|
|
|
@ -1151,6 +1151,44 @@ class DataFrameAggregateSuite extends QueryTest
|
|||
}
|
||||
assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
|
||||
}
|
||||
|
||||
test("SPARK-34837: Support ANSI SQL intervals by the aggregate function `avg`") {
|
||||
val df = Seq((1, Period.ofMonths(10), Duration.ofDays(10)),
|
||||
(2, Period.ofMonths(1), Duration.ofDays(1)),
|
||||
(2, null, null),
|
||||
(3, Period.ofMonths(-3), Duration.ofDays(-6)),
|
||||
(3, Period.ofMonths(21), Duration.ofDays(-5)))
|
||||
.toDF("class", "year-month", "day-time")
|
||||
|
||||
val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)),
|
||||
(Period.ofMonths(10), Duration.ofDays(10)))
|
||||
.toDF("year-month", "day-time")
|
||||
|
||||
val avgDF = df.select(avg($"year-month"), avg($"day-time"))
|
||||
checkAnswer(avgDF, Row(Period.ofMonths(7), Duration.ofDays(0)))
|
||||
assert(find(avgDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
|
||||
assert(avgDF.schema == StructType(Seq(StructField("avg(year-month)", YearMonthIntervalType),
|
||||
StructField("avg(day-time)", DayTimeIntervalType))))
|
||||
|
||||
val avgDF2 = df.groupBy($"class").agg(avg($"year-month"), avg($"day-time"))
|
||||
checkAnswer(avgDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) ::
|
||||
Row(2, Period.ofMonths(1), Duration.ofDays(1)) ::
|
||||
Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) ::Nil)
|
||||
assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
|
||||
assert(avgDF2.schema == StructType(Seq(StructField("class", IntegerType, false),
|
||||
StructField("avg(year-month)", YearMonthIntervalType),
|
||||
StructField("avg(day-time)", DayTimeIntervalType))))
|
||||
|
||||
val error = intercept[SparkException] {
|
||||
checkAnswer(df2.select(avg($"year-month")), Nil)
|
||||
}
|
||||
assert(error.toString contains "java.lang.ArithmeticException: integer overflow")
|
||||
|
||||
val error2 = intercept[SparkException] {
|
||||
checkAnswer(df2.select(avg($"day-time")), Nil)
|
||||
}
|
||||
assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
|
||||
}
|
||||
}
|
||||
|
||||
case class B(c: Option[Double])
|
||||
|
|
Loading…
Reference in a new issue