diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala new file mode 100644 index 0000000000..3ebaa8c87a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io._ + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.Platform +import org.apache.spark.util.Utils + +private[state] class RocksDBStateStoreProvider + extends StateStoreProvider with Logging with Closeable { + import RocksDBStateStoreProvider._ + + class RocksDBStateStore(lastVersion: Long) extends StateStore { + /** Trait and classes representing the internal state of the store */ + trait STATE + case object UPDATING extends STATE + case object COMMITTED extends STATE + case object ABORTED extends STATE + + @volatile private var state: STATE = UPDATING + @volatile private var isValidated = false + + override def id: StateStoreId = RocksDBStateStoreProvider.this.stateStoreId + + override def version: Long = lastVersion + + override def get(key: UnsafeRow): UnsafeRow = { + verify(key != null, "Key cannot be null") + val value = encoder.decodeValue(rocksDB.get(encoder.encode(key))) + if (!isValidated && value != null) { + StateStoreProvider.validateStateRowFormat( + key, keySchema, value, valueSchema, storeConf) + isValidated = true + } + value + } + + override def put(key: UnsafeRow, value: UnsafeRow): Unit = { + verify(state == UPDATING, "Cannot put after already committed or aborted") + verify(key != null, "Key cannot be null") + require(value != null, "Cannot put a null value") + rocksDB.put(encoder.encode(key), encoder.encode(value)) + } + + override def remove(key: UnsafeRow): Unit = { + verify(state == UPDATING, "Cannot remove after already committed or aborted") + verify(key != null, "Key cannot be null") + rocksDB.remove(encoder.encode(key)) + } + + override def getRange( + start: Option[UnsafeRow], + end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = { + verify(state == UPDATING, "Cannot call getRange() after already committed or aborted") + iterator() + } + + override def iterator(): Iterator[UnsafeRowPair] = { + rocksDB.iterator().map { kv => + val rowPair = encoder.decode(kv) + if (!isValidated && rowPair.value != null) { + StateStoreProvider.validateStateRowFormat( + rowPair.key, keySchema, rowPair.value, valueSchema, storeConf) + isValidated = true + } + rowPair + } + } + + override def commit(): Long = synchronized { + verify(state == UPDATING, "Cannot commit after already committed or aborted") + val newVersion = rocksDB.commit() + state = COMMITTED + logInfo(s"Committed $newVersion for $id") + newVersion + } + + override def abort(): Unit = { + verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed") + logInfo(s"Aborting ${version + 1} for $id") + rocksDB.rollback() + state = ABORTED + } + + override def metrics: StateStoreMetrics = { + val rocksDBMetrics = rocksDB.metrics + def commitLatencyMs(typ: String): Long = rocksDBMetrics.lastCommitLatencyMs.getOrElse(typ, 0L) + def avgNativeOpsLatencyMs(typ: String): Long = { + rocksDBMetrics.nativeOpsLatencyMicros.get(typ).map(_.avg).getOrElse(0.0).toLong + } + + val stateStoreCustomMetrics = Map[StateStoreCustomMetric, Long]( + CUSTOM_METRIC_SST_FILE_SIZE -> rocksDBMetrics.totalSSTFilesBytes, + CUSTOM_METRIC_GET_TIME -> avgNativeOpsLatencyMs("get"), + CUSTOM_METRIC_PUT_TIME -> avgNativeOpsLatencyMs("put"), + CUSTOM_METRIC_WRITEBATCH_TIME -> commitLatencyMs("writeBatch"), + CUSTOM_METRIC_FLUSH_TIME -> commitLatencyMs("flush"), + CUSTOM_METRIC_PAUSE_TIME -> commitLatencyMs("pause"), + CUSTOM_METRIC_CHECKPOINT_TIME -> commitLatencyMs("checkpoint"), + CUSTOM_METRIC_FILESYNC_TIME -> commitLatencyMs("fileSync"), + CUSTOM_METRIC_BYTES_COPIED -> rocksDBMetrics.bytesCopied, + CUSTOM_METRIC_FILES_COPIED -> rocksDBMetrics.filesCopied, + CUSTOM_METRIC_FILES_REUSED -> rocksDBMetrics.filesReused + ) ++ rocksDBMetrics.zipFileBytesUncompressed.map(bytes => + Map(CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED -> bytes)).getOrElse(Map()) + + StateStoreMetrics( + rocksDBMetrics.numUncommittedKeys, + rocksDBMetrics.memUsageBytes, + stateStoreCustomMetrics) + } + + override def hasCommitted: Boolean = state == COMMITTED + + override def toString: String = { + s"RocksDBStateStore[id=(op=${id.operatorId},part=${id.partitionId})," + + s"dir=${id.storeCheckpointLocation()}]" + } + + /** Return the [[RocksDB]] instance in this store. This is exposed mainly for testing. */ + def dbInstance(): RocksDB = rocksDB + } + + override def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int], + storeConf: StateStoreConf, + hadoopConf: Configuration): Unit = { + this.stateStoreId_ = stateStoreId + this.keySchema = keySchema + this.valueSchema = valueSchema + this.storeConf = storeConf + this.hadoopConf = hadoopConf + rocksDB // lazy initialization + } + + override def stateStoreId: StateStoreId = stateStoreId_ + + override def getStore(version: Long): StateStore = { + require(version >= 0, "Version cannot be less than 0") + rocksDB.load(version) + new RocksDBStateStore(version) + } + + override def doMaintenance(): Unit = { + rocksDB.cleanup() + } + + override def close(): Unit = { + rocksDB.close() + } + + override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = ALL_CUSTOM_METRICS + + private[state] def latestVersion: Long = rocksDB.getLatestVersion() + + /** Internal fields and methods */ + + @volatile private var stateStoreId_ : StateStoreId = _ + @volatile private var keySchema: StructType = _ + @volatile private var valueSchema: StructType = _ + @volatile private var storeConf: StateStoreConf = _ + @volatile private var hadoopConf: Configuration = _ + + private[sql] lazy val rocksDB = { + val dfsRootDir = stateStoreId.storeCheckpointLocation().toString + val storeIdStr = s"StateStoreId(opId=${stateStoreId.operatorId}," + + s"partId=${stateStoreId.partitionId},name=${stateStoreId.storeName})" + val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr) + new RocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr) + } + + private lazy val encoder = new StateEncoder + + private def verify(condition: => Boolean, msg: String): Unit = { + if (!condition) { throw new IllegalStateException(msg) } + } + + /** + * Encodes/decodes UnsafeRows to versioned byte arrays. + * It uses the first byte of the generated byte array to store the version that describes how the + * row is encoded in the rest of the byte array. Currently, the default version is 0, + * + * VERSION 0: [ VERSION (1 byte) | ROW (N bytes) ] + * The bytes of a UnsafeRow is written unmodified to starting from offset 1 + * (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes, + * then the generated array byte will be N+1 bytes. + */ + class StateEncoder { + import RocksDBStateStoreProvider._ + + // Reusable objects + private val keyRow = new UnsafeRow(keySchema.size) + private val valueRow = new UnsafeRow(valueSchema.size) + private val rowTuple = new UnsafeRowPair() + + /** + * Encode the UnsafeRow of N bytes as a N+1 byte array. + * @note This creates a new byte array and memcopies the UnsafeRow to the new array. + */ + def encode(row: UnsafeRow): Array[Byte] = { + val bytesToEncode = row.getBytes + val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES) + Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION) + // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte arrays. See Platform. + Platform.copyMemory( + bytesToEncode, Platform.BYTE_ARRAY_OFFSET, + encodedBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, + bytesToEncode.length) + encodedBytes + } + + /** + * Decode byte array for a key to a UnsafeRow. + * @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to + * the given byte array. + */ + def decodeKey(keyBytes: Array[Byte]): UnsafeRow = { + if (keyBytes != null) { + // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform. + keyRow.pointTo( + keyBytes, + Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, + keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES) + keyRow + } else { + null + } + } + + /** + * Decode byte array for a value to a UnsafeRow. + * + * @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to + * the given byte array. + */ + def decodeValue(valueBytes: Array[Byte]): UnsafeRow = { + if (valueBytes != null) { + // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform. + valueRow.pointTo( + valueBytes, + Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, + valueBytes.size - STATE_ENCODING_NUM_VERSION_BYTES) + valueRow + } else { + null + } + } + + /** + * Decode pair of key-value byte arrays in a pair of key-value UnsafeRows. + * + * @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to + * the given byte array. + */ + def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair = { + rowTuple.withRows(decodeKey(byteArrayTuple.key), decodeValue(byteArrayTuple.value)) + } + } +} + +object RocksDBStateStoreProvider { + // Version as a single byte that specifies the encoding of the row data in RocksDB + val STATE_ENCODING_NUM_VERSION_BYTES = 1 + val STATE_ENCODING_VERSION: Byte = 0 + + // Native operation latencies report as latency per 1000 calls + // as SQLMetrics support ms latency whereas RocksDB reports it in microseconds. + val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric( + "rocksdbGetLatency", "RocksDB: avg get latency (per 1000 calls)") + val CUSTOM_METRIC_PUT_TIME = StateStoreCustomTimingMetric( + "rocksdbPutLatency", "RocksDB: avg put latency (per 1000 calls)") + + // Commit latency detailed breakdown + val CUSTOM_METRIC_WRITEBATCH_TIME = StateStoreCustomTimingMetric( + "rocksdbCommitWriteBatchLatency", "RocksDB: commit - write batch time") + val CUSTOM_METRIC_FLUSH_TIME = StateStoreCustomTimingMetric( + "rocksdbCommitFlushLatency", "RocksDB: commit - flush time") + val CUSTOM_METRIC_PAUSE_TIME = StateStoreCustomTimingMetric( + "rocksdbCommitPauseLatency", "RocksDB: commit - pause bg time") + val CUSTOM_METRIC_CHECKPOINT_TIME = StateStoreCustomTimingMetric( + "rocksdbCommitCheckpointLatency", "RocksDB: commit - checkpoint time") + val CUSTOM_METRIC_FILESYNC_TIME = StateStoreCustomTimingMetric( + "rocksdbFileSyncTime", "RocksDB: commit - file sync time") + val CUSTOM_METRIC_FILES_COPIED = StateStoreCustomSizeMetric( + "rocksdbFilesCopied", "RocksDB: file manager - files copied") + val CUSTOM_METRIC_BYTES_COPIED = StateStoreCustomSizeMetric( + "rocksdbBytesCopied", "RocksDB: file manager - bytes copied") + val CUSTOM_METRIC_FILES_REUSED = StateStoreCustomSizeMetric( + "rocksdbFilesReused", "RocksDB: file manager - files reused") + val CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED = StateStoreCustomSizeMetric( + "rocksdbZipFileBytesUncompressed", "RocksDB: file manager - uncompressed zip file bytes") + + // Total SST file size + val CUSTOM_METRIC_SST_FILE_SIZE = StateStoreCustomSizeMetric( + "rocksdbSstFileSize", "RocksDB: size of all SST files") + + val ALL_CUSTOM_METRICS = Seq( + CUSTOM_METRIC_SST_FILE_SIZE, CUSTOM_METRIC_GET_TIME, CUSTOM_METRIC_PUT_TIME, + CUSTOM_METRIC_WRITEBATCH_TIME, CUSTOM_METRIC_FLUSH_TIME, CUSTOM_METRIC_PAUSE_TIME, + CUSTOM_METRIC_CHECKPOINT_TIME, CUSTOM_METRIC_FILESYNC_TIME, + CUSTOM_METRIC_BYTES_COPIED, CUSTOM_METRIC_FILES_COPIED, CUSTOM_METRIC_FILES_REUSED, + CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED + ) +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala new file mode 100644 index 0000000000..bf4bd3e105 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.File + +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming._ + + +class RocksDBStateStoreIntegrationSuite extends StreamTest { + import testImplicits._ + + test("RocksDBStateStore") { + withTempDir { dir => + val input = MemoryStream[Int] + val conf = Map(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) + + testStream(input.toDF.groupBy().count(), outputMode = OutputMode.Update)( + StartStream(checkpointLocation = dir.getAbsolutePath, additionalConfs = conf), + AddData(input, 1, 2, 3), + CheckAnswer(3), + AssertOnQuery { q => + // Verify that RocksDBStateStore by verify the state checkpoints are [version].zip + val storeCheckpointDir = StateStoreId( + dir.getAbsolutePath + "/state", 0, 0).storeCheckpointLocation() + val storeCheckpointFile = storeCheckpointDir + "/1.zip" + new File(storeCheckpointFile).exists() + } + ) + } + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala new file mode 100644 index 0000000000..b9cc844319 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.UUID + +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkConf +import org.apache.spark.sql.LocalSparkSession.withSparkSession +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform +import org.apache.spark.util.Utils + +class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvider] + with BeforeAndAfter { + + import StateStoreTestsHelper._ + + test("version encoding") { + import RocksDBStateStoreProvider._ + + val provider = newStoreProvider() + val store = provider.getStore(0) + val keyRow = stringToRow("a") + val valueRow = intToRow(1) + store.put(keyRow, valueRow) + val iter = provider.rocksDB.iterator() + assert(iter.hasNext) + val kv = iter.next() + + // Verify the version encoded in first byte of the key and value byte arrays + assert(Platform.getByte(kv.key, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION) + assert(Platform.getByte(kv.value, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION) + } + + test("RocksDB confs are passed correctly from SparkSession to db instance") { + val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName) + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => + // Set the session confs that should be passed into RocksDB + val testConfs = Seq( + ("spark.sql.streaming.stateStore.providerClass", + classOf[RocksDBStateStoreProvider].getName), + (RocksDBConf.ROCKSDB_CONF_NAME_PREFIX + ".compactOnCommit", "true"), + (RocksDBConf.ROCKSDB_CONF_NAME_PREFIX + ".lockAcquireTimeoutMs", "10") + ) + testConfs.foreach { case (k, v) => spark.conf.set(k, v) } + + // Prepare test objects for running task on state store + val testRDD = spark.sparkContext.makeRDD[String](Seq("a"), 1) + val testSchema = StructType(Seq(StructField("key", StringType, true))) + val testStateInfo = StatefulOperatorStateInfo( + checkpointLocation = Utils.createTempDir().getAbsolutePath, + queryRunId = UUID.randomUUID, operatorId = 0, storeVersion = 0, numPartitions = 5) + + // Create state store in a task and get the RocksDBConf from the instantiated RocksDB instance + val rocksDBConfInTask: RocksDBConf = testRDD.mapPartitionsWithStateStore[RocksDBConf]( + spark.sqlContext, testStateInfo, testSchema, testSchema, None) { + (store: StateStore, _: Iterator[String]) => + // Use reflection to get RocksDB instance + val dbInstanceMethod = + store.getClass.getMethods.filter(_.getName.contains("dbInstance")).head + Iterator(dbInstanceMethod.invoke(store).asInstanceOf[RocksDB].conf) + }.collect().head + + // Verify the confs are same as those configured in the session conf + assert(rocksDBConfInTask.compactOnCommit == true) + assert(rocksDBConfInTask.lockAcquireTimeoutMs == 10L) + } + } + + test("rocksdb file manager metrics exposed") { + import RocksDBStateStoreProvider._ + def getCustomMetric(metrics: StateStoreMetrics, customMetric: StateStoreCustomMetric): Long = { + val metricPair = metrics.customMetrics.find(_._1.name == customMetric.name) + assert(metricPair.isDefined) + metricPair.get._2 + } + + val provider = newStoreProvider() + val store = provider.getStore(0) + // Verify state after updating + put(store, "a", 1) + assert(get(store, "a") === Some(1)) + assert(store.commit() === 1) + assert(store.hasCommitted) + val storeMetrics = store.metrics + assert(storeMetrics.numKeys === 1) + assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_FILES_COPIED) > 0L) + assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_FILES_REUSED) == 0L) + assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_BYTES_COPIED) > 0L) + assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED) > 0L) + } + + override def newStoreProvider(): RocksDBStateStoreProvider = { + newStoreProvider(StateStoreId(newDir(), Random.nextInt(), 0)) + } + + def newStoreProvider(storeId: StateStoreId): RocksDBStateStoreProvider = { + val keySchema = StructType(Seq(StructField("key", StringType, true))) + val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + val provider = new RocksDBStateStoreProvider() + provider.init( + storeId, keySchema, valueSchema, indexOrdinal = None, new StateStoreConf, new Configuration) + provider + } + + override def getLatestData(storeProvider: RocksDBStateStoreProvider): Set[(String, Int)] = { + getData(storeProvider, version = -1) + } + + override def getData( + provider: RocksDBStateStoreProvider, + version: Int = -1): Set[(String, Int)] = { + val reloadedProvider = newStoreProvider(provider.stateStoreId) + val versionToRead = if (version < 0) reloadedProvider.latestVersion else version + reloadedProvider.getStore(versionToRead).iterator().map(rowsToStringInt).toSet + } + + override protected val keySchema = StructType(Seq(StructField("key", StringType, true))) + override protected val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + + override def newStoreProvider( + minDeltasForSnapshot: Int, + numOfVersToRetainInMemory: Int): RocksDBStateStoreProvider = newStoreProvider() + + override def getDefaultSQLConf( + minDeltasForSnapshot: Int, + numOfVersToRetainInMemory: Int): SQLConf = new SQLConf() +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 4323725df9..2990860bc7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -49,6 +49,7 @@ import org.apache.spark.util.Utils class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] with BeforeAndAfter { import StateStoreTestsHelper._ + import StateStoreCoordinatorSuite._ override val keySchema = StructType(Seq(StructField("key", StringType, true))) override val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) @@ -235,6 +236,162 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(getSizeOfStateForCurrentVersion(store.metrics) > noDataMemoryUsed) } + test("maintenance") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + // Make sure that when SparkContext stops, the StateStore maintenance thread 'quickly' + // fails to talk to the StateStoreCoordinator and unloads all the StateStores + .set(RPC_NUM_RETRIES, 1) + val opId = 0 + val dir1 = newDir() + val storeProviderId1 = StateStoreProviderId(StateStoreId(dir1, opId, 0), UUID.randomUUID) + val dir2 = newDir() + val storeProviderId2 = StateStoreProviderId(StateStoreId(dir2, opId, 1), UUID.randomUUID) + val sqlConf = getDefaultSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, + SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get) + sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) + // Make maintenance thread do snapshots and cleanups very fast + sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 10L) + val storeConf = StateStoreConf(sqlConf) + val hadoopConf = new Configuration() + val provider = newStoreProvider(storeProviderId1.storeId) + + var latestStoreVersion = 0 + + def generateStoreVersions(): Unit = { + for (i <- 1 to 20) { + val store = StateStore.get(storeProviderId1, keySchema, valueSchema, None, + latestStoreVersion, storeConf, hadoopConf) + put(store, "a", i) + store.commit() + latestStoreVersion += 1 + } + } + + val timeoutDuration = 1.minute + + quietly { + withSpark(new SparkContext(conf)) { sc => + withCoordinatorRef(sc) { coordinatorRef => + require(!StateStore.isMaintenanceRunning, "StateStore is unexpectedly running") + + // Generate sufficient versions of store for snapshots + generateStoreVersions() + + eventually(timeout(timeoutDuration)) { + // Store should have been reported to the coordinator + assert(coordinatorRef.getLocation(storeProviderId1).nonEmpty, + "active instance was not reported") + + // Background maintenance should clean up and generate snapshots + assert(StateStore.isMaintenanceRunning, "Maintenance task is not running") + + // Some snapshots should have been generated + val snapshotVersions = (1 to latestStoreVersion).filter { version => + fileExists(provider, version, isSnapshot = true) + } + assert(snapshotVersions.nonEmpty, "no snapshot file found") + } + + // Generate more versions such that there is another snapshot and + // the earliest delta file will be cleaned up + generateStoreVersions() + + // Earliest delta file should get cleaned up + eventually(timeout(timeoutDuration)) { + assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") + } + + // If driver decides to deactivate all stores related to a query run, + // then this instance should be unloaded + coordinatorRef.deactivateInstances(storeProviderId1.queryRunId) + eventually(timeout(timeoutDuration)) { + assert(!StateStore.isLoaded(storeProviderId1)) + } + + // Reload the store and verify + StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None, + latestStoreVersion, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeProviderId1)) + + // If some other executor loads the store, then this instance should be unloaded + coordinatorRef + .reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty) + eventually(timeout(timeoutDuration)) { + assert(!StateStore.isLoaded(storeProviderId1)) + } + + // Reload the store and verify + StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None, + latestStoreVersion, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeProviderId1)) + + // If some other executor loads the store, and when this executor loads other store, + // then this executor should unload inactive instances immediately. + coordinatorRef + .reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty) + StateStore.get(storeProviderId2, keySchema, valueSchema, indexOrdinal = None, + 0, storeConf, hadoopConf) + assert(!StateStore.isLoaded(storeProviderId1)) + assert(StateStore.isLoaded(storeProviderId2)) + } + } + + // Verify if instance is unloaded if SparkContext is stopped + eventually(timeout(timeoutDuration)) { + require(SparkEnv.get === null) + assert(!StateStore.isLoaded(storeProviderId1)) + assert(!StateStore.isLoaded(storeProviderId2)) + assert(!StateStore.isMaintenanceRunning) + } + } + } + + test("snapshotting") { + val provider = newStoreProvider(minDeltasForSnapshot = 5, numOfVersToRetainInMemory = 2) + + var currentVersion = 0 + + currentVersion = updateVersionTo(provider, currentVersion, 2) + require(getLatestData(provider) === Set("a" -> 2)) + provider.doMaintenance() // should not generate snapshot files + assert(getLatestData(provider) === Set("a" -> 2)) + + for (i <- 1 to currentVersion) { + assert(fileExists(provider, i, isSnapshot = false)) // all delta files present + assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present + } + + // After version 6, snapshotting should generate one snapshot file + currentVersion = updateVersionTo(provider, currentVersion, 6) + require(getLatestData(provider) === Set("a" -> 6), "store not updated correctly") + provider.doMaintenance() // should generate snapshot files + + val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true)) + assert(snapshotVersion.nonEmpty, "snapshot file not generated") + deleteFilesEarlierThanVersion(provider, snapshotVersion.get) + assert( + getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get), + "snapshotting messed up the data of the snapshotted version") + assert( + getLatestData(provider) === Set("a" -> 6), + "snapshotting messed up the data of the final version") + + // After version 20, snapshotting should generate newer snapshot files + currentVersion = updateVersionTo(provider, currentVersion, 20) + require(getLatestData(provider) === Set("a" -> 20), "store not updated correctly") + provider.doMaintenance() // do snapshot + + val latestSnapshotVersion = (0 to 20).filter(version => + fileExists(provider, version, isSnapshot = true)).lastOption + assert(latestSnapshotVersion.nonEmpty, "no snapshot file found") + assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated") + + deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get) + assert(getLatestData(provider) === Set("a" -> 20), "snapshotting messed up the data") + } + testQuietly("SPARK-18342: commit fails when rename fails") { import RenameReturnsFalseFileSystem._ val dir = scheme + "://" + newDir() @@ -582,7 +739,6 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] extends StateStoreCodecsTest with PrivateMethodTester { import StateStoreTestsHelper._ - import StateStoreCoordinatorSuite._ type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] type ProviderMapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow] @@ -761,118 +917,6 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] assert(rowsToSet(finalStore.iterator()) === Set(key -> 2)) } - test("maintenance") { - val conf = new SparkConf() - .setMaster("local") - .setAppName("test") - // Make sure that when SparkContext stops, the StateStore maintenance thread 'quickly' - // fails to talk to the StateStoreCoordinator and unloads all the StateStores - .set(RPC_NUM_RETRIES, 1) - val opId = 0 - val dir1 = newDir() - val storeProviderId1 = StateStoreProviderId(StateStoreId(dir1, opId, 0), UUID.randomUUID) - val dir2 = newDir() - val storeProviderId2 = StateStoreProviderId(StateStoreId(dir2, opId, 1), UUID.randomUUID) - val sqlConf = getDefaultSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, - SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get) - sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) - // Make maintenance thread do snapshots and cleanups very fast - sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 10L) - val storeConf = StateStoreConf(sqlConf) - val hadoopConf = new Configuration() - val provider = newStoreProvider(storeProviderId1.storeId) - - var latestStoreVersion = 0 - - def generateStoreVersions(): Unit = { - for (i <- 1 to 20) { - val store = StateStore.get(storeProviderId1, keySchema, valueSchema, None, - latestStoreVersion, storeConf, hadoopConf) - put(store, "a", i) - store.commit() - latestStoreVersion += 1 - } - } - - val timeoutDuration = 1.minute - - quietly { - withSpark(new SparkContext(conf)) { sc => - withCoordinatorRef(sc) { coordinatorRef => - require(!StateStore.isMaintenanceRunning, "StateStore is unexpectedly running") - - // Generate sufficient versions of store for snapshots - generateStoreVersions() - - eventually(timeout(timeoutDuration)) { - // Store should have been reported to the coordinator - assert(coordinatorRef.getLocation(storeProviderId1).nonEmpty, - "active instance was not reported") - - // Background maintenance should clean up and generate snapshots - assert(StateStore.isMaintenanceRunning, "Maintenance task is not running") - - // Some snapshots should have been generated - val snapshotVersions = (1 to latestStoreVersion).filter { version => - fileExists(provider, version, isSnapshot = true) - } - assert(snapshotVersions.nonEmpty, "no snapshot file found") - } - - // Generate more versions such that there is another snapshot and - // the earliest delta file will be cleaned up - generateStoreVersions() - - // Earliest delta file should get cleaned up - eventually(timeout(timeoutDuration)) { - assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") - } - - // If driver decides to deactivate all stores related to a query run, - // then this instance should be unloaded - coordinatorRef.deactivateInstances(storeProviderId1.queryRunId) - eventually(timeout(timeoutDuration)) { - assert(!StateStore.isLoaded(storeProviderId1)) - } - - // Reload the store and verify - StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None, - latestStoreVersion, storeConf, hadoopConf) - assert(StateStore.isLoaded(storeProviderId1)) - - // If some other executor loads the store, then this instance should be unloaded - coordinatorRef - .reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty) - eventually(timeout(timeoutDuration)) { - assert(!StateStore.isLoaded(storeProviderId1)) - } - - // Reload the store and verify - StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None, - latestStoreVersion, storeConf, hadoopConf) - assert(StateStore.isLoaded(storeProviderId1)) - - // If some other executor loads the store, and when this executor loads other store, - // then this executor should unload inactive instances immediately. - coordinatorRef - .reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty) - StateStore.get(storeProviderId2, keySchema, valueSchema, indexOrdinal = None, - 0, storeConf, hadoopConf) - assert(!StateStore.isLoaded(storeProviderId1)) - assert(StateStore.isLoaded(storeProviderId2)) - } - } - - // Verify if instance is unloaded if SparkContext is stopped - eventually(timeout(timeoutDuration)) { - require(SparkEnv.get === null) - assert(!StateStore.isLoaded(storeProviderId1)) - assert(!StateStore.isLoaded(storeProviderId2)) - assert(!StateStore.isMaintenanceRunning) - } - } - } - test("StateStore.get") { quietly { val dir = newDir() @@ -925,50 +969,6 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] } } - test("snapshotting") { - val provider = newStoreProvider(minDeltasForSnapshot = 5, numOfVersToRetainInMemory = 2) - - var currentVersion = 0 - - currentVersion = updateVersionTo(provider, currentVersion, 2) - require(getLatestData(provider) === Set("a" -> 2)) - provider.doMaintenance() // should not generate snapshot files - assert(getLatestData(provider) === Set("a" -> 2)) - - for (i <- 1 to currentVersion) { - assert(fileExists(provider, i, isSnapshot = false)) // all delta files present - assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present - } - - // After version 6, snapshotting should generate one snapshot file - currentVersion = updateVersionTo(provider, currentVersion, 6) - require(getLatestData(provider) === Set("a" -> 6), "store not updated correctly") - provider.doMaintenance() // should generate snapshot files - - val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true)) - assert(snapshotVersion.nonEmpty, "snapshot file not generated") - deleteFilesEarlierThanVersion(provider, snapshotVersion.get) - assert( - getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get), - "snapshotting messed up the data of the snapshotted version") - assert( - getLatestData(provider) === Set("a" -> 6), - "snapshotting messed up the data of the final version") - - // After version 20, snapshotting should generate newer snapshot files - currentVersion = updateVersionTo(provider, currentVersion, 20) - require(getLatestData(provider) === Set("a" -> 20), "store not updated correctly") - provider.doMaintenance() // do snapshot - - val latestSnapshotVersion = (0 to 20).filter(version => - fileExists(provider, version, isSnapshot = true)).lastOption - assert(latestSnapshotVersion.nonEmpty, "no snapshot file found") - assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated") - - deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get) - assert(getLatestData(provider) === Set("a" -> 20), "snapshotting messed up the data") - } - test("reports memory usage") { val provider = newStoreProvider() val store = provider.getStore(0)