[SPARK-19067][SS] Processing-time-based timeout in MapGroupsWithState

## What changes were proposed in this pull request?

When a key does not get any new data in `mapGroupsWithState`, the mapping function is never called on it. So we need a timeout feature that calls the function again in such cases, so that the user can decide whether to continue waiting or clean up (remove state, save stuff externally, etc.).
Timeouts can be either based on processing time or event time. This JIRA is for processing time, but defines the high level API design for both. The usage would look like this.
```
def stateFunction(key: K, value: Iterator[V], state: KeyedState[S]): U = {
  ...
  state.setTimeoutDuration(10000)
  ...
}

dataset					// type is Dataset[T]
  .groupByKey[K](keyingFunc)   // generates KeyValueGroupedDataset[K, T]
  .mapGroupsWithState[S, U](
     func = stateFunction,
     timeout = KeyedStateTimeout.withProcessingTime)	// returns Dataset[U]
```

Note the following design aspects.

- The timeout type is provided as a param in mapGroupsWithState as a parameter global to all the keys. This is so that the planner knows this at planning time, and accordingly optimize the execution based on whether to saves extra info in state or not (e.g. timeout durations or timestamps).

- The exact timeout duration is provided inside the function call so that it can be customized on a per key basis.

- When the timeout occurs for a key, the function is called with no values, and KeyedState.isTimingOut() set to true.

- The timeout is reset for key every time the function is called on the key, that is, when the key has new data, or the key has timed out. So the user has to set the timeout duration everytime the function is called, otherwise there will not be any timeout set.

Guarantees provided on timeout of key, when timeout duration is D ms:
- Timeout will never be called before real clock time has advanced by D ms
- Timeout will be called eventually when there is a trigger with any data in it (i.e. after D ms). So there is a no strict upper bound on when the timeout would occur. For example, if there is no data in the stream (for any key) for a while, then the timeout will not be hit.

Implementation details:
- Added new param to `mapGroupsWithState` for timeout
- Added new method to `StateStore` to filter data based on timeout timestamp
- Changed the internal map type of `HDFSBackedStateStore` from Java's `HashMap` to `ConcurrentHashMap` as the latter allows weakly-consistent fail-safe iterators on the map data. See comments in code for more details.
- Refactored logic of `MapGroupsWithStateExec` to
  - Save timeout info to state store for each key that has data.
  - Then, filter states that should be timed out based on the current batch processing timestamp.
- Moved KeyedState for `o.a.s.sql` to `o.a.s.sql.streaming`. I remember that this was a feedback in the MapGroupsWithState PR that I had forgotten to address.

## How was this patch tested?
New unit tests in
- MapGroupsWithStateSuite for timeouts.
- StateStoreSuite for new APIs in StateStore.

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #17179 from tdas/mapgroupwithstate-timeout.
This commit is contained in:
Tathagata Das 2017-03-19 14:07:49 -07:00
parent 0ee9fbf51a
commit 990af630d0
22 changed files with 1370 additions and 446 deletions

View file

@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.streaming;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.catalyst.plans.logical.NoTimeout$;
import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout;
import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout$;
/**
* Represents the type of timeouts possible for the Dataset operations
* `mapGroupsWithState` and `flatMapGroupsWithState`. See documentation on
* `KeyedState` for more details.
*
* @since 2.2.0
*/
@Experimental
@InterfaceStability.Evolving
public class KeyedStateTimeout {
/** Timeout based on processing time. */
public static KeyedStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; }
/** No timeout */
public static KeyedStateTimeout NoTimeout() { return NoTimeout$.MODULE$; }
}

View file

@ -951,7 +951,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
override def eval(input: InternalRow): Any = {
val result = child.eval(input)
if (result == null) {
throw new RuntimeException(errMsg);
throw new RuntimeException(errMsg)
}
result
}

View file

@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode }
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@ -353,6 +353,10 @@ case class MapGroups(
/** Internal class representing State */
trait LogicalKeyedState[S]
/** Possible types of timeouts used in FlatMapGroupsWithState */
case object NoTimeout extends KeyedStateTimeout
case object ProcessingTimeTimeout extends KeyedStateTimeout
/** Factory for constructing new `MapGroupsWithState` nodes. */
object FlatMapGroupsWithState {
def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder](
@ -361,7 +365,10 @@ object FlatMapGroupsWithState {
dataAttributes: Seq[Attribute],
outputMode: OutputMode,
isMapGroupsWithState: Boolean,
timeout: KeyedStateTimeout,
child: LogicalPlan): LogicalPlan = {
val encoder = encoderFor[S]
val mapped = new FlatMapGroupsWithState(
func,
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
@ -369,11 +376,11 @@ object FlatMapGroupsWithState {
groupingAttributes,
dataAttributes,
CatalystSerde.generateObjAttr[U],
encoderFor[S].resolveAndBind().deserializer,
encoderFor[S].namedExpressions,
encoder.asInstanceOf[ExpressionEncoder[Any]],
outputMode,
child,
isMapGroupsWithState)
isMapGroupsWithState,
timeout,
child)
CatalystSerde.serialize[U](mapped)
}
}
@ -384,15 +391,16 @@ object FlatMapGroupsWithState {
* Func is invoked with an object representation of the grouping key an iterator containing the
* object representation of all the rows with that key.
*
* @param func function called on each group
* @param keyDeserializer used to extract the key object for each group.
* @param valueDeserializer used to extract the items in the iterator from an input row.
* @param groupingAttributes used to group the data
* @param dataAttributes used to read the data
* @param outputObjAttr used to define the output object
* @param stateDeserializer used to deserialize state before calling `func`
* @param stateSerializer used to serialize updated state after calling `func`
* @param stateEncoder used to serialize/deserialize state before calling `func`
* @param outputMode the output mode of `func`
* @param isMapGroupsWithState whether it is created by the `mapGroupsWithState` method
* @param timeout used to timeout groups that have not received data in a while
*/
case class FlatMapGroupsWithState(
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
@ -401,11 +409,11 @@ case class FlatMapGroupsWithState(
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
outputObjAttr: Attribute,
stateDeserializer: Expression,
stateSerializer: Seq[NamedExpression],
stateEncoder: ExpressionEncoder[Any],
outputMode: OutputMode,
child: LogicalPlan,
isMapGroupsWithState: Boolean = false) extends UnaryNode with ObjectProducer {
isMapGroupsWithState: Boolean = false,
timeout: KeyedStateTimeout,
child: LogicalPlan) extends UnaryNode with ObjectProducer {
if (isMapGroupsWithState) {
assert(outputMode == OutputMode.Update)

View file

@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.streaming;
import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout$;
import org.junit.Test;
public class JavaKeyedStateTimeoutSuite {
@Test
public void testTimeouts() {
assert(KeyedStateTimeout.ProcessingTimeTimeout() == ProcessingTimeTimeout$.MODULE$);
}
}

View file

@ -144,14 +144,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
assertSupportedInBatchPlan(
s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation",
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation))
null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, null,
batchRelation))
assertSupportedInBatchPlan(
s"flatMapGroupsWithState - multiple flatMapGroupsWithState($funcMode)s on batch relation",
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode,
null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, null,
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation)))
null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false,
null, batchRelation)))
}
// FlatMapGroupsWithState(Update) in streaming without aggregation
@ -159,14 +161,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
"flatMapGroupsWithState - flatMapGroupsWithState(Update) " +
"on streaming relation without aggregation in update mode",
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation),
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null,
streamRelation),
outputMode = Update)
assertNotSupportedInStreamingPlan(
"flatMapGroupsWithState - flatMapGroupsWithState(Update) " +
"on streaming relation without aggregation in append mode",
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation),
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null,
streamRelation),
outputMode = Append,
expectedMsgs = Seq("flatMapGroupsWithState in update mode", "Append"))
@ -174,7 +178,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
"flatMapGroupsWithState - flatMapGroupsWithState(Update) " +
"on streaming relation without aggregation in complete mode",
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation),
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null,
streamRelation),
outputMode = Complete,
// Disallowed by the aggregation check but let's still keep this test in case it's broken in
// future.
@ -186,7 +191,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
"flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation " +
s"with aggregation in $outputMode mode",
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update,
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null,
Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)),
outputMode = outputMode,
expectedMsgs = Seq("flatMapGroupsWithState in update mode", "with aggregation"))
@ -197,14 +202,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
"flatMapGroupsWithState - flatMapGroupsWithState(Append) " +
"on streaming relation without aggregation in append mode",
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation),
null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null,
streamRelation),
outputMode = Append)
assertNotSupportedInStreamingPlan(
"flatMapGroupsWithState - flatMapGroupsWithState(Append) " +
"on streaming relation without aggregation in update mode",
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation),
null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null,
streamRelation),
outputMode = Update,
expectedMsgs = Seq("flatMapGroupsWithState in append mode", "update"))
@ -217,7 +224,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
Seq(attributeWithWatermark),
aggExprs("c"),
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)),
null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null,
streamRelation)),
outputMode = outputMode)
}
@ -225,7 +233,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
assertNotSupportedInStreamingPlan(
"flatMapGroupsWithState - flatMapGroupsWithState(Append) " +
s"on streaming relation after aggregation in $outputMode mode",
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append,
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
isMapGroupsWithState = false, null,
Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)),
outputMode = outputMode,
expectedMsgs = Seq("flatMapGroupsWithState", "after aggregation"))
@ -235,7 +244,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
"flatMapGroupsWithState - " +
"flatMapGroupsWithState(Update) on streaming relation in complete mode",
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation),
null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null,
streamRelation),
outputMode = Complete,
// Disallowed by the aggregation check but let's still keep this test in case it's broken in
// future.
@ -248,7 +258,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation inside " +
s"streaming relation in $outputMode output mode",
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation),
null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false,
null, batchRelation),
outputMode = outputMode
)
}
@ -258,19 +269,20 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
assertSupportedInStreamingPlan(
"flatMapGroupsWithState - multiple flatMapGroupsWithStates on streaming relation and all are " +
"in append mode",
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append,
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)),
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
isMapGroupsWithState = false, null,
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append,
isMapGroupsWithState = false, null, streamRelation)),
outputMode = Append)
assertNotSupportedInStreamingPlan(
"flatMapGroupsWithState - multiple flatMapGroupsWithStates on s streaming relation but some" +
" are not in append mode",
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update,
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null,
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)),
null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null,
streamRelation)),
outputMode = Append,
expectedMsgs = Seq("multiple flatMapGroupsWithState", "append"))
@ -279,8 +291,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
"mapGroupsWithState - mapGroupsWithState " +
"on streaming relation without aggregation in append mode",
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation,
isMapGroupsWithState = true),
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null,
streamRelation),
outputMode = Append,
// Disallowed by the aggregation check but let's still keep this test in case it's broken in
// future.
@ -290,8 +302,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
"mapGroupsWithState - mapGroupsWithState " +
"on streaming relation without aggregation in complete mode",
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation,
isMapGroupsWithState = true),
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null,
streamRelation),
outputMode = Complete,
// Disallowed by the aggregation check but let's still keep this test in case it's broken in
// future.
@ -301,10 +313,9 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
assertNotSupportedInStreamingPlan(
"mapGroupsWithState - mapGroupsWithState on streaming relation " +
s"with aggregation in $outputMode mode",
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update,
Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation),
isMapGroupsWithState = true),
FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Update,
isMapGroupsWithState = true, null,
Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)),
outputMode = outputMode,
expectedMsgs = Seq("mapGroupsWithState", "with aggregation"))
}
@ -314,11 +325,10 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
"mapGroupsWithState - multiple mapGroupsWithStates on streaming relation and all are " +
"in append mode",
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update,
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null,
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation,
isMapGroupsWithState = true),
isMapGroupsWithState = true),
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null,
streamRelation)),
outputMode = Append,
expectedMsgs = Seq("multiple mapGroupsWithStates"))
@ -327,11 +337,11 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
"mapGroupsWithState - " +
"mixing mapGroupsWithStates and flatMapGroupsWithStates on streaming relation",
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update,
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null,
FlatMapGroupsWithState(
null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation,
isMapGroupsWithState = false),
isMapGroupsWithState = true),
null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null,
streamRelation)
),
outputMode = Append,
expectedMsgs = Seq("Mixing mapGroupsWithStates and flatMapGroupsWithStates"))

View file

@ -22,7 +22,7 @@ import java.util.Iterator;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.KeyedState;
import org.apache.spark.sql.streaming.KeyedState;
/**
* ::Experimental::

View file

@ -22,7 +22,7 @@ import java.util.Iterator;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.KeyedState;
import org.apache.spark.sql.streaming.KeyedState;
/**
* ::Experimental::

View file

@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.expressions.ReduceAggregator
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.streaming.{KeyedState, KeyedStateTimeout, OutputMode}
/**
* :: Experimental ::
@ -228,13 +228,14 @@ class KeyValueGroupedDataset[K, V] private[sql](
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See [[KeyedState]] for more details.
* See [[org.apache.spark.sql.streaming.KeyedState]] for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
* @param func Function to be called on every group.
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
* @since 2.1.1
* @since 2.2.0
*/
@Experimental
@InterfaceStability.Evolving
@ -249,6 +250,43 @@ class KeyValueGroupedDataset[K, V] private[sql](
dataAttributes,
OutputMode.Update,
isMapGroupsWithState = true,
KeyedStateTimeout.NoTimeout,
child = logicalPlan))
}
/**
* ::Experimental::
* (Scala-specific)
* Applies the given function to each group of data, while maintaining a user-defined per-group
* state. The result Dataset will represent the objects returned by the function.
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See [[org.apache.spark.sql.streaming.KeyedState]] for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
* @param func Function to be called on every group.
* @param timeoutConf Timeout configuration for groups that do not receive data for a while.
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
* @since 2.2.0
*/
@Experimental
@InterfaceStability.Evolving
def mapGroupsWithState[S: Encoder, U: Encoder](
timeoutConf: KeyedStateTimeout)(
func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = {
val flatMapFunc = (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s))
Dataset[U](
sparkSession,
FlatMapGroupsWithState[K, V, S, U](
flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
groupingAttributes,
dataAttributes,
OutputMode.Update,
isMapGroupsWithState = true,
timeoutConf,
child = logicalPlan))
}
@ -269,7 +307,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @param outputEncoder Encoder for the output type.
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
* @since 2.1.1
* @since 2.2.0
*/
@Experimental
@InterfaceStability.Evolving
@ -282,6 +320,38 @@ class KeyValueGroupedDataset[K, V] private[sql](
)(stateEncoder, outputEncoder)
}
/**
* ::Experimental::
* (Java-specific)
* Applies the given function to each group of data, while maintaining a user-defined per-group
* state. The result Dataset will represent the objects returned by the function.
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See [[KeyedState]] for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
* @param func Function to be called on every group.
* @param stateEncoder Encoder for the state type.
* @param outputEncoder Encoder for the output type.
* @param timeoutConf Timeout configuration for groups that do not receive data for a while.
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
* @since 2.2.0
*/
@Experimental
@InterfaceStability.Evolving
def mapGroupsWithState[S, U](
func: MapGroupsWithStateFunction[K, V, S, U],
stateEncoder: Encoder[S],
outputEncoder: Encoder[U],
timeoutConf: KeyedStateTimeout): Dataset[U] = {
mapGroupsWithState[S, U](
(key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s)
)(stateEncoder, outputEncoder)
}
/**
* ::Experimental::
* (Scala-specific)
@ -296,14 +366,17 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
* @param func Function to be called on every group.
* @param outputMode The output mode of the function.
* @param timeoutConf Timeout configuration for groups that do not receive data for a while.
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
* @since 2.1.1
* @since 2.2.0
*/
@Experimental
@InterfaceStability.Evolving
def flatMapGroupsWithState[S: Encoder, U: Encoder](
func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: OutputMode): Dataset[U] = {
outputMode: OutputMode,
timeoutConf: KeyedStateTimeout)(
func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = {
if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) {
throw new IllegalArgumentException("The output mode of function should be append or update")
}
@ -315,34 +388,10 @@ class KeyValueGroupedDataset[K, V] private[sql](
dataAttributes,
outputMode,
isMapGroupsWithState = false,
timeoutConf,
child = logicalPlan))
}
/**
* ::Experimental::
* (Scala-specific)
* Applies the given function to each group of data, while maintaining a user-defined per-group
* state. The result Dataset will represent the objects returned by the function.
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See [[KeyedState]] for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
* @param func Function to be called on every group.
* @param outputMode The output mode of the function.
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
* @since 2.1.1
*/
@Experimental
@InterfaceStability.Evolving
def flatMapGroupsWithState[S: Encoder, U: Encoder](
func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: String): Dataset[U] = {
flatMapGroupsWithState(func, InternalOutputModes(outputMode))
}
/**
* ::Experimental::
* (Java-specific)
@ -359,9 +408,10 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @param outputMode The output mode of the function.
* @param stateEncoder Encoder for the state type.
* @param outputEncoder Encoder for the output type.
* @param timeoutConf Timeout configuration for groups that do not receive data for a while.
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
* @since 2.1.1
* @since 2.2.0
*/
@Experimental
@InterfaceStability.Evolving
@ -369,41 +419,10 @@ class KeyValueGroupedDataset[K, V] private[sql](
func: FlatMapGroupsWithStateFunction[K, V, S, U],
outputMode: OutputMode,
stateEncoder: Encoder[S],
outputEncoder: Encoder[U]): Dataset[U] = {
flatMapGroupsWithState[S, U](
(key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala,
outputMode
)(stateEncoder, outputEncoder)
}
/**
* ::Experimental::
* (Java-specific)
* Applies the given function to each group of data, while maintaining a user-defined per-group
* state. The result Dataset will represent the objects returned by the function.
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See [[KeyedState]] for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
* @param func Function to be called on every group.
* @param outputMode The output mode of the function.
* @param stateEncoder Encoder for the state type.
* @param outputEncoder Encoder for the output type.
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
* @since 2.1.1
*/
@Experimental
@InterfaceStability.Evolving
def flatMapGroupsWithState[S, U](
func: FlatMapGroupsWithStateFunction[K, V, S, U],
outputMode: String,
stateEncoder: Encoder[S],
outputEncoder: Encoder[U]): Dataset[U] = {
flatMapGroupsWithState(func, InternalOutputModes(outputMode), stateEncoder, outputEncoder)
outputEncoder: Encoder[U],
timeoutConf: KeyedStateTimeout): Dataset[U] = {
val f = (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala
flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder)
}
/**

View file

@ -1,140 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.annotation.{Experimental, InterfaceStability}
import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
/**
* :: Experimental ::
*
* Wrapper class for interacting with keyed state data in `mapGroupsWithState` and
* `flatMapGroupsWithState` operations on
* [[KeyValueGroupedDataset]].
*
* Detail description on `[map/flatMap]GroupsWithState` operation
* ------------------------------------------------------------
* Both, `mapGroupsWithState` and `flatMapGroupsWithState` in [[KeyValueGroupedDataset]]
* will invoke the user-given function on each group (defined by the grouping function in
* `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations.
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger.
* That is, in every batch of the `streaming.StreamingQuery`,
* the function will be invoked once for each group that has data in the batch.
*
* The function is invoked with following parameters.
* - The key of the group.
* - An iterator containing all the values for this key.
* - A user-defined state object set by previous invocations of the given function.
* In case of a batch Dataset, there is only one invocation and state object will be empty as
* there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState`
* is equivalent to `[map/flatMap]Groups`.
*
* Important points to note about the function.
* - In a trigger, the function will be called only the groups present in the batch. So do not
* assume that the function will be called in every trigger for every group that has state.
* - There is no guaranteed ordering of values in the iterator in the function, neither with
* batch, nor with streaming Datasets.
* - All the data will be shuffled before applying the function.
*
* Important points to note about using KeyedState.
* - The value of the state cannot be null. So updating state with null will throw
* `IllegalArgumentException`.
* - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers.
* - If `remove()` is called, then `exists()` will return `false`,
* `get()` will throw `NoSuchElementException` and `getOption()` will return `None`
* - After that, if `update(newState)` is called, then `exists()` will again return `true`,
* `get()` and `getOption()`will return the updated value.
*
* Scala example of using KeyedState in `mapGroupsWithState`:
* {{{
* // A mapping function that maintains an integer state for string keys and returns a string.
* def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = {
* // Check if state exists
* if (state.exists) {
* val existingState = state.get // Get the existing state
* val shouldRemove = ... // Decide whether to remove the state
* if (shouldRemove) {
* state.remove() // Remove the state
* } else {
* val newState = ...
* state.update(newState) // Set the new state
* }
* } else {
* val initialState = ...
* state.update(initialState) // Set the initial state
* }
* ... // return something
* }
*
* }}}
*
* Java example of using `KeyedState`:
* {{{
* // A mapping function that maintains an integer state for string keys and returns a string.
* MapGroupsWithStateFunction<String, Integer, Integer, String> mappingFunction =
* new MapGroupsWithStateFunction<String, Integer, Integer, String>() {
*
* @Override
* public String call(String key, Iterator<Integer> value, KeyedState<Integer> state) {
* if (state.exists()) {
* int existingState = state.get(); // Get the existing state
* boolean shouldRemove = ...; // Decide whether to remove the state
* if (shouldRemove) {
* state.remove(); // Remove the state
* } else {
* int newState = ...;
* state.update(newState); // Set the new state
* }
* } else {
* int initialState = ...; // Set the initial state
* state.update(initialState);
* }
* ... // return something
* }
* };
* }}}
*
* @tparam S User-defined type of the state to be stored for each key. Must be encodable into
* Spark SQL types (see [[Encoder]] for more details).
* @since 2.1.1
*/
@Experimental
@InterfaceStability.Evolving
trait KeyedState[S] extends LogicalKeyedState[S] {
/** Whether state exists or not. */
def exists: Boolean
/** Get the state value if it exists, or throw NoSuchElementException. */
@throws[NoSuchElementException]("when state does not exist")
def get: S
/** Get the state value as a scala Option. */
def getOption: Option[S]
/**
* Update the value of the state. Note that `null` is not a valid value, and it throws
* IllegalArgumentException.
*/
@throws[IllegalArgumentException]("when updating with null")
def update(newState: S): Unit
/** Remove this keyed state. */
def remove(): Unit
}

View file

@ -329,22 +329,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* Strategy to convert [[FlatMapGroupsWithState]] logical operator to physical operator
* in streaming plans. Conversion for batch plans is handled by [[BasicOperators]].
*/
object MapGroupsWithStateStrategy extends Strategy {
object FlatMapGroupsWithStateStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case FlatMapGroupsWithState(
f,
keyDeser,
valueDeser,
groupAttr,
dataAttr,
outputAttr,
stateDeser,
stateSer,
outputMode,
child,
_) =>
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _,
timeout, child) =>
val execPlan = FlatMapGroupsWithStateExec(
f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer,
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, outputMode,
timeout, batchTimestampMs = KeyedStateImpl.NO_BATCH_PROCESSING_TIMESTAMP,
planLater(child))
execPlan :: Nil
case _ =>
@ -392,7 +384,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
case logical.FlatMapGroupsWithState(
f, key, value, grouping, data, output, _, _, _, child, _) =>
f, key, value, grouping, data, output, _, _, _, _, child) =>
execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil
case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
execution.CoGroupExec(

View file

@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.debug._
import org.apache.spark.sql.execution.streaming.IncrementalExecution
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types._
@ -106,7 +106,8 @@ case class ExplainCommand(
if (logicalPlan.isStreaming) {
// This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the
// output mode does not matter since there is no `Sink`.
new IncrementalExecution(sparkSession, logicalPlan, OutputMode.Append(), "<unknown>", 0, 0)
new IncrementalExecution(
sparkSession, logicalPlan, OutputMode.Append(), "<unknown>", 0, OffsetSeqMetadata(0, 0))
} else {
sparkSession.sessionState.executePlan(logicalPlan)
}

View file

@ -0,0 +1,258 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, Expression, Literal, SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalKeyedState, ProcessingTimeTimeout}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode}
import org.apache.spark.sql.types.{BooleanType, IntegerType}
import org.apache.spark.util.CompletionIterator
/**
* Physical operator for executing `FlatMapGroupsWithState.`
*
* @param func function called on each group
* @param keyDeserializer used to extract the key object for each group.
* @param valueDeserializer used to extract the items in the iterator from an input row.
* @param groupingAttributes used to group the data
* @param dataAttributes used to read the data
* @param outputObjAttr used to define the output object
* @param stateEncoder used to serialize/deserialize state before calling `func`
* @param outputMode the output mode of `func`
* @param timeout used to timeout groups that have not received data in a while
* @param batchTimestampMs processing timestamp of the current batch.
*/
case class FlatMapGroupsWithStateExec(
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
outputObjAttr: Attribute,
stateId: Option[OperatorStateId],
stateEncoder: ExpressionEncoder[Any],
outputMode: OutputMode,
timeout: KeyedStateTimeout,
batchTimestampMs: Long,
child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter {
private val isTimeoutEnabled = timeout == ProcessingTimeTimeout
private val timestampTimeoutAttribute =
AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)()
private val stateAttributes: Seq[Attribute] = {
val encSchemaAttribs = stateEncoder.schema.toAttributes
if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs
}
import KeyedStateImpl._
/** Distribute by grouping attributes */
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(groupingAttributes) :: Nil
/** Ordering needed for using GroupingIterator */
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
Seq(groupingAttributes.map(SortOrder(_, Ascending)))
override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver
child.execute().mapPartitionsWithStateStore[InternalRow](
getStateId.checkpointLocation,
getStateId.operatorId,
getStateId.batchId,
groupingAttributes.toStructType,
stateAttributes.toStructType,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iterator) =>
val updater = new StateStoreUpdater(store)
// Generate a iterator that returns the rows grouped by the grouping function
// Note that this code ensures that the filtering for timeout occurs only after
// all the data has been processed. This is to ensure that the timeout information of all
// the keys with data is updated before they are processed for timeouts.
val outputIterator =
updater.updateStateForKeysWithData(iterator) ++ updater.updateStateForTimedOutKeys()
// Return an iterator of all the rows generated by all the keys, such that when fully
// consumed, all the state updates will be committed by the state store
CompletionIterator[InternalRow, Iterator[InternalRow]](
outputIterator,
{
store.commit()
longMetric("numTotalStateRows") += store.numKeys()
}
)
}
}
/** Helper class to update the state store */
class StateStoreUpdater(store: StateStore) {
// Converters for translating input keys, values, output data between rows and Java objects
private val getKeyObj =
ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
private val getValueObj =
ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
// Converter for translating state rows to Java objects
private val getStateObjFromRow = ObjectOperator.deserializeRowToObject(
stateEncoder.resolveAndBind().deserializer, stateAttributes)
// Converter for translating state Java objects to rows
private val stateSerializer = {
val encoderSerializer = stateEncoder.namedExpressions
if (isTimeoutEnabled) {
encoderSerializer :+ Literal(KeyedStateImpl.TIMEOUT_TIMESTAMP_NOT_SET)
} else {
encoderSerializer
}
}
private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer)
// Index of the additional metadata fields in the state row
private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute)
// Metrics
private val numUpdatedStateRows = longMetric("numUpdatedStateRows")
private val numOutputRows = longMetric("numOutputRows")
/**
* For every group, get the key, values and corresponding state and call the function,
* and return an iterator of rows
*/
def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = {
val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output)
groupedIter.flatMap { case (keyRow, valueRowIter) =>
val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
callFunctionAndUpdateState(
keyUnsafeRow,
valueRowIter,
store.get(keyUnsafeRow),
hasTimedOut = false)
}
}
/** Find the groups that have timeout set and are timing out right now, and call the function */
def updateStateForTimedOutKeys(): Iterator[InternalRow] = {
if (isTimeoutEnabled) {
val timingOutKeys = store.filter { case (_, stateRow) =>
val timeoutTimestamp = getTimeoutTimestamp(stateRow)
timeoutTimestamp != TIMEOUT_TIMESTAMP_NOT_SET && timeoutTimestamp < batchTimestampMs
}
timingOutKeys.flatMap { case (keyRow, stateRow) =>
callFunctionAndUpdateState(
keyRow,
Iterator.empty,
Some(stateRow),
hasTimedOut = true)
}
} else Iterator.empty
}
/**
* Call the user function on a key's data, update the state store, and return the return data
* iterator. Note that the store updating is lazy, that is, the store will be updated only
* after the returned iterator is fully consumed.
*/
private def callFunctionAndUpdateState(
keyRow: UnsafeRow,
valueRowIter: Iterator[InternalRow],
prevStateRowOption: Option[UnsafeRow],
hasTimedOut: Boolean): Iterator[InternalRow] = {
val keyObj = getKeyObj(keyRow) // convert key to objects
val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
val stateObjOption = getStateObj(prevStateRowOption)
val keyedState = new KeyedStateImpl(
stateObjOption, batchTimestampMs, isTimeoutEnabled, hasTimedOut)
// Call function, get the returned objects and convert them to rows
val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj =>
numOutputRows += 1
getOutputRow(obj)
}
// When the iterator is consumed, then write changes to state
def onIteratorCompletion: Unit = {
// Has the timeout information changed
if (keyedState.hasRemoved) {
store.remove(keyRow)
numUpdatedStateRows += 1
} else {
val previousTimeoutTimestamp = prevStateRowOption match {
case Some(row) => getTimeoutTimestamp(row)
case None => TIMEOUT_TIMESTAMP_NOT_SET
}
val stateRowToWrite = if (keyedState.hasUpdated) {
getStateRow(keyedState.get)
} else {
prevStateRowOption.orNull
}
val hasTimeoutChanged = keyedState.getTimeoutTimestamp != previousTimeoutTimestamp
val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged
if (shouldWriteState) {
if (stateRowToWrite == null) {
// This should never happen because checks in KeyedStateImpl should avoid cases
// where empty state would need to be written
throw new IllegalStateException(
"Attempting to write empty state")
}
setTimeoutTimestamp(stateRowToWrite, keyedState.getTimeoutTimestamp)
store.put(keyRow.copy(), stateRowToWrite.copy())
numUpdatedStateRows += 1
}
}
}
// Return an iterator of rows such that fully consumed, the updated state value will be saved
CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion)
}
/** Returns the state as Java object if defined */
def getStateObj(stateRowOption: Option[UnsafeRow]): Option[Any] = {
stateRowOption.map(getStateObjFromRow)
}
/** Returns the row for an updated state */
def getStateRow(obj: Any): UnsafeRow = {
getStateRowFromObj(obj)
}
/** Returns the timeout timestamp of a state row is set */
def getTimeoutTimestamp(stateRow: UnsafeRow): Long = {
if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else TIMEOUT_TIMESTAMP_NOT_SET
}
/** Set the timestamp in a state row */
def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = {
if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps)
}
}
}

View file

@ -37,13 +37,13 @@ class IncrementalExecution(
val outputMode: OutputMode,
val checkpointLocation: String,
val currentBatchId: Long,
val currentEventTimeWatermark: Long)
offsetSeqMetadata: OffsetSeqMetadata)
extends QueryExecution(sparkSession, logicalPlan) with Logging {
// TODO: make this always part of planning.
val streamingExtraStrategies =
sparkSession.sessionState.planner.StatefulAggregationStrategy +:
sparkSession.sessionState.planner.MapGroupsWithStateStrategy +:
sparkSession.sessionState.planner.FlatMapGroupsWithStateStrategy +:
sparkSession.sessionState.planner.StreamingRelationStrategy +:
sparkSession.sessionState.planner.StreamingDeduplicationStrategy +:
sparkSession.sessionState.experimentalMethods.extraStrategies
@ -88,12 +88,13 @@ class IncrementalExecution(
keys,
Some(stateId),
Some(outputMode),
Some(currentEventTimeWatermark),
Some(offsetSeqMetadata.batchWatermarkMs),
agg.withNewChildren(
StateStoreRestoreExec(
keys,
Some(stateId),
child) :: Nil))
case StreamingDeduplicateExec(keys, child, None, None) =>
val stateId =
OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
@ -102,13 +103,12 @@ class IncrementalExecution(
keys,
child,
Some(stateId),
Some(currentEventTimeWatermark))
case FlatMapGroupsWithStateExec(
f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) =>
Some(offsetSeqMetadata.batchWatermarkMs))
case m: FlatMapGroupsWithStateExec =>
val stateId =
OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
FlatMapGroupsWithStateExec(
f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child)
m.copy(stateId = Some(stateId), batchTimestampMs = offsetSeqMetadata.batchTimestampMs)
}
}

View file

@ -17,15 +17,37 @@
package org.apache.spark.sql.execution.streaming
import org.apache.spark.sql.KeyedState
import org.apache.commons.lang3.StringUtils
/** Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. */
private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedState[S] {
import org.apache.spark.sql.streaming.KeyedState
import org.apache.spark.unsafe.types.CalendarInterval
/**
* Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe.
* @param optionalValue Optional value of the state
* @param batchProcessingTimeMs Processing time of current batch, used to calculate timestamp
* for processing time timeouts
* @param isTimeoutEnabled Whether timeout is enabled. This will be used to check whether the user
* is allowed to configure timeouts.
* @param hasTimedOut Whether the key for which this state wrapped is being created is
* getting timed out or not.
*/
private[sql] class KeyedStateImpl[S](
optionalValue: Option[S],
batchProcessingTimeMs: Long,
isTimeoutEnabled: Boolean,
override val hasTimedOut: Boolean) extends KeyedState[S] {
import KeyedStateImpl._
// Constructor to create dummy state when using mapGroupsWithState in a batch query
def this(optionalValue: Option[S]) = this(
optionalValue, -1, isTimeoutEnabled = false, hasTimedOut = false)
private var value: S = optionalValue.getOrElse(null.asInstanceOf[S])
private var defined: Boolean = optionalValue.isDefined
private var updated: Boolean = false
// whether value has been updated (but not removed)
private var updated: Boolean = false // whether value has been updated (but not removed)
private var removed: Boolean = false // whether value has been removed
private var timeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET
// ========= Public API =========
override def exists: Boolean = defined
@ -60,6 +82,55 @@ private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedStat
defined = false
updated = false
removed = true
timeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET
}
override def setTimeoutDuration(durationMs: Long): Unit = {
if (!isTimeoutEnabled) {
throw new UnsupportedOperationException(
"Cannot set timeout information without enabling timeout in map/flatMapGroupsWithState")
}
if (!defined) {
throw new IllegalStateException(
"Cannot set timeout information without any state value, " +
"state has either not been initialized, or has already been removed")
}
if (durationMs <= 0) {
throw new IllegalArgumentException("Timeout duration must be positive")
}
if (!removed && batchProcessingTimeMs != NO_BATCH_PROCESSING_TIMESTAMP) {
timeoutTimestamp = durationMs + batchProcessingTimeMs
} else {
// This is being called in a batch query, hence no processing timestamp.
// Just ignore any attempts to set timeout.
}
}
override def setTimeoutDuration(duration: String): Unit = {
if (StringUtils.isBlank(duration)) {
throw new IllegalArgumentException(
"The window duration, slide duration and start time cannot be null or blank.")
}
val intervalString = if (duration.startsWith("interval")) {
duration
} else {
"interval " + duration
}
val cal = CalendarInterval.fromString(intervalString)
if (cal == null) {
throw new IllegalArgumentException(
s"The provided duration ($duration) is not valid.")
}
if (cal.milliseconds < 0 || cal.months < 0) {
throw new IllegalArgumentException("Timeout duration must be positive")
}
val delayMs = {
val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31
cal.milliseconds + cal.months * millisPerMonth
}
setTimeoutDuration(delayMs)
}
override def toString: String = {
@ -69,12 +140,21 @@ private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedStat
// ========= Internal API =========
/** Whether the state has been marked for removing */
def isRemoved: Boolean = {
removed
}
def hasRemoved: Boolean = removed
/** Whether the state has been been updated */
def isUpdated: Boolean = {
updated
}
/** Whether the state has been updated */
def hasUpdated: Boolean = updated
/** Return timeout timestamp or `TIMEOUT_TIMESTAMP_NOT_SET` if not set */
def getTimeoutTimestamp: Long = timeoutTimestamp
}
private[sql] object KeyedStateImpl {
// Value used in the state row to represent the lack of any timeout timestamp
val TIMEOUT_TIMESTAMP_NOT_SET = -1L
// Value to represent that no batch processing timestamp is passed to KeyedStateImpl. This is
// used in batch queries where there are no streaming batches and timeouts.
val NO_BATCH_PROCESSING_TIMESTAMP = -1L
}

View file

@ -590,7 +590,7 @@ class StreamExecution(
outputMode,
checkpointFile("state"),
currentBatchId,
offsetSeqMetadata.batchWatermarkMs)
offsetSeqMetadata)
lastExecution.executedPlan // Force the lazy generation of execution plan
}

View file

@ -73,7 +73,12 @@ private[state] class HDFSBackedStateStoreProvider(
hadoopConf: Configuration
) extends StateStoreProvider with Logging {
type MapType = java.util.HashMap[UnsafeRow, UnsafeRow]
// ConcurrentHashMap is used because it generates fail-safe iterators on filtering
// - The iterator is weakly consistent with the map, i.e., iterator's data reflect the values in
// the map when the iterator was created
// - Any updates to the map while iterating through the filtered iterator does not throw
// java.util.ConcurrentModificationException
type MapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow]
/** Implementation of [[StateStore]] API which is backed by a HDFS-compatible file system */
class HDFSBackedStateStore(val version: Long, mapToUpdate: MapType)
@ -99,6 +104,16 @@ private[state] class HDFSBackedStateStoreProvider(
Option(mapToUpdate.get(key))
}
override def filter(
condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = {
mapToUpdate
.entrySet
.asScala
.iterator
.filter { entry => condition(entry.getKey, entry.getValue) }
.map { entry => (entry.getKey, entry.getValue) }
}
override def put(key: UnsafeRow, value: UnsafeRow): Unit = {
verify(state == UPDATING, "Cannot put after already committed or aborted")
@ -227,7 +242,7 @@ private[state] class HDFSBackedStateStoreProvider(
}
override def toString(): String = {
s"HDFSStateStore[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]"
s"HDFSStateStore[id=(op=${id.operatorId},part=${id.partitionId}),dir=$baseDir]"
}
}

View file

@ -50,6 +50,15 @@ trait StateStore {
/** Get the current value of a key. */
def get(key: UnsafeRow): Option[UnsafeRow]
/**
* Return an iterator of key-value pairs that satisfy a certain condition.
* Note that the iterator must be fail-safe towards modification to the store, that is,
* it must be based on the snapshot of store the time of this call, and any change made to the
* store while iterating through iterator should not cause the iterator to fail or have
* any affect on the values in the iterator.
*/
def filter(condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)]
/** Put a new value for a key. */
def put(key: UnsafeRow, value: UnsafeRow): Unit

View file

@ -19,17 +19,18 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState}
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState, ProcessingTimeTimeout}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{DataType, NullType, StructType}
import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode}
import org.apache.spark.sql.types._
import org.apache.spark.util.CompletionIterator
@ -256,94 +257,6 @@ case class StateStoreSaveExec(
override def outputPartitioning: Partitioning = child.outputPartitioning
}
/** Physical operator for executing streaming flatMapGroupsWithState. */
case class FlatMapGroupsWithStateExec(
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
outputObjAttr: Attribute,
stateId: Option[OperatorStateId],
stateDeserializer: Expression,
stateSerializer: Seq[NamedExpression],
child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter {
override def outputPartitioning: Partitioning = child.outputPartitioning
/** Distribute by grouping attributes */
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(groupingAttributes) :: Nil
/** Ordering needed for using GroupingIterator */
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
Seq(groupingAttributes.map(SortOrder(_, Ascending)))
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsWithStateStore[InternalRow](
getStateId.checkpointLocation,
getStateId.operatorId,
getStateId.batchId,
groupingAttributes.toStructType,
child.output.toStructType,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
val numTotalStateRows = longMetric("numTotalStateRows")
val numUpdatedStateRows = longMetric("numUpdatedStateRows")
val numOutputRows = longMetric("numOutputRows")
// Generate a iterator that returns the rows grouped by the grouping function
val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
// Converters to and from object and rows
val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
val getStateObj =
ObjectOperator.deserializeRowToObject(stateDeserializer)
val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer)
// For every group, get the key, values and corresponding state and call the function,
// and return an iterator of rows
val allRowsIterator = groupedIter.flatMap { case (keyRow, valueRowIter) =>
val key = keyRow.asInstanceOf[UnsafeRow]
val keyObj = getKeyObj(keyRow) // convert key to objects
val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
val stateObjOption = store.get(key).map(getStateObj) // get existing state if any
val wrappedState = new KeyedStateImpl(stateObjOption)
val mappedIterator = func(keyObj, valueObjIter, wrappedState).map { obj =>
numOutputRows += 1
getOutputRow(obj) // convert back to rows
}
// Return an iterator of rows generated this key,
// such that fully consumed, the updated state value will be saved
CompletionIterator[InternalRow, Iterator[InternalRow]](
mappedIterator, {
// When the iterator is consumed, then write changes to state
if (wrappedState.isRemoved) {
store.remove(key)
numUpdatedStateRows += 1
} else if (wrappedState.isUpdated) {
store.put(key, outputStateObj(wrappedState.get))
numUpdatedStateRows += 1
}
})
}
// Return an iterator of all the rows generated by all the keys, such that when fully
// consumer, all the state updates will be committed by the state store
CompletionIterator[InternalRow, Iterator[InternalRow]](allRowsIterator, {
store.commit()
numTotalStateRows += store.numKeys()
})
}
}
}
/** Physical operator for executing streaming Deduplicate. */
case class StreamingDeduplicateExec(
keyExpressions: Seq[Attribute],

View file

@ -0,0 +1,214 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.streaming
import org.apache.spark.annotation.{Experimental, InterfaceStability}
import org.apache.spark.sql.{Encoder, KeyValueGroupedDataset}
import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
/**
* :: Experimental ::
*
* Wrapper class for interacting with keyed state data in `mapGroupsWithState` and
* `flatMapGroupsWithState` operations on
* [[KeyValueGroupedDataset]].
*
* Detail description on `[map/flatMap]GroupsWithState` operation
* --------------------------------------------------------------
* Both, `mapGroupsWithState` and `flatMapGroupsWithState` in [[KeyValueGroupedDataset]]
* will invoke the user-given function on each group (defined by the grouping function in
* `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations.
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger.
* That is, in every batch of the `streaming.StreamingQuery`,
* the function will be invoked once for each group that has data in the trigger. Furthermore,
* if timeout is set, then the function will invoked on timed out keys (more detail below).
*
* The function is invoked with following parameters.
* - The key of the group.
* - An iterator containing all the values for this key.
* - A user-defined state object set by previous invocations of the given function.
* In case of a batch Dataset, there is only one invocation and state object will be empty as
* there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState`
* is equivalent to `[map/flatMap]Groups` and any updates to the state and/or timeouts have
* no effect.
*
* Important points to note about the function.
* - In a trigger, the function will be called only the groups present in the batch. So do not
* assume that the function will be called in every trigger for every group that has state.
* - There is no guaranteed ordering of values in the iterator in the function, neither with
* batch, nor with streaming Datasets.
* - All the data will be shuffled before applying the function.
* - If timeout is set, then the function will also be called with no values.
* See more details on KeyedStateTimeout` below.
*
* Important points to note about using `KeyedState`.
* - The value of the state cannot be null. So updating state with null will throw
* `IllegalArgumentException`.
* - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers.
* - If `remove()` is called, then `exists()` will return `false`,
* `get()` will throw `NoSuchElementException` and `getOption()` will return `None`
* - After that, if `update(newState)` is called, then `exists()` will again return `true`,
* `get()` and `getOption()`will return the updated value.
*
* Important points to note about using `KeyedStateTimeout`.
* - The timeout type is a global param across all the keys (set as `timeout` param in
* `[map|flatMap]GroupsWithState`, but the exact timeout duration is configurable per key
* (by calling `setTimeout...()` in `KeyedState`).
* - When the timeout occurs for a key, the function is called with no values, and
* `KeyedState.hasTimedOut()` set to true.
* - The timeout is reset for key every time the function is called on the key, that is,
* when the key has new data, or the key has timed out. So the user has to set the timeout
* duration every time the function is called, otherwise there will not be any timeout set.
* - Guarantees provided on processing-time-based timeout of key, when timeout duration is D ms:
* - Timeout will never be called before real clock time has advanced by D ms
* - Timeout will be called eventually when there is a trigger in the query
* (i.e. after D ms). So there is a no strict upper bound on when the timeout would occur.
* For example, the trigger interval of the query will affect when the timeout is actually hit.
* If there is no data in the stream (for any key) for a while, then their will not be
* any trigger and timeout will not be hit until there is data.
*
* Scala example of using KeyedState in `mapGroupsWithState`:
* {{{
* // A mapping function that maintains an integer state for string keys and returns a string.
* // Additionally, it sets a timeout to remove the state if it has not received data for an hour.
* def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = {
*
* if (state.hasTimedOut) { // If called when timing out, remove the state
* state.remove()
*
* } else if (state.exists) { // If state exists, use it for processing
* val existingState = state.get // Get the existing state
* val shouldRemove = ... // Decide whether to remove the state
* if (shouldRemove) {
* state.remove() // Remove the state
*
* } else {
* val newState = ...
* state.update(newState) // Set the new state
* state.setTimeoutDuration("1 hour") // Set the timeout
* }
*
* } else {
* val initialState = ...
* state.update(initialState) // Set the initial state
* state.setTimeoutDuration("1 hour") // Set the timeout
* }
* ...
* // return something
* }
*
* dataset
* .groupByKey(...)
* .mapGroupsWithState(KeyedStateTimeout.ProcessingTimeTimeout)(mappingFunction)
* }}}
*
* Java example of using `KeyedState`:
* {{{
* // A mapping function that maintains an integer state for string keys and returns a string.
* // Additionally, it sets a timeout to remove the state if it has not received data for an hour.
* MapGroupsWithStateFunction<String, Integer, Integer, String> mappingFunction =
* new MapGroupsWithStateFunction<String, Integer, Integer, String>() {
*
* @Override
* public String call(String key, Iterator<Integer> value, KeyedState<Integer> state) {
* if (state.hasTimedOut()) { // If called when timing out, remove the state
* state.remove();
*
* } else if (state.exists()) { // If state exists, use it for processing
* int existingState = state.get(); // Get the existing state
* boolean shouldRemove = ...; // Decide whether to remove the state
* if (shouldRemove) {
* state.remove(); // Remove the state
*
* } else {
* int newState = ...;
* state.update(newState); // Set the new state
* state.setTimeoutDuration("1 hour"); // Set the timeout
* }
*
* } else {
* int initialState = ...; // Set the initial state
* state.update(initialState);
* state.setTimeoutDuration("1 hour"); // Set the timeout
* }
* ...
* // return something
* }
* };
*
* dataset
* .groupByKey(...)
* .mapGroupsWithState(
* mappingFunction, Encoders.INT, Encoders.STRING, KeyedStateTimeout.ProcessingTimeTimeout);
* }}}
*
* @tparam S User-defined type of the state to be stored for each key. Must be encodable into
* Spark SQL types (see [[Encoder]] for more details).
* @since 2.2.0
*/
@Experimental
@InterfaceStability.Evolving
trait KeyedState[S] extends LogicalKeyedState[S] {
/** Whether state exists or not. */
def exists: Boolean
/** Get the state value if it exists, or throw NoSuchElementException. */
@throws[NoSuchElementException]("when state does not exist")
def get: S
/** Get the state value as a scala Option. */
def getOption: Option[S]
/**
* Update the value of the state. Note that `null` is not a valid value, and it throws
* IllegalArgumentException.
*/
@throws[IllegalArgumentException]("when updating with null")
def update(newState: S): Unit
/** Remove this keyed state. Note that this resets any timeout configuration as well. */
def remove(): Unit
/**
* Whether the function has been called because the key has timed out.
* @note This can return true only when timeouts are enabled in `[map/flatmap]GroupsWithStates`.
*/
def hasTimedOut: Boolean
/**
* Set the timeout duration in ms for this key.
* @note Timeouts must be enabled in `[map/flatmap]GroupsWithStates`.
*/
@throws[IllegalArgumentException]("if 'durationMs' is not positive")
@throws[IllegalStateException]("when state is either not initialized, or already removed")
@throws[UnsupportedOperationException](
"if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query")
def setTimeoutDuration(durationMs: Long): Unit
/**
* Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc.
* @note, Timeouts must be enabled in `[map/flatmap]GroupsWithStates`.
*/
@throws[IllegalArgumentException]("if 'duration' is not a valid duration")
@throws[IllegalStateException]("when state is either not initialized, or already removed")
@throws[UnsupportedOperationException](
"if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query")
def setTimeoutDuration(duration: String): Unit
}

View file

@ -23,6 +23,7 @@ import java.sql.Date;
import java.sql.Timestamp;
import java.util.*;
import org.apache.spark.sql.streaming.KeyedStateTimeout;
import org.apache.spark.sql.streaming.OutputMode;
import scala.Tuple2;
import scala.Tuple3;
@ -208,7 +209,8 @@ public class JavaDatasetSuite implements Serializable {
},
OutputMode.Append(),
Encoders.LONG(),
Encoders.STRING());
Encoders.STRING(),
KeyedStateTimeout.NoTimeout());
Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList()));

View file

@ -123,6 +123,30 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4))
}
test("filter and concurrent updates") {
val provider = newStoreProvider()
// Verify state before starting a new set of updates
assert(provider.latestIterator.isEmpty)
val store = provider.getStore(0)
put(store, "a", 1)
put(store, "b", 2)
// Updates should work while iterating of filtered entries
val filtered = store.filter { case (keyRow, _) => rowToString(keyRow) == "a" }
filtered.foreach { case (keyRow, valueRow) =>
store.put(keyRow, intToRow(rowToInt(valueRow) + 1))
}
assert(get(store, "a") === Some(2))
// Removes should work while iterating of filtered entries
val filtered2 = store.filter { case (keyRow, _) => rowToString(keyRow) == "b" }
filtered2.foreach { case (keyRow, _) =>
store.remove(keyRow)
}
assert(get(store, "b") === None)
}
test("updates iterator with all combos of updates and removes") {
val provider = newStoreProvider()
var currentVersion: Int = 0

View file

@ -17,20 +17,33 @@
package org.apache.spark.sql.streaming
import java.util
import java.util.concurrent.ConcurrentHashMap
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkException
import org.apache.spark.sql.KeyedState
import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState
import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.streaming.{KeyedStateImpl, MemoryStream}
import org.apache.spark.sql.execution.streaming.state.StateStore
import org.apache.spark.sql.execution.RDDScanExec
import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, KeyedStateImpl, MemoryStream}
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate}
import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore
import org.apache.spark.sql.types.{DataType, IntegerType}
/** Class to check custom state types */
case class RunningCount(count: Long)
case class Result(key: Long, count: Int)
class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
import testImplicits._
import KeyedStateImpl._
override def afterAll(): Unit = {
super.afterAll()
@ -54,8 +67,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}
}
assert(state.getOption === expectedData)
assert(state.isUpdated === shouldBeUpdated)
assert(state.isRemoved === shouldBeRemoved)
assert(state.hasUpdated === shouldBeUpdated)
assert(state.hasRemoved === shouldBeRemoved)
}
// Updating empty state
@ -83,6 +96,79 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}
}
test("KeyedState - setTimeoutDuration, hasTimedOut") {
import KeyedStateImpl._
var state: KeyedStateImpl[Int] = null
// When isTimeoutEnabled = false, then setTimeoutDuration() is not allowed
for (initState <- Seq(None, Some(5))) {
// for different initial state
state = new KeyedStateImpl(initState, 1000, isTimeoutEnabled = false, hasTimedOut = false)
assert(state.hasTimedOut === false)
assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
intercept[UnsupportedOperationException] {
state.setTimeoutDuration(1000)
}
intercept[UnsupportedOperationException] {
state.setTimeoutDuration("1 day")
}
assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
}
def testTimeoutNotAllowed(): Unit = {
intercept[IllegalStateException] {
state.setTimeoutDuration(1000)
}
assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
intercept[IllegalStateException] {
state.setTimeoutDuration("2 second")
}
assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
}
// When isTimeoutEnabled = true, then setTimeoutDuration() is not allowed until the
// state is be defined
state = new KeyedStateImpl(None, 1000, isTimeoutEnabled = true, hasTimedOut = false)
assert(state.hasTimedOut === false)
assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
testTimeoutNotAllowed()
// After state has been set, setTimeoutDuration() is allowed, and
// getTimeoutTimestamp returned correct timestamp
state.update(5)
assert(state.hasTimedOut === false)
assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
state.setTimeoutDuration(1000)
assert(state.getTimeoutTimestamp === 2000)
state.setTimeoutDuration("2 second")
assert(state.getTimeoutTimestamp === 3000)
assert(state.hasTimedOut === false)
// setTimeoutDuration() with negative values or 0 is not allowed
def testIllegalTimeout(body: => Unit): Unit = {
intercept[IllegalArgumentException] { body }
assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
}
state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = false)
testIllegalTimeout { state.setTimeoutDuration(-1000) }
testIllegalTimeout { state.setTimeoutDuration(0) }
testIllegalTimeout { state.setTimeoutDuration("-2 second") }
testIllegalTimeout { state.setTimeoutDuration("-1 month") }
testIllegalTimeout { state.setTimeoutDuration("1 month -1 day") }
// Test remove() clear timeout timestamp, and setTimeoutDuration() is not allowed after that
state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = false)
state.remove()
assert(state.hasTimedOut === false)
assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
testTimeoutNotAllowed()
// Test hasTimedOut = true
state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = true)
assert(state.hasTimedOut === true)
assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
}
test("KeyedState - primitive type") {
var intState = new KeyedStateImpl[Int](None)
intercept[NoSuchElementException] {
@ -100,6 +186,151 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}
}
// Values used for testing StateStoreUpdater
val currentTimestamp = 1000
val beforeCurrentTimestamp = 999
val afterCurrentTimestamp = 1001
// Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout is disabled
for (priorState <- Seq(None, Some(0))) {
val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state"
val testName = s"timeout disabled - $priorStateStr - "
testStateUpdateWithData(
testName + "no update",
stateUpdates = state => { /* do nothing */ },
timeoutType = KeyedStateTimeout.NoTimeout,
priorState = priorState,
expectedState = priorState) // should not change
testStateUpdateWithData(
testName + "state updated",
stateUpdates = state => { state.update(5) },
timeoutType = KeyedStateTimeout.NoTimeout,
priorState = priorState,
expectedState = Some(5)) // should change
testStateUpdateWithData(
testName + "state removed",
stateUpdates = state => { state.remove() },
timeoutType = KeyedStateTimeout.NoTimeout,
priorState = priorState,
expectedState = None) // should be removed
}
// Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout is enabled
for (priorState <- Seq(None, Some(0))) {
for (priorTimeoutTimestamp <- Seq(TIMEOUT_TIMESTAMP_NOT_SET, 1000)) {
var testName = s"timeout enabled - "
if (priorState.nonEmpty) {
testName += "prior state set, "
if (priorTimeoutTimestamp == 1000) {
testName += "prior timeout set - "
} else {
testName += "no prior timeout - "
}
} else {
testName += "no prior state - "
}
testStateUpdateWithData(
testName + "no update",
stateUpdates = state => { /* do nothing */ },
timeoutType = KeyedStateTimeout.ProcessingTimeTimeout,
priorState = priorState,
priorTimeoutTimestamp = priorTimeoutTimestamp,
expectedState = priorState, // state should not change
expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset
testStateUpdateWithData(
testName + "state updated",
stateUpdates = state => { state.update(5) },
timeoutType = KeyedStateTimeout.ProcessingTimeTimeout,
priorState = priorState,
priorTimeoutTimestamp = priorTimeoutTimestamp,
expectedState = Some(5), // state should change
expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset
testStateUpdateWithData(
testName + "state removed",
stateUpdates = state => { state.remove() },
timeoutType = KeyedStateTimeout.ProcessingTimeTimeout,
priorState = priorState,
priorTimeoutTimestamp = priorTimeoutTimestamp,
expectedState = None) // state should be removed
testStateUpdateWithData(
testName + "timeout and state updated",
stateUpdates = state => { state.update(5); state.setTimeoutDuration(5000) },
timeoutType = KeyedStateTimeout.ProcessingTimeTimeout,
priorState = priorState,
priorTimeoutTimestamp = priorTimeoutTimestamp,
expectedState = Some(5), // state should change
expectedTimeoutTimestamp = currentTimestamp + 5000) // timestamp should change
}
}
// Tests for StateStoreUpdater.updateStateForTimedOutKeys()
val preTimeoutState = Some(5)
testStateUpdateWithTimeout(
"should not timeout",
stateUpdates = state => { assert(false, "function called without timeout") },
priorTimeoutTimestamp = afterCurrentTimestamp,
expectedState = preTimeoutState, // state should not change
expectedTimeoutTimestamp = afterCurrentTimestamp) // timestamp should not change
testStateUpdateWithTimeout(
"should timeout - no update/remove",
stateUpdates = state => { /* do nothing */ },
priorTimeoutTimestamp = beforeCurrentTimestamp,
expectedState = preTimeoutState, // state should not change
expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset
testStateUpdateWithTimeout(
"should timeout - update state",
stateUpdates = state => { state.update(5) },
priorTimeoutTimestamp = beforeCurrentTimestamp,
expectedState = Some(5), // state should change
expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset
testStateUpdateWithTimeout(
"should timeout - remove state",
stateUpdates = state => { state.remove() },
priorTimeoutTimestamp = beforeCurrentTimestamp,
expectedState = None, // state should be removed
expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET)
testStateUpdateWithTimeout(
"should timeout - timeout updated",
stateUpdates = state => { state.setTimeoutDuration(2000) },
priorTimeoutTimestamp = beforeCurrentTimestamp,
expectedState = preTimeoutState, // state should not change
expectedTimeoutTimestamp = currentTimestamp + 2000) // timestamp should change
testStateUpdateWithTimeout(
"should timeout - timeout and state updated",
stateUpdates = state => { state.update(5); state.setTimeoutDuration(2000) },
priorTimeoutTimestamp = beforeCurrentTimestamp,
expectedState = Some(5), // state should change
expectedTimeoutTimestamp = currentTimestamp + 2000) // timestamp should change
test("StateStoreUpdater - rows are cloned before writing to StateStore") {
// function for running count
val func = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => {
state.update(state.getOption.getOrElse(0) + values.size)
Iterator.empty
}
val store = newStateStore()
val plan = newFlatMapGroupsWithStateExec(func)
val updater = new plan.StateStoreUpdater(store)
val data = Seq(1, 1, 2)
val returnIter = updater.updateStateForKeysWithData(data.iterator.map(intToRow))
returnIter.size // consume the iterator to force store updates
val storeData = store.iterator.map { case (k, v) => (rowToInt(k), rowToInt(v)) }.toSet
assert(storeData === Set((1, 2), (2, 1)))
}
test("flatMapGroupsWithState - streaming") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count if state is defined, otherwise does not return anything
@ -119,7 +350,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str)
.flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc)
testStream(result, Update)(
AddData(inputData, "a"),
@ -162,8 +393,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str)
.flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc)
testStream(result, Update)(
AddData(inputData, "a", "a", "b"),
CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")),
@ -178,15 +408,115 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
)
}
test("flatMapGroupsWithState - streaming + aggregation") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
val count = state.getOption.map(_.count).getOrElse(0L) + values.size
if (count == 3) {
state.remove()
Iterator(key -> "-1")
} else {
state.update(RunningCount(count))
Iterator(key -> count.toString)
}
}
val inputData = MemoryStream[String]
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Append, KeyedStateTimeout.NoTimeout)(stateFunc)
.groupByKey(_._1)
.count()
testStream(result, Complete)(
AddData(inputData, "a"),
CheckLastBatch(("a", 1)),
AddData(inputData, "a", "b"),
// mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1
CheckLastBatch(("a", 2), ("b", 1)),
StopStream,
StartStream(),
AddData(inputData, "a", "b"),
// mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ;
// so increment a and b by 1
CheckLastBatch(("a", 3), ("b", 2)),
StopStream,
StartStream(),
AddData(inputData, "a", "c"),
// mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ;
// so increment a and c by 1
CheckLastBatch(("a", 4), ("b", 2), ("c", 1))
)
}
test("flatMapGroupsWithState - batch") {
// Function that returns running count only if its even, otherwise does not return
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
if (state.exists) throw new IllegalArgumentException("state.exists should be false")
Iterator((key, values.size))
}
checkAnswer(
Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc, Update).toDF,
Seq(("a", 2), ("b", 1)).toDF)
val df = Seq("a", "a", "b").toDS
.groupByKey(x => x)
.flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc).toDF
checkAnswer(df, Seq(("a", 2), ("b", 1)).toDF)
}
test("flatMapGroupsWithState - streaming with processing time timeout") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
if (state.hasTimedOut) {
state.remove()
Iterator((key, "-1"))
} else {
val count = state.getOption.map(_.count).getOrElse(0L) + values.size
state.update(RunningCount(count))
state.setTimeoutDuration("10 seconds")
Iterator((key, count.toString))
}
}
val clock = new StreamManualClock
val inputData = MemoryStream[String]
val timeout = KeyedStateTimeout.ProcessingTimeTimeout
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, timeout)(stateFunc)
testStream(result, Update)(
StartStream(ProcessingTime("1 second"), triggerClock = clock),
AddData(inputData, "a"),
AdvanceManualClock(1 * 1000),
CheckLastBatch(("a", "1")),
assertNumStateRows(total = 1, updated = 1),
AddData(inputData, "b"),
AdvanceManualClock(1 * 1000),
CheckLastBatch(("b", "1")),
assertNumStateRows(total = 2, updated = 1),
AddData(inputData, "b"),
AdvanceManualClock(10 * 1000),
CheckLastBatch(("a", "-1"), ("b", "2")),
assertNumStateRows(total = 1, updated = 2),
StopStream,
StartStream(ProcessingTime("1 second"), triggerClock = clock),
AddData(inputData, "c"),
AdvanceManualClock(20 * 1000),
CheckLastBatch(("b", "-1"), ("c", "1")),
assertNumStateRows(total = 1, updated = 2),
AddData(inputData, "c"),
AdvanceManualClock(20 * 1000),
CheckLastBatch(("c", "2")),
assertNumStateRows(total = 1, updated = 1)
)
}
test("mapGroupsWithState - streaming") {
@ -230,50 +560,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
)
}
test("flatMapGroupsWithState - streaming + aggregation") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
val count = state.getOption.map(_.count).getOrElse(0L) + values.size
if (count == 3) {
state.remove()
Iterator(key -> "-1")
} else {
state.update(RunningCount(count))
Iterator(key -> count.toString)
}
}
val inputData = MemoryStream[String]
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(stateFunc, Append) // Types = State: MyState, Out: (Str, Str)
.groupByKey(_._1)
.count()
testStream(result, Complete)(
AddData(inputData, "a"),
CheckLastBatch(("a", 1)),
AddData(inputData, "a", "b"),
// mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1
CheckLastBatch(("a", 2), ("b", 1)),
StopStream,
StartStream(),
AddData(inputData, "a", "b"),
// mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ;
// so increment a and b by 1
CheckLastBatch(("a", 3), ("b", 2)),
StopStream,
StartStream(),
AddData(inputData, "a", "c"),
// mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ;
// so increment a and c by 1
CheckLastBatch(("a", 4), ("b", 2), ("c", 1))
)
}
test("mapGroupsWithState - batch") {
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
if (state.exists) throw new IllegalArgumentException("state.exists should be false")
@ -322,23 +608,185 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
)
}
test("output partitioning is unknown") {
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => key
val inputData = MemoryStream[String]
val result = inputData.toDS.groupByKey(x => x).mapGroupsWithState(stateFunc)
result
testStream(result, Update)(
AddData(inputData, "a"),
CheckLastBatch("a"),
AssertOnQuery(_.lastExecution.executedPlan.outputPartitioning === UnknownPartitioning(0))
)
}
test("disallow complete mode") {
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[Int]) => {
Iterator[String]()
}
var e = intercept[IllegalArgumentException] {
MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, Complete)
MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(
OutputMode.Complete, KeyedStateTimeout.NoTimeout)(stateFunc)
}
assert(e.getMessage === "The output mode of function should be append or update")
val javaStateFunc = new FlatMapGroupsWithStateFunction[String, String, Int, String] {
import java.util.{Iterator => JIterator}
override def call(
key: String,
values: JIterator[String],
state: KeyedState[Int]): JIterator[String] = { null }
}
e = intercept[IllegalArgumentException] {
MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, "complete")
MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(
javaStateFunc, OutputMode.Complete,
implicitly[Encoder[Int]], implicitly[Encoder[String]], KeyedStateTimeout.NoTimeout)
}
assert(e.getMessage === "The output mode of function should be append or update")
}
def testStateUpdateWithData(
testName: String,
stateUpdates: KeyedState[Int] => Unit,
timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout,
priorState: Option[Int],
priorTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET,
expectedState: Option[Int] = None,
expectedTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET): Unit = {
if (priorState.isEmpty && priorTimeoutTimestamp != TIMEOUT_TIMESTAMP_NOT_SET) {
return // there can be no prior timestamp, when there is no prior state
}
test(s"StateStoreUpdater - updates with data - $testName") {
val mapGroupsFunc = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => {
assert(state.hasTimedOut === false, "hasTimedOut not false")
assert(values.nonEmpty, "Some value is expected")
stateUpdates(state)
Iterator.empty
}
testStateUpdate(
testTimeoutUpdates = false, mapGroupsFunc, timeoutType,
priorState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp)
}
}
def testStateUpdateWithTimeout(
testName: String,
stateUpdates: KeyedState[Int] => Unit,
priorTimeoutTimestamp: Long,
expectedState: Option[Int],
expectedTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET): Unit = {
test(s"StateStoreUpdater - updates for timeout - $testName") {
val mapGroupsFunc = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => {
assert(state.hasTimedOut === true, "hasTimedOut not true")
assert(values.isEmpty, "values not empty")
stateUpdates(state)
Iterator.empty
}
testStateUpdate(
testTimeoutUpdates = true, mapGroupsFunc, KeyedStateTimeout.ProcessingTimeTimeout,
preTimeoutState, priorTimeoutTimestamp,
expectedState, expectedTimeoutTimestamp)
}
}
def testStateUpdate(
testTimeoutUpdates: Boolean,
mapGroupsFunc: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int],
timeoutType: KeyedStateTimeout,
priorState: Option[Int],
priorTimeoutTimestamp: Long,
expectedState: Option[Int],
expectedTimeoutTimestamp: Long): Unit = {
val store = newStateStore()
val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec(
mapGroupsFunc, timeoutType, currentTimestamp)
val updater = new mapGroupsSparkPlan.StateStoreUpdater(store)
val key = intToRow(0)
// Prepare store with prior state configs
if (priorState.nonEmpty) {
val row = updater.getStateRow(priorState.get)
updater.setTimeoutTimestamp(row, priorTimeoutTimestamp)
store.put(key.copy(), row.copy())
}
// Call updating function to update state store
val returnedIter = if (testTimeoutUpdates) {
updater.updateStateForTimedOutKeys()
} else {
updater.updateStateForKeysWithData(Iterator(key))
}
returnedIter.size // consumer the iterator to force state updates
// Verify updated state in store
val updatedStateRow = store.get(key)
assert(
updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState,
"final state not as expected")
if (updatedStateRow.nonEmpty) {
assert(
updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp,
"final timeout timestamp not as expected")
}
}
def newFlatMapGroupsWithStateExec(
func: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int],
timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout,
batchTimestampMs: Long = NO_BATCH_PROCESSING_TIMESTAMP): FlatMapGroupsWithStateExec = {
MemoryStream[Int]
.toDS
.groupByKey(x => x)
.flatMapGroupsWithState[Int, Int](Append, timeoutConf = timeoutType)(func)
.logicalPlan.collectFirst {
case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) =>
FlatMapGroupsWithStateExec(
f, k, v, g, d, o, None, s, m, t, currentTimestamp,
RDDScanExec(g, null, "rdd"))
}.get
}
def newStateStore(): StateStore = new MemoryStateStore()
val intProj = UnsafeProjection.create(Array[DataType](IntegerType))
def intToRow(i: Int): UnsafeRow = {
intProj.apply(new GenericInternalRow(Array[Any](i))).copy()
}
def rowToInt(row: UnsafeRow): Int = row.getInt(0)
}
object FlatMapGroupsWithStateSuite {
var failInTask = true
class MemoryStateStore extends StateStore() {
import scala.collection.JavaConverters._
private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow]
override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = {
map.entrySet.iterator.asScala.map { case e => (e.getKey, e.getValue) }
}
override def filter(c: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = {
iterator.filter { case (k, v) => c(k, v) }
}
override def get(key: UnsafeRow): Option[UnsafeRow] = Option(map.get(key))
override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = map.put(key, newValue)
override def remove(key: UnsafeRow): Unit = { map.remove(key) }
override def remove(condition: (UnsafeRow) => Boolean): Unit = {
iterator.map(_._1).filter(condition).foreach(map.remove)
}
override def commit(): Long = version + 1
override def abort(): Unit = { }
override def id: StateStoreId = null
override def version: Long = 0
override def updates(): Iterator[StoreUpdate] = { throw new UnsupportedOperationException }
override def numKeys(): Long = map.size
override def hasCommitted: Boolean = true
}
}