[SPARK-35763][SS] Remove the StateStoreCustomMetric subclass enumeration dependency

### What changes were proposed in this pull request?

Remove the usage of the enumerating subclasses of `StateStoreCustomMetric` dependency.

To achieve it, add couple of utility methods to `StateStoreCustomMetric`
* `withNewDesc(desc : String)` to `StateStoreCustomMetric` for cloning the instance with a new `desc` (currently used in `SymmetricHashJoinStateManager`)
* `createSQLMetric(sparkContext: sparkContext): SQLMetric` for creating a corresponding `SQLMetric` to show the metric in UI and accumulate at the query level (currently used in `statefulOperator. stateStoreCustomMetrics`)

### Why are the changes needed?

Code in [SymmetricHashJoinStateManager](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala#L321) and [StateStoreWriter](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala#L129) rely on the subclass implementations of [StateStoreCustomMetric](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala#L187).

If a new subclass of `StateStoreCustomMetric` is added, it requires code changes to `SymmetricHashJoinStateManager` and `StateStoreWriter`, and we may miss the update if there is no existing test coverage.

To prevent these issues add a couple of utility methods to `StateStoreCustomMetric` as mentioned above.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing UT and a new UT

Closes #32914 from vkorukanti/SPARK-35763.

Authored-by: Venki Korukanti <venki.korukanti@gmail.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
This commit is contained in:
Venki Korukanti 2021-06-17 07:48:24 +09:00 committed by Jungtaek Lim
parent 506ef9aad7
commit 8e594f084a
4 changed files with 41 additions and 19 deletions

View file

@ -32,6 +32,7 @@ import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.util.UnsafeRowUtils
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{ThreadUtils, Utils}
@ -182,16 +183,36 @@ object StateStoreMetrics {
/**
* Name and description of custom implementation-specific metrics that a
* state store may wish to expose.
* state store may wish to expose. Also provides [[SQLMetric]] instance to
* show the metric in UI and accumulate it at the query level.
*/
trait StateStoreCustomMetric {
def name: String
def desc: String
def withNewDesc(desc: String): StateStoreCustomMetric
def createSQLMetric(sparkContext: SparkContext): SQLMetric
}
case class StateStoreCustomSumMetric(name: String, desc: String) extends StateStoreCustomMetric
case class StateStoreCustomSizeMetric(name: String, desc: String) extends StateStoreCustomMetric
case class StateStoreCustomTimingMetric(name: String, desc: String) extends StateStoreCustomMetric
case class StateStoreCustomSumMetric(name: String, desc: String) extends StateStoreCustomMetric {
override def withNewDesc(newDesc: String): StateStoreCustomSumMetric = copy(desc = desc)
override def createSQLMetric(sparkContext: SparkContext): SQLMetric =
SQLMetrics.createMetric(sparkContext, desc)
}
case class StateStoreCustomSizeMetric(name: String, desc: String) extends StateStoreCustomMetric {
override def withNewDesc(desc: String): StateStoreCustomSizeMetric = copy(desc = desc)
override def createSQLMetric(sparkContext: SparkContext): SQLMetric =
SQLMetrics.createSizeMetric(sparkContext, desc)
}
case class StateStoreCustomTimingMetric(name: String, desc: String) extends StateStoreCustomMetric {
override def withNewDesc(desc: String): StateStoreCustomTimingMetric = copy(desc = desc)
override def createSQLMetric(sparkContext: SparkContext): SQLMetric =
SQLMetrics.createTimingMetric(sparkContext, desc)
}
/**
* An exception thrown when an invalid UnsafeRow is detected in state store.

View file

@ -319,15 +319,7 @@ class SymmetricHashJoinStateManager(
keyWithIndexToValueMetrics.numKeys, // represent each buffered row only once
keyToNumValuesMetrics.memoryUsedBytes + keyWithIndexToValueMetrics.memoryUsedBytes,
keyWithIndexToValueMetrics.customMetrics.map {
case (s @ StateStoreCustomSumMetric(_, desc), value) =>
s.copy(desc = newDesc(desc)) -> value
case (s @ StateStoreCustomSizeMetric(_, desc), value) =>
s.copy(desc = newDesc(desc)) -> value
case (s @ StateStoreCustomTimingMetric(_, desc), value) =>
s.copy(desc = newDesc(desc)) -> value
case (s, _) =>
throw new IllegalArgumentException(
s"Unknown state store custom metric is found at metrics: $s")
case (metric, value) => (metric.withNewDesc(desc = newDesc(metric.desc)), value)
}
)
}

View file

@ -126,12 +126,7 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
private def stateStoreCustomMetrics: Map[String, SQLMetric] = {
val provider = StateStoreProvider.create(sqlContext.conf.stateStoreProviderClass)
provider.supportedCustomMetrics.map {
case StateStoreCustomSumMetric(name, desc) =>
name -> SQLMetrics.createMetric(sparkContext, desc)
case StateStoreCustomSizeMetric(name, desc) =>
name -> SQLMetrics.createSizeMetric(sparkContext, desc)
case StateStoreCustomTimingMetric(name, desc) =>
name -> SQLMetrics.createTimingMetric(sparkContext, desc)
metric => (metric.name, metric.createSQLMetric(sparkContext))
}.toMap
}

View file

@ -1013,6 +1013,20 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
assert(err.getMessage.contains("Cannot put a null value"))
}
test("SPARK-35763: StateStoreCustomMetric withNewDesc and createSQLMetric") {
val metric = StateStoreCustomSizeMetric(name = "m1", desc = "desc1")
val metricNew = metric.withNewDesc("new desc")
assert(metricNew.desc === "new desc", "incorrect description in copied instance")
assert(metricNew.name === "m1", "incorrect name in copied instance")
val conf = new SparkConf().setMaster("local").setAppName("SPARK-35763").set(RPC_NUM_RETRIES, 1)
withSpark(new SparkContext(conf)) { sc =>
val sqlMetric = metric.createSQLMetric(sc)
assert(sqlMetric != null)
assert(sqlMetric.name === Some("desc1"))
}
}
/** Return a new provider with a random id */
def newStoreProvider(): ProviderClass