diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 889477b9bc..60ad318e9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -32,6 +32,7 @@ import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.UnsafeRowUtils +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ThreadUtils, Utils} @@ -182,16 +183,36 @@ object StateStoreMetrics { /** * Name and description of custom implementation-specific metrics that a - * state store may wish to expose. + * state store may wish to expose. Also provides [[SQLMetric]] instance to + * show the metric in UI and accumulate it at the query level. */ trait StateStoreCustomMetric { def name: String def desc: String + def withNewDesc(desc: String): StateStoreCustomMetric + def createSQLMetric(sparkContext: SparkContext): SQLMetric } -case class StateStoreCustomSumMetric(name: String, desc: String) extends StateStoreCustomMetric -case class StateStoreCustomSizeMetric(name: String, desc: String) extends StateStoreCustomMetric -case class StateStoreCustomTimingMetric(name: String, desc: String) extends StateStoreCustomMetric +case class StateStoreCustomSumMetric(name: String, desc: String) extends StateStoreCustomMetric { + override def withNewDesc(newDesc: String): StateStoreCustomSumMetric = copy(desc = desc) + + override def createSQLMetric(sparkContext: SparkContext): SQLMetric = + SQLMetrics.createMetric(sparkContext, desc) +} + +case class StateStoreCustomSizeMetric(name: String, desc: String) extends StateStoreCustomMetric { + override def withNewDesc(desc: String): StateStoreCustomSizeMetric = copy(desc = desc) + + override def createSQLMetric(sparkContext: SparkContext): SQLMetric = + SQLMetrics.createSizeMetric(sparkContext, desc) +} + +case class StateStoreCustomTimingMetric(name: String, desc: String) extends StateStoreCustomMetric { + override def withNewDesc(desc: String): StateStoreCustomTimingMetric = copy(desc = desc) + + override def createSQLMetric(sparkContext: SparkContext): SQLMetric = + SQLMetrics.createTimingMetric(sparkContext, desc) +} /** * An exception thrown when an invalid UnsafeRow is detected in state store. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 5c74811040..d342c83d5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -319,15 +319,7 @@ class SymmetricHashJoinStateManager( keyWithIndexToValueMetrics.numKeys, // represent each buffered row only once keyToNumValuesMetrics.memoryUsedBytes + keyWithIndexToValueMetrics.memoryUsedBytes, keyWithIndexToValueMetrics.customMetrics.map { - case (s @ StateStoreCustomSumMetric(_, desc), value) => - s.copy(desc = newDesc(desc)) -> value - case (s @ StateStoreCustomSizeMetric(_, desc), value) => - s.copy(desc = newDesc(desc)) -> value - case (s @ StateStoreCustomTimingMetric(_, desc), value) => - s.copy(desc = newDesc(desc)) -> value - case (s, _) => - throw new IllegalArgumentException( - s"Unknown state store custom metric is found at metrics: $s") + case (metric, value) => (metric.withNewDesc(desc = newDesc(metric.desc)), value) } ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index cef0c01e22..f9b639b0d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -126,12 +126,7 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => private def stateStoreCustomMetrics: Map[String, SQLMetric] = { val provider = StateStoreProvider.create(sqlContext.conf.stateStoreProviderClass) provider.supportedCustomMetrics.map { - case StateStoreCustomSumMetric(name, desc) => - name -> SQLMetrics.createMetric(sparkContext, desc) - case StateStoreCustomSizeMetric(name, desc) => - name -> SQLMetrics.createSizeMetric(sparkContext, desc) - case StateStoreCustomTimingMetric(name, desc) => - name -> SQLMetrics.createTimingMetric(sparkContext, desc) + metric => (metric.name, metric.createSQLMetric(sparkContext)) }.toMap } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index af5e9bb5c6..4323725df9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -1013,6 +1013,20 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] assert(err.getMessage.contains("Cannot put a null value")) } + test("SPARK-35763: StateStoreCustomMetric withNewDesc and createSQLMetric") { + val metric = StateStoreCustomSizeMetric(name = "m1", desc = "desc1") + val metricNew = metric.withNewDesc("new desc") + assert(metricNew.desc === "new desc", "incorrect description in copied instance") + assert(metricNew.name === "m1", "incorrect name in copied instance") + + val conf = new SparkConf().setMaster("local").setAppName("SPARK-35763").set(RPC_NUM_RETRIES, 1) + withSpark(new SparkContext(conf)) { sc => + val sqlMetric = metric.createSQLMetric(sc) + assert(sqlMetric != null) + assert(sqlMetric.name === Some("desc1")) + } + } + /** Return a new provider with a random id */ def newStoreProvider(): ProviderClass