diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index d730586a7b..3785262f08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -490,7 +490,7 @@ abstract class HashExpression[E] extends Expression { case BooleanType => genHashBoolean(input, result) case ByteType | ShortType | IntegerType | DateType => genHashInt(input, result) case LongType => genHashLong(input, result) - case TimestampType => genHashTimestamp(input, result) + case TimestampType | TimestampNTZType => genHashTimestamp(input, result) case FloatType => genHashFloat(input, result) case DoubleType => genHashDouble(input, result) case d: DecimalType => genHashDecimal(ctx, d, input, result) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 7da5a28771..b4b6903cda 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -553,7 +553,8 @@ public final class OffHeapColumnVector extends WritableColumnVector { type instanceof DateType || DecimalType.is32BitDecimalType(type)) { this.data = Platform.reallocateMemory(data, oldCapacity * 4L, newCapacity * 4L); } else if (type instanceof LongType || type instanceof DoubleType || - DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { + DecimalType.is64BitDecimalType(type) || type instanceof TimestampType || + type instanceof TimestampNTZType) { this.data = Platform.reallocateMemory(data, oldCapacity * 8L, newCapacity * 8L); } else if (childColumns != null) { // Nothing to store. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 5942c5f00a..3fb96d872c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -547,7 +547,8 @@ public final class OnHeapColumnVector extends WritableColumnVector { if (intData != null) System.arraycopy(intData, 0, newData, 0, capacity); intData = newData; } - } else if (type instanceof LongType || type instanceof TimestampType || + } else if (type instanceof LongType || + type instanceof TimestampType ||type instanceof TimestampNTZType || DecimalType.is64BitDecimalType(type) || type instanceof DayTimeIntervalType) { if (longData == null || longData.length < newCapacity) { long[] newData = new long[newCapacity]; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index b3f5e341f6..713e7db4cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -160,7 +160,7 @@ abstract class HashMapGenerator( case BooleanType => hashInt(s"$input ? 1 : 0") case ByteType | ShortType | IntegerType | DateType | _: YearMonthIntervalType => hashInt(input) - case LongType | TimestampType | _: DayTimeIntervalType => hashLong(input) + case LongType | TimestampType | TimestampNTZType | _: DayTimeIntervalType => hashLong(input) case FloatType => hashInt(s"Float.floatToIntBits($input)") case DoubleType => hashLong(s"Double.doubleToLongBits($input)") case d: DecimalType => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 9cd743602b..62d68b85c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.time.{Duration, Period} +import java.time.{Duration, LocalDateTime, Period} import scala.util.Random @@ -1398,6 +1398,17 @@ class DataFrameAggregateSuite extends QueryTest val df2 = Seq(Period.ofYears(1)).toDF("a").groupBy("a").count() checkAnswer(df2, Row(Period.ofYears(1), 1)) } + + test("SPARK-36054: Support group by TimestampNTZ column") { + val ts1 = "2021-01-01T00:00:00" + val ts2 = "2021-01-01T00:00:01" + val localDateTime = Seq(ts1, ts1, ts2).map(LocalDateTime.parse) + val df = localDateTime.toDF("ts").groupBy("ts").count().orderBy("ts") + val expectedSchema = + new StructType().add(StructField("ts", TimestampNTZType)).add("count", LongType, false) + assert (df.schema == expectedSchema) + checkAnswer(df, Seq(Row(LocalDateTime.parse(ts1), 2), Row(LocalDateTime.parse(ts2), 1))) + } } case class B(c: Option[Double])