[SPARK-35988][SS] The implementation for RocksDBStateStoreProvider

### What changes were proposed in this pull request?
Add the implementation for the RocksDBStateStoreProvider. It's the subclass of StateStoreProvider that leverages all the functionalities implemented in the RocksDB instance.

### Why are the changes needed?
The interface for the end-user to use the RocksDB state store.

### Does this PR introduce _any_ user-facing change?
Yes. New RocksDBStateStore can be used in their applications.

### How was this patch tested?
New UT added.

Closes #33187 from xuanyuanking/SPARK-35988.

Authored-by: Yuanjian Li <yuanjian.li@databricks.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
(cherry picked from commit 0621e78b5f)
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
This commit is contained in:
Yuanjian Li 2021-07-08 21:02:37 +09:00 committed by Jungtaek Lim
parent cafb829c42
commit 097b667db7
4 changed files with 691 additions and 157 deletions

View file

@ -0,0 +1,331 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming.state
import java.io._
import org.apache.hadoop.conf.Configuration
import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.Platform
import org.apache.spark.util.Utils
private[state] class RocksDBStateStoreProvider
extends StateStoreProvider with Logging with Closeable {
import RocksDBStateStoreProvider._
class RocksDBStateStore(lastVersion: Long) extends StateStore {
/** Trait and classes representing the internal state of the store */
trait STATE
case object UPDATING extends STATE
case object COMMITTED extends STATE
case object ABORTED extends STATE
@volatile private var state: STATE = UPDATING
@volatile private var isValidated = false
override def id: StateStoreId = RocksDBStateStoreProvider.this.stateStoreId
override def version: Long = lastVersion
override def get(key: UnsafeRow): UnsafeRow = {
verify(key != null, "Key cannot be null")
val value = encoder.decodeValue(rocksDB.get(encoder.encode(key)))
if (!isValidated && value != null) {
StateStoreProvider.validateStateRowFormat(
key, keySchema, value, valueSchema, storeConf)
isValidated = true
}
value
}
override def put(key: UnsafeRow, value: UnsafeRow): Unit = {
verify(state == UPDATING, "Cannot put after already committed or aborted")
verify(key != null, "Key cannot be null")
require(value != null, "Cannot put a null value")
rocksDB.put(encoder.encode(key), encoder.encode(value))
}
override def remove(key: UnsafeRow): Unit = {
verify(state == UPDATING, "Cannot remove after already committed or aborted")
verify(key != null, "Key cannot be null")
rocksDB.remove(encoder.encode(key))
}
override def getRange(
start: Option[UnsafeRow],
end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = {
verify(state == UPDATING, "Cannot call getRange() after already committed or aborted")
iterator()
}
override def iterator(): Iterator[UnsafeRowPair] = {
rocksDB.iterator().map { kv =>
val rowPair = encoder.decode(kv)
if (!isValidated && rowPair.value != null) {
StateStoreProvider.validateStateRowFormat(
rowPair.key, keySchema, rowPair.value, valueSchema, storeConf)
isValidated = true
}
rowPair
}
}
override def commit(): Long = synchronized {
verify(state == UPDATING, "Cannot commit after already committed or aborted")
val newVersion = rocksDB.commit()
state = COMMITTED
logInfo(s"Committed $newVersion for $id")
newVersion
}
override def abort(): Unit = {
verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed")
logInfo(s"Aborting ${version + 1} for $id")
rocksDB.rollback()
state = ABORTED
}
override def metrics: StateStoreMetrics = {
val rocksDBMetrics = rocksDB.metrics
def commitLatencyMs(typ: String): Long = rocksDBMetrics.lastCommitLatencyMs.getOrElse(typ, 0L)
def avgNativeOpsLatencyMs(typ: String): Long = {
rocksDBMetrics.nativeOpsLatencyMicros.get(typ).map(_.avg).getOrElse(0.0).toLong
}
val stateStoreCustomMetrics = Map[StateStoreCustomMetric, Long](
CUSTOM_METRIC_SST_FILE_SIZE -> rocksDBMetrics.totalSSTFilesBytes,
CUSTOM_METRIC_GET_TIME -> avgNativeOpsLatencyMs("get"),
CUSTOM_METRIC_PUT_TIME -> avgNativeOpsLatencyMs("put"),
CUSTOM_METRIC_WRITEBATCH_TIME -> commitLatencyMs("writeBatch"),
CUSTOM_METRIC_FLUSH_TIME -> commitLatencyMs("flush"),
CUSTOM_METRIC_PAUSE_TIME -> commitLatencyMs("pause"),
CUSTOM_METRIC_CHECKPOINT_TIME -> commitLatencyMs("checkpoint"),
CUSTOM_METRIC_FILESYNC_TIME -> commitLatencyMs("fileSync"),
CUSTOM_METRIC_BYTES_COPIED -> rocksDBMetrics.bytesCopied,
CUSTOM_METRIC_FILES_COPIED -> rocksDBMetrics.filesCopied,
CUSTOM_METRIC_FILES_REUSED -> rocksDBMetrics.filesReused
) ++ rocksDBMetrics.zipFileBytesUncompressed.map(bytes =>
Map(CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED -> bytes)).getOrElse(Map())
StateStoreMetrics(
rocksDBMetrics.numUncommittedKeys,
rocksDBMetrics.memUsageBytes,
stateStoreCustomMetrics)
}
override def hasCommitted: Boolean = state == COMMITTED
override def toString: String = {
s"RocksDBStateStore[id=(op=${id.operatorId},part=${id.partitionId})," +
s"dir=${id.storeCheckpointLocation()}]"
}
/** Return the [[RocksDB]] instance in this store. This is exposed mainly for testing. */
def dbInstance(): RocksDB = rocksDB
}
override def init(
stateStoreId: StateStoreId,
keySchema: StructType,
valueSchema: StructType,
indexOrdinal: Option[Int],
storeConf: StateStoreConf,
hadoopConf: Configuration): Unit = {
this.stateStoreId_ = stateStoreId
this.keySchema = keySchema
this.valueSchema = valueSchema
this.storeConf = storeConf
this.hadoopConf = hadoopConf
rocksDB // lazy initialization
}
override def stateStoreId: StateStoreId = stateStoreId_
override def getStore(version: Long): StateStore = {
require(version >= 0, "Version cannot be less than 0")
rocksDB.load(version)
new RocksDBStateStore(version)
}
override def doMaintenance(): Unit = {
rocksDB.cleanup()
}
override def close(): Unit = {
rocksDB.close()
}
override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = ALL_CUSTOM_METRICS
private[state] def latestVersion: Long = rocksDB.getLatestVersion()
/** Internal fields and methods */
@volatile private var stateStoreId_ : StateStoreId = _
@volatile private var keySchema: StructType = _
@volatile private var valueSchema: StructType = _
@volatile private var storeConf: StateStoreConf = _
@volatile private var hadoopConf: Configuration = _
private[sql] lazy val rocksDB = {
val dfsRootDir = stateStoreId.storeCheckpointLocation().toString
val storeIdStr = s"StateStoreId(opId=${stateStoreId.operatorId}," +
s"partId=${stateStoreId.partitionId},name=${stateStoreId.storeName})"
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr)
new RocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr)
}
private lazy val encoder = new StateEncoder
private def verify(condition: => Boolean, msg: String): Unit = {
if (!condition) { throw new IllegalStateException(msg) }
}
/**
* Encodes/decodes UnsafeRows to versioned byte arrays.
* It uses the first byte of the generated byte array to store the version that describes how the
* row is encoded in the rest of the byte array. Currently, the default version is 0,
*
* VERSION 0: [ VERSION (1 byte) | ROW (N bytes) ]
* The bytes of a UnsafeRow is written unmodified to starting from offset 1
* (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes,
* then the generated array byte will be N+1 bytes.
*/
class StateEncoder {
import RocksDBStateStoreProvider._
// Reusable objects
private val keyRow = new UnsafeRow(keySchema.size)
private val valueRow = new UnsafeRow(valueSchema.size)
private val rowTuple = new UnsafeRowPair()
/**
* Encode the UnsafeRow of N bytes as a N+1 byte array.
* @note This creates a new byte array and memcopies the UnsafeRow to the new array.
*/
def encode(row: UnsafeRow): Array[Byte] = {
val bytesToEncode = row.getBytes
val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES)
Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION)
// Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte arrays. See Platform.
Platform.copyMemory(
bytesToEncode, Platform.BYTE_ARRAY_OFFSET,
encodedBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
bytesToEncode.length)
encodedBytes
}
/**
* Decode byte array for a key to a UnsafeRow.
* @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
* the given byte array.
*/
def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
if (keyBytes != null) {
// Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform.
keyRow.pointTo(
keyBytes,
Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES)
keyRow
} else {
null
}
}
/**
* Decode byte array for a value to a UnsafeRow.
*
* @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
* the given byte array.
*/
def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
if (valueBytes != null) {
// Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform.
valueRow.pointTo(
valueBytes,
Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
valueBytes.size - STATE_ENCODING_NUM_VERSION_BYTES)
valueRow
} else {
null
}
}
/**
* Decode pair of key-value byte arrays in a pair of key-value UnsafeRows.
*
* @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
* the given byte array.
*/
def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair = {
rowTuple.withRows(decodeKey(byteArrayTuple.key), decodeValue(byteArrayTuple.value))
}
}
}
object RocksDBStateStoreProvider {
// Version as a single byte that specifies the encoding of the row data in RocksDB
val STATE_ENCODING_NUM_VERSION_BYTES = 1
val STATE_ENCODING_VERSION: Byte = 0
// Native operation latencies report as latency per 1000 calls
// as SQLMetrics support ms latency whereas RocksDB reports it in microseconds.
val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric(
"rocksdbGetLatency", "RocksDB: avg get latency (per 1000 calls)")
val CUSTOM_METRIC_PUT_TIME = StateStoreCustomTimingMetric(
"rocksdbPutLatency", "RocksDB: avg put latency (per 1000 calls)")
// Commit latency detailed breakdown
val CUSTOM_METRIC_WRITEBATCH_TIME = StateStoreCustomTimingMetric(
"rocksdbCommitWriteBatchLatency", "RocksDB: commit - write batch time")
val CUSTOM_METRIC_FLUSH_TIME = StateStoreCustomTimingMetric(
"rocksdbCommitFlushLatency", "RocksDB: commit - flush time")
val CUSTOM_METRIC_PAUSE_TIME = StateStoreCustomTimingMetric(
"rocksdbCommitPauseLatency", "RocksDB: commit - pause bg time")
val CUSTOM_METRIC_CHECKPOINT_TIME = StateStoreCustomTimingMetric(
"rocksdbCommitCheckpointLatency", "RocksDB: commit - checkpoint time")
val CUSTOM_METRIC_FILESYNC_TIME = StateStoreCustomTimingMetric(
"rocksdbFileSyncTime", "RocksDB: commit - file sync time")
val CUSTOM_METRIC_FILES_COPIED = StateStoreCustomSizeMetric(
"rocksdbFilesCopied", "RocksDB: file manager - files copied")
val CUSTOM_METRIC_BYTES_COPIED = StateStoreCustomSizeMetric(
"rocksdbBytesCopied", "RocksDB: file manager - bytes copied")
val CUSTOM_METRIC_FILES_REUSED = StateStoreCustomSizeMetric(
"rocksdbFilesReused", "RocksDB: file manager - files reused")
val CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED = StateStoreCustomSizeMetric(
"rocksdbZipFileBytesUncompressed", "RocksDB: file manager - uncompressed zip file bytes")
// Total SST file size
val CUSTOM_METRIC_SST_FILE_SIZE = StateStoreCustomSizeMetric(
"rocksdbSstFileSize", "RocksDB: size of all SST files")
val ALL_CUSTOM_METRICS = Seq(
CUSTOM_METRIC_SST_FILE_SIZE, CUSTOM_METRIC_GET_TIME, CUSTOM_METRIC_PUT_TIME,
CUSTOM_METRIC_WRITEBATCH_TIME, CUSTOM_METRIC_FLUSH_TIME, CUSTOM_METRIC_PAUSE_TIME,
CUSTOM_METRIC_CHECKPOINT_TIME, CUSTOM_METRIC_FILESYNC_TIME,
CUSTOM_METRIC_BYTES_COPIED, CUSTOM_METRIC_FILES_COPIED, CUSTOM_METRIC_FILES_REUSED,
CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED
)
}

View file

@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming.state
import java.io.File
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming._
class RocksDBStateStoreIntegrationSuite extends StreamTest {
import testImplicits._
test("RocksDBStateStore") {
withTempDir { dir =>
val input = MemoryStream[Int]
val conf = Map(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName)
testStream(input.toDF.groupBy().count(), outputMode = OutputMode.Update)(
StartStream(checkpointLocation = dir.getAbsolutePath, additionalConfs = conf),
AddData(input, 1, 2, 3),
CheckAnswer(3),
AssertOnQuery { q =>
// Verify that RocksDBStateStore by verify the state checkpoints are [version].zip
val storeCheckpointDir = StateStoreId(
dir.getAbsolutePath + "/state", 0, 0).storeCheckpointLocation()
val storeCheckpointFile = storeCheckpointDir + "/1.zip"
new File(storeCheckpointFile).exists()
}
)
}
}
}

View file

@ -0,0 +1,152 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming.state
import java.util.UUID
import scala.util.Random
import org.apache.hadoop.conf.Configuration
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkConf
import org.apache.spark.sql.LocalSparkSession.withSparkSession
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.util.Utils
class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvider]
with BeforeAndAfter {
import StateStoreTestsHelper._
test("version encoding") {
import RocksDBStateStoreProvider._
val provider = newStoreProvider()
val store = provider.getStore(0)
val keyRow = stringToRow("a")
val valueRow = intToRow(1)
store.put(keyRow, valueRow)
val iter = provider.rocksDB.iterator()
assert(iter.hasNext)
val kv = iter.next()
// Verify the version encoded in first byte of the key and value byte arrays
assert(Platform.getByte(kv.key, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION)
assert(Platform.getByte(kv.value, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION)
}
test("RocksDB confs are passed correctly from SparkSession to db instance") {
val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName)
withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
// Set the session confs that should be passed into RocksDB
val testConfs = Seq(
("spark.sql.streaming.stateStore.providerClass",
classOf[RocksDBStateStoreProvider].getName),
(RocksDBConf.ROCKSDB_CONF_NAME_PREFIX + ".compactOnCommit", "true"),
(RocksDBConf.ROCKSDB_CONF_NAME_PREFIX + ".lockAcquireTimeoutMs", "10")
)
testConfs.foreach { case (k, v) => spark.conf.set(k, v) }
// Prepare test objects for running task on state store
val testRDD = spark.sparkContext.makeRDD[String](Seq("a"), 1)
val testSchema = StructType(Seq(StructField("key", StringType, true)))
val testStateInfo = StatefulOperatorStateInfo(
checkpointLocation = Utils.createTempDir().getAbsolutePath,
queryRunId = UUID.randomUUID, operatorId = 0, storeVersion = 0, numPartitions = 5)
// Create state store in a task and get the RocksDBConf from the instantiated RocksDB instance
val rocksDBConfInTask: RocksDBConf = testRDD.mapPartitionsWithStateStore[RocksDBConf](
spark.sqlContext, testStateInfo, testSchema, testSchema, None) {
(store: StateStore, _: Iterator[String]) =>
// Use reflection to get RocksDB instance
val dbInstanceMethod =
store.getClass.getMethods.filter(_.getName.contains("dbInstance")).head
Iterator(dbInstanceMethod.invoke(store).asInstanceOf[RocksDB].conf)
}.collect().head
// Verify the confs are same as those configured in the session conf
assert(rocksDBConfInTask.compactOnCommit == true)
assert(rocksDBConfInTask.lockAcquireTimeoutMs == 10L)
}
}
test("rocksdb file manager metrics exposed") {
import RocksDBStateStoreProvider._
def getCustomMetric(metrics: StateStoreMetrics, customMetric: StateStoreCustomMetric): Long = {
val metricPair = metrics.customMetrics.find(_._1.name == customMetric.name)
assert(metricPair.isDefined)
metricPair.get._2
}
val provider = newStoreProvider()
val store = provider.getStore(0)
// Verify state after updating
put(store, "a", 1)
assert(get(store, "a") === Some(1))
assert(store.commit() === 1)
assert(store.hasCommitted)
val storeMetrics = store.metrics
assert(storeMetrics.numKeys === 1)
assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_FILES_COPIED) > 0L)
assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_FILES_REUSED) == 0L)
assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_BYTES_COPIED) > 0L)
assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED) > 0L)
}
override def newStoreProvider(): RocksDBStateStoreProvider = {
newStoreProvider(StateStoreId(newDir(), Random.nextInt(), 0))
}
def newStoreProvider(storeId: StateStoreId): RocksDBStateStoreProvider = {
val keySchema = StructType(Seq(StructField("key", StringType, true)))
val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
val provider = new RocksDBStateStoreProvider()
provider.init(
storeId, keySchema, valueSchema, indexOrdinal = None, new StateStoreConf, new Configuration)
provider
}
override def getLatestData(storeProvider: RocksDBStateStoreProvider): Set[(String, Int)] = {
getData(storeProvider, version = -1)
}
override def getData(
provider: RocksDBStateStoreProvider,
version: Int = -1): Set[(String, Int)] = {
val reloadedProvider = newStoreProvider(provider.stateStoreId)
val versionToRead = if (version < 0) reloadedProvider.latestVersion else version
reloadedProvider.getStore(versionToRead).iterator().map(rowsToStringInt).toSet
}
override protected val keySchema = StructType(Seq(StructField("key", StringType, true)))
override protected val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
override def newStoreProvider(
minDeltasForSnapshot: Int,
numOfVersToRetainInMemory: Int): RocksDBStateStoreProvider = newStoreProvider()
override def getDefaultSQLConf(
minDeltasForSnapshot: Int,
numOfVersToRetainInMemory: Int): SQLConf = new SQLConf()
}

View file

@ -49,6 +49,7 @@ import org.apache.spark.util.Utils
class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
with BeforeAndAfter { with BeforeAndAfter {
import StateStoreTestsHelper._ import StateStoreTestsHelper._
import StateStoreCoordinatorSuite._
override val keySchema = StructType(Seq(StructField("key", StringType, true))) override val keySchema = StructType(Seq(StructField("key", StringType, true)))
override val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) override val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
@ -235,6 +236,162 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
assert(getSizeOfStateForCurrentVersion(store.metrics) > noDataMemoryUsed) assert(getSizeOfStateForCurrentVersion(store.metrics) > noDataMemoryUsed)
} }
test("maintenance") {
val conf = new SparkConf()
.setMaster("local")
.setAppName("test")
// Make sure that when SparkContext stops, the StateStore maintenance thread 'quickly'
// fails to talk to the StateStoreCoordinator and unloads all the StateStores
.set(RPC_NUM_RETRIES, 1)
val opId = 0
val dir1 = newDir()
val storeProviderId1 = StateStoreProviderId(StateStoreId(dir1, opId, 0), UUID.randomUUID)
val dir2 = newDir()
val storeProviderId2 = StateStoreProviderId(StateStoreId(dir2, opId, 1), UUID.randomUUID)
val sqlConf = getDefaultSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get)
sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
// Make maintenance thread do snapshots and cleanups very fast
sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 10L)
val storeConf = StateStoreConf(sqlConf)
val hadoopConf = new Configuration()
val provider = newStoreProvider(storeProviderId1.storeId)
var latestStoreVersion = 0
def generateStoreVersions(): Unit = {
for (i <- 1 to 20) {
val store = StateStore.get(storeProviderId1, keySchema, valueSchema, None,
latestStoreVersion, storeConf, hadoopConf)
put(store, "a", i)
store.commit()
latestStoreVersion += 1
}
}
val timeoutDuration = 1.minute
quietly {
withSpark(new SparkContext(conf)) { sc =>
withCoordinatorRef(sc) { coordinatorRef =>
require(!StateStore.isMaintenanceRunning, "StateStore is unexpectedly running")
// Generate sufficient versions of store for snapshots
generateStoreVersions()
eventually(timeout(timeoutDuration)) {
// Store should have been reported to the coordinator
assert(coordinatorRef.getLocation(storeProviderId1).nonEmpty,
"active instance was not reported")
// Background maintenance should clean up and generate snapshots
assert(StateStore.isMaintenanceRunning, "Maintenance task is not running")
// Some snapshots should have been generated
val snapshotVersions = (1 to latestStoreVersion).filter { version =>
fileExists(provider, version, isSnapshot = true)
}
assert(snapshotVersions.nonEmpty, "no snapshot file found")
}
// Generate more versions such that there is another snapshot and
// the earliest delta file will be cleaned up
generateStoreVersions()
// Earliest delta file should get cleaned up
eventually(timeout(timeoutDuration)) {
assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted")
}
// If driver decides to deactivate all stores related to a query run,
// then this instance should be unloaded
coordinatorRef.deactivateInstances(storeProviderId1.queryRunId)
eventually(timeout(timeoutDuration)) {
assert(!StateStore.isLoaded(storeProviderId1))
}
// Reload the store and verify
StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None,
latestStoreVersion, storeConf, hadoopConf)
assert(StateStore.isLoaded(storeProviderId1))
// If some other executor loads the store, then this instance should be unloaded
coordinatorRef
.reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
eventually(timeout(timeoutDuration)) {
assert(!StateStore.isLoaded(storeProviderId1))
}
// Reload the store and verify
StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None,
latestStoreVersion, storeConf, hadoopConf)
assert(StateStore.isLoaded(storeProviderId1))
// If some other executor loads the store, and when this executor loads other store,
// then this executor should unload inactive instances immediately.
coordinatorRef
.reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
StateStore.get(storeProviderId2, keySchema, valueSchema, indexOrdinal = None,
0, storeConf, hadoopConf)
assert(!StateStore.isLoaded(storeProviderId1))
assert(StateStore.isLoaded(storeProviderId2))
}
}
// Verify if instance is unloaded if SparkContext is stopped
eventually(timeout(timeoutDuration)) {
require(SparkEnv.get === null)
assert(!StateStore.isLoaded(storeProviderId1))
assert(!StateStore.isLoaded(storeProviderId2))
assert(!StateStore.isMaintenanceRunning)
}
}
}
test("snapshotting") {
val provider = newStoreProvider(minDeltasForSnapshot = 5, numOfVersToRetainInMemory = 2)
var currentVersion = 0
currentVersion = updateVersionTo(provider, currentVersion, 2)
require(getLatestData(provider) === Set("a" -> 2))
provider.doMaintenance() // should not generate snapshot files
assert(getLatestData(provider) === Set("a" -> 2))
for (i <- 1 to currentVersion) {
assert(fileExists(provider, i, isSnapshot = false)) // all delta files present
assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present
}
// After version 6, snapshotting should generate one snapshot file
currentVersion = updateVersionTo(provider, currentVersion, 6)
require(getLatestData(provider) === Set("a" -> 6), "store not updated correctly")
provider.doMaintenance() // should generate snapshot files
val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true))
assert(snapshotVersion.nonEmpty, "snapshot file not generated")
deleteFilesEarlierThanVersion(provider, snapshotVersion.get)
assert(
getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get),
"snapshotting messed up the data of the snapshotted version")
assert(
getLatestData(provider) === Set("a" -> 6),
"snapshotting messed up the data of the final version")
// After version 20, snapshotting should generate newer snapshot files
currentVersion = updateVersionTo(provider, currentVersion, 20)
require(getLatestData(provider) === Set("a" -> 20), "store not updated correctly")
provider.doMaintenance() // do snapshot
val latestSnapshotVersion = (0 to 20).filter(version =>
fileExists(provider, version, isSnapshot = true)).lastOption
assert(latestSnapshotVersion.nonEmpty, "no snapshot file found")
assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated")
deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get)
assert(getLatestData(provider) === Set("a" -> 20), "snapshotting messed up the data")
}
testQuietly("SPARK-18342: commit fails when rename fails") { testQuietly("SPARK-18342: commit fails when rename fails") {
import RenameReturnsFalseFileSystem._ import RenameReturnsFalseFileSystem._
val dir = scheme + "://" + newDir() val dir = scheme + "://" + newDir()
@ -582,7 +739,6 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
extends StateStoreCodecsTest with PrivateMethodTester { extends StateStoreCodecsTest with PrivateMethodTester {
import StateStoreTestsHelper._ import StateStoreTestsHelper._
import StateStoreCoordinatorSuite._
type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] type MapType = mutable.HashMap[UnsafeRow, UnsafeRow]
type ProviderMapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow] type ProviderMapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow]
@ -761,118 +917,6 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
assert(rowsToSet(finalStore.iterator()) === Set(key -> 2)) assert(rowsToSet(finalStore.iterator()) === Set(key -> 2))
} }
test("maintenance") {
val conf = new SparkConf()
.setMaster("local")
.setAppName("test")
// Make sure that when SparkContext stops, the StateStore maintenance thread 'quickly'
// fails to talk to the StateStoreCoordinator and unloads all the StateStores
.set(RPC_NUM_RETRIES, 1)
val opId = 0
val dir1 = newDir()
val storeProviderId1 = StateStoreProviderId(StateStoreId(dir1, opId, 0), UUID.randomUUID)
val dir2 = newDir()
val storeProviderId2 = StateStoreProviderId(StateStoreId(dir2, opId, 1), UUID.randomUUID)
val sqlConf = getDefaultSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get)
sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
// Make maintenance thread do snapshots and cleanups very fast
sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 10L)
val storeConf = StateStoreConf(sqlConf)
val hadoopConf = new Configuration()
val provider = newStoreProvider(storeProviderId1.storeId)
var latestStoreVersion = 0
def generateStoreVersions(): Unit = {
for (i <- 1 to 20) {
val store = StateStore.get(storeProviderId1, keySchema, valueSchema, None,
latestStoreVersion, storeConf, hadoopConf)
put(store, "a", i)
store.commit()
latestStoreVersion += 1
}
}
val timeoutDuration = 1.minute
quietly {
withSpark(new SparkContext(conf)) { sc =>
withCoordinatorRef(sc) { coordinatorRef =>
require(!StateStore.isMaintenanceRunning, "StateStore is unexpectedly running")
// Generate sufficient versions of store for snapshots
generateStoreVersions()
eventually(timeout(timeoutDuration)) {
// Store should have been reported to the coordinator
assert(coordinatorRef.getLocation(storeProviderId1).nonEmpty,
"active instance was not reported")
// Background maintenance should clean up and generate snapshots
assert(StateStore.isMaintenanceRunning, "Maintenance task is not running")
// Some snapshots should have been generated
val snapshotVersions = (1 to latestStoreVersion).filter { version =>
fileExists(provider, version, isSnapshot = true)
}
assert(snapshotVersions.nonEmpty, "no snapshot file found")
}
// Generate more versions such that there is another snapshot and
// the earliest delta file will be cleaned up
generateStoreVersions()
// Earliest delta file should get cleaned up
eventually(timeout(timeoutDuration)) {
assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted")
}
// If driver decides to deactivate all stores related to a query run,
// then this instance should be unloaded
coordinatorRef.deactivateInstances(storeProviderId1.queryRunId)
eventually(timeout(timeoutDuration)) {
assert(!StateStore.isLoaded(storeProviderId1))
}
// Reload the store and verify
StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None,
latestStoreVersion, storeConf, hadoopConf)
assert(StateStore.isLoaded(storeProviderId1))
// If some other executor loads the store, then this instance should be unloaded
coordinatorRef
.reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
eventually(timeout(timeoutDuration)) {
assert(!StateStore.isLoaded(storeProviderId1))
}
// Reload the store and verify
StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None,
latestStoreVersion, storeConf, hadoopConf)
assert(StateStore.isLoaded(storeProviderId1))
// If some other executor loads the store, and when this executor loads other store,
// then this executor should unload inactive instances immediately.
coordinatorRef
.reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
StateStore.get(storeProviderId2, keySchema, valueSchema, indexOrdinal = None,
0, storeConf, hadoopConf)
assert(!StateStore.isLoaded(storeProviderId1))
assert(StateStore.isLoaded(storeProviderId2))
}
}
// Verify if instance is unloaded if SparkContext is stopped
eventually(timeout(timeoutDuration)) {
require(SparkEnv.get === null)
assert(!StateStore.isLoaded(storeProviderId1))
assert(!StateStore.isLoaded(storeProviderId2))
assert(!StateStore.isMaintenanceRunning)
}
}
}
test("StateStore.get") { test("StateStore.get") {
quietly { quietly {
val dir = newDir() val dir = newDir()
@ -925,50 +969,6 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
} }
} }
test("snapshotting") {
val provider = newStoreProvider(minDeltasForSnapshot = 5, numOfVersToRetainInMemory = 2)
var currentVersion = 0
currentVersion = updateVersionTo(provider, currentVersion, 2)
require(getLatestData(provider) === Set("a" -> 2))
provider.doMaintenance() // should not generate snapshot files
assert(getLatestData(provider) === Set("a" -> 2))
for (i <- 1 to currentVersion) {
assert(fileExists(provider, i, isSnapshot = false)) // all delta files present
assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present
}
// After version 6, snapshotting should generate one snapshot file
currentVersion = updateVersionTo(provider, currentVersion, 6)
require(getLatestData(provider) === Set("a" -> 6), "store not updated correctly")
provider.doMaintenance() // should generate snapshot files
val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true))
assert(snapshotVersion.nonEmpty, "snapshot file not generated")
deleteFilesEarlierThanVersion(provider, snapshotVersion.get)
assert(
getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get),
"snapshotting messed up the data of the snapshotted version")
assert(
getLatestData(provider) === Set("a" -> 6),
"snapshotting messed up the data of the final version")
// After version 20, snapshotting should generate newer snapshot files
currentVersion = updateVersionTo(provider, currentVersion, 20)
require(getLatestData(provider) === Set("a" -> 20), "store not updated correctly")
provider.doMaintenance() // do snapshot
val latestSnapshotVersion = (0 to 20).filter(version =>
fileExists(provider, version, isSnapshot = true)).lastOption
assert(latestSnapshotVersion.nonEmpty, "no snapshot file found")
assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated")
deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get)
assert(getLatestData(provider) === Set("a" -> 20), "snapshotting messed up the data")
}
test("reports memory usage") { test("reports memory usage") {
val provider = newStoreProvider() val provider = newStoreProvider()
val store = provider.getStore(0) val store = provider.getStore(0)