[SPARK-35897][SS] Support user defined initial state with flatMapGroupsWithState in Structured Streaming

### What changes were proposed in this pull request?
This PR aims to add support for specifying a user defined initial state for arbitrary structured streaming stateful processing using [flat]MapGroupsWithState operator.

### Why are the changes needed?
Users can load previous state of their stateful processing as an initial state instead of redoing the entire processing once again.

### Does this PR introduce _any_ user-facing change?

Yes this PR introduces new API
```
  def mapGroupsWithState[S: Encoder, U: Encoder](
      timeoutConf: GroupStateTimeout,
      initialState: KeyValueGroupedDataset[K, S])(
      func: (K, Iterator[V], GroupState[S]) => U): Dataset[U]

  def flatMapGroupsWithState[S: Encoder, U: Encoder](
      outputMode: OutputMode,
      timeoutConf: GroupStateTimeout,
      initialState: KeyValueGroupedDataset[K, S])(
      func: (K, Iterator[V], GroupState[S]) => Iterator[U])

```

### How was this patch tested?

Through unit tests in FlatMapGroupsWithStateSuite

Closes #33093 from rahulsmahadev/flatMapGroupsWithState.

Authored-by: Rahul Mahadev <rahul.mahadev@databricks.com>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
This commit is contained in:
Rahul Mahadev 2021-07-02 11:53:17 +08:00 committed by Gengliang Wang
parent 95d94948c5
commit 47485a3c2d
11 changed files with 875 additions and 122 deletions

View file

@ -37,6 +37,12 @@ object UnsupportedOperationChecker extends Logging {
case p if p.isStreaming =>
throwError("Queries with streaming sources must be executed with writeStream.start()")(p)
case f: FlatMapGroupsWithState =>
if (f.hasInitialState) {
throwError("Initial state is not supported in [flatMap|map]GroupsWithState" +
" operation on a batch DataFrame/Dataset")(f)
}
case _ =>
}
}
@ -232,6 +238,12 @@ object UnsupportedOperationChecker extends Logging {
// Check compatibility with output modes and aggregations in query
val aggsInQuery = collectStreamingAggregates(plan)
if (m.initialState.isStreaming) {
// initial state has to be a batch relation
throwError("Non-streaming DataFrame/Dataset is not supported as the" +
" initial state in [flatMap|map]GroupsWithState operation on a streaming" +
" DataFrame/Dataset")
}
if (m.isMapGroupsWithState) { // check mapGroupsWithState
// allowed only in update query output mode and without aggregation
if (aggsInQuery.nonEmpty) {

View file

@ -440,7 +440,7 @@ object FlatMapGroupsWithState {
isMapGroupsWithState: Boolean,
timeout: GroupStateTimeout,
child: LogicalPlan): LogicalPlan = {
val encoder = encoderFor[S]
val stateEncoder = encoderFor[S]
val mapped = new FlatMapGroupsWithState(
func,
@ -449,10 +449,49 @@ object FlatMapGroupsWithState {
groupingAttributes,
dataAttributes,
CatalystSerde.generateObjAttr[U],
encoder.asInstanceOf[ExpressionEncoder[Any]],
stateEncoder.asInstanceOf[ExpressionEncoder[Any]],
outputMode,
isMapGroupsWithState,
timeout,
hasInitialState = false,
groupingAttributes,
dataAttributes,
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
LocalRelation(stateEncoder.schema.toAttributes), // empty data set
child
)
CatalystSerde.serialize[U](mapped)
}
def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder](
func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
outputMode: OutputMode,
isMapGroupsWithState: Boolean,
timeout: GroupStateTimeout,
child: LogicalPlan,
initialStateGroupAttrs: Seq[Attribute],
initialStateDataAttrs: Seq[Attribute],
initialState: LogicalPlan): LogicalPlan = {
val stateEncoder = encoderFor[S]
val mapped = new FlatMapGroupsWithState(
func,
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
groupingAttributes,
dataAttributes,
CatalystSerde.generateObjAttr[U],
stateEncoder.asInstanceOf[ExpressionEncoder[Any]],
outputMode,
isMapGroupsWithState,
timeout,
hasInitialState = true,
initialStateGroupAttrs,
initialStateDataAttrs,
UnresolvedDeserializer(encoderFor[S].deserializer, initialStateDataAttrs),
initialState,
child)
CatalystSerde.serialize[U](mapped)
}
@ -474,6 +513,12 @@ object FlatMapGroupsWithState {
* @param outputMode the output mode of `func`
* @param isMapGroupsWithState whether it is created by the `mapGroupsWithState` method
* @param timeout used to timeout groups that have not received data in a while
* @param hasInitialState Indicates whether initial state needs to be applied or not.
* @param initialStateGroupAttrs grouping attributes for the initial state
* @param initialStateDataAttrs used to read the initial state
* @param initialStateDeserializer used to extract the initial state objects.
* @param initialState user defined initial state that is applied in the first batch.
* @param child logical plan of the underlying data
*/
case class FlatMapGroupsWithState(
func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
@ -486,14 +531,24 @@ case class FlatMapGroupsWithState(
outputMode: OutputMode,
isMapGroupsWithState: Boolean = false,
timeout: GroupStateTimeout,
child: LogicalPlan) extends UnaryNode with ObjectProducer {
hasInitialState: Boolean = false,
initialStateGroupAttrs: Seq[Attribute],
initialStateDataAttrs: Seq[Attribute],
initialStateDeserializer: Expression,
initialState: LogicalPlan,
child: LogicalPlan) extends BinaryNode with ObjectProducer {
if (isMapGroupsWithState) {
assert(outputMode == OutputMode.Update)
}
override protected def withNewChildInternal(newChild: LogicalPlan): FlatMapGroupsWithState =
copy(child = newChild)
override def left: LogicalPlan = child
override def right: LogicalPlan = initialState
override protected def withNewChildrenInternal(
newLeft: LogicalPlan, newRight: LogicalPlan): FlatMapGroupsWithState =
copy(child = newLeft, initialState = newRight)
}
/** Factory for constructing new `FlatMapGroupsInR` nodes. */

View file

@ -24,13 +24,13 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, MonotonicallyIncreasingID, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, MonotonicallyIncreasingID, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder}
/** A dummy command for testing unsupported operations. */
@ -145,15 +145,15 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
for (funcMode <- Seq(Append, Update)) {
assertSupportedInBatchPlan(
s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, null,
batchRelation))
assertSupportedInBatchPlan(
s"flatMapGroupsWithState - multiple flatMapGroupsWithState($funcMode)s on batch relation",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, null,
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false,
null, batchRelation)))
}
@ -162,7 +162,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertSupportedInStreamingPlan(
"flatMapGroupsWithState - flatMapGroupsWithState(Update) " +
"on streaming relation without aggregation in update mode",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null,
streamRelation),
outputMode = Update)
@ -170,7 +170,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertNotSupportedInStreamingPlan(
"flatMapGroupsWithState - flatMapGroupsWithState(Update) " +
"on streaming relation without aggregation in append mode",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null,
streamRelation),
outputMode = Append,
@ -179,7 +179,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertNotSupportedInStreamingPlan(
"flatMapGroupsWithState - flatMapGroupsWithState(Update) " +
"on streaming relation without aggregation in complete mode",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null,
streamRelation),
outputMode = Complete,
@ -192,7 +192,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertNotSupportedInStreamingPlan(
"flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation " +
s"with aggregation in $outputMode mode",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null,
Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)),
outputMode = outputMode,
@ -203,7 +203,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertSupportedInStreamingPlan(
"flatMapGroupsWithState - flatMapGroupsWithState(Append) " +
"on streaming relation without aggregation in append mode",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null,
streamRelation),
outputMode = Append)
@ -211,7 +211,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertNotSupportedInStreamingPlan(
"flatMapGroupsWithState - flatMapGroupsWithState(Append) " +
"on streaming relation without aggregation in update mode",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null,
streamRelation),
outputMode = Update,
@ -226,7 +226,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
Aggregate(
Seq(attributeWithWatermark),
aggExprs("c"),
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null,
streamRelation)),
outputMode = outputMode,
@ -237,7 +237,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertNotSupportedInStreamingPlan(
"flatMapGroupsWithState - flatMapGroupsWithState(Append) " +
s"on streaming relation after aggregation in $outputMode mode",
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
isMapGroupsWithState = false, null,
Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)),
outputMode = outputMode,
@ -247,7 +247,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertNotSupportedInStreamingPlan(
"flatMapGroupsWithState - " +
"flatMapGroupsWithState(Update) on streaming relation in complete mode",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null,
streamRelation),
outputMode = Complete,
@ -261,7 +261,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertSupportedInStreamingPlan(
s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation inside " +
s"streaming relation in $outputMode output mode",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false,
null, batchRelation),
outputMode = outputMode
@ -274,9 +274,9 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertSupportedInStreamingPlan(
"flatMapGroupsWithState - multiple flatMapGroupsWithStates on streaming relation and all are " +
"in append mode",
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
isMapGroupsWithState = false, null,
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
isMapGroupsWithState = false, null, streamRelation)),
outputMode = Append,
SQLConf.STATEFUL_OPERATOR_CHECK_CORRECTNESS_ENABLED.key -> "false")
@ -284,9 +284,9 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertNotSupportedInStreamingPlan(
"flatMapGroupsWithState - multiple flatMapGroupsWithStates on s streaming relation but some" +
" are not in append mode",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null,
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null,
streamRelation)),
outputMode = Append,
@ -296,7 +296,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertNotSupportedInStreamingPlan(
"mapGroupsWithState - mapGroupsWithState " +
"on streaming relation without aggregation in append mode",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null,
streamRelation),
outputMode = Append,
@ -307,7 +307,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertNotSupportedInStreamingPlan(
"mapGroupsWithState - mapGroupsWithState " +
"on streaming relation without aggregation in complete mode",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null,
streamRelation),
outputMode = Complete,
@ -319,7 +319,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertNotSupportedInStreamingPlan(
"mapGroupsWithState - mapGroupsWithState on streaming relation " +
s"with aggregation in $outputMode mode",
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Update,
TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Update,
isMapGroupsWithState = true, null,
Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)),
outputMode = outputMode,
@ -330,9 +330,9 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertNotSupportedInStreamingPlan(
"mapGroupsWithState - multiple mapGroupsWithStates on streaming relation and all are " +
"in append mode",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null,
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null,
streamRelation)),
outputMode = Append,
@ -342,9 +342,9 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertNotSupportedInStreamingPlan(
"mapGroupsWithState - " +
"mixing mapGroupsWithStates and flatMapGroupsWithStates on streaming relation",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null,
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null,
streamRelation)
),
@ -354,7 +354,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
// mapGroupsWithState with event time timeout + watermark
assertNotSupportedInStreamingPlan(
"mapGroupsWithState - mapGroupsWithState with event time timeout without watermark",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true,
EventTimeTimeout, streamRelation),
outputMode = Update,
@ -362,7 +362,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertSupportedInStreamingPlan(
"mapGroupsWithState - mapGroupsWithState with event time timeout with watermark",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true,
EventTimeTimeout, new TestStreamingRelation(attributeWithWatermark)),
outputMode = Update)
@ -532,7 +532,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
testGlobalWatermarkLimit(
s"FlatMapGroupsWithState after stream-stream $joinType join in Append mode",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Append,
isMapGroupsWithState = false, null,
streamRelation.join(streamRelation, joinType = joinType,
@ -664,7 +664,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertFailOnGlobalWatermarkLimit(
"FlatMapGroupsWithState after streaming aggregation in Append mode",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Append,
isMapGroupsWithState = false, null,
streamRelation.groupBy("a")(count("*"))),
@ -675,14 +675,14 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
{
assertPassOnGlobalWatermarkLimit(
"single FlatMapGroupsWithState in Append mode",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Append,
isMapGroupsWithState = false, null, streamRelation),
OutputMode.Append())
assertFailOnGlobalWatermarkLimit(
"streaming aggregation after FlatMapGroupsWithState in Append mode",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Append,
isMapGroupsWithState = false, null, streamRelation).groupBy("*")(count("*")),
OutputMode.Append())
@ -691,7 +691,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertFailOnGlobalWatermarkLimit(
s"stream-stream $joinType after FlatMapGroupsWithState in Append mode",
streamRelation.join(
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
isMapGroupsWithState = false, null, streamRelation), joinType = joinType,
condition = Some(attributeWithWatermark === attribute)),
OutputMode.Append())
@ -699,16 +699,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertFailOnGlobalWatermarkLimit(
"FlatMapGroupsWithState after FlatMapGroupsWithState in Append mode",
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
isMapGroupsWithState = false, null,
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
isMapGroupsWithState = false, null, streamRelation)),
OutputMode.Append())
assertFailOnGlobalWatermarkLimit(
s"deduplicate after FlatMapGroupsWithState in Append mode",
Deduplicate(Seq(attribute),
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
isMapGroupsWithState = false, null, streamRelation)),
OutputMode.Append())
}
@ -730,7 +730,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
assertPassOnGlobalWatermarkLimit(
"FlatMapGroupsWithState after deduplicate in Append mode",
FlatMapGroupsWithState(
TestFlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, null, Append,
isMapGroupsWithState = false, null,
Deduplicate(Seq(attribute), streamRelation)),
@ -1015,3 +1015,43 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
override def nodeName: String = "StreamingRelationV2"
}
}
object TestFlatMapGroupsWithState {
// scalastyle:off
// Creating an apply constructor here as we changed the class by adding more fields
def apply(func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
outputObjAttr: Attribute,
stateEncoder: ExpressionEncoder[Any],
outputMode: OutputMode,
isMapGroupsWithState: Boolean = false,
timeout: GroupStateTimeout,
child: LogicalPlan): FlatMapGroupsWithState = {
val attribute = AttributeReference("a", IntegerType, nullable = true)()
val batchRelation = LocalRelation(attribute)
new FlatMapGroupsWithState(
func,
keyDeserializer,
valueDeserializer,
groupingAttributes,
dataAttributes,
outputObjAttr,
stateEncoder,
outputMode,
isMapGroupsWithState,
timeout,
false,
groupingAttributes,
dataAttributes,
valueDeserializer,
batchRelation,
child
)
}
// scalastyle:on
}

View file

@ -280,6 +280,51 @@ class KeyValueGroupedDataset[K, V] private[sql](
child = logicalPlan))
}
/**
* (Scala-specific)
* Applies the given function to each group of data, while maintaining a user-defined per-group
* state. The result Dataset will represent the objects returned by the function.
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See [[org.apache.spark.sql.streaming.GroupState]] for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
* @param func Function to be called on every group.
* @param timeoutConf Timeout Conf, see GroupStateTimeout for more details
* @param initialState The user provided state that will be initialized when the first batch
* of data is processed in the streaming query. The user defined function
* will be called on the state data even if there are no other values in
* the group. To convert a Dataset ds of type Dataset[(K, S)] to a
* KeyValueGroupedDataset[K, S]
* do {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}}
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
* @since 3.2.0
*/
def mapGroupsWithState[S: Encoder, U: Encoder](
timeoutConf: GroupStateTimeout,
initialState: KeyValueGroupedDataset[K, S])(
func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s))
Dataset[U](
sparkSession,
FlatMapGroupsWithState[K, V, S, U](
flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]],
groupingAttributes,
dataAttributes,
OutputMode.Update,
isMapGroupsWithState = true,
timeoutConf,
child = logicalPlan,
initialState.groupingAttributes,
initialState.dataAttributes,
initialState.queryExecution.analyzed
))
}
/**
* (Java-specific)
* Applies the given function to each group of data, while maintaining a user-defined per-group
@ -336,6 +381,40 @@ class KeyValueGroupedDataset[K, V] private[sql](
)(stateEncoder, outputEncoder)
}
/**
* (Java-specific)
* Applies the given function to each group of data, while maintaining a user-defined per-group
* state. The result Dataset will represent the objects returned by the function.
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See `GroupState` for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
* @param func Function to be called on every group.
* @param stateEncoder Encoder for the state type.
* @param outputEncoder Encoder for the output type.
* @param timeoutConf Timeout configuration for groups that do not receive data for a while.
* @param initialState The user provided state that will be initialized when the first batch
* of data is processed in the streaming query. The user defined function
* will be called on the state data even if there are no other values in
* the group.
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
* @since 3.2.0
*/
def mapGroupsWithState[S, U](
func: MapGroupsWithStateFunction[K, V, S, U],
stateEncoder: Encoder[S],
outputEncoder: Encoder[U],
timeoutConf: GroupStateTimeout,
initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
mapGroupsWithState[S, U](timeoutConf, initialState)(
(key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s)
)(stateEncoder, outputEncoder)
}
/**
* (Scala-specific)
* Applies the given function to each group of data, while maintaining a user-defined per-group
@ -373,6 +452,53 @@ class KeyValueGroupedDataset[K, V] private[sql](
child = logicalPlan))
}
/**
* (Scala-specific)
* Applies the given function to each group of data, while maintaining a user-defined per-group
* state. The result Dataset will represent the objects returned by the function.
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See `GroupState` for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
* @param func Function to be called on every group.
* @param outputMode The output mode of the function.
* @param timeoutConf Timeout configuration for groups that do not receive data for a while.
* @param initialState The user provided state that will be initialized when the first batch
* of data is processed in the streaming query. The user defined function
* will be called on the state data even if there are no other values in
* the group. To covert a Dataset `ds` of type of type `Dataset[(K, S)]`
* to a `KeyValueGroupedDataset[K, S]`, use
* {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}}
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
* @since 3.2.0
*/
def flatMapGroupsWithState[S: Encoder, U: Encoder](
outputMode: OutputMode,
timeoutConf: GroupStateTimeout,
initialState: KeyValueGroupedDataset[K, S])(
func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = {
if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) {
throw new IllegalArgumentException("The output mode of function should be append or update")
}
Dataset[U](
sparkSession,
FlatMapGroupsWithState[K, V, S, U](
func.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]],
groupingAttributes,
dataAttributes,
outputMode,
isMapGroupsWithState = false,
timeoutConf,
child = logicalPlan,
initialState.groupingAttributes,
initialState.dataAttributes,
initialState.queryExecution.analyzed
))
}
/**
* (Java-specific)
* Applies the given function to each group of data, while maintaining a user-defined per-group
@ -403,6 +529,44 @@ class KeyValueGroupedDataset[K, V] private[sql](
flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder)
}
/**
* (Java-specific)
* Applies the given function to each group of data, while maintaining a user-defined per-group
* state. The result Dataset will represent the objects returned by the function.
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See `GroupState` for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
* @param func Function to be called on every group.
* @param outputMode The output mode of the function.
* @param stateEncoder Encoder for the state type.
* @param outputEncoder Encoder for the output type.
* @param timeoutConf Timeout configuration for groups that do not receive data for a while.
* @param initialState The user provided state that will be initialized when the first batch
* of data is processed in the streaming query. The user defined function
* will be called on the state data even if there are no other values in
* the group. To covert a Dataset `ds` of type of type `Dataset[(K, S)]`
* to a `KeyValueGroupedDataset[K, S]`, use
* {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}}
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
* @since 3.2.0
*/
def flatMapGroupsWithState[S, U](
func: FlatMapGroupsWithStateFunction[K, V, S, U],
outputMode: OutputMode,
stateEncoder: Encoder[S],
outputEncoder: Encoder[U],
timeoutConf: GroupStateTimeout,
initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
val f = (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s).asScala
flatMapGroupsWithState[S, U](
outputMode, timeoutConf, initialState)(f)(stateEncoder, outputEncoder)
}
/**
* (Scala-specific)
* Reduces the elements of each group of data using the specified binary function.

View file

@ -561,11 +561,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case FlatMapGroupsWithState(
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _,
timeout, child) =>
timeout, hasInitialState, stateGroupAttr, sda, sDeser, initialState, child) =>
val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
val execPlan = FlatMapGroupsWithStateExec(
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, stateVersion,
outputMode, timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child))
func, keyDeser, valueDeser, sDeser, groupAttr, stateGroupAttr, dataAttr, sda, outputAttr,
None, stateEnc, stateVersion, outputMode, timeout, batchTimestampMs = None,
eventTimeWatermark = None, planLater(initialState), hasInitialState, planLater(child)
)
execPlan :: Nil
case _ =>
Nil
@ -669,7 +671,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
case logical.FlatMapGroupsWithState(
f, key, value, grouping, data, output, _, _, _, timeout, child) =>
f, key, value, grouping, data, output, _, _, _, timeout, _, _, _, _, _, child) =>
execution.MapGroupsExec(
f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil
case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>

View file

@ -25,9 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expressi
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.sql.streaming.GroupStateTimeout.NoTimeout
import org.apache.spark.util.{CompletionIterator, SerializableConfiguration}
/**
* Physical operator for executing `FlatMapGroupsWithState`
@ -35,6 +37,7 @@ import org.apache.spark.util.CompletionIterator
* @param func function called on each group
* @param keyDeserializer used to extract the key object for each group.
* @param valueDeserializer used to extract the items in the iterator from an input row.
* @param initialStateDeserializer used to extract the state object from the initialState dataset
* @param groupingAttributes used to group the data
* @param dataAttributes used to read the data
* @param outputObjAttr Defines the output object
@ -42,13 +45,20 @@ import org.apache.spark.util.CompletionIterator
* @param outputMode the output mode of `func`
* @param timeoutConf used to timeout groups that have not received data in a while
* @param batchTimestampMs processing timestamp of the current batch.
* @param eventTimeWatermark event time watermark for the current batch
* @param initialState the user specified initial state
* @param hasInitialState indicates whether the initial state is provided or not
* @param child the physical plan for the underlying data
*/
case class FlatMapGroupsWithStateExec(
func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
initialStateDeserializer: Expression,
groupingAttributes: Seq[Attribute],
initialStateGroupAttrs: Seq[Attribute],
dataAttributes: Seq[Attribute],
initialStateDataAttrs: Seq[Attribute],
outputObjAttr: Attribute,
stateInfo: Option[StatefulOperatorStateInfo],
stateEncoder: ExpressionEncoder[Any],
@ -57,27 +67,45 @@ case class FlatMapGroupsWithStateExec(
timeoutConf: GroupStateTimeout,
batchTimestampMs: Option[Long],
eventTimeWatermark: Option[Long],
initialState: SparkPlan,
hasInitialState: Boolean,
child: SparkPlan
) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport {
) extends BinaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport {
import FlatMapGroupsWithStateExecHelper._
import GroupStateImpl._
override def left: SparkPlan = child
override def right: SparkPlan = initialState
private val isTimeoutEnabled = timeoutConf != NoTimeout
private val watermarkPresent = child.output.exists {
case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true
case _ => false
}
private[sql] val stateManager =
createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion)
/** Distribute by grouping attributes */
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) :: Nil
/**
* Distribute by grouping attributes - We need the underlying data and the initial state data
* to have the same grouping so that the data are co-lacated on the same task.
*/
override def requiredChildDistribution: Seq[Distribution] = {
ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) ::
ClusteredDistribution(initialStateGroupAttrs, stateInfo.map(_.numPartitions)) ::
Nil
}
/** Ordering needed for using GroupingIterator */
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
Seq(groupingAttributes.map(SortOrder(_, Ascending)))
/**
* Ordering needed for using GroupingIterator.
* We need the initial state to also use the ordering as the data so that we can co-locate the
* keys from the underlying data and the initial state.
*/
override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(
groupingAttributes.map(SortOrder(_, Ascending)),
initialStateGroupAttrs.map(SortOrder(_, Ascending)))
override def keyExpressions: Seq[Attribute] = groupingAttributes
@ -95,6 +123,77 @@ case class FlatMapGroupsWithStateExec(
}
}
/**
* Process data by applying the user defined function on a per partition basis.
*
* @param iter - Iterator of the data rows
* @param store - associated state store for this partition
* @param processor - handle to the input processor object.
* @param initialStateIterOption - optional initial state iterator
*/
def processDataWithPartition(
iter: Iterator[InternalRow],
store: StateStore,
processor: InputProcessor,
initialStateIterOption: Option[Iterator[InternalRow]] = None
): CompletionIterator[InternalRow, Iterator[InternalRow]] = {
val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
val commitTimeMs = longMetric("commitTimeMs")
val timeoutLatencyMs = longMetric("allRemovalsTimeMs")
val currentTimeNs = System.nanoTime
val updatesStartTimeNs = currentTimeNs
var timeoutProcessingStartTimeNs = currentTimeNs
// If timeout is based on event time, then filter late data based on watermark
val filteredIter = watermarkPredicateForData match {
case Some(predicate) if timeoutConf == EventTimeTimeout =>
applyRemovingRowsOlderThanWatermark(iter, predicate)
case _ =>
iter
}
val processedOutputIterator = initialStateIterOption match {
case Some(initStateIter) if initStateIter.hasNext =>
processor.processNewDataWithInitialState(filteredIter, initStateIter)
case _ => processor.processNewData(filteredIter)
}
val newDataProcessorIter =
CompletionIterator[InternalRow, Iterator[InternalRow]](
processedOutputIterator, {
// Once the input is processed, mark the start time for timeout processing to measure
// it separately from the overall processing time.
timeoutProcessingStartTimeNs = System.nanoTime
})
val timeoutProcessorIter =
CompletionIterator[InternalRow, Iterator[InternalRow]](processor.processTimedOutState(), {
// Note: `timeoutLatencyMs` also includes the time the parent operator took for
// processing output returned through iterator.
timeoutLatencyMs += NANOSECONDS.toMillis(System.nanoTime - timeoutProcessingStartTimeNs)
})
// Generate a iterator that returns the rows grouped by the grouping function
// Note that this code ensures that the filtering for timeout occurs only after
// all the data has been processed. This is to ensure that the timeout information of all
// the keys with data is updated before they are processed for timeouts.
val outputIterator = newDataProcessorIter ++ timeoutProcessorIter
// Return an iterator of all the rows generated by all the keys, such that when fully
// consumed, all the state updates will be committed by the state store
CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator, {
// Note: Due to the iterator lazy execution, this metric also captures the time taken
// by the upstream (consumer) operators in addition to the processing in this operator.
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
commitTimeMs += timeTakenMs {
store.commit()
}
setStoreMetrics(store)
setOperatorMetrics()
})
}
override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver
@ -103,69 +202,50 @@ case class FlatMapGroupsWithStateExec(
case ProcessingTimeTimeout =>
require(batchTimestampMs.nonEmpty)
case EventTimeTimeout =>
require(eventTimeWatermark.nonEmpty) // watermark value has been populated
require(eventTimeWatermark.nonEmpty) // watermark value has been populated
require(watermarkExpression.nonEmpty) // input schema has watermark attribute
case _ =>
}
child.execute().mapPartitionsWithStateStore[InternalRow](
getStateInfo,
groupingAttributes.toStructType,
stateManager.stateSchema,
indexOrdinal = None,
session.sessionState,
Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
val commitTimeMs = longMetric("commitTimeMs")
val timeoutLatencyMs = longMetric("allRemovalsTimeMs")
if (hasInitialState) {
// If the user provided initial state we need to have the initial state and the
// data in the same partition so that we can still have just one commit at the end.
val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf)
val hadoopConfBroadcast = sparkContext.broadcast(
new SerializableConfiguration(session.sqlContext.sessionState.newHadoopConf()))
child.execute().stateStoreAwareZipPartitions(
initialState.execute(),
getStateInfo,
storeNames = Seq(),
session.sqlContext.streams.stateStoreCoordinator) {
// The state store aware zip partitions will provide us with two iterators,
// child data iterator and the initial state iterator per partition.
case (partitionId, childDataIterator, initStateIterator) =>
val stateStoreId = StateStoreId(
stateInfo.get.checkpointLocation, stateInfo.get.operatorId, partitionId)
val storeProviderId = StateStoreProviderId(stateStoreId, stateInfo.get.queryRunId)
val store = StateStore.get(
storeProviderId,
groupingAttributes.toStructType,
stateManager.stateSchema,
indexOrdinal = None,
stateInfo.get.storeVersion, storeConf, hadoopConfBroadcast.value.value)
val processor = new InputProcessor(store)
processDataWithPartition(childDataIterator, store, processor, Some(initStateIterator))
}
} else {
child.execute().mapPartitionsWithStateStore[InternalRow](
getStateInfo,
groupingAttributes.toStructType,
stateManager.stateSchema,
indexOrdinal = None,
session.sqlContext.sessionState,
Some(session.sqlContext.streams.stateStoreCoordinator)
) { case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
val processor = new InputProcessor(store)
val currentTimeNs = System.nanoTime
val updatesStartTimeNs = currentTimeNs
var timeoutProcessingStartTimeNs = currentTimeNs
// If timeout is based on event time, then filter late data based on watermark
val filteredIter = watermarkPredicateForData match {
case Some(predicate) if timeoutConf == EventTimeTimeout =>
applyRemovingRowsOlderThanWatermark(iter, predicate)
case _ =>
iter
}
val newDataProcessorIter =
CompletionIterator[InternalRow, Iterator[InternalRow]](
processor.processNewData(filteredIter), {
// Once the input is processed, mark the start time for timeout processing to measure
// it separately from the overall processing time.
timeoutProcessingStartTimeNs = System.nanoTime
})
val timeoutProcessorIter =
CompletionIterator[InternalRow, Iterator[InternalRow]](processor.processTimedOutState(), {
// Note: `timeoutLatencyMs` also includes the time the parent operator took for
// processing output returned through iterator.
timeoutLatencyMs += NANOSECONDS.toMillis(System.nanoTime - timeoutProcessingStartTimeNs)
})
// Generate a iterator that returns the rows grouped by the grouping function
// Note that this code ensures that the filtering for timeout occurs only after
// all the data has been processed. This is to ensure that the timeout information of all
// the keys with data is updated before they are processed for timeouts.
val outputIterator = newDataProcessorIter ++ timeoutProcessorIter
// Return an iterator of all the rows generated by all the keys, such that when fully
// consumed, all the state updates will be committed by the state store
CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator, {
// Note: Due to the iterator lazy execution, this metric also captures the time taken
// by the upstream (consumer) operators in addition to the processing in this operator.
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
commitTimeMs += timeTakenMs {
store.commit()
}
setStoreMetrics(store)
setOperatorMetrics()
}
)
processDataWithPartition(singleIterator, store, processor)
}
}
}
@ -178,6 +258,11 @@ case class FlatMapGroupsWithStateExec(
private val getValueObj =
ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType)
private val getStateObj = if (hasInitialState) {
Some(ObjectOperator.deserializeRowToObject(initialStateDeserializer, initialStateDataAttrs))
} else {
None
}
// Metrics
private val numUpdatedStateRows = longMetric("numUpdatedStateRows")
@ -199,6 +284,49 @@ case class FlatMapGroupsWithStateExec(
}
}
/**
* Process the new data iterator along with the initial state. The initial state is applied
* before processing the new data for every key. The user defined function is called only
* once for every key that has either initial state or data or both.
*/
def processNewDataWithInitialState(
childDataIter: Iterator[InternalRow],
initStateIter: Iterator[InternalRow]
): Iterator[InternalRow] = {
if (!childDataIter.hasNext && !initStateIter.hasNext) return Iterator.empty
// Create iterators for the child data and the initial state grouped by their grouping
// attributes.
val groupedChildDataIter = GroupedIterator(childDataIter, groupingAttributes, child.output)
val groupedInitialStateIter =
GroupedIterator(initStateIter, initialStateGroupAttrs, initialState.output)
// Create a CoGroupedIterator that will group the two iterators together for every key group.
new CoGroupedIterator(
groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap {
case (keyRow, valueRowIter, initialStateRowIter) =>
val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
var foundInitialStateForKey = false
initialStateRowIter.foreach { initialStateRow =>
if (foundInitialStateForKey) {
throw new IllegalArgumentException("The initial state provided contained " +
"multiple rows(state) with the same key. Make sure to de-duplicate the " +
"initial state before passing it.")
}
foundInitialStateForKey = true
val initStateObj = getStateObj.get(initialStateRow)
stateManager.putState(store, keyUnsafeRow, initStateObj, NO_TIMESTAMP)
}
// We apply the values for the key after applying the initial state.
callFunctionAndUpdateState(
stateManager.getState(store, keyUnsafeRow),
valueRowIter,
hasTimedOut = false
)
}
}
/** Find the groups that have timeout set and are timing out right now, and call the function */
def processTimedOutState(): Iterator[InternalRow] = {
if (isTimeoutEnabled) {
@ -271,6 +399,8 @@ case class FlatMapGroupsWithStateExec(
}
}
override protected def withNewChildInternal(newChild: SparkPlan): FlatMapGroupsWithStateExec =
copy(child = newChild)
override protected def withNewChildrenInternal(
newLeft: SparkPlan, newRight: SparkPlan): FlatMapGroupsWithStateExec =
copy(child = newLeft, initialState = newRight)
}

View file

@ -157,10 +157,14 @@ class IncrementalExecution(
Some(offsetSeqMetadata.batchWatermarkMs))
case m: FlatMapGroupsWithStateExec =>
// We set this to true only for the first batch of the streaming query.
val hasInitialState = (currentBatchId == 0L && m.hasInitialState)
m.copy(
stateInfo = Some(nextStatefulOperationStateInfo),
batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs))
eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs),
hasInitialState = hasInitialState
)
case j: StreamingSymmetricHashJoinExec =>
j.copy(

View file

@ -203,7 +203,9 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
}
/** An operator that supports watermark. */
trait WatermarkSupport extends UnaryExecNode {
trait WatermarkSupport extends SparkPlan {
def child: SparkPlan
/** The keys that may have a watermark attribute. */
def keyExpressions: Seq[Attribute]

View file

@ -111,6 +111,11 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState
* when the group has new data, or the group has timed out. So the user has to set the timeout
* duration every time the function is called, otherwise, there will not be any timeout set.
*
* `[map/flatMap]GroupsWithState` can take a user defined initial state as an additional argument.
* This state will be applied when the first batch of the streaming query is processed. If there
* are no matching rows in the data for the keys present in the initial state, the state is still
* applied and the function will be invoked with the values being an empty iterator.
*
* Scala example of using GroupState in `mapGroupsWithState`:
* {{{
* // A mapping function that maintains an integer state for string keys and returns a string.

View file

@ -46,6 +46,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.test.TestSparkSession;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.util.LongAccumulator;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.expr;
import static org.apache.spark.sql.types.DataTypes.*;
@ -160,6 +161,71 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(6, reduced);
}
@Test
public void testInitialStateFlatMapGroupsWithState() {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
Dataset<Tuple2<Integer, Long>> initialStateDS = spark.createDataset(
Arrays.asList(new Tuple2<Integer, Long>(2, 2L)),
Encoders.tuple(Encoders.INT(), Encoders.LONG())
);
KeyValueGroupedDataset<Integer, Tuple2<Integer, Long>> kvInitStateDS =
initialStateDS.groupByKey(
(MapFunction<Tuple2<Integer, Long>, Integer>) f -> f._1, Encoders.INT());
KeyValueGroupedDataset<Integer, Long> kvInitStateMappedDS = kvInitStateDS.mapValues(
(MapFunction<Tuple2<Integer, Long>, Long>) f -> f._2,
Encoders.LONG()
);
KeyValueGroupedDataset<Integer, String> grouped =
ds.groupByKey((MapFunction<String, Integer>) String::length, Encoders.INT());
Dataset<String> flatMapped2 = grouped.flatMapGroupsWithState(
(FlatMapGroupsWithStateFunction<Integer, String, Long, String>) (key, values, s) -> {
StringBuilder sb = new StringBuilder(key.toString());
while (values.hasNext()) {
sb.append(values.next());
}
return Collections.singletonList(sb.toString()).iterator();
},
OutputMode.Append(),
Encoders.LONG(),
Encoders.STRING(),
GroupStateTimeout.NoTimeout(),
kvInitStateMappedDS);
Assert.assertThrows(
"Initial state is not supported in [flatMap|map]GroupsWithState " +
"operation on a batch DataFrame/Dataset",
AnalysisException.class,
() -> {
flatMapped2.collectAsList();
}
);
Dataset<String> mapped2 = grouped.mapGroupsWithState(
(MapGroupsWithStateFunction<Integer, String, Long, String>) (key, values, s) -> {
StringBuilder sb = new StringBuilder(key.toString());
while (values.hasNext()) {
sb.append(values.next());
}
return sb.toString();
},
Encoders.LONG(),
Encoders.STRING(),
GroupStateTimeout.NoTimeout(),
kvInitStateMappedDS);
Assert.assertThrows(
"Initial state is not supported in [flatMap|map]GroupsWithState " +
"operation on a batch DataFrame/Dataset",
AnalysisException.class,
() -> {
mapped2.collectAsList();
}
);
}
@Test
public void testIllegalTestGroupStateCreations() {
// SPARK-35800: test code throws upon illegal TestGroupState create() calls

View file

@ -26,7 +26,7 @@ import org.scalatest.exceptions.TestFailedException
import org.apache.spark.SparkException
import org.apache.spark.api.java.Optional
import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction
import org.apache.spark.sql.{DataFrame, Encoder}
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Encoder, KeyValueGroupedDataset}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState
@ -1042,7 +1042,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
)
}
test("mapGroupsWithState - streaming") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
@ -1268,12 +1267,281 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
assert(e.getMessage === "The output mode of function should be append or update")
}
import testImplicits._
/**
* FlatMapGroupsWithState function that returns the key, value as passed to it
* along with the updated state. The state is incremented for every value.
*/
val flatMapGroupsWithStateFunc =
(key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
val valList = values.toSeq
if (valList.isEmpty) {
// When the function is called on just the initial state make sure the other fields
// are set correctly
assert(state.exists)
}
assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
assertCannotGetWatermark { state.getCurrentWatermarkMs() }
assert(!state.hasTimedOut)
val count = state.getOption.map(_.count).getOrElse(0L) + valList.size
// We need to check if not explicitly calling update will still save the init state or not
if (!key.contains("NoUpdate")) {
// this is not reached when valList is empty and the state count is 2
state.update(new RunningCount(count))
}
Iterator((key, valList, count.toString))
}
Seq("1", "2", "6").foreach { shufflePartitions =>
testWithAllStateVersions(s"flatMapGroupsWithState - initial " +
s"state - all cases - shuffle partitions ${shufflePartitions}") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> shufflePartitions) {
// We will test them on different shuffle partition configuration to make sure the
// grouping by key will still work. On higher number of shuffle partitions its possible
// that all keys end up on different partitions.
val initialState: Dataset[(String, RunningCount)] = Seq(
("keyInStateAndData-1", new RunningCount(1)),
("keyInStateAndData-2", new RunningCount(2)),
("keyNoUpdate", new RunningCount(2)), // state.update will not be called
("keyOnlyInState-1", new RunningCount(1))
).toDS()
val it = initialState.groupByKey(x => x._1).mapValues(_._2)
val inputData = MemoryStream[String]
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(
Update, GroupStateTimeout.NoTimeout, it)(flatMapGroupsWithStateFunc)
testStream(result, Update)(
AddData(inputData, "keyOnlyInData", "keyInStateAndData-2"),
CheckNewAnswer(
("keyOnlyInState-1", Seq[String](), "1"),
("keyNoUpdate", Seq[String](), "2"), // update will not be called
("keyInStateAndData-2", Seq[String]("keyInStateAndData-2"), "3"), // inc by 1
("keyInStateAndData-1", Seq[String](), "1"),
("keyOnlyInData", Seq[String]("keyOnlyInData"), "1") // inc by 1
),
assertNumStateRows(total = 5, updated = 4),
// Stop and Start stream to make sure initial state doesn't get applied again.
StopStream,
StartStream(),
AddData(inputData, "keyInStateAndData-1"),
CheckNewAnswer(
// state incremented by 1
("keyInStateAndData-1", Seq[String]("keyInStateAndData-1"), "2")
),
assertNumStateRows(total = 5, updated = 1),
StopStream
)
}
}
}
testWithAllStateVersions("flatMapGroupsWithState - initial state - case class key") {
val stateFunc = (key: User, values: Iterator[User], state: GroupState[Long]) => {
val valList = values.toSeq
if (valList.isEmpty) {
// When the function is called on just the initial state make sure the other fields
// are set correctly
assert(state.exists)
}
assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
assertCannotGetWatermark { state.getCurrentWatermarkMs() }
assert(!state.hasTimedOut)
val count = state.getOption.getOrElse(0L) + valList.size
// We need to check if not explicitly calling update will still save the state or not
if (!key.name.contains("NoUpdate")) {
// this is not reached when valList is empty and the state count is 2
state.update(count)
}
Iterator((key, valList.map(_.name), count.toString))
}
val ds = Seq(
(User("keyInStateAndData", "1"), (1L)),
(User("keyOnlyInState", "1"), (1L)),
(User("keyNoUpdate", "2"), (2L)) // state.update will not be called on this in the function
).toDS().groupByKey(_._1).mapValues(_._2)
val inputData = MemoryStream[User]
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, NoTimeout(), ds)(stateFunc)
testStream(result, Update)(
AddData(inputData, User("keyInStateAndData", "1"), User("keyOnlyInData", "1")),
CheckNewAnswer(
(("keyInStateAndData", "1"), Seq[String]("keyInStateAndData"), "2"),
(("keyOnlyInState", "1"), Seq[String](), "1"),
(("keyNoUpdate", "2"), Seq[String](), "2"),
(("keyOnlyInData", "1"), Seq[String]("keyOnlyInData"), "1")
),
assertNumStateRows(total = 4, updated = 3), // (keyOnlyInState, 2) does not call update()
// Stop and Start stream to make sure initial state doesn't get applied again.
StopStream,
StartStream(),
AddData(inputData, User("keyOnlyInData", "1")),
CheckNewAnswer(
(("keyOnlyInData", "1"), Seq[String]("keyOnlyInData"), "2")
),
assertNumStateRows(total = 4, updated = 1),
StopStream
)
}
testQuietly("flatMapGroupsWithState - initial state - duplicate keys") {
val initialState = Seq(
("a", new RunningCount(2)),
("a", new RunningCount(1))
).toDS().groupByKey(_._1).mapValues(_._2)
val inputData = MemoryStream[String]
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, NoTimeout(), initialState)(flatMapGroupsWithStateFunc)
testStream(result, Update)(
AddData(inputData, "a"),
ExpectFailure[SparkException] { e =>
assert(e.getCause.getMessage.contains("The initial state provided contained " +
"multiple rows(state) with the same key"))
}
)
}
testQuietly("flatMapGroupsWithState - initial state - streaming initial state") {
val initialStateData = MemoryStream[(String, RunningCount)]
initialStateData.addData(("a", new RunningCount(1)))
val inputData = MemoryStream[String]
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(
Update, NoTimeout(), initialStateData.toDS().groupByKey(_._1).mapValues(_._2)
)(flatMapGroupsWithStateFunc)
val e = intercept[AnalysisException] {
result.writeStream
.format("console")
.start()
}
val expectedError = "Non-streaming DataFrame/Dataset is not supported" +
" as the initial state in [flatMap|map]GroupsWithState" +
" operation on a streaming DataFrame/Dataset"
assert(e.message.contains(expectedError))
}
test("flatMapGroupsWithState - initial state - initial state has flatMapGroupsWithState") {
val initialStateDS = Seq(("keyInStateAndData", new RunningCount(1))).toDS()
val initialState: KeyValueGroupedDataset[String, RunningCount] =
initialStateDS.groupByKey(_._1).mapValues(_._2)
.mapGroupsWithState(
GroupStateTimeout.NoTimeout())(
(key: String, values: Iterator[RunningCount], state: GroupState[Boolean]) => {
(key, values.next())
}
).groupByKey(_._1).mapValues(_._2)
val inputData = MemoryStream[String]
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(
Update, NoTimeout(), initialState
)(flatMapGroupsWithStateFunc)
testStream(result, Update)(
AddData(inputData, "keyInStateAndData"),
CheckNewAnswer(
("keyInStateAndData", Seq[String]("keyInStateAndData"), "2")
),
StopStream
)
}
testWithAllStateVersions("mapGroupsWithState - initial state - null key") {
val mapGroupsWithStateFunc =
(key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
val valList = values.toList
val count = state.getOption.map(_.count).getOrElse(0L) + valList.size
state.update(new RunningCount(count))
(key, state.get.count.toString)
}
val initialState = Seq(
("key", new RunningCount(5)),
(null, new RunningCount(2))
).toDS().groupByKey(_._1).mapValues(_._2)
val inputData = MemoryStream[String]
val result =
inputData.toDS()
.groupByKey(x => x)
.mapGroupsWithState(NoTimeout(), initialState)(mapGroupsWithStateFunc)
testStream(result, Update)(
AddData(inputData, "key", null),
CheckNewAnswer(
("key", "6"), // state is incremented by 1
(null, "3") // incremented by 1
),
assertNumStateRows(total = 2, updated = 2),
StopStream
)
}
testWithAllStateVersions("flatMapGroupsWithState - initial state - processing time timeout") {
// function will return -1 on timeout and returns count of the state otherwise
val stateFunc =
(key: String, values: Iterator[(String, Long)], state: GroupState[RunningCount]) => {
if (state.hasTimedOut) {
state.remove()
Iterator((key, "-1"))
} else {
val count = state.getOption.map(_.count).getOrElse(0L) + values.size
state.update(RunningCount(count))
state.setTimeoutDuration("10 seconds")
Iterator((key, count.toString))
}
}
val clock = new StreamManualClock
val inputData = MemoryStream[(String, Long)]
val initialState = Seq(
("c", new RunningCount(2))
).toDS().groupByKey(_._1).mapValues(_._2)
val result =
inputData.toDF().toDF("key", "time")
.selectExpr("key", "timestamp_seconds(time) as timestamp")
.withWatermark("timestamp", "10 second")
.as[(String, Long)]
.groupByKey(x => x._1)
.flatMapGroupsWithState(Update, ProcessingTimeTimeout(), initialState)(stateFunc)
testStream(result, Update)(
StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
AddData(inputData, ("a", 1L)),
AdvanceManualClock(1 * 1000), // a and c are processed here for the first time.
CheckNewAnswer(("a", "1"), ("c", "2")),
AdvanceManualClock(10 * 1000),
AddData(inputData, ("b", 1L)), // this will trigger c and a to get timed out
AdvanceManualClock(1 * 1000),
CheckNewAnswer(("a", "-1"), ("b", "1"), ("c", "-1"))
)
}
def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = {
test("SPARK-20714: watermark does not fail query when timeout = " + timeoutConf) {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
val stateFunc =
(key: String, values: Iterator[(String, Long)], state: GroupState[RunningCount]) => {
(key: String, values: Iterator[(String, Long)], state: GroupState[RunningCount]) => {
if (state.hasTimedOut) {
state.remove()
Iterator((key, "-1"))
@ -1411,10 +1679,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
.groupByKey(x => x)
.flatMapGroupsWithState[Int, Int](Append, timeoutConf = timeoutType)(func)
.logicalPlan.collectFirst {
case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) =>
case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t,
hasInitialState, sga, sda, se, i, c) =>
FlatMapGroupsWithStateExec(
f, k, v, g, d, o, None, s, stateFormatVersion, m, t,
f, k, v, se, g, sga, d, sda, o, None, s, stateFormatVersion, m, t,
Some(currentBatchTimestamp), Some(currentBatchWatermark),
RDDScanExec(g, emptyRdd, "rdd"),
hasInitialState,
RDDScanExec(g, emptyRdd, "rdd"))
}.get
}
@ -1461,6 +1732,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
}
}
case class User(name: String, id: String)
object FlatMapGroupsWithStateSuite {
var failInTask = true