From 86d1fb469892f878fca6622527a19801f0e45ca8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 20 Jul 2021 20:20:35 -0700 Subject: [PATCH] [SPARK-36030][SQL] Support DS v2 metrics at writing path ### What changes were proposed in this pull request? We add the interface for DS v2 metrics in SPARK-34366. It is only added for reading path, though. This patch extends the metrics interface to writing path. ### Why are the changes needed? Complete DS v2 metrics interface support in writing path. ### Does this PR introduce _any_ user-facing change? No. For developer, yes, as this adds metrics support at DS v2 writing path. ### How was this patch tested? Added test. Closes #33239 from viirya/v2-write-metrics. Authored-by: Liang-Chi Hsieh Signed-off-by: Liang-Chi Hsieh (cherry picked from commit 2653201b0a50a651ebc0c4e1fabb47d32dee77c4) Signed-off-by: Liang-Chi Hsieh --- .../apache/spark/sql/connector/read/Scan.java | 3 +- .../spark/sql/connector/write/DataWriter.java | 9 ++ .../spark/sql/connector/write/Write.java | 9 ++ .../sql/connector/catalog/InMemoryTable.scala | 22 ++++ .../datasources/FileFormatDataWriter.scala | 51 ++++++--- .../datasources/v2/DataSourceV2Strategy.scala | 8 +- .../v2/WriteToDataSourceV2Exec.scala | 36 ++++++- .../sql/execution/metric/CustomMetrics.scala | 5 +- .../streaming/MicroBatchExecution.scala | 4 +- .../execution/streaming/StreamExecution.scala | 6 +- .../continuous/ContinuousExecution.scala | 6 +- .../continuous/ContinuousWriteRDD.scala | 11 +- .../WriteToContinuousDataSource.scala | 4 +- .../WriteToContinuousDataSourceExec.scala | 11 +- .../sources/WriteToMicroBatchDataSource.scala | 6 +- .../connector/SimpleWritableDataSource.scala | 8 +- .../FileFormatDataWriterMetricSuite.scala | 96 +++++++++++++++++ .../InMemoryTableMetricSuite.scala | 96 +++++++++++++++++ .../execution/metric/CustomMetricsSuite.scala | 10 +- .../ui/SQLAppStatusListenerSuite.scala | 102 +++++++++++++++++- 20 files changed, 456 insertions(+), 47 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriterMetricSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java index 78684b3705..d161de92eb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java @@ -113,7 +113,6 @@ public interface Scan { * By default it returns empty array. */ default CustomMetric[] supportedCustomMetrics() { - CustomMetric[] NO_METRICS = {}; - return NO_METRICS; + return new CustomMetric[]{}; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java index 1c07480148..6a1cee181b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java @@ -21,6 +21,7 @@ import java.io.Closeable; import java.io.IOException; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; /** * A data writer returned by {@link DataWriterFactory#createWriter(int, long)} and is @@ -104,4 +105,12 @@ public interface DataWriter extends Closeable { * @throws IOException if failure happens during disk/network IO like writing files. */ void abort() throws IOException; + + /** + * Returns an array of custom task metrics. By default it returns empty array. Note that it is + * not recommended to put heavy logic in this method as it may affect writing performance. + */ + default CustomTaskMetric[] currentMetricsValues() { + return new CustomTaskMetric[]{}; + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/Write.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/Write.java index 873680415d..7da5d0c83f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/Write.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/Write.java @@ -20,6 +20,7 @@ package org.apache.spark.sql.connector.write; import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.metric.CustomMetric; import org.apache.spark.sql.connector.write.streaming.StreamingWrite; /** @@ -62,4 +63,12 @@ public interface Write { default StreamingWrite toStreaming() { throw new UnsupportedOperationException(description() + ": Streaming write is not supported"); } + + /** + * Returns an array of supported custom metrics with name and description. + * By default it returns empty array. + */ + default CustomMetric[] supportedCustomMetrics() { + return new CustomMetric[]{}; + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index 6f7f761146..2f3c5a3853 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow} import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils} import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions._ +import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} @@ -344,6 +345,10 @@ class InMemoryTable( case exc: StreamingNotSupportedOperation => exc.throwsException() case s => s } + + override def supportedCustomMetrics(): Array[CustomMetric] = { + Array(new InMemorySimpleCustomMetric) + } } } } @@ -604,4 +609,21 @@ private class BufferWriter extends DataWriter[InternalRow] { override def abort(): Unit = {} override def close(): Unit = {} + + override def currentMetricsValues(): Array[CustomTaskMetric] = { + val metric = new CustomTaskMetric { + override def name(): String = "in_memory_buffer_rows" + + override def value(): Long = buffer.rows.size + } + Array(metric) + } +} + +class InMemorySimpleCustomMetric extends CustomMetric { + override def name(): String = "in_memory_buffer_rows" + override def description(): String = "number of rows in buffer" + override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { + s"in-memory rows: ${taskMetrics.sum}" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index 7e5a8cce27..365a9036e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} import org.apache.spark.sql.execution.datasources.FileFormatWriter.ConcurrentOutputWriterSpec +import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StringType import org.apache.spark.util.SerializableConfiguration @@ -41,7 +42,8 @@ import org.apache.spark.util.SerializableConfiguration abstract class FileFormatDataWriter( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) extends DataWriter[InternalRow] { + committer: FileCommitProtocol, + customMetrics: Map[String, SQLMetric]) extends DataWriter[InternalRow] { /** * Max number of files a single task writes out due to file size. In most cases the number of * files written should be very small. This is just a safe guard to protect some really bad @@ -76,12 +78,21 @@ abstract class FileFormatDataWriter( /** Writes a record. */ def write(record: InternalRow): Unit + def writeWithMetrics(record: InternalRow, count: Long): Unit = { + if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) { + CustomMetrics.updateMetrics(currentMetricsValues, customMetrics) + } + write(record) + } /** Write an iterator of records. */ def writeWithIterator(iterator: Iterator[InternalRow]): Unit = { + var count = 0L while (iterator.hasNext) { - write(iterator.next()) + writeWithMetrics(iterator.next(), count) + count += 1 } + CustomMetrics.updateMetrics(currentMetricsValues, customMetrics) } /** @@ -113,8 +124,9 @@ abstract class FileFormatDataWriter( class EmptyDirectoryDataWriter( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol -) extends FileFormatDataWriter(description, taskAttemptContext, committer) { + committer: FileCommitProtocol, + customMetrics: Map[String, SQLMetric] = Map.empty +) extends FileFormatDataWriter(description, taskAttemptContext, committer, customMetrics) { override def write(record: InternalRow): Unit = {} } @@ -122,8 +134,9 @@ class EmptyDirectoryDataWriter( class SingleDirectoryDataWriter( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) - extends FileFormatDataWriter(description, taskAttemptContext, committer) { + committer: FileCommitProtocol, + customMetrics: Map[String, SQLMetric] = Map.empty) + extends FileFormatDataWriter(description, taskAttemptContext, committer, customMetrics) { private var fileCounter: Int = _ private var recordsInFile: Long = _ // Initialize currentWriter and statsTrackers @@ -169,8 +182,9 @@ class SingleDirectoryDataWriter( abstract class BaseDynamicPartitionDataWriter( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) - extends FileFormatDataWriter(description, taskAttemptContext, committer) { + committer: FileCommitProtocol, + customMetrics: Map[String, SQLMetric]) + extends FileFormatDataWriter(description, taskAttemptContext, committer, customMetrics) { /** Flag saying whether or not the data to be written out is partitioned. */ protected val isPartitioned = description.partitionColumns.nonEmpty @@ -314,8 +328,10 @@ abstract class BaseDynamicPartitionDataWriter( class DynamicPartitionDataSingleWriter( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) - extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer) { + committer: FileCommitProtocol, + customMetrics: Map[String, SQLMetric] = Map.empty) + extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer, + customMetrics) { private var currentPartitionValues: Option[UnsafeRow] = None private var currentBucketId: Option[Int] = None @@ -361,8 +377,9 @@ class DynamicPartitionDataConcurrentWriter( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, committer: FileCommitProtocol, - concurrentOutputWriterSpec: ConcurrentOutputWriterSpec) - extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer) + concurrentOutputWriterSpec: ConcurrentOutputWriterSpec, + customMetrics: Map[String, SQLMetric] = Map.empty) + extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer, customMetrics) with Logging { /** Wrapper class to index a unique concurrent output writer. */ @@ -452,17 +469,23 @@ class DynamicPartitionDataConcurrentWriter( * Write iterator of records with concurrent writers. */ override def writeWithIterator(iterator: Iterator[InternalRow]): Unit = { + var count = 0L while (iterator.hasNext && !sorted) { - write(iterator.next()) + writeWithMetrics(iterator.next(), count) + count += 1 } + CustomMetrics.updateMetrics(currentMetricsValues, customMetrics) if (iterator.hasNext) { + count = 0L clearCurrentWriterStatus() val sorter = concurrentOutputWriterSpec.createSorter() val sortIterator = sorter.sort(iterator.asInstanceOf[Iterator[UnsafeRow]]) while (sortIterator.hasNext) { - write(sortIterator.next()) + writeWithMetrics(sortIterator.next(), count) + count += 1 } + CustomMetrics.updateMetrics(currentMetricsValues, customMetrics) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 7be13791ce..1ab554f520 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -140,12 +140,12 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat // Add a Project here to make sure we produce unsafe rows. withProjectAndFilter(p, f, scanExec, !scanExec.supportsColumnar) :: Nil - case WriteToDataSourceV2(relationOpt, writer, query) => + case WriteToDataSourceV2(relationOpt, writer, query, customMetrics) => val invalidateCacheFunc: () => Unit = () => relationOpt match { case Some(r) => session.sharedState.cacheManager.uncacheQuery(session, r, cascade = true) case None => () } - WriteToDataSourceV2Exec(writer, invalidateCacheFunc, planLater(query)) :: Nil + WriteToDataSourceV2Exec(writer, invalidateCacheFunc, planLater(query), customMetrics) :: Nil case CreateV2Table(catalog, ident, schema, parts, props, ifNotExists) => val propsWithOwner = CatalogV2Util.withDefaultOwnership(props) @@ -260,8 +260,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat throw QueryCompilationErrors.deleteOnlySupportedWithV2TablesError() } - case WriteToContinuousDataSource(writer, query) => - WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil + case WriteToContinuousDataSource(writer, query, customMetrics) => + WriteToContinuousDataSourceExec(writer, planLater(query), customMetrics) :: Nil case DescribeNamespace(ResolvedNamespace(catalog, ns), extended, output) => DescribeNamespaceExec(output, catalog.asNamespaceCatalog, ns, extended) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 7179eebbd4..ab9d93805e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -32,9 +32,11 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, SupportsWrite, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, LogicalWriteInfoImpl, PhysicalWriteInfoImpl, V1Write, Write, WriterCommitMessage} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric, SQLMetrics} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{LongAccumulator, Utils} @@ -46,7 +48,8 @@ import org.apache.spark.util.{LongAccumulator, Utils} case class WriteToDataSourceV2( relation: Option[DataSourceV2Relation], batchWrite: BatchWrite, - query: LogicalPlan) extends UnaryNode { + query: LogicalPlan, + customMetrics: Seq[CustomMetric]) extends UnaryNode { override def child: LogicalPlan = query override def output: Seq[Attribute] = Nil override protected def withNewChildInternal(newChild: LogicalPlan): WriteToDataSourceV2 = @@ -276,7 +279,12 @@ case class OverwritePartitionsDynamicExec( case class WriteToDataSourceV2Exec( batchWrite: BatchWrite, refreshCache: () => Unit, - query: SparkPlan) extends V2TableWriteExec { + query: SparkPlan, + writeMetrics: Seq[CustomMetric]) extends V2TableWriteExec { + + override val customMetrics: Map[String, SQLMetric] = writeMetrics.map { customMetric => + customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric) + }.toMap override protected def run(): Seq[InternalRow] = { val writtenRows = writeWithV2(batchWrite) @@ -292,6 +300,11 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec { def refreshCache: () => Unit def write: Write + override val customMetrics: Map[String, SQLMetric] = + write.supportedCustomMetrics().map { customMetric => + customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric) + }.toMap + override protected def run(): Seq[InternalRow] = { val writtenRows = writeWithV2(write.toBatch) refreshCache() @@ -310,6 +323,10 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode { override def child: SparkPlan = query override def output: Seq[Attribute] = Nil + protected val customMetrics: Map[String, SQLMetric] = Map.empty + + override lazy val metrics = customMetrics + protected def writeWithV2(batchWrite: BatchWrite): Seq[InternalRow] = { val rdd: RDD[InternalRow] = { val tempRdd = query.execute() @@ -330,11 +347,15 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode { logInfo(s"Start processing data source write support: $batchWrite. " + s"The input RDD has ${messages.length} partitions.") + // Avoid object not serializable issue. + val writeMetrics: Map[String, SQLMetric] = customMetrics + try { sparkContext.runJob( rdd, (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator), + DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator, + writeMetrics), rdd.partitions.indices, (index, result: DataWritingSparkTaskResult) => { val commitMessage = result.writerCommitMessage @@ -376,7 +397,8 @@ object DataWritingSparkTask extends Logging { writerFactory: DataWriterFactory, context: TaskContext, iter: Iterator[InternalRow], - useCommitCoordinator: Boolean): DataWritingSparkTaskResult = { + useCommitCoordinator: Boolean, + customMetrics: Map[String, SQLMetric]): DataWritingSparkTaskResult = { val stageId = context.stageId() val stageAttempt = context.stageAttemptNumber() val partId = context.partitionId() @@ -388,11 +410,17 @@ object DataWritingSparkTask extends Logging { // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { while (iter.hasNext) { + if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) { + CustomMetrics.updateMetrics(dataWriter.currentMetricsValues, customMetrics) + } + // Count is here. count += 1 dataWriter.write(iter.next()) } + CustomMetrics.updateMetrics(dataWriter.currentMetricsValues, customMetrics) + val msg = if (useCommitCoordinator) { val coordinator = SparkEnv.get.outputCommitCoordinator val commitAuthorized = coordinator.canCommit(stageId, stageAttempt, partId, attemptId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala index 3e6cad2676..222a705631 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala @@ -45,13 +45,14 @@ object CustomMetrics { } /** - * Updates given custom metrics. + * Updates given custom metrics. If `currentMetricsValues` has metric which does not exist + * in `customMetrics` map, it is non-op. */ def updateMetrics( currentMetricsValues: Seq[CustomTaskMetric], customMetrics: Map[String, SQLMetric]): Unit = { currentMetricsValues.foreach { metric => - customMetrics(metric.name()).set(metric.value()) + customMetrics.get(metric.name()).map(_.set(metric.value())) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index c31307f5ac..47888b70ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -137,11 +137,11 @@ class MicroBatchExecution( // TODO (SPARK-27484): we should add the writing node before the plan is analyzed. sink match { case s: SupportsWrite => - val streamingWrite = createStreamingWrite(s, extraOptions, _logicalPlan) + val (streamingWrite, customMetrics) = createStreamingWrite(s, extraOptions, _logicalPlan) val relationOpt = plan.catalogAndIdent.map { case (catalog, ident) => DataSourceV2Relation.create(s, Some(catalog), Some(ident)) } - WriteToMicroBatchDataSource(relationOpt, streamingWrite, _logicalPlan) + WriteToMicroBatchDataSource(relationOpt, streamingWrite, _logicalPlan, customMetrics) case _ => _logicalPlan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 400906bb79..624043d4a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table} +import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit, SparkDataStream} import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsTruncate} import org.apache.spark.sql.connector.write.streaming.StreamingWrite @@ -581,7 +582,7 @@ abstract class StreamExecution( protected def createStreamingWrite( table: SupportsWrite, options: Map[String, String], - inputPlan: LogicalPlan): StreamingWrite = { + inputPlan: LogicalPlan): (StreamingWrite, Seq[CustomMetric]) = { val info = LogicalWriteInfoImpl( queryId = id.toString, inputPlan.schema, @@ -602,7 +603,8 @@ abstract class StreamExecution( table.name + " does not support Update mode.") writeBuilder.asInstanceOf[SupportsStreamingUpdateAsAppend].build() } - write.toStreaming + + (write.toStreaming, write.supportedCustomMetrics().toSeq) } protected def purge(threshold: Long): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 5e40860c66..c4bef706bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -85,9 +85,9 @@ class ContinuousExecution( uniqueSources = sources.distinct.map(s => s -> ReadLimit.allAvailable()).toMap // TODO (SPARK-27484): we should add the writing node before the plan is analyzed. - WriteToContinuousDataSource( - createStreamingWrite( - plan.sink.asInstanceOf[SupportsWrite], extraOptions, _logicalPlan), _logicalPlan) + val (streamingWrite, customMetrics) = createStreamingWrite( + plan.sink.asInstanceOf[SupportsWrite], extraOptions, _logicalPlan) + WriteToContinuousDataSource(streamingWrite, _logicalPlan, customMetrics) } private val triggerExecutor = trigger match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index 909dda57ee..e2a1f412dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -22,6 +22,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.write.DataWriter import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory +import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} import org.apache.spark.util.Utils /** @@ -32,8 +33,8 @@ import org.apache.spark.util.Utils * * We keep repeating prev.compute() and writing new epochs until the query is shut down. */ -class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDataWriterFactory) - extends RDD[Unit](prev) { +class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDataWriterFactory, + customMetrics: Map[String, SQLMetric]) extends RDD[Unit](prev) { override val partitioner = prev.partitioner @@ -55,9 +56,15 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDat context.partitionId(), context.taskAttemptId(), EpochTracker.getCurrentEpoch.get) + var count = 0L while (dataIterator.hasNext) { + if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) { + CustomMetrics.updateMetrics(dataWriter.currentMetricsValues, customMetrics) + } + count += 1 dataWriter.write(dataIterator.next()) } + CustomMetrics.updateMetrics(dataWriter.currentMetricsValues, customMetrics) logInfo(s"Writer for partition ${context.partitionId()} " + s"in epoch ${EpochTracker.getCurrentEpoch.get} is committing.") val msg = dataWriter.commit() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala index ceb52f520d..d5e6bab05c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -19,12 +19,14 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} +import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.write.streaming.StreamingWrite /** * The logical plan for writing data in a continuous stream. */ -case class WriteToContinuousDataSource(write: StreamingWrite, query: LogicalPlan) +case class WriteToContinuousDataSource(write: StreamingWrite, query: LogicalPlan, + customMetrics: Seq[CustomMetric]) extends UnaryNode { override def child: LogicalPlan = query override def output: Seq[Attribute] = Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index a7c59ce831..45f50d2527 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -23,26 +23,33 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.write.PhysicalWriteInfoImpl import org.apache.spark.sql.connector.write.streaming.StreamingWrite import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.StreamExecution /** * The physical plan for writing data into a continuous processing [[StreamingWrite]]. */ -case class WriteToContinuousDataSourceExec(write: StreamingWrite, query: SparkPlan) +case class WriteToContinuousDataSourceExec(write: StreamingWrite, query: SparkPlan, + customMetrics: Seq[CustomMetric]) extends UnaryExecNode with Logging { override def child: SparkPlan = query override def output: Seq[Attribute] = Nil + override lazy val metrics = customMetrics.map { customMetric => + customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric) + }.toMap + override protected def doExecute(): RDD[InternalRow] = { val queryRdd = query.execute() val writerFactory = write.createStreamingWriterFactory( PhysicalWriteInfoImpl(queryRdd.getNumPartitions)) - val rdd = new ContinuousWriteRDD(queryRdd, writerFactory) + val rdd = new ContinuousWriteRDD(queryRdd, writerFactory, metrics) logInfo(s"Start processing data source write support: $write. " + s"The input RDD has ${rdd.partitions.length} partitions.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala index f0422fb4bd..b8b85a7ded 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} +import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.write.streaming.StreamingWrite import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} @@ -31,13 +32,14 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, Writ case class WriteToMicroBatchDataSource( relation: Option[DataSourceV2Relation], write: StreamingWrite, - query: LogicalPlan) + query: LogicalPlan, + customMetrics: Seq[CustomMetric]) extends UnaryNode { override def child: LogicalPlan = query override def output: Seq[Attribute] = Nil def createPlan(batchId: Long): WriteToDataSourceV2 = { - WriteToDataSourceV2(relation, new MicroBatchWrite(batchId, write), query) + WriteToDataSourceV2(relation, new MicroBatchWrite(batchId, write), query, customMetrics) } override protected def withNewChildInternal(newChild: LogicalPlan): WriteToMicroBatchDataSource = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala index 49a6742a85..bb2acecc78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala @@ -65,8 +65,8 @@ class SimpleWritableDataSource extends TestingV2Source { class MyWriteBuilder(path: String, info: LogicalWriteInfo) extends WriteBuilder with SupportsTruncate { - private val queryId: String = info.queryId() - private var needTruncate = false + protected val queryId: String = info.queryId() + protected var needTruncate = false override def truncate(): WriteBuilder = { this.needTruncate = true @@ -127,8 +127,8 @@ class SimpleWritableDataSource extends TestingV2Source { class MyTable(options: CaseInsensitiveStringMap) extends SimpleBatchTable with SupportsWrite { - private val path = options.get("path") - private val conf = SparkContext.getActive.get.hadoopConfiguration + protected val path = options.get("path") + protected val conf = SparkContext.getActive.get.hadoopConfiguration override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder(new Path(path).toUri.toString, conf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriterMetricSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriterMetricSuite.scala new file mode 100644 index 0000000000..0a0a27e1f8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriterMetricSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources + +import java.util.Collections + +import org.scalatest.BeforeAndAfter +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StructType + +class FileFormatDataWriterMetricSuite + extends QueryTest with SharedSparkSession with BeforeAndAfter { + import testImplicits._ + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + before { + spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) + } + + after { + spark.sessionState.catalogManager.reset() + spark.sessionState.conf.clear() + } + + private def testMetricOnDSv2(func: String => Unit, checker: Map[Long, String] => Unit) { + withTable("testcat.table_name") { + val statusStore = spark.sharedState.statusStore + val oldCount = statusStore.executionsList().size + + val testCatalog = spark.sessionState.catalogManager.catalog("testcat").asTableCatalog + + testCatalog.createTable( + Identifier.of(Array(), "table_name"), + new StructType().add("i", "int"), + Array.empty, Collections.emptyMap[String, String]) + + func("testcat.table_name") + + // Wait until the new execution is started and being tracked. + eventually(timeout(10.seconds), interval(10.milliseconds)) { + assert(statusStore.executionsCount() >= oldCount) + } + + // Wait for listener to finish computing the metrics for the execution. + eventually(timeout(10.seconds), interval(10.milliseconds)) { + assert(statusStore.executionsList().nonEmpty && + statusStore.executionsList().last.metricValues != null) + } + + val execId = statusStore.executionsList().last.executionId + val metrics = statusStore.executionMetrics(execId) + checker(metrics) + } + } + + test("Report metrics from Datasource v2 write: append") { + testMetricOnDSv2(table => { + val df = sql("select 1 as i") + val v2Writer = df.writeTo(table) + v2Writer.append() + }, metrics => { + val customMetric = metrics.find(_._2 == "in-memory rows: 1") + assert(customMetric.isDefined) + }) + } + + test("Report metrics from Datasource v2 write: overwrite") { + testMetricOnDSv2(table => { + val df = Seq(1, 2, 3).toDF("i") + val v2Writer = df.writeTo(table) + v2Writer.overwrite(lit(true)) + }, metrics => { + val customMetric = metrics.find(_._2 == "in-memory rows: 3") + assert(customMetric.isDefined) + }) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala new file mode 100644 index 0000000000..4a3c1141a7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources + +import java.util.Collections + +import org.scalatest.BeforeAndAfter +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StructType + +class InMemoryTableMetricSuite + extends QueryTest with SharedSparkSession with BeforeAndAfter { + import testImplicits._ + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + before { + spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) + } + + after { + spark.sessionState.catalogManager.reset() + spark.sessionState.conf.clear() + } + + private def testMetricOnDSv2(func: String => Unit, checker: Map[Long, String] => Unit) { + withTable("testcat.table_name") { + val statusStore = spark.sharedState.statusStore + val oldCount = statusStore.executionsList().size + + val testCatalog = spark.sessionState.catalogManager.catalog("testcat").asTableCatalog + + testCatalog.createTable( + Identifier.of(Array(), "table_name"), + new StructType().add("i", "int"), + Array.empty, Collections.emptyMap[String, String]) + + func("testcat.table_name") + + // Wait until the new execution is started and being tracked. + eventually(timeout(10.seconds), interval(10.milliseconds)) { + assert(statusStore.executionsCount() >= oldCount) + } + + // Wait for listener to finish computing the metrics for the execution. + eventually(timeout(10.seconds), interval(10.milliseconds)) { + assert(statusStore.executionsList().nonEmpty && + statusStore.executionsList().last.metricValues != null) + } + + val execId = statusStore.executionsList().last.executionId + val metrics = statusStore.executionMetrics(execId) + checker(metrics) + } + } + + test("Report metrics from Datasource v2 write: append") { + testMetricOnDSv2(table => { + val df = sql("select 1 as i") + val v2Writer = df.writeTo(table) + v2Writer.append() + }, metrics => { + val customMetric = metrics.find(_._2 == "in-memory rows: 1") + assert(customMetric.isDefined) + }) + } + + test("Report metrics from Datasource v2 write: overwrite") { + testMetricOnDSv2(table => { + val df = Seq(1, 2, 3).toDF("i") + val v2Writer = df.writeTo(table) + v2Writer.overwrite(lit(true)) + }, metrics => { + val customMetric = metrics.find(_._2 == "in-memory rows: 3") + assert(customMetric.isDefined) + }) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala index 440b0dc08e..f182499bae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.metric import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.connector.metric.{CustomAvgMetric, CustomSumMetric} +import org.apache.spark.sql.connector.metric.{CustomAvgMetric, CustomSumMetric, CustomTaskMetric} class CustomMetricsSuite extends SparkFunSuite { @@ -52,6 +52,14 @@ class CustomMetricsSuite extends SparkFunSuite { val metricValues2 = Array.empty[Long] assert(metric.aggregateTaskMetrics(metricValues2) == "0") } + + test("Report unsupported metrics should be non-op") { + val taskMetric = new CustomTaskMetric { + override def name(): String = "custom_metric" + override def value(): Long = 1L + } + CustomMetrics.updateMetrics(Seq(taskMetric), Map.empty) + } } private[spark] class TestCustomSumMetric extends CustomSumMetric { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index c9609e8402..bcb5892c50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -21,8 +21,11 @@ import java.util.Properties import scala.collection.mutable.ListBuffer +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} import org.json4s.jackson.JsonMethods._ import org.scalatest.BeforeAndAfter +import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.LocalSparkContext._ @@ -37,9 +40,11 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.connector.{RangeInputPartition, SimpleScanBuilder} +import org.apache.spark.sql.connector.{CSVDataWriter, CSVDataWriterFactory, RangeInputPartition, SimpleScanBuilder, SimpleWritableDataSource} +import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, Write, WriteBuilder} import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanInfo, SQLExecution} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution import org.apache.spark.sql.execution.datasources.v2.BatchScanExec @@ -49,8 +54,9 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.UI_RETAINED_EXECUTIONS import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.status.ElementTrackingStore -import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator} +import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator, SerializableConfiguration} import org.apache.spark.util.kvstore.InMemoryStore @@ -852,6 +858,33 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils assert(metrics.contains(expectedMetric.id)) assert(metrics(expectedMetric.id) === expectedValue) } + + test("SPARK-36030: Report metrics from Datasource v2 write") { + withTempDir { dir => + val statusStore = spark.sharedState.statusStore + val oldCount = statusStore.executionsList().size + + val cls = classOf[CustomMetricsDataSource].getName + spark.range(10).select('id as 'i, -'id as 'j).write.format(cls) + .option("path", dir.getCanonicalPath).mode("append").save() + + // Wait until the new execution is started and being tracked. + eventually(timeout(10.seconds), interval(10.milliseconds)) { + assert(statusStore.executionsCount() >= oldCount) + } + + // Wait for listener to finish computing the metrics for the execution. + eventually(timeout(10.seconds), interval(10.milliseconds)) { + assert(statusStore.executionsList().nonEmpty && + statusStore.executionsList().last.metricValues != null) + } + + val execId = statusStore.executionsList().last.executionId + val metrics = statusStore.executionMetrics(execId) + val customMetric = metrics.find(_._2 == "custom_metric: 12345, 12345") + assert(customMetric.isDefined) + } + } } @@ -973,3 +1006,68 @@ class CustomMetricScanBuilder extends SimpleScanBuilder { override def createReaderFactory(): PartitionReaderFactory = CustomMetricReaderFactory } + +class CustomMetricsCSVDataWriter(fs: FileSystem, file: Path) extends CSVDataWriter(fs, file) { + override def currentMetricsValues(): Array[CustomTaskMetric] = { + val metric = new CustomTaskMetric { + override def name(): String = "custom_metric" + override def value(): Long = 12345 + } + Array(metric) + } +} + +class CustomMetricsWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) + extends CSVDataWriterFactory(path, jobId, conf) { + + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + val jobPath = new Path(new Path(path, "_temporary"), jobId) + val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") + val fs = filePath.getFileSystem(conf.value) + new CustomMetricsCSVDataWriter(fs, filePath) + } +} + +class CustomMetricsDataSource extends SimpleWritableDataSource { + + class CustomMetricBatchWrite(queryId: String, path: String, conf: Configuration) + extends MyBatchWrite(queryId, path, conf) { + override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = { + new CustomMetricsWriterFactory(path, queryId, new SerializableConfiguration(conf)) + } + } + + class CustomMetricWriteBuilder(path: String, info: LogicalWriteInfo) + extends MyWriteBuilder(path, info) { + override def build(): Write = { + new Write { + override def toBatch: BatchWrite = { + val hadoopPath = new Path(path) + val hadoopConf = SparkContext.getActive.get.hadoopConfiguration + val fs = hadoopPath.getFileSystem(hadoopConf) + + if (needTruncate) { + fs.delete(hadoopPath, true) + } + + val pathStr = hadoopPath.toUri.toString + new CustomMetricBatchWrite(queryId, pathStr, hadoopConf) + } + + override def supportedCustomMetrics(): Array[CustomMetric] = { + Array(new SimpleCustomMetric) + } + } + } + } + + class CustomMetricTable(options: CaseInsensitiveStringMap) extends MyTable(options) { + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + new CustomMetricWriteBuilder(path, info) + } + } + + override def getTable(options: CaseInsensitiveStringMap): Table = { + new CustomMetricTable(options) + } +}