[SPARK-35412][SQL] Fix a bug in groupBy of year-month/day-time intervals
### What changes were proposed in this pull request? To fix a bug below in groupBy of year-month/day-time intervals, this PR proposes to make `HashMapGenerator` handle the two types for hash-aggregates; ``` scala> Seq(java.time.Duration.ofDays(1)).toDF("a").groupBy("a").count().show() scala.MatchError: DayTimeIntervalType (of class org.apache.spark.sql.types.DayTimeIntervalType$) at org.apache.spark.sql.execution.aggregate.HashMapGenerator.genComputeHash(HashMapGenerator.scala:159) at org.apache.spark.sql.execution.aggregate.HashMapGenerator.$anonfun$generateHashFunction$1(HashMapGenerator.scala:102) at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238) at scala.collection.immutable.List.foreach(List.scala:392) at scala.collection.TraversableLike.map(TraversableLike.scala:238) at scala.collection.TraversableLike.map$(TraversableLike.scala:231) at scala.collection.immutable.List.map(List.scala:298) at org.apache.spark.sql.execution.aggregate.HashMapGenerator.genHashForKeys$1(HashMapGenerator.scala:99) at org.apache.spark.sql.execution.aggregate.HashMapGenerator.generateHashFunction(HashMapGenerator.scala:111) ``` ### Why are the changes needed? Bugfix. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added a unit test. Closes #32560 from maropu/FixIntervalIssue. Authored-by: Takeshi Yamamuro <yamamuro@apache.org> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
parent
5c1567ba97
commit
2eef2f9035
|
@ -158,8 +158,8 @@ abstract class HashMapGenerator(
|
|||
|
||||
dataType match {
|
||||
case BooleanType => hashInt(s"$input ? 1 : 0")
|
||||
case ByteType | ShortType | IntegerType | DateType => hashInt(input)
|
||||
case LongType | TimestampType => hashLong(input)
|
||||
case ByteType | ShortType | IntegerType | DateType | YearMonthIntervalType => hashInt(input)
|
||||
case LongType | TimestampType | DayTimeIntervalType => hashLong(input)
|
||||
case FloatType => hashInt(s"Float.floatToIntBits($input)")
|
||||
case DoubleType => hashLong(s"Double.doubleToLongBits($input)")
|
||||
case d: DecimalType =>
|
||||
|
|
|
@ -1196,6 +1196,13 @@ class DataFrameAggregateSuite extends QueryTest
|
|||
val avgDF4 = df3.groupBy($"class").agg(avg($"year-month"), avg($"day-time"))
|
||||
checkAnswer(avgDF4, Nil)
|
||||
}
|
||||
|
||||
test("SPARK-35412: groupBy of year-month/day-time intervals should work") {
|
||||
val df1 = Seq(Duration.ofDays(1)).toDF("a").groupBy("a").count()
|
||||
checkAnswer(df1, Row(Duration.ofDays(1), 1))
|
||||
val df2 = Seq(Period.ofYears(1)).toDF("a").groupBy("a").count()
|
||||
checkAnswer(df2, Row(Period.ofYears(1), 1))
|
||||
}
|
||||
}
|
||||
|
||||
case class B(c: Option[Double])
|
||||
|
|
Loading…
Reference in a new issue