From e858cd568a74123f7fd8fe4c3d2917a7e5bbb685 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 13 Sep 2021 21:47:43 +0800 Subject: [PATCH] [SPARK-36724][SQL] Support timestamp_ntz as a type of time column for SessionWindow ### What changes were proposed in this pull request? This PR proposes to support `timestamp_ntz` as a type of time column for `SessionWIndow` like `TimeWindow` does. ### Why are the changes needed? For better usability. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New test. Closes #33965 from sarutak/session-window-ntz. Authored-by: Kousuke Saruta Signed-off-by: Gengliang Wang --- .../sql/catalyst/analysis/Analyzer.scala | 9 ++--- .../catalyst/expressions/SessionWindow.scala | 6 ++-- .../sql/DataFrameSessionWindowingSuite.scala | 33 +++++++++++++++++-- 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 340b859590..0f90159d58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3999,7 +3999,8 @@ object SessionWindowing extends Rule[LogicalPlan] { val sessionAttr = AttributeReference( SESSION_COL_NAME, session.dataType, metadata = newMetadata)() - val sessionStart = PreciseTimestampConversion(session.timeColumn, TimestampType, LongType) + val sessionStart = + PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType) val gapDuration = session.gapDuration match { case expr if Cast.canCast(expr.dataType, CalendarIntervalType) => Cast(expr, CalendarIntervalType) @@ -4007,13 +4008,13 @@ object SessionWindowing extends Rule[LogicalPlan] { throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType) } val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration, - TimestampType, LongType) + session.timeColumn.dataType, LongType) val literalSessionStruct = CreateNamedStruct( Literal(SESSION_START) :: - PreciseTimestampConversion(sessionStart, LongType, TimestampType) :: + PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType) :: Literal(SESSION_END) :: - PreciseTimestampConversion(sessionEnd, LongType, TimestampType) :: + PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType) :: Nil) val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala index 796ea27efc..77e8dfde87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala @@ -69,10 +69,10 @@ case class SessionWindow(timeColumn: Expression, gapDuration: Expression) extend with NonSQLExpression { override def children: Seq[Expression] = Seq(timeColumn, gapDuration) - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, AnyDataType) + override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimestampType, AnyDataType) override def dataType: DataType = new StructType() - .add(StructField("start", TimestampType)) - .add(StructField("end", TimestampType)) + .add(StructField("start", timeColumn.dataType)) + .add(StructField("end", timeColumn.dataType)) // This expression is replaced in the analyzer. override lazy val resolved = false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala index 7a0cd420d4..b3d212716d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql +import java.time.LocalDateTime + import org.scalatest.BeforeAndAfterEach -import org.apache.spark.sql.catalyst.plans.logical.Expand +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types._ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession with BeforeAndAfterEach { @@ -377,4 +380,30 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession ) } } + + test("SPARK-36724: Support timestamp_ntz as a type of time column for SessionWindow") { + val df = Seq((LocalDateTime.parse("2016-03-27T19:39:30"), 1, "a"), + (LocalDateTime.parse("2016-03-27T19:39:25"), 2, "a")).toDF("time", "value", "id") + val aggDF = + df.groupBy(session_window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"session_window.start".asc) + .select($"session_window.start".cast("string"), + $"session_window.end".cast("string"), $"counts") + + val aggregate = aggDF.queryExecution.analyzed.children(0).children(0) + assert(aggregate.isInstanceOf[Aggregate]) + + val timeWindow = aggregate.asInstanceOf[Aggregate].groupingExpressions(0) + assert(timeWindow.isInstanceOf[AttributeReference]) + + val attributeReference = timeWindow.asInstanceOf[AttributeReference] + assert(attributeReference.name == "session_window") + + val expectedSchema = StructType( + Seq(StructField("start", TimestampNTZType), StructField("end", TimestampNTZType))) + assert(attributeReference.dataType == expectedSchema) + + checkAnswer(aggDF, Seq(Row("2016-03-27 19:39:25", "2016-03-27 19:39:40", 2))) + } }