diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 5e79232a20..bd8d5d7b43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -230,6 +230,20 @@ case class FlatMapGroupsWithStateExec( // When the iterator is consumed, then write changes to state def onIteratorCompletion: Unit = { + + val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp + // If the state has not yet been set but timeout has been set, then + // we have to generate a row to save the timeout. However, attempting serialize + // null using case class encoder throws - + // java.lang.NullPointerException: Null value appeared in non-nullable field: + // If the schema is inferred from a Scala tuple / case class, or a Java bean, please + // try to use scala.Option[_] or other nullable types. + if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) { + throw new IllegalStateException( + "Cannot set timeout when state is not defined, that is, state has not been" + + "initialized or has been removed") + } + if (keyedState.hasRemoved) { store.remove(keyRow) numUpdatedStateRows += 1 @@ -239,7 +253,6 @@ case class FlatMapGroupsWithStateExec( case Some(row) => getTimeoutTimestamp(row) case None => NO_TIMESTAMP } - val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp val stateRowToWrite = if (keyedState.hasUpdated) { getStateRow(keyedState.get) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index 148d92247d..d4606fd5a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -91,7 +91,6 @@ private[sql] class GroupStateImpl[S]( defined = false updated = false removed = true - timeoutTimestamp = NO_TIMESTAMP } override def setTimeoutDuration(durationMs: Long): Unit = { @@ -100,16 +99,10 @@ private[sql] class GroupStateImpl[S]( "Cannot set timeout duration without enabling processing time timeout in " + "map/flatMapGroupsWithState") } - if (!defined) { - throw new IllegalStateException( - "Cannot set timeout information without any state value, " + - "state has either not been initialized, or has already been removed") - } - if (durationMs <= 0) { throw new IllegalArgumentException("Timeout duration must be positive") } - if (!removed && batchProcessingTimeMs != NO_TIMESTAMP) { + if (batchProcessingTimeMs != NO_TIMESTAMP) { timeoutTimestamp = durationMs + batchProcessingTimeMs } else { // This is being called in a batch query, hence no processing timestamp. @@ -135,7 +128,7 @@ private[sql] class GroupStateImpl[S]( s"Timeout timestamp ($timestampMs) cannot be earlier than the " + s"current watermark ($eventTimeWatermarkMs)") } - if (!removed && batchProcessingTimeMs != NO_TIMESTAMP) { + if (batchProcessingTimeMs != NO_TIMESTAMP) { timeoutTimestamp = timestampMs } else { // This is being called in a batch query, hence no processing timestamp. @@ -213,11 +206,6 @@ private[sql] class GroupStateImpl[S]( "Cannot set timeout timestamp without enabling event time timeout in " + "map/flatMapGroupsWithState") } - if (!defined) { - throw new IllegalStateException( - "Cannot set timeout timestamp without any state value, " + - "state has either not been initialized, or has already been removed") - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index c659ac7fcf..04a956b70b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -212,7 +212,7 @@ trait GroupState[S] extends LogicalGroupState[S] { @throws[IllegalArgumentException]("when updating with null") def update(newState: S): Unit - /** Remove this state. Note that this resets any timeout configuration as well. */ + /** Remove this state. */ def remove(): Unit /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 89cfba6c55..10e91740eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -112,11 +112,12 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf state = new GroupStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) - testTimeoutDurationNotAllowed[IllegalStateException](state) + state.setTimeoutDuration(500) + assert(state.getTimeoutTimestamp === 1500) // can be set without initializing state testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) state.update(5) - assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + assert(state.getTimeoutTimestamp === 1500) // does not change state.setTimeoutDuration(1000) assert(state.getTimeoutTimestamp === 2000) state.setTimeoutDuration("2 second") @@ -124,8 +125,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) state.remove() - assert(state.getTimeoutTimestamp === NO_TIMESTAMP) - testTimeoutDurationNotAllowed[IllegalStateException](state) + assert(state.getTimeoutTimestamp === 3000) // does not change + state.setTimeoutDuration(500) // can still be set + assert(state.getTimeoutTimestamp === 1500) testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) } @@ -134,9 +136,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf None, 1000, 1000, EventTimeTimeout, hasTimedOut = false) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) - testTimeoutTimestampNotAllowed[IllegalStateException](state) + state.setTimeoutTimestamp(5000) + assert(state.getTimeoutTimestamp === 5000) // can be set without initializing state state.update(5) + assert(state.getTimeoutTimestamp === 5000) // does not change state.setTimeoutTimestamp(10000) assert(state.getTimeoutTimestamp === 10000) state.setTimeoutTimestamp(new Date(20000)) @@ -144,9 +148,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testTimeoutDurationNotAllowed[UnsupportedOperationException](state) state.remove() - assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + assert(state.getTimeoutTimestamp === 20000) + state.setTimeoutTimestamp(5000) + assert(state.getTimeoutTimestamp === 5000) // can be set after removing state testTimeoutDurationNotAllowed[UnsupportedOperationException](state) - testTimeoutTimestampNotAllowed[IllegalStateException](state) } test("GroupState - illegal params to setTimeout****") { @@ -154,26 +159,54 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf // Test setTimeout****() with illegal values def testIllegalTimeout(body: => Unit): Unit = { - intercept[IllegalArgumentException] { body } + intercept[IllegalArgumentException] { + body + } assert(state.getTimeoutTimestamp === NO_TIMESTAMP) } state = new GroupStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) - testIllegalTimeout { state.setTimeoutDuration(-1000) } - testIllegalTimeout { state.setTimeoutDuration(0) } - testIllegalTimeout { state.setTimeoutDuration("-2 second") } - testIllegalTimeout { state.setTimeoutDuration("-1 month") } - testIllegalTimeout { state.setTimeoutDuration("1 month -1 day") } + testIllegalTimeout { + state.setTimeoutDuration(-1000) + } + testIllegalTimeout { + state.setTimeoutDuration(0) + } + testIllegalTimeout { + state.setTimeoutDuration("-2 second") + } + testIllegalTimeout { + state.setTimeoutDuration("-1 month") + } + testIllegalTimeout { + state.setTimeoutDuration("1 month -1 day") + } state = new GroupStateImpl(Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) - testIllegalTimeout { state.setTimeoutTimestamp(-10000) } - testIllegalTimeout { state.setTimeoutTimestamp(10000, "-3 second") } - testIllegalTimeout { state.setTimeoutTimestamp(10000, "-1 month") } - testIllegalTimeout { state.setTimeoutTimestamp(10000, "1 month -1 day") } - testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000)) } - testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-3 second") } - testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-1 month") } - testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day") } + testIllegalTimeout { + state.setTimeoutTimestamp(-10000) + } + testIllegalTimeout { + state.setTimeoutTimestamp(10000, "-3 second") + } + testIllegalTimeout { + state.setTimeoutTimestamp(10000, "-1 month") + } + testIllegalTimeout { + state.setTimeoutTimestamp(10000, "1 month -1 day") + } + testIllegalTimeout { + state.setTimeoutTimestamp(new Date(-10000)) + } + testIllegalTimeout { + state.setTimeoutTimestamp(new Date(-10000), "-3 second") + } + testIllegalTimeout { + state.setTimeoutTimestamp(new Date(-10000), "-1 month") + } + testIllegalTimeout { + state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day") + } } test("GroupState - hasTimedOut") { @@ -318,6 +351,44 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } + // Currently disallowed cases for StateStoreUpdater.updateStateForKeysWithData(), + // Try to remove these cases in the future + for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { + val testName = + if (priorTimeoutTimestamp != NO_TIMESTAMP) "prior timeout set" else "no prior timeout" + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - setting timeout without init state not allowed", + stateUpdates = state => { state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedException = classOf[IllegalStateException]) + + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - setting timeout with state removal not allowed", + stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = Some(5), + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedException = classOf[IllegalStateException]) + + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout without init state not allowed", + stateUpdates = state => { state.setTimeoutTimestamp(10000) }, + timeoutConf = EventTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedException = classOf[IllegalStateException]) + + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", + stateUpdates = state => { state.remove(); state.setTimeoutTimestamp(10000) }, + timeoutConf = EventTimeTimeout, + priorState = Some(5), + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedException = classOf[IllegalStateException]) + } + // Tests for StateStoreUpdater.updateStateForTimedOutKeys() val preTimeoutState = Some(5) for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { @@ -806,7 +877,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf priorState: Option[Int], priorTimeoutTimestamp: Long = NO_TIMESTAMP, expectedState: Option[Int] = None, - expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { + expectedTimeoutTimestamp: Long = NO_TIMESTAMP, + expectedException: Class[_ <: Exception] = null): Unit = { if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) { return // there can be no prior timestamp, when there is no prior state @@ -820,7 +892,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } testStateUpdate( testTimeoutUpdates = false, mapGroupsFunc, timeoutConf, - priorState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp) + priorState, priorTimeoutTimestamp, + expectedState, expectedTimeoutTimestamp, expectedException) } } @@ -839,9 +912,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf stateUpdates(state) Iterator.empty } + testStateUpdate( testTimeoutUpdates = true, mapGroupsFunc, timeoutConf = timeoutConf, - preTimeoutState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp) + preTimeoutState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp, null) } } @@ -852,7 +926,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf priorState: Option[Int], priorTimeoutTimestamp: Long, expectedState: Option[Int], - expectedTimeoutTimestamp: Long): Unit = { + expectedTimeoutTimestamp: Long, + expectedException: Class[_ <: Exception]): Unit = { val store = newStateStore() val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( @@ -867,22 +942,30 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } // Call updating function to update state store - val returnedIter = if (testTimeoutUpdates) { - updater.updateStateForTimedOutKeys() - } else { - updater.updateStateForKeysWithData(Iterator(key)) + def callFunction() = { + val returnedIter = if (testTimeoutUpdates) { + updater.updateStateForTimedOutKeys() + } else { + updater.updateStateForKeysWithData(Iterator(key)) + } + returnedIter.size // consume the iterator to force state updates } - returnedIter.size // consumer the iterator to force state updates - - // Verify updated state in store - val updatedStateRow = store.get(key) - assert( - updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState, - "final state not as expected") - if (updatedStateRow.nonEmpty) { + if (expectedException != null) { + // Call function and verify the exception type + val e = intercept[Exception] { callFunction() } + assert(e.getClass === expectedException, "Exception thrown but of the wrong type") + } else { + // Call function to update and verify updated state in store + callFunction() + val updatedStateRow = store.get(key) assert( - updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp, - "final timeout timestamp not as expected") + updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState, + "final state not as expected") + if (updatedStateRow.nonEmpty) { + assert( + updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp, + "final timeout timestamp not as expected") + } } }