[SPARK-36030][SQL] Support DS v2 metrics at writing path

### What changes were proposed in this pull request?

We add the interface for DS v2 metrics in SPARK-34366. It is only added for reading path, though. This patch extends the metrics interface to writing path.

### Why are the changes needed?

Complete DS v2 metrics interface support in writing path.

### Does this PR introduce _any_ user-facing change?

No. For developer, yes, as this adds metrics support at DS v2 writing path.

### How was this patch tested?

Added test.

Closes #33239 from viirya/v2-write-metrics.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
(cherry picked from commit 2653201b0a)
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
This commit is contained in:
Liang-Chi Hsieh 2021-07-20 20:20:35 -07:00
parent ab80d3c167
commit 86d1fb4698
20 changed files with 456 additions and 47 deletions

View file

@ -113,7 +113,6 @@ public interface Scan {
* By default it returns empty array.
*/
default CustomMetric[] supportedCustomMetrics() {
CustomMetric[] NO_METRICS = {};
return NO_METRICS;
return new CustomMetric[]{};
}
}

View file

@ -21,6 +21,7 @@ import java.io.Closeable;
import java.io.IOException;
import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.metric.CustomTaskMetric;
/**
* A data writer returned by {@link DataWriterFactory#createWriter(int, long)} and is
@ -104,4 +105,12 @@ public interface DataWriter<T> extends Closeable {
* @throws IOException if failure happens during disk/network IO like writing files.
*/
void abort() throws IOException;
/**
* Returns an array of custom task metrics. By default it returns empty array. Note that it is
* not recommended to put heavy logic in this method as it may affect writing performance.
*/
default CustomTaskMetric[] currentMetricsValues() {
return new CustomTaskMetric[]{};
}
}

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.connector.write;
import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.catalog.TableCapability;
import org.apache.spark.sql.connector.metric.CustomMetric;
import org.apache.spark.sql.connector.write.streaming.StreamingWrite;
/**
@ -62,4 +63,12 @@ public interface Write {
default StreamingWrite toStreaming() {
throw new UnsupportedOperationException(description() + ": Streaming write is not supported");
}
/**
* Returns an array of supported custom metrics with name and description.
* By default it returns empty array.
*/
default CustomMetric[] supportedCustomMetrics() {
return new CustomMetric[]{};
}
}

View file

@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow}
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils}
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions._
import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
@ -344,6 +345,10 @@ class InMemoryTable(
case exc: StreamingNotSupportedOperation => exc.throwsException()
case s => s
}
override def supportedCustomMetrics(): Array[CustomMetric] = {
Array(new InMemorySimpleCustomMetric)
}
}
}
}
@ -604,4 +609,21 @@ private class BufferWriter extends DataWriter[InternalRow] {
override def abort(): Unit = {}
override def close(): Unit = {}
override def currentMetricsValues(): Array[CustomTaskMetric] = {
val metric = new CustomTaskMetric {
override def name(): String = "in_memory_buffer_rows"
override def value(): Long = buffer.rows.size
}
Array(metric)
}
}
class InMemorySimpleCustomMetric extends CustomMetric {
override def name(): String = "in_memory_buffer_rows"
override def description(): String = "number of rows in buffer"
override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
s"in-memory rows: ${taskMetrics.sum}"
}
}

View file

@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
import org.apache.spark.sql.execution.datasources.FileFormatWriter.ConcurrentOutputWriterSpec
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StringType
import org.apache.spark.util.SerializableConfiguration
@ -41,7 +42,8 @@ import org.apache.spark.util.SerializableConfiguration
abstract class FileFormatDataWriter(
description: WriteJobDescription,
taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol) extends DataWriter[InternalRow] {
committer: FileCommitProtocol,
customMetrics: Map[String, SQLMetric]) extends DataWriter[InternalRow] {
/**
* Max number of files a single task writes out due to file size. In most cases the number of
* files written should be very small. This is just a safe guard to protect some really bad
@ -76,12 +78,21 @@ abstract class FileFormatDataWriter(
/** Writes a record. */
def write(record: InternalRow): Unit
def writeWithMetrics(record: InternalRow, count: Long): Unit = {
if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) {
CustomMetrics.updateMetrics(currentMetricsValues, customMetrics)
}
write(record)
}
/** Write an iterator of records. */
def writeWithIterator(iterator: Iterator[InternalRow]): Unit = {
var count = 0L
while (iterator.hasNext) {
write(iterator.next())
writeWithMetrics(iterator.next(), count)
count += 1
}
CustomMetrics.updateMetrics(currentMetricsValues, customMetrics)
}
/**
@ -113,8 +124,9 @@ abstract class FileFormatDataWriter(
class EmptyDirectoryDataWriter(
description: WriteJobDescription,
taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol
) extends FileFormatDataWriter(description, taskAttemptContext, committer) {
committer: FileCommitProtocol,
customMetrics: Map[String, SQLMetric] = Map.empty
) extends FileFormatDataWriter(description, taskAttemptContext, committer, customMetrics) {
override def write(record: InternalRow): Unit = {}
}
@ -122,8 +134,9 @@ class EmptyDirectoryDataWriter(
class SingleDirectoryDataWriter(
description: WriteJobDescription,
taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol)
extends FileFormatDataWriter(description, taskAttemptContext, committer) {
committer: FileCommitProtocol,
customMetrics: Map[String, SQLMetric] = Map.empty)
extends FileFormatDataWriter(description, taskAttemptContext, committer, customMetrics) {
private var fileCounter: Int = _
private var recordsInFile: Long = _
// Initialize currentWriter and statsTrackers
@ -169,8 +182,9 @@ class SingleDirectoryDataWriter(
abstract class BaseDynamicPartitionDataWriter(
description: WriteJobDescription,
taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol)
extends FileFormatDataWriter(description, taskAttemptContext, committer) {
committer: FileCommitProtocol,
customMetrics: Map[String, SQLMetric])
extends FileFormatDataWriter(description, taskAttemptContext, committer, customMetrics) {
/** Flag saying whether or not the data to be written out is partitioned. */
protected val isPartitioned = description.partitionColumns.nonEmpty
@ -314,8 +328,10 @@ abstract class BaseDynamicPartitionDataWriter(
class DynamicPartitionDataSingleWriter(
description: WriteJobDescription,
taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol)
extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer) {
committer: FileCommitProtocol,
customMetrics: Map[String, SQLMetric] = Map.empty)
extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer,
customMetrics) {
private var currentPartitionValues: Option[UnsafeRow] = None
private var currentBucketId: Option[Int] = None
@ -361,8 +377,9 @@ class DynamicPartitionDataConcurrentWriter(
description: WriteJobDescription,
taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol,
concurrentOutputWriterSpec: ConcurrentOutputWriterSpec)
extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer)
concurrentOutputWriterSpec: ConcurrentOutputWriterSpec,
customMetrics: Map[String, SQLMetric] = Map.empty)
extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer, customMetrics)
with Logging {
/** Wrapper class to index a unique concurrent output writer. */
@ -452,17 +469,23 @@ class DynamicPartitionDataConcurrentWriter(
* Write iterator of records with concurrent writers.
*/
override def writeWithIterator(iterator: Iterator[InternalRow]): Unit = {
var count = 0L
while (iterator.hasNext && !sorted) {
write(iterator.next())
writeWithMetrics(iterator.next(), count)
count += 1
}
CustomMetrics.updateMetrics(currentMetricsValues, customMetrics)
if (iterator.hasNext) {
count = 0L
clearCurrentWriterStatus()
val sorter = concurrentOutputWriterSpec.createSorter()
val sortIterator = sorter.sort(iterator.asInstanceOf[Iterator[UnsafeRow]])
while (sortIterator.hasNext) {
write(sortIterator.next())
writeWithMetrics(sortIterator.next(), count)
count += 1
}
CustomMetrics.updateMetrics(currentMetricsValues, customMetrics)
}
}

View file

@ -140,12 +140,12 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
// Add a Project here to make sure we produce unsafe rows.
withProjectAndFilter(p, f, scanExec, !scanExec.supportsColumnar) :: Nil
case WriteToDataSourceV2(relationOpt, writer, query) =>
case WriteToDataSourceV2(relationOpt, writer, query, customMetrics) =>
val invalidateCacheFunc: () => Unit = () => relationOpt match {
case Some(r) => session.sharedState.cacheManager.uncacheQuery(session, r, cascade = true)
case None => ()
}
WriteToDataSourceV2Exec(writer, invalidateCacheFunc, planLater(query)) :: Nil
WriteToDataSourceV2Exec(writer, invalidateCacheFunc, planLater(query), customMetrics) :: Nil
case CreateV2Table(catalog, ident, schema, parts, props, ifNotExists) =>
val propsWithOwner = CatalogV2Util.withDefaultOwnership(props)
@ -260,8 +260,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
throw QueryCompilationErrors.deleteOnlySupportedWithV2TablesError()
}
case WriteToContinuousDataSource(writer, query) =>
WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil
case WriteToContinuousDataSource(writer, query, customMetrics) =>
WriteToContinuousDataSourceExec(writer, planLater(query), customMetrics) :: Nil
case DescribeNamespace(ResolvedNamespace(catalog, ns), extended, output) =>
DescribeNamespaceExec(output, catalog.asNamespaceCatalog, ns, extended) :: Nil

View file

@ -32,9 +32,11 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, SupportsWrite, Table, TableCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, LogicalWriteInfoImpl, PhysicalWriteInfoImpl, V1Write, Write, WriterCommitMessage}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric, SQLMetrics}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.{LongAccumulator, Utils}
@ -46,7 +48,8 @@ import org.apache.spark.util.{LongAccumulator, Utils}
case class WriteToDataSourceV2(
relation: Option[DataSourceV2Relation],
batchWrite: BatchWrite,
query: LogicalPlan) extends UnaryNode {
query: LogicalPlan,
customMetrics: Seq[CustomMetric]) extends UnaryNode {
override def child: LogicalPlan = query
override def output: Seq[Attribute] = Nil
override protected def withNewChildInternal(newChild: LogicalPlan): WriteToDataSourceV2 =
@ -276,7 +279,12 @@ case class OverwritePartitionsDynamicExec(
case class WriteToDataSourceV2Exec(
batchWrite: BatchWrite,
refreshCache: () => Unit,
query: SparkPlan) extends V2TableWriteExec {
query: SparkPlan,
writeMetrics: Seq[CustomMetric]) extends V2TableWriteExec {
override val customMetrics: Map[String, SQLMetric] = writeMetrics.map { customMetric =>
customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric)
}.toMap
override protected def run(): Seq[InternalRow] = {
val writtenRows = writeWithV2(batchWrite)
@ -292,6 +300,11 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec {
def refreshCache: () => Unit
def write: Write
override val customMetrics: Map[String, SQLMetric] =
write.supportedCustomMetrics().map { customMetric =>
customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric)
}.toMap
override protected def run(): Seq[InternalRow] = {
val writtenRows = writeWithV2(write.toBatch)
refreshCache()
@ -310,6 +323,10 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode {
override def child: SparkPlan = query
override def output: Seq[Attribute] = Nil
protected val customMetrics: Map[String, SQLMetric] = Map.empty
override lazy val metrics = customMetrics
protected def writeWithV2(batchWrite: BatchWrite): Seq[InternalRow] = {
val rdd: RDD[InternalRow] = {
val tempRdd = query.execute()
@ -330,11 +347,15 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode {
logInfo(s"Start processing data source write support: $batchWrite. " +
s"The input RDD has ${messages.length} partitions.")
// Avoid object not serializable issue.
val writeMetrics: Map[String, SQLMetric] = customMetrics
try {
sparkContext.runJob(
rdd,
(context: TaskContext, iter: Iterator[InternalRow]) =>
DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator),
DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator,
writeMetrics),
rdd.partitions.indices,
(index, result: DataWritingSparkTaskResult) => {
val commitMessage = result.writerCommitMessage
@ -376,7 +397,8 @@ object DataWritingSparkTask extends Logging {
writerFactory: DataWriterFactory,
context: TaskContext,
iter: Iterator[InternalRow],
useCommitCoordinator: Boolean): DataWritingSparkTaskResult = {
useCommitCoordinator: Boolean,
customMetrics: Map[String, SQLMetric]): DataWritingSparkTaskResult = {
val stageId = context.stageId()
val stageAttempt = context.stageAttemptNumber()
val partId = context.partitionId()
@ -388,11 +410,17 @@ object DataWritingSparkTask extends Logging {
// write the data and commit this writer.
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
while (iter.hasNext) {
if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) {
CustomMetrics.updateMetrics(dataWriter.currentMetricsValues, customMetrics)
}
// Count is here.
count += 1
dataWriter.write(iter.next())
}
CustomMetrics.updateMetrics(dataWriter.currentMetricsValues, customMetrics)
val msg = if (useCommitCoordinator) {
val coordinator = SparkEnv.get.outputCommitCoordinator
val commitAuthorized = coordinator.canCommit(stageId, stageAttempt, partId, attemptId)

View file

@ -45,13 +45,14 @@ object CustomMetrics {
}
/**
* Updates given custom metrics.
* Updates given custom metrics. If `currentMetricsValues` has metric which does not exist
* in `customMetrics` map, it is non-op.
*/
def updateMetrics(
currentMetricsValues: Seq[CustomTaskMetric],
customMetrics: Map[String, SQLMetric]): Unit = {
currentMetricsValues.foreach { metric =>
customMetrics(metric.name()).set(metric.value())
customMetrics.get(metric.name()).map(_.set(metric.value()))
}
}
}

View file

@ -137,11 +137,11 @@ class MicroBatchExecution(
// TODO (SPARK-27484): we should add the writing node before the plan is analyzed.
sink match {
case s: SupportsWrite =>
val streamingWrite = createStreamingWrite(s, extraOptions, _logicalPlan)
val (streamingWrite, customMetrics) = createStreamingWrite(s, extraOptions, _logicalPlan)
val relationOpt = plan.catalogAndIdent.map {
case (catalog, ident) => DataSourceV2Relation.create(s, Some(catalog), Some(ident))
}
WriteToMicroBatchDataSource(relationOpt, streamingWrite, _logicalPlan)
WriteToMicroBatchDataSource(relationOpt, streamingWrite, _logicalPlan, customMetrics)
case _ => _logicalPlan
}

View file

@ -37,6 +37,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table}
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit, SparkDataStream}
import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsTruncate}
import org.apache.spark.sql.connector.write.streaming.StreamingWrite
@ -581,7 +582,7 @@ abstract class StreamExecution(
protected def createStreamingWrite(
table: SupportsWrite,
options: Map[String, String],
inputPlan: LogicalPlan): StreamingWrite = {
inputPlan: LogicalPlan): (StreamingWrite, Seq[CustomMetric]) = {
val info = LogicalWriteInfoImpl(
queryId = id.toString,
inputPlan.schema,
@ -602,7 +603,8 @@ abstract class StreamExecution(
table.name + " does not support Update mode.")
writeBuilder.asInstanceOf[SupportsStreamingUpdateAsAppend].build()
}
write.toStreaming
(write.toStreaming, write.supportedCustomMetrics().toSeq)
}
protected def purge(threshold: Long): Unit = {

View file

@ -85,9 +85,9 @@ class ContinuousExecution(
uniqueSources = sources.distinct.map(s => s -> ReadLimit.allAvailable()).toMap
// TODO (SPARK-27484): we should add the writing node before the plan is analyzed.
WriteToContinuousDataSource(
createStreamingWrite(
plan.sink.asInstanceOf[SupportsWrite], extraOptions, _logicalPlan), _logicalPlan)
val (streamingWrite, customMetrics) = createStreamingWrite(
plan.sink.asInstanceOf[SupportsWrite], extraOptions, _logicalPlan)
WriteToContinuousDataSource(streamingWrite, _logicalPlan, customMetrics)
}
private val triggerExecutor = trigger match {

View file

@ -22,6 +22,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.write.DataWriter
import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
import org.apache.spark.util.Utils
/**
@ -32,8 +33,8 @@ import org.apache.spark.util.Utils
*
* We keep repeating prev.compute() and writing new epochs until the query is shut down.
*/
class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDataWriterFactory)
extends RDD[Unit](prev) {
class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDataWriterFactory,
customMetrics: Map[String, SQLMetric]) extends RDD[Unit](prev) {
override val partitioner = prev.partitioner
@ -55,9 +56,15 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDat
context.partitionId(),
context.taskAttemptId(),
EpochTracker.getCurrentEpoch.get)
var count = 0L
while (dataIterator.hasNext) {
if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) {
CustomMetrics.updateMetrics(dataWriter.currentMetricsValues, customMetrics)
}
count += 1
dataWriter.write(dataIterator.next())
}
CustomMetrics.updateMetrics(dataWriter.currentMetricsValues, customMetrics)
logInfo(s"Writer for partition ${context.partitionId()} " +
s"in epoch ${EpochTracker.getCurrentEpoch.get} is committing.")
val msg = dataWriter.commit()

View file

@ -19,12 +19,14 @@ package org.apache.spark.sql.execution.streaming.continuous
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.write.streaming.StreamingWrite
/**
* The logical plan for writing data in a continuous stream.
*/
case class WriteToContinuousDataSource(write: StreamingWrite, query: LogicalPlan)
case class WriteToContinuousDataSource(write: StreamingWrite, query: LogicalPlan,
customMetrics: Seq[CustomMetric])
extends UnaryNode {
override def child: LogicalPlan = query
override def output: Seq[Attribute] = Nil

View file

@ -23,26 +23,33 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.write.PhysicalWriteInfoImpl
import org.apache.spark.sql.connector.write.streaming.StreamingWrite
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.streaming.StreamExecution
/**
* The physical plan for writing data into a continuous processing [[StreamingWrite]].
*/
case class WriteToContinuousDataSourceExec(write: StreamingWrite, query: SparkPlan)
case class WriteToContinuousDataSourceExec(write: StreamingWrite, query: SparkPlan,
customMetrics: Seq[CustomMetric])
extends UnaryExecNode with Logging {
override def child: SparkPlan = query
override def output: Seq[Attribute] = Nil
override lazy val metrics = customMetrics.map { customMetric =>
customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric)
}.toMap
override protected def doExecute(): RDD[InternalRow] = {
val queryRdd = query.execute()
val writerFactory = write.createStreamingWriterFactory(
PhysicalWriteInfoImpl(queryRdd.getNumPartitions))
val rdd = new ContinuousWriteRDD(queryRdd, writerFactory)
val rdd = new ContinuousWriteRDD(queryRdd, writerFactory, metrics)
logInfo(s"Start processing data source write support: $write. " +
s"The input RDD has ${rdd.partitions.length} partitions.")

View file

@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.sources
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.write.streaming.StreamingWrite
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2}
@ -31,13 +32,14 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, Writ
case class WriteToMicroBatchDataSource(
relation: Option[DataSourceV2Relation],
write: StreamingWrite,
query: LogicalPlan)
query: LogicalPlan,
customMetrics: Seq[CustomMetric])
extends UnaryNode {
override def child: LogicalPlan = query
override def output: Seq[Attribute] = Nil
def createPlan(batchId: Long): WriteToDataSourceV2 = {
WriteToDataSourceV2(relation, new MicroBatchWrite(batchId, write), query)
WriteToDataSourceV2(relation, new MicroBatchWrite(batchId, write), query, customMetrics)
}
override protected def withNewChildInternal(newChild: LogicalPlan): WriteToMicroBatchDataSource =

View file

@ -65,8 +65,8 @@ class SimpleWritableDataSource extends TestingV2Source {
class MyWriteBuilder(path: String, info: LogicalWriteInfo)
extends WriteBuilder with SupportsTruncate {
private val queryId: String = info.queryId()
private var needTruncate = false
protected val queryId: String = info.queryId()
protected var needTruncate = false
override def truncate(): WriteBuilder = {
this.needTruncate = true
@ -127,8 +127,8 @@ class SimpleWritableDataSource extends TestingV2Source {
class MyTable(options: CaseInsensitiveStringMap)
extends SimpleBatchTable with SupportsWrite {
private val path = options.get("path")
private val conf = SparkContext.getActive.get.hadoopConfiguration
protected val path = options.get("path")
protected val conf = SparkContext.getActive.get.hadoopConfiguration
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new MyScanBuilder(new Path(path).toUri.toString, conf)

View file

@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources
import java.util.Collections
import org.scalatest.BeforeAndAfter
import org.scalatest.time.SpanSugar._
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog}
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
class FileFormatDataWriterMetricSuite
extends QueryTest with SharedSparkSession with BeforeAndAfter {
import testImplicits._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
before {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
}
after {
spark.sessionState.catalogManager.reset()
spark.sessionState.conf.clear()
}
private def testMetricOnDSv2(func: String => Unit, checker: Map[Long, String] => Unit) {
withTable("testcat.table_name") {
val statusStore = spark.sharedState.statusStore
val oldCount = statusStore.executionsList().size
val testCatalog = spark.sessionState.catalogManager.catalog("testcat").asTableCatalog
testCatalog.createTable(
Identifier.of(Array(), "table_name"),
new StructType().add("i", "int"),
Array.empty, Collections.emptyMap[String, String])
func("testcat.table_name")
// Wait until the new execution is started and being tracked.
eventually(timeout(10.seconds), interval(10.milliseconds)) {
assert(statusStore.executionsCount() >= oldCount)
}
// Wait for listener to finish computing the metrics for the execution.
eventually(timeout(10.seconds), interval(10.milliseconds)) {
assert(statusStore.executionsList().nonEmpty &&
statusStore.executionsList().last.metricValues != null)
}
val execId = statusStore.executionsList().last.executionId
val metrics = statusStore.executionMetrics(execId)
checker(metrics)
}
}
test("Report metrics from Datasource v2 write: append") {
testMetricOnDSv2(table => {
val df = sql("select 1 as i")
val v2Writer = df.writeTo(table)
v2Writer.append()
}, metrics => {
val customMetric = metrics.find(_._2 == "in-memory rows: 1")
assert(customMetric.isDefined)
})
}
test("Report metrics from Datasource v2 write: overwrite") {
testMetricOnDSv2(table => {
val df = Seq(1, 2, 3).toDF("i")
val v2Writer = df.writeTo(table)
v2Writer.overwrite(lit(true))
}, metrics => {
val customMetric = metrics.find(_._2 == "in-memory rows: 3")
assert(customMetric.isDefined)
})
}
}

View file

@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources
import java.util.Collections
import org.scalatest.BeforeAndAfter
import org.scalatest.time.SpanSugar._
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog}
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
class InMemoryTableMetricSuite
extends QueryTest with SharedSparkSession with BeforeAndAfter {
import testImplicits._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
before {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
}
after {
spark.sessionState.catalogManager.reset()
spark.sessionState.conf.clear()
}
private def testMetricOnDSv2(func: String => Unit, checker: Map[Long, String] => Unit) {
withTable("testcat.table_name") {
val statusStore = spark.sharedState.statusStore
val oldCount = statusStore.executionsList().size
val testCatalog = spark.sessionState.catalogManager.catalog("testcat").asTableCatalog
testCatalog.createTable(
Identifier.of(Array(), "table_name"),
new StructType().add("i", "int"),
Array.empty, Collections.emptyMap[String, String])
func("testcat.table_name")
// Wait until the new execution is started and being tracked.
eventually(timeout(10.seconds), interval(10.milliseconds)) {
assert(statusStore.executionsCount() >= oldCount)
}
// Wait for listener to finish computing the metrics for the execution.
eventually(timeout(10.seconds), interval(10.milliseconds)) {
assert(statusStore.executionsList().nonEmpty &&
statusStore.executionsList().last.metricValues != null)
}
val execId = statusStore.executionsList().last.executionId
val metrics = statusStore.executionMetrics(execId)
checker(metrics)
}
}
test("Report metrics from Datasource v2 write: append") {
testMetricOnDSv2(table => {
val df = sql("select 1 as i")
val v2Writer = df.writeTo(table)
v2Writer.append()
}, metrics => {
val customMetric = metrics.find(_._2 == "in-memory rows: 1")
assert(customMetric.isDefined)
})
}
test("Report metrics from Datasource v2 write: overwrite") {
testMetricOnDSv2(table => {
val df = Seq(1, 2, 3).toDF("i")
val v2Writer = df.writeTo(table)
v2Writer.overwrite(lit(true))
}, metrics => {
val customMetric = metrics.find(_._2 == "in-memory rows: 3")
assert(customMetric.isDefined)
})
}
}

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.metric
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.connector.metric.{CustomAvgMetric, CustomSumMetric}
import org.apache.spark.sql.connector.metric.{CustomAvgMetric, CustomSumMetric, CustomTaskMetric}
class CustomMetricsSuite extends SparkFunSuite {
@ -52,6 +52,14 @@ class CustomMetricsSuite extends SparkFunSuite {
val metricValues2 = Array.empty[Long]
assert(metric.aggregateTaskMetrics(metricValues2) == "0")
}
test("Report unsupported metrics should be non-op") {
val taskMetric = new CustomTaskMetric {
override def name(): String = "custom_metric"
override def value(): Long = 1L
}
CustomMetrics.updateMetrics(Seq(taskMetric), Map.empty)
}
}
private[spark] class TestCustomSumMetric extends CustomSumMetric {

View file

@ -21,8 +21,11 @@ import java.util.Properties
import scala.collection.mutable.ListBuffer
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.json4s.jackson.JsonMethods._
import org.scalatest.BeforeAndAfter
import org.scalatest.time.SpanSugar._
import org.apache.spark._
import org.apache.spark.LocalSparkContext._
@ -37,9 +40,11 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.connector.{RangeInputPartition, SimpleScanBuilder}
import org.apache.spark.sql.connector.{CSVDataWriter, CSVDataWriterFactory, RangeInputPartition, SimpleScanBuilder, SimpleWritableDataSource}
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, Write, WriteBuilder}
import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanInfo, SQLExecution}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
@ -49,8 +54,9 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.UI_RETAINED_EXECUTIONS
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.status.ElementTrackingStore
import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator}
import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator, SerializableConfiguration}
import org.apache.spark.util.kvstore.InMemoryStore
@ -852,6 +858,33 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils
assert(metrics.contains(expectedMetric.id))
assert(metrics(expectedMetric.id) === expectedValue)
}
test("SPARK-36030: Report metrics from Datasource v2 write") {
withTempDir { dir =>
val statusStore = spark.sharedState.statusStore
val oldCount = statusStore.executionsList().size
val cls = classOf[CustomMetricsDataSource].getName
spark.range(10).select('id as 'i, -'id as 'j).write.format(cls)
.option("path", dir.getCanonicalPath).mode("append").save()
// Wait until the new execution is started and being tracked.
eventually(timeout(10.seconds), interval(10.milliseconds)) {
assert(statusStore.executionsCount() >= oldCount)
}
// Wait for listener to finish computing the metrics for the execution.
eventually(timeout(10.seconds), interval(10.milliseconds)) {
assert(statusStore.executionsList().nonEmpty &&
statusStore.executionsList().last.metricValues != null)
}
val execId = statusStore.executionsList().last.executionId
val metrics = statusStore.executionMetrics(execId)
val customMetric = metrics.find(_._2 == "custom_metric: 12345, 12345")
assert(customMetric.isDefined)
}
}
}
@ -973,3 +1006,68 @@ class CustomMetricScanBuilder extends SimpleScanBuilder {
override def createReaderFactory(): PartitionReaderFactory = CustomMetricReaderFactory
}
class CustomMetricsCSVDataWriter(fs: FileSystem, file: Path) extends CSVDataWriter(fs, file) {
override def currentMetricsValues(): Array[CustomTaskMetric] = {
val metric = new CustomTaskMetric {
override def name(): String = "custom_metric"
override def value(): Long = 12345
}
Array(metric)
}
}
class CustomMetricsWriterFactory(path: String, jobId: String, conf: SerializableConfiguration)
extends CSVDataWriterFactory(path, jobId, conf) {
override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
val jobPath = new Path(new Path(path, "_temporary"), jobId)
val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId")
val fs = filePath.getFileSystem(conf.value)
new CustomMetricsCSVDataWriter(fs, filePath)
}
}
class CustomMetricsDataSource extends SimpleWritableDataSource {
class CustomMetricBatchWrite(queryId: String, path: String, conf: Configuration)
extends MyBatchWrite(queryId, path, conf) {
override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = {
new CustomMetricsWriterFactory(path, queryId, new SerializableConfiguration(conf))
}
}
class CustomMetricWriteBuilder(path: String, info: LogicalWriteInfo)
extends MyWriteBuilder(path, info) {
override def build(): Write = {
new Write {
override def toBatch: BatchWrite = {
val hadoopPath = new Path(path)
val hadoopConf = SparkContext.getActive.get.hadoopConfiguration
val fs = hadoopPath.getFileSystem(hadoopConf)
if (needTruncate) {
fs.delete(hadoopPath, true)
}
val pathStr = hadoopPath.toUri.toString
new CustomMetricBatchWrite(queryId, pathStr, hadoopConf)
}
override def supportedCustomMetrics(): Array[CustomMetric] = {
Array(new SimpleCustomMetric)
}
}
}
}
class CustomMetricTable(options: CaseInsensitiveStringMap) extends MyTable(options) {
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
new CustomMetricWriteBuilder(path, info)
}
}
override def getTable(options: CaseInsensitiveStringMap): Table = {
new CustomMetricTable(options)
}
}