diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 13c7f75275..321725d8dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -37,12 +37,6 @@ object UnsupportedOperationChecker extends Logging { case p if p.isStreaming => throwError("Queries with streaming sources must be executed with writeStream.start()")(p) - case f: FlatMapGroupsWithState => - if (f.hasInitialState) { - throwError("Initial state is not supported in [flatMap|map]GroupsWithState" + - " operation on a batch DataFrame/Dataset")(f) - } - case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6d10fa83f4..7624b157e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -690,9 +690,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil case logical.FlatMapGroupsWithState( - f, key, value, grouping, data, output, _, _, _, timeout, _, _, _, _, _, child) => - execution.MapGroupsExec( - f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil + f, keyDeserializer, valueDeserializer, grouping, data, output, stateEncoder, outputMode, + isFlatMapGroupsWithState, timeout, hasInitialState, initialStateGroupAttrs, + initialStateDataAttrs, initialStateDeserializer, initialState, child) => + FlatMapGroupsWithStateExec.generateSparkPlanForBatchQueries( + f, keyDeserializer, valueDeserializer, initialStateDeserializer, grouping, + initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout, + hasInitialState, planLater(initialState), planLater(child) + ) :: Nil case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 03694d4ad3..a00a62216f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -309,9 +309,7 @@ case class FlatMapGroupsWithStateExec( var foundInitialStateForKey = false initialStateRowIter.foreach { initialStateRow => if (foundInitialStateForKey) { - throw new IllegalArgumentException("The initial state provided contained " + - "multiple rows(state) with the same key. Make sure to de-duplicate the " + - "initial state before passing it.") + FlatMapGroupsWithStateExec.foundDuplicateInitialKeyException() } foundInitialStateForKey = true val initStateObj = getStateObj.get(initialStateRow) @@ -403,3 +401,70 @@ case class FlatMapGroupsWithStateExec( copy(child = newLeft, initialState = newRight) } +object FlatMapGroupsWithStateExec { + + def foundDuplicateInitialKeyException(): Exception = { + throw new IllegalArgumentException("The initial state provided contained " + + "multiple rows(state) with the same key. Make sure to de-duplicate the " + + "initial state before passing it.") + } + + /** + * Plan logical flatmapGroupsWIthState for batch queries + * If the initial state is provided, we create an instance of the CoGroupExec, if the initial + * state is not provided we create an instance of the MapGroupsExec + */ + // scalastyle:off argcount + def generateSparkPlanForBatchQueries( + userFunc: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], + keyDeserializer: Expression, + valueDeserializer: Expression, + initialStateDeserializer: Expression, + groupingAttributes: Seq[Attribute], + initialStateGroupAttrs: Seq[Attribute], + dataAttributes: Seq[Attribute], + initialStateDataAttrs: Seq[Attribute], + outputObjAttr: Attribute, + timeoutConf: GroupStateTimeout, + hasInitialState: Boolean, + initialState: SparkPlan, + child: SparkPlan): SparkPlan = { + if (hasInitialState) { + val watermarkPresent = child.output.exists { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true + case _ => false + } + val func = (keyRow: Any, values: Iterator[Any], states: Iterator[Any]) => { + // Check if there is only one state for every key. + var foundInitialStateForKey = false + val optionalStates = states.map { stateValue => + if (foundInitialStateForKey) { + foundDuplicateInitialKeyException() + } + foundInitialStateForKey = true + stateValue + }.toArray + + // Create group state object + val groupState = GroupStateImpl.createForStreaming( + optionalStates.headOption, + System.currentTimeMillis, + GroupStateImpl.NO_TIMESTAMP, + timeoutConf, + hasTimedOut = false, + watermarkPresent) + + // Call user function with the state and values for this key + userFunc(keyRow, values, groupState) + } + CoGroupExec( + func, keyDeserializer, valueDeserializer, initialStateDeserializer, groupingAttributes, + initialStateGroupAttrs, dataAttributes, initialStateDataAttrs, outputObjAttr, + child, initialState) + } else { + MapGroupsExec( + userFunc, keyDeserializer, valueDeserializer, groupingAttributes, + dataAttributes, outputObjAttr, timeoutConf, child) + } + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 0500c5217b..28439f27ff 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -196,14 +196,7 @@ public class JavaDatasetSuite implements Serializable { GroupStateTimeout.NoTimeout(), kvInitStateMappedDS); - Assert.assertThrows( - "Initial state is not supported in [flatMap|map]GroupsWithState " + - "operation on a batch DataFrame/Dataset", - AnalysisException.class, - () -> { - flatMapped2.collectAsList(); - } - ); + Assert.assertEquals(asSet("1a", "2", "3foobar"), toSet(flatMapped2.collectAsList())); Dataset mapped2 = grouped.mapGroupsWithState( (MapGroupsWithStateFunction) (key, values, s) -> { StringBuilder sb = new StringBuilder(key.toString()); @@ -216,14 +209,7 @@ public class JavaDatasetSuite implements Serializable { Encoders.STRING(), GroupStateTimeout.NoTimeout(), kvInitStateMappedDS); - Assert.assertThrows( - "Initial state is not supported in [flatMap|map]GroupsWithState " + - "operation on a batch DataFrame/Dataset", - AnalysisException.class, - () -> { - mapped2.collectAsList(); - } - ); + Assert.assertEquals(asSet("1a", "2", "3foobar"), toSet(mapped2.collectAsList())); } @Test diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 152dd167fa..d34b2b8e9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -1284,6 +1284,12 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } assertCannotGetWatermark { state.getCurrentWatermarkMs() } assert(!state.hasTimedOut) + if (key.contains("EventTime")) { + state.setTimeoutTimestamp(0, "1 hour") + } + if (key.contains("ProcessingTime")) { + state.setTimeoutDuration("1 hour") + } val count = state.getOption.map(_.count).getOrElse(0L) + valList.size // We need to check if not explicitly calling update will still save the init state or not if (!key.contains("NoUpdate")) { @@ -1413,6 +1419,52 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { ) } + Seq(NoTimeout(), EventTimeTimeout(), ProcessingTimeTimeout()).foreach { timeout => + test(s"flatMapGroupsWithState - initial state - batch mode - timeout ${timeout}") { + // We will test them on different shuffle partition configuration to make sure the + // grouping by key will still work. On higher number of shuffle partitions its possible + // that all keys end up on different partitions. + val initialState = Seq( + (s"keyInStateAndData-1-$timeout", new RunningCount(1)), + ("keyInStateAndData-2", new RunningCount(2)), + ("keyNoUpdate", new RunningCount(2)), // state.update will not be called + ("keyOnlyInState-1", new RunningCount(1)) + ).toDS().groupByKey(x => x._1).mapValues(_._2) + + val inputData = Seq( + ("keyOnlyInData"), ("keyInStateAndData-2") + ) + val result = inputData.toDS().groupByKey(x => x) + .flatMapGroupsWithState( + Update, timeout, initialState)(flatMapGroupsWithStateFunc) + + val expected = Seq( + ("keyOnlyInState-1", Seq[String](), "1"), + ("keyNoUpdate", Seq[String](), "2"), // update will not be called + ("keyInStateAndData-2", Seq[String]("keyInStateAndData-2"), "3"), // inc by 1 + (s"keyInStateAndData-1-$timeout", Seq[String](), "1"), + ("keyOnlyInData", Seq[String]("keyOnlyInData"), "1") // inc by 1 + ).toDF() + checkAnswer(result.toDF(), expected) + } + } + + testQuietly("flatMapGroupsWithState - initial state - batch mode - duplicate state") { + val initialState = Seq( + ("a", new RunningCount(1)), + ("a", new RunningCount(2)) + ).toDS().groupByKey(x => x._1).mapValues(_._2) + + val e = intercept[SparkException] { + Seq("a", "b").toDS().groupByKey(x => x) + .flatMapGroupsWithState(Update, NoTimeout(), initialState)(flatMapGroupsWithStateFunc) + .show() + } + assert(e.getMessage.contains( + "The initial state provided contained multiple rows(state) with the same key." + + " Make sure to de-duplicate the initial state before passing it.")) + } + testQuietly("flatMapGroupsWithState - initial state - streaming initial state") { val initialStateData = MemoryStream[(String, RunningCount)] initialStateData.addData(("a", new RunningCount(1)))