[SPARK-34893][SS] Support session window natively

Introduction: this PR is the last part of SPARK-10816 (EventTime based sessionization (session window)). Please refer #31937 to see the overall view of the code change. (Note that code diff could be diverged a bit.)

### What changes were proposed in this pull request?

This PR proposes to support native session window. Please refer the comments/design doc in SPARK-10816 for more details on the rationalization and design (could be outdated a bit compared to the PR).

The definition of the boundary of "session window" is [the timestamp of start event ~ the timestamp of last event + gap duration). That said, unlike time window, session window is a dynamic window which can expand if new input row is added to the session. To handle expansion of session window, Spark defines session window per input row, and "merge" windows if they can be merged (boundaries are overlapped).

This PR leverages two different approaches on merging session windows:

1. merging session windows with Spark's aggregation logic (a variant of sort aggregation)
2. updating session window for all rows bound to the same session, and applying aggregation logic afterwards

First one is preferable as it outperforms compared to the second one, though it can be only used if merging session window can be applied altogether with aggregation. It is not applicable on all the cases, so second one is used to cover the remaining cases.

This PR also applies the optimization on merging input rows and existing sessions with retaining the order (group keys + start timestamp of session window), leveraging the fact the number of existing sessions per group key won't be huge.

The state format is versioned, so that we can bring a new state format if we find a better one.

### Why are the changes needed?

For now, to deal with sessionization, Spark requires end users to play with (flat)MapGroupsWithState directly which has a couple of major drawbacks:

1. (flat)MapGroupsWithState is lower level API and end users have to code everything in details for defining session window and merging windows
2. built-in aggregate functions cannot be used and end users have to deal with aggregation by themselves
3. (flat)MapGroupsWithState is only available in Scala/Java.

With native support of session window, end users simply use "session_window" like they use "window" for tumbling/sliding window, and leverage built-in aggregate functions as well as UDAFs to simply define aggregations.

Quoting the query example from test suite:

```
    val inputData = MemoryStream[(String, Long)]

    // Split the lines into words, treat words as sessionId of events
    val events = inputData.toDF()
      .select($"_1".as("value"), $"_2".as("timestamp"))
      .withColumn("eventTime", $"timestamp".cast("timestamp"))
      .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime")
      .withWatermark("eventTime", "30 seconds")

    val sessionUpdates = events
      .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId)
      .agg(count("*").as("numEvents"))
      .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)",
        "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs",
        "numEvents")
```

which is same as StructuredSessionization (native session window is shorter and clearer even ignoring model classes).

39542bb81f/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala (L66-L105)

(Worth noting that the code in StructuredSessionization only works with processing time. The code doesn't consider old event can update the start time of old session.)

### Does this PR introduce _any_ user-facing change?

Yes. This PR brings the new feature to support session window on both batch and streaming query, which adds a new function "session_window" which usage is similar with "window".

### How was this patch tested?

New test suites. Also tested with benchmark code.

Closes #33081 from HeartSaVioR/SPARK-34893-SPARK-10816-PR-31570-part-5.

Lead-authored-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
Co-authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Co-authored-by: Yuanjian Li <yuanjian.li@databricks.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
This commit is contained in:
Jungtaek Lim 2021-07-16 20:38:16 +09:00
parent c1b3f86c58
commit f2bf8b051b
23 changed files with 1608 additions and 39 deletions

View file

@ -2333,6 +2333,41 @@ def window(timeColumn, windowDuration, slideDuration=None, startTime=None):
return Column(res) return Column(res)
def session_window(timeColumn, gapDuration):
"""
Generates session window given a timestamp specifying column.
Session window is one of dynamic windows, which means the length of window is varying
according to the given inputs. The length of session window is defined as "the timestamp
of latest input of the session + gap duration", so when the new inputs are bound to the
current session window, the end time of session window can be expanded according to the new
inputs.
Windows can support microsecond precision. Windows in the order of months are not supported.
For a streaming query, you may use the function `current_timestamp` to generate windows on
processing time.
gapDuration is provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid
interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'.
The output column will be a struct called 'session_window' by default with the nested columns
'start' and 'end', where 'start' and 'end' will be of :class:`pyspark.sql.types.TimestampType`.
.. versionadded:: 3.2.0
Examples
--------
>>> df = spark.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val")
>>> w = df.groupBy(session_window("date", "5 seconds")).agg(sum("val").alias("sum"))
>>> w.select(w.session_window.start.cast("string").alias("start"),
... w.session_window.end.cast("string").alias("end"), "sum").collect()
[Row(start='2016-03-11 09:00:07', end='2016-03-11 09:00:12', sum=1)]
"""
def check_string_field(field, fieldName):
if not field or type(field) is not str:
raise TypeError("%s should be provided as a string" % fieldName)
sc = SparkContext._active_spark_context
time_col = _to_java_column(timeColumn)
check_string_field(gapDuration, "gapDuration")
res = sc._jvm.functions.session_window(time_col, gapDuration)
return Column(res)
# ---------------------------- misc functions ---------------------------------- # ---------------------------- misc functions ----------------------------------
def crc32(col): def crc32(col):

View file

@ -135,6 +135,7 @@ def window(
slideDuration: Optional[str] = ..., slideDuration: Optional[str] = ...,
startTime: Optional[str] = ..., startTime: Optional[str] = ...,
) -> Column: ... ) -> Column: ...
def session_window(timeColumn: ColumnOrName, gapDuration: str) -> Column: ...
def crc32(col: ColumnOrName) -> Column: ... def crc32(col: ColumnOrName) -> Column: ...
def md5(col: ColumnOrName) -> Column: ... def md5(col: ColumnOrName) -> Column: ...
def sha1(col: ColumnOrName) -> Column: ... def sha1(col: ColumnOrName) -> Column: ...

View file

@ -296,6 +296,7 @@ class Analyzer(override val catalogManager: CatalogManager)
GlobalAggregates :: GlobalAggregates ::
ResolveAggregateFunctions :: ResolveAggregateFunctions ::
TimeWindowing :: TimeWindowing ::
SessionWindowing ::
ResolveInlineTables :: ResolveInlineTables ::
ResolveHigherOrderFunctions(catalogManager) :: ResolveHigherOrderFunctions(catalogManager) ::
ResolveLambdaVariables :: ResolveLambdaVariables ::
@ -3856,9 +3857,13 @@ object TimeWindowing extends Rule[LogicalPlan] {
val windowExpressions = val windowExpressions =
p.expressions.flatMap(_.collect { case t: TimeWindow => t }).toSet p.expressions.flatMap(_.collect { case t: TimeWindow => t }).toSet
val numWindowExpr = windowExpressions.size val numWindowExpr = p.expressions.flatMap(_.collect {
case s: SessionWindow => s
case t: TimeWindow => t
}).toSet.size
// Only support a single window expression for now // Only support a single window expression for now
if (numWindowExpr == 1 && if (numWindowExpr == 1 && windowExpressions.nonEmpty &&
windowExpressions.head.timeColumn.resolved && windowExpressions.head.timeColumn.resolved &&
windowExpressions.head.checkInputDataTypes().isSuccess) { windowExpressions.head.checkInputDataTypes().isSuccess) {
@ -3933,6 +3938,83 @@ object TimeWindowing extends Rule[LogicalPlan] {
} }
} }
/** Maps a time column to a session window. */
object SessionWindowing extends Rule[LogicalPlan] {
import org.apache.spark.sql.catalyst.dsl.expressions._
private final val SESSION_COL_NAME = "session_window"
private final val SESSION_START = "start"
private final val SESSION_END = "end"
/**
* Generates the logical plan for generating session window on a timestamp column.
* Each session window is initially defined as [timestamp, timestamp + gap).
*
* This also adds a marker to the session column so that downstream can easily find the column
* on session window.
*/
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p: LogicalPlan if p.children.size == 1 =>
val child = p.children.head
val sessionExpressions =
p.expressions.flatMap(_.collect { case s: SessionWindow => s }).toSet
val numWindowExpr = p.expressions.flatMap(_.collect {
case s: SessionWindow => s
case t: TimeWindow => t
}).toSet.size
// Only support a single session expression for now
if (numWindowExpr == 1 && sessionExpressions.nonEmpty &&
sessionExpressions.head.timeColumn.resolved &&
sessionExpressions.head.checkInputDataTypes().isSuccess) {
val session = sessionExpressions.head
val metadata = session.timeColumn match {
case a: Attribute => a.metadata
case _ => Metadata.empty
}
val newMetadata = new MetadataBuilder()
.withMetadata(metadata)
.putBoolean(SessionWindow.marker, true)
.build()
val sessionAttr = AttributeReference(
SESSION_COL_NAME, session.dataType, metadata = newMetadata)()
val sessionStart = PreciseTimestampConversion(session.timeColumn, TimestampType, LongType)
val sessionEnd = sessionStart + session.gapDuration
val literalSessionStruct = CreateNamedStruct(
Literal(SESSION_START) ::
PreciseTimestampConversion(sessionStart, LongType, TimestampType) ::
Literal(SESSION_END) ::
PreciseTimestampConversion(sessionEnd, LongType, TimestampType) ::
Nil)
val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(
exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata))
val replacedPlan = p transformExpressions {
case s: SessionWindow => sessionAttr
}
// As same as tumbling window, we add a filter to filter out nulls.
val filterExpr = IsNotNull(session.timeColumn)
replacedPlan.withNewChildren(
Project(sessionStruct +: child.output,
Filter(filterExpr, child)) :: Nil)
} else if (numWindowExpr > 1) {
throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p)
} else {
p // Return unchanged. Analyzer will throw exception later
}
}
}
/** /**
* Resolve expressions if they contains [[NamePlaceholder]]s. * Resolve expressions if they contains [[NamePlaceholder]]s.
*/ */

View file

@ -552,6 +552,7 @@ object FunctionRegistry {
expression[WeekOfYear]("weekofyear"), expression[WeekOfYear]("weekofyear"),
expression[Year]("year"), expression[Year]("year"),
expression[TimeWindow]("window"), expression[TimeWindow]("window"),
expression[SessionWindow]("session_window"),
expression[MakeDate]("make_date"), expression[MakeDate]("make_date"),
expression[MakeTimestamp]("make_timestamp"), expression[MakeTimestamp]("make_timestamp"),
expression[MakeTimestampNTZ]("make_timestamp_ntz", true), expression[MakeTimestampNTZ]("make_timestamp_ntz", true),

View file

@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.types._
/**
* Represent the session window.
*
* @param timeColumn the start time of session window
* @param gapDuration the duration of session gap, meaning the session will close if there is
* no new element appeared within "the last element in session + gap".
*/
case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends UnaryExpression
with ImplicitCastInputTypes
with Unevaluable
with NonSQLExpression {
//////////////////////////
// SQL Constructors
//////////////////////////
def this(timeColumn: Expression, gapDuration: Expression) = {
this(timeColumn, TimeWindow.parseExpression(gapDuration))
}
override def child: Expression = timeColumn
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
override def dataType: DataType = new StructType()
.add(StructField("start", TimestampType))
.add(StructField("end", TimestampType))
// This expression is replaced in the analyzer.
override lazy val resolved = false
/** Validate the inputs for the gap duration in addition to the input data type. */
override def checkInputDataTypes(): TypeCheckResult = {
val dataTypeCheck = super.checkInputDataTypes()
if (dataTypeCheck.isSuccess) {
if (gapDuration <= 0) {
return TypeCheckFailure(s"The window duration ($gapDuration) must be greater than 0.")
}
}
dataTypeCheck
}
override protected def withNewChildInternal(newChild: Expression): Expression =
copy(timeColumn = newChild)
}
object SessionWindow {
val marker = "spark.sessionWindow"
def apply(
timeColumn: Expression,
gapDuration: String): SessionWindow = {
SessionWindow(timeColumn,
TimeWindow.getIntervalInMicroSeconds(gapDuration))
}
}

View file

@ -109,7 +109,7 @@ object TimeWindow {
* @return The interval duration in microseconds. SparkSQL casts TimestampType has microsecond * @return The interval duration in microseconds. SparkSQL casts TimestampType has microsecond
* precision. * precision.
*/ */
private def getIntervalInMicroSeconds(interval: String): Long = { def getIntervalInMicroSeconds(interval: String): Long = {
val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval)) val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval))
if (cal.months != 0) { if (cal.months != 0) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
@ -122,7 +122,7 @@ object TimeWindow {
* Parses the duration expression to generate the long value for the original constructor so * Parses the duration expression to generate the long value for the original constructor so
* that we can use `window` in SQL. * that we can use `window` in SQL.
*/ */
private def parseExpression(expr: Expression): Long = expr match { def parseExpression(expr: Expression): Long = expr match {
case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString) case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString)
case IntegerLiteral(i) => i.toLong case IntegerLiteral(i) => i.toLong
case NonNullLiteral(l, LongType) => l.toString.toLong case NonNullLiteral(l, LongType) => l.toString.toLong

View file

@ -366,8 +366,9 @@ private[spark] object QueryCompilationErrors {
} }
def multiTimeWindowExpressionsNotSupportedError(t: TreeNode[_]): Throwable = { def multiTimeWindowExpressionsNotSupportedError(t: TreeNode[_]): Throwable = {
new AnalysisException("Multiple time window expressions would result in a cartesian product " + new AnalysisException("Multiple time/session window expressions would result in a cartesian " +
"of rows, therefore they are currently not supported.", t.origin.line, t.origin.startPosition) "product of rows, therefore they are currently not supported.", t.origin.line,
t.origin.startPosition)
} }
def viewOutputNumberMismatchQueryColumnNamesError( def viewOutputNumberMismatchQueryColumnNamesError(

View file

@ -1610,6 +1610,27 @@ object SQLConf {
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
.createWithDefault(2) .createWithDefault(2)
val STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION =
buildConf("spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition")
.internal()
.doc("When true, streaming session window sorts and merge sessions in local partition " +
"prior to shuffle. This is to reduce the rows to shuffle, but only beneficial when " +
"there're lots of rows in a batch being assigned to same sessions.")
.version("3.2.0")
.booleanConf
.createWithDefault(false)
val STREAMING_SESSION_WINDOW_STATE_FORMAT_VERSION =
buildConf("spark.sql.streaming.sessionWindow.stateFormatVersion")
.internal()
.doc("State format version used by streaming session window in a streaming query. " +
"State between versions are tend to be incompatible, so state format version shouldn't " +
"be modified after running.")
.version("3.2.0")
.intConf
.checkValue(v => Set(1).contains(v), "Valid version is 1")
.createWithDefault(1)
val UNSUPPORTED_OPERATION_CHECK_ENABLED = val UNSUPPORTED_OPERATION_CHECK_ENABLED =
buildConf("spark.sql.streaming.unsupportedOperationCheck") buildConf("spark.sql.streaming.unsupportedOperationCheck")
.internal() .internal()
@ -3676,6 +3697,9 @@ class SQLConf extends Serializable with Logging {
def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT)
def streamingSessionWindowMergeSessionInLocalPartition: Boolean =
getConf(STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION)
def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED) def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED)
def uiExplainMode: String = getConf(UI_EXPLAIN_MODE) def uiExplainMode: String = getConf(UI_EXPLAIN_MODE)

View file

@ -324,7 +324,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
throw QueryCompilationErrors.groupAggPandasUDFUnsupportedByStreamingAggError() throw QueryCompilationErrors.groupAggPandasUDFUnsupportedByStreamingAggError()
} }
val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) val sessionWindowOption = namedGroupingExpressions.find { p =>
p.metadata.contains(SessionWindow.marker)
}
// Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because
// `groupingExpressions` is not extracted during logical phase. // `groupingExpressions` is not extracted during logical phase.
@ -335,12 +337,29 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
} }
} }
AggUtils.planStreamingAggregation( sessionWindowOption match {
normalizedGroupingExpressions, case Some(sessionWindow) =>
aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), val stateVersion = conf.getConf(SQLConf.STREAMING_SESSION_WINDOW_STATE_FORMAT_VERSION)
rewrittenResultExpressions,
stateVersion, AggUtils.planStreamingAggregationForSession(
planLater(child)) normalizedGroupingExpressions,
sessionWindow,
aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]),
rewrittenResultExpressions,
stateVersion,
conf.streamingSessionWindowMergeSessionInLocalPartition,
planLater(child))
case None =>
val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION)
AggUtils.planStreamingAggregation(
normalizedGroupingExpressions,
aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]),
rewrittenResultExpressions,
stateVersion,
planLater(child))
}
case _ => Nil case _ => Nil
} }

View file

@ -17,10 +17,11 @@
package org.apache.spark.sql.execution.aggregate package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec} import org.apache.spark.sql.execution.streaming._
/** /**
* Utility functions used by the query planner to convert our plan to new aggregation code path. * Utility functions used by the query planner to convert our plan to new aggregation code path.
@ -113,6 +114,11 @@ object AggUtils {
resultExpressions = partialResultExpressions, resultExpressions = partialResultExpressions,
child = child) child = child)
// If we have session window expression in aggregation, we add MergingSessionExec to
// merge sessions with calculating aggregation values.
val interExec: SparkPlan = mayAppendMergingSessionExec(groupingExpressions,
aggregateExpressions, partialAggregate)
// 2. Create an Aggregate Operator for final aggregations. // 2. Create an Aggregate Operator for final aggregations.
val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
// The attributes of the final aggregation buffer, which is presented as input to the result // The attributes of the final aggregation buffer, which is presented as input to the result
@ -126,7 +132,7 @@ object AggUtils {
aggregateAttributes = finalAggregateAttributes, aggregateAttributes = finalAggregateAttributes,
initialInputBufferOffset = groupingExpressions.length, initialInputBufferOffset = groupingExpressions.length,
resultExpressions = resultExpressions, resultExpressions = resultExpressions,
child = partialAggregate) child = interExec)
finalAggregate :: Nil finalAggregate :: Nil
} }
@ -140,6 +146,11 @@ object AggUtils {
resultExpressions: Seq[NamedExpression], resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = { child: SparkPlan): Seq[SparkPlan] = {
// If we have session window expression in aggregation, we add UpdatingSessionsExec to
// calculate sessions for input rows and update rows' session column, so that further
// aggregations can aggregate input rows for the same session.
val maySessionChild = mayAppendUpdatingSessionExec(groupingExpressions, child)
val distinctAttributes = normalizedNamedDistinctExpressions.map(_.toAttribute) val distinctAttributes = normalizedNamedDistinctExpressions.map(_.toAttribute)
val groupingAttributes = groupingExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute)
@ -156,7 +167,7 @@ object AggUtils {
aggregateAttributes = aggregateAttributes, aggregateAttributes = aggregateAttributes,
resultExpressions = groupingAttributes ++ distinctAttributes ++ resultExpressions = groupingAttributes ++ distinctAttributes ++
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
child = child) child = maySessionChild)
} }
// 2. Create an Aggregate Operator for partial merge aggregations. // 2. Create an Aggregate Operator for partial merge aggregations.
@ -345,4 +356,177 @@ object AggUtils {
finalAndCompleteAggregate :: Nil finalAndCompleteAggregate :: Nil
} }
/**
* Plans a streaming session aggregation using the following progression:
*
* - Partial Aggregation
* - all tuples will have aggregated columns with initial value
* - (If "spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition" is enabled)
* - Sort within partition (sort: all keys)
* - MergingSessionExec
* - calculate session among tuples, and aggregate tuples in session with partial merge
* - Shuffle & Sort (distribution: keys "without" session, sort: all keys)
* - SessionWindowStateStoreRestore (group: keys "without" session)
* - merge input tuples with stored tuples (sessions) respecting sort order
* - MergingSessionExec
* - calculate session among tuples, and aggregate tuples in session with partial merge
* - NOTE: it leverages the fact that the output of SessionWindowStateStoreRestore is sorted
* - now there is at most 1 tuple per group, key with session
* - SessionWindowStateStoreSave (group: keys "without" session)
* - saves tuple(s) for the next batch (multiple sessions could co-exist at the same time)
* - Complete (output the current result of the aggregation)
*/
def planStreamingAggregationForSession(
groupingExpressions: Seq[NamedExpression],
sessionExpression: NamedExpression,
functionsWithoutDistinct: Seq[AggregateExpression],
resultExpressions: Seq[NamedExpression],
stateFormatVersion: Int,
mergeSessionsInLocalPartition: Boolean,
child: SparkPlan): Seq[SparkPlan] = {
val groupWithoutSessionExpression = groupingExpressions.filterNot { p =>
p.semanticEquals(sessionExpression)
}
if (groupWithoutSessionExpression.isEmpty) {
throw new AnalysisException("Global aggregation with session window in streaming query" +
" is not supported.")
}
val groupingWithoutSessionAttributes = groupWithoutSessionExpression.map(_.toAttribute)
val groupingAttributes = groupingExpressions.map(_.toAttribute)
// Here doing partial merge is to have aggregated columns with default value for each row.
val partialAggregate: SparkPlan = {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
createAggregate(
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
resultExpressions = groupingAttributes ++
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
child = child)
}
val partialMerged1: SparkPlan = if (mergeSessionsInLocalPartition) {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
// sort happens here to merge sessions on each partition
// this is to reduce amount of rows to shuffle
MergingSessionsExec(
requiredChildDistributionExpressions = None,
requiredChildDistributionOption = None,
groupingExpressions = groupingAttributes,
sessionExpression = sessionExpression,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = groupingAttributes.length,
resultExpressions = groupingAttributes ++
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
child = partialAggregate
)
} else {
partialAggregate
}
// shuffle & sort happens here: most of details are also handled in this physical plan
val restored = SessionWindowStateStoreRestoreExec(groupingWithoutSessionAttributes,
sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None,
stateFormatVersion, partialMerged1)
val mergedSessions = {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
MergingSessionsExec(
requiredChildDistributionExpressions = None,
requiredChildDistributionOption = Some(restored.requiredChildDistribution),
groupingExpressions = groupingAttributes,
sessionExpression = sessionExpression,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = groupingAttributes.length,
resultExpressions = groupingAttributes ++
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
child = restored
)
}
// Note: stateId and returnAllStates are filled in later with preparation rules
// in IncrementalExecution.
val saved = SessionWindowStateStoreSaveExec(
groupingWithoutSessionAttributes,
sessionExpression.toAttribute,
stateInfo = None,
outputMode = None,
eventTimeWatermark = None,
stateFormatVersion, mergedSessions)
val finalAndCompleteAggregate: SparkPlan = {
val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
// The attributes of the final aggregation buffer, which is presented as input to the result
// projection:
val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)
createAggregate(
requiredChildDistributionExpressions = Some(groupingAttributes),
groupingExpressions = groupingAttributes,
aggregateExpressions = finalAggregateExpressions,
aggregateAttributes = finalAggregateAttributes,
initialInputBufferOffset = groupingAttributes.length,
resultExpressions = resultExpressions,
child = saved)
}
finalAndCompleteAggregate :: Nil
}
private def mayAppendUpdatingSessionExec(
groupingExpressions: Seq[NamedExpression],
maybeChildPlan: SparkPlan): SparkPlan = {
groupingExpressions.find(_.metadata.contains(SessionWindow.marker)) match {
case Some(sessionExpression) =>
UpdatingSessionsExec(
groupingExpressions.map(_.toAttribute),
sessionExpression.toAttribute,
maybeChildPlan)
case None => maybeChildPlan
}
}
private def mayAppendMergingSessionExec(
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
partialAggregate: SparkPlan): SparkPlan = {
groupingExpressions.find(_.metadata.contains(SessionWindow.marker)) match {
case Some(sessionExpression) =>
val aggExpressions = aggregateExpressions.map(_.copy(mode = PartialMerge))
val aggAttributes = aggregateExpressions.map(_.resultAttribute)
val groupingAttributes = groupingExpressions.map(_.toAttribute)
val groupingWithoutSessionExpressions = groupingExpressions.diff(Seq(sessionExpression))
val groupingWithoutSessionsAttributes = groupingWithoutSessionExpressions
.map(_.toAttribute)
MergingSessionsExec(
requiredChildDistributionExpressions = Some(groupingWithoutSessionsAttributes),
requiredChildDistributionOption = None,
groupingExpressions = groupingAttributes,
sessionExpression = sessionExpression,
aggregateExpressions = aggExpressions,
aggregateAttributes = aggAttributes,
initialInputBufferOffset = groupingAttributes.length,
resultExpressions = groupingAttributes ++
aggExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
child = partialAggregate
)
case None => partialAggregate
}
}
} }

View file

@ -181,7 +181,8 @@ class UpdatingSessionsIterator(
private val valueProj = GenerateUnsafeProjection.generate(valuesExpressions, inputSchema) private val valueProj = GenerateUnsafeProjection.generate(valuesExpressions, inputSchema)
private val restoreProj = GenerateUnsafeProjection.generate(inputSchema, private val restoreProj = GenerateUnsafeProjection.generate(inputSchema,
groupingExpressions.map(_.toAttribute) ++ valuesExpressions.map(_.toAttribute)) groupingWithoutSession.map(_.toAttribute) ++ Seq(sessionExpression.toAttribute) ++
valuesExpressions.map(_.toAttribute))
private def generateGroupingKey(): InternalRow = { private def generateGroupingKey(): InternalRow = {
val newRow = new SpecificInternalRow(Seq(sessionExpression.toAttribute).toStructType) val newRow = new SpecificInternalRow(Seq(sessionExpression.toAttribute).toStructType)
@ -190,19 +191,21 @@ class UpdatingSessionsIterator(
} }
private def closeCurrentSession(keyChanged: Boolean): Unit = { private def closeCurrentSession(keyChanged: Boolean): Unit = {
assert(returnRowsIter == null || !returnRowsIter.hasNext)
returnRows = rowsForCurrentSession returnRows = rowsForCurrentSession
rowsForCurrentSession = null rowsForCurrentSession = null
val groupingKey = generateGroupingKey() val groupingKey = generateGroupingKey().copy()
val currentRowsIter = returnRows.generateIterator().map { internalRow => val currentRowsIter = returnRows.generateIterator().map { internalRow =>
val valueRow = valueProj(internalRow) val valueRow = valueProj(internalRow)
restoreProj(join2(groupingKey, valueRow)).copy() restoreProj(join2(groupingKey, valueRow)).copy()
} }
returnRowsIter = currentRowsIter if (returnRowsIter != null && returnRowsIter.hasNext) {
returnRowsIter = returnRowsIter ++ currentRowsIter
} else {
returnRowsIter = currentRowsIter
}
if (keyChanged) processedKeys.add(currentKeys) if (keyChanged) processedKeys.add(currentKeys)

View file

@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.aggregate.UpdatingSessionsIterator
import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
@ -53,6 +54,17 @@ case class AggregateInPandasExec(
override def producedAttributes: AttributeSet = AttributeSet(output) override def producedAttributes: AttributeSet = AttributeSet(output)
val sessionWindowOption = groupingExpressions.find { p =>
p.metadata.contains(SessionWindow.marker)
}
val groupingWithoutSessionExpressions = sessionWindowOption match {
case Some(sessionExpression) =>
groupingExpressions.filterNot { p => p.semanticEquals(sessionExpression) }
case None => groupingExpressions
}
override def requiredChildDistribution: Seq[Distribution] = { override def requiredChildDistribution: Seq[Distribution] = {
if (groupingExpressions.isEmpty) { if (groupingExpressions.isEmpty) {
AllTuples :: Nil AllTuples :: Nil
@ -61,6 +73,14 @@ case class AggregateInPandasExec(
} }
} }
override def requiredChildOrdering: Seq[Seq[SortOrder]] = sessionWindowOption match {
case Some(sessionExpression) =>
Seq((groupingWithoutSessionExpressions ++ Seq(sessionExpression))
.map(SortOrder(_, Ascending)))
case None => Seq(groupingExpressions.map(SortOrder(_, Ascending)))
}
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match { udf.children match {
case Seq(u: PythonUDF) => case Seq(u: PythonUDF) =>
@ -73,9 +93,6 @@ case class AggregateInPandasExec(
} }
} }
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
Seq(groupingExpressions.map(SortOrder(_, Ascending)))
override protected def doExecute(): RDD[InternalRow] = { override protected def doExecute(): RDD[InternalRow] = {
val inputRDD = child.execute() val inputRDD = child.execute()
@ -107,13 +124,18 @@ case class AggregateInPandasExec(
// Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty
inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
// If we have session window expression in aggregation, we wrap iterator with
// UpdatingSessionIterator to calculate sessions for input rows and update
// rows' session column, so that further aggregations can aggregate input rows
// for the same session.
val newIter: Iterator[InternalRow] = mayAppendUpdatingSessionIterator(iter)
val prunedProj = UnsafeProjection.create(allInputs.toSeq, child.output) val prunedProj = UnsafeProjection.create(allInputs.toSeq, child.output)
val grouped = if (groupingExpressions.isEmpty) { val grouped = if (groupingExpressions.isEmpty) {
// Use an empty unsafe row as a place holder for the grouping key // Use an empty unsafe row as a place holder for the grouping key
Iterator((new UnsafeRow(), iter)) Iterator((new UnsafeRow(), newIter))
} else { } else {
GroupedIterator(iter, groupingExpressions, child.output) GroupedIterator(newIter, groupingExpressions, child.output)
}.map { case (key, rows) => }.map { case (key, rows) =>
(key, rows.map(prunedProj)) (key, rows.map(prunedProj))
} }
@ -157,4 +179,21 @@ case class AggregateInPandasExec(
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild) copy(child = newChild)
private def mayAppendUpdatingSessionIterator(
iter: Iterator[InternalRow]): Iterator[InternalRow] = {
val newIter = sessionWindowOption match {
case Some(sessionExpression) =>
val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold
val spillThreshold = conf.windowExecBufferSpillThreshold
new UpdatingSessionsIterator(iter, groupingWithoutSessionExpressions, sessionExpression,
child.output, inMemoryThreshold, spillThreshold)
case None => iter
}
newIter
}
} }

View file

@ -149,6 +149,26 @@ class IncrementalExecution(
stateFormatVersion, stateFormatVersion,
child) :: Nil)) child) :: Nil))
case SessionWindowStateStoreSaveExec(keys, session, None, None, None, stateFormatVersion,
UnaryExecNode(agg,
SessionWindowStateStoreRestoreExec(_, _, None, None, _, child))) =>
val aggStateInfo = nextStatefulOperationStateInfo
SessionWindowStateStoreSaveExec(
keys,
session,
Some(aggStateInfo),
Some(outputMode),
Some(offsetSeqMetadata.batchWatermarkMs),
stateFormatVersion,
agg.withNewChildren(
SessionWindowStateStoreRestoreExec(
keys,
session,
Some(aggStateInfo),
Some(offsetSeqMetadata.batchWatermarkMs),
stateFormatVersion,
child) :: Nil))
case StreamingDeduplicateExec(keys, child, None, None) => case StreamingDeduplicateExec(keys, child, None, None) =>
StreamingDeduplicateExec( StreamingDeduplicateExec(
keys, keys,

View file

@ -68,8 +68,11 @@ sealed trait StreamingSessionWindowStateManager extends Serializable {
* {@code extractKeyWithoutSession}. * {@code extractKeyWithoutSession}.
* @param sessions The all sessions including existing sessions if it's active. * @param sessions The all sessions including existing sessions if it's active.
* Existing sessions which aren't included in this parameter will be removed. * Existing sessions which aren't included in this parameter will be removed.
* @return A tuple having two elements
* 1. number of added/updated rows
* 2. number of deleted rows
*/ */
def updateSessions(store: StateStore, key: UnsafeRow, sessions: Seq[UnsafeRow]): Unit def updateSessions(store: StateStore, key: UnsafeRow, sessions: Seq[UnsafeRow]): (Long, Long)
/** /**
* Removes using a predicate on values, with returning removed values via iterator. * Removes using a predicate on values, with returning removed values via iterator.
@ -168,7 +171,7 @@ class StreamingSessionWindowStateManagerImplV1(
override def updateSessions( override def updateSessions(
store: StateStore, store: StateStore,
key: UnsafeRow, key: UnsafeRow,
sessions: Seq[UnsafeRow]): Unit = { sessions: Seq[UnsafeRow]): (Long, Long) = {
// Below two will be used multiple times - need to make sure this is not a stream or iterator. // Below two will be used multiple times - need to make sure this is not a stream or iterator.
val newValues = sessions.toList val newValues = sessions.toList
val savedStates = getSessionsWithKeys(store, key) val savedStates = getSessionsWithKeys(store, key)
@ -225,7 +228,7 @@ class StreamingSessionWindowStateManagerImplV1(
store: StateStore, store: StateStore,
key: UnsafeRow, key: UnsafeRow,
oldValues: List[(UnsafeRow, UnsafeRow)], oldValues: List[(UnsafeRow, UnsafeRow)],
values: List[UnsafeRow]): Unit = { values: List[UnsafeRow]): (Long, Long) = {
// Here the key doesn't represent the state key - we need to construct the key for state // Here the key doesn't represent the state key - we need to construct the key for state
val keyAndValues = values.map { row => val keyAndValues = values.map { row =>
val sessionStart = helper.extractTimePair(row)._1 val sessionStart = helper.extractTimePair(row)._1
@ -236,16 +239,24 @@ class StreamingSessionWindowStateManagerImplV1(
val keysForValues = keyAndValues.map(_._1) val keysForValues = keyAndValues.map(_._1)
val keysForOldValues = oldValues.map(_._1) val keysForOldValues = oldValues.map(_._1)
var upsertedRows = 0L
var deletedRows = 0L
// We should "replace" the value instead of "delete" and "put" if the start time // We should "replace" the value instead of "delete" and "put" if the start time
// equals to. This will remove unnecessary tombstone being written to the delta, which is // equals to. This will remove unnecessary tombstone being written to the delta, which is
// implementation details on state store implementations. // implementation details on state store implementations.
keysForOldValues.filterNot(keysForValues.contains).foreach { oldKey => keysForOldValues.filterNot(keysForValues.contains).foreach { oldKey =>
store.remove(oldKey) store.remove(oldKey)
deletedRows += 1
} }
keyAndValues.foreach { case (key, value) => keyAndValues.foreach { case (key, value) =>
store.put(key, value) store.put(key, value)
upsertedRows += 1
} }
(upsertedRows, deletedRows)
} }
override def abortIfNeeded(store: StateStore): Unit = { override def abortIfNeeded(store: StateStore): Unit = {

View file

@ -20,7 +20,9 @@ package org.apache.spark.sql.execution.streaming
import java.util.UUID import java.util.UUID
import java.util.concurrent.TimeUnit._ import java.util.concurrent.TimeUnit._
import scala.annotation.tailrec
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import scala.collection.mutable
import org.apache.spark.SparkContext import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
@ -511,6 +513,293 @@ case class StateStoreSaveExec(
copy(child = newChild) copy(child = newChild)
} }
/**
* This class sorts input rows and existing sessions in state and provides output rows as
* sorted by "group keys + start time of session window".
*
* Refer [[MergingSortWithSessionWindowStateIterator]] for more details.
*/
case class SessionWindowStateStoreRestoreExec(
keyWithoutSessionExpressions: Seq[Attribute],
sessionExpression: Attribute,
stateInfo: Option[StatefulOperatorStateInfo],
eventTimeWatermark: Option[Long],
stateFormatVersion: Int,
child: SparkPlan)
extends UnaryExecNode with StateStoreReader with WatermarkSupport {
override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions
assert(keyExpressions.nonEmpty, "Grouping key must be specified when using sessionWindow")
private val stateManager = StreamingSessionWindowStateManager.createStateManager(
keyWithoutSessionExpressions, sessionExpression, child.output, stateFormatVersion)
override protected def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
child.execute().mapPartitionsWithReadStateStore(
getStateInfo,
stateManager.getStateKeySchema,
stateManager.getStateValueSchema,
numColsPrefixKey = stateManager.getNumColsForPrefixKey,
session.sessionState,
Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
// We need to filter out outdated inputs
val filteredIterator = watermarkPredicateForData match {
case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row))
case None => iter
}
new MergingSortWithSessionWindowStateIterator(
filteredIterator,
stateManager,
store,
keyWithoutSessionExpressions,
sessionExpression,
child.output).map { row =>
numOutputRows += 1
row
}
}
}
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = {
(keyWithoutSessionExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending))
}
override def requiredChildDistribution: Seq[Distribution] = {
ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil
}
override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
Seq((keyWithoutSessionExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending)))
}
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
}
/**
* For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]].
*/
case class SessionWindowStateStoreSaveExec(
keyWithoutSessionExpressions: Seq[Attribute],
sessionExpression: Attribute,
stateInfo: Option[StatefulOperatorStateInfo] = None,
outputMode: Option[OutputMode] = None,
eventTimeWatermark: Option[Long] = None,
stateFormatVersion: Int,
child: SparkPlan)
extends UnaryExecNode with StateStoreWriter with WatermarkSupport {
override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions
private val stateManager = StreamingSessionWindowStateManager.createStateManager(
keyWithoutSessionExpressions, sessionExpression, child.output, stateFormatVersion)
override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver
assert(outputMode.nonEmpty,
"Incorrect planning in IncrementalExecution, outputMode has not been set")
assert(keyExpressions.nonEmpty,
"Grouping key must be specified when using sessionWindow")
child.execute().mapPartitionsWithStateStore(
getStateInfo,
stateManager.getStateKeySchema,
stateManager.getStateValueSchema,
numColsPrefixKey = stateManager.getNumColsForPrefixKey,
session.sessionState,
Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
val numOutputRows = longMetric("numOutputRows")
val numRemovedStateRows = longMetric("numRemovedStateRows")
val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
val allRemovalsTimeMs = longMetric("allRemovalsTimeMs")
val commitTimeMs = longMetric("commitTimeMs")
outputMode match {
// Update and output all rows in the StateStore.
case Some(Complete) =>
allUpdatesTimeMs += timeTakenMs {
putToStore(iter, store)
}
commitTimeMs += timeTakenMs {
stateManager.commit(store)
}
setStoreMetrics(store)
stateManager.iterator(store).map { row =>
numOutputRows += 1
row
}
// Update and output only rows being evicted from the StateStore
// Assumption: watermark predicates must be non-empty if append mode is allowed
case Some(Append) =>
allUpdatesTimeMs += timeTakenMs {
val filteredIter = applyRemovingRowsOlderThanWatermark(iter,
watermarkPredicateForData.get)
putToStore(filteredIter, store)
}
val removalStartTimeNs = System.nanoTime
new NextIterator[InternalRow] {
private val removedIter = stateManager.removeByValueCondition(
store, watermarkPredicateForData.get.eval)
override protected def getNext(): InternalRow = {
if (!removedIter.hasNext) {
finished = true
null
} else {
numRemovedStateRows += 1
numOutputRows += 1
removedIter.next()
}
}
override protected def close(): Unit = {
allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs)
commitTimeMs += timeTakenMs { store.commit() }
setStoreMetrics(store)
setOperatorMetrics()
}
}
case Some(Update) =>
val baseIterator = watermarkPredicateForData match {
case Some(predicate) => applyRemovingRowsOlderThanWatermark(iter, predicate)
case None => iter
}
val iterPutToStore = iteratorPutToStore(baseIterator, store,
returnOnlyUpdatedRows = true)
new NextIterator[InternalRow] {
private val updatesStartTimeNs = System.nanoTime
override protected def getNext(): InternalRow = {
if (iterPutToStore.hasNext) {
val row = iterPutToStore.next()
numOutputRows += 1
row
} else {
finished = true
null
}
}
override protected def close(): Unit = {
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
allRemovalsTimeMs += timeTakenMs {
if (watermarkPredicateForData.nonEmpty) {
val removedIter = stateManager.removeByValueCondition(
store, watermarkPredicateForData.get.eval)
while (removedIter.hasNext) {
numRemovedStateRows += 1
removedIter.next()
}
}
}
commitTimeMs += timeTakenMs { store.commit() }
setStoreMetrics(store)
setOperatorMetrics()
}
}
case _ => throw QueryExecutionErrors.invalidStreamingOutputModeError(outputMode)
}
}
}
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
override def requiredChildDistribution: Seq[Distribution] = {
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
}
override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = {
(outputMode.contains(Append) || outputMode.contains(Update)) &&
eventTimeWatermark.isDefined &&
newMetadata.batchWatermarkMs > eventTimeWatermark.get
}
private def iteratorPutToStore(
iter: Iterator[InternalRow],
store: StateStore,
returnOnlyUpdatedRows: Boolean): Iterator[InternalRow] = {
val numUpdatedStateRows = longMetric("numUpdatedStateRows")
val numRemovedStateRows = longMetric("numRemovedStateRows")
new NextIterator[InternalRow] {
var curKey: UnsafeRow = null
val curValuesOnKey = new mutable.ArrayBuffer[UnsafeRow]()
private def applyChangesOnKey(): Unit = {
if (curValuesOnKey.nonEmpty) {
val (upserted, deleted) = stateManager.updateSessions(store, curKey, curValuesOnKey.toSeq)
numUpdatedStateRows += upserted
numRemovedStateRows += deleted
curValuesOnKey.clear
}
}
@tailrec
override protected def getNext(): InternalRow = {
if (!iter.hasNext) {
applyChangesOnKey()
finished = true
return null
}
val row = iter.next().asInstanceOf[UnsafeRow]
val key = stateManager.extractKeyWithoutSession(row)
if (curKey == null || curKey != key) {
// new group appears
applyChangesOnKey()
curKey = key.copy()
}
// must copy the row, for this row is a reference in iterator and
// will change when iter.next
curValuesOnKey += row.copy
if (!returnOnlyUpdatedRows) {
row
} else {
if (stateManager.newOrModified(store, row)) {
row
} else {
// current row isn't the "updated" row, continue to the next row
getNext()
}
}
}
override protected def close(): Unit = {}
}
}
private def putToStore(baseIter: Iterator[InternalRow], store: StateStore): Unit = {
val iterPutToStore = iteratorPutToStore(baseIter, store, returnOnlyUpdatedRows = false)
while (iterPutToStore.hasNext) {
iterPutToStore.next()
}
}
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
}
/** Physical operator for executing streaming Deduplicate. */ /** Physical operator for executing streaming Deduplicate. */
case class StreamingDeduplicateExec( case class StreamingDeduplicateExec(
keyExpressions: Seq[Attribute], keyExpressions: Seq[Attribute],

View file

@ -3630,6 +3630,36 @@ object functions {
window(timeColumn, windowDuration, windowDuration, "0 second") window(timeColumn, windowDuration, windowDuration, "0 second")
} }
/**
* Generates session window given a timestamp specifying column.
*
* Session window is one of dynamic windows, which means the length of window is varying
* according to the given inputs. The length of session window is defined as "the timestamp
* of latest input of the session + gap duration", so when the new inputs are bound to the
* current session window, the end time of session window can be expanded according to the new
* inputs.
*
* Windows can support microsecond precision. gapDuration in the order of months are not
* supported.
*
* For a streaming query, you may use the function `current_timestamp` to generate windows on
* processing time.
*
* @param timeColumn The column or the expression to use as the timestamp for windowing by time.
* The time column must be of TimestampType.
* @param gapDuration A string specifying the timeout of the session, e.g. `10 minutes`,
* `1 second`. Check `org.apache.spark.unsafe.types.CalendarInterval` for
* valid duration identifiers.
*
* @group datetime_funcs
* @since 3.2.0
*/
def session_window(timeColumn: Column, gapDuration: String): Column = {
withExpr {
SessionWindow(timeColumn.expr, gapDuration)
}.as("session_window")
}
/** /**
* Creates timestamp from the number of seconds since UTC epoch. * Creates timestamp from the number of seconds since UTC epoch.
* @group datetime_funcs * @group datetime_funcs

View file

@ -1,8 +1,8 @@
<!-- Automatically generated by ExpressionsSchemaSuite --> <!-- Automatically generated by ExpressionsSchemaSuite -->
## Summary ## Summary
- Number of queries: 360 - Number of queries: 361
- Number of expressions that missing example: 13 - Number of expressions that missing example: 14
- Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint,window - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint,session_window,window
## Schema of Built-in Functions ## Schema of Built-in Functions
| Class name | Function name or alias | Query example | Output schema | | Class name | Function name or alias | Query example | Output schema |
| ---------- | ---------------------- | ------------- | ------------- | | ---------- | ---------------------- | ------------- | ------------- |
@ -244,6 +244,7 @@
| org.apache.spark.sql.catalyst.expressions.SecondsToTimestamp | timestamp_seconds | SELECT timestamp_seconds(1230219000) | struct<timestamp_seconds(1230219000):timestamp> | | org.apache.spark.sql.catalyst.expressions.SecondsToTimestamp | timestamp_seconds | SELECT timestamp_seconds(1230219000) | struct<timestamp_seconds(1230219000):timestamp> |
| org.apache.spark.sql.catalyst.expressions.Sentences | sentences | SELECT sentences('Hi there! Good morning.') | struct<sentences(Hi there! Good morning., , ):array<array<string>>> | | org.apache.spark.sql.catalyst.expressions.Sentences | sentences | SELECT sentences('Hi there! Good morning.') | struct<sentences(Hi there! Good morning., , ):array<array<string>>> |
| org.apache.spark.sql.catalyst.expressions.Sequence | sequence | SELECT sequence(1, 5) | struct<sequence(1, 5):array<int>> | | org.apache.spark.sql.catalyst.expressions.Sequence | sequence | SELECT sequence(1, 5) | struct<sequence(1, 5):array<int>> |
| org.apache.spark.sql.catalyst.expressions.SessionWindow | session_window | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.Sha1 | sha | SELECT sha('Spark') | struct<sha(Spark):string> | | org.apache.spark.sql.catalyst.expressions.Sha1 | sha | SELECT sha('Spark') | struct<sha(Spark):string> |
| org.apache.spark.sql.catalyst.expressions.Sha1 | sha1 | SELECT sha1('Spark') | struct<sha1(Spark):string> | | org.apache.spark.sql.catalyst.expressions.Sha1 | sha1 | SELECT sha1('Spark') | struct<sha1(Spark):string> |
| org.apache.spark.sql.catalyst.expressions.Sha2 | sha2 | SELECT sha2('Spark', 256) | struct<sha2(Spark, 256):string> | | org.apache.spark.sql.catalyst.expressions.Sha2 | sha2 | SELECT sha2('Spark', 256) | struct<sha2(Spark, 256):string> |

View file

@ -0,0 +1,290 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.sql.catalyst.plans.logical.Expand
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StringType
class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
with BeforeAndAfterEach {
import testImplicits._
test("simple session window with record at window start") {
val df = Seq(
("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id")
checkAnswer(
df.groupBy(session_window($"time", "10 seconds"))
.agg(count("*").as("counts"))
.orderBy($"session_window.start".asc)
.select($"session_window.start".cast("string"), $"session_window.end".cast("string"),
$"counts"),
Seq(
Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1)
)
)
}
test("session window groupBy statement") {
val df = Seq(
("2016-03-27 19:39:34", 1, "a"),
("2016-03-27 19:39:56", 2, "a"),
("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
// session window handles sort while applying group by
// whereas time window doesn't
checkAnswer(
df.groupBy(session_window($"time", "10 seconds"))
.agg(count("*").as("counts"))
.orderBy($"session_window.start".asc)
.select("counts"),
Seq(Row(2), Row(1))
)
}
test("session window groupBy with multiple keys statement") {
val df = Seq(
("2016-03-27 19:39:34", 1, "a"),
("2016-03-27 19:39:39", 1, "a"),
("2016-03-27 19:39:56", 2, "a"),
("2016-03-27 19:40:04", 2, "a"),
("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
// session window handles sort while applying group by
// whereas time window doesn't
// expected sessions
// key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14)
// key "b" => (19:39:27 ~ 19:39:37)
checkAnswer(
df.groupBy(session_window($"time", "10 seconds"), 'id)
.agg(count("*").as("counts"), sum("value").as("sum"))
.orderBy($"session_window.start".asc)
.selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)",
"id", "counts", "sum"),
Seq(
Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4),
Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 2, 2),
Row("2016-03-27 19:39:56", "2016-03-27 19:40:14", "a", 2, 4)
)
)
}
test("session window groupBy with multiple keys statement - one distinct") {
val df = Seq(
("2016-03-27 19:39:34", 1, "a"),
("2016-03-27 19:39:39", 1, "a"),
("2016-03-27 19:39:56", 2, "a"),
("2016-03-27 19:40:04", 2, "a"),
("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
// session window handles sort while applying group by
// whereas time window doesn't
// expected sessions
// key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14)
// key "b" => (19:39:27 ~ 19:39:37)
checkAnswer(
df.groupBy(session_window($"time", "10 seconds"), 'id)
.agg(count("*").as("counts"), sum_distinct(col("value")).as("sum"))
.orderBy($"session_window.start".asc)
.selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)",
"id", "counts", "sum"),
Seq(
Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4),
Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 2, 1),
Row("2016-03-27 19:39:56", "2016-03-27 19:40:14", "a", 2, 2)
)
)
}
test("session window groupBy with multiple keys statement - two distinct") {
val df = Seq(
("2016-03-27 19:39:34", 1, 2, "a"),
("2016-03-27 19:39:39", 1, 2, "a"),
("2016-03-27 19:39:56", 2, 4, "a"),
("2016-03-27 19:40:04", 2, 4, "a"),
("2016-03-27 19:39:27", 4, 8, "b")).toDF("time", "value", "value2", "id")
// session window handles sort while applying group by
// whereas time window doesn't
// expected sessions
// key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14)
// key "b" => (19:39:27 ~ 19:39:37)
checkAnswer(
df.groupBy(session_window($"time", "10 seconds"), 'id)
.agg(sum_distinct(col("value")).as("sum"), sum_distinct(col("value2")).as("sum2"))
.orderBy($"session_window.start".asc)
.selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)",
"id", "sum", "sum2"),
Seq(
Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 4, 8),
Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 1, 2),
Row("2016-03-27 19:39:56", "2016-03-27 19:40:14", "a", 2, 4)
)
)
}
test("session window groupBy with multiple keys statement - keys overlapped with sessions") {
val df = Seq(
("2016-03-27 19:39:34", 1, "a"),
("2016-03-27 19:39:39", 1, "b"),
("2016-03-27 19:39:40", 2, "a"),
("2016-03-27 19:39:45", 2, "b"),
("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
// session window handles sort while applying group by
// whereas time window doesn't
// expected sessions
// a => (19:39:34 ~ 19:39:50)
// b => (19:39:27 ~ 19:39:37), (19:39:39 ~ 19:39:55)
checkAnswer(
df.groupBy(session_window($"time", "10 seconds"), 'id)
.agg(count("*").as("counts"), sum("value").as("sum"))
.orderBy($"session_window.start".asc)
.selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)",
"id", "counts", "sum"),
Seq(
Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4),
Row("2016-03-27 19:39:34", "2016-03-27 19:39:50", "a", 2, 3),
Row("2016-03-27 19:39:39", "2016-03-27 19:39:55", "b", 2, 3)
)
)
}
test("session window with multi-column projection") {
val df = Seq(
("2016-03-27 19:39:34", 1, "a"),
("2016-03-27 19:39:56", 2, "a"),
("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
.select(session_window($"time", "10 seconds"), $"value")
.orderBy($"session_window.start".asc)
.select($"session_window.start".cast("string"), $"session_window.end".cast("string"),
$"value")
val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand])
assert(expands.isEmpty, "Session windows shouldn't require expand")
checkAnswer(
df,
Seq(
Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 4),
Row("2016-03-27 19:39:34", "2016-03-27 19:39:44", 1),
Row("2016-03-27 19:39:56", "2016-03-27 19:40:06", 2)
)
)
}
test("session window combined with explode expression") {
val df = Seq(
("2016-03-27 19:39:34", 1, Seq("a", "b")),
("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids")
checkAnswer(
df.select(session_window($"time", "10 seconds"), $"value", explode($"ids"))
.orderBy($"session_window.start".asc).select("value"),
// first window exploded to two rows for "a", and "b", second window exploded to 3 rows
Seq(Row(1), Row(1), Row(2), Row(2), Row(2))
)
}
test("null timestamps") {
val df = Seq(
("2016-03-27 09:00:05", 1),
("2016-03-27 09:00:32", 2),
(null, 3),
(null, 4)).toDF("time", "value")
checkDataset(
df.select(session_window($"time", "10 seconds"), $"value")
.orderBy($"session_window.start".asc)
.select("value")
.as[Int],
1, 2) // null columns are dropped
}
// NOTE: unlike time window, joining session windows without grouping
// doesn't arrange session, so two rows will be joined only if session range is exactly same
test("multiple session windows in a single operator throws nice exception") {
val df = Seq(
("2016-03-27 09:00:02", 3),
("2016-03-27 09:00:35", 6)).toDF("time", "value")
val e = intercept[AnalysisException] {
df.select(session_window($"time", "10 second"), session_window($"time", "15 second"))
.collect()
}
assert(e.getMessage.contains(
"Multiple time/session window expressions would result in a cartesian product"))
}
test("aliased session windows") {
val df = Seq(
("2016-03-27 19:39:34", 1, Seq("a", "b")),
("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids")
checkAnswer(
df.select(session_window($"time", "10 seconds").as("session_window"), $"value")
.orderBy($"session_window.start".asc)
.select("value"),
Seq(Row(1), Row(2))
)
}
private def withTempTable(f: String => Unit): Unit = {
val tableName = "temp"
Seq(
("2016-03-27 19:39:34", 1),
("2016-03-27 19:39:56", 2),
("2016-03-27 19:39:27", 4)).toDF("time", "value").createOrReplaceTempView(tableName)
try {
f(tableName)
} finally {
spark.catalog.dropTempView(tableName)
}
}
test("time window in SQL with single string expression") {
withTempTable { table =>
checkAnswer(
spark.sql(s"""select session_window(time, "10 seconds"), value from $table""")
.select($"session_window.start".cast(StringType), $"session_window.end".cast(StringType),
$"value"),
Seq(
Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 4),
Row("2016-03-27 19:39:34", "2016-03-27 19:39:44", 1),
Row("2016-03-27 19:39:56", "2016-03-27 19:40:06", 2)
)
)
}
}
}

View file

@ -239,7 +239,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession {
df.select(window($"time", "10 second"), window($"time", "15 second")).collect() df.select(window($"time", "10 second"), window($"time", "15 second")).collect()
} }
assert(e.getMessage.contains( assert(e.getMessage.contains(
"Multiple time window expressions would result in a cartesian product")) "Multiple time/session window expressions would result in a cartesian product"))
} }
test("aliased windows") { test("aliased windows") {

View file

@ -136,7 +136,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
test("SPARK-14415: All functions should have own descriptions") { test("SPARK-14415: All functions should have own descriptions") {
for (f <- spark.sessionState.functionRegistry.listFunction()) { for (f <- spark.sessionState.functionRegistry.listFunction()) {
if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f.unquotedString)) { if (!Seq("cube", "grouping", "grouping_id", "rollup", "window",
"session_window").contains(f.unquotedString)) {
checkKeywordsNotExist(sql(s"describe function $f"), "N/A.") checkKeywordsNotExist(sql(s"describe function $f"), "N/A.")
} }
} }

View file

@ -199,9 +199,9 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession {
val row6 = createRow("a", 2, 115, 125, 20, 1.2) val row6 = createRow("a", 2, 115, 125, 20, 1.2)
val rows3 = List(row5, row6) val rows3 = List(row5, row6)
// This is to test the edge case that the last input row creates a new session.
val row7 = createRow("a", 2, 127, 137, 30, 1.3) val row7 = createRow("a", 2, 127, 137, 30, 1.3)
val row8 = createRow("a", 2, 135, 145, 40, 1.4) val rows4 = List(row7)
val rows4 = List(row7, row8)
val rowsAll = rows1 ++ rows2 ++ rows3 ++ rows4 val rowsAll = rows1 ++ rows2 ++ rows3 ++ rows4
@ -244,8 +244,8 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession {
} }
retRows4.zip(rows4).foreach { case (retRow, expectedRow) => retRows4.zip(rows4).foreach { case (retRow, expectedRow) =>
// session being expanded to (127 ~ 145) // session being expanded to (127 ~ 137)
assertRowsEqualsWithNewSession(expectedRow, retRow, 127, 145) assertRowsEqualsWithNewSession(expectedRow, retRow, 127, 137)
} }
assert(iterator.hasNext === false) assert(iterator.hasNext === false)

View file

@ -133,6 +133,7 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
val ignoreSet = Set( val ignoreSet = Set(
// Explicitly inherits NonSQLExpression, and has no ExpressionDescription // Explicitly inherits NonSQLExpression, and has no ExpressionDescription
"org.apache.spark.sql.catalyst.expressions.TimeWindow", "org.apache.spark.sql.catalyst.expressions.TimeWindow",
"org.apache.spark.sql.catalyst.expressions.SessionWindow",
// Cast aliases do not need examples // Cast aliases do not need examples
"org.apache.spark.sql.catalyst.expressions.Cast") "org.apache.spark.sql.catalyst.expressions.Cast")

View file

@ -0,0 +1,460 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.streaming
import java.util.Locale
import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.must.Matchers
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider}
import org.apache.spark.sql.functions.{count, session_window, sum}
import org.apache.spark.sql.internal.SQLConf
class StreamingSessionWindowSuite extends StreamTest
with BeforeAndAfter with Matchers with Logging {
import testImplicits._
after {
sqlContext.streams.active.foreach(_.stop())
}
def testWithAllOptions(name: String, confPairs: (String, String)*)
(func: => Any): Unit = {
val mergingSessionOptions = Seq(true, false).map { value =>
(SQLConf.STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION.key, value)
}
val providerOptions = Seq(
classOf[HDFSBackedStateStoreProvider].getCanonicalName,
classOf[RocksDBStateStoreProvider].getCanonicalName
).map { value =>
(SQLConf.STATE_STORE_PROVIDER_CLASS.key, value.stripSuffix("$"))
}
val availableOptions = for (
opt1 <- mergingSessionOptions;
opt2 <- providerOptions
) yield (opt1, opt2)
for (option <- availableOptions) {
test(s"$name - merging sessions in local partition: ${option._1._2} / " +
s"provider: ${option._2._2}") {
withSQLConf(confPairs ++
Seq(
option._1._1 -> option._1._2.toString,
option._2._1 -> option._2._2): _*) {
func
}
}
}
}
testWithAllOptions("complete mode - session window") {
// Implements StructuredSessionization.scala leveraging "session" function
// as a test, to verify the sessionization works with simple example
// note that complete mode doesn't honor watermark: even it is specified, watermark will be
// always Unix timestamp 0
val inputData = MemoryStream[(String, Long)]
// Split the lines into words, treat words as sessionId of events
val events = inputData.toDF()
.select($"_1".as("value"), $"_2".as("timestamp"))
.withColumn("eventTime", $"timestamp".cast("timestamp"))
.selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime")
val sessionUpdates = events
.groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId)
.agg(count("*").as("numEvents"))
.selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)",
"CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs",
"numEvents")
sessionUpdates.explain()
testStream(sessionUpdates, OutputMode.Complete())(
AddData(inputData,
("hello world spark streaming", 40L),
("world hello structured streaming", 41L)
),
CheckNewAnswer(
("hello", 40, 51, 11, 2),
("world", 40, 51, 11, 2),
("streaming", 40, 51, 11, 2),
("spark", 40, 50, 10, 1),
("structured", 41, 51, 10, 1)
),
// placing new sessions "before" previous sessions
AddData(inputData, ("spark streaming", 25L)),
CheckNewAnswer(
("spark", 25, 35, 10, 1),
("streaming", 25, 35, 10, 1),
("hello", 40, 51, 11, 2),
("world", 40, 51, 11, 2),
("streaming", 40, 51, 11, 2),
("spark", 40, 50, 10, 1),
("structured", 41, 51, 10, 1)
),
// concatenating multiple previous sessions into one
AddData(inputData, ("spark streaming", 30L)),
CheckNewAnswer(
("spark", 25, 50, 25, 3),
("streaming", 25, 51, 26, 4),
("hello", 40, 51, 11, 2),
("world", 40, 51, 11, 2),
("structured", 41, 51, 10, 1)
),
// placing new sessions after previous sessions
AddData(inputData, ("hello apache spark", 60L)),
CheckNewAnswer(
("spark", 25, 50, 25, 3),
("streaming", 25, 51, 26, 4),
("hello", 40, 51, 11, 2),
("world", 40, 51, 11, 2),
("structured", 41, 51, 10, 1),
("hello", 60, 70, 10, 1),
("apache", 60, 70, 10, 1),
("spark", 60, 70, 10, 1)
),
AddData(inputData, ("structured streaming", 90L)),
CheckNewAnswer(
("spark", 25, 50, 25, 3),
("streaming", 25, 51, 26, 4),
("hello", 40, 51, 11, 2),
("world", 40, 51, 11, 2),
("structured", 41, 51, 10, 1),
("hello", 60, 70, 10, 1),
("apache", 60, 70, 10, 1),
("spark", 60, 70, 10, 1),
("structured", 90, 100, 10, 1),
("streaming", 90, 100, 10, 1)
)
)
}
testWithAllOptions("complete mode - session window - no key") {
// complete mode doesn't honor watermark: even it is specified, watermark will be
// always Unix timestamp 0
val inputData = MemoryStream[Int]
val windowedAggregation = inputData.toDF()
.selectExpr("*")
.withColumn("eventTime", $"value".cast("timestamp"))
.groupBy(session_window($"eventTime", "5 seconds") as 'session)
.agg(count("*") as 'count, sum("value") as 'sum)
.select($"session".getField("start").cast("long").as[Long],
$"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long])
val e = intercept[StreamingQueryException] {
testStream(windowedAggregation, OutputMode.Complete())(
AddData(inputData, 40),
CheckAnswer() // this is just to trigger the exception
)
}
Seq("Global aggregation with session window", "not supported").foreach { m =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT)))
}
}
testWithAllOptions("append mode - session window") {
// Implements StructuredSessionization.scala leveraging "session" function
// as a test, to verify the sessionization works with simple example
val inputData = MemoryStream[(String, Long)]
// Split the lines into words, treat words as sessionId of events
val events = inputData.toDF()
.select($"_1".as("value"), $"_2".as("timestamp"))
.withColumn("eventTime", $"timestamp".cast("timestamp"))
.selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime")
.withWatermark("eventTime", "30 seconds")
val sessionUpdates = events
.groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId)
.agg(count("*").as("numEvents"))
.selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)",
"CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs",
"numEvents")
testStream(sessionUpdates, OutputMode.Append())(
AddData(inputData,
("hello world spark streaming", 40L),
("world hello structured streaming", 41L)
),
// watermark: 11
// current sessions
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("streaming", 40, 51, 11, 2),
// ("spark", 40, 50, 10, 1),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
),
// placing new sessions "before" previous sessions
AddData(inputData, ("spark streaming", 25L)),
// watermark: 11
// current sessions
// ("spark", 25, 35, 10, 1),
// ("streaming", 25, 35, 10, 1),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("streaming", 40, 51, 11, 2),
// ("spark", 40, 50, 10, 1),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
),
// late event which session's end 10 would be later than watermark 11: should be dropped
AddData(inputData, ("spark streaming", 0L)),
// watermark: 11
// current sessions
// ("spark", 25, 35, 10, 1),
// ("streaming", 25, 35, 10, 1),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("streaming", 40, 51, 11, 2),
// ("spark", 40, 50, 10, 1),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
),
// concatenating multiple previous sessions into one
AddData(inputData, ("spark streaming", 30L)),
// watermark: 11
// current sessions
// ("spark", 25, 50, 25, 3),
// ("streaming", 25, 51, 26, 4),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
),
// placing new sessions after previous sessions
AddData(inputData, ("hello apache spark", 60L)),
// watermark: 30
// current sessions
// ("spark", 25, 50, 25, 3),
// ("streaming", 25, 51, 26, 4),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("structured", 41, 51, 10, 1),
// ("hello", 60, 70, 10, 1),
// ("apache", 60, 70, 10, 1),
// ("spark", 60, 70, 10, 1)
CheckNewAnswer(
),
AddData(inputData, ("structured streaming", 90L)),
// watermark: 60
// current sessions
// ("hello", 60, 70, 10, 1),
// ("apache", 60, 70, 10, 1),
// ("spark", 60, 70, 10, 1),
// ("structured", 90, 100, 10, 1),
// ("streaming", 90, 100, 10, 1)
CheckNewAnswer(
("spark", 25, 50, 25, 3),
("streaming", 25, 51, 26, 4),
("hello", 40, 51, 11, 2),
("world", 40, 51, 11, 2),
("structured", 41, 51, 10, 1)
)
)
}
testWithAllOptions("append mode - session window - no key") {
val inputData = MemoryStream[Int]
val windowedAggregation = inputData.toDF()
.selectExpr("*")
.withColumn("eventTime", $"value".cast("timestamp"))
.withWatermark("eventTime", "10 seconds")
.groupBy(session_window($"eventTime", "5 seconds") as 'session)
.agg(count("*") as 'count, sum("value") as 'sum)
.select($"session".getField("start").cast("long").as[Long],
$"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long])
val e = intercept[StreamingQueryException] {
testStream(windowedAggregation)(
AddData(inputData, 40),
CheckAnswer() // this is just to trigger the exception
)
}
Seq("Global aggregation with session window", "not supported").foreach { m =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT)))
}
}
testWithAllOptions("update mode - session window") {
// Implements StructuredSessionization.scala leveraging "session" function
// as a test, to verify the sessionization works with simple example
val inputData = MemoryStream[(String, Long)]
// Split the lines into words, treat words as sessionId of events
val events = inputData.toDF()
.select($"_1".as("value"), $"_2".as("timestamp"))
.withColumn("eventTime", $"timestamp".cast("timestamp"))
.selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime")
.withWatermark("eventTime", "10 seconds")
val sessionUpdates = events
.groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId)
.agg(count("*").as("numEvents"))
.selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)",
"CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs",
"numEvents")
testStream(sessionUpdates, OutputMode.Update())(
AddData(inputData,
("hello world spark streaming", 40L),
("world hello structured streaming", 41L)
),
// watermark: 11
// current sessions
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("streaming", 40, 51, 11, 2),
// ("spark", 40, 50, 10, 1),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
("hello", 40, 51, 11, 2),
("world", 40, 51, 11, 2),
("streaming", 40, 51, 11, 2),
("spark", 40, 50, 10, 1),
("structured", 41, 51, 10, 1)
),
// placing new sessions "before" previous sessions
AddData(inputData, ("spark streaming", 25L)),
// watermark: 11
// current sessions
// ("spark", 25, 35, 10, 1),
// ("streaming", 25, 35, 10, 1),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("streaming", 40, 51, 11, 2),
// ("spark", 40, 50, 10, 1),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
("spark", 25, 35, 10, 1),
("streaming", 25, 35, 10, 1)
),
// late event which session's end 10 would be later than watermark 11: should be dropped
AddData(inputData, ("spark streaming", 0L)),
// watermark: 11
// current sessions
// ("spark", 25, 35, 10, 1),
// ("streaming", 25, 35, 10, 1),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("streaming", 40, 51, 11, 2),
// ("spark", 40, 50, 10, 1),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
),
// concatenating multiple previous sessions into one
AddData(inputData, ("spark streaming", 30L)),
// watermark: 11
// current sessions
// ("spark", 25, 50, 25, 3),
// ("streaming", 25, 51, 26, 4),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
("spark", 25, 50, 25, 3),
("streaming", 25, 51, 26, 4)
),
// placing new sessions after previous sessions
AddData(inputData, ("hello apache spark", 60L)),
// watermark: 30
// current sessions
// ("spark", 25, 50, 25, 3),
// ("streaming", 25, 51, 26, 4),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("structured", 41, 51, 10, 1),
// ("hello", 60, 70, 10, 1),
// ("apache", 60, 70, 10, 1),
// ("spark", 60, 70, 10, 1)
CheckNewAnswer(
("hello", 60, 70, 10, 1),
("apache", 60, 70, 10, 1),
("spark", 60, 70, 10, 1)
),
AddData(inputData, ("structured streaming", 90L)),
// watermark: 60
// current sessions
// ("hello", 60, 70, 10, 1),
// ("apache", 60, 70, 10, 1),
// ("spark", 60, 70, 10, 1),
// ("structured", 90, 100, 10, 1),
// ("streaming", 90, 100, 10, 1)
// evicted
// ("spark", 25, 50, 25, 3),
// ("streaming", 25, 51, 26, 4),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
("structured", 90, 100, 10, 1),
("streaming", 90, 100, 10, 1)
)
)
}
testWithAllOptions("update mode - session window - no key") {
val inputData = MemoryStream[Int]
val windowedAggregation = inputData.toDF()
.selectExpr("*")
.withColumn("eventTime", $"value".cast("timestamp"))
.withWatermark("eventTime", "10 seconds")
.groupBy(session_window($"eventTime", "5 seconds") as 'session)
.agg(count("*") as 'count, sum("value") as 'sum)
.select($"session".getField("start").cast("long").as[Long],
$"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long])
val e = intercept[StreamingQueryException] {
testStream(windowedAggregation, OutputMode.Update())(
AddData(inputData, 40),
CheckAnswer() // this is just to trigger the exception
)
}
Seq("Global aggregation with session window", "not supported").foreach { m =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT)))
}
}
}