From 47485a3c2df3201c838b939e82d5b26332e2d858 Mon Sep 17 00:00:00 2001 From: Rahul Mahadev Date: Fri, 2 Jul 2021 11:53:17 +0800 Subject: [PATCH] [SPARK-35897][SS] Support user defined initial state with flatMapGroupsWithState in Structured Streaming ### What changes were proposed in this pull request? This PR aims to add support for specifying a user defined initial state for arbitrary structured streaming stateful processing using [flat]MapGroupsWithState operator. ### Why are the changes needed? Users can load previous state of their stateful processing as an initial state instead of redoing the entire processing once again. ### Does this PR introduce _any_ user-facing change? Yes this PR introduces new API ``` def mapGroupsWithState[S: Encoder, U: Encoder]( timeoutConf: GroupStateTimeout, initialState: KeyValueGroupedDataset[K, S])( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, timeoutConf: GroupStateTimeout, initialState: KeyValueGroupedDataset[K, S])( func: (K, Iterator[V], GroupState[S]) => Iterator[U]) ``` ### How was this patch tested? Through unit tests in FlatMapGroupsWithStateSuite Closes #33093 from rahulsmahadev/flatMapGroupsWithState. Authored-by: Rahul Mahadev Signed-off-by: Gengliang Wang --- .../UnsupportedOperationChecker.scala | 12 + .../sql/catalyst/plans/logical/object.scala | 65 +++- .../analysis/UnsupportedOperationsSuite.scala | 116 ++++--- .../spark/sql/KeyValueGroupedDataset.scala | 164 ++++++++++ .../spark/sql/execution/SparkStrategies.scala | 10 +- .../FlatMapGroupsWithStateExec.scala | 266 +++++++++++----- .../streaming/IncrementalExecution.scala | 6 +- .../streaming/statefulOperators.scala | 4 +- .../spark/sql/streaming/GroupState.scala | 5 + .../apache/spark/sql/JavaDatasetSuite.java | 66 ++++ .../FlatMapGroupsWithStateSuite.scala | 283 +++++++++++++++++- 11 files changed, 875 insertions(+), 122 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index a3a85cb120..86293000d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -37,6 +37,12 @@ object UnsupportedOperationChecker extends Logging { case p if p.isStreaming => throwError("Queries with streaming sources must be executed with writeStream.start()")(p) + case f: FlatMapGroupsWithState => + if (f.hasInitialState) { + throwError("Initial state is not supported in [flatMap|map]GroupsWithState" + + " operation on a batch DataFrame/Dataset")(f) + } + case _ => } } @@ -232,6 +238,12 @@ object UnsupportedOperationChecker extends Logging { // Check compatibility with output modes and aggregations in query val aggsInQuery = collectStreamingAggregates(plan) + if (m.initialState.isStreaming) { + // initial state has to be a batch relation + throwError("Non-streaming DataFrame/Dataset is not supported as the" + + " initial state in [flatMap|map]GroupsWithState operation on a streaming" + + " DataFrame/Dataset") + } if (m.isMapGroupsWithState) { // check mapGroupsWithState // allowed only in update query output mode and without aggregation if (aggsInQuery.nonEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 1f7eb67bf1..e5fe07e2d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -440,7 +440,7 @@ object FlatMapGroupsWithState { isMapGroupsWithState: Boolean, timeout: GroupStateTimeout, child: LogicalPlan): LogicalPlan = { - val encoder = encoderFor[S] + val stateEncoder = encoderFor[S] val mapped = new FlatMapGroupsWithState( func, @@ -449,10 +449,49 @@ object FlatMapGroupsWithState { groupingAttributes, dataAttributes, CatalystSerde.generateObjAttr[U], - encoder.asInstanceOf[ExpressionEncoder[Any]], + stateEncoder.asInstanceOf[ExpressionEncoder[Any]], outputMode, isMapGroupsWithState, timeout, + hasInitialState = false, + groupingAttributes, + dataAttributes, + UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), + LocalRelation(stateEncoder.schema.toAttributes), // empty data set + child + ) + CatalystSerde.serialize[U](mapped) + } + + def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder]( + func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + outputMode: OutputMode, + isMapGroupsWithState: Boolean, + timeout: GroupStateTimeout, + child: LogicalPlan, + initialStateGroupAttrs: Seq[Attribute], + initialStateDataAttrs: Seq[Attribute], + initialState: LogicalPlan): LogicalPlan = { + val stateEncoder = encoderFor[S] + + val mapped = new FlatMapGroupsWithState( + func, + UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), + UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes), + groupingAttributes, + dataAttributes, + CatalystSerde.generateObjAttr[U], + stateEncoder.asInstanceOf[ExpressionEncoder[Any]], + outputMode, + isMapGroupsWithState, + timeout, + hasInitialState = true, + initialStateGroupAttrs, + initialStateDataAttrs, + UnresolvedDeserializer(encoderFor[S].deserializer, initialStateDataAttrs), + initialState, child) CatalystSerde.serialize[U](mapped) } @@ -474,6 +513,12 @@ object FlatMapGroupsWithState { * @param outputMode the output mode of `func` * @param isMapGroupsWithState whether it is created by the `mapGroupsWithState` method * @param timeout used to timeout groups that have not received data in a while + * @param hasInitialState Indicates whether initial state needs to be applied or not. + * @param initialStateGroupAttrs grouping attributes for the initial state + * @param initialStateDataAttrs used to read the initial state + * @param initialStateDeserializer used to extract the initial state objects. + * @param initialState user defined initial state that is applied in the first batch. + * @param child logical plan of the underlying data */ case class FlatMapGroupsWithState( func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], @@ -486,14 +531,24 @@ case class FlatMapGroupsWithState( outputMode: OutputMode, isMapGroupsWithState: Boolean = false, timeout: GroupStateTimeout, - child: LogicalPlan) extends UnaryNode with ObjectProducer { + hasInitialState: Boolean = false, + initialStateGroupAttrs: Seq[Attribute], + initialStateDataAttrs: Seq[Attribute], + initialStateDeserializer: Expression, + initialState: LogicalPlan, + child: LogicalPlan) extends BinaryNode with ObjectProducer { if (isMapGroupsWithState) { assert(outputMode == OutputMode.Update) } - override protected def withNewChildInternal(newChild: LogicalPlan): FlatMapGroupsWithState = - copy(child = newChild) + override def left: LogicalPlan = child + + override def right: LogicalPlan = initialState + + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): FlatMapGroupsWithState = + copy(child = newLeft, initialState = newRight) } /** Factory for constructing new `FlatMapGroupsInR` nodes. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 296d0ee8f4..96262f5afb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -24,13 +24,13 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, MonotonicallyIncreasingID, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, MonotonicallyIncreasingID, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder} /** A dummy command for testing unsupported operations. */ @@ -145,15 +145,15 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { for (funcMode <- Seq(Append, Update)) { assertSupportedInBatchPlan( s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, null, batchRelation)) assertSupportedInBatchPlan( s"flatMapGroupsWithState - multiple flatMapGroupsWithState($funcMode)s on batch relation", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, null, - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, null, batchRelation))) } @@ -162,7 +162,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertSupportedInStreamingPlan( "flatMapGroupsWithState - flatMapGroupsWithState(Update) " + "on streaming relation without aggregation in update mode", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, streamRelation), outputMode = Update) @@ -170,7 +170,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertNotSupportedInStreamingPlan( "flatMapGroupsWithState - flatMapGroupsWithState(Update) " + "on streaming relation without aggregation in append mode", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, streamRelation), outputMode = Append, @@ -179,7 +179,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertNotSupportedInStreamingPlan( "flatMapGroupsWithState - flatMapGroupsWithState(Update) " + "on streaming relation without aggregation in complete mode", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, streamRelation), outputMode = Complete, @@ -192,7 +192,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertNotSupportedInStreamingPlan( "flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation " + s"with aggregation in $outputMode mode", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), outputMode = outputMode, @@ -203,7 +203,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertSupportedInStreamingPlan( "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + "on streaming relation without aggregation in append mode", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, streamRelation), outputMode = Append) @@ -211,7 +211,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertNotSupportedInStreamingPlan( "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + "on streaming relation without aggregation in update mode", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, streamRelation), outputMode = Update, @@ -226,7 +226,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { Aggregate( Seq(attributeWithWatermark), aggExprs("c"), - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, streamRelation)), outputMode = outputMode, @@ -237,7 +237,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertNotSupportedInStreamingPlan( "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + s"on streaming relation after aggregation in $outputMode mode", - FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), outputMode = outputMode, @@ -247,7 +247,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertNotSupportedInStreamingPlan( "flatMapGroupsWithState - " + "flatMapGroupsWithState(Update) on streaming relation in complete mode", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, streamRelation), outputMode = Complete, @@ -261,7 +261,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertSupportedInStreamingPlan( s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation inside " + s"streaming relation in $outputMode output mode", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, null, batchRelation), outputMode = outputMode @@ -274,9 +274,9 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertSupportedInStreamingPlan( "flatMapGroupsWithState - multiple flatMapGroupsWithStates on streaming relation and all are " + "in append mode", - FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, - FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, streamRelation)), outputMode = Append, SQLConf.STATEFUL_OPERATOR_CHECK_CORRECTNESS_ENABLED.key -> "false") @@ -284,9 +284,9 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertNotSupportedInStreamingPlan( "flatMapGroupsWithState - multiple flatMapGroupsWithStates on s streaming relation but some" + " are not in append mode", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, streamRelation)), outputMode = Append, @@ -296,7 +296,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertNotSupportedInStreamingPlan( "mapGroupsWithState - mapGroupsWithState " + "on streaming relation without aggregation in append mode", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, streamRelation), outputMode = Append, @@ -307,7 +307,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertNotSupportedInStreamingPlan( "mapGroupsWithState - mapGroupsWithState " + "on streaming relation without aggregation in complete mode", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, streamRelation), outputMode = Complete, @@ -319,7 +319,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertNotSupportedInStreamingPlan( "mapGroupsWithState - mapGroupsWithState on streaming relation " + s"with aggregation in $outputMode mode", - FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Update, + TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), outputMode = outputMode, @@ -330,9 +330,9 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertNotSupportedInStreamingPlan( "mapGroupsWithState - multiple mapGroupsWithStates on streaming relation and all are " + "in append mode", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, streamRelation)), outputMode = Append, @@ -342,9 +342,9 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertNotSupportedInStreamingPlan( "mapGroupsWithState - " + "mixing mapGroupsWithStates and flatMapGroupsWithStates on streaming relation", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, streamRelation) ), @@ -354,7 +354,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { // mapGroupsWithState with event time timeout + watermark assertNotSupportedInStreamingPlan( "mapGroupsWithState - mapGroupsWithState with event time timeout without watermark", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, EventTimeTimeout, streamRelation), outputMode = Update, @@ -362,7 +362,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertSupportedInStreamingPlan( "mapGroupsWithState - mapGroupsWithState with event time timeout with watermark", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, EventTimeTimeout, new TestStreamingRelation(attributeWithWatermark)), outputMode = Update) @@ -532,7 +532,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { testGlobalWatermarkLimit( s"FlatMapGroupsWithState after stream-stream $joinType join in Append mode", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, streamRelation.join(streamRelation, joinType = joinType, @@ -664,7 +664,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertFailOnGlobalWatermarkLimit( "FlatMapGroupsWithState after streaming aggregation in Append mode", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, streamRelation.groupBy("a")(count("*"))), @@ -675,14 +675,14 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { { assertPassOnGlobalWatermarkLimit( "single FlatMapGroupsWithState in Append mode", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, streamRelation), OutputMode.Append()) assertFailOnGlobalWatermarkLimit( "streaming aggregation after FlatMapGroupsWithState in Append mode", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, streamRelation).groupBy("*")(count("*")), OutputMode.Append()) @@ -691,7 +691,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertFailOnGlobalWatermarkLimit( s"stream-stream $joinType after FlatMapGroupsWithState in Append mode", streamRelation.join( - FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, streamRelation), joinType = joinType, condition = Some(attributeWithWatermark === attribute)), OutputMode.Append()) @@ -699,16 +699,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertFailOnGlobalWatermarkLimit( "FlatMapGroupsWithState after FlatMapGroupsWithState in Append mode", - FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, - FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, streamRelation)), OutputMode.Append()) assertFailOnGlobalWatermarkLimit( s"deduplicate after FlatMapGroupsWithState in Append mode", Deduplicate(Seq(attribute), - FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, streamRelation)), OutputMode.Append()) } @@ -730,7 +730,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assertPassOnGlobalWatermarkLimit( "FlatMapGroupsWithState after deduplicate in Append mode", - FlatMapGroupsWithState( + TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, Deduplicate(Seq(attribute), streamRelation)), @@ -1015,3 +1015,43 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { override def nodeName: String = "StreamingRelationV2" } } + +object TestFlatMapGroupsWithState { + + // scalastyle:off + // Creating an apply constructor here as we changed the class by adding more fields + def apply(func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, + stateEncoder: ExpressionEncoder[Any], + outputMode: OutputMode, + isMapGroupsWithState: Boolean = false, + timeout: GroupStateTimeout, + child: LogicalPlan): FlatMapGroupsWithState = { + + val attribute = AttributeReference("a", IntegerType, nullable = true)() + val batchRelation = LocalRelation(attribute) + new FlatMapGroupsWithState( + func, + keyDeserializer, + valueDeserializer, + groupingAttributes, + dataAttributes, + outputObjAttr, + stateEncoder, + outputMode, + isMapGroupsWithState, + timeout, + false, + groupingAttributes, + dataAttributes, + valueDeserializer, + batchRelation, + child + ) + } + // scalastyle:on +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 76ee297dfc..add692f57d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -280,6 +280,51 @@ class KeyValueGroupedDataset[K, V] private[sql]( child = logicalPlan)) } + /** + * (Scala-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[org.apache.spark.sql.streaming.GroupState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param timeoutConf Timeout Conf, see GroupStateTimeout for more details + * @param initialState The user provided state that will be initialized when the first batch + * of data is processed in the streaming query. The user defined function + * will be called on the state data even if there are no other values in + * the group. To convert a Dataset ds of type Dataset[(K, S)] to a + * KeyValueGroupedDataset[K, S] + * do {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}} + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 3.2.0 + */ + def mapGroupsWithState[S: Encoder, U: Encoder]( + timeoutConf: GroupStateTimeout, + initialState: KeyValueGroupedDataset[K, S])( + func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { + val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s)) + + Dataset[U]( + sparkSession, + FlatMapGroupsWithState[K, V, S, U]( + flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]], + groupingAttributes, + dataAttributes, + OutputMode.Update, + isMapGroupsWithState = true, + timeoutConf, + child = logicalPlan, + initialState.groupingAttributes, + initialState.dataAttributes, + initialState.queryExecution.analyzed + )) + } + /** * (Java-specific) * Applies the given function to each group of data, while maintaining a user-defined per-group @@ -336,6 +381,40 @@ class KeyValueGroupedDataset[K, V] private[sql]( )(stateEncoder, outputEncoder) } + /** + * (Java-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See `GroupState` for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param stateEncoder Encoder for the state type. + * @param outputEncoder Encoder for the output type. + * @param timeoutConf Timeout configuration for groups that do not receive data for a while. + * @param initialState The user provided state that will be initialized when the first batch + * of data is processed in the streaming query. The user defined function + * will be called on the state data even if there are no other values in + * the group. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 3.2.0 + */ + def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout, + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { + mapGroupsWithState[S, U](timeoutConf, initialState)( + (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s) + )(stateEncoder, outputEncoder) + } + /** * (Scala-specific) * Applies the given function to each group of data, while maintaining a user-defined per-group @@ -373,6 +452,53 @@ class KeyValueGroupedDataset[K, V] private[sql]( child = logicalPlan)) } + /** + * (Scala-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See `GroupState` for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param outputMode The output mode of the function. + * @param timeoutConf Timeout configuration for groups that do not receive data for a while. + * @param initialState The user provided state that will be initialized when the first batch + * of data is processed in the streaming query. The user defined function + * will be called on the state data even if there are no other values in + * the group. To covert a Dataset `ds` of type of type `Dataset[(K, S)]` + * to a `KeyValueGroupedDataset[K, S]`, use + * {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}} + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 3.2.0 + */ + def flatMapGroupsWithState[S: Encoder, U: Encoder]( + outputMode: OutputMode, + timeoutConf: GroupStateTimeout, + initialState: KeyValueGroupedDataset[K, S])( + func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = { + if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { + throw new IllegalArgumentException("The output mode of function should be append or update") + } + Dataset[U]( + sparkSession, + FlatMapGroupsWithState[K, V, S, U]( + func.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]], + groupingAttributes, + dataAttributes, + outputMode, + isMapGroupsWithState = false, + timeoutConf, + child = logicalPlan, + initialState.groupingAttributes, + initialState.dataAttributes, + initialState.queryExecution.analyzed + )) + } + /** * (Java-specific) * Applies the given function to each group of data, while maintaining a user-defined per-group @@ -403,6 +529,44 @@ class KeyValueGroupedDataset[K, V] private[sql]( flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder) } + /** + * (Java-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See `GroupState` for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param outputMode The output mode of the function. + * @param stateEncoder Encoder for the state type. + * @param outputEncoder Encoder for the output type. + * @param timeoutConf Timeout configuration for groups that do not receive data for a while. + * @param initialState The user provided state that will be initialized when the first batch + * of data is processed in the streaming query. The user defined function + * will be called on the state data even if there are no other values in + * the group. To covert a Dataset `ds` of type of type `Dataset[(K, S)]` + * to a `KeyValueGroupedDataset[K, S]`, use + * {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}} + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 3.2.0 + */ + def flatMapGroupsWithState[S, U]( + func: FlatMapGroupsWithStateFunction[K, V, S, U], + outputMode: OutputMode, + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout, + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { + val f = (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s).asScala + flatMapGroupsWithState[S, U]( + outputMode, timeoutConf, initialState)(f)(stateEncoder, outputEncoder) + } + /** * (Scala-specific) * Reduces the elements of each group of data using the specified binary function. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4076229d98..65a592302c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -561,11 +561,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case FlatMapGroupsWithState( func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _, - timeout, child) => + timeout, hasInitialState, stateGroupAttr, sda, sDeser, initialState, child) => val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) val execPlan = FlatMapGroupsWithStateExec( - func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, stateVersion, - outputMode, timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child)) + func, keyDeser, valueDeser, sDeser, groupAttr, stateGroupAttr, dataAttr, sda, outputAttr, + None, stateEnc, stateVersion, outputMode, timeout, batchTimestampMs = None, + eventTimeWatermark = None, planLater(initialState), hasInitialState, planLater(child) + ) execPlan :: Nil case _ => Nil @@ -669,7 +671,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil case logical.FlatMapGroupsWithState( - f, key, value, grouping, data, output, _, _, _, timeout, child) => + f, key, value, grouping, data, output, _, _, _, timeout, _, _, _, _, _, child) => execution.MapGroupsExec( f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index fda26b0a50..0e0fbe09f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -25,9 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expressi import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} -import org.apache.spark.util.CompletionIterator +import org.apache.spark.sql.streaming.GroupStateTimeout.NoTimeout +import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} /** * Physical operator for executing `FlatMapGroupsWithState` @@ -35,6 +37,7 @@ import org.apache.spark.util.CompletionIterator * @param func function called on each group * @param keyDeserializer used to extract the key object for each group. * @param valueDeserializer used to extract the items in the iterator from an input row. + * @param initialStateDeserializer used to extract the state object from the initialState dataset * @param groupingAttributes used to group the data * @param dataAttributes used to read the data * @param outputObjAttr Defines the output object @@ -42,13 +45,20 @@ import org.apache.spark.util.CompletionIterator * @param outputMode the output mode of `func` * @param timeoutConf used to timeout groups that have not received data in a while * @param batchTimestampMs processing timestamp of the current batch. + * @param eventTimeWatermark event time watermark for the current batch + * @param initialState the user specified initial state + * @param hasInitialState indicates whether the initial state is provided or not + * @param child the physical plan for the underlying data */ case class FlatMapGroupsWithStateExec( func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], keyDeserializer: Expression, valueDeserializer: Expression, + initialStateDeserializer: Expression, groupingAttributes: Seq[Attribute], + initialStateGroupAttrs: Seq[Attribute], dataAttributes: Seq[Attribute], + initialStateDataAttrs: Seq[Attribute], outputObjAttr: Attribute, stateInfo: Option[StatefulOperatorStateInfo], stateEncoder: ExpressionEncoder[Any], @@ -57,27 +67,45 @@ case class FlatMapGroupsWithStateExec( timeoutConf: GroupStateTimeout, batchTimestampMs: Option[Long], eventTimeWatermark: Option[Long], + initialState: SparkPlan, + hasInitialState: Boolean, child: SparkPlan - ) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport { + ) extends BinaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport { import FlatMapGroupsWithStateExecHelper._ import GroupStateImpl._ + override def left: SparkPlan = child + + override def right: SparkPlan = initialState + private val isTimeoutEnabled = timeoutConf != NoTimeout private val watermarkPresent = child.output.exists { case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true case _ => false } + private[sql] val stateManager = createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion) - /** Distribute by grouping attributes */ - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) :: Nil + /** + * Distribute by grouping attributes - We need the underlying data and the initial state data + * to have the same grouping so that the data are co-lacated on the same task. + */ + override def requiredChildDistribution: Seq[Distribution] = { + ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) :: + ClusteredDistribution(initialStateGroupAttrs, stateInfo.map(_.numPartitions)) :: + Nil + } - /** Ordering needed for using GroupingIterator */ - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingAttributes.map(SortOrder(_, Ascending))) + /** + * Ordering needed for using GroupingIterator. + * We need the initial state to also use the ordering as the data so that we can co-locate the + * keys from the underlying data and the initial state. + */ + override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( + groupingAttributes.map(SortOrder(_, Ascending)), + initialStateGroupAttrs.map(SortOrder(_, Ascending))) override def keyExpressions: Seq[Attribute] = groupingAttributes @@ -95,6 +123,77 @@ case class FlatMapGroupsWithStateExec( } } + /** + * Process data by applying the user defined function on a per partition basis. + * + * @param iter - Iterator of the data rows + * @param store - associated state store for this partition + * @param processor - handle to the input processor object. + * @param initialStateIterOption - optional initial state iterator + */ + def processDataWithPartition( + iter: Iterator[InternalRow], + store: StateStore, + processor: InputProcessor, + initialStateIterOption: Option[Iterator[InternalRow]] = None + ): CompletionIterator[InternalRow, Iterator[InternalRow]] = { + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val commitTimeMs = longMetric("commitTimeMs") + val timeoutLatencyMs = longMetric("allRemovalsTimeMs") + + val currentTimeNs = System.nanoTime + val updatesStartTimeNs = currentTimeNs + var timeoutProcessingStartTimeNs = currentTimeNs + + // If timeout is based on event time, then filter late data based on watermark + val filteredIter = watermarkPredicateForData match { + case Some(predicate) if timeoutConf == EventTimeTimeout => + applyRemovingRowsOlderThanWatermark(iter, predicate) + case _ => + iter + } + + val processedOutputIterator = initialStateIterOption match { + case Some(initStateIter) if initStateIter.hasNext => + processor.processNewDataWithInitialState(filteredIter, initStateIter) + case _ => processor.processNewData(filteredIter) + } + + val newDataProcessorIter = + CompletionIterator[InternalRow, Iterator[InternalRow]]( + processedOutputIterator, { + // Once the input is processed, mark the start time for timeout processing to measure + // it separately from the overall processing time. + timeoutProcessingStartTimeNs = System.nanoTime + }) + + val timeoutProcessorIter = + CompletionIterator[InternalRow, Iterator[InternalRow]](processor.processTimedOutState(), { + // Note: `timeoutLatencyMs` also includes the time the parent operator took for + // processing output returned through iterator. + timeoutLatencyMs += NANOSECONDS.toMillis(System.nanoTime - timeoutProcessingStartTimeNs) + }) + + // Generate a iterator that returns the rows grouped by the grouping function + // Note that this code ensures that the filtering for timeout occurs only after + // all the data has been processed. This is to ensure that the timeout information of all + // the keys with data is updated before they are processed for timeouts. + val outputIterator = newDataProcessorIter ++ timeoutProcessorIter + + // Return an iterator of all the rows generated by all the keys, such that when fully + // consumed, all the state updates will be committed by the state store + CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator, { + // Note: Due to the iterator lazy execution, this metric also captures the time taken + // by the upstream (consumer) operators in addition to the processing in this operator. + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + commitTimeMs += timeTakenMs { + store.commit() + } + setStoreMetrics(store) + setOperatorMetrics() + }) + } + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver @@ -103,69 +202,50 @@ case class FlatMapGroupsWithStateExec( case ProcessingTimeTimeout => require(batchTimestampMs.nonEmpty) case EventTimeTimeout => - require(eventTimeWatermark.nonEmpty) // watermark value has been populated + require(eventTimeWatermark.nonEmpty) // watermark value has been populated require(watermarkExpression.nonEmpty) // input schema has watermark attribute case _ => } - child.execute().mapPartitionsWithStateStore[InternalRow]( - getStateInfo, - groupingAttributes.toStructType, - stateManager.stateSchema, - indexOrdinal = None, - session.sessionState, - Some(session.streams.stateStoreCoordinator)) { case (store, iter) => - val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") - val commitTimeMs = longMetric("commitTimeMs") - val timeoutLatencyMs = longMetric("allRemovalsTimeMs") + if (hasInitialState) { + // If the user provided initial state we need to have the initial state and the + // data in the same partition so that we can still have just one commit at the end. + val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf) + val hadoopConfBroadcast = sparkContext.broadcast( + new SerializableConfiguration(session.sqlContext.sessionState.newHadoopConf())) + child.execute().stateStoreAwareZipPartitions( + initialState.execute(), + getStateInfo, + storeNames = Seq(), + session.sqlContext.streams.stateStoreCoordinator) { + // The state store aware zip partitions will provide us with two iterators, + // child data iterator and the initial state iterator per partition. + case (partitionId, childDataIterator, initStateIterator) => + + val stateStoreId = StateStoreId( + stateInfo.get.checkpointLocation, stateInfo.get.operatorId, partitionId) + val storeProviderId = StateStoreProviderId(stateStoreId, stateInfo.get.queryRunId) + val store = StateStore.get( + storeProviderId, + groupingAttributes.toStructType, + stateManager.stateSchema, + indexOrdinal = None, + stateInfo.get.storeVersion, storeConf, hadoopConfBroadcast.value.value) + val processor = new InputProcessor(store) + processDataWithPartition(childDataIterator, store, processor, Some(initStateIterator)) + } + } else { + child.execute().mapPartitionsWithStateStore[InternalRow]( + getStateInfo, + groupingAttributes.toStructType, + stateManager.stateSchema, + indexOrdinal = None, + session.sqlContext.sessionState, + Some(session.sqlContext.streams.stateStoreCoordinator) + ) { case (store: StateStore, singleIterator: Iterator[InternalRow]) => val processor = new InputProcessor(store) - - val currentTimeNs = System.nanoTime - val updatesStartTimeNs = currentTimeNs - var timeoutProcessingStartTimeNs = currentTimeNs - - // If timeout is based on event time, then filter late data based on watermark - val filteredIter = watermarkPredicateForData match { - case Some(predicate) if timeoutConf == EventTimeTimeout => - applyRemovingRowsOlderThanWatermark(iter, predicate) - case _ => - iter - } - - val newDataProcessorIter = - CompletionIterator[InternalRow, Iterator[InternalRow]]( - processor.processNewData(filteredIter), { - // Once the input is processed, mark the start time for timeout processing to measure - // it separately from the overall processing time. - timeoutProcessingStartTimeNs = System.nanoTime - }) - - val timeoutProcessorIter = - CompletionIterator[InternalRow, Iterator[InternalRow]](processor.processTimedOutState(), { - // Note: `timeoutLatencyMs` also includes the time the parent operator took for - // processing output returned through iterator. - timeoutLatencyMs += NANOSECONDS.toMillis(System.nanoTime - timeoutProcessingStartTimeNs) - }) - - // Generate a iterator that returns the rows grouped by the grouping function - // Note that this code ensures that the filtering for timeout occurs only after - // all the data has been processed. This is to ensure that the timeout information of all - // the keys with data is updated before they are processed for timeouts. - val outputIterator = newDataProcessorIter ++ timeoutProcessorIter - - // Return an iterator of all the rows generated by all the keys, such that when fully - // consumed, all the state updates will be committed by the state store - CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator, { - // Note: Due to the iterator lazy execution, this metric also captures the time taken - // by the upstream (consumer) operators in addition to the processing in this operator. - allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) - commitTimeMs += timeTakenMs { - store.commit() - } - setStoreMetrics(store) - setOperatorMetrics() - } - ) + processDataWithPartition(singleIterator, store, processor) + } } } @@ -178,6 +258,11 @@ case class FlatMapGroupsWithStateExec( private val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType) + private val getStateObj = if (hasInitialState) { + Some(ObjectOperator.deserializeRowToObject(initialStateDeserializer, initialStateDataAttrs)) + } else { + None + } // Metrics private val numUpdatedStateRows = longMetric("numUpdatedStateRows") @@ -199,6 +284,49 @@ case class FlatMapGroupsWithStateExec( } } + /** + * Process the new data iterator along with the initial state. The initial state is applied + * before processing the new data for every key. The user defined function is called only + * once for every key that has either initial state or data or both. + */ + def processNewDataWithInitialState( + childDataIter: Iterator[InternalRow], + initStateIter: Iterator[InternalRow] + ): Iterator[InternalRow] = { + + if (!childDataIter.hasNext && !initStateIter.hasNext) return Iterator.empty + + // Create iterators for the child data and the initial state grouped by their grouping + // attributes. + val groupedChildDataIter = GroupedIterator(childDataIter, groupingAttributes, child.output) + val groupedInitialStateIter = + GroupedIterator(initStateIter, initialStateGroupAttrs, initialState.output) + + // Create a CoGroupedIterator that will group the two iterators together for every key group. + new CoGroupedIterator( + groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap { + case (keyRow, valueRowIter, initialStateRowIter) => + val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] + var foundInitialStateForKey = false + initialStateRowIter.foreach { initialStateRow => + if (foundInitialStateForKey) { + throw new IllegalArgumentException("The initial state provided contained " + + "multiple rows(state) with the same key. Make sure to de-duplicate the " + + "initial state before passing it.") + } + foundInitialStateForKey = true + val initStateObj = getStateObj.get(initialStateRow) + stateManager.putState(store, keyUnsafeRow, initStateObj, NO_TIMESTAMP) + } + // We apply the values for the key after applying the initial state. + callFunctionAndUpdateState( + stateManager.getState(store, keyUnsafeRow), + valueRowIter, + hasTimedOut = false + ) + } + } + /** Find the groups that have timeout set and are timing out right now, and call the function */ def processTimedOutState(): Iterator[InternalRow] = { if (isTimeoutEnabled) { @@ -271,6 +399,8 @@ case class FlatMapGroupsWithStateExec( } } - override protected def withNewChildInternal(newChild: SparkPlan): FlatMapGroupsWithStateExec = - copy(child = newChild) + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): FlatMapGroupsWithStateExec = + copy(child = newLeft, initialState = newRight) } + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 1a3e3a7729..e98996b8e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -157,10 +157,14 @@ class IncrementalExecution( Some(offsetSeqMetadata.batchWatermarkMs)) case m: FlatMapGroupsWithStateExec => + // We set this to true only for the first batch of the streaming query. + val hasInitialState = (currentBatchId == 0L && m.hasInitialState) m.copy( stateInfo = Some(nextStatefulOperationStateInfo), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), - eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs)) + eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs), + hasInitialState = hasInitialState + ) case j: StreamingSymmetricHashJoinExec => j.copy( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 53650092ed..22feff3710 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -203,7 +203,9 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => } /** An operator that supports watermark. */ -trait WatermarkSupport extends UnaryExecNode { +trait WatermarkSupport extends SparkPlan { + + def child: SparkPlan /** The keys that may have a watermark attribute. */ def keyExpressions: Seq[Attribute] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index f50727307d..71c6aaea8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -111,6 +111,11 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState * when the group has new data, or the group has timed out. So the user has to set the timeout * duration every time the function is called, otherwise, there will not be any timeout set. * + * `[map/flatMap]GroupsWithState` can take a user defined initial state as an additional argument. + * This state will be applied when the first batch of the streaming query is processed. If there + * are no matching rows in the data for the keys present in the initial state, the state is still + * applied and the function will be invoked with the values being an empty iterator. + * * Scala example of using GroupState in `mapGroupsWithState`: * {{{ * // A mapping function that maintains an integer state for string keys and returns a string. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 5e48dc653d..0500c5217b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -46,6 +46,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.StructType; import org.apache.spark.util.LongAccumulator; + import static org.apache.spark.sql.functions.col; import static org.apache.spark.sql.functions.expr; import static org.apache.spark.sql.types.DataTypes.*; @@ -160,6 +161,71 @@ public class JavaDatasetSuite implements Serializable { Assert.assertEquals(6, reduced); } + @Test + public void testInitialStateFlatMapGroupsWithState() { + List data = Arrays.asList("a", "foo", "bar"); + Dataset ds = spark.createDataset(data, Encoders.STRING()); + Dataset> initialStateDS = spark.createDataset( + Arrays.asList(new Tuple2(2, 2L)), + Encoders.tuple(Encoders.INT(), Encoders.LONG()) + ); + + KeyValueGroupedDataset> kvInitStateDS = + initialStateDS.groupByKey( + (MapFunction, Integer>) f -> f._1, Encoders.INT()); + + KeyValueGroupedDataset kvInitStateMappedDS = kvInitStateDS.mapValues( + (MapFunction, Long>) f -> f._2, + Encoders.LONG() + ); + + KeyValueGroupedDataset grouped = + ds.groupByKey((MapFunction) String::length, Encoders.INT()); + + Dataset flatMapped2 = grouped.flatMapGroupsWithState( + (FlatMapGroupsWithStateFunction) (key, values, s) -> { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return Collections.singletonList(sb.toString()).iterator(); + }, + OutputMode.Append(), + Encoders.LONG(), + Encoders.STRING(), + GroupStateTimeout.NoTimeout(), + kvInitStateMappedDS); + + Assert.assertThrows( + "Initial state is not supported in [flatMap|map]GroupsWithState " + + "operation on a batch DataFrame/Dataset", + AnalysisException.class, + () -> { + flatMapped2.collectAsList(); + } + ); + Dataset mapped2 = grouped.mapGroupsWithState( + (MapGroupsWithStateFunction) (key, values, s) -> { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return sb.toString(); + }, + Encoders.LONG(), + Encoders.STRING(), + GroupStateTimeout.NoTimeout(), + kvInitStateMappedDS); + Assert.assertThrows( + "Initial state is not supported in [flatMap|map]GroupsWithState " + + "operation on a batch DataFrame/Dataset", + AnalysisException.class, + () -> { + mapped2.collectAsList(); + } + ); + } + @Test public void testIllegalTestGroupStateCreations() { // SPARK-35800: test code throws upon illegal TestGroupState create() calls diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 171d330c20..152dd167fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkException import org.apache.spark.api.java.Optional import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction -import org.apache.spark.sql.{DataFrame, Encoder} +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Encoder, KeyValueGroupedDataset} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState @@ -1042,7 +1042,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { ) } - test("mapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) @@ -1268,12 +1267,281 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { assert(e.getMessage === "The output mode of function should be append or update") } + import testImplicits._ + + /** + * FlatMapGroupsWithState function that returns the key, value as passed to it + * along with the updated state. The state is incremented for every value. + */ + val flatMapGroupsWithStateFunc = + (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + val valList = values.toSeq + if (valList.isEmpty) { + // When the function is called on just the initial state make sure the other fields + // are set correctly + assert(state.exists) + } + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + assert(!state.hasTimedOut) + val count = state.getOption.map(_.count).getOrElse(0L) + valList.size + // We need to check if not explicitly calling update will still save the init state or not + if (!key.contains("NoUpdate")) { + // this is not reached when valList is empty and the state count is 2 + state.update(new RunningCount(count)) + } + Iterator((key, valList, count.toString)) + } + + Seq("1", "2", "6").foreach { shufflePartitions => + testWithAllStateVersions(s"flatMapGroupsWithState - initial " + + s"state - all cases - shuffle partitions ${shufflePartitions}") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> shufflePartitions) { + // We will test them on different shuffle partition configuration to make sure the + // grouping by key will still work. On higher number of shuffle partitions its possible + // that all keys end up on different partitions. + val initialState: Dataset[(String, RunningCount)] = Seq( + ("keyInStateAndData-1", new RunningCount(1)), + ("keyInStateAndData-2", new RunningCount(2)), + ("keyNoUpdate", new RunningCount(2)), // state.update will not be called + ("keyOnlyInState-1", new RunningCount(1)) + ).toDS() + + val it = initialState.groupByKey(x => x._1).mapValues(_._2) + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState( + Update, GroupStateTimeout.NoTimeout, it)(flatMapGroupsWithStateFunc) + + testStream(result, Update)( + AddData(inputData, "keyOnlyInData", "keyInStateAndData-2"), + CheckNewAnswer( + ("keyOnlyInState-1", Seq[String](), "1"), + ("keyNoUpdate", Seq[String](), "2"), // update will not be called + ("keyInStateAndData-2", Seq[String]("keyInStateAndData-2"), "3"), // inc by 1 + ("keyInStateAndData-1", Seq[String](), "1"), + ("keyOnlyInData", Seq[String]("keyOnlyInData"), "1") // inc by 1 + ), + assertNumStateRows(total = 5, updated = 4), + // Stop and Start stream to make sure initial state doesn't get applied again. + StopStream, + StartStream(), + AddData(inputData, "keyInStateAndData-1"), + CheckNewAnswer( + // state incremented by 1 + ("keyInStateAndData-1", Seq[String]("keyInStateAndData-1"), "2") + ), + assertNumStateRows(total = 5, updated = 1), + StopStream + ) + } + } + } + + testWithAllStateVersions("flatMapGroupsWithState - initial state - case class key") { + val stateFunc = (key: User, values: Iterator[User], state: GroupState[Long]) => { + val valList = values.toSeq + if (valList.isEmpty) { + // When the function is called on just the initial state make sure the other fields + // are set correctly + assert(state.exists) + } + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + assert(!state.hasTimedOut) + val count = state.getOption.getOrElse(0L) + valList.size + // We need to check if not explicitly calling update will still save the state or not + if (!key.name.contains("NoUpdate")) { + // this is not reached when valList is empty and the state count is 2 + state.update(count) + } + Iterator((key, valList.map(_.name), count.toString)) + } + + val ds = Seq( + (User("keyInStateAndData", "1"), (1L)), + (User("keyOnlyInState", "1"), (1L)), + (User("keyNoUpdate", "2"), (2L)) // state.update will not be called on this in the function + ).toDS().groupByKey(_._1).mapValues(_._2) + + val inputData = MemoryStream[User] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, NoTimeout(), ds)(stateFunc) + + testStream(result, Update)( + AddData(inputData, User("keyInStateAndData", "1"), User("keyOnlyInData", "1")), + CheckNewAnswer( + (("keyInStateAndData", "1"), Seq[String]("keyInStateAndData"), "2"), + (("keyOnlyInState", "1"), Seq[String](), "1"), + (("keyNoUpdate", "2"), Seq[String](), "2"), + (("keyOnlyInData", "1"), Seq[String]("keyOnlyInData"), "1") + ), + assertNumStateRows(total = 4, updated = 3), // (keyOnlyInState, 2) does not call update() + // Stop and Start stream to make sure initial state doesn't get applied again. + StopStream, + StartStream(), + AddData(inputData, User("keyOnlyInData", "1")), + CheckNewAnswer( + (("keyOnlyInData", "1"), Seq[String]("keyOnlyInData"), "2") + ), + assertNumStateRows(total = 4, updated = 1), + StopStream + ) + } + + testQuietly("flatMapGroupsWithState - initial state - duplicate keys") { + val initialState = Seq( + ("a", new RunningCount(2)), + ("a", new RunningCount(1)) + ).toDS().groupByKey(_._1).mapValues(_._2) + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, NoTimeout(), initialState)(flatMapGroupsWithStateFunc) + testStream(result, Update)( + AddData(inputData, "a"), + ExpectFailure[SparkException] { e => + assert(e.getCause.getMessage.contains("The initial state provided contained " + + "multiple rows(state) with the same key")) + } + ) + } + + testQuietly("flatMapGroupsWithState - initial state - streaming initial state") { + val initialStateData = MemoryStream[(String, RunningCount)] + initialStateData.addData(("a", new RunningCount(1))) + + val inputData = MemoryStream[String] + + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState( + Update, NoTimeout(), initialStateData.toDS().groupByKey(_._1).mapValues(_._2) + )(flatMapGroupsWithStateFunc) + + val e = intercept[AnalysisException] { + result.writeStream + .format("console") + .start() + } + + val expectedError = "Non-streaming DataFrame/Dataset is not supported" + + " as the initial state in [flatMap|map]GroupsWithState" + + " operation on a streaming DataFrame/Dataset" + assert(e.message.contains(expectedError)) + } + + test("flatMapGroupsWithState - initial state - initial state has flatMapGroupsWithState") { + val initialStateDS = Seq(("keyInStateAndData", new RunningCount(1))).toDS() + val initialState: KeyValueGroupedDataset[String, RunningCount] = + initialStateDS.groupByKey(_._1).mapValues(_._2) + .mapGroupsWithState( + GroupStateTimeout.NoTimeout())( + (key: String, values: Iterator[RunningCount], state: GroupState[Boolean]) => { + (key, values.next()) + } + ).groupByKey(_._1).mapValues(_._2) + + val inputData = MemoryStream[String] + + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState( + Update, NoTimeout(), initialState + )(flatMapGroupsWithStateFunc) + + testStream(result, Update)( + AddData(inputData, "keyInStateAndData"), + CheckNewAnswer( + ("keyInStateAndData", Seq[String]("keyInStateAndData"), "2") + ), + StopStream + ) + } + + testWithAllStateVersions("mapGroupsWithState - initial state - null key") { + val mapGroupsWithStateFunc = + (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + val valList = values.toList + val count = state.getOption.map(_.count).getOrElse(0L) + valList.size + state.update(new RunningCount(count)) + (key, state.get.count.toString) + } + val initialState = Seq( + ("key", new RunningCount(5)), + (null, new RunningCount(2)) + ).toDS().groupByKey(_._1).mapValues(_._2) + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .mapGroupsWithState(NoTimeout(), initialState)(mapGroupsWithStateFunc) + testStream(result, Update)( + AddData(inputData, "key", null), + CheckNewAnswer( + ("key", "6"), // state is incremented by 1 + (null, "3") // incremented by 1 + ), + assertNumStateRows(total = 2, updated = 2), + StopStream + ) + } + + testWithAllStateVersions("flatMapGroupsWithState - initial state - processing time timeout") { + // function will return -1 on timeout and returns count of the state otherwise + val stateFunc = + (key: String, values: Iterator[(String, Long)], state: GroupState[RunningCount]) => { + if (state.hasTimedOut) { + state.remove() + Iterator((key, "-1")) + } else { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + state.setTimeoutDuration("10 seconds") + Iterator((key, count.toString)) + } + } + + val clock = new StreamManualClock + val inputData = MemoryStream[(String, Long)] + val initialState = Seq( + ("c", new RunningCount(2)) + ).toDS().groupByKey(_._1).mapValues(_._2) + val result = + inputData.toDF().toDF("key", "time") + .selectExpr("key", "timestamp_seconds(time) as timestamp") + .withWatermark("timestamp", "10 second") + .as[(String, Long)] + .groupByKey(x => x._1) + .flatMapGroupsWithState(Update, ProcessingTimeTimeout(), initialState)(stateFunc) + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, ("a", 1L)), + AdvanceManualClock(1 * 1000), // a and c are processed here for the first time. + CheckNewAnswer(("a", "1"), ("c", "2")), + AdvanceManualClock(10 * 1000), + AddData(inputData, ("b", 1L)), // this will trigger c and a to get timed out + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "-1"), ("b", "1"), ("c", "-1")) + ) + } + def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = { test("SPARK-20714: watermark does not fail query when timeout = " + timeoutConf) { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = - (key: String, values: Iterator[(String, Long)], state: GroupState[RunningCount]) => { + (key: String, values: Iterator[(String, Long)], state: GroupState[RunningCount]) => { if (state.hasTimedOut) { state.remove() Iterator((key, "-1")) @@ -1411,10 +1679,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { .groupByKey(x => x) .flatMapGroupsWithState[Int, Int](Append, timeoutConf = timeoutType)(func) .logicalPlan.collectFirst { - case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) => + case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, + hasInitialState, sga, sda, se, i, c) => FlatMapGroupsWithStateExec( - f, k, v, g, d, o, None, s, stateFormatVersion, m, t, + f, k, v, se, g, sga, d, sda, o, None, s, stateFormatVersion, m, t, Some(currentBatchTimestamp), Some(currentBatchWatermark), + RDDScanExec(g, emptyRdd, "rdd"), + hasInitialState, RDDScanExec(g, emptyRdd, "rdd")) }.get } @@ -1461,6 +1732,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { } } +case class User(name: String, id: String) + object FlatMapGroupsWithStateSuite { var failInTask = true