[SPARK-34891][SS] Introduce state store manager for session window in streaming query

Introduction: this PR is a part of SPARK-10816 (`EventTime based sessionization (session window)`). Please refer #31937 to see the overall view of the code change. (Note that code diff could be diverged a bit.)

### What changes were proposed in this pull request?

This PR introduces state store manager for session window in streaming query. Session window in batch query wouldn't need to leverage state store manager.

This PR ensures versioning on state format for state store manager, so that we can apply further optimization after releasing Spark version. StreamingSessionWindowStateManager is a trait defining the available methods in session window state store manager. Its subclasses are classes implementing the trait with versioning.

The format of version 1 leverages the new feature of "prefix match scan" to represent the session windows:

* full key : [ group keys, start time in session window ]
* prefix key [ group keys ]

### Why are the changes needed?

This part is a one of required on implementing SPARK-10816.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New test suite added

Closes #31989 from HeartSaVioR/SPARK-34891-SPARK-10816-PR-31570-part-3.

Authored-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
(cherry picked from commit 0fe2d809d6)
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
This commit is contained in:
Jungtaek Lim 2021-07-13 08:58:31 -07:00 committed by Liang-Chi Hsieh
parent 3ace01b25b
commit fa8c37acb1
2 changed files with 464 additions and 0 deletions

View file

@ -0,0 +1,269 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming.state
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.types.{StructType, TimestampType}
import org.apache.spark.util.NextIterator
sealed trait StreamingSessionWindowStateManager extends Serializable {
/**
* Returns the schema for key of the state.
*/
def getStateKeySchema: StructType
/**
* Returns the schema for value of the state.
*/
def getStateValueSchema: StructType
/**
* Returns the number of columns for `prefix key` in key schema.
*/
def getNumColsForPrefixKey: Int
/**
* Extracts the key without session window from the row.
* This can be used to group session windows by key.
*/
def extractKeyWithoutSession(value: UnsafeRow): UnsafeRow
/**
* Returns true if the session of the given value doesn't exist in the store, or the value
* in the session is different to the stored value of the session in the store.
* This can be used to control the output in UPDATE mode.
*/
def newOrModified(store: ReadStateStore, value: UnsafeRow): Boolean
/**
* Returns all sessions for the key.
*
* @param key The key without session, which can be retrieved from
* {@code extractKeyWithoutSession}.
*/
def getSessions(store: ReadStateStore, key: UnsafeRow): Iterator[UnsafeRow]
/**
* Replaces all sessions for the key to given one.
*
* @param key The key without session, which can be retrieved from
* {@code extractKeyWithoutSession}.
* @param sessions The all sessions including existing sessions if it's active.
* Existing sessions which aren't included in this parameter will be removed.
*/
def updateSessions(store: StateStore, key: UnsafeRow, sessions: Seq[UnsafeRow]): Unit
/**
* Removes using a predicate on values, with returning removed values via iterator.
*
* At a high level, this produces an iterator over the (key, value, matched) tuples such that
* value satisfies the predicate, where producing an element removes the value from the
* state store and producing all elements with a given key updates it accordingly.
*
* This implies the iterator must be consumed fully without any other operations on this manager
* or the underlying store being interleaved.
*
* @param removalCondition The predicate on removing the key-value.
*/
def removeByValueCondition(
store: StateStore,
removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRow]
/**
* Return an iterator containing all the sessions. Implementations must ensure that updates
* (puts, removes) can be made while iterating over this iterator.
*/
def iterator(store: ReadStateStore): Iterator[UnsafeRow]
/**
* Commits the change.
*/
def commit(store: StateStore): Long
/**
* Aborts the change.
*/
def abortIfNeeded(store: StateStore): Unit
}
object StreamingSessionWindowStateManager {
val supportedVersions = Seq(1)
def createStateManager(
keyWithoutSessionExpressions: Seq[Attribute],
sessionExpression: Attribute,
inputRowAttributes: Seq[Attribute],
stateFormatVersion: Int): StreamingSessionWindowStateManager = {
stateFormatVersion match {
case 1 => new StreamingSessionWindowStateManagerImplV1(
keyWithoutSessionExpressions, sessionExpression, inputRowAttributes)
case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid")
}
}
}
class StreamingSessionWindowStateManagerImplV1(
keyWithoutSessionExpressions: Seq[Attribute],
sessionExpression: Attribute,
valueAttributes: Seq[Attribute])
extends StreamingSessionWindowStateManager with Logging {
private val stateKeyStructType = keyWithoutSessionExpressions.toStructType
.add("sessionStartTime", TimestampType, nullable = false)
private val stateKeyExprs = keyWithoutSessionExpressions :+ Literal(1L)
private val indexOrdinalInSessionStart = keyWithoutSessionExpressions.size
@transient private lazy val keyRowGenerator = UnsafeProjection.create(
keyWithoutSessionExpressions, valueAttributes)
@transient private lazy val stateKeyRowGenerator = UnsafeProjection.create(stateKeyExprs,
keyWithoutSessionExpressions)
@transient private lazy val helper = new StreamingSessionWindowHelper(
sessionExpression, valueAttributes)
override def getStateKeySchema: StructType = stateKeyStructType
override def getStateValueSchema: StructType = valueAttributes.toStructType
override def getNumColsForPrefixKey: Int = keyWithoutSessionExpressions.length
override def extractKeyWithoutSession(value: UnsafeRow): UnsafeRow = {
keyRowGenerator(value)
}
override def newOrModified(store: ReadStateStore, value: UnsafeRow): Boolean = {
val sessionStart = helper.extractTimePair(value)._1
val stateKey = getStateKey(getKey(value), sessionStart)
val stateRow = store.get(stateKey)
stateRow == null || !stateRow.equals(value)
}
override def getSessions(store: ReadStateStore, key: UnsafeRow): Iterator[UnsafeRow] =
getSessionsWithKeys(store, key).map(_.value)
private def getSessionsWithKeys(
store: ReadStateStore,
key: UnsafeRow): Iterator[UnsafeRowPair] = {
store.prefixScan(key)
}
override def updateSessions(
store: StateStore,
key: UnsafeRow,
sessions: Seq[UnsafeRow]): Unit = {
// Below two will be used multiple times - need to make sure this is not a stream or iterator.
val newValues = sessions.toList
val savedStates = getSessionsWithKeys(store, key)
.map(pair => (pair.key.copy(), pair.value.copy())).toList
putRows(store, key, savedStates, newValues)
}
override def commit(store: StateStore): Long = store.commit()
override def iterator(store: ReadStateStore): Iterator[UnsafeRow] =
iteratorWithKeys(store).map(_.value)
private def iteratorWithKeys(store: ReadStateStore): Iterator[UnsafeRowPair] = {
store.iterator()
}
override def removeByValueCondition(
store: StateStore,
removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRow] = {
new NextIterator[UnsafeRow] {
private val rangeIter = iteratorWithKeys(store)
override protected def getNext(): UnsafeRow = {
var removedValueRow: UnsafeRow = null
while (rangeIter.hasNext && removedValueRow == null) {
val rowPair = rangeIter.next()
if (removalCondition(rowPair.value)) {
store.remove(rowPair.key)
removedValueRow = rowPair.value
}
}
if (removedValueRow == null) {
finished = true
null
} else {
removedValueRow
}
}
override protected def close(): Unit = {}
}
}
private def getKey(value: UnsafeRow): UnsafeRow = keyRowGenerator(value)
private def getStateKey(key: UnsafeRow, sessionStart: Long): UnsafeRow = {
val stateKey = stateKeyRowGenerator(key)
stateKey.setLong(indexOrdinalInSessionStart, sessionStart)
stateKey.copy()
}
private def putRows(
store: StateStore,
key: UnsafeRow,
oldValues: List[(UnsafeRow, UnsafeRow)],
values: List[UnsafeRow]): Unit = {
// Here the key doesn't represent the state key - we need to construct the key for state
val keyAndValues = values.map { row =>
val sessionStart = helper.extractTimePair(row)._1
val stateKey = getStateKey(key, sessionStart)
(stateKey, row)
}
val keysForValues = keyAndValues.map(_._1)
val keysForOldValues = oldValues.map(_._1)
// We should "replace" the value instead of "delete" and "put" if the start time
// equals to. This will remove unnecessary tombstone being written to the delta, which is
// implementation details on state store implementations.
keysForOldValues.filterNot(keysForValues.contains).foreach { oldKey =>
store.remove(oldKey)
}
keyAndValues.foreach { case (key, value) =>
store.put(key, value)
}
}
override def abortIfNeeded(store: StateStore): Unit = {
if (!store.hasCommitted) {
logInfo(s"Aborted store ${store.id}")
store.abort()
}
}
}
class StreamingSessionWindowHelper(sessionExpression: Attribute, inputSchema: Seq[Attribute]) {
private val sessionProjection: UnsafeProjection =
GenerateUnsafeProjection.generate(Seq(sessionExpression), inputSchema)
/** extract session_window (start, end) from UnsafeRow */
def extractTimePair(value: InternalRow): (Long, Long) = {
val window = sessionProjection(value).getStruct(0, 2)
(window.getLong(0), window.getLong(1))
}
}

View file

@ -0,0 +1,195 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming.state
import java.util.UUID
import org.apache.hadoop.conf.Configuration
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String
class StreamingSessionWindowStateManagerSuite extends StreamTest with BeforeAndAfter {
private val rowSchema = new StructType().add("key1", StringType).add("key2", IntegerType)
.add("session", new StructType().add("start", LongType).add("end", LongType))
.add("value", LongType)
private val rowAttributes = rowSchema.toAttributes
private val keysWithoutSessionAttributes = rowAttributes.filter { attr =>
List("key1", "key2").contains(attr.name)
}
private val sessionAttribute = rowAttributes.filter(_.name == "session").head
private val inputValueGen = UnsafeProjection.create(rowAttributes.map(_.dataType).toArray)
private val inputKeyGen = UnsafeProjection.create(
keysWithoutSessionAttributes.map(_.dataType).toArray)
before {
SparkSession.setActiveSession(spark)
spark.streams.stateStoreCoordinator // initialize the lazy coordinator
}
private val providerOptions = Seq(
classOf[HDFSBackedStateStoreProvider].getCanonicalName,
classOf[RocksDBStateStoreProvider].getCanonicalName).map { value =>
(SQLConf.STATE_STORE_PROVIDER_CLASS.key, value.stripSuffix("$"))
}
private val availableOptions = for (
opt1 <- providerOptions;
opt2 <- StreamingSessionWindowStateManager.supportedVersions
) yield (opt1, opt2)
availableOptions.foreach { case (providerOpt, version) =>
withSQLConf(providerOpt) {
test(s"StreamingSessionWindowStateManager " +
s"provider ${providerOpt._2} state version v${version} - extract field operations") {
testExtractFieldOperations(version)
}
test(s"StreamingSessionWindowStateManager " +
s"provider ${providerOpt._2} state version v${version} - CRUD operations") {
testAllOperations(version)
}
}
}
private def createRow(
key1: String,
key2: Int,
sessionStart: Long,
sessionEnd: Long,
value: Long): UnsafeRow = {
val sessionRow = new GenericInternalRow(Array[Any](sessionStart, sessionEnd))
val row = new GenericInternalRow(
Array[Any](UTF8String.fromString(key1), key2, sessionRow, value))
inputValueGen.apply(row).copy()
}
private def createKeyRow(key1: String, key2: Int): UnsafeRow = {
val row = new GenericInternalRow(Array[Any](UTF8String.fromString(key1), key2))
inputKeyGen.apply(row).copy()
}
private def testExtractFieldOperations(stateFormatVersion: Int): Unit = {
withStateManager(stateFormatVersion) { case (stateManager, _) =>
val testRow = createRow("a", 1, 100, 150, 1)
val expectedKeyRow = createKeyRow("a", 1)
val keyWithoutSessionRow = stateManager.extractKeyWithoutSession(testRow)
assert(expectedKeyRow === keyWithoutSessionRow)
}
}
private def testAllOperations(stateFormatVersion: Int): Unit = {
withStateManager(stateFormatVersion) { case (stateManager, store) =>
def updateAndVerify(keyRow: UnsafeRow, rows: Seq[UnsafeRow]): Unit = {
stateManager.updateSessions(store, keyRow, rows)
val expectedValues = stateManager.getSessions(store, keyRow).map(_.copy()).toList
assert(expectedValues.toSet === rows.toSet)
rows.foreach { row =>
assert(!stateManager.newOrModified(store, row))
}
}
val key1Row = createKeyRow("a", 1)
val key1Values = Seq(
createRow("a", 1, 100, 110, 1),
createRow("a", 1, 120, 130, 2),
createRow("a", 1, 140, 150, 3))
updateAndVerify(key1Row, key1Values)
val key2Row = createKeyRow("a", 2)
val key2Values = Seq(
createRow("a", 2, 70, 80, 1),
createRow("a", 2, 100, 110, 2))
updateAndVerify(key2Row, key2Values)
val key2NewValues = Seq(
createRow("a", 2, 70, 80, 2),
createRow("a", 2, 80, 90, 3),
createRow("a", 2, 90, 120, 4),
createRow("a", 2, 140, 150, 5))
updateAndVerify(key2Row, key2NewValues)
val key3Row = createKeyRow("a", 3)
val key3Values = Seq(
createRow("a", 3, 100, 110, 1),
createRow("a", 3, 120, 130, 2))
updateAndVerify(key3Row, key3Values)
val valuesOnComparison = Seq(
// new
(createRow("a", 3, 10, 20, 1), true),
// modified
(createRow("a", 3, 100, 110, 3), true),
// exist and not modified
(createRow("a", 3, 120, 130, 2), false))
valuesOnComparison.foreach { case (row, expected) =>
assert(expected === stateManager.newOrModified(store, row))
}
val existingRows = stateManager.iterator(store).map(_.copy()).toSet
val removedRows = stateManager.removeByValueCondition(store,
_.getLong(3) <= 1).map(_.copy()).toSet
val expectedRemovedRows = Set(key1Values(0), key3Values(0))
assert(removedRows == expectedRemovedRows)
val afterRemovingRows = stateManager.iterator(store).map(_.copy()).toSet
assert(existingRows.diff(afterRemovingRows) === expectedRemovedRows)
}
}
private def withStateManager(
stateFormatVersion: Int)(
f: (StreamingSessionWindowStateManager, StateStore) => Unit): Unit = {
withTempDir { file =>
val storeConf = new StateStoreConf()
val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5)
val manager = StreamingSessionWindowStateManager.createStateManager(
keysWithoutSessionAttributes,
sessionAttribute,
rowAttributes,
stateFormatVersion)
val storeProviderId = StateStoreProviderId(stateInfo, 0, StateStoreId.DEFAULT_STORE_NAME)
val store = StateStore.get(
storeProviderId, manager.getStateKeySchema, manager.getStateValueSchema,
manager.getNumColsForPrefixKey, stateInfo.storeVersion, storeConf, new Configuration)
try {
f(manager, store)
} finally {
manager.abortIfNeeded(store)
}
}
}
}