[SPARK-35775][SQL][TESTS] Check all year-month interval types in aggregate expressions
### What changes were proposed in this pull request? This PR adds test to check `sum` and `avg` works with all the `YearMonthInterval` types. ### Why are the changes needed? To ensure the results of aggregations are what is expected. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New test. Closes #32988 from sarutak/check-interval-agg-ym. Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com> Signed-off-by: Max Gekk <max.gekk@gmail.com>
This commit is contained in:
parent
844f10c742
commit
2c91672259
|
@ -34,6 +34,7 @@ import org.apache.spark.sql.internal.SQLConf
|
|||
import org.apache.spark.sql.test.SharedSparkSession
|
||||
import org.apache.spark.sql.test.SQLTestData.DecimalData
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
|
||||
|
||||
case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double)
|
||||
|
||||
|
@ -1115,34 +1116,47 @@ class DataFrameAggregateSuite extends QueryTest
|
|||
}
|
||||
|
||||
test("SPARK-34716: Support ANSI SQL intervals by the aggregate function `sum`") {
|
||||
val df = Seq((1, Period.ofMonths(10), Duration.ofDays(10)),
|
||||
(2, Period.ofMonths(1), Duration.ofDays(1)),
|
||||
(2, null, null),
|
||||
(3, Period.ofMonths(-3), Duration.ofDays(-6)),
|
||||
(3, Period.ofMonths(21), Duration.ofDays(-5)))
|
||||
.toDF("class", "year-month", "day-time")
|
||||
val df = Seq(
|
||||
(1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)),
|
||||
(2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)),
|
||||
(2, null, null, null, null),
|
||||
(3, Period.ofMonths(-3), Period.ofYears(-12), Period.ofMonths(-3), Duration.ofDays(-6)),
|
||||
(3, Period.ofMonths(21), Period.ofYears(30), Period.ofMonths(5), Duration.ofDays(-5)))
|
||||
.toDF("class", "year-month", "year", "month", "day-time")
|
||||
.select(
|
||||
$"class",
|
||||
$"year-month",
|
||||
$"year" cast YearMonthIntervalType(YEAR) as "year",
|
||||
$"month" cast YearMonthIntervalType(MONTH) as "month",
|
||||
$"day-time")
|
||||
|
||||
val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)),
|
||||
(Period.ofMonths(10), Duration.ofDays(10)))
|
||||
.toDF("year-month", "day-time")
|
||||
|
||||
val sumDF = df.select(sum($"year-month"), sum($"day-time"))
|
||||
checkAnswer(sumDF, Row(Period.of(2, 5, 0), Duration.ofDays(0)))
|
||||
val sumDF = df.select(sum($"year-month"), sum($"year"), sum($"month"), sum($"day-time"))
|
||||
checkAnswer(sumDF,
|
||||
Row(Period.of(2, 5, 0), Period.ofYears(27), Period.of(1, 1, 0), Duration.ofDays(0)))
|
||||
assert(find(sumDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
|
||||
assert(sumDF.schema == StructType(Seq(
|
||||
// TODO(SPARK-35775): Check all year-month interval types in aggregate expressions
|
||||
StructField("sum(year-month)", YearMonthIntervalType()),
|
||||
StructField("sum(year)", YearMonthIntervalType(YEAR)),
|
||||
StructField("sum(month)", YearMonthIntervalType(MONTH)),
|
||||
// TODO(SPARK-35729): Check all day-time interval types in aggregate expressions
|
||||
StructField("sum(day-time)", DayTimeIntervalType()))))
|
||||
|
||||
val sumDF2 = df.groupBy($"class").agg(sum($"year-month"), sum($"day-time"))
|
||||
checkAnswer(sumDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) ::
|
||||
Row(2, Period.ofMonths(1), Duration.ofDays(1)) ::
|
||||
Row(3, Period.of(1, 6, 0), Duration.ofDays(-11)) :: Nil)
|
||||
val sumDF2 =
|
||||
df.groupBy($"class").agg(sum($"year-month"), sum($"year"), sum($"month"), sum($"day-time"))
|
||||
checkAnswer(sumDF2,
|
||||
Row(1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)) ::
|
||||
Row(2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)) ::
|
||||
Row(3, Period.of(1, 6, 0), Period.ofYears(18), Period.ofMonths(2), Duration.ofDays(-11)) ::
|
||||
Nil)
|
||||
assert(find(sumDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
|
||||
assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType, false),
|
||||
// TODO(SPARK-35775): Check all year-month interval types in aggregate expressions
|
||||
StructField("sum(year-month)", YearMonthIntervalType()),
|
||||
StructField("sum(year)", YearMonthIntervalType(YEAR)),
|
||||
StructField("sum(month)", YearMonthIntervalType(MONTH)),
|
||||
// TODO(SPARK-35729): Check all day-time interval types in aggregate expressions
|
||||
StructField("sum(day-time)", DayTimeIntervalType()))))
|
||||
|
||||
|
@ -1158,34 +1172,48 @@ class DataFrameAggregateSuite extends QueryTest
|
|||
}
|
||||
|
||||
test("SPARK-34837: Support ANSI SQL intervals by the aggregate function `avg`") {
|
||||
val df = Seq((1, Period.ofMonths(10), Duration.ofDays(10)),
|
||||
(2, Period.ofMonths(1), Duration.ofDays(1)),
|
||||
(2, null, null),
|
||||
(3, Period.ofMonths(-3), Duration.ofDays(-6)),
|
||||
(3, Period.ofMonths(21), Duration.ofDays(-5)))
|
||||
.toDF("class", "year-month", "day-time")
|
||||
val df = Seq(
|
||||
(1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)),
|
||||
(2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)),
|
||||
(2, null, null, null, null),
|
||||
(3, Period.ofMonths(-3), Period.ofYears(-12), Period.ofMonths(-3), Duration.ofDays(-6)),
|
||||
(3, Period.ofMonths(21), Period.ofYears(30), Period.ofMonths(5), Duration.ofDays(-5)),
|
||||
(3, null, Period.ofYears(1), null, null))
|
||||
.toDF("class", "year-month", "year", "month", "day-time")
|
||||
.select(
|
||||
$"class",
|
||||
$"year-month",
|
||||
$"year" cast YearMonthIntervalType(YEAR) as "year",
|
||||
$"month" cast YearMonthIntervalType(MONTH) as "month",
|
||||
$"day-time")
|
||||
|
||||
val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)),
|
||||
(Period.ofMonths(10), Duration.ofDays(10)))
|
||||
.toDF("year-month", "day-time")
|
||||
|
||||
val avgDF = df.select(avg($"year-month"), avg($"day-time"))
|
||||
checkAnswer(avgDF, Row(Period.ofMonths(7), Duration.ofDays(0)))
|
||||
val avgDF = df.select(avg($"year-month"), avg($"year"), avg($"month"), avg($"day-time"))
|
||||
checkAnswer(avgDF,
|
||||
Row(Period.ofMonths(7), Period.of(5, 7, 0), Period.ofMonths(3), Duration.ofDays(0)))
|
||||
assert(find(avgDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
|
||||
assert(avgDF.schema == StructType(Seq(
|
||||
// TODO(SPARK-35775): Check all year-month interval types in aggregate expressions
|
||||
StructField("avg(year-month)", YearMonthIntervalType()),
|
||||
StructField("avg(year)", YearMonthIntervalType()),
|
||||
StructField("avg(month)", YearMonthIntervalType()),
|
||||
// TODO(SPARK-35729): Check all day-time interval types in aggregate expressions
|
||||
StructField("avg(day-time)", DayTimeIntervalType()))))
|
||||
|
||||
val avgDF2 = df.groupBy($"class").agg(avg($"year-month"), avg($"day-time"))
|
||||
checkAnswer(avgDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) ::
|
||||
Row(2, Period.ofMonths(1), Duration.ofDays(1)) ::
|
||||
Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) :: Nil)
|
||||
val avgDF2 =
|
||||
df.groupBy($"class").agg(avg($"year-month"), avg($"year"), avg($"month"), avg($"day-time"))
|
||||
checkAnswer(avgDF2,
|
||||
Row(1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)) ::
|
||||
Row(2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)) ::
|
||||
Row(3, Period.ofMonths(9), Period.of(6, 4, 0), Period.ofMonths(1),
|
||||
Duration.ofDays(-5).plusHours(-12)) :: Nil)
|
||||
assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
|
||||
assert(avgDF2.schema == StructType(Seq(StructField("class", IntegerType, false),
|
||||
// TODO(SPARK-35775): Check all year-month interval types in aggregate expressions
|
||||
StructField("avg(year-month)", YearMonthIntervalType()),
|
||||
StructField("avg(year)", YearMonthIntervalType()),
|
||||
StructField("avg(month)", YearMonthIntervalType()),
|
||||
// TODO(SPARK-35729): Check all day-time interval types in aggregate expressions
|
||||
StructField("avg(day-time)", DayTimeIntervalType()))))
|
||||
|
||||
|
|
Loading…
Reference in a new issue