diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 2cfa29899a..2f82c0738a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.test.SQLTestData.DecimalData import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR} case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) @@ -1115,34 +1116,47 @@ class DataFrameAggregateSuite extends QueryTest } test("SPARK-34716: Support ANSI SQL intervals by the aggregate function `sum`") { - val df = Seq((1, Period.ofMonths(10), Duration.ofDays(10)), - (2, Period.ofMonths(1), Duration.ofDays(1)), - (2, null, null), - (3, Period.ofMonths(-3), Duration.ofDays(-6)), - (3, Period.ofMonths(21), Duration.ofDays(-5))) - .toDF("class", "year-month", "day-time") + val df = Seq( + (1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)), + (2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)), + (2, null, null, null, null), + (3, Period.ofMonths(-3), Period.ofYears(-12), Period.ofMonths(-3), Duration.ofDays(-6)), + (3, Period.ofMonths(21), Period.ofYears(30), Period.ofMonths(5), Duration.ofDays(-5))) + .toDF("class", "year-month", "year", "month", "day-time") + .select( + $"class", + $"year-month", + $"year" cast YearMonthIntervalType(YEAR) as "year", + $"month" cast YearMonthIntervalType(MONTH) as "month", + $"day-time") val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)), (Period.ofMonths(10), Duration.ofDays(10))) .toDF("year-month", "day-time") - val sumDF = df.select(sum($"year-month"), sum($"day-time")) - checkAnswer(sumDF, Row(Period.of(2, 5, 0), Duration.ofDays(0))) + val sumDF = df.select(sum($"year-month"), sum($"year"), sum($"month"), sum($"day-time")) + checkAnswer(sumDF, + Row(Period.of(2, 5, 0), Period.ofYears(27), Period.of(1, 1, 0), Duration.ofDays(0))) assert(find(sumDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) assert(sumDF.schema == StructType(Seq( - // TODO(SPARK-35775): Check all year-month interval types in aggregate expressions StructField("sum(year-month)", YearMonthIntervalType()), + StructField("sum(year)", YearMonthIntervalType(YEAR)), + StructField("sum(month)", YearMonthIntervalType(MONTH)), // TODO(SPARK-35729): Check all day-time interval types in aggregate expressions StructField("sum(day-time)", DayTimeIntervalType())))) - val sumDF2 = df.groupBy($"class").agg(sum($"year-month"), sum($"day-time")) - checkAnswer(sumDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) :: - Row(2, Period.ofMonths(1), Duration.ofDays(1)) :: - Row(3, Period.of(1, 6, 0), Duration.ofDays(-11)) :: Nil) + val sumDF2 = + df.groupBy($"class").agg(sum($"year-month"), sum($"year"), sum($"month"), sum($"day-time")) + checkAnswer(sumDF2, + Row(1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)) :: + Row(2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)) :: + Row(3, Period.of(1, 6, 0), Period.ofYears(18), Period.ofMonths(2), Duration.ofDays(-11)) :: + Nil) assert(find(sumDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType, false), - // TODO(SPARK-35775): Check all year-month interval types in aggregate expressions StructField("sum(year-month)", YearMonthIntervalType()), + StructField("sum(year)", YearMonthIntervalType(YEAR)), + StructField("sum(month)", YearMonthIntervalType(MONTH)), // TODO(SPARK-35729): Check all day-time interval types in aggregate expressions StructField("sum(day-time)", DayTimeIntervalType())))) @@ -1158,34 +1172,48 @@ class DataFrameAggregateSuite extends QueryTest } test("SPARK-34837: Support ANSI SQL intervals by the aggregate function `avg`") { - val df = Seq((1, Period.ofMonths(10), Duration.ofDays(10)), - (2, Period.ofMonths(1), Duration.ofDays(1)), - (2, null, null), - (3, Period.ofMonths(-3), Duration.ofDays(-6)), - (3, Period.ofMonths(21), Duration.ofDays(-5))) - .toDF("class", "year-month", "day-time") + val df = Seq( + (1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)), + (2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)), + (2, null, null, null, null), + (3, Period.ofMonths(-3), Period.ofYears(-12), Period.ofMonths(-3), Duration.ofDays(-6)), + (3, Period.ofMonths(21), Period.ofYears(30), Period.ofMonths(5), Duration.ofDays(-5)), + (3, null, Period.ofYears(1), null, null)) + .toDF("class", "year-month", "year", "month", "day-time") + .select( + $"class", + $"year-month", + $"year" cast YearMonthIntervalType(YEAR) as "year", + $"month" cast YearMonthIntervalType(MONTH) as "month", + $"day-time") val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)), (Period.ofMonths(10), Duration.ofDays(10))) .toDF("year-month", "day-time") - val avgDF = df.select(avg($"year-month"), avg($"day-time")) - checkAnswer(avgDF, Row(Period.ofMonths(7), Duration.ofDays(0))) + val avgDF = df.select(avg($"year-month"), avg($"year"), avg($"month"), avg($"day-time")) + checkAnswer(avgDF, + Row(Period.ofMonths(7), Period.of(5, 7, 0), Period.ofMonths(3), Duration.ofDays(0))) assert(find(avgDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) assert(avgDF.schema == StructType(Seq( - // TODO(SPARK-35775): Check all year-month interval types in aggregate expressions StructField("avg(year-month)", YearMonthIntervalType()), + StructField("avg(year)", YearMonthIntervalType()), + StructField("avg(month)", YearMonthIntervalType()), // TODO(SPARK-35729): Check all day-time interval types in aggregate expressions StructField("avg(day-time)", DayTimeIntervalType())))) - val avgDF2 = df.groupBy($"class").agg(avg($"year-month"), avg($"day-time")) - checkAnswer(avgDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) :: - Row(2, Period.ofMonths(1), Duration.ofDays(1)) :: - Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) :: Nil) + val avgDF2 = + df.groupBy($"class").agg(avg($"year-month"), avg($"year"), avg($"month"), avg($"day-time")) + checkAnswer(avgDF2, + Row(1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)) :: + Row(2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)) :: + Row(3, Period.ofMonths(9), Period.of(6, 4, 0), Period.ofMonths(1), + Duration.ofDays(-5).plusHours(-12)) :: Nil) assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) assert(avgDF2.schema == StructType(Seq(StructField("class", IntegerType, false), - // TODO(SPARK-35775): Check all year-month interval types in aggregate expressions StructField("avg(year-month)", YearMonthIntervalType()), + StructField("avg(year)", YearMonthIntervalType()), + StructField("avg(month)", YearMonthIntervalType()), // TODO(SPARK-35729): Check all day-time interval types in aggregate expressions StructField("avg(day-time)", DayTimeIntervalType()))))