[SPARK-35775][SQL][TESTS] Check all year-month interval types in aggregate expressions

### What changes were proposed in this pull request?

This PR adds test to check `sum` and `avg` works with all the `YearMonthInterval` types.

### Why are the changes needed?

To ensure the results of aggregations are what is expected.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New test.

Closes #32988 from sarutak/check-interval-agg-ym.

Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
This commit is contained in:
Kousuke Saruta 2021-06-21 16:47:29 +03:00 committed by Max Gekk
parent 844f10c742
commit 2c91672259

View file

@ -34,6 +34,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SQLTestData.DecimalData
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double)
@ -1115,34 +1116,47 @@ class DataFrameAggregateSuite extends QueryTest
}
test("SPARK-34716: Support ANSI SQL intervals by the aggregate function `sum`") {
val df = Seq((1, Period.ofMonths(10), Duration.ofDays(10)),
(2, Period.ofMonths(1), Duration.ofDays(1)),
(2, null, null),
(3, Period.ofMonths(-3), Duration.ofDays(-6)),
(3, Period.ofMonths(21), Duration.ofDays(-5)))
.toDF("class", "year-month", "day-time")
val df = Seq(
(1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)),
(2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)),
(2, null, null, null, null),
(3, Period.ofMonths(-3), Period.ofYears(-12), Period.ofMonths(-3), Duration.ofDays(-6)),
(3, Period.ofMonths(21), Period.ofYears(30), Period.ofMonths(5), Duration.ofDays(-5)))
.toDF("class", "year-month", "year", "month", "day-time")
.select(
$"class",
$"year-month",
$"year" cast YearMonthIntervalType(YEAR) as "year",
$"month" cast YearMonthIntervalType(MONTH) as "month",
$"day-time")
val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)),
(Period.ofMonths(10), Duration.ofDays(10)))
.toDF("year-month", "day-time")
val sumDF = df.select(sum($"year-month"), sum($"day-time"))
checkAnswer(sumDF, Row(Period.of(2, 5, 0), Duration.ofDays(0)))
val sumDF = df.select(sum($"year-month"), sum($"year"), sum($"month"), sum($"day-time"))
checkAnswer(sumDF,
Row(Period.of(2, 5, 0), Period.ofYears(27), Period.of(1, 1, 0), Duration.ofDays(0)))
assert(find(sumDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
assert(sumDF.schema == StructType(Seq(
// TODO(SPARK-35775): Check all year-month interval types in aggregate expressions
StructField("sum(year-month)", YearMonthIntervalType()),
StructField("sum(year)", YearMonthIntervalType(YEAR)),
StructField("sum(month)", YearMonthIntervalType(MONTH)),
// TODO(SPARK-35729): Check all day-time interval types in aggregate expressions
StructField("sum(day-time)", DayTimeIntervalType()))))
val sumDF2 = df.groupBy($"class").agg(sum($"year-month"), sum($"day-time"))
checkAnswer(sumDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) ::
Row(2, Period.ofMonths(1), Duration.ofDays(1)) ::
Row(3, Period.of(1, 6, 0), Duration.ofDays(-11)) :: Nil)
val sumDF2 =
df.groupBy($"class").agg(sum($"year-month"), sum($"year"), sum($"month"), sum($"day-time"))
checkAnswer(sumDF2,
Row(1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)) ::
Row(2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)) ::
Row(3, Period.of(1, 6, 0), Period.ofYears(18), Period.ofMonths(2), Duration.ofDays(-11)) ::
Nil)
assert(find(sumDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType, false),
// TODO(SPARK-35775): Check all year-month interval types in aggregate expressions
StructField("sum(year-month)", YearMonthIntervalType()),
StructField("sum(year)", YearMonthIntervalType(YEAR)),
StructField("sum(month)", YearMonthIntervalType(MONTH)),
// TODO(SPARK-35729): Check all day-time interval types in aggregate expressions
StructField("sum(day-time)", DayTimeIntervalType()))))
@ -1158,34 +1172,48 @@ class DataFrameAggregateSuite extends QueryTest
}
test("SPARK-34837: Support ANSI SQL intervals by the aggregate function `avg`") {
val df = Seq((1, Period.ofMonths(10), Duration.ofDays(10)),
(2, Period.ofMonths(1), Duration.ofDays(1)),
(2, null, null),
(3, Period.ofMonths(-3), Duration.ofDays(-6)),
(3, Period.ofMonths(21), Duration.ofDays(-5)))
.toDF("class", "year-month", "day-time")
val df = Seq(
(1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)),
(2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)),
(2, null, null, null, null),
(3, Period.ofMonths(-3), Period.ofYears(-12), Period.ofMonths(-3), Duration.ofDays(-6)),
(3, Period.ofMonths(21), Period.ofYears(30), Period.ofMonths(5), Duration.ofDays(-5)),
(3, null, Period.ofYears(1), null, null))
.toDF("class", "year-month", "year", "month", "day-time")
.select(
$"class",
$"year-month",
$"year" cast YearMonthIntervalType(YEAR) as "year",
$"month" cast YearMonthIntervalType(MONTH) as "month",
$"day-time")
val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)),
(Period.ofMonths(10), Duration.ofDays(10)))
.toDF("year-month", "day-time")
val avgDF = df.select(avg($"year-month"), avg($"day-time"))
checkAnswer(avgDF, Row(Period.ofMonths(7), Duration.ofDays(0)))
val avgDF = df.select(avg($"year-month"), avg($"year"), avg($"month"), avg($"day-time"))
checkAnswer(avgDF,
Row(Period.ofMonths(7), Period.of(5, 7, 0), Period.ofMonths(3), Duration.ofDays(0)))
assert(find(avgDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
assert(avgDF.schema == StructType(Seq(
// TODO(SPARK-35775): Check all year-month interval types in aggregate expressions
StructField("avg(year-month)", YearMonthIntervalType()),
StructField("avg(year)", YearMonthIntervalType()),
StructField("avg(month)", YearMonthIntervalType()),
// TODO(SPARK-35729): Check all day-time interval types in aggregate expressions
StructField("avg(day-time)", DayTimeIntervalType()))))
val avgDF2 = df.groupBy($"class").agg(avg($"year-month"), avg($"day-time"))
checkAnswer(avgDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) ::
Row(2, Period.ofMonths(1), Duration.ofDays(1)) ::
Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) :: Nil)
val avgDF2 =
df.groupBy($"class").agg(avg($"year-month"), avg($"year"), avg($"month"), avg($"day-time"))
checkAnswer(avgDF2,
Row(1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)) ::
Row(2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)) ::
Row(3, Period.ofMonths(9), Period.of(6, 4, 0), Period.ofMonths(1),
Duration.ofDays(-5).plusHours(-12)) :: Nil)
assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
assert(avgDF2.schema == StructType(Seq(StructField("class", IntegerType, false),
// TODO(SPARK-35775): Check all year-month interval types in aggregate expressions
StructField("avg(year-month)", YearMonthIntervalType()),
StructField("avg(year)", YearMonthIntervalType()),
StructField("avg(month)", YearMonthIntervalType()),
// TODO(SPARK-35729): Check all day-time interval types in aggregate expressions
StructField("avg(day-time)", DayTimeIntervalType()))))