[SPARK-36132][SS][SQL] Support initial state for batch mode of flatMapGroupsWithState
### What changes were proposed in this pull request? Adding support for accepting an initial state with flatMapGroupsWithState in batch mode. ### Why are the changes needed? SPARK-35897 added support for accepting an initial state for streaming queries using flatMapGroupsWithState. the code flow is separate for batch and streaming and required a different PR. ### Does this PR introduce _any_ user-facing change? Yes as discussed above flatMapGroupsWithState in batch mode can accept an initialState, previously this would throw an UnsupportedOperationException ### How was this patch tested? Added relevant unit tests in FlatMapGroupsWithStateSuite and modified the tests `JavaDatasetSuite` Closes #33336 from rahulsmahadev/flatMapGroupsWithStateBatch. Authored-by: Rahul Mahadev <rahul.mahadev@databricks.com> Signed-off-by: Tathagata Das <tathagata.das1565@gmail.com>
This commit is contained in:
parent
df798ed301
commit
efcce23b91
|
@ -37,12 +37,6 @@ object UnsupportedOperationChecker extends Logging {
|
|||
case p if p.isStreaming =>
|
||||
throwError("Queries with streaming sources must be executed with writeStream.start()")(p)
|
||||
|
||||
case f: FlatMapGroupsWithState =>
|
||||
if (f.hasInitialState) {
|
||||
throwError("Initial state is not supported in [flatMap|map]GroupsWithState" +
|
||||
" operation on a batch DataFrame/Dataset")(f)
|
||||
}
|
||||
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
|
|
|
@ -690,9 +690,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
|
||||
execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
|
||||
case logical.FlatMapGroupsWithState(
|
||||
f, key, value, grouping, data, output, _, _, _, timeout, _, _, _, _, _, child) =>
|
||||
execution.MapGroupsExec(
|
||||
f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil
|
||||
f, keyDeserializer, valueDeserializer, grouping, data, output, stateEncoder, outputMode,
|
||||
isFlatMapGroupsWithState, timeout, hasInitialState, initialStateGroupAttrs,
|
||||
initialStateDataAttrs, initialStateDeserializer, initialState, child) =>
|
||||
FlatMapGroupsWithStateExec.generateSparkPlanForBatchQueries(
|
||||
f, keyDeserializer, valueDeserializer, initialStateDeserializer, grouping,
|
||||
initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout,
|
||||
hasInitialState, planLater(initialState), planLater(child)
|
||||
) :: Nil
|
||||
case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
|
||||
execution.CoGroupExec(
|
||||
f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
|
||||
|
|
|
@ -309,9 +309,7 @@ case class FlatMapGroupsWithStateExec(
|
|||
var foundInitialStateForKey = false
|
||||
initialStateRowIter.foreach { initialStateRow =>
|
||||
if (foundInitialStateForKey) {
|
||||
throw new IllegalArgumentException("The initial state provided contained " +
|
||||
"multiple rows(state) with the same key. Make sure to de-duplicate the " +
|
||||
"initial state before passing it.")
|
||||
FlatMapGroupsWithStateExec.foundDuplicateInitialKeyException()
|
||||
}
|
||||
foundInitialStateForKey = true
|
||||
val initStateObj = getStateObj.get(initialStateRow)
|
||||
|
@ -403,3 +401,70 @@ case class FlatMapGroupsWithStateExec(
|
|||
copy(child = newLeft, initialState = newRight)
|
||||
}
|
||||
|
||||
object FlatMapGroupsWithStateExec {
|
||||
|
||||
def foundDuplicateInitialKeyException(): Exception = {
|
||||
throw new IllegalArgumentException("The initial state provided contained " +
|
||||
"multiple rows(state) with the same key. Make sure to de-duplicate the " +
|
||||
"initial state before passing it.")
|
||||
}
|
||||
|
||||
/**
|
||||
* Plan logical flatmapGroupsWIthState for batch queries
|
||||
* If the initial state is provided, we create an instance of the CoGroupExec, if the initial
|
||||
* state is not provided we create an instance of the MapGroupsExec
|
||||
*/
|
||||
// scalastyle:off argcount
|
||||
def generateSparkPlanForBatchQueries(
|
||||
userFunc: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
|
||||
keyDeserializer: Expression,
|
||||
valueDeserializer: Expression,
|
||||
initialStateDeserializer: Expression,
|
||||
groupingAttributes: Seq[Attribute],
|
||||
initialStateGroupAttrs: Seq[Attribute],
|
||||
dataAttributes: Seq[Attribute],
|
||||
initialStateDataAttrs: Seq[Attribute],
|
||||
outputObjAttr: Attribute,
|
||||
timeoutConf: GroupStateTimeout,
|
||||
hasInitialState: Boolean,
|
||||
initialState: SparkPlan,
|
||||
child: SparkPlan): SparkPlan = {
|
||||
if (hasInitialState) {
|
||||
val watermarkPresent = child.output.exists {
|
||||
case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true
|
||||
case _ => false
|
||||
}
|
||||
val func = (keyRow: Any, values: Iterator[Any], states: Iterator[Any]) => {
|
||||
// Check if there is only one state for every key.
|
||||
var foundInitialStateForKey = false
|
||||
val optionalStates = states.map { stateValue =>
|
||||
if (foundInitialStateForKey) {
|
||||
foundDuplicateInitialKeyException()
|
||||
}
|
||||
foundInitialStateForKey = true
|
||||
stateValue
|
||||
}.toArray
|
||||
|
||||
// Create group state object
|
||||
val groupState = GroupStateImpl.createForStreaming(
|
||||
optionalStates.headOption,
|
||||
System.currentTimeMillis,
|
||||
GroupStateImpl.NO_TIMESTAMP,
|
||||
timeoutConf,
|
||||
hasTimedOut = false,
|
||||
watermarkPresent)
|
||||
|
||||
// Call user function with the state and values for this key
|
||||
userFunc(keyRow, values, groupState)
|
||||
}
|
||||
CoGroupExec(
|
||||
func, keyDeserializer, valueDeserializer, initialStateDeserializer, groupingAttributes,
|
||||
initialStateGroupAttrs, dataAttributes, initialStateDataAttrs, outputObjAttr,
|
||||
child, initialState)
|
||||
} else {
|
||||
MapGroupsExec(
|
||||
userFunc, keyDeserializer, valueDeserializer, groupingAttributes,
|
||||
dataAttributes, outputObjAttr, timeoutConf, child)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -196,14 +196,7 @@ public class JavaDatasetSuite implements Serializable {
|
|||
GroupStateTimeout.NoTimeout(),
|
||||
kvInitStateMappedDS);
|
||||
|
||||
Assert.assertThrows(
|
||||
"Initial state is not supported in [flatMap|map]GroupsWithState " +
|
||||
"operation on a batch DataFrame/Dataset",
|
||||
AnalysisException.class,
|
||||
() -> {
|
||||
flatMapped2.collectAsList();
|
||||
}
|
||||
);
|
||||
Assert.assertEquals(asSet("1a", "2", "3foobar"), toSet(flatMapped2.collectAsList()));
|
||||
Dataset<String> mapped2 = grouped.mapGroupsWithState(
|
||||
(MapGroupsWithStateFunction<Integer, String, Long, String>) (key, values, s) -> {
|
||||
StringBuilder sb = new StringBuilder(key.toString());
|
||||
|
@ -216,14 +209,7 @@ public class JavaDatasetSuite implements Serializable {
|
|||
Encoders.STRING(),
|
||||
GroupStateTimeout.NoTimeout(),
|
||||
kvInitStateMappedDS);
|
||||
Assert.assertThrows(
|
||||
"Initial state is not supported in [flatMap|map]GroupsWithState " +
|
||||
"operation on a batch DataFrame/Dataset",
|
||||
AnalysisException.class,
|
||||
() -> {
|
||||
mapped2.collectAsList();
|
||||
}
|
||||
);
|
||||
Assert.assertEquals(asSet("1a", "2", "3foobar"), toSet(mapped2.collectAsList()));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -1284,6 +1284,12 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
|
|||
assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
|
||||
assertCannotGetWatermark { state.getCurrentWatermarkMs() }
|
||||
assert(!state.hasTimedOut)
|
||||
if (key.contains("EventTime")) {
|
||||
state.setTimeoutTimestamp(0, "1 hour")
|
||||
}
|
||||
if (key.contains("ProcessingTime")) {
|
||||
state.setTimeoutDuration("1 hour")
|
||||
}
|
||||
val count = state.getOption.map(_.count).getOrElse(0L) + valList.size
|
||||
// We need to check if not explicitly calling update will still save the init state or not
|
||||
if (!key.contains("NoUpdate")) {
|
||||
|
@ -1413,6 +1419,52 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
|
|||
)
|
||||
}
|
||||
|
||||
Seq(NoTimeout(), EventTimeTimeout(), ProcessingTimeTimeout()).foreach { timeout =>
|
||||
test(s"flatMapGroupsWithState - initial state - batch mode - timeout ${timeout}") {
|
||||
// We will test them on different shuffle partition configuration to make sure the
|
||||
// grouping by key will still work. On higher number of shuffle partitions its possible
|
||||
// that all keys end up on different partitions.
|
||||
val initialState = Seq(
|
||||
(s"keyInStateAndData-1-$timeout", new RunningCount(1)),
|
||||
("keyInStateAndData-2", new RunningCount(2)),
|
||||
("keyNoUpdate", new RunningCount(2)), // state.update will not be called
|
||||
("keyOnlyInState-1", new RunningCount(1))
|
||||
).toDS().groupByKey(x => x._1).mapValues(_._2)
|
||||
|
||||
val inputData = Seq(
|
||||
("keyOnlyInData"), ("keyInStateAndData-2")
|
||||
)
|
||||
val result = inputData.toDS().groupByKey(x => x)
|
||||
.flatMapGroupsWithState(
|
||||
Update, timeout, initialState)(flatMapGroupsWithStateFunc)
|
||||
|
||||
val expected = Seq(
|
||||
("keyOnlyInState-1", Seq[String](), "1"),
|
||||
("keyNoUpdate", Seq[String](), "2"), // update will not be called
|
||||
("keyInStateAndData-2", Seq[String]("keyInStateAndData-2"), "3"), // inc by 1
|
||||
(s"keyInStateAndData-1-$timeout", Seq[String](), "1"),
|
||||
("keyOnlyInData", Seq[String]("keyOnlyInData"), "1") // inc by 1
|
||||
).toDF()
|
||||
checkAnswer(result.toDF(), expected)
|
||||
}
|
||||
}
|
||||
|
||||
testQuietly("flatMapGroupsWithState - initial state - batch mode - duplicate state") {
|
||||
val initialState = Seq(
|
||||
("a", new RunningCount(1)),
|
||||
("a", new RunningCount(2))
|
||||
).toDS().groupByKey(x => x._1).mapValues(_._2)
|
||||
|
||||
val e = intercept[SparkException] {
|
||||
Seq("a", "b").toDS().groupByKey(x => x)
|
||||
.flatMapGroupsWithState(Update, NoTimeout(), initialState)(flatMapGroupsWithStateFunc)
|
||||
.show()
|
||||
}
|
||||
assert(e.getMessage.contains(
|
||||
"The initial state provided contained multiple rows(state) with the same key." +
|
||||
" Make sure to de-duplicate the initial state before passing it."))
|
||||
}
|
||||
|
||||
testQuietly("flatMapGroupsWithState - initial state - streaming initial state") {
|
||||
val initialStateData = MemoryStream[(String, RunningCount)]
|
||||
initialStateData.addData(("a", new RunningCount(1)))
|
||||
|
|
Loading…
Reference in a new issue