[SPARK-36463][SS] Prohibit update mode in streaming aggregation with session window
### What changes were proposed in this pull request?
This PR proposes to prohibit update mode in streaming aggregation with session window.
UnsupportedOperationChecker will check and prohibit the case. As a side effect, this PR also simplifies the code as we can remove the implementation of iterator to support outputs of update mode.
This PR also cleans up test code via deduplicating.
### Why are the changes needed?
The semantic of "update" mode for session window based streaming aggregation is quite unclear.
For normal streaming aggregation, Spark will provide the outputs which can be "upsert"ed based on the grouping key. This is based on the fact grouping key won't be changed.
This doesn't hold true for session window based streaming aggregation, as session range is changing.
If end users leverage their knowledge about streaming aggregation, they will consider the key as grouping key + session (since they'll specify these things in groupBy), and it's high likely possible that existing row is not updated (overwritten) and ended up with having different rows.
If end users consider the key as grouping key, there's a small chance for end users to upsert the session correctly, though only the last updated session will be stored so it won't work with event time processing which there could be multiple active sessions.
### Does this PR introduce _any_ user-facing change?
No, as we haven't released this feature.
### How was this patch tested?
Updated tests.
Closes #33689 from HeartSaVioR/SPARK-36463.
Authored-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
(cherry picked from commit ed60aaa9f1
)
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
This commit is contained in:
parent
c6b683e5a2
commit
161908c10d
|
@ -1134,6 +1134,13 @@ sessionizedCounts = events \
|
|||
</div>
|
||||
</div>
|
||||
|
||||
Note that there are some restrictions when you use session window in streaming query, like below:
|
||||
|
||||
- "Update mode" as output mode is not supported.
|
||||
- There should be at least one column in addition to `session_window` in grouping key.
|
||||
|
||||
For batch query, global window (only having `session_window` in grouping key) is supported.
|
||||
|
||||
##### Conditions for watermarking to clean aggregation state
|
||||
{:.no_toc}
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
|
|||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, CurrentDate, CurrentTimestampLike, GroupingSets, LocalTimestamp, MonotonicallyIncreasingID}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, CurrentDate, CurrentTimestampLike, GroupingSets, LocalTimestamp, MonotonicallyIncreasingID, SessionWindow}
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
|
@ -169,6 +169,21 @@ object UnsupportedOperationChecker extends Logging {
|
|||
s"streaming DataFrames/DataSets without watermark")(plan)
|
||||
}
|
||||
|
||||
case InternalOutputModes.Update if aggregates.nonEmpty =>
|
||||
val aggregate = aggregates.head
|
||||
|
||||
val existingSessionWindow = aggregate.groupingExpressions.exists {
|
||||
case attr: AttributeReference
|
||||
if attr.metadata.contains(SessionWindow.marker) &&
|
||||
attr.metadata.getBoolean(SessionWindow.marker) => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
if (existingSessionWindow) {
|
||||
throwError(s"$outputMode output mode not supported for session window on " +
|
||||
"streaming DataFrames/DataSets")(plan)
|
||||
}
|
||||
|
||||
case InternalOutputModes.Complete if aggregates.isEmpty =>
|
||||
throwError(
|
||||
s"$outputMode output mode not supported when there are no streaming aggregations on " +
|
||||
|
|
|
@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.streaming
|
|||
import java.util.UUID
|
||||
import java.util.concurrent.TimeUnit._
|
||||
|
||||
import scala.annotation.tailrec
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
|
@ -672,46 +671,6 @@ case class SessionWindowStateStoreSaveExec(
|
|||
}
|
||||
}
|
||||
|
||||
case Some(Update) =>
|
||||
val baseIterator = watermarkPredicateForData match {
|
||||
case Some(predicate) => applyRemovingRowsOlderThanWatermark(iter, predicate)
|
||||
case None => iter
|
||||
}
|
||||
val iterPutToStore = iteratorPutToStore(baseIterator, store,
|
||||
returnOnlyUpdatedRows = true)
|
||||
new NextIterator[InternalRow] {
|
||||
private val updatesStartTimeNs = System.nanoTime
|
||||
|
||||
override protected def getNext(): InternalRow = {
|
||||
if (iterPutToStore.hasNext) {
|
||||
val row = iterPutToStore.next()
|
||||
numOutputRows += 1
|
||||
row
|
||||
} else {
|
||||
finished = true
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
override protected def close(): Unit = {
|
||||
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
|
||||
|
||||
allRemovalsTimeMs += timeTakenMs {
|
||||
if (watermarkPredicateForData.nonEmpty) {
|
||||
val removedIter = stateManager.removeByValueCondition(
|
||||
store, watermarkPredicateForData.get.eval)
|
||||
while (removedIter.hasNext) {
|
||||
numRemovedStateRows += 1
|
||||
removedIter.next()
|
||||
}
|
||||
}
|
||||
}
|
||||
commitTimeMs += timeTakenMs { store.commit() }
|
||||
setStoreMetrics(store)
|
||||
setOperatorMetrics()
|
||||
}
|
||||
}
|
||||
|
||||
case _ => throw QueryExecutionErrors.invalidStreamingOutputModeError(outputMode)
|
||||
}
|
||||
}
|
||||
|
@ -731,68 +690,38 @@ case class SessionWindowStateStoreSaveExec(
|
|||
newMetadata.batchWatermarkMs > eventTimeWatermark.get
|
||||
}
|
||||
|
||||
private def iteratorPutToStore(
|
||||
iter: Iterator[InternalRow],
|
||||
store: StateStore,
|
||||
returnOnlyUpdatedRows: Boolean): Iterator[InternalRow] = {
|
||||
private def putToStore(iter: Iterator[InternalRow], store: StateStore): Unit = {
|
||||
val numUpdatedStateRows = longMetric("numUpdatedStateRows")
|
||||
val numRemovedStateRows = longMetric("numRemovedStateRows")
|
||||
|
||||
new NextIterator[InternalRow] {
|
||||
var curKey: UnsafeRow = null
|
||||
val curValuesOnKey = new mutable.ArrayBuffer[UnsafeRow]()
|
||||
var curKey: UnsafeRow = null
|
||||
val curValuesOnKey = new mutable.ArrayBuffer[UnsafeRow]()
|
||||
|
||||
private def applyChangesOnKey(): Unit = {
|
||||
if (curValuesOnKey.nonEmpty) {
|
||||
val (upserted, deleted) = stateManager.updateSessions(store, curKey, curValuesOnKey.toSeq)
|
||||
numUpdatedStateRows += upserted
|
||||
numRemovedStateRows += deleted
|
||||
curValuesOnKey.clear
|
||||
}
|
||||
def applyChangesOnKey(): Unit = {
|
||||
if (curValuesOnKey.nonEmpty) {
|
||||
val (upserted, deleted) = stateManager.updateSessions(store, curKey, curValuesOnKey.toSeq)
|
||||
numUpdatedStateRows += upserted
|
||||
numRemovedStateRows += deleted
|
||||
curValuesOnKey.clear
|
||||
}
|
||||
}
|
||||
|
||||
while (iter.hasNext) {
|
||||
val row = iter.next().asInstanceOf[UnsafeRow]
|
||||
val key = stateManager.extractKeyWithoutSession(row)
|
||||
|
||||
if (curKey == null || curKey != key) {
|
||||
// new group appears
|
||||
applyChangesOnKey()
|
||||
curKey = key.copy()
|
||||
}
|
||||
|
||||
@tailrec
|
||||
override protected def getNext(): InternalRow = {
|
||||
if (!iter.hasNext) {
|
||||
applyChangesOnKey()
|
||||
finished = true
|
||||
return null
|
||||
}
|
||||
|
||||
val row = iter.next().asInstanceOf[UnsafeRow]
|
||||
val key = stateManager.extractKeyWithoutSession(row)
|
||||
|
||||
if (curKey == null || curKey != key) {
|
||||
// new group appears
|
||||
applyChangesOnKey()
|
||||
curKey = key.copy()
|
||||
}
|
||||
|
||||
// must copy the row, for this row is a reference in iterator and
|
||||
// will change when iter.next
|
||||
curValuesOnKey += row.copy
|
||||
|
||||
if (!returnOnlyUpdatedRows) {
|
||||
row
|
||||
} else {
|
||||
if (stateManager.newOrModified(store, row)) {
|
||||
row
|
||||
} else {
|
||||
// current row isn't the "updated" row, continue to the next row
|
||||
getNext()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override protected def close(): Unit = {}
|
||||
// must copy the row, for this row is a reference in iterator and
|
||||
// will change when iter.next
|
||||
curValuesOnKey += row.copy
|
||||
}
|
||||
}
|
||||
|
||||
private def putToStore(baseIter: Iterator[InternalRow], store: StateStore): Unit = {
|
||||
val iterPutToStore = iteratorPutToStore(baseIter, store, returnOnlyUpdatedRows = false)
|
||||
while (iterPutToStore.hasNext) {
|
||||
iterPutToStore.next()
|
||||
}
|
||||
applyChangesOnKey()
|
||||
}
|
||||
|
||||
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.scalatest.BeforeAndAfter
|
|||
import org.scalatest.matchers.must.Matchers
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.{AnalysisException, DataFrame}
|
||||
import org.apache.spark.sql.execution.streaming.MemoryStream
|
||||
import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider}
|
||||
import org.apache.spark.sql.functions.{count, session_window, sum}
|
||||
|
@ -75,21 +76,7 @@ class StreamingSessionWindowSuite extends StreamTest
|
|||
// always Unix timestamp 0
|
||||
|
||||
val inputData = MemoryStream[(String, Long)]
|
||||
|
||||
// Split the lines into words, treat words as sessionId of events
|
||||
val events = inputData.toDF()
|
||||
.select($"_1".as("value"), $"_2".as("timestamp"))
|
||||
.withColumn("eventTime", $"timestamp".cast("timestamp"))
|
||||
.selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime")
|
||||
|
||||
val sessionUpdates = events
|
||||
.groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId)
|
||||
.agg(count("*").as("numEvents"))
|
||||
.selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)",
|
||||
"CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs",
|
||||
"numEvents")
|
||||
|
||||
sessionUpdates.explain()
|
||||
val sessionUpdates = sessionWindowQuery(inputData)
|
||||
|
||||
testStream(sessionUpdates, OutputMode.Complete())(
|
||||
AddData(inputData,
|
||||
|
@ -160,14 +147,7 @@ class StreamingSessionWindowSuite extends StreamTest
|
|||
// always Unix timestamp 0
|
||||
|
||||
val inputData = MemoryStream[Int]
|
||||
|
||||
val windowedAggregation = inputData.toDF()
|
||||
.selectExpr("*")
|
||||
.withColumn("eventTime", $"value".cast("timestamp"))
|
||||
.groupBy(session_window($"eventTime", "5 seconds") as 'session)
|
||||
.agg(count("*") as 'count, sum("value") as 'sum)
|
||||
.select($"session".getField("start").cast("long").as[Long],
|
||||
$"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long])
|
||||
val windowedAggregation = sessionWindowQueryOnGlobalKey(inputData)
|
||||
|
||||
val e = intercept[StreamingQueryException] {
|
||||
testStream(windowedAggregation, OutputMode.Complete())(
|
||||
|
@ -185,20 +165,7 @@ class StreamingSessionWindowSuite extends StreamTest
|
|||
// as a test, to verify the sessionization works with simple example
|
||||
|
||||
val inputData = MemoryStream[(String, Long)]
|
||||
|
||||
// Split the lines into words, treat words as sessionId of events
|
||||
val events = inputData.toDF()
|
||||
.select($"_1".as("value"), $"_2".as("timestamp"))
|
||||
.withColumn("eventTime", $"timestamp".cast("timestamp"))
|
||||
.selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime")
|
||||
.withWatermark("eventTime", "30 seconds")
|
||||
|
||||
val sessionUpdates = events
|
||||
.groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId)
|
||||
.agg(count("*").as("numEvents"))
|
||||
.selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)",
|
||||
"CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs",
|
||||
"numEvents")
|
||||
val sessionUpdates = sessionWindowQuery(inputData)
|
||||
|
||||
testStream(sessionUpdates, OutputMode.Append())(
|
||||
AddData(inputData,
|
||||
|
@ -291,15 +258,7 @@ class StreamingSessionWindowSuite extends StreamTest
|
|||
|
||||
testWithAllOptions("append mode - session window - no key") {
|
||||
val inputData = MemoryStream[Int]
|
||||
|
||||
val windowedAggregation = inputData.toDF()
|
||||
.selectExpr("*")
|
||||
.withColumn("eventTime", $"value".cast("timestamp"))
|
||||
.withWatermark("eventTime", "10 seconds")
|
||||
.groupBy(session_window($"eventTime", "5 seconds") as 'session)
|
||||
.agg(count("*") as 'count, sum("value") as 'sum)
|
||||
.select($"session".getField("start").cast("long").as[Long],
|
||||
$"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long])
|
||||
val windowedAggregation = sessionWindowQueryOnGlobalKey(inputData)
|
||||
|
||||
val e = intercept[StreamingQueryException] {
|
||||
testStream(windowedAggregation)(
|
||||
|
@ -317,128 +276,52 @@ class StreamingSessionWindowSuite extends StreamTest
|
|||
// as a test, to verify the sessionization works with simple example
|
||||
|
||||
val inputData = MemoryStream[(String, Long)]
|
||||
val sessionUpdates = sessionWindowQuery(inputData)
|
||||
|
||||
val e = intercept[AnalysisException] {
|
||||
testStream(sessionUpdates, OutputMode.Update())(
|
||||
AddData(inputData, ("hello", 40L)),
|
||||
CheckAnswer() // this is just to trigger the exception
|
||||
)
|
||||
}
|
||||
Seq("Update output mode", "not supported", "for session window").foreach { m =>
|
||||
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT)))
|
||||
}
|
||||
}
|
||||
|
||||
testWithAllOptions("update mode - session window - no key") {
|
||||
val inputData = MemoryStream[Int]
|
||||
val windowedAggregation = sessionWindowQueryOnGlobalKey(inputData)
|
||||
|
||||
val e = intercept[AnalysisException] {
|
||||
testStream(windowedAggregation, OutputMode.Update())(
|
||||
AddData(inputData, 40),
|
||||
CheckAnswer() // this is just to trigger the exception
|
||||
)
|
||||
}
|
||||
Seq("Update output mode", "not supported", "for session window").foreach { m =>
|
||||
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT)))
|
||||
}
|
||||
}
|
||||
|
||||
private def sessionWindowQuery(input: MemoryStream[(String, Long)]): DataFrame = {
|
||||
// Split the lines into words, treat words as sessionId of events
|
||||
val events = inputData.toDF()
|
||||
val events = input.toDF()
|
||||
.select($"_1".as("value"), $"_2".as("timestamp"))
|
||||
.withColumn("eventTime", $"timestamp".cast("timestamp"))
|
||||
.withWatermark("eventTime", "30 seconds")
|
||||
.selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime")
|
||||
.withWatermark("eventTime", "10 seconds")
|
||||
|
||||
val sessionUpdates = events
|
||||
events
|
||||
.groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId)
|
||||
.agg(count("*").as("numEvents"))
|
||||
.selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)",
|
||||
"CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs",
|
||||
"numEvents")
|
||||
|
||||
testStream(sessionUpdates, OutputMode.Update())(
|
||||
AddData(inputData,
|
||||
("hello world spark streaming", 40L),
|
||||
("world hello structured streaming", 41L)
|
||||
),
|
||||
// watermark: 11
|
||||
// current sessions
|
||||
// ("hello", 40, 51, 11, 2),
|
||||
// ("world", 40, 51, 11, 2),
|
||||
// ("streaming", 40, 51, 11, 2),
|
||||
// ("spark", 40, 50, 10, 1),
|
||||
// ("structured", 41, 51, 10, 1)
|
||||
CheckNewAnswer(
|
||||
("hello", 40, 51, 11, 2),
|
||||
("world", 40, 51, 11, 2),
|
||||
("streaming", 40, 51, 11, 2),
|
||||
("spark", 40, 50, 10, 1),
|
||||
("structured", 41, 51, 10, 1)
|
||||
),
|
||||
|
||||
// placing new sessions "before" previous sessions
|
||||
AddData(inputData, ("spark streaming", 25L)),
|
||||
// watermark: 11
|
||||
// current sessions
|
||||
// ("spark", 25, 35, 10, 1),
|
||||
// ("streaming", 25, 35, 10, 1),
|
||||
// ("hello", 40, 51, 11, 2),
|
||||
// ("world", 40, 51, 11, 2),
|
||||
// ("streaming", 40, 51, 11, 2),
|
||||
// ("spark", 40, 50, 10, 1),
|
||||
// ("structured", 41, 51, 10, 1)
|
||||
CheckNewAnswer(
|
||||
("spark", 25, 35, 10, 1),
|
||||
("streaming", 25, 35, 10, 1)
|
||||
),
|
||||
|
||||
// late event which session's end 10 would be later than watermark 11: should be dropped
|
||||
AddData(inputData, ("spark streaming", 0L)),
|
||||
// watermark: 11
|
||||
// current sessions
|
||||
// ("spark", 25, 35, 10, 1),
|
||||
// ("streaming", 25, 35, 10, 1),
|
||||
// ("hello", 40, 51, 11, 2),
|
||||
// ("world", 40, 51, 11, 2),
|
||||
// ("streaming", 40, 51, 11, 2),
|
||||
// ("spark", 40, 50, 10, 1),
|
||||
// ("structured", 41, 51, 10, 1)
|
||||
CheckNewAnswer(
|
||||
),
|
||||
|
||||
// concatenating multiple previous sessions into one
|
||||
AddData(inputData, ("spark streaming", 30L)),
|
||||
// watermark: 11
|
||||
// current sessions
|
||||
// ("spark", 25, 50, 25, 3),
|
||||
// ("streaming", 25, 51, 26, 4),
|
||||
// ("hello", 40, 51, 11, 2),
|
||||
// ("world", 40, 51, 11, 2),
|
||||
// ("structured", 41, 51, 10, 1)
|
||||
CheckNewAnswer(
|
||||
("spark", 25, 50, 25, 3),
|
||||
("streaming", 25, 51, 26, 4)
|
||||
),
|
||||
|
||||
// placing new sessions after previous sessions
|
||||
AddData(inputData, ("hello apache spark", 60L)),
|
||||
// watermark: 30
|
||||
// current sessions
|
||||
// ("spark", 25, 50, 25, 3),
|
||||
// ("streaming", 25, 51, 26, 4),
|
||||
// ("hello", 40, 51, 11, 2),
|
||||
// ("world", 40, 51, 11, 2),
|
||||
// ("structured", 41, 51, 10, 1),
|
||||
// ("hello", 60, 70, 10, 1),
|
||||
// ("apache", 60, 70, 10, 1),
|
||||
// ("spark", 60, 70, 10, 1)
|
||||
CheckNewAnswer(
|
||||
("hello", 60, 70, 10, 1),
|
||||
("apache", 60, 70, 10, 1),
|
||||
("spark", 60, 70, 10, 1)
|
||||
),
|
||||
|
||||
AddData(inputData, ("structured streaming", 90L)),
|
||||
// watermark: 60
|
||||
// current sessions
|
||||
// ("hello", 60, 70, 10, 1),
|
||||
// ("apache", 60, 70, 10, 1),
|
||||
// ("spark", 60, 70, 10, 1),
|
||||
// ("structured", 90, 100, 10, 1),
|
||||
// ("streaming", 90, 100, 10, 1)
|
||||
// evicted
|
||||
// ("spark", 25, 50, 25, 3),
|
||||
// ("streaming", 25, 51, 26, 4),
|
||||
// ("hello", 40, 51, 11, 2),
|
||||
// ("world", 40, 51, 11, 2),
|
||||
// ("structured", 41, 51, 10, 1)
|
||||
CheckNewAnswer(
|
||||
("structured", 90, 100, 10, 1),
|
||||
("streaming", 90, 100, 10, 1)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
testWithAllOptions("update mode - session window - no key") {
|
||||
val inputData = MemoryStream[Int]
|
||||
|
||||
val windowedAggregation = inputData.toDF()
|
||||
private def sessionWindowQueryOnGlobalKey(input: MemoryStream[Int]): DataFrame = {
|
||||
input.toDF()
|
||||
.selectExpr("*")
|
||||
.withColumn("eventTime", $"value".cast("timestamp"))
|
||||
.withWatermark("eventTime", "10 seconds")
|
||||
|
@ -446,15 +329,5 @@ class StreamingSessionWindowSuite extends StreamTest
|
|||
.agg(count("*") as 'count, sum("value") as 'sum)
|
||||
.select($"session".getField("start").cast("long").as[Long],
|
||||
$"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long])
|
||||
|
||||
val e = intercept[StreamingQueryException] {
|
||||
testStream(windowedAggregation, OutputMode.Update())(
|
||||
AddData(inputData, 40),
|
||||
CheckAnswer() // this is just to trigger the exception
|
||||
)
|
||||
}
|
||||
Seq("Global aggregation with session window", "not supported").foreach { m =>
|
||||
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue