[SPARK-36463][SS] Prohibit update mode in streaming aggregation with session window

### What changes were proposed in this pull request?

This PR proposes to prohibit update mode in streaming aggregation with session window.

UnsupportedOperationChecker will check and prohibit the case. As a side effect, this PR also simplifies the code as we can remove the implementation of iterator to support outputs of update mode.

This PR also cleans up test code via deduplicating.

### Why are the changes needed?

The semantic of "update" mode for session window based streaming aggregation is quite unclear.

For normal streaming aggregation, Spark will provide the outputs which can be "upsert"ed based on the grouping key. This is based on the fact grouping key won't be changed.

This doesn't hold true for session window based streaming aggregation, as session range is changing.

If end users leverage their knowledge about streaming aggregation, they will consider the key as grouping key + session (since they'll specify these things in groupBy), and it's high likely possible that existing row is not updated (overwritten) and ended up with having different rows.

If end users consider the key as grouping key, there's a small chance for end users to upsert the session correctly, though only the last updated session will be stored so it won't work with event time processing which there could be multiple active sessions.

### Does this PR introduce _any_ user-facing change?

No, as we haven't released this feature.

### How was this patch tested?

Updated tests.

Closes #33689 from HeartSaVioR/SPARK-36463.

Authored-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
(cherry picked from commit ed60aaa9f1)
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
This commit is contained in:
Jungtaek Lim 2021-08-11 10:45:52 +09:00
parent c6b683e5a2
commit 161908c10d
4 changed files with 85 additions and 261 deletions

View file

@ -1134,6 +1134,13 @@ sessionizedCounts = events \
</div> </div>
</div> </div>
Note that there are some restrictions when you use session window in streaming query, like below:
- "Update mode" as output mode is not supported.
- There should be at least one column in addition to `session_window` in grouping key.
For batch query, global window (only having `session_window` in grouping key) is supported.
##### Conditions for watermarking to clean aggregation state ##### Conditions for watermarking to clean aggregation state
{:.no_toc} {:.no_toc}

View file

@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Attribute, CurrentDate, CurrentTimestampLike, GroupingSets, LocalTimestamp, MonotonicallyIncreasingID} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, CurrentDate, CurrentTimestampLike, GroupingSets, LocalTimestamp, MonotonicallyIncreasingID, SessionWindow}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
@ -169,6 +169,21 @@ object UnsupportedOperationChecker extends Logging {
s"streaming DataFrames/DataSets without watermark")(plan) s"streaming DataFrames/DataSets without watermark")(plan)
} }
case InternalOutputModes.Update if aggregates.nonEmpty =>
val aggregate = aggregates.head
val existingSessionWindow = aggregate.groupingExpressions.exists {
case attr: AttributeReference
if attr.metadata.contains(SessionWindow.marker) &&
attr.metadata.getBoolean(SessionWindow.marker) => true
case _ => false
}
if (existingSessionWindow) {
throwError(s"$outputMode output mode not supported for session window on " +
"streaming DataFrames/DataSets")(plan)
}
case InternalOutputModes.Complete if aggregates.isEmpty => case InternalOutputModes.Complete if aggregates.isEmpty =>
throwError( throwError(
s"$outputMode output mode not supported when there are no streaming aggregations on " + s"$outputMode output mode not supported when there are no streaming aggregations on " +

View file

@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.streaming
import java.util.UUID import java.util.UUID
import java.util.concurrent.TimeUnit._ import java.util.concurrent.TimeUnit._
import scala.annotation.tailrec
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import scala.collection.mutable import scala.collection.mutable
@ -672,46 +671,6 @@ case class SessionWindowStateStoreSaveExec(
} }
} }
case Some(Update) =>
val baseIterator = watermarkPredicateForData match {
case Some(predicate) => applyRemovingRowsOlderThanWatermark(iter, predicate)
case None => iter
}
val iterPutToStore = iteratorPutToStore(baseIterator, store,
returnOnlyUpdatedRows = true)
new NextIterator[InternalRow] {
private val updatesStartTimeNs = System.nanoTime
override protected def getNext(): InternalRow = {
if (iterPutToStore.hasNext) {
val row = iterPutToStore.next()
numOutputRows += 1
row
} else {
finished = true
null
}
}
override protected def close(): Unit = {
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
allRemovalsTimeMs += timeTakenMs {
if (watermarkPredicateForData.nonEmpty) {
val removedIter = stateManager.removeByValueCondition(
store, watermarkPredicateForData.get.eval)
while (removedIter.hasNext) {
numRemovedStateRows += 1
removedIter.next()
}
}
}
commitTimeMs += timeTakenMs { store.commit() }
setStoreMetrics(store)
setOperatorMetrics()
}
}
case _ => throw QueryExecutionErrors.invalidStreamingOutputModeError(outputMode) case _ => throw QueryExecutionErrors.invalidStreamingOutputModeError(outputMode)
} }
} }
@ -731,68 +690,38 @@ case class SessionWindowStateStoreSaveExec(
newMetadata.batchWatermarkMs > eventTimeWatermark.get newMetadata.batchWatermarkMs > eventTimeWatermark.get
} }
private def iteratorPutToStore( private def putToStore(iter: Iterator[InternalRow], store: StateStore): Unit = {
iter: Iterator[InternalRow],
store: StateStore,
returnOnlyUpdatedRows: Boolean): Iterator[InternalRow] = {
val numUpdatedStateRows = longMetric("numUpdatedStateRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows")
val numRemovedStateRows = longMetric("numRemovedStateRows") val numRemovedStateRows = longMetric("numRemovedStateRows")
new NextIterator[InternalRow] { var curKey: UnsafeRow = null
var curKey: UnsafeRow = null val curValuesOnKey = new mutable.ArrayBuffer[UnsafeRow]()
val curValuesOnKey = new mutable.ArrayBuffer[UnsafeRow]()
private def applyChangesOnKey(): Unit = { def applyChangesOnKey(): Unit = {
if (curValuesOnKey.nonEmpty) { if (curValuesOnKey.nonEmpty) {
val (upserted, deleted) = stateManager.updateSessions(store, curKey, curValuesOnKey.toSeq) val (upserted, deleted) = stateManager.updateSessions(store, curKey, curValuesOnKey.toSeq)
numUpdatedStateRows += upserted numUpdatedStateRows += upserted
numRemovedStateRows += deleted numRemovedStateRows += deleted
curValuesOnKey.clear curValuesOnKey.clear
} }
}
while (iter.hasNext) {
val row = iter.next().asInstanceOf[UnsafeRow]
val key = stateManager.extractKeyWithoutSession(row)
if (curKey == null || curKey != key) {
// new group appears
applyChangesOnKey()
curKey = key.copy()
} }
@tailrec // must copy the row, for this row is a reference in iterator and
override protected def getNext(): InternalRow = { // will change when iter.next
if (!iter.hasNext) { curValuesOnKey += row.copy
applyChangesOnKey()
finished = true
return null
}
val row = iter.next().asInstanceOf[UnsafeRow]
val key = stateManager.extractKeyWithoutSession(row)
if (curKey == null || curKey != key) {
// new group appears
applyChangesOnKey()
curKey = key.copy()
}
// must copy the row, for this row is a reference in iterator and
// will change when iter.next
curValuesOnKey += row.copy
if (!returnOnlyUpdatedRows) {
row
} else {
if (stateManager.newOrModified(store, row)) {
row
} else {
// current row isn't the "updated" row, continue to the next row
getNext()
}
}
}
override protected def close(): Unit = {}
} }
}
private def putToStore(baseIter: Iterator[InternalRow], store: StateStore): Unit = { applyChangesOnKey()
val iterPutToStore = iteratorPutToStore(baseIter, store, returnOnlyUpdatedRows = false)
while (iterPutToStore.hasNext) {
iterPutToStore.next()
}
} }
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =

View file

@ -23,6 +23,7 @@ import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.must.Matchers import org.scalatest.matchers.must.Matchers
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, DataFrame}
import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider} import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider}
import org.apache.spark.sql.functions.{count, session_window, sum} import org.apache.spark.sql.functions.{count, session_window, sum}
@ -75,21 +76,7 @@ class StreamingSessionWindowSuite extends StreamTest
// always Unix timestamp 0 // always Unix timestamp 0
val inputData = MemoryStream[(String, Long)] val inputData = MemoryStream[(String, Long)]
val sessionUpdates = sessionWindowQuery(inputData)
// Split the lines into words, treat words as sessionId of events
val events = inputData.toDF()
.select($"_1".as("value"), $"_2".as("timestamp"))
.withColumn("eventTime", $"timestamp".cast("timestamp"))
.selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime")
val sessionUpdates = events
.groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId)
.agg(count("*").as("numEvents"))
.selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)",
"CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs",
"numEvents")
sessionUpdates.explain()
testStream(sessionUpdates, OutputMode.Complete())( testStream(sessionUpdates, OutputMode.Complete())(
AddData(inputData, AddData(inputData,
@ -160,14 +147,7 @@ class StreamingSessionWindowSuite extends StreamTest
// always Unix timestamp 0 // always Unix timestamp 0
val inputData = MemoryStream[Int] val inputData = MemoryStream[Int]
val windowedAggregation = sessionWindowQueryOnGlobalKey(inputData)
val windowedAggregation = inputData.toDF()
.selectExpr("*")
.withColumn("eventTime", $"value".cast("timestamp"))
.groupBy(session_window($"eventTime", "5 seconds") as 'session)
.agg(count("*") as 'count, sum("value") as 'sum)
.select($"session".getField("start").cast("long").as[Long],
$"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long])
val e = intercept[StreamingQueryException] { val e = intercept[StreamingQueryException] {
testStream(windowedAggregation, OutputMode.Complete())( testStream(windowedAggregation, OutputMode.Complete())(
@ -185,20 +165,7 @@ class StreamingSessionWindowSuite extends StreamTest
// as a test, to verify the sessionization works with simple example // as a test, to verify the sessionization works with simple example
val inputData = MemoryStream[(String, Long)] val inputData = MemoryStream[(String, Long)]
val sessionUpdates = sessionWindowQuery(inputData)
// Split the lines into words, treat words as sessionId of events
val events = inputData.toDF()
.select($"_1".as("value"), $"_2".as("timestamp"))
.withColumn("eventTime", $"timestamp".cast("timestamp"))
.selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime")
.withWatermark("eventTime", "30 seconds")
val sessionUpdates = events
.groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId)
.agg(count("*").as("numEvents"))
.selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)",
"CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs",
"numEvents")
testStream(sessionUpdates, OutputMode.Append())( testStream(sessionUpdates, OutputMode.Append())(
AddData(inputData, AddData(inputData,
@ -291,15 +258,7 @@ class StreamingSessionWindowSuite extends StreamTest
testWithAllOptions("append mode - session window - no key") { testWithAllOptions("append mode - session window - no key") {
val inputData = MemoryStream[Int] val inputData = MemoryStream[Int]
val windowedAggregation = sessionWindowQueryOnGlobalKey(inputData)
val windowedAggregation = inputData.toDF()
.selectExpr("*")
.withColumn("eventTime", $"value".cast("timestamp"))
.withWatermark("eventTime", "10 seconds")
.groupBy(session_window($"eventTime", "5 seconds") as 'session)
.agg(count("*") as 'count, sum("value") as 'sum)
.select($"session".getField("start").cast("long").as[Long],
$"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long])
val e = intercept[StreamingQueryException] { val e = intercept[StreamingQueryException] {
testStream(windowedAggregation)( testStream(windowedAggregation)(
@ -317,128 +276,52 @@ class StreamingSessionWindowSuite extends StreamTest
// as a test, to verify the sessionization works with simple example // as a test, to verify the sessionization works with simple example
val inputData = MemoryStream[(String, Long)] val inputData = MemoryStream[(String, Long)]
val sessionUpdates = sessionWindowQuery(inputData)
val e = intercept[AnalysisException] {
testStream(sessionUpdates, OutputMode.Update())(
AddData(inputData, ("hello", 40L)),
CheckAnswer() // this is just to trigger the exception
)
}
Seq("Update output mode", "not supported", "for session window").foreach { m =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT)))
}
}
testWithAllOptions("update mode - session window - no key") {
val inputData = MemoryStream[Int]
val windowedAggregation = sessionWindowQueryOnGlobalKey(inputData)
val e = intercept[AnalysisException] {
testStream(windowedAggregation, OutputMode.Update())(
AddData(inputData, 40),
CheckAnswer() // this is just to trigger the exception
)
}
Seq("Update output mode", "not supported", "for session window").foreach { m =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT)))
}
}
private def sessionWindowQuery(input: MemoryStream[(String, Long)]): DataFrame = {
// Split the lines into words, treat words as sessionId of events // Split the lines into words, treat words as sessionId of events
val events = inputData.toDF() val events = input.toDF()
.select($"_1".as("value"), $"_2".as("timestamp")) .select($"_1".as("value"), $"_2".as("timestamp"))
.withColumn("eventTime", $"timestamp".cast("timestamp")) .withColumn("eventTime", $"timestamp".cast("timestamp"))
.withWatermark("eventTime", "30 seconds")
.selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime")
.withWatermark("eventTime", "10 seconds")
val sessionUpdates = events events
.groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId)
.agg(count("*").as("numEvents")) .agg(count("*").as("numEvents"))
.selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)",
"CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs",
"numEvents") "numEvents")
testStream(sessionUpdates, OutputMode.Update())(
AddData(inputData,
("hello world spark streaming", 40L),
("world hello structured streaming", 41L)
),
// watermark: 11
// current sessions
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("streaming", 40, 51, 11, 2),
// ("spark", 40, 50, 10, 1),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
("hello", 40, 51, 11, 2),
("world", 40, 51, 11, 2),
("streaming", 40, 51, 11, 2),
("spark", 40, 50, 10, 1),
("structured", 41, 51, 10, 1)
),
// placing new sessions "before" previous sessions
AddData(inputData, ("spark streaming", 25L)),
// watermark: 11
// current sessions
// ("spark", 25, 35, 10, 1),
// ("streaming", 25, 35, 10, 1),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("streaming", 40, 51, 11, 2),
// ("spark", 40, 50, 10, 1),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
("spark", 25, 35, 10, 1),
("streaming", 25, 35, 10, 1)
),
// late event which session's end 10 would be later than watermark 11: should be dropped
AddData(inputData, ("spark streaming", 0L)),
// watermark: 11
// current sessions
// ("spark", 25, 35, 10, 1),
// ("streaming", 25, 35, 10, 1),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("streaming", 40, 51, 11, 2),
// ("spark", 40, 50, 10, 1),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
),
// concatenating multiple previous sessions into one
AddData(inputData, ("spark streaming", 30L)),
// watermark: 11
// current sessions
// ("spark", 25, 50, 25, 3),
// ("streaming", 25, 51, 26, 4),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
("spark", 25, 50, 25, 3),
("streaming", 25, 51, 26, 4)
),
// placing new sessions after previous sessions
AddData(inputData, ("hello apache spark", 60L)),
// watermark: 30
// current sessions
// ("spark", 25, 50, 25, 3),
// ("streaming", 25, 51, 26, 4),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("structured", 41, 51, 10, 1),
// ("hello", 60, 70, 10, 1),
// ("apache", 60, 70, 10, 1),
// ("spark", 60, 70, 10, 1)
CheckNewAnswer(
("hello", 60, 70, 10, 1),
("apache", 60, 70, 10, 1),
("spark", 60, 70, 10, 1)
),
AddData(inputData, ("structured streaming", 90L)),
// watermark: 60
// current sessions
// ("hello", 60, 70, 10, 1),
// ("apache", 60, 70, 10, 1),
// ("spark", 60, 70, 10, 1),
// ("structured", 90, 100, 10, 1),
// ("streaming", 90, 100, 10, 1)
// evicted
// ("spark", 25, 50, 25, 3),
// ("streaming", 25, 51, 26, 4),
// ("hello", 40, 51, 11, 2),
// ("world", 40, 51, 11, 2),
// ("structured", 41, 51, 10, 1)
CheckNewAnswer(
("structured", 90, 100, 10, 1),
("streaming", 90, 100, 10, 1)
)
)
} }
testWithAllOptions("update mode - session window - no key") { private def sessionWindowQueryOnGlobalKey(input: MemoryStream[Int]): DataFrame = {
val inputData = MemoryStream[Int] input.toDF()
val windowedAggregation = inputData.toDF()
.selectExpr("*") .selectExpr("*")
.withColumn("eventTime", $"value".cast("timestamp")) .withColumn("eventTime", $"value".cast("timestamp"))
.withWatermark("eventTime", "10 seconds") .withWatermark("eventTime", "10 seconds")
@ -446,15 +329,5 @@ class StreamingSessionWindowSuite extends StreamTest
.agg(count("*") as 'count, sum("value") as 'sum) .agg(count("*") as 'count, sum("value") as 'sum)
.select($"session".getField("start").cast("long").as[Long], .select($"session".getField("start").cast("long").as[Long],
$"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long])
val e = intercept[StreamingQueryException] {
testStream(windowedAggregation, OutputMode.Update())(
AddData(inputData, 40),
CheckAnswer() // this is just to trigger the exception
)
}
Seq("Global aggregation with session window", "not supported").foreach { m =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT)))
}
} }
} }