[SPARK-20057][SS] Renamed KeyedState to GroupState in mapGroupsWithState

## What changes were proposed in this pull request?

Since the state is tied a "group" in the "mapGroupsWithState" operations, its better to call the state "GroupState" instead of a key. This would make it more general if you extends this operation to RelationGroupedDataset and python APIs.

## How was this patch tested?
Existing unit tests.

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #17385 from tdas/SPARK-20057.
This commit is contained in:
Tathagata Das 2017-03-22 12:30:36 -07:00
parent 80fd070389
commit 82b598b963
13 changed files with 172 additions and 167 deletions

View file

@ -24,31 +24,31 @@ import org.apache.spark.sql.catalyst.plans.logical.*;
/**
* Represents the type of timeouts possible for the Dataset operations
* `mapGroupsWithState` and `flatMapGroupsWithState`. See documentation on
* `KeyedState` for more details.
* `GroupState` for more details.
*
* @since 2.2.0
*/
@Experimental
@InterfaceStability.Evolving
public class KeyedStateTimeout {
public class GroupStateTimeout {
/**
* Timeout based on processing time. The duration of timeout can be set for each group in
* `map/flatMapGroupsWithState` by calling `KeyedState.setTimeoutDuration()`. See documentation
* on `KeyedState` for more details.
* `map/flatMapGroupsWithState` by calling `GroupState.setTimeoutDuration()`. See documentation
* on `GroupState` for more details.
*/
public static KeyedStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; }
public static GroupStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; }
/**
* Timeout based on event-time. The event-time timestamp for timeout can be set for each
* group in `map/flatMapGroupsWithState` by calling `KeyedState.setTimeoutTimestamp()`.
* group in `map/flatMapGroupsWithState` by calling `GroupState.setTimeoutTimestamp()`.
* In addition, you have to define the watermark in the query using `Dataset.withWatermark`.
* When the watermark advances beyond the set timestamp of a group and the group has not
* received any data, then the group times out. See documentation on
* `KeyedState` for more details.
* `GroupState` for more details.
*/
public static KeyedStateTimeout EventTimeTimeout() { return EventTimeTimeout$.MODULE$; }
public static GroupStateTimeout EventTimeTimeout() { return EventTimeTimeout$.MODULE$; }
/** No timeout. */
public static KeyedStateTimeout NoTimeout() { return NoTimeout$.MODULE$; }
public static GroupStateTimeout NoTimeout() { return NoTimeout$.MODULE$; }
}

View file

@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode }
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode }
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@ -351,22 +351,22 @@ case class MapGroups(
child: LogicalPlan) extends UnaryNode with ObjectProducer
/** Internal class representing State */
trait LogicalKeyedState[S]
trait LogicalGroupState[S]
/** Types of timeouts used in FlatMapGroupsWithState */
case object NoTimeout extends KeyedStateTimeout
case object ProcessingTimeTimeout extends KeyedStateTimeout
case object EventTimeTimeout extends KeyedStateTimeout
case object NoTimeout extends GroupStateTimeout
case object ProcessingTimeTimeout extends GroupStateTimeout
case object EventTimeTimeout extends GroupStateTimeout
/** Factory for constructing new `MapGroupsWithState` nodes. */
object FlatMapGroupsWithState {
def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder](
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
outputMode: OutputMode,
isMapGroupsWithState: Boolean,
timeout: KeyedStateTimeout,
timeout: GroupStateTimeout,
child: LogicalPlan): LogicalPlan = {
val encoder = encoderFor[S]
@ -404,7 +404,7 @@ object FlatMapGroupsWithState {
* @param timeout used to timeout groups that have not received data in a while
*/
case class FlatMapGroupsWithState(
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
groupingAttributes: Seq[Attribute],
@ -413,7 +413,7 @@ case class FlatMapGroupsWithState(
stateEncoder: ExpressionEncoder[Any],
outputMode: OutputMode,
isMapGroupsWithState: Boolean = false,
timeout: KeyedStateTimeout,
timeout: GroupStateTimeout,
child: LogicalPlan) extends UnaryNode with ObjectProducer {
if (isMapGroupsWithState) {

View file

@ -17,13 +17,17 @@
package org.apache.spark.sql.streaming;
import org.apache.spark.sql.catalyst.plans.logical.EventTimeTimeout$;
import org.apache.spark.sql.catalyst.plans.logical.NoTimeout$;
import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout$;
import org.junit.Test;
public class JavaKeyedStateTimeoutSuite {
public class JavaGroupStateTimeoutSuite {
@Test
public void testTimeouts() {
assert(KeyedStateTimeout.ProcessingTimeTimeout() == ProcessingTimeTimeout$.MODULE$);
assert (GroupStateTimeout.ProcessingTimeTimeout() == ProcessingTimeTimeout$.MODULE$);
assert (GroupStateTimeout.EventTimeTimeout() == EventTimeTimeout$.MODULE$);
assert (GroupStateTimeout.NoTimeout() == NoTimeout$.MODULE$);
}
}

View file

@ -22,7 +22,7 @@ import java.util.Iterator;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.streaming.KeyedState;
import org.apache.spark.sql.streaming.GroupState;
/**
* ::Experimental::
@ -35,5 +35,5 @@ import org.apache.spark.sql.streaming.KeyedState;
@Experimental
@InterfaceStability.Evolving
public interface FlatMapGroupsWithStateFunction<K, V, S, R> extends Serializable {
Iterator<R> call(K key, Iterator<V> values, KeyedState<S> state) throws Exception;
Iterator<R> call(K key, Iterator<V> values, GroupState<S> state) throws Exception;
}

View file

@ -22,7 +22,7 @@ import java.util.Iterator;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.streaming.KeyedState;
import org.apache.spark.sql.streaming.GroupState;
/**
* ::Experimental::
@ -34,5 +34,5 @@ import org.apache.spark.sql.streaming.KeyedState;
@Experimental
@InterfaceStability.Evolving
public interface MapGroupsWithStateFunction<K, V, S, R> extends Serializable {
R call(K key, Iterator<V> values, KeyedState<S> state) throws Exception;
R call(K key, Iterator<V> values, GroupState<S> state) throws Exception;
}

View file

@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.expressions.ReduceAggregator
import org.apache.spark.sql.streaming.{KeyedState, KeyedStateTimeout, OutputMode}
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode}
/**
* :: Experimental ::
@ -228,7 +228,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See [[org.apache.spark.sql.streaming.KeyedState]] for more details.
* See [[org.apache.spark.sql.streaming.GroupState]] for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
@ -240,17 +240,17 @@ class KeyValueGroupedDataset[K, V] private[sql](
@Experimental
@InterfaceStability.Evolving
def mapGroupsWithState[S: Encoder, U: Encoder](
func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = {
val flatMapFunc = (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s))
func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s))
Dataset[U](
sparkSession,
FlatMapGroupsWithState[K, V, S, U](
flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]],
groupingAttributes,
dataAttributes,
OutputMode.Update,
isMapGroupsWithState = true,
KeyedStateTimeout.NoTimeout,
GroupStateTimeout.NoTimeout,
child = logicalPlan))
}
@ -262,7 +262,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See [[org.apache.spark.sql.streaming.KeyedState]] for more details.
* See [[org.apache.spark.sql.streaming.GroupState]] for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
@ -275,13 +275,13 @@ class KeyValueGroupedDataset[K, V] private[sql](
@Experimental
@InterfaceStability.Evolving
def mapGroupsWithState[S: Encoder, U: Encoder](
timeoutConf: KeyedStateTimeout)(
func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = {
val flatMapFunc = (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s))
timeoutConf: GroupStateTimeout)(
func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s))
Dataset[U](
sparkSession,
FlatMapGroupsWithState[K, V, S, U](
flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]],
groupingAttributes,
dataAttributes,
OutputMode.Update,
@ -298,7 +298,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See [[KeyedState]] for more details.
* See [[GroupState]] for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
@ -316,7 +316,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
stateEncoder: Encoder[S],
outputEncoder: Encoder[U]): Dataset[U] = {
mapGroupsWithState[S, U](
(key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s)
(key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s)
)(stateEncoder, outputEncoder)
}
@ -328,7 +328,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See [[KeyedState]] for more details.
* See [[GroupState]] for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
@ -346,9 +346,9 @@ class KeyValueGroupedDataset[K, V] private[sql](
func: MapGroupsWithStateFunction[K, V, S, U],
stateEncoder: Encoder[S],
outputEncoder: Encoder[U],
timeoutConf: KeyedStateTimeout): Dataset[U] = {
timeoutConf: GroupStateTimeout): Dataset[U] = {
mapGroupsWithState[S, U](
(key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s)
(key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s)
)(stateEncoder, outputEncoder)
}
@ -360,7 +360,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See [[KeyedState]] for more details.
* See [[GroupState]] for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
@ -375,15 +375,15 @@ class KeyValueGroupedDataset[K, V] private[sql](
@InterfaceStability.Evolving
def flatMapGroupsWithState[S: Encoder, U: Encoder](
outputMode: OutputMode,
timeoutConf: KeyedStateTimeout)(
func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = {
timeoutConf: GroupStateTimeout)(
func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = {
if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) {
throw new IllegalArgumentException("The output mode of function should be append or update")
}
Dataset[U](
sparkSession,
FlatMapGroupsWithState[K, V, S, U](
func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
func.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]],
groupingAttributes,
dataAttributes,
outputMode,
@ -400,7 +400,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See [[KeyedState]] for more details.
* See [[GroupState]] for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
@ -420,8 +420,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
outputMode: OutputMode,
stateEncoder: Encoder[S],
outputEncoder: Encoder[U],
timeoutConf: KeyedStateTimeout): Dataset[U] = {
val f = (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala
timeoutConf: GroupStateTimeout): Dataset[U] = {
val f = (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s).asScala
flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder)
}

View file

@ -31,8 +31,8 @@ import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.logical.FunctionUtils
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
import org.apache.spark.sql.execution.streaming.KeyedStateImpl
import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState
import org.apache.spark.sql.execution.streaming.GroupStateImpl
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@ -355,14 +355,14 @@ case class MapGroupsExec(
object MapGroupsExec {
def apply(
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => TraversableOnce[Any],
func: (Any, Iterator[Any], LogicalGroupState[Any]) => TraversableOnce[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
outputObjAttr: Attribute,
child: SparkPlan): MapGroupsExec = {
val f = (key: Any, values: Iterator[Any]) => func(key, values, new KeyedStateImpl[Any](None))
val f = (key: Any, values: Iterator[Any]) => func(key, values, new GroupStateImpl[Any](None))
new MapGroupsExec(f, keyDeserializer, valueDeserializer,
groupingAttributes, dataAttributes, outputObjAttr, child)
}

View file

@ -23,9 +23,9 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.streaming.KeyedStateImpl.NO_TIMESTAMP
import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode}
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.util.CompletionIterator
@ -44,7 +44,7 @@ import org.apache.spark.util.CompletionIterator
* @param batchTimestampMs processing timestamp of the current batch.
*/
case class FlatMapGroupsWithStateExec(
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
groupingAttributes: Seq[Attribute],
@ -53,13 +53,13 @@ case class FlatMapGroupsWithStateExec(
stateId: Option[OperatorStateId],
stateEncoder: ExpressionEncoder[Any],
outputMode: OutputMode,
timeoutConf: KeyedStateTimeout,
timeoutConf: GroupStateTimeout,
batchTimestampMs: Option[Long],
override val eventTimeWatermark: Option[Long],
child: SparkPlan
) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport {
import KeyedStateImpl._
import GroupStateImpl._
private val isTimeoutEnabled = timeoutConf != NoTimeout
private val timestampTimeoutAttribute =
@ -147,7 +147,7 @@ case class FlatMapGroupsWithStateExec(
private val stateSerializer = {
val encoderSerializer = stateEncoder.namedExpressions
if (isTimeoutEnabled) {
encoderSerializer :+ Literal(KeyedStateImpl.NO_TIMESTAMP)
encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP)
} else {
encoderSerializer
}
@ -211,7 +211,7 @@ case class FlatMapGroupsWithStateExec(
val keyObj = getKeyObj(keyRow) // convert key to objects
val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
val stateObjOption = getStateObj(prevStateRowOption)
val keyedState = new KeyedStateImpl(
val keyedState = new GroupStateImpl(
stateObjOption,
batchTimestampMs.getOrElse(NO_TIMESTAMP),
eventTimeWatermark.getOrElse(NO_TIMESTAMP),
@ -247,7 +247,7 @@ case class FlatMapGroupsWithStateExec(
if (shouldWriteState) {
if (stateRowToWrite == null) {
// This should never happen because checks in KeyedStateImpl should avoid cases
// This should never happen because checks in GroupStateImpl should avoid cases
// where empty state would need to be written
throw new IllegalStateException("Attempting to write empty state")
}

View file

@ -22,13 +22,14 @@ import java.sql.Date
import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout}
import org.apache.spark.sql.execution.streaming.KeyedStateImpl._
import org.apache.spark.sql.streaming.{KeyedState, KeyedStateTimeout}
import org.apache.spark.sql.execution.streaming.GroupStateImpl._
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout}
import org.apache.spark.unsafe.types.CalendarInterval
/**
* Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe.
* Internal implementation of the [[GroupState]] interface. Methods are not thread-safe.
*
* @param optionalValue Optional value of the state
* @param batchProcessingTimeMs Processing time of current batch, used to calculate timestamp
* for processing time timeouts
@ -37,19 +38,19 @@ import org.apache.spark.unsafe.types.CalendarInterval
* @param hasTimedOut Whether the key for which this state wrapped is being created is
* getting timed out or not.
*/
private[sql] class KeyedStateImpl[S](
private[sql] class GroupStateImpl[S](
optionalValue: Option[S],
batchProcessingTimeMs: Long,
eventTimeWatermarkMs: Long,
timeoutConf: KeyedStateTimeout,
override val hasTimedOut: Boolean) extends KeyedState[S] {
timeoutConf: GroupStateTimeout,
override val hasTimedOut: Boolean) extends GroupState[S] {
// Constructor to create dummy state when using mapGroupsWithState in a batch query
def this(optionalValue: Option[S]) = this(
optionalValue,
batchProcessingTimeMs = NO_TIMESTAMP,
eventTimeWatermarkMs = NO_TIMESTAMP,
timeoutConf = KeyedStateTimeout.NoTimeout,
timeoutConf = GroupStateTimeout.NoTimeout,
hasTimedOut = false)
private var value: S = optionalValue.getOrElse(null.asInstanceOf[S])
private var defined: Boolean = optionalValue.isDefined
@ -169,7 +170,7 @@ private[sql] class KeyedStateImpl[S](
}
override def toString: String = {
s"KeyedState(${getOption.map(_.toString).getOrElse("<undefined>")})"
s"GroupState(${getOption.map(_.toString).getOrElse("<undefined>")})"
}
// ========= Internal API =========
@ -221,7 +222,7 @@ private[sql] class KeyedStateImpl[S](
}
private[sql] object KeyedStateImpl {
private[sql] object GroupStateImpl {
// Value used represent the lack of valid timestamp as a long
val NO_TIMESTAMP = -1L
}

View file

@ -23,13 +23,13 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState, ProcessingTimeTimeout}
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalGroupState, ProcessingTimeTimeout}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode}
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
import org.apache.spark.sql.types._
import org.apache.spark.util.CompletionIterator

View file

@ -19,14 +19,13 @@ package org.apache.spark.sql.streaming
import org.apache.spark.annotation.{Experimental, InterfaceStability}
import org.apache.spark.sql.{Encoder, KeyValueGroupedDataset}
import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState
/**
* :: Experimental ::
*
* Wrapper class for interacting with keyed state data in `mapGroupsWithState` and
* `flatMapGroupsWithState` operations on
* [[KeyValueGroupedDataset]].
* Wrapper class for interacting with per-group state data in `mapGroupsWithState` and
* `flatMapGroupsWithState` operations on [[KeyValueGroupedDataset]].
*
* Detail description on `[map/flatMap]GroupsWithState` operation
* --------------------------------------------------------------
@ -37,11 +36,11 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
* Dataset, the function will be invoked for each group repeatedly in every trigger.
* That is, in every batch of the `streaming.StreamingQuery`,
* the function will be invoked once for each group that has data in the trigger. Furthermore,
* if timeout is set, then the function will invoked on timed out keys (more detail below).
* if timeout is set, then the function will invoked on timed out groups (more detail below).
*
* The function is invoked with following parameters.
* - The key of the group.
* - An iterator containing all the values for this key.
* - An iterator containing all the values for this group.
* - A user-defined state object set by previous invocations of the given function.
* In case of a batch Dataset, there is only one invocation and state object will be empty as
* there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState`
@ -55,57 +54,58 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
* batch, nor with streaming Datasets.
* - All the data will be shuffled before applying the function.
* - If timeout is set, then the function will also be called with no values.
* See more details on `KeyedStateTimeout` below.
* See more details on `GroupStateTimeout` below.
*
* Important points to note about using `KeyedState`.
* Important points to note about using `GroupState`.
* - The value of the state cannot be null. So updating state with null will throw
* `IllegalArgumentException`.
* - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers.
* - Operations on `GroupState` are not thread-safe. This is to avoid memory barriers.
* - If `remove()` is called, then `exists()` will return `false`,
* `get()` will throw `NoSuchElementException` and `getOption()` will return `None`
* - After that, if `update(newState)` is called, then `exists()` will again return `true`,
* `get()` and `getOption()`will return the updated value.
*
* Important points to note about using `KeyedStateTimeout`.
* - The timeout type is a global param across all the keys (set as `timeout` param in
* Important points to note about using `GroupStateTimeout`.
* - The timeout type is a global param across all the groups (set as `timeout` param in
* `[map|flatMap]GroupsWithState`, but the exact timeout duration/timestamp is configurable per
* key by calling `setTimeout...()` in `KeyedState`.
* group by calling `setTimeout...()` in `GroupState`.
* - Timeouts can be either based on processing time (i.e.
* [[KeyedStateTimeout.ProcessingTimeTimeout]]) or event time (i.e.
* [[KeyedStateTimeout.EventTimeTimeout]]).
* [[GroupStateTimeout.ProcessingTimeTimeout]]) or event time (i.e.
* [[GroupStateTimeout.EventTimeTimeout]]).
* - With `ProcessingTimeTimeout`, the timeout duration can be set by calling
* `KeyedState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the set
* `GroupState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the set
* duration. Guarantees provided by this timeout with a duration of D ms are as follows:
* - Timeout will never be occur before the clock time has advanced by D ms
* - Timeout will occur eventually when there is a trigger in the query
* (i.e. after D ms). So there is a no strict upper bound on when the timeout would occur.
* For example, the trigger interval of the query will affect when the timeout actually occurs.
* If there is no data in the stream (for any key) for a while, then their will not be
* If there is no data in the stream (for any group) for a while, then their will not be
* any trigger and timeout function call will not occur until there is data.
* - Since the processing time timeout is based on the clock time, it is affected by the
* variations in the system clock (i.e. time zone changes, clock skew, etc.).
* - With `EventTimeTimeout`, the user also has to specify the the the event time watermark in
* the query using `Dataset.withWatermark()`. With this setting, data that is older than the
* watermark are filtered out. The timeout can be enabled for a key by setting a timestamp using
* `KeyedState.setTimeoutTimestamp()`, and the timeout would occur when the watermark advances
* beyond the set timestamp. You can control the timeout delay by two parameters - (i) watermark
* delay and an additional duration beyond the timestamp in the event (which is guaranteed to
* > watermark due to the filtering). Guarantees provided by this timeout are as follows:
* watermark are filtered out. The timeout can be set for a group by setting a timeout timestamp
* using`GroupState.setTimeoutTimestamp()`, and the timeout would occur when the watermark
* advances beyond the set timestamp. You can control the timeout delay by two parameters -
* (i) watermark delay and an additional duration beyond the timestamp in the event (which
* is guaranteed to be newer than watermark due to the filtering). Guarantees provided by this
* timeout are as follows:
* - Timeout will never be occur before watermark has exceeded the set timeout.
* - Similar to processing time timeouts, there is a no strict upper bound on the delay when
* the timeout actually occurs. The watermark can advance only when there is data in the
* stream, and the event time of the data has actually advanced.
* - When the timeout occurs for a key, the function is called for that key with no values, and
* `KeyedState.hasTimedOut()` set to true.
* - The timeout is reset for key every time the function is called on the key, that is,
* when the key has new data, or the key has timed out. So the user has to set the timeout
* - When the timeout occurs for a group, the function is called for that group with no values, and
* `GroupState.hasTimedOut()` set to true.
* - The timeout is reset every time the function is called on a group, that is,
* when the group has new data, or the group has timed out. So the user has to set the timeout
* duration every time the function is called, otherwise there will not be any timeout set.
*
* Scala example of using KeyedState in `mapGroupsWithState`:
* Scala example of using GroupState in `mapGroupsWithState`:
* {{{
* // A mapping function that maintains an integer state for string keys and returns a string.
* // Additionally, it sets a timeout to remove the state if it has not received data for an hour.
* def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = {
* def mappingFunction(key: String, value: Iterator[Int], state: GroupState[Int]): String = {
*
* if (state.hasTimedOut) { // If called when timing out, remove the state
* state.remove()
@ -133,10 +133,10 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
*
* dataset
* .groupByKey(...)
* .mapGroupsWithState(KeyedStateTimeout.ProcessingTimeTimeout)(mappingFunction)
* .mapGroupsWithState(GroupStateTimeout.ProcessingTimeTimeout)(mappingFunction)
* }}}
*
* Java example of using `KeyedState`:
* Java example of using `GroupState`:
* {{{
* // A mapping function that maintains an integer state for string keys and returns a string.
* // Additionally, it sets a timeout to remove the state if it has not received data for an hour.
@ -144,7 +144,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
* new MapGroupsWithStateFunction<String, Integer, Integer, String>() {
*
* @Override
* public String call(String key, Iterator<Integer> value, KeyedState<Integer> state) {
* public String call(String key, Iterator<Integer> value, GroupState<Integer> state) {
* if (state.hasTimedOut()) { // If called when timing out, remove the state
* state.remove();
*
@ -173,16 +173,16 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
* dataset
* .groupByKey(...)
* .mapGroupsWithState(
* mappingFunction, Encoders.INT, Encoders.STRING, KeyedStateTimeout.ProcessingTimeTimeout);
* mappingFunction, Encoders.INT, Encoders.STRING, GroupStateTimeout.ProcessingTimeTimeout);
* }}}
*
* @tparam S User-defined type of the state to be stored for each key. Must be encodable into
* @tparam S User-defined type of the state to be stored for each group. Must be encodable into
* Spark SQL types (see [[Encoder]] for more details).
* @since 2.2.0
*/
@Experimental
@InterfaceStability.Evolving
trait KeyedState[S] extends LogicalKeyedState[S] {
trait GroupState[S] extends LogicalGroupState[S] {
/** Whether state exists or not. */
def exists: Boolean
@ -201,7 +201,7 @@ trait KeyedState[S] extends LogicalKeyedState[S] {
@throws[IllegalArgumentException]("when updating with null")
def update(newState: S): Unit
/** Remove this keyed state. Note that this resets any timeout configuration as well. */
/** Remove this state. Note that this resets any timeout configuration as well. */
def remove(): Unit
/**

View file

@ -23,7 +23,7 @@ import java.sql.Date;
import java.sql.Timestamp;
import java.util.*;
import org.apache.spark.sql.streaming.KeyedStateTimeout;
import org.apache.spark.sql.streaming.GroupStateTimeout;
import org.apache.spark.sql.streaming.OutputMode;
import scala.Tuple2;
import scala.Tuple3;
@ -210,7 +210,7 @@ public class JavaDatasetSuite implements Serializable {
OutputMode.Append(),
Encoders.LONG(),
Encoders.STRING(),
KeyedStateTimeout.NoTimeout());
GroupStateTimeout.NoTimeout());
Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList()));

View file

@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState
import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.RDDScanExec
import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, KeyedStateImpl, MemoryStream}
import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream}
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate}
import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore
import org.apache.spark.sql.types.{DataType, IntegerType}
@ -43,16 +43,16 @@ case class Result(key: Long, count: Int)
class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
import testImplicits._
import KeyedStateImpl._
import KeyedStateTimeout._
import GroupStateImpl._
import GroupStateTimeout._
override def afterAll(): Unit = {
super.afterAll()
StateStore.stop()
}
test("KeyedState - get, exists, update, remove") {
var state: KeyedStateImpl[String] = null
test("GroupState - get, exists, update, remove") {
var state: GroupStateImpl[String] = null
def testState(
expectedData: Option[String],
@ -73,13 +73,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}
// Updating empty state
state = new KeyedStateImpl[String](None)
state = new GroupStateImpl[String](None)
testState(None)
state.update("")
testState(Some(""), shouldBeUpdated = true)
// Updating exiting state
state = new KeyedStateImpl[String](Some("2"))
state = new GroupStateImpl[String](Some("2"))
testState(Some("2"))
state.update("3")
testState(Some("3"), shouldBeUpdated = true)
@ -97,19 +97,19 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}
}
test("KeyedState - setTimeout**** with NoTimeout") {
test("GroupState - setTimeout**** with NoTimeout") {
for (initState <- Seq(None, Some(5))) {
// for different initial state
implicit val state = new KeyedStateImpl(initState, 1000, 1000, NoTimeout, hasTimedOut = false)
implicit val state = new GroupStateImpl(initState, 1000, 1000, NoTimeout, hasTimedOut = false)
testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
}
}
test("KeyedState - setTimeout**** with ProcessingTimeTimeout") {
implicit var state: KeyedStateImpl[Int] = null
test("GroupState - setTimeout**** with ProcessingTimeTimeout") {
implicit var state: GroupStateImpl[Int] = null
state = new KeyedStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false)
state = new GroupStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false)
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
testTimeoutDurationNotAllowed[IllegalStateException](state)
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
@ -128,8 +128,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
}
test("KeyedState - setTimeout**** with EventTimeTimeout") {
implicit val state = new KeyedStateImpl[Int](
test("GroupState - setTimeout**** with EventTimeTimeout") {
implicit val state = new GroupStateImpl[Int](
None, 1000, 1000, EventTimeTimeout, hasTimedOut = false)
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
@ -148,8 +148,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
testTimeoutTimestampNotAllowed[IllegalStateException](state)
}
test("KeyedState - illegal params to setTimeout****") {
var state: KeyedStateImpl[Int] = null
test("GroupState - illegal params to setTimeout****") {
var state: GroupStateImpl[Int] = null
// Test setTimeout****() with illegal values
def testIllegalTimeout(body: => Unit): Unit = {
@ -157,14 +157,14 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
}
state = new KeyedStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false)
state = new GroupStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false)
testIllegalTimeout { state.setTimeoutDuration(-1000) }
testIllegalTimeout { state.setTimeoutDuration(0) }
testIllegalTimeout { state.setTimeoutDuration("-2 second") }
testIllegalTimeout { state.setTimeoutDuration("-1 month") }
testIllegalTimeout { state.setTimeoutDuration("1 month -1 day") }
state = new KeyedStateImpl(Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false)
state = new GroupStateImpl(Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false)
testIllegalTimeout { state.setTimeoutTimestamp(-10000) }
testIllegalTimeout { state.setTimeoutTimestamp(10000, "-3 second") }
testIllegalTimeout { state.setTimeoutTimestamp(10000, "-1 month") }
@ -175,25 +175,25 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day") }
}
test("KeyedState - hasTimedOut") {
test("GroupState - hasTimedOut") {
for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, EventTimeTimeout)) {
for (initState <- Seq(None, Some(5))) {
val state1 = new KeyedStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = false)
val state1 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = false)
assert(state1.hasTimedOut === false)
val state2 = new KeyedStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = true)
val state2 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = true)
assert(state2.hasTimedOut === true)
}
}
}
test("KeyedState - primitive type") {
var intState = new KeyedStateImpl[Int](None)
test("GroupState - primitive type") {
var intState = new GroupStateImpl[Int](None)
intercept[NoSuchElementException] {
intState.get
}
assert(intState.getOption === None)
intState = new KeyedStateImpl[Int](Some(10))
intState = new GroupStateImpl[Int](Some(10))
assert(intState.get == 10)
intState.update(0)
assert(intState.get == 0)
@ -218,21 +218,21 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
testStateUpdateWithData(
testName + "no update",
stateUpdates = state => { /* do nothing */ },
timeoutConf = KeyedStateTimeout.NoTimeout,
timeoutConf = GroupStateTimeout.NoTimeout,
priorState = priorState,
expectedState = priorState) // should not change
testStateUpdateWithData(
testName + "state updated",
stateUpdates = state => { state.update(5) },
timeoutConf = KeyedStateTimeout.NoTimeout,
timeoutConf = GroupStateTimeout.NoTimeout,
priorState = priorState,
expectedState = Some(5)) // should change
testStateUpdateWithData(
testName + "state removed",
stateUpdates = state => { state.remove() },
timeoutConf = KeyedStateTimeout.NoTimeout,
timeoutConf = GroupStateTimeout.NoTimeout,
priorState = priorState,
expectedState = None) // should be removed
}
@ -283,7 +283,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
testStateUpdateWithData(
s"ProcessingTimeTimeout - $testName - state and timeout duration updated",
stateUpdates =
(state: KeyedState[Int]) => { state.update(5); state.setTimeoutDuration(5000) },
(state: GroupState[Int]) => { state.update(5); state.setTimeoutDuration(5000) },
timeoutConf = ProcessingTimeTimeout,
priorState = priorState,
priorTimeoutTimestamp = priorTimeoutTimestamp,
@ -293,7 +293,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
testStateUpdateWithData(
s"EventTimeTimeout - $testName - state and timeout timestamp updated",
stateUpdates =
(state: KeyedState[Int]) => { state.update(5); state.setTimeoutTimestamp(5000) },
(state: GroupState[Int]) => { state.update(5); state.setTimeoutTimestamp(5000) },
timeoutConf = EventTimeTimeout,
priorState = priorState,
priorTimeoutTimestamp = priorTimeoutTimestamp,
@ -303,7 +303,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
testStateUpdateWithData(
s"EventTimeTimeout - $testName - timeout timestamp updated to before watermark",
stateUpdates =
(state: KeyedState[Int]) => {
(state: GroupState[Int]) => {
state.update(5)
intercept[IllegalArgumentException] {
state.setTimeoutTimestamp(currentBatchWatermark - 1) // try to set to < watermark
@ -387,7 +387,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
test("StateStoreUpdater - rows are cloned before writing to StateStore") {
// function for running count
val func = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => {
val func = (key: Int, values: Iterator[Int], state: GroupState[Int]) => {
state.update(state.getOption.getOrElse(0) + values.size)
Iterator.empty
}
@ -404,7 +404,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
test("flatMapGroupsWithState - streaming") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count if state is defined, otherwise does not return anything
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
val count = state.getOption.map(_.count).getOrElse(0L) + values.size
if (count == 3) {
@ -420,7 +420,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc)
.flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc)
testStream(result, Update)(
AddData(inputData, "a"),
@ -446,7 +446,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count if state is defined, otherwise does not return anything
// Additionally, it updates state lazily as the returned iterator get consumed
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
values.flatMap { _ =>
val count = state.getOption.map(_.count).getOrElse(0L) + 1
if (count == 3) {
@ -463,7 +463,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc)
.flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc)
testStream(result, Update)(
AddData(inputData, "a", "a", "b"),
CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")),
@ -481,7 +481,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
test("flatMapGroupsWithState - streaming + aggregation") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
val count = state.getOption.map(_.count).getOrElse(0L) + values.size
if (count == 3) {
@ -497,7 +497,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Append, KeyedStateTimeout.NoTimeout)(stateFunc)
.flatMapGroupsWithState(Append, GroupStateTimeout.NoTimeout)(stateFunc)
.groupByKey(_._1)
.count()
@ -524,20 +524,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
test("flatMapGroupsWithState - batch") {
// Function that returns running count only if its even, otherwise does not return
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
if (state.exists) throw new IllegalArgumentException("state.exists should be false")
Iterator((key, values.size))
}
val df = Seq("a", "a", "b").toDS
.groupByKey(x => x)
.flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc).toDF
.flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc).toDF
checkAnswer(df, Seq(("a", 2), ("b", 1)).toDF)
}
test("flatMapGroupsWithState - streaming with processing time timeout") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
if (state.hasTimedOut) {
state.remove()
Iterator((key, "-1"))
@ -594,7 +594,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
val stateFunc = (
key: String,
values: Iterator[(String, Long)],
state: KeyedState[Long]) => {
state: GroupState[Long]) => {
val timeoutDelay = 5
if (key != "a") {
Iterator.empty
@ -637,7 +637,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
test("mapGroupsWithState - streaming") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
val count = state.getOption.map(_.count).getOrElse(0L) + values.size
if (count == 3) {
@ -676,7 +676,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}
test("mapGroupsWithState - batch") {
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
if (state.exists) throw new IllegalArgumentException("state.exists should be false")
(key, values.size)
}
@ -690,7 +690,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}
testQuietly("StateStore.abort on task failure handling") {
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
if (FlatMapGroupsWithStateSuite.failInTask) throw new Exception("expected failure")
val count = state.getOption.map(_.count).getOrElse(0L) + values.size
state.update(RunningCount(count))
@ -724,7 +724,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}
test("output partitioning is unknown") {
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => key
val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => key
val inputData = MemoryStream[String]
val result = inputData.toDS.groupByKey(x => x).mapGroupsWithState(stateFunc)
testStream(result, Update)(
@ -735,13 +735,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}
test("disallow complete mode") {
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[Int]) => {
val stateFunc = (key: String, values: Iterator[String], state: GroupState[Int]) => {
Iterator[String]()
}
var e = intercept[IllegalArgumentException] {
MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(
OutputMode.Complete, KeyedStateTimeout.NoTimeout)(stateFunc)
OutputMode.Complete, GroupStateTimeout.NoTimeout)(stateFunc)
}
assert(e.getMessage === "The output mode of function should be append or update")
@ -750,20 +750,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
override def call(
key: String,
values: JIterator[String],
state: KeyedState[Int]): JIterator[String] = { null }
state: GroupState[Int]): JIterator[String] = { null }
}
e = intercept[IllegalArgumentException] {
MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(
javaStateFunc, OutputMode.Complete,
implicitly[Encoder[Int]], implicitly[Encoder[String]], KeyedStateTimeout.NoTimeout)
implicitly[Encoder[Int]], implicitly[Encoder[String]], GroupStateTimeout.NoTimeout)
}
assert(e.getMessage === "The output mode of function should be append or update")
}
def testStateUpdateWithData(
testName: String,
stateUpdates: KeyedState[Int] => Unit,
timeoutConf: KeyedStateTimeout,
stateUpdates: GroupState[Int] => Unit,
timeoutConf: GroupStateTimeout,
priorState: Option[Int],
priorTimeoutTimestamp: Long = NO_TIMESTAMP,
expectedState: Option[Int] = None,
@ -773,7 +773,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
return // there can be no prior timestamp, when there is no prior state
}
test(s"StateStoreUpdater - updates with data - $testName") {
val mapGroupsFunc = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => {
val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => {
assert(state.hasTimedOut === false, "hasTimedOut not false")
assert(values.nonEmpty, "Some value is expected")
stateUpdates(state)
@ -787,14 +787,14 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
def testStateUpdateWithTimeout(
testName: String,
stateUpdates: KeyedState[Int] => Unit,
timeoutConf: KeyedStateTimeout,
stateUpdates: GroupState[Int] => Unit,
timeoutConf: GroupStateTimeout,
priorTimeoutTimestamp: Long,
expectedState: Option[Int],
expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = {
test(s"StateStoreUpdater - updates for timeout - $testName") {
val mapGroupsFunc = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => {
val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => {
assert(state.hasTimedOut === true, "hasTimedOut not true")
assert(values.isEmpty, "values not empty")
stateUpdates(state)
@ -808,8 +808,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
def testStateUpdate(
testTimeoutUpdates: Boolean,
mapGroupsFunc: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int],
timeoutConf: KeyedStateTimeout,
mapGroupsFunc: (Int, Iterator[Int], GroupState[Int]) => Iterator[Int],
timeoutConf: GroupStateTimeout,
priorState: Option[Int],
priorTimeoutTimestamp: Long,
expectedState: Option[Int],
@ -848,8 +848,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}
def newFlatMapGroupsWithStateExec(
func: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int],
timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout,
func: (Int, Iterator[Int], GroupState[Int]) => Iterator[Int],
timeoutType: GroupStateTimeout = GroupStateTimeout.NoTimeout,
batchTimestampMs: Long = NO_TIMESTAMP): FlatMapGroupsWithStateExec = {
MemoryStream[Int]
.toDS
@ -863,7 +863,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}.get
}
def testTimeoutDurationNotAllowed[T <: Exception: Manifest](state: KeyedStateImpl[_]): Unit = {
def testTimeoutDurationNotAllowed[T <: Exception: Manifest](state: GroupStateImpl[_]): Unit = {
val prevTimestamp = state.getTimeoutTimestamp
intercept[T] { state.setTimeoutDuration(1000) }
assert(state.getTimeoutTimestamp === prevTimestamp)
@ -871,7 +871,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
assert(state.getTimeoutTimestamp === prevTimestamp)
}
def testTimeoutTimestampNotAllowed[T <: Exception: Manifest](state: KeyedStateImpl[_]): Unit = {
def testTimeoutTimestampNotAllowed[T <: Exception: Manifest](state: GroupStateImpl[_]): Unit = {
val prevTimestamp = state.getTimeoutTimestamp
intercept[T] { state.setTimeoutTimestamp(2000) }
assert(state.getTimeoutTimestamp === prevTimestamp)