From 3257a30e5399d4f366e4aae60b04371b31514fb4 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Tue, 29 Jun 2021 17:46:45 -0700 Subject: [PATCH] [SPARK-35784][SS] Implementation for RocksDB instance ### What changes were proposed in this pull request? The implementation for the RocksDB instance, which is used in the RocksDB state store. It plays a role as a handler for the RocksDB instance and RocksDBFileManager. ### Why are the changes needed? Part of the RocksDB state store implementation. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New UT added. Closes #32928 from xuanyuanking/SPARK-35784. Authored-by: Yuanjian Li Signed-off-by: Liang-Chi Hsieh --- dev/deps/spark-deps-hadoop-2.7-hive-2.3 | 1 + dev/deps/spark-deps-hadoop-3.2-hive-2.3 | 1 + sql/core/pom.xml | 5 + .../execution/streaming/state/RocksDB.scala | 452 ++++++++++++++++++ .../streaming/state/RocksDBLoader.scala | 60 +++ .../streaming/state/RocksDBSuite.scala | 97 ++++ 6 files changed, 616 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBLoader.scala diff --git a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 b/dev/deps/spark-deps-hadoop-2.7-hive-2.3 index b18df074cf..0d8e0323bf 100644 --- a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-2.7-hive-2.3 @@ -211,6 +211,7 @@ parquet-jackson/1.12.0//parquet-jackson-1.12.0.jar protobuf-java/2.5.0//protobuf-java-2.5.0.jar py4j/0.10.9.2//py4j-0.10.9.2.jar pyrolite/4.30//pyrolite-4.30.jar +rocksdbjni/6.2.2//rocksdbjni-6.2.2.jar scala-collection-compat_2.12/2.1.1//scala-collection-compat_2.12-2.1.1.jar scala-compiler/2.12.14//scala-compiler-2.12.14.jar scala-library/2.12.14//scala-library-2.12.14.jar diff --git a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 index bc77aa66d1..b7d49384c3 100644 --- a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 @@ -182,6 +182,7 @@ parquet-jackson/1.12.0//parquet-jackson-1.12.0.jar protobuf-java/2.5.0//protobuf-java-2.5.0.jar py4j/0.10.9.2//py4j-0.10.9.2.jar pyrolite/4.30//pyrolite-4.30.jar +rocksdbjni/6.2.2//rocksdbjni-6.2.2.jar scala-collection-compat_2.12/2.1.1//scala-collection-compat_2.12-2.1.1.jar scala-compiler/2.12.14//scala-compiler-2.12.14.jar scala-library/2.12.14//scala-library-2.12.14.jar diff --git a/sql/core/pom.xml b/sql/core/pom.xml index f7bbe807ca..149b58b1f3 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -35,6 +35,11 @@ + + org.rocksdb + rocksdbjni + 6.2.2 + com.univocity univocity-parsers diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala new file mode 100644 index 0000000000..82aa1663f8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.File +import java.util.Locale +import javax.annotation.concurrent.GuardedBy + +import scala.collection.{mutable, Map} +import scala.ref.WeakReference +import scala.util.Try + +import org.apache.hadoop.conf.Configuration +import org.rocksdb.{RocksDB => NativeRocksDB, _} + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.util.{NextIterator, Utils} + +/** + * Class representing a RocksDB instance that checkpoints version of data to DFS. + * After a set of updates, a new version can be committed by calling `commit()`. + * Any past version can be loaded by calling `load(version)`. + * + * @note This class is not thread-safe, so use it only from one thread. + * @see [[RocksDBFileManager]] to see how the files are laid out in local disk and DFS. + * @param dfsRootDir Remote directory where checkpoints are going to be written + * @param conf Configuration for RocksDB + * @param localRootDir Root directory in local disk that is used to working and checkpointing dirs + * @param hadoopConf Hadoop configuration for talking to the remote file system + * @param loggingId Id that will be prepended in logs for isolating concurrent RocksDBs + */ +class RocksDB( + dfsRootDir: String, + val conf: RocksDBConf, + localRootDir: File = Utils.createTempDir(), + hadoopConf: Configuration = new Configuration, + loggingId: String = "") extends Logging { + + RocksDBLoader.loadLibrary() + + // Java wrapper objects linking to native RocksDB objects + private val readOptions = new ReadOptions() // used for gets + private val writeOptions = new WriteOptions().setSync(true) // wait for batched write to complete + private val flushOptions = new FlushOptions().setWaitForFlush(true) // wait for flush to complete + private val writeBatch = new WriteBatchWithIndex(true) // overwrite multiple updates to a key + + private val bloomFilter = new BloomFilter() + private val tableFormatConfig = new BlockBasedTableConfig() + tableFormatConfig.setBlockSize(conf.blockSizeKB * 1024) + tableFormatConfig.setBlockCache(new LRUCache(conf.blockCacheSizeMB * 1024 * 1024)) + tableFormatConfig.setFilterPolicy(bloomFilter) + + private val dbOptions = new Options() // options to open the RocksDB + dbOptions.setCreateIfMissing(true) + dbOptions.setTableFormatConfig(tableFormatConfig) + private val dbLogger = createLogger() // for forwarding RocksDB native logs to log4j + dbOptions.setStatistics(new Statistics()) + + private val workingDir = createTempDir("workingDir") + private val fileManager = new RocksDBFileManager( + dfsRootDir, createTempDir("fileManager"), hadoopConf, loggingId = loggingId) + private val byteArrayPair = new ByteArrayPair() + private val commitLatencyMs = new mutable.HashMap[String, Long]() + private val acquireLock = new Object + + @volatile private var db: NativeRocksDB = _ + @volatile private var loadedVersion = -1L // -1 = nothing valid is loaded + @volatile private var numKeysOnLoadedVersion = 0L + @volatile private var numKeysOnWritingVersion = 0L + + @GuardedBy("acquireLock") + @volatile private var acquiredThreadInfo: AcquiredThreadInfo = _ + + /** + * Load the given version of data in a native RocksDB instance. + * Note that this will copy all the necessary file from DFS to local disk as needed, + * and possibly restart the native RocksDB instance. + */ + def load(version: Long): RocksDB = { + assert(version >= 0) + acquire() + logInfo(s"Loading $version") + try { + if (loadedVersion != version) { + closeDB() + val metadata = fileManager.loadCheckpointFromDfs(version, workingDir) + openDB() + numKeysOnWritingVersion = metadata.numKeys + numKeysOnLoadedVersion = metadata.numKeys + loadedVersion = version + } + writeBatch.clear() + logInfo(s"Loaded $version") + } catch { + case t: Throwable => + loadedVersion = -1 // invalidate loaded data + throw t + } + this + } + + /** + * Get the value for the given key if present, or null. + * @note This will return the last written value even if it was uncommitted. + */ + def get(key: Array[Byte]): Array[Byte] = { + writeBatch.getFromBatchAndDB(db, readOptions, key) + } + + /** + * Put the given value for the given key and return the last written value. + * @note This update is not committed to disk until commit() is called. + */ + def put(key: Array[Byte], value: Array[Byte]): Array[Byte] = { + val oldValue = writeBatch.getFromBatchAndDB(db, readOptions, key) + writeBatch.put(key, value) + if (oldValue == null) { + numKeysOnWritingVersion += 1 + } + oldValue + } + + /** + * Remove the key if present, and return the previous value if it was present (null otherwise). + * @note This update is not committed to disk until commit() is called. + */ + def remove(key: Array[Byte]): Array[Byte] = { + val value = writeBatch.getFromBatchAndDB(db, readOptions, key) + if (value != null) { + writeBatch.remove(key) + numKeysOnWritingVersion -= 1 + } + value + } + + /** + * Get an iterator of all committed and uncommitted key-value pairs. + */ + def iterator(): Iterator[ByteArrayPair] = { + val iter = writeBatch.newIteratorWithBase(db.newIterator()) + logInfo(s"Getting iterator from version $loadedVersion") + iter.seekToFirst() + + // Attempt to close this iterator if there is a task failure, or a task interruption. + // This is a hack because it assumes that the RocksDB is running inside a task. + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => iter.close() } + } + + new NextIterator[ByteArrayPair] { + override protected def getNext(): ByteArrayPair = { + if (iter.isValid) { + byteArrayPair.set(iter.key, iter.value) + iter.next() + byteArrayPair + } else { + finished = true + iter.close() + null + } + } + override protected def close(): Unit = { iter.close() } + } + } + + /** + * Commit all the updates made as a version to DFS. The steps it needs to do to commits are: + * - Write all the updates to the native RocksDB + * - Flush all changes to disk + * - Create a RocksDB checkpoint in a new local dir + * - Sync the checkpoint dir files to DFS + */ + def commit(): Long = { + val newVersion = loadedVersion + 1 + val checkpointDir = createTempDir("checkpoint") + try { + // Make sure the directory does not exist. Native RocksDB fails if the directory to + // checkpoint exists. + Utils.deleteRecursively(checkpointDir) + + logInfo(s"Writing updates for $newVersion") + val writeTimeMs = timeTakenMs { db.write(writeOptions, writeBatch) } + + logInfo(s"Flushing updates for $newVersion") + val flushTimeMs = timeTakenMs { db.flush(flushOptions) } + + val compactTimeMs = if (conf.compactOnCommit) { + logInfo("Compacting") + timeTakenMs { db.compactRange() } + } else 0 + logInfo("Pausing background work") + + val pauseTimeMs = timeTakenMs { + db.pauseBackgroundWork() // To avoid files being changed while committing + } + + logInfo(s"Creating checkpoint for $newVersion in $checkpointDir") + val checkpointTimeMs = timeTakenMs { + val cp = Checkpoint.create(db) + cp.createCheckpoint(checkpointDir.toString) + } + + logInfo(s"Syncing checkpoint for $newVersion to DFS") + val fileSyncTimeMs = timeTakenMs { + fileManager.saveCheckpointToDfs(checkpointDir, newVersion, numKeysOnWritingVersion) + } + numKeysOnLoadedVersion = numKeysOnWritingVersion + loadedVersion = newVersion + commitLatencyMs ++= Map( + "writeBatch" -> writeTimeMs, + "flush" -> flushTimeMs, + "compact" -> compactTimeMs, + "pause" -> pauseTimeMs, + "checkpoint" -> checkpointTimeMs, + "fileSync" -> fileSyncTimeMs + ) + loadedVersion + } catch { + case t: Throwable => + loadedVersion = -1 // invalidate loaded version + throw t + } finally { + db.continueBackgroundWork() + silentDeleteRecursively(checkpointDir, s"committing $newVersion") + release() + } + } + + /** + * Drop uncommitted changes, and roll back to previous version. + */ + def rollback(): Unit = { + writeBatch.clear() + numKeysOnWritingVersion = numKeysOnLoadedVersion + release() + logInfo(s"Rolled back to $loadedVersion") + } + + /** Release all resources */ + def close(): Unit = { + try { + closeDB() + + // Release all resources related to native RockDB objects + writeBatch.clear() + writeBatch.close() + readOptions.close() + writeOptions.close() + flushOptions.close() + dbOptions.close() + dbLogger.close() + silentDeleteRecursively(localRootDir, "closing RocksDB") + } catch { + case e: Exception => + logWarning("Error closing RocksDB", e) + } + } + + /** Get the latest version available in the DFS */ + def getLatestVersion(): Long = fileManager.getLatestVersion() + + private def acquire(): Unit = acquireLock.synchronized { + val newAcquiredThreadInfo = AcquiredThreadInfo() + val waitStartTime = System.currentTimeMillis + def timeWaitedMs = System.currentTimeMillis - waitStartTime + def isAcquiredByDifferentThread = acquiredThreadInfo != null && + acquiredThreadInfo.threadRef.get.isDefined && + newAcquiredThreadInfo.threadRef.get.get.getId != acquiredThreadInfo.threadRef.get.get.getId + + while (isAcquiredByDifferentThread && timeWaitedMs < conf.lockAcquireTimeoutMs) { + acquireLock.wait(10) + } + if (isAcquiredByDifferentThread) { + val stackTraceOutput = acquiredThreadInfo.threadRef.get.get.getStackTrace.mkString("\n") + val msg = s"RocksDB instance could not be acquired by $newAcquiredThreadInfo as it " + + s"was not released by $acquiredThreadInfo after $timeWaitedMs ms.\n" + + s"Thread holding the lock has trace: $stackTraceOutput" + logError(msg) + throw new IllegalStateException(s"$loggingId: $msg") + } else { + acquiredThreadInfo = newAcquiredThreadInfo + // Add a listener to always release the lock when the task (if active) completes + Option(TaskContext.get).foreach(_.addTaskCompletionListener[Unit] { _ => this.release() }) + logInfo(s"RocksDB instance was acquired by $acquiredThreadInfo") + } + } + + private def release(): Unit = acquireLock.synchronized { + acquiredThreadInfo = null + acquireLock.notifyAll() + } + + private def openDB(): Unit = { + assert(db == null) + db = NativeRocksDB.open(dbOptions, workingDir.toString) + logInfo(s"Opened DB with conf ${conf}") + } + + private def closeDB(): Unit = { + if (db != null) { + db.close() + db = null + } + } + + /** Create a native RocksDB logger that forwards native logs to log4j with correct log levels. */ + private def createLogger(): Logger = { + val dbLogger = new Logger(dbOptions) { + override def log(infoLogLevel: InfoLogLevel, logMsg: String) = { + // Map DB log level to log4j levels + // Warn is mapped to info because RocksDB warn is too verbose + // (e.g. dumps non-warning stuff like stats) + val loggingFunc: ( => String) => Unit = infoLogLevel match { + case InfoLogLevel.FATAL_LEVEL | InfoLogLevel.ERROR_LEVEL => logError(_) + case InfoLogLevel.WARN_LEVEL | InfoLogLevel.INFO_LEVEL => logInfo(_) + case InfoLogLevel.DEBUG_LEVEL => logDebug(_) + case _ => logTrace(_) + } + loggingFunc(s"[NativeRocksDB-${infoLogLevel.getValue}] $logMsg") + } + } + + var dbLogLevel = InfoLogLevel.ERROR_LEVEL + if (log.isWarnEnabled) dbLogLevel = InfoLogLevel.WARN_LEVEL + if (log.isInfoEnabled) dbLogLevel = InfoLogLevel.INFO_LEVEL + if (log.isDebugEnabled) dbLogLevel = InfoLogLevel.DEBUG_LEVEL + dbOptions.setLogger(dbLogger) + dbOptions.setInfoLogLevel(dbLogLevel) + logInfo(s"Set RocksDB native logging level to $dbLogLevel") + dbLogger + } + + /** Create a temp directory inside the local root directory */ + private def createTempDir(prefix: String): File = { + Utils.createDirectory(localRootDir.getAbsolutePath, prefix) + } + + /** Attempt to delete recursively, and log the error if any */ + private def silentDeleteRecursively(file: File, msg: String): Unit = { + try { + Utils.deleteRecursively(file) + } catch { + case e: Exception => + logWarning(s"Error recursively deleting local dir $file while $msg", e) + } + } + + /** Records the duration of running `body` for the next query progress update. */ + protected def timeTakenMs(body: => Unit): Long = Utils.timeTakenMs(body)._2 + + override protected def logName: String = s"${super.logName} $loggingId" +} + + +/** Mutable and reusable pair of byte arrays */ +class ByteArrayPair(var key: Array[Byte] = null, var value: Array[Byte] = null) { + def set(key: Array[Byte], value: Array[Byte]): ByteArrayPair = { + this.key = key + this.value = value + this + } +} + + +/** + * Configurations for optimizing RocksDB + * + * @param compactOnCommit Whether to compact RocksDB data before commit / checkpointing + */ +case class RocksDBConf( + minVersionsToRetain: Int, + compactOnCommit: Boolean, + pauseBackgroundWorkForCommit: Boolean, + blockSizeKB: Long, + blockCacheSizeMB: Long, + lockAcquireTimeoutMs: Long) + +object RocksDBConf { + /** Common prefix of all confs in SQLConf that affects RocksDB */ + val ROCKSDB_CONF_NAME_PREFIX = "spark.sql.streaming.stateStore.rocksdb" + + private case class ConfEntry(name: String, default: String) { + def fullName: String = s"$ROCKSDB_CONF_NAME_PREFIX.${name}".toLowerCase(Locale.ROOT) + } + + // Configuration that specifies whether to compact the RocksDB data every time data is committed + private val COMPACT_ON_COMMIT_CONF = ConfEntry("compactOnCommit", "false") + private val PAUSE_BG_WORK_FOR_COMMIT_CONF = ConfEntry("pauseBackgroundWorkForCommit", "true") + private val BLOCK_SIZE_KB_CONF = ConfEntry("blockSizeKB", "4") + private val BLOCK_CACHE_SIZE_MB_CONF = ConfEntry("blockCacheSizeMB", "8") + private val LOCK_ACQUIRE_TIMEOUT_MS_CONF = ConfEntry("lockAcquireTimeoutMs", "60000") + + def apply(storeConf: StateStoreConf): RocksDBConf = { + val confs = CaseInsensitiveMap[String](storeConf.confs) + + def getBooleanConf(conf: ConfEntry): Boolean = { + Try { confs.getOrElse(conf.fullName, conf.default).toBoolean } getOrElse { + throw new IllegalArgumentException(s"Invalid value for '${conf.fullName}', must be boolean") + } + } + + def getPositiveLongConf(conf: ConfEntry): Long = { + Try { confs.getOrElse(conf.fullName, conf.default).toLong } filter { _ >= 0 } getOrElse { + throw new IllegalArgumentException( + s"Invalid value for '${conf.fullName}', must be a positive integer") + } + } + + RocksDBConf( + storeConf.minVersionsToRetain, + getBooleanConf(COMPACT_ON_COMMIT_CONF), + getBooleanConf(PAUSE_BG_WORK_FOR_COMMIT_CONF), + getPositiveLongConf(BLOCK_SIZE_KB_CONF), + getPositiveLongConf(BLOCK_CACHE_SIZE_MB_CONF), + getPositiveLongConf(LOCK_ACQUIRE_TIMEOUT_MS_CONF)) + } + + def apply(): RocksDBConf = apply(new StateStoreConf()) +} + +case class AcquiredThreadInfo() { + val threadRef: WeakReference[Thread] = new WeakReference[Thread](Thread.currentThread()) + val tc: TaskContext = TaskContext.get() + + override def toString(): String = { + val taskStr = if (tc != null) { + val taskDetails = + s"${tc.partitionId}.${tc.attemptNumber} in stage ${tc.stageId}, TID ${tc.taskAttemptId}" + s", task: $taskDetails" + } else "" + + s"[ThreadId: ${threadRef.get.map(_.getId)}$taskStr]" + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBLoader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBLoader.scala new file mode 100644 index 0000000000..cc51819243 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBLoader.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.rocksdb.{RocksDB => NativeRocksDB} + +import org.apache.spark.internal.Logging +import org.apache.spark.util.UninterruptibleThread + +/** + * A wrapper for RocksDB library loading using an uninterruptible thread, as the native RocksDB + * code will throw an error when interrupted. + */ +object RocksDBLoader extends Logging { + /** + * Keep tracks of the exception thrown from the loading thread, if any. + */ + private var exception: Option[Throwable] = null + + private val loadLibraryThread = new UninterruptibleThread("RocksDBLoader") { + override def run(): Unit = { + try { + runUninterruptibly { + NativeRocksDB.loadLibrary() + exception = None + } + } catch { + case e: Throwable => + exception = Some(e) + } + } + } + + def loadLibrary(): Unit = synchronized { + if (exception == null) { + loadLibraryThread.start() + logInfo("RocksDB library loading thread started") + loadLibraryThread.join() + exception.foreach(throw _) + logInfo("RocksDB library loading thread finished successfully") + } else { + exception.foreach(throw _) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index c75eed2ae1..a11eb8a33d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -33,6 +33,75 @@ import org.apache.spark.util.Utils class RocksDBSuite extends SparkFunSuite { + test("RocksDB: get, put, iterator, commit, load") { + def testOps(compactOnCommit: Boolean): Unit = { + val remoteDir = Utils.createTempDir().toString + new File(remoteDir).delete() // to make sure that the directory gets created + + val conf = RocksDBConf().copy(compactOnCommit = compactOnCommit) + withDB(remoteDir, conf = conf) { db => + assert(db.get("a") === null) + assert(iterator(db).isEmpty) + + db.put("a", "1") + assert(toStr(db.get("a")) === "1") + db.commit() + } + + withDB(remoteDir, conf = conf, version = 0) { db => + // version 0 can be loaded again + assert(toStr(db.get("a")) === null) + assert(iterator(db).isEmpty) + } + + withDB(remoteDir, conf = conf, version = 1) { db => + // version 1 data recovered correctly + assert(toStr(db.get("a")) === "1") + assert(db.iterator().map(toStr).toSet === Set(("a", "1"))) + + // make changes but do not commit version 2 + db.put("b", "2") + assert(toStr(db.get("b")) === "2") + assert(db.iterator().map(toStr).toSet === Set(("a", "1"), ("b", "2"))) + } + + withDB(remoteDir, conf = conf, version = 1) { db => + // version 1 data not changed + assert(toStr(db.get("a")) === "1") + assert(db.get("b") === null) + assert(db.iterator().map(toStr).toSet === Set(("a", "1"))) + + // commit version 2 + db.put("b", "2") + assert(toStr(db.get("b")) === "2") + db.commit() + assert(db.iterator().map(toStr).toSet === Set(("a", "1"), ("b", "2"))) + } + + withDB(remoteDir, conf = conf, version = 1) { db => + // version 1 data not changed + assert(toStr(db.get("a")) === "1") + assert(db.get("b") === null) + } + + withDB(remoteDir, conf = conf, version = 2) { db => + // version 2 can be loaded again + assert(toStr(db.get("b")) === "2") + assert(db.iterator().map(toStr).toSet === Set(("a", "1"), ("b", "2"))) + + db.load(1) + assert(toStr(db.get("b")) === null) + assert(db.iterator().map(toStr).toSet === Set(("a", "1"))) + } + } + + for (compactOnCommit <- Seq(false, true)) { + withClue(s"compactOnCommit = $compactOnCommit") { + testOps(compactOnCommit) + } + } + } + test("RocksDBFileManager: upload only new immutable files") { withTempDir { dir => val dfsRootDir = dir.getAbsolutePath @@ -167,6 +236,26 @@ class RocksDBSuite extends SparkFunSuite { // scalastyle:on line.size.limit } + def withDB[T]( + remoteDir: String, + version: Int = 0, + conf: RocksDBConf = RocksDBConf().copy(compactOnCommit = false, minVersionsToRetain = 100), + hadoopConf: Configuration = new Configuration())( + func: RocksDB => T): T = { + var db: RocksDB = null + try { + db = new RocksDB( + remoteDir, conf = conf, hadoopConf = hadoopConf, + loggingId = s"[Thread-${Thread.currentThread.getId}]") + db.load(version) + func(db) + } finally { + if (db != null) { + db.close() + } + } + } + def generateFiles(dir: String, fileToLengths: Seq[(String, Int)]): Unit = { fileToLengths.foreach { case (fileName, length) => val file = new File(dir, fileName) @@ -200,6 +289,14 @@ class RocksDBSuite extends SparkFunSuite { implicit def toFile(path: String): File = new File(path) + implicit def toArray(str: String): Array[Byte] = if (str != null) str.getBytes else null + + implicit def toStr(bytes: Array[Byte]): String = if (bytes != null) new String(bytes) else null + + def toStr(kv: ByteArrayPair): (String, String) = (toStr(kv.key), toStr(kv.value)) + + def iterator(db: RocksDB): Iterator[(String, String)] = db.iterator().map(toStr) + def listFiles(file: File): Seq[File] = { if (!file.exists()) return Seq.empty file.listFiles.filter(file => !file.getName.endsWith("crc") && !file.isDirectory)