diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 24c89a7718..c93d0f0351 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -36,23 +36,34 @@ import org.apache.spark.util.Utils class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvider] with BeforeAndAfter { + before { + StateStore.stop() + require(!StateStore.isMaintenanceRunning) + } + + after { + StateStore.stop() + require(!StateStore.isMaintenanceRunning) + } + import StateStoreTestsHelper._ test("version encoding") { import RocksDBStateStoreProvider._ - val provider = newStoreProvider() - val store = provider.getStore(0) - val keyRow = dataToKeyRow("a", 0) - val valueRow = dataToValueRow(1) - store.put(keyRow, valueRow) - val iter = provider.rocksDB.iterator() - assert(iter.hasNext) - val kv = iter.next() + tryWithProviderResource(newStoreProvider()) { provider => + val store = provider.getStore(0) + val keyRow = dataToKeyRow("a", 0) + val valueRow = dataToValueRow(1) + store.put(keyRow, valueRow) + val iter = provider.rocksDB.iterator() + assert(iter.hasNext) + val kv = iter.next() - // Verify the version encoded in first byte of the key and value byte arrays - assert(Platform.getByte(kv.key, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION) - assert(Platform.getByte(kv.value, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION) + // Verify the version encoded in first byte of the key and value byte arrays + assert(Platform.getByte(kv.key, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION) + assert(Platform.getByte(kv.value, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION) + } } test("RocksDB confs are passed correctly from SparkSession to db instance") { @@ -100,19 +111,20 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid metricPair.get._2 } - val provider = newStoreProvider() - val store = provider.getStore(0) - // Verify state after updating - put(store, "a", 0, 1) - assert(get(store, "a", 0) === Some(1)) - assert(store.commit() === 1) - assert(store.hasCommitted) - val storeMetrics = store.metrics - assert(storeMetrics.numKeys === 1) - assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_FILES_COPIED) > 0L) - assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_FILES_REUSED) == 0L) - assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_BYTES_COPIED) > 0L) - assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED) > 0L) + tryWithProviderResource(newStoreProvider()) { provider => + val store = provider.getStore(0) + // Verify state after updating + put(store, "a", 0, 1) + assert(get(store, "a", 0) === Some(1)) + assert(store.commit() === 1) + assert(store.hasCommitted) + val storeMetrics = store.metrics + assert(storeMetrics.numKeys === 1) + assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_FILES_COPIED) > 0L) + assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_FILES_REUSED) == 0L) + assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_BYTES_COPIED) > 0L) + assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED) > 0L) + } } override def newStoreProvider(): RocksDBStateStoreProvider = { @@ -145,9 +157,10 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid override def getData( provider: RocksDBStateStoreProvider, version: Int = -1): Set[((String, Int), Int)] = { - val reloadedProvider = newStoreProvider(provider.stateStoreId) - val versionToRead = if (version < 0) reloadedProvider.latestVersion else version - reloadedProvider.getStore(versionToRead).iterator().map(rowPairToDataPair).toSet + tryWithProviderResource(newStoreProvider(provider.stateStoreId)) { reloadedProvider => + val versionToRead = if (version < 0) reloadedProvider.latestVersion else version + reloadedProvider.getStore(versionToRead).iterator().map(rowPairToDataPair).toSet + } } override def newStoreProvider( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 8a6f66da72..601b62bd81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -62,112 +62,118 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } test("retaining only two latest versions when MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 2") { - val provider = newStoreProvider(minDeltasForSnapshot = 10, numOfVersToRetainInMemory = 2) + tryWithProviderResource( + newStoreProvider(minDeltasForSnapshot = 10, numOfVersToRetainInMemory = 2)) { provider => - var currentVersion = 0 + var currentVersion = 0 - // commit the ver 1 : cache will have one element - currentVersion = incrementVersion(provider, currentVersion) - assert(getLatestData(provider) === Set(("a", 0) -> 1)) - var loadedMaps = provider.getLoadedMaps() - checkLoadedVersions(loadedMaps, count = 1, earliestKey = 1, latestKey = 1) - checkVersion(loadedMaps, 1, Map(("a", 0) -> 1)) + // commit the ver 1 : cache will have one element + currentVersion = incrementVersion(provider, currentVersion) + assert(getLatestData(provider) === Set(("a", 0) -> 1)) + var loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 1, latestKey = 1) + checkVersion(loadedMaps, 1, Map(("a", 0) -> 1)) - // commit the ver 2 : cache will have two elements - currentVersion = incrementVersion(provider, currentVersion) - assert(getLatestData(provider) === Set(("a", 0) -> 2)) - loadedMaps = provider.getLoadedMaps() - checkLoadedVersions(loadedMaps, count = 2, earliestKey = 2, latestKey = 1) - checkVersion(loadedMaps, 2, Map(("a", 0) -> 2)) - checkVersion(loadedMaps, 1, Map(("a", 0) -> 1)) + // commit the ver 2 : cache will have two elements + currentVersion = incrementVersion(provider, currentVersion) + assert(getLatestData(provider) === Set(("a", 0) -> 2)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 2, earliestKey = 2, latestKey = 1) + checkVersion(loadedMaps, 2, Map(("a", 0) -> 2)) + checkVersion(loadedMaps, 1, Map(("a", 0) -> 1)) - // commit the ver 3 : cache has already two elements and adding ver 3 incurs exceeding cache, - // and ver 3 will be added but ver 1 will be evicted - currentVersion = incrementVersion(provider, currentVersion) - assert(getLatestData(provider) === Set(("a", 0) -> 3)) - loadedMaps = provider.getLoadedMaps() - checkLoadedVersions(loadedMaps, count = 2, earliestKey = 3, latestKey = 2) - checkVersion(loadedMaps, 3, Map(("a", 0) -> 3)) - checkVersion(loadedMaps, 2, Map(("a", 0) -> 2)) + // commit the ver 3 : cache has already two elements and adding ver 3 incurs exceeding cache, + // and ver 3 will be added but ver 1 will be evicted + currentVersion = incrementVersion(provider, currentVersion) + assert(getLatestData(provider) === Set(("a", 0) -> 3)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 2, earliestKey = 3, latestKey = 2) + checkVersion(loadedMaps, 3, Map(("a", 0) -> 3)) + checkVersion(loadedMaps, 2, Map(("a", 0) -> 2)) + } } test("failure after committing with MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 1") { - val provider = newStoreProvider(opId = Random.nextInt, partition = 0, - numOfVersToRetainInMemory = 1) + tryWithProviderResource(newStoreProvider(opId = Random.nextInt, partition = 0, + numOfVersToRetainInMemory = 1)) { provider => - var currentVersion = 0 + var currentVersion = 0 - // commit the ver 1 : cache will have one element - currentVersion = incrementVersion(provider, currentVersion) - assert(getLatestData(provider) === Set(("a", 0) -> 1)) - var loadedMaps = provider.getLoadedMaps() - checkLoadedVersions(loadedMaps, count = 1, earliestKey = 1, latestKey = 1) - checkVersion(loadedMaps, 1, Map(("a", 0) -> 1)) + // commit the ver 1 : cache will have one element + currentVersion = incrementVersion(provider, currentVersion) + assert(getLatestData(provider) === Set(("a", 0) -> 1)) + var loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 1, latestKey = 1) + checkVersion(loadedMaps, 1, Map(("a", 0) -> 1)) - // commit the ver 2 : cache has already one elements and adding ver 2 incurs exceeding cache, - // and ver 2 will be added but ver 1 will be evicted - // this fact ensures cache miss will occur when this partition succeeds commit - // but there's a failure afterwards so have to reprocess previous batch - currentVersion = incrementVersion(provider, currentVersion) - assert(getLatestData(provider) === Set(("a", 0) -> 2)) - loadedMaps = provider.getLoadedMaps() - checkLoadedVersions(loadedMaps, count = 1, earliestKey = 2, latestKey = 2) - checkVersion(loadedMaps, 2, Map(("a", 0) -> 2)) + // commit the ver 2 : cache has already one elements and adding ver 2 incurs exceeding cache, + // and ver 2 will be added but ver 1 will be evicted + // this fact ensures cache miss will occur when this partition succeeds commit + // but there's a failure afterwards so have to reprocess previous batch + currentVersion = incrementVersion(provider, currentVersion) + assert(getLatestData(provider) === Set(("a", 0) -> 2)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 2, latestKey = 2) + checkVersion(loadedMaps, 2, Map(("a", 0) -> 2)) - // suppose there has been failure after committing, and it decided to reprocess previous batch - currentVersion = 1 + // suppose there has been failure after committing, and it decided to reprocess previous batch + currentVersion = 1 - // committing to existing version which is committed partially but abandoned globally - val store = provider.getStore(currentVersion) - // negative value to represent reprocessing - put(store, "a", 0, -2) - store.commit() - currentVersion += 1 + // committing to existing version which is committed partially but abandoned globally + val store = provider.getStore(currentVersion) + // negative value to represent reprocessing + put(store, "a", 0, -2) + store.commit() + currentVersion += 1 - // make sure newly committed version is reflected to the cache (overwritten) - assert(getLatestData(provider) === Set(("a", 0) -> -2)) - loadedMaps = provider.getLoadedMaps() - checkLoadedVersions(loadedMaps, count = 1, earliestKey = 2, latestKey = 2) - checkVersion(loadedMaps, 2, Map(("a", 0) -> -2)) + // make sure newly committed version is reflected to the cache (overwritten) + assert(getLatestData(provider) === Set(("a", 0) -> -2)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 2, latestKey = 2) + checkVersion(loadedMaps, 2, Map(("a", 0) -> -2)) + } } test("no cache data with MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 0") { - val provider = newStoreProvider(opId = Random.nextInt, partition = 0, - numOfVersToRetainInMemory = 0) + tryWithProviderResource(newStoreProvider(opId = Random.nextInt, partition = 0, + numOfVersToRetainInMemory = 0)) { provider => - var currentVersion = 0 + var currentVersion = 0 - // commit the ver 1 : never cached - currentVersion = incrementVersion(provider, currentVersion) - assert(getLatestData(provider) === Set(("a", 0) -> 1)) - var loadedMaps = provider.getLoadedMaps() - assert(loadedMaps.size() === 0) + // commit the ver 1 : never cached + currentVersion = incrementVersion(provider, currentVersion) + assert(getLatestData(provider) === Set(("a", 0) -> 1)) + var loadedMaps = provider.getLoadedMaps() + assert(loadedMaps.size() === 0) - // commit the ver 2 : never cached - currentVersion = incrementVersion(provider, currentVersion) - assert(getLatestData(provider) === Set(("a", 0) -> 2)) - loadedMaps = provider.getLoadedMaps() - assert(loadedMaps.size() === 0) + // commit the ver 2 : never cached + currentVersion = incrementVersion(provider, currentVersion) + assert(getLatestData(provider) === Set(("a", 0) -> 2)) + loadedMaps = provider.getLoadedMaps() + assert(loadedMaps.size() === 0) + } } test("cleaning") { - val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) + tryWithProviderResource(newStoreProvider(opId = Random.nextInt, partition = 0, + minDeltasForSnapshot = 5)) { provider => - for (i <- 1 to 20) { - val store = provider.getStore(i - 1) - put(store, "a", 0, i) - store.commit() - provider.doMaintenance() // do cleanup + for (i <- 1 to 20) { + val store = provider.getStore(i - 1) + put(store, "a", 0, i) + store.commit() + provider.doMaintenance() // do cleanup + } + require( + rowPairsToDataSet(provider.latestIterator()) === Set(("a", 0) -> 20), + "store not updated correctly") + + assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted + + // last couple of versions should be retrievable + assert(getData(provider, 20) === Set(("a", 0) -> 20)) + assert(getData(provider, 19) === Set(("a", 0) -> 19)) } - require( - rowPairsToDataSet(provider.latestIterator()) === Set(("a", 0) -> 20), - "store not updated correctly") - - assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted - - // last couple of versions should be retrievable - assert(getData(provider, 20) === Set(("a", 0) -> 20)) - assert(getData(provider, 19) === Set(("a", 0) -> 19)) } testQuietly("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { @@ -175,45 +181,51 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName) conf.set("fs.defaultFS", "fake:///") - val provider = newStoreProvider(opId = Random.nextInt, partition = 0, hadoopConf = conf) - provider.getStore(0).commit() - provider.getStore(0).commit() + tryWithProviderResource( + newStoreProvider(opId = Random.nextInt, partition = 0, hadoopConf = conf)) { provider => - // Verify we don't leak temp files - val tempFiles = FileUtils.listFiles(new File(provider.stateStoreId.checkpointRootLocation), - null, true).asScala.filter(_.getName.startsWith("temp-")) - assert(tempFiles.isEmpty) + provider.getStore(0).commit() + provider.getStore(0).commit() + + // Verify we don't leak temp files + val tempFiles = FileUtils.listFiles(new File(provider.stateStoreId.checkpointRootLocation), + null, true).asScala.filter(_.getName.startsWith("temp-")) + assert(tempFiles.isEmpty) + } } test("corrupted file handling") { - val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) - for (i <- 1 to 6) { - val store = provider.getStore(i - 1) - put(store, "a", 0, i) - store.commit() - provider.doMaintenance() // do cleanup - } - val snapshotVersion = (0 to 10).find( version => - fileExists(provider, version, isSnapshot = true)).getOrElse(fail("snapshot file not found")) + tryWithProviderResource(newStoreProvider(opId = Random.nextInt, partition = 0, + minDeltasForSnapshot = 5)) { provider => - // Corrupt snapshot file and verify that it throws error - assert(getData(provider, snapshotVersion) === Set(("a", 0) -> snapshotVersion)) - corruptFile(provider, snapshotVersion, isSnapshot = true) - intercept[Exception] { - getData(provider, snapshotVersion) - } + for (i <- 1 to 6) { + val store = provider.getStore(i - 1) + put(store, "a", 0, i) + store.commit() + provider.doMaintenance() // do cleanup + } + val snapshotVersion = (0 to 10).find( version => + fileExists(provider, version, isSnapshot = true)).getOrElse(fail("snapshot file not found")) - // Corrupt delta file and verify that it throws error - assert(getData(provider, snapshotVersion - 1) === Set(("a", 0) -> (snapshotVersion - 1))) - corruptFile(provider, snapshotVersion - 1, isSnapshot = false) - intercept[Exception] { - getData(provider, snapshotVersion - 1) - } + // Corrupt snapshot file and verify that it throws error + assert(getData(provider, snapshotVersion) === Set(("a", 0) -> snapshotVersion)) + corruptFile(provider, snapshotVersion, isSnapshot = true) + intercept[Exception] { + getData(provider, snapshotVersion) + } - // Delete delta file and verify that it throws error - deleteFilesEarlierThanVersion(provider, snapshotVersion) - intercept[Exception] { - getData(provider, snapshotVersion - 1) + // Corrupt delta file and verify that it throws error + assert(getData(provider, snapshotVersion - 1) === Set(("a", 0) -> (snapshotVersion - 1))) + corruptFile(provider, snapshotVersion - 1, isSnapshot = false) + intercept[Exception] { + getData(provider, snapshotVersion - 1) + } + + // Delete delta file and verify that it throws error + deleteFilesEarlierThanVersion(provider, snapshotVersion) + intercept[Exception] { + getData(provider, snapshotVersion - 1) + } } } @@ -224,13 +236,14 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] metricPair.get._2 } - val provider = newStoreProvider() - val store = provider.getStore(0) - val noDataMemoryUsed = getSizeOfStateForCurrentVersion(store.metrics) + tryWithProviderResource(newStoreProvider()) { provider => + val store = provider.getStore(0) + val noDataMemoryUsed = getSizeOfStateForCurrentVersion(store.metrics) - put(store, "a", 0, 1) - store.commit() - assert(getSizeOfStateForCurrentVersion(store.metrics) > noDataMemoryUsed) + put(store, "a", 0, 1) + store.commit() + assert(getSizeOfStateForCurrentVersion(store.metrics) > noDataMemoryUsed) + } } test("maintenance") { @@ -252,7 +265,6 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 10L) val storeConf = StateStoreConf(sqlConf) val hadoopConf = new Configuration() - val provider = newStoreProvider(storeProviderId1.storeId) var latestStoreVersion = 0 @@ -285,10 +297,12 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(StateStore.isMaintenanceRunning, "Maintenance task is not running") // Some snapshots should have been generated - val snapshotVersions = (1 to latestStoreVersion).filter { version => - fileExists(provider, version, isSnapshot = true) + tryWithProviderResource(newStoreProvider(storeProviderId1.storeId)) { provider => + val snapshotVersions = (1 to latestStoreVersion).filter { version => + fileExists(provider, version, isSnapshot = true) + } + assert(snapshotVersions.nonEmpty, "no snapshot file found") } - assert(snapshotVersions.nonEmpty, "no snapshot file found") } // Generate more versions such that there is another snapshot and @@ -296,8 +310,10 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] generateStoreVersions() // Earliest delta file should get cleaned up - eventually(timeout(timeoutDuration)) { - assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") + tryWithProviderResource(newStoreProvider(storeProviderId1.storeId)) { provider => + eventually(timeout(timeoutDuration)) { + assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") + } } // If driver decides to deactivate all stores related to a query run, @@ -346,47 +362,51 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } test("snapshotting") { - val provider = newStoreProvider(minDeltasForSnapshot = 5, numOfVersToRetainInMemory = 2) + tryWithProviderResource( + newStoreProvider(minDeltasForSnapshot = 5, numOfVersToRetainInMemory = 2)) { provider => - var currentVersion = 0 + var currentVersion = 0 - currentVersion = updateVersionTo(provider, currentVersion, 2) - require(getLatestData(provider) === Set(("a", 0) -> 2)) - provider.doMaintenance() // should not generate snapshot files - assert(getLatestData(provider) === Set(("a", 0) -> 2)) + currentVersion = updateVersionTo(provider, currentVersion, 2) + require(getLatestData(provider) === Set(("a", 0) -> 2)) + provider.doMaintenance() // should not generate snapshot files + assert(getLatestData(provider) === Set(("a", 0) -> 2)) - for (i <- 1 to currentVersion) { - assert(fileExists(provider, i, isSnapshot = false)) // all delta files present - assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present + for (i <- 1 to currentVersion) { + assert(fileExists(provider, i, isSnapshot = false)) // all delta files present + assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present + } + + // After version 6, snapshotting should generate one snapshot file + currentVersion = updateVersionTo(provider, currentVersion, 6) + require(getLatestData(provider) === Set(("a", 0) -> 6), "store not updated correctly") + provider.doMaintenance() // should generate snapshot files + + val snapshotVersion = (0 to 6).find { version => + fileExists(provider, version, isSnapshot = true) + } + assert(snapshotVersion.nonEmpty, "snapshot file not generated") + deleteFilesEarlierThanVersion(provider, snapshotVersion.get) + assert( + getData(provider, snapshotVersion.get) === Set(("a", 0) -> snapshotVersion.get), + "snapshotting messed up the data of the snapshotted version") + assert( + getLatestData(provider) === Set(("a", 0) -> 6), + "snapshotting messed up the data of the final version") + + // After version 20, snapshotting should generate newer snapshot files + currentVersion = updateVersionTo(provider, currentVersion, 20) + require(getLatestData(provider) === Set(("a", 0) -> 20), "store not updated correctly") + provider.doMaintenance() // do snapshot + + val latestSnapshotVersion = (0 to 20).filter(version => + fileExists(provider, version, isSnapshot = true)).lastOption + assert(latestSnapshotVersion.nonEmpty, "no snapshot file found") + assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated") + + deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get) + assert(getLatestData(provider) === Set(("a", 0) -> 20), "snapshotting messed up the data") } - - // After version 6, snapshotting should generate one snapshot file - currentVersion = updateVersionTo(provider, currentVersion, 6) - require(getLatestData(provider) === Set(("a", 0) -> 6), "store not updated correctly") - provider.doMaintenance() // should generate snapshot files - - val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true)) - assert(snapshotVersion.nonEmpty, "snapshot file not generated") - deleteFilesEarlierThanVersion(provider, snapshotVersion.get) - assert( - getData(provider, snapshotVersion.get) === Set(("a", 0) -> snapshotVersion.get), - "snapshotting messed up the data of the snapshotted version") - assert( - getLatestData(provider) === Set(("a", 0) -> 6), - "snapshotting messed up the data of the final version") - - // After version 20, snapshotting should generate newer snapshot files - currentVersion = updateVersionTo(provider, currentVersion, 20) - require(getLatestData(provider) === Set(("a", 0) -> 20), "store not updated correctly") - provider.doMaintenance() // do snapshot - - val latestSnapshotVersion = (0 to 20).filter(version => - fileExists(provider, version, isSnapshot = true)).lastOption - assert(latestSnapshotVersion.nonEmpty, "no snapshot file found") - assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated") - - deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get) - assert(getLatestData(provider) === Set(("a", 0) -> 20), "snapshotting messed up the data") } testQuietly("SPARK-18342: commit fails when rename fails") { @@ -394,12 +414,14 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] val dir = scheme + "://" + newDir() val conf = new Configuration() conf.set(s"fs.$scheme.impl", classOf[RenameReturnsFalseFileSystem].getName) - val provider = newStoreProvider( - opId = Random.nextInt, partition = 0, dir = dir, hadoopConf = conf) - val store = provider.getStore(0) - put(store, "a", 0, 0) - val e = intercept[IllegalStateException](store.commit()) - assert(e.getCause.getMessage.contains("Failed to rename")) + tryWithProviderResource(newStoreProvider( + opId = Random.nextInt, partition = 0, dir = dir, hadoopConf = conf)) { provider => + + val store = provider.getStore(0) + put(store, "a", 0, 0) + val e = intercept[IllegalStateException](store.commit()) + assert(e.getCause.getMessage.contains("Failed to rename")) + } } test("SPARK-18416: do not create temp delta file until the store is updated") { @@ -528,33 +550,34 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] classOf[CreateAtomicTestManager].getName) val remoteDir = Utils.createTempDir().getAbsolutePath - val provider = newStoreProvider( - opId = Random.nextInt, partition = 0, dir = remoteDir, hadoopConf = hadoopConf) + tryWithProviderResource(newStoreProvider(opId = Random.nextInt, partition = 0, + dir = remoteDir, hadoopConf = hadoopConf)) { provider => - // Disable failure of output stream and generate versions - CreateAtomicTestManager.shouldFailInCreateAtomic = false - for (version <- 1 to 10) { - val store = provider.getStore(version - 1) - put(store, version.toString, 0, version) // update "1" -> 1, "2" -> 2, ... - store.commit() + // Disable failure of output stream and generate versions + CreateAtomicTestManager.shouldFailInCreateAtomic = false + for (version <- 1 to 10) { + val store = provider.getStore(version - 1) + put(store, version.toString, 0, version) // update "1" -> 1, "2" -> 2, ... + store.commit() + } + val version10Data = (1L to 10).map(_.toString).map(x => x -> x).toSet + + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + val store = provider.getStore(10) + // Fail commit for next version and verify that reloading resets the files + CreateAtomicTestManager.shouldFailInCreateAtomic = true + put(store, "11", 0, 11) + val e = intercept[IllegalStateException] { quietly { store.commit() } } + assert(e.getCause.isInstanceOf[IOException]) + CreateAtomicTestManager.shouldFailInCreateAtomic = false + + // Abort commit for next version and verify that reloading resets the files + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + val store2 = provider.getStore(10) + put(store2, "11", 0, 11) + store2.abort() + assert(CreateAtomicTestManager.cancelCalledInCreateAtomic) } - val version10Data = (1L to 10).map(_.toString).map(x => x -> x).toSet - - CreateAtomicTestManager.cancelCalledInCreateAtomic = false - val store = provider.getStore(10) - // Fail commit for next version and verify that reloading resets the files - CreateAtomicTestManager.shouldFailInCreateAtomic = true - put(store, "11", 0, 11) - val e = intercept[IllegalStateException] { quietly { store.commit() } } - assert(e.getCause.isInstanceOf[IOException]) - CreateAtomicTestManager.shouldFailInCreateAtomic = false - - // Abort commit for next version and verify that reloading resets the files - CreateAtomicTestManager.cancelCalledInCreateAtomic = false - val store2 = provider.getStore(10) - put(store2, "11", 0, 11) - store2.abort() - assert(CreateAtomicTestManager.cancelCalledInCreateAtomic) } test("expose metrics with custom metrics to StateStoreMetrics") { @@ -578,66 +601,69 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(cacheMissCount === expectedCacheMissCount) } - val provider = newStoreProvider() + var store: StateStore = null + var loadedMapSizeForVersion1: Long = -1L + tryWithProviderResource(newStoreProvider()) { provider => + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) - // Verify state before starting a new set of updates - assert(getLatestData(provider).isEmpty) + store = provider.getStore(0) + assert(!store.hasCommitted) - val store = provider.getStore(0) - assert(!store.hasCommitted) + assert(store.metrics.numKeys === 0) - assert(store.metrics.numKeys === 0) + val initialLoadedMapSize = getLoadedMapSizeMetric(store.metrics) + assert(initialLoadedMapSize >= 0) + assertCacheHitAndMiss(store.metrics, expectedCacheHitCount = 0, expectedCacheMissCount = 0) - val initialLoadedMapSize = getLoadedMapSizeMetric(store.metrics) - assert(initialLoadedMapSize >= 0) - assertCacheHitAndMiss(store.metrics, expectedCacheHitCount = 0, expectedCacheMissCount = 0) + put(store, "a", 0, 1) + assert(store.metrics.numKeys === 1) - put(store, "a", 0, 1) - assert(store.metrics.numKeys === 1) + put(store, "b", 0, 2) + put(store, "aa", 0, 3) + assert(store.metrics.numKeys === 3) + remove(store, _._1.startsWith("a")) + assert(store.metrics.numKeys === 1) + assert(store.commit() === 1) - put(store, "b", 0, 2) - put(store, "aa", 0, 3) - assert(store.metrics.numKeys === 3) - remove(store, _._1.startsWith("a")) - assert(store.metrics.numKeys === 1) - assert(store.commit() === 1) + assert(store.hasCommitted) - assert(store.hasCommitted) + loadedMapSizeForVersion1 = getLoadedMapSizeMetric(store.metrics) + assert(loadedMapSizeForVersion1 > initialLoadedMapSize) + assertCacheHitAndMiss(store.metrics, expectedCacheHitCount = 0, expectedCacheMissCount = 0) - val loadedMapSizeForVersion1 = getLoadedMapSizeMetric(store.metrics) - assert(loadedMapSizeForVersion1 > initialLoadedMapSize) - assertCacheHitAndMiss(store.metrics, expectedCacheHitCount = 0, expectedCacheMissCount = 0) + val storeV2 = provider.getStore(1) + assert(!storeV2.hasCommitted) + assert(storeV2.metrics.numKeys === 1) - val storeV2 = provider.getStore(1) - assert(!storeV2.hasCommitted) - assert(storeV2.metrics.numKeys === 1) + put(storeV2, "cc", 0, 4) + assert(storeV2.metrics.numKeys === 2) + assert(storeV2.commit() === 2) - put(storeV2, "cc", 0, 4) - assert(storeV2.metrics.numKeys === 2) - assert(storeV2.commit() === 2) + assert(storeV2.hasCommitted) - assert(storeV2.hasCommitted) + val loadedMapSizeForVersion1And2 = getLoadedMapSizeMetric(storeV2.metrics) + assert(loadedMapSizeForVersion1And2 > loadedMapSizeForVersion1) + assertCacheHitAndMiss(storeV2.metrics, expectedCacheHitCount = 1, expectedCacheMissCount = 0) + } - val loadedMapSizeForVersion1And2 = getLoadedMapSizeMetric(storeV2.metrics) - assert(loadedMapSizeForVersion1And2 > loadedMapSizeForVersion1) - assertCacheHitAndMiss(storeV2.metrics, expectedCacheHitCount = 1, expectedCacheMissCount = 0) + tryWithProviderResource(newStoreProvider(store.id)) { reloadedProvider => + // intended to load version 2 instead of 1 + // version 2 will not be loaded to the cache in provider + val reloadedStore = reloadedProvider.getStore(1) + assert(reloadedStore.metrics.numKeys === 1) - val reloadedProvider = newStoreProvider(store.id) - // intended to load version 2 instead of 1 - // version 2 will not be loaded to the cache in provider - val reloadedStore = reloadedProvider.getStore(1) - assert(reloadedStore.metrics.numKeys === 1) + assertCacheHitAndMiss(reloadedStore.metrics, expectedCacheHitCount = 0, + expectedCacheMissCount = 1) - assertCacheHitAndMiss(reloadedStore.metrics, expectedCacheHitCount = 0, - expectedCacheMissCount = 1) + // now we are loading version 2 + val reloadedStoreV2 = reloadedProvider.getStore(2) + assert(reloadedStoreV2.metrics.numKeys === 2) - // now we are loading version 2 - val reloadedStoreV2 = reloadedProvider.getStore(2) - assert(reloadedStoreV2.metrics.numKeys === 2) - - assert(getLoadedMapSizeMetric(reloadedStoreV2.metrics) > loadedMapSizeForVersion1) - assertCacheHitAndMiss(reloadedStoreV2.metrics, expectedCacheHitCount = 0, - expectedCacheMissCount = 2) + assert(getLoadedMapSizeMetric(reloadedStoreV2.metrics) > loadedMapSizeForVersion1) + assertCacheHitAndMiss(reloadedStoreV2.metrics, expectedCacheHitCount = 0, + expectedCacheMissCount = 2) + } } override def newStoreProvider(): HDFSBackedStateStoreProvider = { @@ -662,13 +688,14 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } override def getData( - provider: HDFSBackedStateStoreProvider, - version: Int): Set[((String, Int), Int)] = { - val reloadedProvider = newStoreProvider(provider.stateStoreId) - if (version < 0) { - reloadedProvider.latestIterator().map(rowPairToDataPair).toSet - } else { - reloadedProvider.getStore(version).iterator().map(rowPairToDataPair).toSet + provider: HDFSBackedStateStoreProvider, + version: Int): Set[((String, Int), Int)] = { + tryWithProviderResource(newStoreProvider(provider.stateStoreId)) { reloadedProvider => + if (version < 0) { + reloadedProvider.latestIterator().map(rowPairToDataPair).toSet + } else { + reloadedProvider.getStore(version).iterator().map(rowPairToDataPair).toSet + } } } @@ -729,9 +756,9 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } def corruptFile( - provider: HDFSBackedStateStoreProvider, - version: Long, - isSnapshot: Boolean): Unit = { + provider: HDFSBackedStateStoreProvider, + version: Long, + isSnapshot: Boolean): Unit = { val method = PrivateMethod[Path](Symbol("baseDir")) val basePath = provider invokePrivate method() val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" @@ -751,204 +778,206 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] protected val valueSchema: StructType = StateStoreTestsHelper.valueSchema testWithAllCodec("get, put, remove, commit, and all data iterator") { - val provider = newStoreProvider() + tryWithProviderResource(newStoreProvider()) { provider => + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) - // Verify state before starting a new set of updates - assert(getLatestData(provider).isEmpty) + val store = provider.getStore(0) + assert(!store.hasCommitted) + assert(get(store, "a", 0) === None) + assert(store.iterator().isEmpty) + assert(store.metrics.numKeys === 0) - val store = provider.getStore(0) - assert(!store.hasCommitted) - assert(get(store, "a", 0) === None) - assert(store.iterator().isEmpty) - assert(store.metrics.numKeys === 0) + // Verify state after updating + put(store, "a", 0, 1) + assert(get(store, "a", 0) === Some(1)) - // Verify state after updating - put(store, "a", 0, 1) - assert(get(store, "a", 0) === Some(1)) + assert(store.iterator().nonEmpty) + assert(getLatestData(provider).isEmpty) - assert(store.iterator().nonEmpty) - assert(getLatestData(provider).isEmpty) + // Make updates, commit and then verify state + put(store, "b", 0, 2) + put(store, "aa", 0, 3) + remove(store, _._1.startsWith("a")) + assert(store.commit() === 1) - // Make updates, commit and then verify state - put(store, "b", 0, 2) - put(store, "aa", 0, 3) - remove(store, _._1.startsWith("a")) - assert(store.commit() === 1) + assert(store.hasCommitted) + assert(rowPairsToDataSet(store.iterator()) === Set(("b", 0) -> 2)) + assert(getLatestData(provider) === Set(("b", 0) -> 2)) - assert(store.hasCommitted) - assert(rowPairsToDataSet(store.iterator()) === Set(("b", 0) -> 2)) - assert(getLatestData(provider) === Set(("b", 0) -> 2)) + // Trying to get newer versions should fail + intercept[Exception] { + provider.getStore(2) + } + intercept[Exception] { + getData(provider, 2) + } - // Trying to get newer versions should fail - intercept[Exception] { - provider.getStore(2) + // New updates to the reloaded store with new version, and does not change old version + tryWithProviderResource(newStoreProvider(store.id)) { reloadedProvider => + val reloadedStore = reloadedProvider.getStore(1) + put(reloadedStore, "c", 0, 4) + assert(reloadedStore.commit() === 2) + assert(rowPairsToDataSet(reloadedStore.iterator()) === Set(("b", 0) -> 2, ("c", 0) -> 4)) + assert(getLatestData(provider) === Set(("b", 0) -> 2, ("c", 0) -> 4)) + assert(getData(provider, version = 1) === Set(("b", 0) -> 2)) + } } - intercept[Exception] { - getData(provider, 2) - } - - // New updates to the reloaded store with new version, and does not change old version - val reloadedProvider = newStoreProvider(store.id) - val reloadedStore = reloadedProvider.getStore(1) - put(reloadedStore, "c", 0, 4) - assert(reloadedStore.commit() === 2) - assert(rowPairsToDataSet(reloadedStore.iterator()) === Set(("b", 0) -> 2, ("c", 0) -> 4)) - assert(getLatestData(provider) === Set(("b", 0) -> 2, ("c", 0) -> 4)) - assert(getData(provider, version = 1) === Set(("b", 0) -> 2)) } testWithAllCodec("prefix scan") { - val provider = newStoreProvider(numPrefixCols = 1) + tryWithProviderResource(newStoreProvider(numPrefixCols = 1)) { provider => + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) - // Verify state before starting a new set of updates - assert(getLatestData(provider).isEmpty) + var store = provider.getStore(0) - var store = provider.getStore(0) - - def putCompositeKeys(keys: Seq[(String, Int)]): Unit = { - val randomizedKeys = scala.util.Random.shuffle(keys.toList) - randomizedKeys.foreach { case (key1, key2) => - put(store, key1, key2, key2) + def putCompositeKeys(keys: Seq[(String, Int)]): Unit = { + val randomizedKeys = scala.util.Random.shuffle(keys.toList) + randomizedKeys.foreach { case (key1, key2) => + put(store, key1, key2, key2) + } } - } - def verifyScan(key1: Seq[String], key2: Seq[Int]): Unit = { - key1.foreach { k1 => - val keyValueSet = store.prefixScan(dataToPrefixKeyRow(k1)).map { pair => - rowPairToDataPair(pair.withRows(pair.key.copy(), pair.value.copy())) - }.toSet + def verifyScan(key1: Seq[String], key2: Seq[Int]): Unit = { + key1.foreach { k1 => + val keyValueSet = store.prefixScan(dataToPrefixKeyRow(k1)).map { pair => + rowPairToDataPair(pair.withRows(pair.key.copy(), pair.value.copy())) + }.toSet - assert(keyValueSet === key2.map(k2 => ((k1, k2), k2)).toSet) + assert(keyValueSet === key2.map(k2 => ((k1, k2), k2)).toSet) + } } + + val key1AtVersion0 = Seq("a", "b", "c") + val key2AtVersion0 = Seq(1, 2, 3) + val keysAtVersion0 = for (k1 <- key1AtVersion0; k2 <- key2AtVersion0) yield (k1, k2) + + putCompositeKeys(keysAtVersion0) + verifyScan(key1AtVersion0, key2AtVersion0) + + assert(store.prefixScan(dataToPrefixKeyRow("non-exist")).isEmpty) + + // committing and loading the version 1 (the version being committed) + store.commit() + store = provider.getStore(1) + + // before putting the new key-value pairs, verify prefix scan works for existing keys + verifyScan(key1AtVersion0, key2AtVersion0) + + val key1AtVersion1 = Seq("c", "d") + val key2AtVersion1 = Seq(4, 5, 6) + val keysAtVersion1 = for (k1 <- key1AtVersion1; k2 <- key2AtVersion1) yield (k1, k2) + + // put a new key-value pairs, and verify that prefix scan reflects the changes + putCompositeKeys(keysAtVersion1) + verifyScan(Seq("c"), Seq(1, 2, 3, 4, 5, 6)) + verifyScan(Seq("d"), Seq(4, 5, 6)) + + // aborting and loading the version 1 again (keysAtVersion1 should be rolled back) + store.abort() + store = provider.getStore(1) + + // prefix scan should not reflect the uncommitted changes + verifyScan(key1AtVersion0, key2AtVersion0) + verifyScan(Seq("d"), Seq.empty) } - - val key1AtVersion0 = Seq("a", "b", "c") - val key2AtVersion0 = Seq(1, 2, 3) - val keysAtVersion0 = for (k1 <- key1AtVersion0; k2 <- key2AtVersion0) yield (k1, k2) - - putCompositeKeys(keysAtVersion0) - verifyScan(key1AtVersion0, key2AtVersion0) - - assert(store.prefixScan(dataToPrefixKeyRow("non-exist")).isEmpty) - - // committing and loading the version 1 (the version being committed) - store.commit() - store = provider.getStore(1) - - // before putting the new key-value pairs, verify prefix scan works for existing keys - verifyScan(key1AtVersion0, key2AtVersion0) - - val key1AtVersion1 = Seq("c", "d") - val key2AtVersion1 = Seq(4, 5, 6) - val keysAtVersion1 = for (k1 <- key1AtVersion1; k2 <- key2AtVersion1) yield (k1, k2) - - // put a new key-value pairs, and verify that prefix scan reflects the changes - putCompositeKeys(keysAtVersion1) - verifyScan(Seq("c"), Seq(1, 2, 3, 4, 5, 6)) - verifyScan(Seq("d"), Seq(4, 5, 6)) - - // aborting and loading the version 1 again (keysAtVersion1 should be rolled back) - store.abort() - store = provider.getStore(1) - - // prefix scan should not reflect the uncommitted changes - verifyScan(key1AtVersion0, key2AtVersion0) - verifyScan(Seq("d"), Seq.empty) } testWithAllCodec("numKeys metrics") { - val provider = newStoreProvider() + tryWithProviderResource(newStoreProvider()) { provider => + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) - // Verify state before starting a new set of updates - assert(getLatestData(provider).isEmpty) + val store = provider.getStore(0) + put(store, "a", 0, 1) + put(store, "b", 0, 2) + put(store, "c", 0, 3) + put(store, "d", 0, 4) + put(store, "e", 0, 5) + assert(store.commit() === 1) + assert(store.metrics.numKeys === 5) + assert(rowPairsToDataSet(store.iterator()) === + Set(("a", 0) -> 1, ("b", 0) -> 2, ("c", 0) -> 3, ("d", 0) -> 4, ("e", 0) -> 5)) - val store = provider.getStore(0) - put(store, "a", 0, 1) - put(store, "b", 0, 2) - put(store, "c", 0, 3) - put(store, "d", 0, 4) - put(store, "e", 0, 5) - assert(store.commit() === 1) - assert(store.metrics.numKeys === 5) - assert(rowPairsToDataSet(store.iterator()) === - Set(("a", 0) -> 1, ("b", 0) -> 2, ("c", 0) -> 3, ("d", 0) -> 4, ("e", 0) -> 5)) - - val reloadedProvider = newStoreProvider(store.id) - val reloadedStore = reloadedProvider.getStore(1) - remove(reloadedStore, _._1 == "b") - assert(reloadedStore.commit() === 2) - assert(reloadedStore.metrics.numKeys === 4) - assert(rowPairsToDataSet(reloadedStore.iterator()) === - Set(("a", 0) -> 1, ("c", 0) -> 3, ("d", 0) -> 4, ("e", 0) -> 5)) + val reloadedProvider = newStoreProvider(store.id) + val reloadedStore = reloadedProvider.getStore(1) + remove(reloadedStore, _._1 == "b") + assert(reloadedStore.commit() === 2) + assert(reloadedStore.metrics.numKeys === 4) + assert(rowPairsToDataSet(reloadedStore.iterator()) === + Set(("a", 0) -> 1, ("c", 0) -> 3, ("d", 0) -> 4, ("e", 0) -> 5)) + } } testWithAllCodec("removing while iterating") { - val provider = newStoreProvider() + tryWithProviderResource(newStoreProvider()) { provider => + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) + val store = provider.getStore(0) + put(store, "a", 0, 1) + put(store, "b", 0, 2) - // Verify state before starting a new set of updates - assert(getLatestData(provider).isEmpty) - val store = provider.getStore(0) - put(store, "a", 0, 1) - put(store, "b", 0, 2) + // Updates should work while iterating of filtered entries + val filtered = store.iterator.filter { tuple => keyRowToData(tuple.key) == ("a", 0) } + filtered.foreach { tuple => + store.put(tuple.key, dataToValueRow(valueRowToData(tuple.value) + 1)) + } + assert(get(store, "a", 0) === Some(2)) - // Updates should work while iterating of filtered entries - val filtered = store.iterator.filter { tuple => keyRowToData(tuple.key) == ("a", 0) } - filtered.foreach { tuple => - store.put(tuple.key, dataToValueRow(valueRowToData(tuple.value) + 1)) + // Removes should work while iterating of filtered entries + val filtered2 = store.iterator.filter { tuple => keyRowToData(tuple.key) == ("b", 0) } + filtered2.foreach { tuple => store.remove(tuple.key) } + assert(get(store, "b", 0) === None) } - assert(get(store, "a", 0) === Some(2)) - - // Removes should work while iterating of filtered entries - val filtered2 = store.iterator.filter { tuple => keyRowToData(tuple.key) == ("b", 0) } - filtered2.foreach { tuple => store.remove(tuple.key) } - assert(get(store, "b", 0) === None) } testWithAllCodec("abort") { - val provider = newStoreProvider() - val store = provider.getStore(0) - put(store, "a", 0, 1) - store.commit() - assert(rowPairsToDataSet(store.iterator()) === Set(("a", 0) -> 1)) + tryWithProviderResource(newStoreProvider()) { provider => + val store = provider.getStore(0) + put(store, "a", 0, 1) + store.commit() + assert(rowPairsToDataSet(store.iterator()) === Set(("a", 0) -> 1)) - // cancelUpdates should not change the data in the files - val store1 = provider.getStore(1) - put(store1, "b", 0, 1) - store1.abort() + // cancelUpdates should not change the data in the files + val store1 = provider.getStore(1) + put(store1, "b", 0, 1) + store1.abort() + } } testWithAllCodec("getStore with invalid versions") { - val provider = newStoreProvider() - - def checkInvalidVersion(version: Int): Unit = { - intercept[Exception] { - provider.getStore(version) + tryWithProviderResource(newStoreProvider()) { provider => + def checkInvalidVersion(version: Int): Unit = { + intercept[Exception] { + provider.getStore(version) + } } + + checkInvalidVersion(-1) + checkInvalidVersion(1) + + val store = provider.getStore(0) + put(store, "a", 0, 1) + assert(store.commit() === 1) + assert(rowPairsToDataSet(store.iterator()) === Set(("a", 0) -> 1)) + + val store1_ = provider.getStore(1) + assert(rowPairsToDataSet(store1_.iterator()) === Set(("a", 0) -> 1)) + + checkInvalidVersion(-1) + checkInvalidVersion(2) + + // Update store version with some data + val store1 = provider.getStore(1) + assert(rowPairsToDataSet(store1.iterator()) === Set(("a", 0) -> 1)) + put(store1, "b", 0, 1) + assert(store1.commit() === 2) + assert(rowPairsToDataSet(store1.iterator()) === Set(("a", 0) -> 1, ("b", 0) -> 1)) + + checkInvalidVersion(-1) + checkInvalidVersion(3) } - - checkInvalidVersion(-1) - checkInvalidVersion(1) - - val store = provider.getStore(0) - put(store, "a", 0, 1) - assert(store.commit() === 1) - assert(rowPairsToDataSet(store.iterator()) === Set(("a", 0) -> 1)) - - val store1_ = provider.getStore(1) - assert(rowPairsToDataSet(store1_.iterator()) === Set(("a", 0) -> 1)) - - checkInvalidVersion(-1) - checkInvalidVersion(2) - - // Update store version with some data - val store1 = provider.getStore(1) - assert(rowPairsToDataSet(store1.iterator()) === Set(("a", 0) -> 1)) - put(store1, "b", 0, 1) - assert(store1.commit() === 2) - assert(rowPairsToDataSet(store1.iterator()) === Set(("a", 0) -> 1, ("b", 0) -> 1)) - - checkInvalidVersion(-1) - checkInvalidVersion(3) } testWithAllCodec("two concurrent StateStores - one for read-only and one for read-write") { @@ -959,28 +988,33 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] // accidentally lead to the deletion of state. val dir = newDir() val storeId = StateStoreId(dir, 0L, 1) - val provider0 = newStoreProvider(storeId) - // prime state - val store = provider0.getStore(0) val key1 = "a" val key2 = 0 - put(store, key1, key2, 1) - store.commit() - assert(rowPairsToDataSet(store.iterator()) === Set((key1, key2) -> 1)) + + tryWithProviderResource(newStoreProvider(storeId)) { provider0 => + // prime state + val store = provider0.getStore(0) + + put(store, key1, key2, 1) + store.commit() + assert(rowPairsToDataSet(store.iterator()) === Set((key1, key2) -> 1)) + } // two state stores - val provider1 = newStoreProvider(storeId) - val restoreStore = provider1.getReadStore(1) - val saveStore = provider1.getStore(1) + tryWithProviderResource(newStoreProvider(storeId)) { provider1 => + val restoreStore = provider1.getReadStore(1) + val saveStore = provider1.getStore(1) - put(saveStore, key1, key2, get(restoreStore, key1, key2).get + 1) - saveStore.commit() - restoreStore.abort() + put(saveStore, key1, key2, get(restoreStore, key1, key2).get + 1) + saveStore.commit() + restoreStore.abort() + } // check that state is correct for next batch - val provider2 = newStoreProvider(storeId) - val finalStore = provider2.getStore(2) - assert(rowPairsToDataSet(finalStore.iterator()) === Set((key1, key2) -> 2)) + tryWithProviderResource(newStoreProvider(storeId)) { provider2 => + val finalStore = provider2.getStore(2) + assert(rowPairsToDataSet(finalStore.iterator()) === Set((key1, key2) -> 2)) + } } test("StateStore.get") { @@ -1036,12 +1070,13 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] } test("reports memory usage") { - val provider = newStoreProvider() - val store = provider.getStore(0) - val noDataMemoryUsed = store.metrics.memoryUsedBytes - put(store, "a", 0, 1) - store.commit() - assert(store.metrics.memoryUsedBytes > noDataMemoryUsed) + tryWithProviderResource(newStoreProvider()) { provider => + val store = provider.getStore(0) + val noDataMemoryUsed = store.metrics.memoryUsedBytes + put(store, "a", 0, 1) + store.commit() + assert(store.metrics.memoryUsedBytes > noDataMemoryUsed) + } } test("SPARK-34270: StateStoreMetrics.combine should not override individual metrics") { @@ -1067,16 +1102,16 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] } test("SPARK-35659: StateStore.put cannot put null value") { - val provider = newStoreProvider() + tryWithProviderResource(newStoreProvider()) { provider => + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) - // Verify state before starting a new set of updates - assert(getLatestData(provider).isEmpty) - - val store = provider.getStore(0) - val err = intercept[IllegalArgumentException] { - store.put(dataToKeyRow("key", 0), null) + val store = provider.getStore(0) + val err = intercept[IllegalArgumentException] { + store.put(dataToKeyRow("key", 0), null) + } + assert(err.getMessage.contains("Cannot put a null value")) } - assert(err.getMessage.contains("Cannot put a null value")) } test("SPARK-35763: StateStoreCustomMetric withNewDesc and createSQLMetric") { @@ -1122,6 +1157,14 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] } } + protected def tryWithProviderResource[T](provider: ProviderClass)(f: ProviderClass => T): T = { + try { + f(provider) + } finally { + provider.close() + } + } + /** Get the `SQLConf` by the given minimum delta and version to retain in memory */ def getDefaultSQLConf(minDeltasForSnapshot: Int, numOfVersToRetainInMemory: Int): SQLConf