[SPARK-36724][SQL] Support timestamp_ntz as a type of time column for SessionWindow

### What changes were proposed in this pull request?

This PR proposes to support `timestamp_ntz` as a type of time column for `SessionWIndow` like `TimeWindow` does.

### Why are the changes needed?

For better usability.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New test.

Closes #33965 from sarutak/session-window-ntz.

Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
This commit is contained in:
Kousuke Saruta 2021-09-13 21:47:43 +08:00 committed by Gengliang Wang
parent 3747cfdb40
commit e858cd568a
3 changed files with 39 additions and 9 deletions

View file

@ -3999,7 +3999,8 @@ object SessionWindowing extends Rule[LogicalPlan] {
val sessionAttr = AttributeReference(
SESSION_COL_NAME, session.dataType, metadata = newMetadata)()
val sessionStart = PreciseTimestampConversion(session.timeColumn, TimestampType, LongType)
val sessionStart =
PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType)
val gapDuration = session.gapDuration match {
case expr if Cast.canCast(expr.dataType, CalendarIntervalType) =>
Cast(expr, CalendarIntervalType)
@ -4007,13 +4008,13 @@ object SessionWindowing extends Rule[LogicalPlan] {
throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType)
}
val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration,
TimestampType, LongType)
session.timeColumn.dataType, LongType)
val literalSessionStruct = CreateNamedStruct(
Literal(SESSION_START) ::
PreciseTimestampConversion(sessionStart, LongType, TimestampType) ::
PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType) ::
Literal(SESSION_END) ::
PreciseTimestampConversion(sessionEnd, LongType, TimestampType) ::
PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType) ::
Nil)
val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(

View file

@ -69,10 +69,10 @@ case class SessionWindow(timeColumn: Expression, gapDuration: Expression) extend
with NonSQLExpression {
override def children: Seq[Expression] = Seq(timeColumn, gapDuration)
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, AnyDataType)
override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimestampType, AnyDataType)
override def dataType: DataType = new StructType()
.add(StructField("start", TimestampType))
.add(StructField("end", TimestampType))
.add(StructField("start", timeColumn.dataType))
.add(StructField("end", timeColumn.dataType))
// This expression is replaced in the analyzer.
override lazy val resolved = false

View file

@ -17,12 +17,15 @@
package org.apache.spark.sql
import java.time.LocalDateTime
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.sql.catalyst.plans.logical.Expand
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types._
class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
with BeforeAndAfterEach {
@ -377,4 +380,30 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
)
}
}
test("SPARK-36724: Support timestamp_ntz as a type of time column for SessionWindow") {
val df = Seq((LocalDateTime.parse("2016-03-27T19:39:30"), 1, "a"),
(LocalDateTime.parse("2016-03-27T19:39:25"), 2, "a")).toDF("time", "value", "id")
val aggDF =
df.groupBy(session_window($"time", "10 seconds"))
.agg(count("*").as("counts"))
.orderBy($"session_window.start".asc)
.select($"session_window.start".cast("string"),
$"session_window.end".cast("string"), $"counts")
val aggregate = aggDF.queryExecution.analyzed.children(0).children(0)
assert(aggregate.isInstanceOf[Aggregate])
val timeWindow = aggregate.asInstanceOf[Aggregate].groupingExpressions(0)
assert(timeWindow.isInstanceOf[AttributeReference])
val attributeReference = timeWindow.asInstanceOf[AttributeReference]
assert(attributeReference.name == "session_window")
val expectedSchema = StructType(
Seq(StructField("start", TimestampNTZType), StructField("end", TimestampNTZType)))
assert(attributeReference.dataType == expectedSchema)
checkAnswer(aggDF, Seq(Row("2016-03-27 19:39:25", "2016-03-27 19:39:40", 2)))
}
}