[SPARK-36465][SS] Dynamic gap duration in session window

### What changes were proposed in this pull request?

This patch supports dynamic gap duration in session window.

### Why are the changes needed?

The gap duration used in session window for now is a static value. To support more complex usage, it is better to support dynamic gap duration which determines the gap duration by looking at the current data. For example, in our usecase, we may have different gap by looking at the certain column in the input rows.

### Does this PR introduce _any_ user-facing change?

Yes, users can specify dynamic gap duration.

### How was this patch tested?

Modified existing tests and new test.

Closes #33691 from viirya/dynamic-session-window-gap.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
This commit is contained in:
Liang-Chi Hsieh 2021-08-16 11:06:00 +09:00 committed by Jungtaek Lim
parent e369499d14
commit 8b8d91cf64
9 changed files with 358 additions and 46 deletions

View file

@ -1080,8 +1080,8 @@ Tumbling and sliding window use `window` function, which has been described on a
Session windows have different characteristic compared to the previous two types. Session window has a dynamic size
of the window length, depending on the inputs. A session window starts with an input, and expands itself
if following input has been received within gap duration. A session window closes when there's no input
received within gap duration after receiving the latest input.
if following input has been received within gap duration. For static gap duration, a session window closes when
there's no input received within gap duration after receiving the latest input.
Session window uses `session_window` function. The usage of the function is similar to the `window` function.
@ -1134,6 +1134,77 @@ sessionizedCounts = events \
</div>
</div>
Instead of static value, we can also provide an expression to specify gap duration dynamically
based on the input row. Note that the rows with negative or zero gap duration will be filtered
out from the aggregation.
With dynamic gap duration, the closing of a session window does not depend on the latest input
anymore. A session window's range is the union of all events' ranges which are determined by
event start time and evaluated gap duration during the query execution.
<div class="codetabs">
<div data-lang="scala" markdown="1">
{% highlight scala %}
import spark.implicits._
val events = ... // streaming DataFrame of schema { timestamp: Timestamp, userId: String }
val sessionWindow = session_window($"timestamp", when($"userId" === "user1", "5 seconds")
.when($"userId" === "user2", "20 seconds")
.otherwise("5 minutes"))
// Group the data by session window and userId, and compute the count of each group
val sessionizedCounts = events
.withWatermark("timestamp", "10 minutes")
.groupBy(
Column(sessionWindow),
$"userId")
.count()
{% endhighlight %}
</div>
<div data-lang="java" markdown="1">
{% highlight java %}
Dataset<Row> events = ... // streaming DataFrame of schema { timestamp: Timestamp, userId: String }
SessionWindow sessionWindow = session_window(col("timestamp"), when(col("userId").equalTo("user1"), "5 seconds")
.when(col("userId").equalTo("user2"), "20 seconds")
.otherwise("5 minutes"))
// Group the data by session window and userId, and compute the count of each group
Dataset<Row> sessionizedCounts = events
.withWatermark("timestamp", "10 minutes")
.groupBy(
new Column(sessionWindow),
col("userId"))
.count();
{% endhighlight %}
</div>
<div data-lang="python" markdown="1">
{% highlight python %}
from pyspark.sql import functions as F
events = ... # streaming DataFrame of schema { timestamp: Timestamp, userId: String }
session_window = session_window(events.timestamp, \
F.when(events.userId == "user1", "5 seconds") \
.when(events.userId == "user2", "20 seconds").otherwise("5 minutes"))
# Group the data by session window and userId, and compute the count of each group
sessionizedCounts = events \
.withWatermark("timestamp", "10 minutes") \
.groupBy(
session_window,
events.userId) \
.count()
{% endhighlight %}
</div>
</div>
Note that there are some restrictions when you use session window in streaming query, like below:
- "Update mode" as output mode is not supported.

View file

@ -2346,6 +2346,8 @@ def session_window(timeColumn, gapDuration):
processing time.
gapDuration is provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid
interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'.
It could also be a Column which can be evaluated to gap duration dynamically based on the
input row.
The output column will be a struct called 'session_window' by default with the nested columns
'start' and 'end', where 'start' and 'end' will be of :class:`pyspark.sql.types.TimestampType`.
.. versionadded:: 3.2.0
@ -2356,15 +2358,24 @@ def session_window(timeColumn, gapDuration):
>>> w.select(w.session_window.start.cast("string").alias("start"),
... w.session_window.end.cast("string").alias("end"), "sum").collect()
[Row(start='2016-03-11 09:00:07', end='2016-03-11 09:00:12', sum=1)]
>>> w = df.groupBy(session_window("date", lit("5 seconds"))).agg(sum("val").alias("sum"))
>>> w.select(w.session_window.start.cast("string").alias("start"),
... w.session_window.end.cast("string").alias("end"), "sum").collect()
[Row(start='2016-03-11 09:00:07', end='2016-03-11 09:00:12', sum=1)]
"""
def check_string_field(field, fieldName):
if not field or type(field) is not str:
raise TypeError("%s should be provided as a string" % fieldName)
def check_field(field, fieldName):
if field is None or not isinstance(field, (str, Column)):
raise TypeError("%s should be provided as a string or Column" % fieldName)
sc = SparkContext._active_spark_context
time_col = _to_java_column(timeColumn)
check_string_field(gapDuration, "gapDuration")
res = sc._jvm.functions.session_window(time_col, gapDuration)
check_field(gapDuration, "gapDuration")
gap_duration = (
gapDuration
if isinstance(gapDuration, str)
else _to_java_column(gapDuration)
)
res = sc._jvm.functions.session_window(time_col, gap_duration)
return Column(res)

View file

@ -136,7 +136,7 @@ def window(
slideDuration: Optional[str] = ...,
startTime: Optional[str] = ...,
) -> Column: ...
def session_window(timeColumn: ColumnOrName, gapDuration: str) -> Column: ...
def session_window(timeColumn: ColumnOrName, gapDuration: Union[Column, str]) -> Column: ...
def crc32(col: ColumnOrName) -> Column: ...
def md5(col: ColumnOrName) -> Column: ...
def sha1(col: ColumnOrName) -> Column: ...

View file

@ -3984,7 +3984,14 @@ object SessionWindowing extends Rule[LogicalPlan] {
SESSION_COL_NAME, session.dataType, metadata = newMetadata)()
val sessionStart = PreciseTimestampConversion(session.timeColumn, TimestampType, LongType)
val sessionEnd = sessionStart + session.gapDuration
val gapDuration = session.gapDuration match {
case expr if Cast.canCast(expr.dataType, CalendarIntervalType) =>
Cast(expr, CalendarIntervalType)
case other =>
throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType)
}
val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration,
TimestampType, LongType)
val literalSessionStruct = CreateNamedStruct(
Literal(SESSION_START) ::
@ -4001,11 +4008,13 @@ object SessionWindowing extends Rule[LogicalPlan] {
}
// As same as tumbling window, we add a filter to filter out nulls.
val filterExpr = IsNotNull(session.timeColumn)
// And we also filter out events with negative or zero gap duration.
val filterExpr = IsNotNull(session.timeColumn) &&
(sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START))
replacedPlan.withNewChildren(
Project(sessionStruct +: child.output,
Filter(filterExpr, child)) :: Nil)
Filter(filterExpr,
Project(sessionStruct +: child.output, child)) :: Nil)
} else if (numWindowExpr > 1) {
throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p)
} else {

View file

@ -17,32 +17,31 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
/**
* Represent the session window.
*
* @param timeColumn the start time of session window
* @param gapDuration the duration of session gap, meaning the session will close if there is
* no new element appeared within "the last element in session + gap".
* @param gapDuration the duration of session gap. For static gap duration, meaning the session
* will close if there is no new element appeared within "the last element in
* session + gap". Besides a static gap duration value, users can also provide
* an expression to specify gap duration dynamically based on the input row.
* With dynamic gap duration, the closing of a session window does not depend
* on the latest input anymore. A session window's range is the union of all
* events' ranges which are determined by event start time and evaluated gap
* duration during the query execution. Note that the rows with negative or
* zero gap duration will be filtered out from the aggregation.
*/
case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends UnaryExpression
case class SessionWindow(timeColumn: Expression, gapDuration: Expression) extends Expression
with ImplicitCastInputTypes
with Unevaluable
with NonSQLExpression {
//////////////////////////
// SQL Constructors
//////////////////////////
def this(timeColumn: Expression, gapDuration: Expression) = {
this(timeColumn, TimeWindow.parseExpression(gapDuration))
}
override def child: Expression = timeColumn
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
override def children: Seq[Expression] = Seq(timeColumn, gapDuration)
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, AnyDataType)
override def dataType: DataType = new StructType()
.add(StructField("start", TimestampType))
.add(StructField("end", TimestampType))
@ -50,19 +49,10 @@ case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends Unar
// This expression is replaced in the analyzer.
override lazy val resolved = false
/** Validate the inputs for the gap duration in addition to the input data type. */
override def checkInputDataTypes(): TypeCheckResult = {
val dataTypeCheck = super.checkInputDataTypes()
if (dataTypeCheck.isSuccess) {
if (gapDuration <= 0) {
return TypeCheckFailure(s"The window duration ($gapDuration) must be greater than 0.")
}
}
dataTypeCheck
}
override def nullable: Boolean = false
override protected def withNewChildInternal(newChild: Expression): Expression =
copy(timeColumn = newChild)
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(timeColumn = newChildren(0), gapDuration = newChildren(1))
}
object SessionWindow {
@ -72,6 +62,7 @@ object SessionWindow {
timeColumn: Expression,
gapDuration: String): SessionWindow = {
SessionWindow(timeColumn,
TimeWindow.getIntervalInMicroSeconds(gapDuration))
Literal(IntervalUtils.safeStringToInterval(UTF8String.fromString(gapDuration)),
CalendarIntervalType))
}
}

View file

@ -371,6 +371,11 @@ private[spark] object QueryCompilationErrors {
t.origin.startPosition)
}
def sessionWindowGapDurationDataTypeError(dt: DataType): Throwable = {
new AnalysisException("Gap duration expression used in session window must be " +
s"CalendarIntervalType, but got ${dt}")
}
def viewOutputNumberMismatchQueryColumnNamesError(
output: Seq[Attribute], queryColumnNames: Seq[String]): Throwable = {
new AnalysisException(

View file

@ -3661,6 +3661,43 @@ object functions {
}.as("session_window")
}
/**
* Generates session window given a timestamp specifying column.
*
* Session window is one of dynamic windows, which means the length of window is varying
* according to the given inputs. For static gap duration, the length of session window
* is defined as "the timestamp of latest input of the session + gap duration", so when
* the new inputs are bound to the current session window, the end time of session window
* can be expanded according to the new inputs.
*
* Besides a static gap duration value, users can also provide an expression to specify
* gap duration dynamically based on the input row. With dynamic gap duration, the closing
* of a session window does not depend on the latest input anymore. A session window's range
* is the union of all events' ranges which are determined by event start time and evaluated
* gap duration during the query execution. Note that the rows with negative or zero gap
* duration will be filtered out from the aggregation.
*
* Windows can support microsecond precision. gapDuration in the order of months are not
* supported.
*
* For a streaming query, you may use the function `current_timestamp` to generate windows on
* processing time.
*
* @param timeColumn The column or the expression to use as the timestamp for windowing by time.
* The time column must be of TimestampType.
* @param gapDuration A column specifying the timeout of the session. It could be static value,
* e.g. `10 minutes`, `1 second`, or an expression/UDF that specifies gap
* duration dynamically based on the input row.
*
* @group datetime_funcs
* @since 3.2.0
*/
def session_window(timeColumn: Column, gapDuration: Column): Column = {
withExpr {
SessionWindow(timeColumn.expr, gapDuration.expr)
}.as("session_window")
}
/**
* Creates timestamp from the number of seconds since UTC epoch.
* @group datetime_funcs

View file

@ -263,9 +263,10 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
private def withTempTable(f: String => Unit): Unit = {
val tableName = "temp"
Seq(
("2016-03-27 19:39:34", 1),
("2016-03-27 19:39:56", 2),
("2016-03-27 19:39:27", 4)).toDF("time", "value").createOrReplaceTempView(tableName)
("2016-03-27 19:39:34", 1, "10 seconds"),
("2016-03-27 19:39:56", 2, "20 seconds"),
("2016-03-27 19:39:27", 4, "30 seconds")).toDF("time", "value", "duration")
.createOrReplaceTempView(tableName)
try {
f(tableName)
} finally {
@ -287,4 +288,93 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
)
}
}
test("SPARK-36465: time window in SQL with dynamic string expression") {
withTempTable { table =>
checkAnswer(
spark.sql(s"""select session_window(time, duration), value from $table""")
.select($"session_window.start".cast(StringType), $"session_window.end".cast(StringType),
$"value"),
Seq(
Row("2016-03-27 19:39:27", "2016-03-27 19:39:57", 4),
Row("2016-03-27 19:39:34", "2016-03-27 19:39:44", 1),
Row("2016-03-27 19:39:56", "2016-03-27 19:40:16", 2)
)
)
}
}
test("SPARK-36465: Unsupported dynamic gap datatype") {
withTempTable { table =>
val err = intercept[AnalysisException] {
spark.sql(s"""select session_window(time, 1.0), value from $table""")
.select($"session_window.start".cast(StringType), $"session_window.end".cast(StringType),
$"value")
}
assert(err.message.contains("Gap duration expression used in session window must be " +
"CalendarIntervalType, but got DecimalType(2,1)"))
}
}
test("SPARK-36465: time window in SQL with UDF as gap duration") {
withTempTable { table =>
spark.udf.register("gapDuration",
(i: java.lang.Integer) => s"${i * 10} seconds")
checkAnswer(
spark.sql(s"""select session_window(time, gapDuration(value)), value from $table""")
.select($"session_window.start".cast(StringType), $"session_window.end".cast(StringType),
$"value"),
Seq(
Row("2016-03-27 19:39:27", "2016-03-27 19:40:07", 4),
Row("2016-03-27 19:39:34", "2016-03-27 19:39:44", 1),
Row("2016-03-27 19:39:56", "2016-03-27 19:40:16", 2)
)
)
}
}
test("SPARK-36465: time window in SQL with conditional expression as gap duration") {
withTempTable { table =>
checkAnswer(
spark.sql("select session_window(time, " +
"""case when value = 1 then "2 seconds" when value = 2 then "10 seconds" """ +
s"""else "20 seconds" end), value from $table""")
.select($"session_window.start".cast(StringType), $"session_window.end".cast(StringType),
$"value"),
Seq(
Row("2016-03-27 19:39:27", "2016-03-27 19:39:47", 4),
Row("2016-03-27 19:39:34", "2016-03-27 19:39:36", 1),
Row("2016-03-27 19:39:56", "2016-03-27 19:40:06", 2)
)
)
}
}
test("SPARK-36465: filter out events with negative/zero gap duration") {
withTempTable { table =>
spark.udf.register("gapDuration",
(i: java.lang.Integer) => {
if (i == 1) {
"0 seconds"
} else if (i == 2) {
"-10 seconds"
} else {
"5 seconds"
}
})
checkAnswer(
spark.sql(s"""select session_window(time, gapDuration(value)), value from $table""")
.groupBy($"session_window")
.agg(count("*").as("counts"))
.select($"session_window.start".cast("string"), $"session_window.end".cast("string"),
$"counts"),
Seq(Row("2016-03-27 19:39:27", "2016-03-27 19:39:32", 1))
)
}
}
}

View file

@ -23,7 +23,7 @@ import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.must.Matchers
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, DataFrame}
import org.apache.spark.sql.{AnalysisException, Column, DataFrame}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider}
import org.apache.spark.sql.functions.{count, session_window, sum}
@ -256,6 +256,102 @@ class StreamingSessionWindowSuite extends StreamTest
)
}
testWithAllOptions("SPARK-36465: dynamic gap duration") {
val inputData = MemoryStream[(String, Long)]
val udf = spark.udf.register("gapDuration", (s: String) => {
if (s == "hello") {
"1 second"
} else if (s == "structured") {
// zero gap duration will be filtered out from aggregation
"0 second"
} else if (s == "world") {
// negative gap duration will be filtered out from aggregation
"-10 seconds"
} else {
"10 seconds"
}
})
val sessionUpdates = sessionWindowQuery(inputData,
session_window($"eventTime", udf($"sessionId")))
testStream(sessionUpdates, OutputMode.Append())(
AddData(inputData,
("hello world spark streaming", 40L),
("world hello structured streaming", 41L)
),
// watermark: 11
// current sessions
// ("hello", 40, 42, 2, 2),
// ("streaming", 40, 51, 11, 2),
// ("spark", 40, 50, 10, 1),
CheckNewAnswer(
),
// placing new sessions "before" previous sessions
AddData(inputData, ("spark streaming", 25L)),
// watermark: 11
// current sessions
// ("spark", 25, 35, 10, 1),
// ("streaming", 25, 35, 10, 1),
// ("hello", 40, 42, 2, 2),
// ("streaming", 40, 51, 11, 2),
// ("spark", 40, 50, 10, 1),
CheckNewAnswer(
),
// late event which session's end 10 would be later than watermark 11: should be dropped
AddData(inputData, ("spark streaming", 0L)),
// watermark: 11
// current sessions
// ("spark", 25, 35, 10, 1),
// ("streaming", 25, 35, 10, 1),
// ("hello", 40, 42, 2, 2),
// ("streaming", 40, 51, 11, 2),
// ("spark", 40, 50, 10, 1),
CheckNewAnswer(
),
// concatenating multiple previous sessions into one
AddData(inputData, ("spark streaming", 30L)),
// watermark: 11
// current sessions
// ("spark", 25, 50, 25, 3),
// ("streaming", 25, 51, 26, 4),
// ("hello", 40, 42, 2, 2),
CheckNewAnswer(
),
// placing new sessions after previous sessions
AddData(inputData, ("hello apache spark", 60L)),
// watermark: 30
// current sessions
// ("spark", 25, 50, 25, 3),
// ("streaming", 25, 51, 26, 4),
// ("hello", 40, 42, 2, 2),
// ("hello", 60, 61, 1, 1),
// ("apache", 60, 70, 10, 1),
// ("spark", 60, 70, 10, 1)
CheckNewAnswer(
),
AddData(inputData, ("structured streaming", 90L)),
// watermark: 60
// current sessions
// ("hello", 60, 61, 1, 1),
// ("apache", 60, 70, 10, 1),
// ("spark", 60, 70, 10, 1),
// ("streaming", 90, 100, 10, 1)
CheckNewAnswer(
("spark", 25, 50, 25, 3),
("streaming", 25, 51, 26, 4),
("hello", 40, 42, 2, 2)
)
)
}
testWithAllOptions("append mode - session window - no key") {
val inputData = MemoryStream[Int]
val windowedAggregation = sessionWindowQueryOnGlobalKey(inputData)
@ -304,7 +400,9 @@ class StreamingSessionWindowSuite extends StreamTest
}
}
private def sessionWindowQuery(input: MemoryStream[(String, Long)]): DataFrame = {
private def sessionWindowQuery(
input: MemoryStream[(String, Long)],
sessionWindow: Column = session_window($"eventTime", "10 seconds")): DataFrame = {
// Split the lines into words, treat words as sessionId of events
val events = input.toDF()
.select($"_1".as("value"), $"_2".as("timestamp"))
@ -313,7 +411,7 @@ class StreamingSessionWindowSuite extends StreamTest
.selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime")
events
.groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId)
.groupBy(sessionWindow as 'session, 'sessionId)
.agg(count("*").as("numEvents"))
.selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)",
"CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs",