diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIterator.scala new file mode 100644 index 0000000000..a923ebd798 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIterator.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, StreamingSessionWindowStateManager} + +/** + * This class technically does the merge sort between input rows and existing sessions in state, + * to optimize the cost of sort on "input rows + existing sessions". This is based on the + * precondition that input rows are sorted by "group keys + start time of session window". + * + * This only materializes the existing sessions into memory, which are tend to be not many per + * group key. The cost of sorting existing sessions would be also minor based on the assumption. + * + * The output rows are sorted with "group keys + start time of session window", which is same as + * the sort condition on input rows. + */ +class MergingSortWithSessionWindowStateIterator( + iter: Iterator[InternalRow], + stateManager: StreamingSessionWindowStateManager, + store: ReadStateStore, + groupWithoutSessionExpressions: Seq[Attribute], + sessionExpression: Attribute, + inputSchema: Seq[Attribute]) extends Iterator[InternalRow] with Logging { + + private val keysProjection: UnsafeProjection = GenerateUnsafeProjection.generate( + groupWithoutSessionExpressions, inputSchema) + private val sessionProjection: UnsafeProjection = + GenerateUnsafeProjection.generate(Seq(sessionExpression), inputSchema) + + private case class SessionRowInformation( + keys: UnsafeRow, + sessionStart: Long, + sessionEnd: Long, + row: InternalRow) + + private object SessionRowInformation { + def of(row: InternalRow): SessionRowInformation = { + val keys = keysProjection(row).copy() + val session = sessionProjection(row).copy() + val sessionRow = session.getStruct(0, 2) + val sessionStart = sessionRow.getLong(0) + val sessionEnd = sessionRow.getLong(1) + + SessionRowInformation(keys, sessionStart, sessionEnd, row) + } + } + + // Holds the latest fetched row from input side iterator. + private var currentRowFromInput: SessionRowInformation = _ + + // Holds the latest fetched row from state side iterator. + private var currentRowFromState: SessionRowInformation = _ + + // Holds the iterator of rows (sessions) in state for the session key. + private var sessionIterFromState: Iterator[InternalRow] = _ + + // Holds the current session key. + private var currentSessionKey: UnsafeRow = _ + + override def hasNext: Boolean = { + currentRowFromInput != null || currentRowFromState != null || + (sessionIterFromState != null && sessionIterFromState.hasNext) || iter.hasNext + } + + override def next(): InternalRow = { + if (currentRowFromInput == null) { + mayFillCurrentRow() + } + + if (currentRowFromState == null) { + mayFillCurrentStateRow() + } + + if (currentRowFromInput == null && currentRowFromState == null) { + throw new IllegalStateException("No Row to provide in next() which should not happen!") + } + + // return current row vs current state row, should return smaller key, earlier session start + val returnCurrentRow: Boolean = { + if (currentRowFromInput == null) { + false + } else if (currentRowFromState == null) { + true + } else { + // compare + if (currentRowFromInput.keys != currentRowFromState.keys) { + // state row cannot advance to row in input, so state row should be lower + false + } else { + currentRowFromInput.sessionStart < currentRowFromState.sessionStart + } + } + } + + val ret: SessionRowInformation = { + if (returnCurrentRow) { + val toRet = currentRowFromInput + currentRowFromInput = null + toRet + } else { + val toRet = currentRowFromState + currentRowFromState = null + toRet + } + } + + ret.row + } + + private def mayFillCurrentRow(): Unit = { + if (iter.hasNext) { + currentRowFromInput = SessionRowInformation.of(iter.next()) + } + } + + private def mayFillCurrentStateRow(): Unit = { + if (sessionIterFromState != null && sessionIterFromState.hasNext) { + currentRowFromState = SessionRowInformation.of(sessionIterFromState.next()) + } else { + sessionIterFromState = null + + if (currentRowFromInput != null && currentRowFromInput.keys != currentSessionKey) { + // We expect a small number of sessions per group key, so materializing them + // and sorting wouldn't hurt much. The important thing is that we shouldn't buffer input + // rows to sort with existing sessions. + val unsortedIter = stateManager.getSessions(store, currentRowFromInput.keys) + val unsortedList = unsortedIter.map(_.copy()).toList + + val sortedList = unsortedList.sortWith((row1, row2) => { + def getSessionStart(r: InternalRow): Long = { + val session = sessionProjection(r) + val sessionRow = session.getStruct(0, 2) + sessionRow.getLong(0) + } + + // here sorting is based on the fact that keys are same + getSessionStart(row1).compareTo(getSessionStart(row2)) < 0 + }) + sessionIterFromState = sortedList.iterator + + currentSessionKey = currentRowFromInput.keys + if (sessionIterFromState.hasNext) { + currentRowFromState = SessionRowInformation.of(sessionIterFromState.next()) + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index c604021e9c..75b7daef57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -70,7 +70,7 @@ import org.apache.spark.util.{SizeEstimator, Utils} * to ensure re-executed RDD operations re-apply updates on the correct past version of the * store. */ -private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider with Logging { +private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with Logging { class HDFSBackedReadStateStore(val version: Long, map: HDFSBackedStateStoreMap) extends ReadStateStore { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala new file mode 100644 index 0000000000..81f1a3f785 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.UUID + +import org.apache.hadoop.conf.Configuration +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId, StreamingSessionWindowStateManager} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} +import org.apache.spark.unsafe.types.UTF8String + +class MergingSortWithSessionWindowStateIteratorSuite extends StreamTest with BeforeAndAfter { + + private val rowSchema = new StructType().add("key1", StringType).add("key2", IntegerType) + .add("session", new StructType().add("start", LongType).add("end", LongType)) + .add("value", LongType) + private val rowAttributes = rowSchema.toAttributes + + private val keysWithoutSessionAttributes = rowAttributes.filter { attr => + List("key1", "key2").contains(attr.name) + } + + private val sessionAttribute = rowAttributes.filter(_.name == "session").head + + private val inputValueGen = UnsafeProjection.create(rowAttributes.map(_.dataType).toArray) + private val inputKeyGen = UnsafeProjection.create( + keysWithoutSessionAttributes.map(_.dataType).toArray) + + before { + SparkSession.setActiveSession(spark) + spark.streams.stateStoreCoordinator // initialize the lazy coordinator + } + + private val providerOptions = Seq( + classOf[HDFSBackedStateStoreProvider].getCanonicalName, + classOf[RocksDBStateStoreProvider].getCanonicalName).map { value => + (SQLConf.STATE_STORE_PROVIDER_CLASS.key, value.stripSuffix("$")) + } + + private val availableOptions = for ( + opt1 <- providerOptions; + opt2 <- StreamingSessionWindowStateManager.supportedVersions + ) yield (opt1, opt2) + + availableOptions.foreach { case (providerOpt, version) => + withSQLConf(providerOpt) { + test(s"StreamingSessionWindowStateManager " + + s"provider ${providerOpt._2} state version v${version} - rows only in state") { + testRowsOnlyInState(version) + } + + test(s"StreamingSessionWindowStateManager " + + s"provider ${providerOpt._2} state version v${version} - rows in both input and state") { + testRowsInBothInputAndState(version) + } + + test(s"StreamingSessionWindowStateManager " + + s"provider ${providerOpt._2} state version v${version} - rows only in input") { + testRowsOnlyInInput(version) + } + } + } + + private def testRowsOnlyInState(stateFormatVersion: Int): Unit = { + withStateManager(stateFormatVersion) { case (stateManager, store) => + val key = createKeyRow("a", 1) + val values = Seq( + createRow("a", 1, 100, 110, 1), + createRow("a", 1, 120, 130, 2), + createRow("a", 1, 140, 150, 3)) + + stateManager.updateSessions(store, key, values) + + val iter = new MergingSortWithSessionWindowStateIterator( + Iterator.empty, + stateManager, + store, + keysWithoutSessionAttributes, + sessionAttribute, + rowAttributes) + + val actual = iter.map(_.copy()).toList + assert(actual.isEmpty) + } + } + + private def testRowsInBothInputAndState(stateFormatVersion: Int): Unit = { + withStateManager(stateFormatVersion) { case (stateManager, store) => + val key1 = createKeyRow("a", 1) + val key1Values = Seq( + createRow("a", 1, 100, 110, 1), + createRow("a", 1, 120, 130, 2), + createRow("a", 1, 140, 150, 3)) + + // This is to ensure sessions will not be populated if the input doesn't have such group key + val key2 = createKeyRow("a", 2) + val key2Values = Seq( + createRow("a", 2, 100, 110, 1), + createRow("a", 2, 120, 130, 2), + createRow("a", 2, 140, 150, 3)) + + val key3 = createKeyRow("b", 1) + val key3Values = Seq( + createRow("b", 1, 100, 110, 1), + createRow("b", 1, 120, 130, 2), + createRow("b", 1, 140, 150, 3)) + + stateManager.updateSessions(store, key1, key1Values) + stateManager.updateSessions(store, key2, key2Values) + stateManager.updateSessions(store, key3, key3Values) + + val inputsForKey1 = Seq( + createRow("a", 1, 90, 100, 1), + createRow("a", 1, 125, 135, 2)) + val inputsForKey3 = Seq( + createRow("b", 1, 150, 160, 3) + ) + val inputs = inputsForKey1 ++ inputsForKey3 + + val iter = new MergingSortWithSessionWindowStateIterator( + inputs.iterator, + stateManager, + store, + keysWithoutSessionAttributes, + sessionAttribute, + rowAttributes) + + val actual = iter.map(_.copy()).toList + val expected = (key1Values ++ inputsForKey1).sortBy(getSessionStart) ++ + (key3Values ++ inputsForKey3).sortBy(getSessionStart) + assert(actual === expected.toList) + } + } + + private def testRowsOnlyInInput(stateFormatVersion: Int): Unit = { + withStateManager(stateFormatVersion) { case (stateManager, store) => + // This is to ensure sessions will not be populated if the input doesn't have such group key + val key1 = createKeyRow("a", 1) + val key1Values = Seq( + createRow("a", 1, 100, 110, 1), + createRow("a", 1, 120, 130, 2), + createRow("a", 1, 140, 150, 3)) + + stateManager.updateSessions(store, key1, key1Values) + + val inputs = Seq( + createRow("b", 1, 100, 110, 1), + createRow("b", 1, 120, 130, 2), + createRow("b", 1, 140, 150, 3)) + + val iter = new MergingSortWithSessionWindowStateIterator( + inputs.iterator, + stateManager, + store, + keysWithoutSessionAttributes, + sessionAttribute, + rowAttributes) + + val actual = iter.map(_.copy()).toList + assert(actual === inputs.toList) + } + } + + private def createRow( + key1: String, + key2: Int, + sessionStart: Long, + sessionEnd: Long, + value: Long): UnsafeRow = { + val sessionRow = new GenericInternalRow(Array[Any](sessionStart, sessionEnd)) + val row = new GenericInternalRow( + Array[Any](UTF8String.fromString(key1), key2, sessionRow, value)) + inputValueGen.apply(row).copy() + } + + private def createKeyRow(key1: String, key2: Int): UnsafeRow = { + val row = new GenericInternalRow(Array[Any](UTF8String.fromString(key1), key2)) + inputKeyGen.apply(row).copy() + } + + private def getSessionStart(row: UnsafeRow): Long = { + row.getStruct(2, 2).getLong(0) + } + + private def withStateManager( + stateFormatVersion: Int)( + f: (StreamingSessionWindowStateManager, StateStore) => Unit): Unit = { + withTempDir { file => + val storeConf = new StateStoreConf() + val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) + + val manager = StreamingSessionWindowStateManager.createStateManager( + keysWithoutSessionAttributes, + sessionAttribute, + rowAttributes, + stateFormatVersion) + + val storeProviderId = StateStoreProviderId(stateInfo, 0, StateStoreId.DEFAULT_STORE_NAME) + val store = StateStore.get( + storeProviderId, manager.getStateKeySchema, manager.getStateValueSchema, + manager.getNumColsForPrefixKey, stateInfo.storeVersion, storeConf, new Configuration) + + try { + f(manager, store) + } finally { + manager.abortIfNeeded(store) + } + } + } +}