[SPARK-35988][SS] The implementation for RocksDBStateStoreProvider
### What changes were proposed in this pull request?
Add the implementation for the RocksDBStateStoreProvider. It's the subclass of StateStoreProvider that leverages all the functionalities implemented in the RocksDB instance.
### Why are the changes needed?
The interface for the end-user to use the RocksDB state store.
### Does this PR introduce _any_ user-facing change?
Yes. New RocksDBStateStore can be used in their applications.
### How was this patch tested?
New UT added.
Closes #33187 from xuanyuanking/SPARK-35988.
Authored-by: Yuanjian Li <yuanjian.li@databricks.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
(cherry picked from commit 0621e78b5f
)
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
This commit is contained in:
parent
cafb829c42
commit
097b667db7
|
@ -0,0 +1,331 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.streaming.state
|
||||
|
||||
import java.io._
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkEnv}
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.unsafe.Platform
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
private[state] class RocksDBStateStoreProvider
|
||||
extends StateStoreProvider with Logging with Closeable {
|
||||
import RocksDBStateStoreProvider._
|
||||
|
||||
class RocksDBStateStore(lastVersion: Long) extends StateStore {
|
||||
/** Trait and classes representing the internal state of the store */
|
||||
trait STATE
|
||||
case object UPDATING extends STATE
|
||||
case object COMMITTED extends STATE
|
||||
case object ABORTED extends STATE
|
||||
|
||||
@volatile private var state: STATE = UPDATING
|
||||
@volatile private var isValidated = false
|
||||
|
||||
override def id: StateStoreId = RocksDBStateStoreProvider.this.stateStoreId
|
||||
|
||||
override def version: Long = lastVersion
|
||||
|
||||
override def get(key: UnsafeRow): UnsafeRow = {
|
||||
verify(key != null, "Key cannot be null")
|
||||
val value = encoder.decodeValue(rocksDB.get(encoder.encode(key)))
|
||||
if (!isValidated && value != null) {
|
||||
StateStoreProvider.validateStateRowFormat(
|
||||
key, keySchema, value, valueSchema, storeConf)
|
||||
isValidated = true
|
||||
}
|
||||
value
|
||||
}
|
||||
|
||||
override def put(key: UnsafeRow, value: UnsafeRow): Unit = {
|
||||
verify(state == UPDATING, "Cannot put after already committed or aborted")
|
||||
verify(key != null, "Key cannot be null")
|
||||
require(value != null, "Cannot put a null value")
|
||||
rocksDB.put(encoder.encode(key), encoder.encode(value))
|
||||
}
|
||||
|
||||
override def remove(key: UnsafeRow): Unit = {
|
||||
verify(state == UPDATING, "Cannot remove after already committed or aborted")
|
||||
verify(key != null, "Key cannot be null")
|
||||
rocksDB.remove(encoder.encode(key))
|
||||
}
|
||||
|
||||
override def getRange(
|
||||
start: Option[UnsafeRow],
|
||||
end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = {
|
||||
verify(state == UPDATING, "Cannot call getRange() after already committed or aborted")
|
||||
iterator()
|
||||
}
|
||||
|
||||
override def iterator(): Iterator[UnsafeRowPair] = {
|
||||
rocksDB.iterator().map { kv =>
|
||||
val rowPair = encoder.decode(kv)
|
||||
if (!isValidated && rowPair.value != null) {
|
||||
StateStoreProvider.validateStateRowFormat(
|
||||
rowPair.key, keySchema, rowPair.value, valueSchema, storeConf)
|
||||
isValidated = true
|
||||
}
|
||||
rowPair
|
||||
}
|
||||
}
|
||||
|
||||
override def commit(): Long = synchronized {
|
||||
verify(state == UPDATING, "Cannot commit after already committed or aborted")
|
||||
val newVersion = rocksDB.commit()
|
||||
state = COMMITTED
|
||||
logInfo(s"Committed $newVersion for $id")
|
||||
newVersion
|
||||
}
|
||||
|
||||
override def abort(): Unit = {
|
||||
verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed")
|
||||
logInfo(s"Aborting ${version + 1} for $id")
|
||||
rocksDB.rollback()
|
||||
state = ABORTED
|
||||
}
|
||||
|
||||
override def metrics: StateStoreMetrics = {
|
||||
val rocksDBMetrics = rocksDB.metrics
|
||||
def commitLatencyMs(typ: String): Long = rocksDBMetrics.lastCommitLatencyMs.getOrElse(typ, 0L)
|
||||
def avgNativeOpsLatencyMs(typ: String): Long = {
|
||||
rocksDBMetrics.nativeOpsLatencyMicros.get(typ).map(_.avg).getOrElse(0.0).toLong
|
||||
}
|
||||
|
||||
val stateStoreCustomMetrics = Map[StateStoreCustomMetric, Long](
|
||||
CUSTOM_METRIC_SST_FILE_SIZE -> rocksDBMetrics.totalSSTFilesBytes,
|
||||
CUSTOM_METRIC_GET_TIME -> avgNativeOpsLatencyMs("get"),
|
||||
CUSTOM_METRIC_PUT_TIME -> avgNativeOpsLatencyMs("put"),
|
||||
CUSTOM_METRIC_WRITEBATCH_TIME -> commitLatencyMs("writeBatch"),
|
||||
CUSTOM_METRIC_FLUSH_TIME -> commitLatencyMs("flush"),
|
||||
CUSTOM_METRIC_PAUSE_TIME -> commitLatencyMs("pause"),
|
||||
CUSTOM_METRIC_CHECKPOINT_TIME -> commitLatencyMs("checkpoint"),
|
||||
CUSTOM_METRIC_FILESYNC_TIME -> commitLatencyMs("fileSync"),
|
||||
CUSTOM_METRIC_BYTES_COPIED -> rocksDBMetrics.bytesCopied,
|
||||
CUSTOM_METRIC_FILES_COPIED -> rocksDBMetrics.filesCopied,
|
||||
CUSTOM_METRIC_FILES_REUSED -> rocksDBMetrics.filesReused
|
||||
) ++ rocksDBMetrics.zipFileBytesUncompressed.map(bytes =>
|
||||
Map(CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED -> bytes)).getOrElse(Map())
|
||||
|
||||
StateStoreMetrics(
|
||||
rocksDBMetrics.numUncommittedKeys,
|
||||
rocksDBMetrics.memUsageBytes,
|
||||
stateStoreCustomMetrics)
|
||||
}
|
||||
|
||||
override def hasCommitted: Boolean = state == COMMITTED
|
||||
|
||||
override def toString: String = {
|
||||
s"RocksDBStateStore[id=(op=${id.operatorId},part=${id.partitionId})," +
|
||||
s"dir=${id.storeCheckpointLocation()}]"
|
||||
}
|
||||
|
||||
/** Return the [[RocksDB]] instance in this store. This is exposed mainly for testing. */
|
||||
def dbInstance(): RocksDB = rocksDB
|
||||
}
|
||||
|
||||
override def init(
|
||||
stateStoreId: StateStoreId,
|
||||
keySchema: StructType,
|
||||
valueSchema: StructType,
|
||||
indexOrdinal: Option[Int],
|
||||
storeConf: StateStoreConf,
|
||||
hadoopConf: Configuration): Unit = {
|
||||
this.stateStoreId_ = stateStoreId
|
||||
this.keySchema = keySchema
|
||||
this.valueSchema = valueSchema
|
||||
this.storeConf = storeConf
|
||||
this.hadoopConf = hadoopConf
|
||||
rocksDB // lazy initialization
|
||||
}
|
||||
|
||||
override def stateStoreId: StateStoreId = stateStoreId_
|
||||
|
||||
override def getStore(version: Long): StateStore = {
|
||||
require(version >= 0, "Version cannot be less than 0")
|
||||
rocksDB.load(version)
|
||||
new RocksDBStateStore(version)
|
||||
}
|
||||
|
||||
override def doMaintenance(): Unit = {
|
||||
rocksDB.cleanup()
|
||||
}
|
||||
|
||||
override def close(): Unit = {
|
||||
rocksDB.close()
|
||||
}
|
||||
|
||||
override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = ALL_CUSTOM_METRICS
|
||||
|
||||
private[state] def latestVersion: Long = rocksDB.getLatestVersion()
|
||||
|
||||
/** Internal fields and methods */
|
||||
|
||||
@volatile private var stateStoreId_ : StateStoreId = _
|
||||
@volatile private var keySchema: StructType = _
|
||||
@volatile private var valueSchema: StructType = _
|
||||
@volatile private var storeConf: StateStoreConf = _
|
||||
@volatile private var hadoopConf: Configuration = _
|
||||
|
||||
private[sql] lazy val rocksDB = {
|
||||
val dfsRootDir = stateStoreId.storeCheckpointLocation().toString
|
||||
val storeIdStr = s"StateStoreId(opId=${stateStoreId.operatorId}," +
|
||||
s"partId=${stateStoreId.partitionId},name=${stateStoreId.storeName})"
|
||||
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
|
||||
val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr)
|
||||
new RocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr)
|
||||
}
|
||||
|
||||
private lazy val encoder = new StateEncoder
|
||||
|
||||
private def verify(condition: => Boolean, msg: String): Unit = {
|
||||
if (!condition) { throw new IllegalStateException(msg) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Encodes/decodes UnsafeRows to versioned byte arrays.
|
||||
* It uses the first byte of the generated byte array to store the version that describes how the
|
||||
* row is encoded in the rest of the byte array. Currently, the default version is 0,
|
||||
*
|
||||
* VERSION 0: [ VERSION (1 byte) | ROW (N bytes) ]
|
||||
* The bytes of a UnsafeRow is written unmodified to starting from offset 1
|
||||
* (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes,
|
||||
* then the generated array byte will be N+1 bytes.
|
||||
*/
|
||||
class StateEncoder {
|
||||
import RocksDBStateStoreProvider._
|
||||
|
||||
// Reusable objects
|
||||
private val keyRow = new UnsafeRow(keySchema.size)
|
||||
private val valueRow = new UnsafeRow(valueSchema.size)
|
||||
private val rowTuple = new UnsafeRowPair()
|
||||
|
||||
/**
|
||||
* Encode the UnsafeRow of N bytes as a N+1 byte array.
|
||||
* @note This creates a new byte array and memcopies the UnsafeRow to the new array.
|
||||
*/
|
||||
def encode(row: UnsafeRow): Array[Byte] = {
|
||||
val bytesToEncode = row.getBytes
|
||||
val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES)
|
||||
Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION)
|
||||
// Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte arrays. See Platform.
|
||||
Platform.copyMemory(
|
||||
bytesToEncode, Platform.BYTE_ARRAY_OFFSET,
|
||||
encodedBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
|
||||
bytesToEncode.length)
|
||||
encodedBytes
|
||||
}
|
||||
|
||||
/**
|
||||
* Decode byte array for a key to a UnsafeRow.
|
||||
* @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
|
||||
* the given byte array.
|
||||
*/
|
||||
def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
|
||||
if (keyBytes != null) {
|
||||
// Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform.
|
||||
keyRow.pointTo(
|
||||
keyBytes,
|
||||
Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
|
||||
keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES)
|
||||
keyRow
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Decode byte array for a value to a UnsafeRow.
|
||||
*
|
||||
* @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
|
||||
* the given byte array.
|
||||
*/
|
||||
def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
|
||||
if (valueBytes != null) {
|
||||
// Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform.
|
||||
valueRow.pointTo(
|
||||
valueBytes,
|
||||
Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
|
||||
valueBytes.size - STATE_ENCODING_NUM_VERSION_BYTES)
|
||||
valueRow
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Decode pair of key-value byte arrays in a pair of key-value UnsafeRows.
|
||||
*
|
||||
* @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
|
||||
* the given byte array.
|
||||
*/
|
||||
def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair = {
|
||||
rowTuple.withRows(decodeKey(byteArrayTuple.key), decodeValue(byteArrayTuple.value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object RocksDBStateStoreProvider {
|
||||
// Version as a single byte that specifies the encoding of the row data in RocksDB
|
||||
val STATE_ENCODING_NUM_VERSION_BYTES = 1
|
||||
val STATE_ENCODING_VERSION: Byte = 0
|
||||
|
||||
// Native operation latencies report as latency per 1000 calls
|
||||
// as SQLMetrics support ms latency whereas RocksDB reports it in microseconds.
|
||||
val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric(
|
||||
"rocksdbGetLatency", "RocksDB: avg get latency (per 1000 calls)")
|
||||
val CUSTOM_METRIC_PUT_TIME = StateStoreCustomTimingMetric(
|
||||
"rocksdbPutLatency", "RocksDB: avg put latency (per 1000 calls)")
|
||||
|
||||
// Commit latency detailed breakdown
|
||||
val CUSTOM_METRIC_WRITEBATCH_TIME = StateStoreCustomTimingMetric(
|
||||
"rocksdbCommitWriteBatchLatency", "RocksDB: commit - write batch time")
|
||||
val CUSTOM_METRIC_FLUSH_TIME = StateStoreCustomTimingMetric(
|
||||
"rocksdbCommitFlushLatency", "RocksDB: commit - flush time")
|
||||
val CUSTOM_METRIC_PAUSE_TIME = StateStoreCustomTimingMetric(
|
||||
"rocksdbCommitPauseLatency", "RocksDB: commit - pause bg time")
|
||||
val CUSTOM_METRIC_CHECKPOINT_TIME = StateStoreCustomTimingMetric(
|
||||
"rocksdbCommitCheckpointLatency", "RocksDB: commit - checkpoint time")
|
||||
val CUSTOM_METRIC_FILESYNC_TIME = StateStoreCustomTimingMetric(
|
||||
"rocksdbFileSyncTime", "RocksDB: commit - file sync time")
|
||||
val CUSTOM_METRIC_FILES_COPIED = StateStoreCustomSizeMetric(
|
||||
"rocksdbFilesCopied", "RocksDB: file manager - files copied")
|
||||
val CUSTOM_METRIC_BYTES_COPIED = StateStoreCustomSizeMetric(
|
||||
"rocksdbBytesCopied", "RocksDB: file manager - bytes copied")
|
||||
val CUSTOM_METRIC_FILES_REUSED = StateStoreCustomSizeMetric(
|
||||
"rocksdbFilesReused", "RocksDB: file manager - files reused")
|
||||
val CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED = StateStoreCustomSizeMetric(
|
||||
"rocksdbZipFileBytesUncompressed", "RocksDB: file manager - uncompressed zip file bytes")
|
||||
|
||||
// Total SST file size
|
||||
val CUSTOM_METRIC_SST_FILE_SIZE = StateStoreCustomSizeMetric(
|
||||
"rocksdbSstFileSize", "RocksDB: size of all SST files")
|
||||
|
||||
val ALL_CUSTOM_METRICS = Seq(
|
||||
CUSTOM_METRIC_SST_FILE_SIZE, CUSTOM_METRIC_GET_TIME, CUSTOM_METRIC_PUT_TIME,
|
||||
CUSTOM_METRIC_WRITEBATCH_TIME, CUSTOM_METRIC_FLUSH_TIME, CUSTOM_METRIC_PAUSE_TIME,
|
||||
CUSTOM_METRIC_CHECKPOINT_TIME, CUSTOM_METRIC_FILESYNC_TIME,
|
||||
CUSTOM_METRIC_BYTES_COPIED, CUSTOM_METRIC_FILES_COPIED, CUSTOM_METRIC_FILES_REUSED,
|
||||
CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED
|
||||
)
|
||||
}
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.streaming.state
|
||||
|
||||
import java.io.File
|
||||
|
||||
import org.apache.spark.sql.execution.streaming.MemoryStream
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.streaming._
|
||||
|
||||
|
||||
class RocksDBStateStoreIntegrationSuite extends StreamTest {
|
||||
import testImplicits._
|
||||
|
||||
test("RocksDBStateStore") {
|
||||
withTempDir { dir =>
|
||||
val input = MemoryStream[Int]
|
||||
val conf = Map(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
|
||||
classOf[RocksDBStateStoreProvider].getName)
|
||||
|
||||
testStream(input.toDF.groupBy().count(), outputMode = OutputMode.Update)(
|
||||
StartStream(checkpointLocation = dir.getAbsolutePath, additionalConfs = conf),
|
||||
AddData(input, 1, 2, 3),
|
||||
CheckAnswer(3),
|
||||
AssertOnQuery { q =>
|
||||
// Verify that RocksDBStateStore by verify the state checkpoints are [version].zip
|
||||
val storeCheckpointDir = StateStoreId(
|
||||
dir.getAbsolutePath + "/state", 0, 0).storeCheckpointLocation()
|
||||
val storeCheckpointFile = storeCheckpointDir + "/1.zip"
|
||||
new File(storeCheckpointFile).exists()
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.streaming.state
|
||||
|
||||
import java.util.UUID
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.scalatest.BeforeAndAfter
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.sql.LocalSparkSession.withSparkSession
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.Platform
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvider]
|
||||
with BeforeAndAfter {
|
||||
|
||||
import StateStoreTestsHelper._
|
||||
|
||||
test("version encoding") {
|
||||
import RocksDBStateStoreProvider._
|
||||
|
||||
val provider = newStoreProvider()
|
||||
val store = provider.getStore(0)
|
||||
val keyRow = stringToRow("a")
|
||||
val valueRow = intToRow(1)
|
||||
store.put(keyRow, valueRow)
|
||||
val iter = provider.rocksDB.iterator()
|
||||
assert(iter.hasNext)
|
||||
val kv = iter.next()
|
||||
|
||||
// Verify the version encoded in first byte of the key and value byte arrays
|
||||
assert(Platform.getByte(kv.key, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION)
|
||||
assert(Platform.getByte(kv.value, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION)
|
||||
}
|
||||
|
||||
test("RocksDB confs are passed correctly from SparkSession to db instance") {
|
||||
val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName)
|
||||
withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
|
||||
// Set the session confs that should be passed into RocksDB
|
||||
val testConfs = Seq(
|
||||
("spark.sql.streaming.stateStore.providerClass",
|
||||
classOf[RocksDBStateStoreProvider].getName),
|
||||
(RocksDBConf.ROCKSDB_CONF_NAME_PREFIX + ".compactOnCommit", "true"),
|
||||
(RocksDBConf.ROCKSDB_CONF_NAME_PREFIX + ".lockAcquireTimeoutMs", "10")
|
||||
)
|
||||
testConfs.foreach { case (k, v) => spark.conf.set(k, v) }
|
||||
|
||||
// Prepare test objects for running task on state store
|
||||
val testRDD = spark.sparkContext.makeRDD[String](Seq("a"), 1)
|
||||
val testSchema = StructType(Seq(StructField("key", StringType, true)))
|
||||
val testStateInfo = StatefulOperatorStateInfo(
|
||||
checkpointLocation = Utils.createTempDir().getAbsolutePath,
|
||||
queryRunId = UUID.randomUUID, operatorId = 0, storeVersion = 0, numPartitions = 5)
|
||||
|
||||
// Create state store in a task and get the RocksDBConf from the instantiated RocksDB instance
|
||||
val rocksDBConfInTask: RocksDBConf = testRDD.mapPartitionsWithStateStore[RocksDBConf](
|
||||
spark.sqlContext, testStateInfo, testSchema, testSchema, None) {
|
||||
(store: StateStore, _: Iterator[String]) =>
|
||||
// Use reflection to get RocksDB instance
|
||||
val dbInstanceMethod =
|
||||
store.getClass.getMethods.filter(_.getName.contains("dbInstance")).head
|
||||
Iterator(dbInstanceMethod.invoke(store).asInstanceOf[RocksDB].conf)
|
||||
}.collect().head
|
||||
|
||||
// Verify the confs are same as those configured in the session conf
|
||||
assert(rocksDBConfInTask.compactOnCommit == true)
|
||||
assert(rocksDBConfInTask.lockAcquireTimeoutMs == 10L)
|
||||
}
|
||||
}
|
||||
|
||||
test("rocksdb file manager metrics exposed") {
|
||||
import RocksDBStateStoreProvider._
|
||||
def getCustomMetric(metrics: StateStoreMetrics, customMetric: StateStoreCustomMetric): Long = {
|
||||
val metricPair = metrics.customMetrics.find(_._1.name == customMetric.name)
|
||||
assert(metricPair.isDefined)
|
||||
metricPair.get._2
|
||||
}
|
||||
|
||||
val provider = newStoreProvider()
|
||||
val store = provider.getStore(0)
|
||||
// Verify state after updating
|
||||
put(store, "a", 1)
|
||||
assert(get(store, "a") === Some(1))
|
||||
assert(store.commit() === 1)
|
||||
assert(store.hasCommitted)
|
||||
val storeMetrics = store.metrics
|
||||
assert(storeMetrics.numKeys === 1)
|
||||
assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_FILES_COPIED) > 0L)
|
||||
assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_FILES_REUSED) == 0L)
|
||||
assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_BYTES_COPIED) > 0L)
|
||||
assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED) > 0L)
|
||||
}
|
||||
|
||||
override def newStoreProvider(): RocksDBStateStoreProvider = {
|
||||
newStoreProvider(StateStoreId(newDir(), Random.nextInt(), 0))
|
||||
}
|
||||
|
||||
def newStoreProvider(storeId: StateStoreId): RocksDBStateStoreProvider = {
|
||||
val keySchema = StructType(Seq(StructField("key", StringType, true)))
|
||||
val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
|
||||
val provider = new RocksDBStateStoreProvider()
|
||||
provider.init(
|
||||
storeId, keySchema, valueSchema, indexOrdinal = None, new StateStoreConf, new Configuration)
|
||||
provider
|
||||
}
|
||||
|
||||
override def getLatestData(storeProvider: RocksDBStateStoreProvider): Set[(String, Int)] = {
|
||||
getData(storeProvider, version = -1)
|
||||
}
|
||||
|
||||
override def getData(
|
||||
provider: RocksDBStateStoreProvider,
|
||||
version: Int = -1): Set[(String, Int)] = {
|
||||
val reloadedProvider = newStoreProvider(provider.stateStoreId)
|
||||
val versionToRead = if (version < 0) reloadedProvider.latestVersion else version
|
||||
reloadedProvider.getStore(versionToRead).iterator().map(rowsToStringInt).toSet
|
||||
}
|
||||
|
||||
override protected val keySchema = StructType(Seq(StructField("key", StringType, true)))
|
||||
override protected val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
|
||||
|
||||
override def newStoreProvider(
|
||||
minDeltasForSnapshot: Int,
|
||||
numOfVersToRetainInMemory: Int): RocksDBStateStoreProvider = newStoreProvider()
|
||||
|
||||
override def getDefaultSQLConf(
|
||||
minDeltasForSnapshot: Int,
|
||||
numOfVersToRetainInMemory: Int): SQLConf = new SQLConf()
|
||||
}
|
||||
|
|
@ -49,6 +49,7 @@ import org.apache.spark.util.Utils
|
|||
class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
|
||||
with BeforeAndAfter {
|
||||
import StateStoreTestsHelper._
|
||||
import StateStoreCoordinatorSuite._
|
||||
|
||||
override val keySchema = StructType(Seq(StructField("key", StringType, true)))
|
||||
override val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
|
||||
|
@ -235,6 +236,162 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
|
|||
assert(getSizeOfStateForCurrentVersion(store.metrics) > noDataMemoryUsed)
|
||||
}
|
||||
|
||||
test("maintenance") {
|
||||
val conf = new SparkConf()
|
||||
.setMaster("local")
|
||||
.setAppName("test")
|
||||
// Make sure that when SparkContext stops, the StateStore maintenance thread 'quickly'
|
||||
// fails to talk to the StateStoreCoordinator and unloads all the StateStores
|
||||
.set(RPC_NUM_RETRIES, 1)
|
||||
val opId = 0
|
||||
val dir1 = newDir()
|
||||
val storeProviderId1 = StateStoreProviderId(StateStoreId(dir1, opId, 0), UUID.randomUUID)
|
||||
val dir2 = newDir()
|
||||
val storeProviderId2 = StateStoreProviderId(StateStoreId(dir2, opId, 1), UUID.randomUUID)
|
||||
val sqlConf = getDefaultSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
|
||||
SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get)
|
||||
sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
|
||||
// Make maintenance thread do snapshots and cleanups very fast
|
||||
sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 10L)
|
||||
val storeConf = StateStoreConf(sqlConf)
|
||||
val hadoopConf = new Configuration()
|
||||
val provider = newStoreProvider(storeProviderId1.storeId)
|
||||
|
||||
var latestStoreVersion = 0
|
||||
|
||||
def generateStoreVersions(): Unit = {
|
||||
for (i <- 1 to 20) {
|
||||
val store = StateStore.get(storeProviderId1, keySchema, valueSchema, None,
|
||||
latestStoreVersion, storeConf, hadoopConf)
|
||||
put(store, "a", i)
|
||||
store.commit()
|
||||
latestStoreVersion += 1
|
||||
}
|
||||
}
|
||||
|
||||
val timeoutDuration = 1.minute
|
||||
|
||||
quietly {
|
||||
withSpark(new SparkContext(conf)) { sc =>
|
||||
withCoordinatorRef(sc) { coordinatorRef =>
|
||||
require(!StateStore.isMaintenanceRunning, "StateStore is unexpectedly running")
|
||||
|
||||
// Generate sufficient versions of store for snapshots
|
||||
generateStoreVersions()
|
||||
|
||||
eventually(timeout(timeoutDuration)) {
|
||||
// Store should have been reported to the coordinator
|
||||
assert(coordinatorRef.getLocation(storeProviderId1).nonEmpty,
|
||||
"active instance was not reported")
|
||||
|
||||
// Background maintenance should clean up and generate snapshots
|
||||
assert(StateStore.isMaintenanceRunning, "Maintenance task is not running")
|
||||
|
||||
// Some snapshots should have been generated
|
||||
val snapshotVersions = (1 to latestStoreVersion).filter { version =>
|
||||
fileExists(provider, version, isSnapshot = true)
|
||||
}
|
||||
assert(snapshotVersions.nonEmpty, "no snapshot file found")
|
||||
}
|
||||
|
||||
// Generate more versions such that there is another snapshot and
|
||||
// the earliest delta file will be cleaned up
|
||||
generateStoreVersions()
|
||||
|
||||
// Earliest delta file should get cleaned up
|
||||
eventually(timeout(timeoutDuration)) {
|
||||
assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted")
|
||||
}
|
||||
|
||||
// If driver decides to deactivate all stores related to a query run,
|
||||
// then this instance should be unloaded
|
||||
coordinatorRef.deactivateInstances(storeProviderId1.queryRunId)
|
||||
eventually(timeout(timeoutDuration)) {
|
||||
assert(!StateStore.isLoaded(storeProviderId1))
|
||||
}
|
||||
|
||||
// Reload the store and verify
|
||||
StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None,
|
||||
latestStoreVersion, storeConf, hadoopConf)
|
||||
assert(StateStore.isLoaded(storeProviderId1))
|
||||
|
||||
// If some other executor loads the store, then this instance should be unloaded
|
||||
coordinatorRef
|
||||
.reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
|
||||
eventually(timeout(timeoutDuration)) {
|
||||
assert(!StateStore.isLoaded(storeProviderId1))
|
||||
}
|
||||
|
||||
// Reload the store and verify
|
||||
StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None,
|
||||
latestStoreVersion, storeConf, hadoopConf)
|
||||
assert(StateStore.isLoaded(storeProviderId1))
|
||||
|
||||
// If some other executor loads the store, and when this executor loads other store,
|
||||
// then this executor should unload inactive instances immediately.
|
||||
coordinatorRef
|
||||
.reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
|
||||
StateStore.get(storeProviderId2, keySchema, valueSchema, indexOrdinal = None,
|
||||
0, storeConf, hadoopConf)
|
||||
assert(!StateStore.isLoaded(storeProviderId1))
|
||||
assert(StateStore.isLoaded(storeProviderId2))
|
||||
}
|
||||
}
|
||||
|
||||
// Verify if instance is unloaded if SparkContext is stopped
|
||||
eventually(timeout(timeoutDuration)) {
|
||||
require(SparkEnv.get === null)
|
||||
assert(!StateStore.isLoaded(storeProviderId1))
|
||||
assert(!StateStore.isLoaded(storeProviderId2))
|
||||
assert(!StateStore.isMaintenanceRunning)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("snapshotting") {
|
||||
val provider = newStoreProvider(minDeltasForSnapshot = 5, numOfVersToRetainInMemory = 2)
|
||||
|
||||
var currentVersion = 0
|
||||
|
||||
currentVersion = updateVersionTo(provider, currentVersion, 2)
|
||||
require(getLatestData(provider) === Set("a" -> 2))
|
||||
provider.doMaintenance() // should not generate snapshot files
|
||||
assert(getLatestData(provider) === Set("a" -> 2))
|
||||
|
||||
for (i <- 1 to currentVersion) {
|
||||
assert(fileExists(provider, i, isSnapshot = false)) // all delta files present
|
||||
assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present
|
||||
}
|
||||
|
||||
// After version 6, snapshotting should generate one snapshot file
|
||||
currentVersion = updateVersionTo(provider, currentVersion, 6)
|
||||
require(getLatestData(provider) === Set("a" -> 6), "store not updated correctly")
|
||||
provider.doMaintenance() // should generate snapshot files
|
||||
|
||||
val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true))
|
||||
assert(snapshotVersion.nonEmpty, "snapshot file not generated")
|
||||
deleteFilesEarlierThanVersion(provider, snapshotVersion.get)
|
||||
assert(
|
||||
getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get),
|
||||
"snapshotting messed up the data of the snapshotted version")
|
||||
assert(
|
||||
getLatestData(provider) === Set("a" -> 6),
|
||||
"snapshotting messed up the data of the final version")
|
||||
|
||||
// After version 20, snapshotting should generate newer snapshot files
|
||||
currentVersion = updateVersionTo(provider, currentVersion, 20)
|
||||
require(getLatestData(provider) === Set("a" -> 20), "store not updated correctly")
|
||||
provider.doMaintenance() // do snapshot
|
||||
|
||||
val latestSnapshotVersion = (0 to 20).filter(version =>
|
||||
fileExists(provider, version, isSnapshot = true)).lastOption
|
||||
assert(latestSnapshotVersion.nonEmpty, "no snapshot file found")
|
||||
assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated")
|
||||
|
||||
deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get)
|
||||
assert(getLatestData(provider) === Set("a" -> 20), "snapshotting messed up the data")
|
||||
}
|
||||
|
||||
testQuietly("SPARK-18342: commit fails when rename fails") {
|
||||
import RenameReturnsFalseFileSystem._
|
||||
val dir = scheme + "://" + newDir()
|
||||
|
@ -582,7 +739,6 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
|
|||
abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
|
||||
extends StateStoreCodecsTest with PrivateMethodTester {
|
||||
import StateStoreTestsHelper._
|
||||
import StateStoreCoordinatorSuite._
|
||||
|
||||
type MapType = mutable.HashMap[UnsafeRow, UnsafeRow]
|
||||
type ProviderMapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow]
|
||||
|
@ -761,118 +917,6 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
|
|||
assert(rowsToSet(finalStore.iterator()) === Set(key -> 2))
|
||||
}
|
||||
|
||||
test("maintenance") {
|
||||
val conf = new SparkConf()
|
||||
.setMaster("local")
|
||||
.setAppName("test")
|
||||
// Make sure that when SparkContext stops, the StateStore maintenance thread 'quickly'
|
||||
// fails to talk to the StateStoreCoordinator and unloads all the StateStores
|
||||
.set(RPC_NUM_RETRIES, 1)
|
||||
val opId = 0
|
||||
val dir1 = newDir()
|
||||
val storeProviderId1 = StateStoreProviderId(StateStoreId(dir1, opId, 0), UUID.randomUUID)
|
||||
val dir2 = newDir()
|
||||
val storeProviderId2 = StateStoreProviderId(StateStoreId(dir2, opId, 1), UUID.randomUUID)
|
||||
val sqlConf = getDefaultSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
|
||||
SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get)
|
||||
sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
|
||||
// Make maintenance thread do snapshots and cleanups very fast
|
||||
sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 10L)
|
||||
val storeConf = StateStoreConf(sqlConf)
|
||||
val hadoopConf = new Configuration()
|
||||
val provider = newStoreProvider(storeProviderId1.storeId)
|
||||
|
||||
var latestStoreVersion = 0
|
||||
|
||||
def generateStoreVersions(): Unit = {
|
||||
for (i <- 1 to 20) {
|
||||
val store = StateStore.get(storeProviderId1, keySchema, valueSchema, None,
|
||||
latestStoreVersion, storeConf, hadoopConf)
|
||||
put(store, "a", i)
|
||||
store.commit()
|
||||
latestStoreVersion += 1
|
||||
}
|
||||
}
|
||||
|
||||
val timeoutDuration = 1.minute
|
||||
|
||||
quietly {
|
||||
withSpark(new SparkContext(conf)) { sc =>
|
||||
withCoordinatorRef(sc) { coordinatorRef =>
|
||||
require(!StateStore.isMaintenanceRunning, "StateStore is unexpectedly running")
|
||||
|
||||
// Generate sufficient versions of store for snapshots
|
||||
generateStoreVersions()
|
||||
|
||||
eventually(timeout(timeoutDuration)) {
|
||||
// Store should have been reported to the coordinator
|
||||
assert(coordinatorRef.getLocation(storeProviderId1).nonEmpty,
|
||||
"active instance was not reported")
|
||||
|
||||
// Background maintenance should clean up and generate snapshots
|
||||
assert(StateStore.isMaintenanceRunning, "Maintenance task is not running")
|
||||
|
||||
// Some snapshots should have been generated
|
||||
val snapshotVersions = (1 to latestStoreVersion).filter { version =>
|
||||
fileExists(provider, version, isSnapshot = true)
|
||||
}
|
||||
assert(snapshotVersions.nonEmpty, "no snapshot file found")
|
||||
}
|
||||
|
||||
// Generate more versions such that there is another snapshot and
|
||||
// the earliest delta file will be cleaned up
|
||||
generateStoreVersions()
|
||||
|
||||
// Earliest delta file should get cleaned up
|
||||
eventually(timeout(timeoutDuration)) {
|
||||
assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted")
|
||||
}
|
||||
|
||||
// If driver decides to deactivate all stores related to a query run,
|
||||
// then this instance should be unloaded
|
||||
coordinatorRef.deactivateInstances(storeProviderId1.queryRunId)
|
||||
eventually(timeout(timeoutDuration)) {
|
||||
assert(!StateStore.isLoaded(storeProviderId1))
|
||||
}
|
||||
|
||||
// Reload the store and verify
|
||||
StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None,
|
||||
latestStoreVersion, storeConf, hadoopConf)
|
||||
assert(StateStore.isLoaded(storeProviderId1))
|
||||
|
||||
// If some other executor loads the store, then this instance should be unloaded
|
||||
coordinatorRef
|
||||
.reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
|
||||
eventually(timeout(timeoutDuration)) {
|
||||
assert(!StateStore.isLoaded(storeProviderId1))
|
||||
}
|
||||
|
||||
// Reload the store and verify
|
||||
StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None,
|
||||
latestStoreVersion, storeConf, hadoopConf)
|
||||
assert(StateStore.isLoaded(storeProviderId1))
|
||||
|
||||
// If some other executor loads the store, and when this executor loads other store,
|
||||
// then this executor should unload inactive instances immediately.
|
||||
coordinatorRef
|
||||
.reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
|
||||
StateStore.get(storeProviderId2, keySchema, valueSchema, indexOrdinal = None,
|
||||
0, storeConf, hadoopConf)
|
||||
assert(!StateStore.isLoaded(storeProviderId1))
|
||||
assert(StateStore.isLoaded(storeProviderId2))
|
||||
}
|
||||
}
|
||||
|
||||
// Verify if instance is unloaded if SparkContext is stopped
|
||||
eventually(timeout(timeoutDuration)) {
|
||||
require(SparkEnv.get === null)
|
||||
assert(!StateStore.isLoaded(storeProviderId1))
|
||||
assert(!StateStore.isLoaded(storeProviderId2))
|
||||
assert(!StateStore.isMaintenanceRunning)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("StateStore.get") {
|
||||
quietly {
|
||||
val dir = newDir()
|
||||
|
@ -925,50 +969,6 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
|
|||
}
|
||||
}
|
||||
|
||||
test("snapshotting") {
|
||||
val provider = newStoreProvider(minDeltasForSnapshot = 5, numOfVersToRetainInMemory = 2)
|
||||
|
||||
var currentVersion = 0
|
||||
|
||||
currentVersion = updateVersionTo(provider, currentVersion, 2)
|
||||
require(getLatestData(provider) === Set("a" -> 2))
|
||||
provider.doMaintenance() // should not generate snapshot files
|
||||
assert(getLatestData(provider) === Set("a" -> 2))
|
||||
|
||||
for (i <- 1 to currentVersion) {
|
||||
assert(fileExists(provider, i, isSnapshot = false)) // all delta files present
|
||||
assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present
|
||||
}
|
||||
|
||||
// After version 6, snapshotting should generate one snapshot file
|
||||
currentVersion = updateVersionTo(provider, currentVersion, 6)
|
||||
require(getLatestData(provider) === Set("a" -> 6), "store not updated correctly")
|
||||
provider.doMaintenance() // should generate snapshot files
|
||||
|
||||
val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true))
|
||||
assert(snapshotVersion.nonEmpty, "snapshot file not generated")
|
||||
deleteFilesEarlierThanVersion(provider, snapshotVersion.get)
|
||||
assert(
|
||||
getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get),
|
||||
"snapshotting messed up the data of the snapshotted version")
|
||||
assert(
|
||||
getLatestData(provider) === Set("a" -> 6),
|
||||
"snapshotting messed up the data of the final version")
|
||||
|
||||
// After version 20, snapshotting should generate newer snapshot files
|
||||
currentVersion = updateVersionTo(provider, currentVersion, 20)
|
||||
require(getLatestData(provider) === Set("a" -> 20), "store not updated correctly")
|
||||
provider.doMaintenance() // do snapshot
|
||||
|
||||
val latestSnapshotVersion = (0 to 20).filter(version =>
|
||||
fileExists(provider, version, isSnapshot = true)).lastOption
|
||||
assert(latestSnapshotVersion.nonEmpty, "no snapshot file found")
|
||||
assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated")
|
||||
|
||||
deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get)
|
||||
assert(getLatestData(provider) === Set("a" -> 20), "snapshotting messed up the data")
|
||||
}
|
||||
|
||||
test("reports memory usage") {
|
||||
val provider = newStoreProvider()
|
||||
val store = provider.getStore(0)
|
||||
|
|
Loading…
Reference in a new issue