[SPARK-36132][SS][SQL] Support initial state for batch mode of flatMapGroupsWithState

### What changes were proposed in this pull request?
Adding support for accepting an initial state with flatMapGroupsWithState in batch mode.

### Why are the changes needed?
SPARK-35897  added support for accepting an initial state for streaming queries using flatMapGroupsWithState. the code flow is separate for batch and streaming and required a different PR.

### Does this PR introduce _any_ user-facing change?

Yes as discussed above flatMapGroupsWithState in batch mode can accept an initialState, previously this would throw an UnsupportedOperationException

### How was this patch tested?

Added relevant unit tests in FlatMapGroupsWithStateSuite and modified the  tests `JavaDatasetSuite`

Closes #33336 from rahulsmahadev/flatMapGroupsWithStateBatch.

Authored-by: Rahul Mahadev <rahul.mahadev@databricks.com>
Signed-off-by: Tathagata Das <tathagata.das1565@gmail.com>
This commit is contained in:
Rahul Mahadev 2021-07-21 01:48:58 -04:00 committed by Tathagata Das
parent df798ed301
commit efcce23b91
5 changed files with 130 additions and 28 deletions

View file

@ -37,12 +37,6 @@ object UnsupportedOperationChecker extends Logging {
case p if p.isStreaming =>
throwError("Queries with streaming sources must be executed with writeStream.start()")(p)
case f: FlatMapGroupsWithState =>
if (f.hasInitialState) {
throwError("Initial state is not supported in [flatMap|map]GroupsWithState" +
" operation on a batch DataFrame/Dataset")(f)
}
case _ =>
}
}

View file

@ -690,9 +690,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
case logical.FlatMapGroupsWithState(
f, key, value, grouping, data, output, _, _, _, timeout, _, _, _, _, _, child) =>
execution.MapGroupsExec(
f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil
f, keyDeserializer, valueDeserializer, grouping, data, output, stateEncoder, outputMode,
isFlatMapGroupsWithState, timeout, hasInitialState, initialStateGroupAttrs,
initialStateDataAttrs, initialStateDeserializer, initialState, child) =>
FlatMapGroupsWithStateExec.generateSparkPlanForBatchQueries(
f, keyDeserializer, valueDeserializer, initialStateDeserializer, grouping,
initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout,
hasInitialState, planLater(initialState), planLater(child)
) :: Nil
case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
execution.CoGroupExec(
f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,

View file

@ -309,9 +309,7 @@ case class FlatMapGroupsWithStateExec(
var foundInitialStateForKey = false
initialStateRowIter.foreach { initialStateRow =>
if (foundInitialStateForKey) {
throw new IllegalArgumentException("The initial state provided contained " +
"multiple rows(state) with the same key. Make sure to de-duplicate the " +
"initial state before passing it.")
FlatMapGroupsWithStateExec.foundDuplicateInitialKeyException()
}
foundInitialStateForKey = true
val initStateObj = getStateObj.get(initialStateRow)
@ -403,3 +401,70 @@ case class FlatMapGroupsWithStateExec(
copy(child = newLeft, initialState = newRight)
}
object FlatMapGroupsWithStateExec {
def foundDuplicateInitialKeyException(): Exception = {
throw new IllegalArgumentException("The initial state provided contained " +
"multiple rows(state) with the same key. Make sure to de-duplicate the " +
"initial state before passing it.")
}
/**
* Plan logical flatmapGroupsWIthState for batch queries
* If the initial state is provided, we create an instance of the CoGroupExec, if the initial
* state is not provided we create an instance of the MapGroupsExec
*/
// scalastyle:off argcount
def generateSparkPlanForBatchQueries(
userFunc: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
initialStateDeserializer: Expression,
groupingAttributes: Seq[Attribute],
initialStateGroupAttrs: Seq[Attribute],
dataAttributes: Seq[Attribute],
initialStateDataAttrs: Seq[Attribute],
outputObjAttr: Attribute,
timeoutConf: GroupStateTimeout,
hasInitialState: Boolean,
initialState: SparkPlan,
child: SparkPlan): SparkPlan = {
if (hasInitialState) {
val watermarkPresent = child.output.exists {
case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true
case _ => false
}
val func = (keyRow: Any, values: Iterator[Any], states: Iterator[Any]) => {
// Check if there is only one state for every key.
var foundInitialStateForKey = false
val optionalStates = states.map { stateValue =>
if (foundInitialStateForKey) {
foundDuplicateInitialKeyException()
}
foundInitialStateForKey = true
stateValue
}.toArray
// Create group state object
val groupState = GroupStateImpl.createForStreaming(
optionalStates.headOption,
System.currentTimeMillis,
GroupStateImpl.NO_TIMESTAMP,
timeoutConf,
hasTimedOut = false,
watermarkPresent)
// Call user function with the state and values for this key
userFunc(keyRow, values, groupState)
}
CoGroupExec(
func, keyDeserializer, valueDeserializer, initialStateDeserializer, groupingAttributes,
initialStateGroupAttrs, dataAttributes, initialStateDataAttrs, outputObjAttr,
child, initialState)
} else {
MapGroupsExec(
userFunc, keyDeserializer, valueDeserializer, groupingAttributes,
dataAttributes, outputObjAttr, timeoutConf, child)
}
}
}

View file

@ -196,14 +196,7 @@ public class JavaDatasetSuite implements Serializable {
GroupStateTimeout.NoTimeout(),
kvInitStateMappedDS);
Assert.assertThrows(
"Initial state is not supported in [flatMap|map]GroupsWithState " +
"operation on a batch DataFrame/Dataset",
AnalysisException.class,
() -> {
flatMapped2.collectAsList();
}
);
Assert.assertEquals(asSet("1a", "2", "3foobar"), toSet(flatMapped2.collectAsList()));
Dataset<String> mapped2 = grouped.mapGroupsWithState(
(MapGroupsWithStateFunction<Integer, String, Long, String>) (key, values, s) -> {
StringBuilder sb = new StringBuilder(key.toString());
@ -216,14 +209,7 @@ public class JavaDatasetSuite implements Serializable {
Encoders.STRING(),
GroupStateTimeout.NoTimeout(),
kvInitStateMappedDS);
Assert.assertThrows(
"Initial state is not supported in [flatMap|map]GroupsWithState " +
"operation on a batch DataFrame/Dataset",
AnalysisException.class,
() -> {
mapped2.collectAsList();
}
);
Assert.assertEquals(asSet("1a", "2", "3foobar"), toSet(mapped2.collectAsList()));
}
@Test

View file

@ -1284,6 +1284,12 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
assertCannotGetWatermark { state.getCurrentWatermarkMs() }
assert(!state.hasTimedOut)
if (key.contains("EventTime")) {
state.setTimeoutTimestamp(0, "1 hour")
}
if (key.contains("ProcessingTime")) {
state.setTimeoutDuration("1 hour")
}
val count = state.getOption.map(_.count).getOrElse(0L) + valList.size
// We need to check if not explicitly calling update will still save the init state or not
if (!key.contains("NoUpdate")) {
@ -1413,6 +1419,52 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
)
}
Seq(NoTimeout(), EventTimeTimeout(), ProcessingTimeTimeout()).foreach { timeout =>
test(s"flatMapGroupsWithState - initial state - batch mode - timeout ${timeout}") {
// We will test them on different shuffle partition configuration to make sure the
// grouping by key will still work. On higher number of shuffle partitions its possible
// that all keys end up on different partitions.
val initialState = Seq(
(s"keyInStateAndData-1-$timeout", new RunningCount(1)),
("keyInStateAndData-2", new RunningCount(2)),
("keyNoUpdate", new RunningCount(2)), // state.update will not be called
("keyOnlyInState-1", new RunningCount(1))
).toDS().groupByKey(x => x._1).mapValues(_._2)
val inputData = Seq(
("keyOnlyInData"), ("keyInStateAndData-2")
)
val result = inputData.toDS().groupByKey(x => x)
.flatMapGroupsWithState(
Update, timeout, initialState)(flatMapGroupsWithStateFunc)
val expected = Seq(
("keyOnlyInState-1", Seq[String](), "1"),
("keyNoUpdate", Seq[String](), "2"), // update will not be called
("keyInStateAndData-2", Seq[String]("keyInStateAndData-2"), "3"), // inc by 1
(s"keyInStateAndData-1-$timeout", Seq[String](), "1"),
("keyOnlyInData", Seq[String]("keyOnlyInData"), "1") // inc by 1
).toDF()
checkAnswer(result.toDF(), expected)
}
}
testQuietly("flatMapGroupsWithState - initial state - batch mode - duplicate state") {
val initialState = Seq(
("a", new RunningCount(1)),
("a", new RunningCount(2))
).toDS().groupByKey(x => x._1).mapValues(_._2)
val e = intercept[SparkException] {
Seq("a", "b").toDS().groupByKey(x => x)
.flatMapGroupsWithState(Update, NoTimeout(), initialState)(flatMapGroupsWithStateFunc)
.show()
}
assert(e.getMessage.contains(
"The initial state provided contained multiple rows(state) with the same key." +
" Make sure to de-duplicate the initial state before passing it."))
}
testQuietly("flatMapGroupsWithState - initial state - streaming initial state") {
val initialStateData = MemoryStream[(String, RunningCount)]
initialStateData.addData(("a", new RunningCount(1)))