[SPARK-35897][SS] Support user defined initial state with flatMapGroupsWithState in Structured Streaming
### What changes were proposed in this pull request? This PR aims to add support for specifying a user defined initial state for arbitrary structured streaming stateful processing using [flat]MapGroupsWithState operator. ### Why are the changes needed? Users can load previous state of their stateful processing as an initial state instead of redoing the entire processing once again. ### Does this PR introduce _any_ user-facing change? Yes this PR introduces new API ``` def mapGroupsWithState[S: Encoder, U: Encoder]( timeoutConf: GroupStateTimeout, initialState: KeyValueGroupedDataset[K, S])( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, timeoutConf: GroupStateTimeout, initialState: KeyValueGroupedDataset[K, S])( func: (K, Iterator[V], GroupState[S]) => Iterator[U]) ``` ### How was this patch tested? Through unit tests in FlatMapGroupsWithStateSuite Closes #33093 from rahulsmahadev/flatMapGroupsWithState. Authored-by: Rahul Mahadev <rahul.mahadev@databricks.com> Signed-off-by: Gengliang Wang <gengliang@apache.org>
This commit is contained in:
parent
95d94948c5
commit
47485a3c2d
|
@ -37,6 +37,12 @@ object UnsupportedOperationChecker extends Logging {
|
|||
case p if p.isStreaming =>
|
||||
throwError("Queries with streaming sources must be executed with writeStream.start()")(p)
|
||||
|
||||
case f: FlatMapGroupsWithState =>
|
||||
if (f.hasInitialState) {
|
||||
throwError("Initial state is not supported in [flatMap|map]GroupsWithState" +
|
||||
" operation on a batch DataFrame/Dataset")(f)
|
||||
}
|
||||
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
|
@ -232,6 +238,12 @@ object UnsupportedOperationChecker extends Logging {
|
|||
// Check compatibility with output modes and aggregations in query
|
||||
val aggsInQuery = collectStreamingAggregates(plan)
|
||||
|
||||
if (m.initialState.isStreaming) {
|
||||
// initial state has to be a batch relation
|
||||
throwError("Non-streaming DataFrame/Dataset is not supported as the" +
|
||||
" initial state in [flatMap|map]GroupsWithState operation on a streaming" +
|
||||
" DataFrame/Dataset")
|
||||
}
|
||||
if (m.isMapGroupsWithState) { // check mapGroupsWithState
|
||||
// allowed only in update query output mode and without aggregation
|
||||
if (aggsInQuery.nonEmpty) {
|
||||
|
|
|
@ -440,7 +440,7 @@ object FlatMapGroupsWithState {
|
|||
isMapGroupsWithState: Boolean,
|
||||
timeout: GroupStateTimeout,
|
||||
child: LogicalPlan): LogicalPlan = {
|
||||
val encoder = encoderFor[S]
|
||||
val stateEncoder = encoderFor[S]
|
||||
|
||||
val mapped = new FlatMapGroupsWithState(
|
||||
func,
|
||||
|
@ -449,10 +449,49 @@ object FlatMapGroupsWithState {
|
|||
groupingAttributes,
|
||||
dataAttributes,
|
||||
CatalystSerde.generateObjAttr[U],
|
||||
encoder.asInstanceOf[ExpressionEncoder[Any]],
|
||||
stateEncoder.asInstanceOf[ExpressionEncoder[Any]],
|
||||
outputMode,
|
||||
isMapGroupsWithState,
|
||||
timeout,
|
||||
hasInitialState = false,
|
||||
groupingAttributes,
|
||||
dataAttributes,
|
||||
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
|
||||
LocalRelation(stateEncoder.schema.toAttributes), // empty data set
|
||||
child
|
||||
)
|
||||
CatalystSerde.serialize[U](mapped)
|
||||
}
|
||||
|
||||
def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder](
|
||||
func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
|
||||
groupingAttributes: Seq[Attribute],
|
||||
dataAttributes: Seq[Attribute],
|
||||
outputMode: OutputMode,
|
||||
isMapGroupsWithState: Boolean,
|
||||
timeout: GroupStateTimeout,
|
||||
child: LogicalPlan,
|
||||
initialStateGroupAttrs: Seq[Attribute],
|
||||
initialStateDataAttrs: Seq[Attribute],
|
||||
initialState: LogicalPlan): LogicalPlan = {
|
||||
val stateEncoder = encoderFor[S]
|
||||
|
||||
val mapped = new FlatMapGroupsWithState(
|
||||
func,
|
||||
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
|
||||
UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
|
||||
groupingAttributes,
|
||||
dataAttributes,
|
||||
CatalystSerde.generateObjAttr[U],
|
||||
stateEncoder.asInstanceOf[ExpressionEncoder[Any]],
|
||||
outputMode,
|
||||
isMapGroupsWithState,
|
||||
timeout,
|
||||
hasInitialState = true,
|
||||
initialStateGroupAttrs,
|
||||
initialStateDataAttrs,
|
||||
UnresolvedDeserializer(encoderFor[S].deserializer, initialStateDataAttrs),
|
||||
initialState,
|
||||
child)
|
||||
CatalystSerde.serialize[U](mapped)
|
||||
}
|
||||
|
@ -474,6 +513,12 @@ object FlatMapGroupsWithState {
|
|||
* @param outputMode the output mode of `func`
|
||||
* @param isMapGroupsWithState whether it is created by the `mapGroupsWithState` method
|
||||
* @param timeout used to timeout groups that have not received data in a while
|
||||
* @param hasInitialState Indicates whether initial state needs to be applied or not.
|
||||
* @param initialStateGroupAttrs grouping attributes for the initial state
|
||||
* @param initialStateDataAttrs used to read the initial state
|
||||
* @param initialStateDeserializer used to extract the initial state objects.
|
||||
* @param initialState user defined initial state that is applied in the first batch.
|
||||
* @param child logical plan of the underlying data
|
||||
*/
|
||||
case class FlatMapGroupsWithState(
|
||||
func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
|
||||
|
@ -486,14 +531,24 @@ case class FlatMapGroupsWithState(
|
|||
outputMode: OutputMode,
|
||||
isMapGroupsWithState: Boolean = false,
|
||||
timeout: GroupStateTimeout,
|
||||
child: LogicalPlan) extends UnaryNode with ObjectProducer {
|
||||
hasInitialState: Boolean = false,
|
||||
initialStateGroupAttrs: Seq[Attribute],
|
||||
initialStateDataAttrs: Seq[Attribute],
|
||||
initialStateDeserializer: Expression,
|
||||
initialState: LogicalPlan,
|
||||
child: LogicalPlan) extends BinaryNode with ObjectProducer {
|
||||
|
||||
if (isMapGroupsWithState) {
|
||||
assert(outputMode == OutputMode.Update)
|
||||
}
|
||||
|
||||
override protected def withNewChildInternal(newChild: LogicalPlan): FlatMapGroupsWithState =
|
||||
copy(child = newChild)
|
||||
override def left: LogicalPlan = child
|
||||
|
||||
override def right: LogicalPlan = initialState
|
||||
|
||||
override protected def withNewChildrenInternal(
|
||||
newLeft: LogicalPlan, newRight: LogicalPlan): FlatMapGroupsWithState =
|
||||
copy(child = newLeft, initialState = newRight)
|
||||
}
|
||||
|
||||
/** Factory for constructing new `FlatMapGroupsInR` nodes. */
|
||||
|
|
|
@ -24,13 +24,13 @@ import org.apache.spark.sql.AnalysisException
|
|||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, MonotonicallyIncreasingID, NamedExpression}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, MonotonicallyIncreasingID, NamedExpression}
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _}
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.streaming.OutputMode
|
||||
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
|
||||
import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder}
|
||||
|
||||
/** A dummy command for testing unsupported operations. */
|
||||
|
@ -145,15 +145,15 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
for (funcMode <- Seq(Append, Update)) {
|
||||
assertSupportedInBatchPlan(
|
||||
s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, null,
|
||||
batchRelation))
|
||||
|
||||
assertSupportedInBatchPlan(
|
||||
s"flatMapGroupsWithState - multiple flatMapGroupsWithState($funcMode)s on batch relation",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, null,
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false,
|
||||
null, batchRelation)))
|
||||
}
|
||||
|
@ -162,7 +162,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
assertSupportedInStreamingPlan(
|
||||
"flatMapGroupsWithState - flatMapGroupsWithState(Update) " +
|
||||
"on streaming relation without aggregation in update mode",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null,
|
||||
streamRelation),
|
||||
outputMode = Update)
|
||||
|
@ -170,7 +170,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
assertNotSupportedInStreamingPlan(
|
||||
"flatMapGroupsWithState - flatMapGroupsWithState(Update) " +
|
||||
"on streaming relation without aggregation in append mode",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null,
|
||||
streamRelation),
|
||||
outputMode = Append,
|
||||
|
@ -179,7 +179,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
assertNotSupportedInStreamingPlan(
|
||||
"flatMapGroupsWithState - flatMapGroupsWithState(Update) " +
|
||||
"on streaming relation without aggregation in complete mode",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null,
|
||||
streamRelation),
|
||||
outputMode = Complete,
|
||||
|
@ -192,7 +192,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
assertNotSupportedInStreamingPlan(
|
||||
"flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation " +
|
||||
s"with aggregation in $outputMode mode",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null,
|
||||
Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)),
|
||||
outputMode = outputMode,
|
||||
|
@ -203,7 +203,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
assertSupportedInStreamingPlan(
|
||||
"flatMapGroupsWithState - flatMapGroupsWithState(Append) " +
|
||||
"on streaming relation without aggregation in append mode",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null,
|
||||
streamRelation),
|
||||
outputMode = Append)
|
||||
|
@ -211,7 +211,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
assertNotSupportedInStreamingPlan(
|
||||
"flatMapGroupsWithState - flatMapGroupsWithState(Append) " +
|
||||
"on streaming relation without aggregation in update mode",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null,
|
||||
streamRelation),
|
||||
outputMode = Update,
|
||||
|
@ -226,7 +226,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
Aggregate(
|
||||
Seq(attributeWithWatermark),
|
||||
aggExprs("c"),
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null,
|
||||
streamRelation)),
|
||||
outputMode = outputMode,
|
||||
|
@ -237,7 +237,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
assertNotSupportedInStreamingPlan(
|
||||
"flatMapGroupsWithState - flatMapGroupsWithState(Append) " +
|
||||
s"on streaming relation after aggregation in $outputMode mode",
|
||||
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
isMapGroupsWithState = false, null,
|
||||
Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)),
|
||||
outputMode = outputMode,
|
||||
|
@ -247,7 +247,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
assertNotSupportedInStreamingPlan(
|
||||
"flatMapGroupsWithState - " +
|
||||
"flatMapGroupsWithState(Update) on streaming relation in complete mode",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null,
|
||||
streamRelation),
|
||||
outputMode = Complete,
|
||||
|
@ -261,7 +261,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
assertSupportedInStreamingPlan(
|
||||
s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation inside " +
|
||||
s"streaming relation in $outputMode output mode",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false,
|
||||
null, batchRelation),
|
||||
outputMode = outputMode
|
||||
|
@ -274,9 +274,9 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
assertSupportedInStreamingPlan(
|
||||
"flatMapGroupsWithState - multiple flatMapGroupsWithStates on streaming relation and all are " +
|
||||
"in append mode",
|
||||
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
isMapGroupsWithState = false, null,
|
||||
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
isMapGroupsWithState = false, null, streamRelation)),
|
||||
outputMode = Append,
|
||||
SQLConf.STATEFUL_OPERATOR_CHECK_CORRECTNESS_ENABLED.key -> "false")
|
||||
|
@ -284,9 +284,9 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
assertNotSupportedInStreamingPlan(
|
||||
"flatMapGroupsWithState - multiple flatMapGroupsWithStates on s streaming relation but some" +
|
||||
" are not in append mode",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null,
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null,
|
||||
streamRelation)),
|
||||
outputMode = Append,
|
||||
|
@ -296,7 +296,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
assertNotSupportedInStreamingPlan(
|
||||
"mapGroupsWithState - mapGroupsWithState " +
|
||||
"on streaming relation without aggregation in append mode",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null,
|
||||
streamRelation),
|
||||
outputMode = Append,
|
||||
|
@ -307,7 +307,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
assertNotSupportedInStreamingPlan(
|
||||
"mapGroupsWithState - mapGroupsWithState " +
|
||||
"on streaming relation without aggregation in complete mode",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null,
|
||||
streamRelation),
|
||||
outputMode = Complete,
|
||||
|
@ -319,7 +319,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
assertNotSupportedInStreamingPlan(
|
||||
"mapGroupsWithState - mapGroupsWithState on streaming relation " +
|
||||
s"with aggregation in $outputMode mode",
|
||||
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Update,
|
||||
TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Update,
|
||||
isMapGroupsWithState = true, null,
|
||||
Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)),
|
||||
outputMode = outputMode,
|
||||
|
@ -330,9 +330,9 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
assertNotSupportedInStreamingPlan(
|
||||
"mapGroupsWithState - multiple mapGroupsWithStates on streaming relation and all are " +
|
||||
"in append mode",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null,
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null,
|
||||
streamRelation)),
|
||||
outputMode = Append,
|
||||
|
@ -342,9 +342,9 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
assertNotSupportedInStreamingPlan(
|
||||
"mapGroupsWithState - " +
|
||||
"mixing mapGroupsWithStates and flatMapGroupsWithStates on streaming relation",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null,
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null,
|
||||
streamRelation)
|
||||
),
|
||||
|
@ -354,7 +354,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
// mapGroupsWithState with event time timeout + watermark
|
||||
assertNotSupportedInStreamingPlan(
|
||||
"mapGroupsWithState - mapGroupsWithState with event time timeout without watermark",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true,
|
||||
EventTimeTimeout, streamRelation),
|
||||
outputMode = Update,
|
||||
|
@ -362,7 +362,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
|
||||
assertSupportedInStreamingPlan(
|
||||
"mapGroupsWithState - mapGroupsWithState with event time timeout with watermark",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true,
|
||||
EventTimeTimeout, new TestStreamingRelation(attributeWithWatermark)),
|
||||
outputMode = Update)
|
||||
|
@ -532,7 +532,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
|
||||
testGlobalWatermarkLimit(
|
||||
s"FlatMapGroupsWithState after stream-stream $joinType join in Append mode",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
isMapGroupsWithState = false, null,
|
||||
streamRelation.join(streamRelation, joinType = joinType,
|
||||
|
@ -664,7 +664,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
|
||||
assertFailOnGlobalWatermarkLimit(
|
||||
"FlatMapGroupsWithState after streaming aggregation in Append mode",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
isMapGroupsWithState = false, null,
|
||||
streamRelation.groupBy("a")(count("*"))),
|
||||
|
@ -675,14 +675,14 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
{
|
||||
assertPassOnGlobalWatermarkLimit(
|
||||
"single FlatMapGroupsWithState in Append mode",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
isMapGroupsWithState = false, null, streamRelation),
|
||||
OutputMode.Append())
|
||||
|
||||
assertFailOnGlobalWatermarkLimit(
|
||||
"streaming aggregation after FlatMapGroupsWithState in Append mode",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
isMapGroupsWithState = false, null, streamRelation).groupBy("*")(count("*")),
|
||||
OutputMode.Append())
|
||||
|
@ -691,7 +691,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
assertFailOnGlobalWatermarkLimit(
|
||||
s"stream-stream $joinType after FlatMapGroupsWithState in Append mode",
|
||||
streamRelation.join(
|
||||
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
isMapGroupsWithState = false, null, streamRelation), joinType = joinType,
|
||||
condition = Some(attributeWithWatermark === attribute)),
|
||||
OutputMode.Append())
|
||||
|
@ -699,16 +699,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
|
||||
assertFailOnGlobalWatermarkLimit(
|
||||
"FlatMapGroupsWithState after FlatMapGroupsWithState in Append mode",
|
||||
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
isMapGroupsWithState = false, null,
|
||||
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
isMapGroupsWithState = false, null, streamRelation)),
|
||||
OutputMode.Append())
|
||||
|
||||
assertFailOnGlobalWatermarkLimit(
|
||||
s"deduplicate after FlatMapGroupsWithState in Append mode",
|
||||
Deduplicate(Seq(attribute),
|
||||
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
isMapGroupsWithState = false, null, streamRelation)),
|
||||
OutputMode.Append())
|
||||
}
|
||||
|
@ -730,7 +730,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
|
||||
assertPassOnGlobalWatermarkLimit(
|
||||
"FlatMapGroupsWithState after deduplicate in Append mode",
|
||||
FlatMapGroupsWithState(
|
||||
TestFlatMapGroupsWithState(
|
||||
null, att, att, Seq(att), Seq(att), att, null, Append,
|
||||
isMapGroupsWithState = false, null,
|
||||
Deduplicate(Seq(attribute), streamRelation)),
|
||||
|
@ -1015,3 +1015,43 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
|
|||
override def nodeName: String = "StreamingRelationV2"
|
||||
}
|
||||
}
|
||||
|
||||
object TestFlatMapGroupsWithState {
|
||||
|
||||
// scalastyle:off
|
||||
// Creating an apply constructor here as we changed the class by adding more fields
|
||||
def apply(func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
|
||||
keyDeserializer: Expression,
|
||||
valueDeserializer: Expression,
|
||||
groupingAttributes: Seq[Attribute],
|
||||
dataAttributes: Seq[Attribute],
|
||||
outputObjAttr: Attribute,
|
||||
stateEncoder: ExpressionEncoder[Any],
|
||||
outputMode: OutputMode,
|
||||
isMapGroupsWithState: Boolean = false,
|
||||
timeout: GroupStateTimeout,
|
||||
child: LogicalPlan): FlatMapGroupsWithState = {
|
||||
|
||||
val attribute = AttributeReference("a", IntegerType, nullable = true)()
|
||||
val batchRelation = LocalRelation(attribute)
|
||||
new FlatMapGroupsWithState(
|
||||
func,
|
||||
keyDeserializer,
|
||||
valueDeserializer,
|
||||
groupingAttributes,
|
||||
dataAttributes,
|
||||
outputObjAttr,
|
||||
stateEncoder,
|
||||
outputMode,
|
||||
isMapGroupsWithState,
|
||||
timeout,
|
||||
false,
|
||||
groupingAttributes,
|
||||
dataAttributes,
|
||||
valueDeserializer,
|
||||
batchRelation,
|
||||
child
|
||||
)
|
||||
}
|
||||
// scalastyle:on
|
||||
}
|
||||
|
|
|
@ -280,6 +280,51 @@ class KeyValueGroupedDataset[K, V] private[sql](
|
|||
child = logicalPlan))
|
||||
}
|
||||
|
||||
/**
|
||||
* (Scala-specific)
|
||||
* Applies the given function to each group of data, while maintaining a user-defined per-group
|
||||
* state. The result Dataset will represent the objects returned by the function.
|
||||
* For a static batch Dataset, the function will be invoked once per group. For a streaming
|
||||
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
|
||||
* updates to each group's state will be saved across invocations.
|
||||
* See [[org.apache.spark.sql.streaming.GroupState]] for more details.
|
||||
*
|
||||
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
|
||||
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
|
||||
* @param func Function to be called on every group.
|
||||
* @param timeoutConf Timeout Conf, see GroupStateTimeout for more details
|
||||
* @param initialState The user provided state that will be initialized when the first batch
|
||||
* of data is processed in the streaming query. The user defined function
|
||||
* will be called on the state data even if there are no other values in
|
||||
* the group. To convert a Dataset ds of type Dataset[(K, S)] to a
|
||||
* KeyValueGroupedDataset[K, S]
|
||||
* do {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}}
|
||||
*
|
||||
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
|
||||
* @since 3.2.0
|
||||
*/
|
||||
def mapGroupsWithState[S: Encoder, U: Encoder](
|
||||
timeoutConf: GroupStateTimeout,
|
||||
initialState: KeyValueGroupedDataset[K, S])(
|
||||
func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
|
||||
val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s))
|
||||
|
||||
Dataset[U](
|
||||
sparkSession,
|
||||
FlatMapGroupsWithState[K, V, S, U](
|
||||
flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]],
|
||||
groupingAttributes,
|
||||
dataAttributes,
|
||||
OutputMode.Update,
|
||||
isMapGroupsWithState = true,
|
||||
timeoutConf,
|
||||
child = logicalPlan,
|
||||
initialState.groupingAttributes,
|
||||
initialState.dataAttributes,
|
||||
initialState.queryExecution.analyzed
|
||||
))
|
||||
}
|
||||
|
||||
/**
|
||||
* (Java-specific)
|
||||
* Applies the given function to each group of data, while maintaining a user-defined per-group
|
||||
|
@ -336,6 +381,40 @@ class KeyValueGroupedDataset[K, V] private[sql](
|
|||
)(stateEncoder, outputEncoder)
|
||||
}
|
||||
|
||||
/**
|
||||
* (Java-specific)
|
||||
* Applies the given function to each group of data, while maintaining a user-defined per-group
|
||||
* state. The result Dataset will represent the objects returned by the function.
|
||||
* For a static batch Dataset, the function will be invoked once per group. For a streaming
|
||||
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
|
||||
* updates to each group's state will be saved across invocations.
|
||||
* See `GroupState` for more details.
|
||||
*
|
||||
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
|
||||
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
|
||||
* @param func Function to be called on every group.
|
||||
* @param stateEncoder Encoder for the state type.
|
||||
* @param outputEncoder Encoder for the output type.
|
||||
* @param timeoutConf Timeout configuration for groups that do not receive data for a while.
|
||||
* @param initialState The user provided state that will be initialized when the first batch
|
||||
* of data is processed in the streaming query. The user defined function
|
||||
* will be called on the state data even if there are no other values in
|
||||
* the group.
|
||||
*
|
||||
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
|
||||
* @since 3.2.0
|
||||
*/
|
||||
def mapGroupsWithState[S, U](
|
||||
func: MapGroupsWithStateFunction[K, V, S, U],
|
||||
stateEncoder: Encoder[S],
|
||||
outputEncoder: Encoder[U],
|
||||
timeoutConf: GroupStateTimeout,
|
||||
initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
|
||||
mapGroupsWithState[S, U](timeoutConf, initialState)(
|
||||
(key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s)
|
||||
)(stateEncoder, outputEncoder)
|
||||
}
|
||||
|
||||
/**
|
||||
* (Scala-specific)
|
||||
* Applies the given function to each group of data, while maintaining a user-defined per-group
|
||||
|
@ -373,6 +452,53 @@ class KeyValueGroupedDataset[K, V] private[sql](
|
|||
child = logicalPlan))
|
||||
}
|
||||
|
||||
/**
|
||||
* (Scala-specific)
|
||||
* Applies the given function to each group of data, while maintaining a user-defined per-group
|
||||
* state. The result Dataset will represent the objects returned by the function.
|
||||
* For a static batch Dataset, the function will be invoked once per group. For a streaming
|
||||
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
|
||||
* updates to each group's state will be saved across invocations.
|
||||
* See `GroupState` for more details.
|
||||
*
|
||||
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
|
||||
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
|
||||
* @param func Function to be called on every group.
|
||||
* @param outputMode The output mode of the function.
|
||||
* @param timeoutConf Timeout configuration for groups that do not receive data for a while.
|
||||
* @param initialState The user provided state that will be initialized when the first batch
|
||||
* of data is processed in the streaming query. The user defined function
|
||||
* will be called on the state data even if there are no other values in
|
||||
* the group. To covert a Dataset `ds` of type of type `Dataset[(K, S)]`
|
||||
* to a `KeyValueGroupedDataset[K, S]`, use
|
||||
* {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}}
|
||||
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
|
||||
* @since 3.2.0
|
||||
*/
|
||||
def flatMapGroupsWithState[S: Encoder, U: Encoder](
|
||||
outputMode: OutputMode,
|
||||
timeoutConf: GroupStateTimeout,
|
||||
initialState: KeyValueGroupedDataset[K, S])(
|
||||
func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = {
|
||||
if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) {
|
||||
throw new IllegalArgumentException("The output mode of function should be append or update")
|
||||
}
|
||||
Dataset[U](
|
||||
sparkSession,
|
||||
FlatMapGroupsWithState[K, V, S, U](
|
||||
func.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]],
|
||||
groupingAttributes,
|
||||
dataAttributes,
|
||||
outputMode,
|
||||
isMapGroupsWithState = false,
|
||||
timeoutConf,
|
||||
child = logicalPlan,
|
||||
initialState.groupingAttributes,
|
||||
initialState.dataAttributes,
|
||||
initialState.queryExecution.analyzed
|
||||
))
|
||||
}
|
||||
|
||||
/**
|
||||
* (Java-specific)
|
||||
* Applies the given function to each group of data, while maintaining a user-defined per-group
|
||||
|
@ -403,6 +529,44 @@ class KeyValueGroupedDataset[K, V] private[sql](
|
|||
flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder)
|
||||
}
|
||||
|
||||
/**
|
||||
* (Java-specific)
|
||||
* Applies the given function to each group of data, while maintaining a user-defined per-group
|
||||
* state. The result Dataset will represent the objects returned by the function.
|
||||
* For a static batch Dataset, the function will be invoked once per group. For a streaming
|
||||
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
|
||||
* updates to each group's state will be saved across invocations.
|
||||
* See `GroupState` for more details.
|
||||
*
|
||||
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
|
||||
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
|
||||
* @param func Function to be called on every group.
|
||||
* @param outputMode The output mode of the function.
|
||||
* @param stateEncoder Encoder for the state type.
|
||||
* @param outputEncoder Encoder for the output type.
|
||||
* @param timeoutConf Timeout configuration for groups that do not receive data for a while.
|
||||
* @param initialState The user provided state that will be initialized when the first batch
|
||||
* of data is processed in the streaming query. The user defined function
|
||||
* will be called on the state data even if there are no other values in
|
||||
* the group. To covert a Dataset `ds` of type of type `Dataset[(K, S)]`
|
||||
* to a `KeyValueGroupedDataset[K, S]`, use
|
||||
* {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}}
|
||||
*
|
||||
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
|
||||
* @since 3.2.0
|
||||
*/
|
||||
def flatMapGroupsWithState[S, U](
|
||||
func: FlatMapGroupsWithStateFunction[K, V, S, U],
|
||||
outputMode: OutputMode,
|
||||
stateEncoder: Encoder[S],
|
||||
outputEncoder: Encoder[U],
|
||||
timeoutConf: GroupStateTimeout,
|
||||
initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
|
||||
val f = (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s).asScala
|
||||
flatMapGroupsWithState[S, U](
|
||||
outputMode, timeoutConf, initialState)(f)(stateEncoder, outputEncoder)
|
||||
}
|
||||
|
||||
/**
|
||||
* (Scala-specific)
|
||||
* Reduces the elements of each group of data using the specified binary function.
|
||||
|
|
|
@ -561,11 +561,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
|
||||
case FlatMapGroupsWithState(
|
||||
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _,
|
||||
timeout, child) =>
|
||||
timeout, hasInitialState, stateGroupAttr, sda, sDeser, initialState, child) =>
|
||||
val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
|
||||
val execPlan = FlatMapGroupsWithStateExec(
|
||||
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, stateVersion,
|
||||
outputMode, timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child))
|
||||
func, keyDeser, valueDeser, sDeser, groupAttr, stateGroupAttr, dataAttr, sda, outputAttr,
|
||||
None, stateEnc, stateVersion, outputMode, timeout, batchTimestampMs = None,
|
||||
eventTimeWatermark = None, planLater(initialState), hasInitialState, planLater(child)
|
||||
)
|
||||
execPlan :: Nil
|
||||
case _ =>
|
||||
Nil
|
||||
|
@ -669,7 +671,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
|
||||
execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
|
||||
case logical.FlatMapGroupsWithState(
|
||||
f, key, value, grouping, data, output, _, _, _, timeout, child) =>
|
||||
f, key, value, grouping, data, output, _, _, _, timeout, _, _, _, _, _, child) =>
|
||||
execution.MapGroupsExec(
|
||||
f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil
|
||||
case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
|
||||
|
|
|
@ -25,9 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expressi
|
|||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
|
||||
import org.apache.spark.sql.execution._
|
||||
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._
|
||||
import org.apache.spark.sql.execution.streaming.state._
|
||||
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
|
||||
import org.apache.spark.util.CompletionIterator
|
||||
import org.apache.spark.sql.streaming.GroupStateTimeout.NoTimeout
|
||||
import org.apache.spark.util.{CompletionIterator, SerializableConfiguration}
|
||||
|
||||
/**
|
||||
* Physical operator for executing `FlatMapGroupsWithState`
|
||||
|
@ -35,6 +37,7 @@ import org.apache.spark.util.CompletionIterator
|
|||
* @param func function called on each group
|
||||
* @param keyDeserializer used to extract the key object for each group.
|
||||
* @param valueDeserializer used to extract the items in the iterator from an input row.
|
||||
* @param initialStateDeserializer used to extract the state object from the initialState dataset
|
||||
* @param groupingAttributes used to group the data
|
||||
* @param dataAttributes used to read the data
|
||||
* @param outputObjAttr Defines the output object
|
||||
|
@ -42,13 +45,20 @@ import org.apache.spark.util.CompletionIterator
|
|||
* @param outputMode the output mode of `func`
|
||||
* @param timeoutConf used to timeout groups that have not received data in a while
|
||||
* @param batchTimestampMs processing timestamp of the current batch.
|
||||
* @param eventTimeWatermark event time watermark for the current batch
|
||||
* @param initialState the user specified initial state
|
||||
* @param hasInitialState indicates whether the initial state is provided or not
|
||||
* @param child the physical plan for the underlying data
|
||||
*/
|
||||
case class FlatMapGroupsWithStateExec(
|
||||
func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
|
||||
keyDeserializer: Expression,
|
||||
valueDeserializer: Expression,
|
||||
initialStateDeserializer: Expression,
|
||||
groupingAttributes: Seq[Attribute],
|
||||
initialStateGroupAttrs: Seq[Attribute],
|
||||
dataAttributes: Seq[Attribute],
|
||||
initialStateDataAttrs: Seq[Attribute],
|
||||
outputObjAttr: Attribute,
|
||||
stateInfo: Option[StatefulOperatorStateInfo],
|
||||
stateEncoder: ExpressionEncoder[Any],
|
||||
|
@ -57,27 +67,45 @@ case class FlatMapGroupsWithStateExec(
|
|||
timeoutConf: GroupStateTimeout,
|
||||
batchTimestampMs: Option[Long],
|
||||
eventTimeWatermark: Option[Long],
|
||||
initialState: SparkPlan,
|
||||
hasInitialState: Boolean,
|
||||
child: SparkPlan
|
||||
) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport {
|
||||
) extends BinaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport {
|
||||
|
||||
import FlatMapGroupsWithStateExecHelper._
|
||||
import GroupStateImpl._
|
||||
|
||||
override def left: SparkPlan = child
|
||||
|
||||
override def right: SparkPlan = initialState
|
||||
|
||||
private val isTimeoutEnabled = timeoutConf != NoTimeout
|
||||
private val watermarkPresent = child.output.exists {
|
||||
case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
private[sql] val stateManager =
|
||||
createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion)
|
||||
|
||||
/** Distribute by grouping attributes */
|
||||
override def requiredChildDistribution: Seq[Distribution] =
|
||||
ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) :: Nil
|
||||
/**
|
||||
* Distribute by grouping attributes - We need the underlying data and the initial state data
|
||||
* to have the same grouping so that the data are co-lacated on the same task.
|
||||
*/
|
||||
override def requiredChildDistribution: Seq[Distribution] = {
|
||||
ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) ::
|
||||
ClusteredDistribution(initialStateGroupAttrs, stateInfo.map(_.numPartitions)) ::
|
||||
Nil
|
||||
}
|
||||
|
||||
/** Ordering needed for using GroupingIterator */
|
||||
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
|
||||
Seq(groupingAttributes.map(SortOrder(_, Ascending)))
|
||||
/**
|
||||
* Ordering needed for using GroupingIterator.
|
||||
* We need the initial state to also use the ordering as the data so that we can co-locate the
|
||||
* keys from the underlying data and the initial state.
|
||||
*/
|
||||
override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(
|
||||
groupingAttributes.map(SortOrder(_, Ascending)),
|
||||
initialStateGroupAttrs.map(SortOrder(_, Ascending)))
|
||||
|
||||
override def keyExpressions: Seq[Attribute] = groupingAttributes
|
||||
|
||||
|
@ -95,6 +123,77 @@ case class FlatMapGroupsWithStateExec(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process data by applying the user defined function on a per partition basis.
|
||||
*
|
||||
* @param iter - Iterator of the data rows
|
||||
* @param store - associated state store for this partition
|
||||
* @param processor - handle to the input processor object.
|
||||
* @param initialStateIterOption - optional initial state iterator
|
||||
*/
|
||||
def processDataWithPartition(
|
||||
iter: Iterator[InternalRow],
|
||||
store: StateStore,
|
||||
processor: InputProcessor,
|
||||
initialStateIterOption: Option[Iterator[InternalRow]] = None
|
||||
): CompletionIterator[InternalRow, Iterator[InternalRow]] = {
|
||||
val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
|
||||
val commitTimeMs = longMetric("commitTimeMs")
|
||||
val timeoutLatencyMs = longMetric("allRemovalsTimeMs")
|
||||
|
||||
val currentTimeNs = System.nanoTime
|
||||
val updatesStartTimeNs = currentTimeNs
|
||||
var timeoutProcessingStartTimeNs = currentTimeNs
|
||||
|
||||
// If timeout is based on event time, then filter late data based on watermark
|
||||
val filteredIter = watermarkPredicateForData match {
|
||||
case Some(predicate) if timeoutConf == EventTimeTimeout =>
|
||||
applyRemovingRowsOlderThanWatermark(iter, predicate)
|
||||
case _ =>
|
||||
iter
|
||||
}
|
||||
|
||||
val processedOutputIterator = initialStateIterOption match {
|
||||
case Some(initStateIter) if initStateIter.hasNext =>
|
||||
processor.processNewDataWithInitialState(filteredIter, initStateIter)
|
||||
case _ => processor.processNewData(filteredIter)
|
||||
}
|
||||
|
||||
val newDataProcessorIter =
|
||||
CompletionIterator[InternalRow, Iterator[InternalRow]](
|
||||
processedOutputIterator, {
|
||||
// Once the input is processed, mark the start time for timeout processing to measure
|
||||
// it separately from the overall processing time.
|
||||
timeoutProcessingStartTimeNs = System.nanoTime
|
||||
})
|
||||
|
||||
val timeoutProcessorIter =
|
||||
CompletionIterator[InternalRow, Iterator[InternalRow]](processor.processTimedOutState(), {
|
||||
// Note: `timeoutLatencyMs` also includes the time the parent operator took for
|
||||
// processing output returned through iterator.
|
||||
timeoutLatencyMs += NANOSECONDS.toMillis(System.nanoTime - timeoutProcessingStartTimeNs)
|
||||
})
|
||||
|
||||
// Generate a iterator that returns the rows grouped by the grouping function
|
||||
// Note that this code ensures that the filtering for timeout occurs only after
|
||||
// all the data has been processed. This is to ensure that the timeout information of all
|
||||
// the keys with data is updated before they are processed for timeouts.
|
||||
val outputIterator = newDataProcessorIter ++ timeoutProcessorIter
|
||||
|
||||
// Return an iterator of all the rows generated by all the keys, such that when fully
|
||||
// consumed, all the state updates will be committed by the state store
|
||||
CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator, {
|
||||
// Note: Due to the iterator lazy execution, this metric also captures the time taken
|
||||
// by the upstream (consumer) operators in addition to the processing in this operator.
|
||||
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
|
||||
commitTimeMs += timeTakenMs {
|
||||
store.commit()
|
||||
}
|
||||
setStoreMetrics(store)
|
||||
setOperatorMetrics()
|
||||
})
|
||||
}
|
||||
|
||||
override protected def doExecute(): RDD[InternalRow] = {
|
||||
metrics // force lazy init at driver
|
||||
|
||||
|
@ -103,69 +202,50 @@ case class FlatMapGroupsWithStateExec(
|
|||
case ProcessingTimeTimeout =>
|
||||
require(batchTimestampMs.nonEmpty)
|
||||
case EventTimeTimeout =>
|
||||
require(eventTimeWatermark.nonEmpty) // watermark value has been populated
|
||||
require(eventTimeWatermark.nonEmpty) // watermark value has been populated
|
||||
require(watermarkExpression.nonEmpty) // input schema has watermark attribute
|
||||
case _ =>
|
||||
}
|
||||
|
||||
child.execute().mapPartitionsWithStateStore[InternalRow](
|
||||
getStateInfo,
|
||||
groupingAttributes.toStructType,
|
||||
stateManager.stateSchema,
|
||||
indexOrdinal = None,
|
||||
session.sessionState,
|
||||
Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
|
||||
val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
|
||||
val commitTimeMs = longMetric("commitTimeMs")
|
||||
val timeoutLatencyMs = longMetric("allRemovalsTimeMs")
|
||||
if (hasInitialState) {
|
||||
// If the user provided initial state we need to have the initial state and the
|
||||
// data in the same partition so that we can still have just one commit at the end.
|
||||
val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf)
|
||||
val hadoopConfBroadcast = sparkContext.broadcast(
|
||||
new SerializableConfiguration(session.sqlContext.sessionState.newHadoopConf()))
|
||||
child.execute().stateStoreAwareZipPartitions(
|
||||
initialState.execute(),
|
||||
getStateInfo,
|
||||
storeNames = Seq(),
|
||||
session.sqlContext.streams.stateStoreCoordinator) {
|
||||
// The state store aware zip partitions will provide us with two iterators,
|
||||
// child data iterator and the initial state iterator per partition.
|
||||
case (partitionId, childDataIterator, initStateIterator) =>
|
||||
|
||||
val stateStoreId = StateStoreId(
|
||||
stateInfo.get.checkpointLocation, stateInfo.get.operatorId, partitionId)
|
||||
val storeProviderId = StateStoreProviderId(stateStoreId, stateInfo.get.queryRunId)
|
||||
val store = StateStore.get(
|
||||
storeProviderId,
|
||||
groupingAttributes.toStructType,
|
||||
stateManager.stateSchema,
|
||||
indexOrdinal = None,
|
||||
stateInfo.get.storeVersion, storeConf, hadoopConfBroadcast.value.value)
|
||||
val processor = new InputProcessor(store)
|
||||
processDataWithPartition(childDataIterator, store, processor, Some(initStateIterator))
|
||||
}
|
||||
} else {
|
||||
child.execute().mapPartitionsWithStateStore[InternalRow](
|
||||
getStateInfo,
|
||||
groupingAttributes.toStructType,
|
||||
stateManager.stateSchema,
|
||||
indexOrdinal = None,
|
||||
session.sqlContext.sessionState,
|
||||
Some(session.sqlContext.streams.stateStoreCoordinator)
|
||||
) { case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
|
||||
val processor = new InputProcessor(store)
|
||||
|
||||
val currentTimeNs = System.nanoTime
|
||||
val updatesStartTimeNs = currentTimeNs
|
||||
var timeoutProcessingStartTimeNs = currentTimeNs
|
||||
|
||||
// If timeout is based on event time, then filter late data based on watermark
|
||||
val filteredIter = watermarkPredicateForData match {
|
||||
case Some(predicate) if timeoutConf == EventTimeTimeout =>
|
||||
applyRemovingRowsOlderThanWatermark(iter, predicate)
|
||||
case _ =>
|
||||
iter
|
||||
}
|
||||
|
||||
val newDataProcessorIter =
|
||||
CompletionIterator[InternalRow, Iterator[InternalRow]](
|
||||
processor.processNewData(filteredIter), {
|
||||
// Once the input is processed, mark the start time for timeout processing to measure
|
||||
// it separately from the overall processing time.
|
||||
timeoutProcessingStartTimeNs = System.nanoTime
|
||||
})
|
||||
|
||||
val timeoutProcessorIter =
|
||||
CompletionIterator[InternalRow, Iterator[InternalRow]](processor.processTimedOutState(), {
|
||||
// Note: `timeoutLatencyMs` also includes the time the parent operator took for
|
||||
// processing output returned through iterator.
|
||||
timeoutLatencyMs += NANOSECONDS.toMillis(System.nanoTime - timeoutProcessingStartTimeNs)
|
||||
})
|
||||
|
||||
// Generate a iterator that returns the rows grouped by the grouping function
|
||||
// Note that this code ensures that the filtering for timeout occurs only after
|
||||
// all the data has been processed. This is to ensure that the timeout information of all
|
||||
// the keys with data is updated before they are processed for timeouts.
|
||||
val outputIterator = newDataProcessorIter ++ timeoutProcessorIter
|
||||
|
||||
// Return an iterator of all the rows generated by all the keys, such that when fully
|
||||
// consumed, all the state updates will be committed by the state store
|
||||
CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator, {
|
||||
// Note: Due to the iterator lazy execution, this metric also captures the time taken
|
||||
// by the upstream (consumer) operators in addition to the processing in this operator.
|
||||
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
|
||||
commitTimeMs += timeTakenMs {
|
||||
store.commit()
|
||||
}
|
||||
setStoreMetrics(store)
|
||||
setOperatorMetrics()
|
||||
}
|
||||
)
|
||||
processDataWithPartition(singleIterator, store, processor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -178,6 +258,11 @@ case class FlatMapGroupsWithStateExec(
|
|||
private val getValueObj =
|
||||
ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
|
||||
private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType)
|
||||
private val getStateObj = if (hasInitialState) {
|
||||
Some(ObjectOperator.deserializeRowToObject(initialStateDeserializer, initialStateDataAttrs))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
||||
// Metrics
|
||||
private val numUpdatedStateRows = longMetric("numUpdatedStateRows")
|
||||
|
@ -199,6 +284,49 @@ case class FlatMapGroupsWithStateExec(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process the new data iterator along with the initial state. The initial state is applied
|
||||
* before processing the new data for every key. The user defined function is called only
|
||||
* once for every key that has either initial state or data or both.
|
||||
*/
|
||||
def processNewDataWithInitialState(
|
||||
childDataIter: Iterator[InternalRow],
|
||||
initStateIter: Iterator[InternalRow]
|
||||
): Iterator[InternalRow] = {
|
||||
|
||||
if (!childDataIter.hasNext && !initStateIter.hasNext) return Iterator.empty
|
||||
|
||||
// Create iterators for the child data and the initial state grouped by their grouping
|
||||
// attributes.
|
||||
val groupedChildDataIter = GroupedIterator(childDataIter, groupingAttributes, child.output)
|
||||
val groupedInitialStateIter =
|
||||
GroupedIterator(initStateIter, initialStateGroupAttrs, initialState.output)
|
||||
|
||||
// Create a CoGroupedIterator that will group the two iterators together for every key group.
|
||||
new CoGroupedIterator(
|
||||
groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap {
|
||||
case (keyRow, valueRowIter, initialStateRowIter) =>
|
||||
val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
|
||||
var foundInitialStateForKey = false
|
||||
initialStateRowIter.foreach { initialStateRow =>
|
||||
if (foundInitialStateForKey) {
|
||||
throw new IllegalArgumentException("The initial state provided contained " +
|
||||
"multiple rows(state) with the same key. Make sure to de-duplicate the " +
|
||||
"initial state before passing it.")
|
||||
}
|
||||
foundInitialStateForKey = true
|
||||
val initStateObj = getStateObj.get(initialStateRow)
|
||||
stateManager.putState(store, keyUnsafeRow, initStateObj, NO_TIMESTAMP)
|
||||
}
|
||||
// We apply the values for the key after applying the initial state.
|
||||
callFunctionAndUpdateState(
|
||||
stateManager.getState(store, keyUnsafeRow),
|
||||
valueRowIter,
|
||||
hasTimedOut = false
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/** Find the groups that have timeout set and are timing out right now, and call the function */
|
||||
def processTimedOutState(): Iterator[InternalRow] = {
|
||||
if (isTimeoutEnabled) {
|
||||
|
@ -271,6 +399,8 @@ case class FlatMapGroupsWithStateExec(
|
|||
}
|
||||
}
|
||||
|
||||
override protected def withNewChildInternal(newChild: SparkPlan): FlatMapGroupsWithStateExec =
|
||||
copy(child = newChild)
|
||||
override protected def withNewChildrenInternal(
|
||||
newLeft: SparkPlan, newRight: SparkPlan): FlatMapGroupsWithStateExec =
|
||||
copy(child = newLeft, initialState = newRight)
|
||||
}
|
||||
|
||||
|
|
|
@ -157,10 +157,14 @@ class IncrementalExecution(
|
|||
Some(offsetSeqMetadata.batchWatermarkMs))
|
||||
|
||||
case m: FlatMapGroupsWithStateExec =>
|
||||
// We set this to true only for the first batch of the streaming query.
|
||||
val hasInitialState = (currentBatchId == 0L && m.hasInitialState)
|
||||
m.copy(
|
||||
stateInfo = Some(nextStatefulOperationStateInfo),
|
||||
batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
|
||||
eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs))
|
||||
eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs),
|
||||
hasInitialState = hasInitialState
|
||||
)
|
||||
|
||||
case j: StreamingSymmetricHashJoinExec =>
|
||||
j.copy(
|
||||
|
|
|
@ -203,7 +203,9 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
|
|||
}
|
||||
|
||||
/** An operator that supports watermark. */
|
||||
trait WatermarkSupport extends UnaryExecNode {
|
||||
trait WatermarkSupport extends SparkPlan {
|
||||
|
||||
def child: SparkPlan
|
||||
|
||||
/** The keys that may have a watermark attribute. */
|
||||
def keyExpressions: Seq[Attribute]
|
||||
|
|
|
@ -111,6 +111,11 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState
|
|||
* when the group has new data, or the group has timed out. So the user has to set the timeout
|
||||
* duration every time the function is called, otherwise, there will not be any timeout set.
|
||||
*
|
||||
* `[map/flatMap]GroupsWithState` can take a user defined initial state as an additional argument.
|
||||
* This state will be applied when the first batch of the streaming query is processed. If there
|
||||
* are no matching rows in the data for the keys present in the initial state, the state is still
|
||||
* applied and the function will be invoked with the values being an empty iterator.
|
||||
*
|
||||
* Scala example of using GroupState in `mapGroupsWithState`:
|
||||
* {{{
|
||||
* // A mapping function that maintains an integer state for string keys and returns a string.
|
||||
|
|
|
@ -46,6 +46,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow;
|
|||
import org.apache.spark.sql.test.TestSparkSession;
|
||||
import org.apache.spark.sql.types.StructType;
|
||||
import org.apache.spark.util.LongAccumulator;
|
||||
|
||||
import static org.apache.spark.sql.functions.col;
|
||||
import static org.apache.spark.sql.functions.expr;
|
||||
import static org.apache.spark.sql.types.DataTypes.*;
|
||||
|
@ -160,6 +161,71 @@ public class JavaDatasetSuite implements Serializable {
|
|||
Assert.assertEquals(6, reduced);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInitialStateFlatMapGroupsWithState() {
|
||||
List<String> data = Arrays.asList("a", "foo", "bar");
|
||||
Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
|
||||
Dataset<Tuple2<Integer, Long>> initialStateDS = spark.createDataset(
|
||||
Arrays.asList(new Tuple2<Integer, Long>(2, 2L)),
|
||||
Encoders.tuple(Encoders.INT(), Encoders.LONG())
|
||||
);
|
||||
|
||||
KeyValueGroupedDataset<Integer, Tuple2<Integer, Long>> kvInitStateDS =
|
||||
initialStateDS.groupByKey(
|
||||
(MapFunction<Tuple2<Integer, Long>, Integer>) f -> f._1, Encoders.INT());
|
||||
|
||||
KeyValueGroupedDataset<Integer, Long> kvInitStateMappedDS = kvInitStateDS.mapValues(
|
||||
(MapFunction<Tuple2<Integer, Long>, Long>) f -> f._2,
|
||||
Encoders.LONG()
|
||||
);
|
||||
|
||||
KeyValueGroupedDataset<Integer, String> grouped =
|
||||
ds.groupByKey((MapFunction<String, Integer>) String::length, Encoders.INT());
|
||||
|
||||
Dataset<String> flatMapped2 = grouped.flatMapGroupsWithState(
|
||||
(FlatMapGroupsWithStateFunction<Integer, String, Long, String>) (key, values, s) -> {
|
||||
StringBuilder sb = new StringBuilder(key.toString());
|
||||
while (values.hasNext()) {
|
||||
sb.append(values.next());
|
||||
}
|
||||
return Collections.singletonList(sb.toString()).iterator();
|
||||
},
|
||||
OutputMode.Append(),
|
||||
Encoders.LONG(),
|
||||
Encoders.STRING(),
|
||||
GroupStateTimeout.NoTimeout(),
|
||||
kvInitStateMappedDS);
|
||||
|
||||
Assert.assertThrows(
|
||||
"Initial state is not supported in [flatMap|map]GroupsWithState " +
|
||||
"operation on a batch DataFrame/Dataset",
|
||||
AnalysisException.class,
|
||||
() -> {
|
||||
flatMapped2.collectAsList();
|
||||
}
|
||||
);
|
||||
Dataset<String> mapped2 = grouped.mapGroupsWithState(
|
||||
(MapGroupsWithStateFunction<Integer, String, Long, String>) (key, values, s) -> {
|
||||
StringBuilder sb = new StringBuilder(key.toString());
|
||||
while (values.hasNext()) {
|
||||
sb.append(values.next());
|
||||
}
|
||||
return sb.toString();
|
||||
},
|
||||
Encoders.LONG(),
|
||||
Encoders.STRING(),
|
||||
GroupStateTimeout.NoTimeout(),
|
||||
kvInitStateMappedDS);
|
||||
Assert.assertThrows(
|
||||
"Initial state is not supported in [flatMap|map]GroupsWithState " +
|
||||
"operation on a batch DataFrame/Dataset",
|
||||
AnalysisException.class,
|
||||
() -> {
|
||||
mapped2.collectAsList();
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIllegalTestGroupStateCreations() {
|
||||
// SPARK-35800: test code throws upon illegal TestGroupState create() calls
|
||||
|
|
|
@ -26,7 +26,7 @@ import org.scalatest.exceptions.TestFailedException
|
|||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.api.java.Optional
|
||||
import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction
|
||||
import org.apache.spark.sql.{DataFrame, Encoder}
|
||||
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Encoder, KeyValueGroupedDataset}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState
|
||||
|
@ -1042,7 +1042,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
|
|||
)
|
||||
}
|
||||
|
||||
|
||||
test("mapGroupsWithState - streaming") {
|
||||
// Function to maintain running count up to 2, and then remove the count
|
||||
// Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
|
||||
|
@ -1268,12 +1267,281 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
|
|||
assert(e.getMessage === "The output mode of function should be append or update")
|
||||
}
|
||||
|
||||
import testImplicits._
|
||||
|
||||
/**
|
||||
* FlatMapGroupsWithState function that returns the key, value as passed to it
|
||||
* along with the updated state. The state is incremented for every value.
|
||||
*/
|
||||
val flatMapGroupsWithStateFunc =
|
||||
(key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
|
||||
val valList = values.toSeq
|
||||
if (valList.isEmpty) {
|
||||
// When the function is called on just the initial state make sure the other fields
|
||||
// are set correctly
|
||||
assert(state.exists)
|
||||
}
|
||||
assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
|
||||
assertCannotGetWatermark { state.getCurrentWatermarkMs() }
|
||||
assert(!state.hasTimedOut)
|
||||
val count = state.getOption.map(_.count).getOrElse(0L) + valList.size
|
||||
// We need to check if not explicitly calling update will still save the init state or not
|
||||
if (!key.contains("NoUpdate")) {
|
||||
// this is not reached when valList is empty and the state count is 2
|
||||
state.update(new RunningCount(count))
|
||||
}
|
||||
Iterator((key, valList, count.toString))
|
||||
}
|
||||
|
||||
Seq("1", "2", "6").foreach { shufflePartitions =>
|
||||
testWithAllStateVersions(s"flatMapGroupsWithState - initial " +
|
||||
s"state - all cases - shuffle partitions ${shufflePartitions}") {
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> shufflePartitions) {
|
||||
// We will test them on different shuffle partition configuration to make sure the
|
||||
// grouping by key will still work. On higher number of shuffle partitions its possible
|
||||
// that all keys end up on different partitions.
|
||||
val initialState: Dataset[(String, RunningCount)] = Seq(
|
||||
("keyInStateAndData-1", new RunningCount(1)),
|
||||
("keyInStateAndData-2", new RunningCount(2)),
|
||||
("keyNoUpdate", new RunningCount(2)), // state.update will not be called
|
||||
("keyOnlyInState-1", new RunningCount(1))
|
||||
).toDS()
|
||||
|
||||
val it = initialState.groupByKey(x => x._1).mapValues(_._2)
|
||||
val inputData = MemoryStream[String]
|
||||
val result =
|
||||
inputData.toDS()
|
||||
.groupByKey(x => x)
|
||||
.flatMapGroupsWithState(
|
||||
Update, GroupStateTimeout.NoTimeout, it)(flatMapGroupsWithStateFunc)
|
||||
|
||||
testStream(result, Update)(
|
||||
AddData(inputData, "keyOnlyInData", "keyInStateAndData-2"),
|
||||
CheckNewAnswer(
|
||||
("keyOnlyInState-1", Seq[String](), "1"),
|
||||
("keyNoUpdate", Seq[String](), "2"), // update will not be called
|
||||
("keyInStateAndData-2", Seq[String]("keyInStateAndData-2"), "3"), // inc by 1
|
||||
("keyInStateAndData-1", Seq[String](), "1"),
|
||||
("keyOnlyInData", Seq[String]("keyOnlyInData"), "1") // inc by 1
|
||||
),
|
||||
assertNumStateRows(total = 5, updated = 4),
|
||||
// Stop and Start stream to make sure initial state doesn't get applied again.
|
||||
StopStream,
|
||||
StartStream(),
|
||||
AddData(inputData, "keyInStateAndData-1"),
|
||||
CheckNewAnswer(
|
||||
// state incremented by 1
|
||||
("keyInStateAndData-1", Seq[String]("keyInStateAndData-1"), "2")
|
||||
),
|
||||
assertNumStateRows(total = 5, updated = 1),
|
||||
StopStream
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
testWithAllStateVersions("flatMapGroupsWithState - initial state - case class key") {
|
||||
val stateFunc = (key: User, values: Iterator[User], state: GroupState[Long]) => {
|
||||
val valList = values.toSeq
|
||||
if (valList.isEmpty) {
|
||||
// When the function is called on just the initial state make sure the other fields
|
||||
// are set correctly
|
||||
assert(state.exists)
|
||||
}
|
||||
assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
|
||||
assertCannotGetWatermark { state.getCurrentWatermarkMs() }
|
||||
assert(!state.hasTimedOut)
|
||||
val count = state.getOption.getOrElse(0L) + valList.size
|
||||
// We need to check if not explicitly calling update will still save the state or not
|
||||
if (!key.name.contains("NoUpdate")) {
|
||||
// this is not reached when valList is empty and the state count is 2
|
||||
state.update(count)
|
||||
}
|
||||
Iterator((key, valList.map(_.name), count.toString))
|
||||
}
|
||||
|
||||
val ds = Seq(
|
||||
(User("keyInStateAndData", "1"), (1L)),
|
||||
(User("keyOnlyInState", "1"), (1L)),
|
||||
(User("keyNoUpdate", "2"), (2L)) // state.update will not be called on this in the function
|
||||
).toDS().groupByKey(_._1).mapValues(_._2)
|
||||
|
||||
val inputData = MemoryStream[User]
|
||||
val result =
|
||||
inputData.toDS()
|
||||
.groupByKey(x => x)
|
||||
.flatMapGroupsWithState(Update, NoTimeout(), ds)(stateFunc)
|
||||
|
||||
testStream(result, Update)(
|
||||
AddData(inputData, User("keyInStateAndData", "1"), User("keyOnlyInData", "1")),
|
||||
CheckNewAnswer(
|
||||
(("keyInStateAndData", "1"), Seq[String]("keyInStateAndData"), "2"),
|
||||
(("keyOnlyInState", "1"), Seq[String](), "1"),
|
||||
(("keyNoUpdate", "2"), Seq[String](), "2"),
|
||||
(("keyOnlyInData", "1"), Seq[String]("keyOnlyInData"), "1")
|
||||
),
|
||||
assertNumStateRows(total = 4, updated = 3), // (keyOnlyInState, 2) does not call update()
|
||||
// Stop and Start stream to make sure initial state doesn't get applied again.
|
||||
StopStream,
|
||||
StartStream(),
|
||||
AddData(inputData, User("keyOnlyInData", "1")),
|
||||
CheckNewAnswer(
|
||||
(("keyOnlyInData", "1"), Seq[String]("keyOnlyInData"), "2")
|
||||
),
|
||||
assertNumStateRows(total = 4, updated = 1),
|
||||
StopStream
|
||||
)
|
||||
}
|
||||
|
||||
testQuietly("flatMapGroupsWithState - initial state - duplicate keys") {
|
||||
val initialState = Seq(
|
||||
("a", new RunningCount(2)),
|
||||
("a", new RunningCount(1))
|
||||
).toDS().groupByKey(_._1).mapValues(_._2)
|
||||
|
||||
val inputData = MemoryStream[String]
|
||||
val result =
|
||||
inputData.toDS()
|
||||
.groupByKey(x => x)
|
||||
.flatMapGroupsWithState(Update, NoTimeout(), initialState)(flatMapGroupsWithStateFunc)
|
||||
testStream(result, Update)(
|
||||
AddData(inputData, "a"),
|
||||
ExpectFailure[SparkException] { e =>
|
||||
assert(e.getCause.getMessage.contains("The initial state provided contained " +
|
||||
"multiple rows(state) with the same key"))
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
testQuietly("flatMapGroupsWithState - initial state - streaming initial state") {
|
||||
val initialStateData = MemoryStream[(String, RunningCount)]
|
||||
initialStateData.addData(("a", new RunningCount(1)))
|
||||
|
||||
val inputData = MemoryStream[String]
|
||||
|
||||
val result =
|
||||
inputData.toDS()
|
||||
.groupByKey(x => x)
|
||||
.flatMapGroupsWithState(
|
||||
Update, NoTimeout(), initialStateData.toDS().groupByKey(_._1).mapValues(_._2)
|
||||
)(flatMapGroupsWithStateFunc)
|
||||
|
||||
val e = intercept[AnalysisException] {
|
||||
result.writeStream
|
||||
.format("console")
|
||||
.start()
|
||||
}
|
||||
|
||||
val expectedError = "Non-streaming DataFrame/Dataset is not supported" +
|
||||
" as the initial state in [flatMap|map]GroupsWithState" +
|
||||
" operation on a streaming DataFrame/Dataset"
|
||||
assert(e.message.contains(expectedError))
|
||||
}
|
||||
|
||||
test("flatMapGroupsWithState - initial state - initial state has flatMapGroupsWithState") {
|
||||
val initialStateDS = Seq(("keyInStateAndData", new RunningCount(1))).toDS()
|
||||
val initialState: KeyValueGroupedDataset[String, RunningCount] =
|
||||
initialStateDS.groupByKey(_._1).mapValues(_._2)
|
||||
.mapGroupsWithState(
|
||||
GroupStateTimeout.NoTimeout())(
|
||||
(key: String, values: Iterator[RunningCount], state: GroupState[Boolean]) => {
|
||||
(key, values.next())
|
||||
}
|
||||
).groupByKey(_._1).mapValues(_._2)
|
||||
|
||||
val inputData = MemoryStream[String]
|
||||
|
||||
val result =
|
||||
inputData.toDS()
|
||||
.groupByKey(x => x)
|
||||
.flatMapGroupsWithState(
|
||||
Update, NoTimeout(), initialState
|
||||
)(flatMapGroupsWithStateFunc)
|
||||
|
||||
testStream(result, Update)(
|
||||
AddData(inputData, "keyInStateAndData"),
|
||||
CheckNewAnswer(
|
||||
("keyInStateAndData", Seq[String]("keyInStateAndData"), "2")
|
||||
),
|
||||
StopStream
|
||||
)
|
||||
}
|
||||
|
||||
testWithAllStateVersions("mapGroupsWithState - initial state - null key") {
|
||||
val mapGroupsWithStateFunc =
|
||||
(key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
|
||||
val valList = values.toList
|
||||
val count = state.getOption.map(_.count).getOrElse(0L) + valList.size
|
||||
state.update(new RunningCount(count))
|
||||
(key, state.get.count.toString)
|
||||
}
|
||||
val initialState = Seq(
|
||||
("key", new RunningCount(5)),
|
||||
(null, new RunningCount(2))
|
||||
).toDS().groupByKey(_._1).mapValues(_._2)
|
||||
|
||||
val inputData = MemoryStream[String]
|
||||
val result =
|
||||
inputData.toDS()
|
||||
.groupByKey(x => x)
|
||||
.mapGroupsWithState(NoTimeout(), initialState)(mapGroupsWithStateFunc)
|
||||
testStream(result, Update)(
|
||||
AddData(inputData, "key", null),
|
||||
CheckNewAnswer(
|
||||
("key", "6"), // state is incremented by 1
|
||||
(null, "3") // incremented by 1
|
||||
),
|
||||
assertNumStateRows(total = 2, updated = 2),
|
||||
StopStream
|
||||
)
|
||||
}
|
||||
|
||||
testWithAllStateVersions("flatMapGroupsWithState - initial state - processing time timeout") {
|
||||
// function will return -1 on timeout and returns count of the state otherwise
|
||||
val stateFunc =
|
||||
(key: String, values: Iterator[(String, Long)], state: GroupState[RunningCount]) => {
|
||||
if (state.hasTimedOut) {
|
||||
state.remove()
|
||||
Iterator((key, "-1"))
|
||||
} else {
|
||||
val count = state.getOption.map(_.count).getOrElse(0L) + values.size
|
||||
state.update(RunningCount(count))
|
||||
state.setTimeoutDuration("10 seconds")
|
||||
Iterator((key, count.toString))
|
||||
}
|
||||
}
|
||||
|
||||
val clock = new StreamManualClock
|
||||
val inputData = MemoryStream[(String, Long)]
|
||||
val initialState = Seq(
|
||||
("c", new RunningCount(2))
|
||||
).toDS().groupByKey(_._1).mapValues(_._2)
|
||||
val result =
|
||||
inputData.toDF().toDF("key", "time")
|
||||
.selectExpr("key", "timestamp_seconds(time) as timestamp")
|
||||
.withWatermark("timestamp", "10 second")
|
||||
.as[(String, Long)]
|
||||
.groupByKey(x => x._1)
|
||||
.flatMapGroupsWithState(Update, ProcessingTimeTimeout(), initialState)(stateFunc)
|
||||
|
||||
testStream(result, Update)(
|
||||
StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
|
||||
AddData(inputData, ("a", 1L)),
|
||||
AdvanceManualClock(1 * 1000), // a and c are processed here for the first time.
|
||||
CheckNewAnswer(("a", "1"), ("c", "2")),
|
||||
AdvanceManualClock(10 * 1000),
|
||||
AddData(inputData, ("b", 1L)), // this will trigger c and a to get timed out
|
||||
AdvanceManualClock(1 * 1000),
|
||||
CheckNewAnswer(("a", "-1"), ("b", "1"), ("c", "-1"))
|
||||
)
|
||||
}
|
||||
|
||||
def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = {
|
||||
test("SPARK-20714: watermark does not fail query when timeout = " + timeoutConf) {
|
||||
// Function to maintain running count up to 2, and then remove the count
|
||||
// Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
|
||||
val stateFunc =
|
||||
(key: String, values: Iterator[(String, Long)], state: GroupState[RunningCount]) => {
|
||||
(key: String, values: Iterator[(String, Long)], state: GroupState[RunningCount]) => {
|
||||
if (state.hasTimedOut) {
|
||||
state.remove()
|
||||
Iterator((key, "-1"))
|
||||
|
@ -1411,10 +1679,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
|
|||
.groupByKey(x => x)
|
||||
.flatMapGroupsWithState[Int, Int](Append, timeoutConf = timeoutType)(func)
|
||||
.logicalPlan.collectFirst {
|
||||
case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) =>
|
||||
case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t,
|
||||
hasInitialState, sga, sda, se, i, c) =>
|
||||
FlatMapGroupsWithStateExec(
|
||||
f, k, v, g, d, o, None, s, stateFormatVersion, m, t,
|
||||
f, k, v, se, g, sga, d, sda, o, None, s, stateFormatVersion, m, t,
|
||||
Some(currentBatchTimestamp), Some(currentBatchWatermark),
|
||||
RDDScanExec(g, emptyRdd, "rdd"),
|
||||
hasInitialState,
|
||||
RDDScanExec(g, emptyRdd, "rdd"))
|
||||
}.get
|
||||
}
|
||||
|
@ -1461,6 +1732,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
|
|||
}
|
||||
}
|
||||
|
||||
case class User(name: String, id: String)
|
||||
|
||||
object FlatMapGroupsWithStateSuite {
|
||||
|
||||
var failInTask = true
|
||||
|
|
Loading…
Reference in a new issue