diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 2f82c0738a..9cd743602b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.test.SQLTestData.DecimalData import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND} import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR} case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) @@ -1116,122 +1117,278 @@ class DataFrameAggregateSuite extends QueryTest } test("SPARK-34716: Support ANSI SQL intervals by the aggregate function `sum`") { - val df = Seq( - (1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)), - (2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)), - (2, null, null, null, null), - (3, Period.ofMonths(-3), Period.ofYears(-12), Period.ofMonths(-3), Duration.ofDays(-6)), - (3, Period.ofMonths(21), Period.ofYears(30), Period.ofMonths(5), Duration.ofDays(-5))) - .toDF("class", "year-month", "year", "month", "day-time") - .select( - $"class", - $"year-month", - $"year" cast YearMonthIntervalType(YEAR) as "year", - $"month" cast YearMonthIntervalType(MONTH) as "month", - $"day-time") - - val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)), - (Period.ofMonths(10), Duration.ofDays(10))) - .toDF("year-month", "day-time") - - val sumDF = df.select(sum($"year-month"), sum($"year"), sum($"month"), sum($"day-time")) + val sumDF = intervalData.select( + sum($"year-month"), + sum($"year"), + sum($"month"), + sum($"day-second"), + sum($"day-minute"), + sum($"day-hour"), + sum($"day"), + sum($"hour-second"), + sum($"hour-minute"), + sum($"hour"), + sum($"minute-second"), + sum($"minute"), + sum($"second")) checkAnswer(sumDF, - Row(Period.of(2, 5, 0), Period.ofYears(27), Period.of(1, 1, 0), Duration.ofDays(0))) + Row( + Period.of(2, 5, 0), + Period.ofYears(28), + Period.of(1, 1, 0), + Duration.ofDays(9).plusHours(23).plusMinutes(29).plusSeconds(4), + Duration.ofDays(23).plusHours(8).plusMinutes(27), + Duration.ofDays(-8).plusHours(-7), + Duration.ofDays(1), + Duration.ofDays(1).plusHours(12).plusMinutes(2).plusSeconds(33), + Duration.ofMinutes(43), + Duration.ofHours(12), + Duration.ofMinutes(18).plusSeconds(3), + Duration.ofMinutes(52), + Duration.ofSeconds(20))) assert(find(sumDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) assert(sumDF.schema == StructType(Seq( StructField("sum(year-month)", YearMonthIntervalType()), StructField("sum(year)", YearMonthIntervalType(YEAR)), StructField("sum(month)", YearMonthIntervalType(MONTH)), - // TODO(SPARK-35729): Check all day-time interval types in aggregate expressions - StructField("sum(day-time)", DayTimeIntervalType())))) + StructField("sum(day-second)", DayTimeIntervalType()), + StructField("sum(day-minute)", DayTimeIntervalType(DAY, MINUTE)), + StructField("sum(day-hour)", DayTimeIntervalType(DAY, HOUR)), + StructField("sum(day)", DayTimeIntervalType(DAY)), + StructField("sum(hour-second)", DayTimeIntervalType(HOUR, SECOND)), + StructField("sum(hour-minute)", DayTimeIntervalType(HOUR, MINUTE)), + StructField("sum(hour)", DayTimeIntervalType(HOUR)), + StructField("sum(minute-second)", DayTimeIntervalType(MINUTE, SECOND)), + StructField("sum(minute)", DayTimeIntervalType(MINUTE)), + StructField("sum(second)", DayTimeIntervalType(SECOND))))) val sumDF2 = - df.groupBy($"class").agg(sum($"year-month"), sum($"year"), sum($"month"), sum($"day-time")) + intervalData.groupBy($"class").agg( + sum($"year-month"), + sum($"year"), + sum($"month"), + sum($"day-second"), + sum($"day-minute"), + sum($"day-hour"), + sum($"day"), + sum($"hour-second"), + sum($"hour-minute"), + sum($"hour"), + sum($"minute-second"), + sum($"minute"), + sum($"second")) checkAnswer(sumDF2, - Row(1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)) :: - Row(2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)) :: - Row(3, Period.of(1, 6, 0), Period.ofYears(18), Period.ofMonths(2), Duration.ofDays(-11)) :: + Row(1, + Period.ofMonths(10), + Period.ofYears(8), + Period.ofMonths(10), + Duration.ofDays(7).plusHours(13).plusMinutes(3).plusSeconds(18), + Duration.ofDays(5).plusHours(21).plusMinutes(12), + Duration.ofDays(1).plusHours(8), + Duration.ofDays(10), + Duration.ofHours(20).plusMinutes(11).plusSeconds(33), + Duration.ofHours(3).plusMinutes(18), + Duration.ofHours(13), + Duration.ofMinutes(2).plusSeconds(59), + Duration.ofMinutes(38), + Duration.ofSeconds(5)) :: + Row(2, + Period.ofMonths(1), + Period.ofYears(1), + Period.ofMonths(1), + Duration.ofSeconds(1), + Duration.ofMinutes(1), + Duration.ofHours(1), + Duration.ofDays(1), + Duration.ofSeconds(1), + Duration.ofMinutes(1), + Duration.ofHours(1), + Duration.ofSeconds(1), + Duration.ofMinutes(1), + Duration.ofSeconds(1)) :: + Row(3, + Period.of(1, 6, 0), + Period.ofYears(19), + Period.ofMonths(2), + Duration.ofDays(2).plusHours(10).plusMinutes(25).plusSeconds(45), + Duration.ofDays(17).plusHours(11).plusMinutes(14), + Duration.ofDays(-9).plusHours(-16), + Duration.ofDays(-10), + Duration.ofHours(15).plusMinutes(50).plusSeconds(59), + Duration.ofHours(-2).plusMinutes(-36), + Duration.ofHours(-2), + Duration.ofMinutes(15).plusSeconds(3), + Duration.ofMinutes(13), + Duration.ofSeconds(14)) :: Nil) assert(find(sumDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType, false), StructField("sum(year-month)", YearMonthIntervalType()), StructField("sum(year)", YearMonthIntervalType(YEAR)), StructField("sum(month)", YearMonthIntervalType(MONTH)), - // TODO(SPARK-35729): Check all day-time interval types in aggregate expressions - StructField("sum(day-time)", DayTimeIntervalType())))) + StructField("sum(day-second)", DayTimeIntervalType()), + StructField("sum(day-minute)", DayTimeIntervalType(DAY, MINUTE)), + StructField("sum(day-hour)", DayTimeIntervalType(DAY, HOUR)), + StructField("sum(day)", DayTimeIntervalType(DAY)), + StructField("sum(hour-second)", DayTimeIntervalType(HOUR, SECOND)), + StructField("sum(hour-minute)", DayTimeIntervalType(HOUR, MINUTE)), + StructField("sum(hour)", DayTimeIntervalType(HOUR)), + StructField("sum(minute-second)", DayTimeIntervalType(MINUTE, SECOND)), + StructField("sum(minute)", DayTimeIntervalType(MINUTE)), + StructField("sum(second)", DayTimeIntervalType(SECOND))))) + val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)), + (Period.ofMonths(10), Duration.ofDays(10))) + .toDF("year-month", "day") val error = intercept[SparkException] { checkAnswer(df2.select(sum($"year-month")), Nil) } assert(error.toString contains "java.lang.ArithmeticException: integer overflow") val error2 = intercept[SparkException] { - checkAnswer(df2.select(sum($"day-time")), Nil) + checkAnswer(df2.select(sum($"day")), Nil) } assert(error2.toString contains "java.lang.ArithmeticException: long overflow") } test("SPARK-34837: Support ANSI SQL intervals by the aggregate function `avg`") { - val df = Seq( - (1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)), - (2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)), - (2, null, null, null, null), - (3, Period.ofMonths(-3), Period.ofYears(-12), Period.ofMonths(-3), Duration.ofDays(-6)), - (3, Period.ofMonths(21), Period.ofYears(30), Period.ofMonths(5), Duration.ofDays(-5)), - (3, null, Period.ofYears(1), null, null)) - .toDF("class", "year-month", "year", "month", "day-time") - .select( - $"class", - $"year-month", - $"year" cast YearMonthIntervalType(YEAR) as "year", - $"month" cast YearMonthIntervalType(MONTH) as "month", - $"day-time") - - val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)), - (Period.ofMonths(10), Duration.ofDays(10))) - .toDF("year-month", "day-time") - - val avgDF = df.select(avg($"year-month"), avg($"year"), avg($"month"), avg($"day-time")) + val avgDF = intervalData.select( + avg($"year-month"), + avg($"year"), + avg($"month"), + avg($"day-second"), + avg($"day-minute"), + avg($"day-hour"), + avg($"day"), + avg($"hour-second"), + avg($"hour-minute"), + avg($"hour"), + avg($"minute-second"), + avg($"minute"), + avg($"second")) checkAnswer(avgDF, - Row(Period.ofMonths(7), Period.of(5, 7, 0), Period.ofMonths(3), Duration.ofDays(0))) + Row(Period.ofMonths(7), + Period.of(5, 7, 0), + Period.ofMonths(3), + Duration.ofDays(2).plusHours(11).plusMinutes(52).plusSeconds(16), + Duration.ofDays(4).plusHours(16).plusMinutes(5).plusSeconds(24), + Duration.ofDays(-1).plusHours(-15).plusMinutes(-48), + Duration.ofHours(4).plusMinutes(48), + Duration.ofHours(9).plusSeconds(38).plusMillis(250), + Duration.ofMinutes(8).plusSeconds(36), + Duration.ofHours(2).plusMinutes(24), + Duration.ofMinutes(4).plusSeconds(30).plusMillis(750), + Duration.ofMinutes(10).plusSeconds(24), + Duration.ofSeconds(5))) assert(find(avgDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) assert(avgDF.schema == StructType(Seq( StructField("avg(year-month)", YearMonthIntervalType()), StructField("avg(year)", YearMonthIntervalType()), StructField("avg(month)", YearMonthIntervalType()), - // TODO(SPARK-35729): Check all day-time interval types in aggregate expressions - StructField("avg(day-time)", DayTimeIntervalType())))) + StructField("avg(day-second)", DayTimeIntervalType()), + StructField("avg(day-minute)", DayTimeIntervalType()), + StructField("avg(day-hour)", DayTimeIntervalType()), + StructField("avg(day)", DayTimeIntervalType()), + StructField("avg(hour-second)", DayTimeIntervalType()), + StructField("avg(hour-minute)", DayTimeIntervalType()), + StructField("avg(hour)", DayTimeIntervalType()), + StructField("avg(minute-second)", DayTimeIntervalType()), + StructField("avg(minute)", DayTimeIntervalType()), + StructField("avg(second)", DayTimeIntervalType())))) val avgDF2 = - df.groupBy($"class").agg(avg($"year-month"), avg($"year"), avg($"month"), avg($"day-time")) + intervalData.groupBy($"class").agg( + avg($"year-month"), + avg($"year"), + avg($"month"), + avg($"day-second"), + avg($"day-minute"), + avg($"day-hour"), + avg($"day"), + avg($"hour-second"), + avg($"hour-minute"), + avg($"hour"), + avg($"minute-second"), + avg($"minute"), + avg($"second")) checkAnswer(avgDF2, - Row(1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)) :: - Row(2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)) :: - Row(3, Period.ofMonths(9), Period.of(6, 4, 0), Period.ofMonths(1), - Duration.ofDays(-5).plusHours(-12)) :: Nil) + Row(1, + Period.ofMonths(10), + Period.ofYears(8), + Period.ofMonths(10), + Duration.ofDays(7).plusHours(13).plusMinutes(3).plusSeconds(18), + Duration.ofDays(5).plusHours(21).plusMinutes(12), + Duration.ofDays(1).plusHours(8), + Duration.ofDays(10), + Duration.ofHours(20).plusMinutes(11).plusSeconds(33), + Duration.ofHours(3).plusMinutes(18), + Duration.ofHours(13), + Duration.ofMinutes(2).plusSeconds(59), + Duration.ofMinutes(38), + Duration.ofSeconds(5)) :: + Row(2, + Period.ofMonths(1), + Period.ofYears(1), + Period.ofMonths(1), + Duration.ofSeconds(1), + Duration.ofMinutes(1), + Duration.ofHours(1), + Duration.ofDays(1), + Duration.ofSeconds(1), + Duration.ofMinutes(1), + Duration.ofHours(1), + Duration.ofSeconds(1), + Duration.ofMinutes(1), + Duration.ofSeconds(1)) :: + Row(3, + Period.ofMonths(9), + Period.of(6, 4, 0), + Period.ofMonths(1), + Duration.ofDays(1).plusHours(5).plusMinutes(12).plusSeconds(52).plusMillis(500), + Duration.ofDays(5).plusHours(19).plusMinutes(44).plusSeconds(40), + Duration.ofDays(-3).plusHours(-5).plusMinutes(-20), + Duration.ofDays(-3).plusHours(-8), + Duration.ofHours(7).plusMinutes(55).plusSeconds(29).plusMillis(500), + Duration.ofMinutes(-52), + Duration.ofMinutes(-40), + Duration.ofMinutes(7).plusSeconds(31).plusMillis(500), + Duration.ofMinutes(4).plusSeconds(20), + Duration.ofSeconds(7)) :: Nil) assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) - assert(avgDF2.schema == StructType(Seq(StructField("class", IntegerType, false), + assert(avgDF2.schema == StructType(Seq( + StructField("class", IntegerType, false), StructField("avg(year-month)", YearMonthIntervalType()), StructField("avg(year)", YearMonthIntervalType()), StructField("avg(month)", YearMonthIntervalType()), - // TODO(SPARK-35729): Check all day-time interval types in aggregate expressions - StructField("avg(day-time)", DayTimeIntervalType())))) + StructField("avg(day-second)", DayTimeIntervalType()), + StructField("avg(day-minute)", DayTimeIntervalType()), + StructField("avg(day-hour)", DayTimeIntervalType()), + StructField("avg(day)", DayTimeIntervalType()), + StructField("avg(hour-second)", DayTimeIntervalType()), + StructField("avg(hour-minute)", DayTimeIntervalType()), + StructField("avg(hour)", DayTimeIntervalType()), + StructField("avg(minute-second)", DayTimeIntervalType()), + StructField("avg(minute)", DayTimeIntervalType()), + StructField("avg(second)", DayTimeIntervalType())))) + val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)), + (Period.ofMonths(10), Duration.ofDays(10))) + .toDF("year-month", "day") val error = intercept[SparkException] { checkAnswer(df2.select(avg($"year-month")), Nil) } assert(error.toString contains "java.lang.ArithmeticException: integer overflow") val error2 = intercept[SparkException] { - checkAnswer(df2.select(avg($"day-time")), Nil) + checkAnswer(df2.select(avg($"day")), Nil) } assert(error2.toString contains "java.lang.ArithmeticException: long overflow") - val df3 = df.filter($"class" > 4) - val avgDF3 = df3.select(avg($"year-month"), avg($"day-time")) + val df3 = intervalData.filter($"class" > 4) + val avgDF3 = df3.select(avg($"year-month"), avg($"day")) checkAnswer(avgDF3, Row(null, null) :: Nil) - val avgDF4 = df3.groupBy($"class").agg(avg($"year-month"), avg($"day-time")) + val avgDF4 = df3.groupBy($"class").agg(avg($"year-month"), avg($"day")) checkAnswer(avgDF4, Nil) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index a1fd4a0215..307c4f33b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -18,9 +18,13 @@ package org.apache.spark.sql.test import java.nio.charset.StandardCharsets +import java.time.{Duration, Period} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND} +import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR} import org.apache.spark.unsafe.types.CalendarInterval /** @@ -287,6 +291,108 @@ private[sql] trait SQLTestData { self => df } + protected lazy val intervalData: DataFrame = Seq( + (1, + Period.ofMonths(10), + Period.ofYears(8), + Period.ofMonths(10), + Duration.ofDays(7).plusHours(13).plusMinutes(3).plusSeconds(18), + Duration.ofDays(5).plusHours(21).plusMinutes(12), + Duration.ofDays(1).plusHours(8), + Duration.ofDays(10), + Duration.ofHours(20).plusMinutes(11).plusSeconds(33), + Duration.ofHours(3).plusMinutes(18), + Duration.ofHours(13), + Duration.ofMinutes(2).plusSeconds(59), + Duration.ofMinutes(38), + Duration.ofSeconds(5)), + (2, + Period.ofMonths(1), + Period.ofYears(1), + Period.ofMonths(1), + Duration.ofSeconds(1), + Duration.ofMinutes(1), + Duration.ofHours(1), + Duration.ofDays(1), + Duration.ofSeconds(1), + Duration.ofMinutes(1), + Duration.ofHours(1), + Duration.ofSeconds(1), + Duration.ofMinutes(1), + Duration.ofSeconds(1)), + (2, null, null, null, null, null, null, null, null, null, null, null, null, null), + (3, + Period.ofMonths(-3), + Period.ofYears(-12), + Period.ofMonths(-3), + Duration.ofDays(-8).plusHours(-21).plusMinutes(-10).plusSeconds(-32), + Duration.ofDays(-2).plusHours(-1).plusMinutes(-12), + Duration.ofDays(-11).plusHours(-7), + Duration.ofDays(-6), + Duration.ofHours(-6).plusMinutes(-17).plusSeconds(-38), + Duration.ofHours(-12).plusMinutes(-53), + Duration.ofHours(-8), + Duration.ofMinutes(-30).plusSeconds(-2), + Duration.ofMinutes(-15), + Duration.ofSeconds(-36)), + (3, + Period.ofMonths(21), + Period.ofYears(30), + Period.ofMonths(5), + Duration.ofDays(11).plusHours(7).plusMinutes(36).plusSeconds(17), + Duration.ofDays(19).plusHours(12).plusMinutes(25), + Duration.ofDays(1).plusHours(14), + Duration.ofDays(-5), + Duration.ofHours(22).plusMinutes(8).plusSeconds(37), + Duration.ofHours(10).plusMinutes(16), + Duration.ofHours(5), + Duration.ofMinutes(45).plusSeconds(5), + Duration.ofMinutes(27), + Duration.ofSeconds(50)), + (3, + null, + Period.ofYears(1), + null, + null, + Duration.ofMinutes(1), + Duration.ofHours(1), + Duration.ofDays(1), + null, + Duration.ofMinutes(1), + Duration.ofHours(1), + null, + Duration.ofMinutes(1), + null)) + .toDF("class", + "year-month", + "year", + "month", + "day-second", + "day-minute", + "day-hour", + "day", + "hour-second", + "hour-minute", + "hour", + "minute-second", + "minute", + "second") + .select( + $"class", + $"year-month", + $"year" cast YearMonthIntervalType(YEAR) as "year", + $"month" cast YearMonthIntervalType(MONTH) as "month", + $"day-second", + $"day-minute" cast DayTimeIntervalType(DAY, MINUTE) as "day-minute", + $"day-hour" cast DayTimeIntervalType(DAY, HOUR) as "day-hour", + $"day" cast DayTimeIntervalType(DAY) as "day", + $"hour-second" cast DayTimeIntervalType(HOUR, SECOND) as "hour-second", + $"hour-minute" cast DayTimeIntervalType(HOUR, MINUTE) as "hour-minute", + $"hour" cast DayTimeIntervalType(HOUR) as "hour", + $"minute-second" cast DayTimeIntervalType(MINUTE, SECOND) as "minute-second", + $"minute" cast DayTimeIntervalType(MINUTE) as "minute", + $"second" cast DayTimeIntervalType(SECOND) as "second") + /** * Initialize all test data such that all temp tables are properly registered. */