diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index b020a6d992..fda33cd829 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -37,12 +37,11 @@ import org.slf4j.LoggerFactory; import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; -import org.apache.spark.TaskContext; -import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.*; @@ -79,7 +78,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final int numPartitions; private final BlockManager blockManager; private final Partitioner partitioner; - private final ShuffleWriteMetrics writeMetrics; + private final ShuffleWriteMetricsReporter writeMetrics; private final int shuffleId; private final int mapId; private final Serializer serializer; @@ -103,8 +102,8 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { IndexShuffleBlockResolver shuffleBlockResolver, BypassMergeSortShuffleHandle handle, int mapId, - TaskContext taskContext, - SparkConf conf) { + SparkConf conf, + ShuffleWriteMetricsReporter writeMetrics) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); @@ -114,7 +113,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.shuffleId = dep.shuffleId(); this.partitioner = dep.partitioner(); this.numPartitions = partitioner.numPartitions(); - this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics(); + this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); this.shuffleBlockResolver = shuffleBlockResolver; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 1c0d664afb..6ee9d5f0ee 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -38,6 +38,7 @@ import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.memory.TooLargePageException; import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.FileSegment; @@ -75,7 +76,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { private final TaskMemoryManager taskMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; - private final ShuffleWriteMetrics writeMetrics; + private final ShuffleWriteMetricsReporter writeMetrics; /** * Force this sorter to spill when there are this many elements in memory. @@ -113,7 +114,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { int initialSize, int numPartitions, SparkConf conf, - ShuffleWriteMetrics writeMetrics) { + ShuffleWriteMetricsReporter writeMetrics) { super(memoryManager, (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, memoryManager.pageSizeBytes()), memoryManager.getTungstenMemoryMode()); @@ -144,7 +145,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { */ private void writeSortedFile(boolean isLastFile) { - final ShuffleWriteMetrics writeMetricsToUse; + final ShuffleWriteMetricsReporter writeMetricsToUse; if (isLastFile) { // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. @@ -241,9 +242,14 @@ final class ShuffleExternalSorter extends MemoryConsumer { // // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`. // Consistent with ExternalSorter, we do not count this IO towards shuffle write time. - // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this. - writeMetrics.incRecordsWritten(writeMetricsToUse.recordsWritten()); - taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.bytesWritten()); + // SPARK-3577 tracks the spill time separately. + + // This is guaranteed to be a ShuffleWriteMetrics based on the if check in the beginning + // of this method. + writeMetrics.incRecordsWritten( + ((ShuffleWriteMetrics)writeMetricsToUse).recordsWritten()); + taskContext.taskMetrics().incDiskBytesSpilled( + ((ShuffleWriteMetrics)writeMetricsToUse).bytesWritten()); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 4839d04522..4b0c743415 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -37,7 +37,6 @@ import org.slf4j.LoggerFactory; import org.apache.spark.*; import org.apache.spark.annotation.Private; -import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.NioBufferedFileInputStream; @@ -47,6 +46,7 @@ import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.IndexShuffleBlockResolver; @@ -73,7 +73,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final TaskMemoryManager memoryManager; private final SerializerInstance serializer; private final Partitioner partitioner; - private final ShuffleWriteMetrics writeMetrics; + private final ShuffleWriteMetricsReporter writeMetrics; private final int shuffleId; private final int mapId; private final TaskContext taskContext; @@ -122,7 +122,8 @@ public class UnsafeShuffleWriter extends ShuffleWriter { SerializedShuffleHandle handle, int mapId, TaskContext taskContext, - SparkConf sparkConf) throws IOException { + SparkConf sparkConf, + ShuffleWriteMetricsReporter writeMetrics) throws IOException { final int numPartitions = handle.dependency().partitioner().numPartitions(); if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( @@ -138,7 +139,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { this.shuffleId = dep.shuffleId(); this.serializer = dep.serializer().newInstance(); this.partitioner = dep.partitioner(); - this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics(); + this.writeMetrics = writeMetrics; this.taskContext = taskContext; this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); diff --git a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java index 5d0555a8c2..fcba3b7344 100644 --- a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java +++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java @@ -21,7 +21,7 @@ import java.io.IOException; import java.io.OutputStream; import org.apache.spark.annotation.Private; -import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; /** * Intercepts write calls and tracks total time spent writing in order to update shuffle write @@ -30,10 +30,11 @@ import org.apache.spark.executor.ShuffleWriteMetrics; @Private public final class TimeTrackingOutputStream extends OutputStream { - private final ShuffleWriteMetrics writeMetrics; + private final ShuffleWriteMetricsReporter writeMetrics; private final OutputStream outputStream; - public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream outputStream) { + public TimeTrackingOutputStream( + ShuffleWriteMetricsReporter writeMetrics, OutputStream outputStream) { this.writeMetrics = writeMetrics; this.outputStream = outputStream; } diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala index 0c9da657c2..d0b0e7da07 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala @@ -18,6 +18,7 @@ package org.apache.spark.executor import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.util.LongAccumulator @@ -27,7 +28,7 @@ import org.apache.spark.util.LongAccumulator * Operations are not thread-safe. */ @DeveloperApi -class ShuffleWriteMetrics private[spark] () extends Serializable { +class ShuffleWriteMetrics private[spark] () extends ShuffleWriteMetricsReporter with Serializable { private[executor] val _bytesWritten = new LongAccumulator private[executor] val _recordsWritten = new LongAccumulator private[executor] val _writeTime = new LongAccumulator @@ -47,13 +48,13 @@ class ShuffleWriteMetrics private[spark] () extends Serializable { */ def writeTime: Long = _writeTime.sum - private[spark] def incBytesWritten(v: Long): Unit = _bytesWritten.add(v) - private[spark] def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v) - private[spark] def incWriteTime(v: Long): Unit = _writeTime.add(v) - private[spark] def decBytesWritten(v: Long): Unit = { + private[spark] override def incBytesWritten(v: Long): Unit = _bytesWritten.add(v) + private[spark] override def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v) + private[spark] override def incWriteTime(v: Long): Unit = _writeTime.add(v) + private[spark] override def decBytesWritten(v: Long): Unit = { _bytesWritten.setValue(bytesWritten - v) } - private[spark] def decRecordsWritten(v: Long): Unit = { + private[spark] override def decRecordsWritten(v: Long): Unit = { _recordsWritten.setValue(recordsWritten - v) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index f2cd65fd52..5412717d61 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -95,7 +95,8 @@ private[spark] class ShuffleMapTask( var writer: ShuffleWriter[Any, Any] = null try { val manager = SparkEnv.get.shuffleManager - writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) + writer = manager.getWriter[Any, Any]( + dep.shuffleHandle, partitionId, context, context.taskMetrics().shuffleWriteMetrics) writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) writer.stop(success = true).get } catch { diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index df601cbdb2..18a743fbfa 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -38,7 +38,11 @@ private[spark] trait ShuffleManager { dependency: ShuffleDependency[K, V, C]): ShuffleHandle /** Get a writer for a given partition. Called on executors by map tasks. */ - def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V] + def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Int, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] /** * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 4f8be198e4..b51a843a31 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -125,7 +125,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager override def getWriter[K, V]( handle: ShuffleHandle, mapId: Int, - context: TaskContext): ShuffleWriter[K, V] = { + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { numMapsForShuffle.putIfAbsent( handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) val env = SparkEnv.get @@ -138,15 +139,16 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager unsafeShuffleHandle, mapId, context, - env.conf) + env.conf, + metrics) case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], bypassMergeSortHandle, mapId, - context, - env.conf) + env.conf, + metrics) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index edae2f95fc..1b617297e0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -33,10 +33,9 @@ import scala.util.Random import scala.util.control.NonFatal import com.codahale.metrics.{MetricRegistry, MetricSet} -import com.google.common.io.CountingOutputStream import org.apache.spark._ -import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} +import org.apache.spark.executor.DataReadMethod import org.apache.spark.internal.{config, Logging} import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.metrics.source.Source @@ -50,7 +49,7 @@ import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} -import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.{ShuffleManager, ShuffleWriteMetricsReporter} import org.apache.spark.storage.memory._ import org.apache.spark.unsafe.Platform import org.apache.spark.util._ @@ -932,7 +931,7 @@ private[spark] class BlockManager( file: File, serializerInstance: SerializerInstance, bufferSize: Int, - writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { + writeMetrics: ShuffleWriteMetricsReporter): DiskBlockObjectWriter = { val syncWrites = conf.getBoolean("spark.shuffle.sync", false) new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize, syncWrites, writeMetrics, blockId) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index a024c83d8d..17390f9c60 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -20,9 +20,9 @@ package org.apache.spark.storage import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream} import java.nio.channels.FileChannel -import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.util.Utils /** @@ -43,7 +43,7 @@ private[spark] class DiskBlockObjectWriter( syncWrites: Boolean, // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. - writeMetrics: ShuffleWriteMetrics, + writeMetrics: ShuffleWriteMetricsReporter, val blockId: BlockId = null) extends OutputStream with Logging { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b159200d79..eac3db0115 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -793,8 +793,8 @@ private[spark] class ExternalSorter[K, V, C]( def nextPartition(): Int = cur._1._1 } - logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + - s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") + logInfo(s"Task ${TaskContext.get().taskAttemptId} force spilling in-memory map to disk " + + s"and it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") val spillFile = spillMemoryIteratorToDisk(inMemoryIterator) forceSpillFiles += spillFile val spillReader = new SpillReader(spillFile) diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index a07d0e84ea..30ad3f5575 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -162,7 +162,8 @@ public class UnsafeShuffleWriterSuite { new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics() ); } @@ -521,7 +522,8 @@ public class UnsafeShuffleWriterSuite { new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, - conf); + conf, + taskContext.taskMetrics().shuffleWriteMetrics()); // Peak memory should be monotonically increasing. More specifically, every time // we allocate a new page it should increase by exactly the size of the page. diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 419a26b857..35f728cd57 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -362,15 +362,19 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC mapTrackerMaster.registerShuffle(0, 1) // first attempt -- its successful - val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) + val context1 = + new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem) + val writer1 = manager.getWriter[Int, Int]( + shuffleHandle, 0, context1, context1.taskMetrics.shuffleWriteMetrics) val data1 = (1 to 10).map { x => x -> x} // second attempt -- also successful. We'll write out different data, // just to simulate the fact that the records may get written differently // depending on what gets spilled, what gets combined, etc. - val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) + val context2 = + new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem) + val writer2 = manager.getWriter[Int, Int]( + shuffleHandle, 0, context2, context2.taskMetrics.shuffleWriteMetrics) val data2 = (11 to 20).map { x => x -> x} // interleave writes of both attempts -- we want to test that both attempts can occur diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 85ccb33471..4467c3241a 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -136,8 +136,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId - taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics ) writer.write(Iterator.empty) writer.stop( /* success = */ true) @@ -160,8 +160,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId - taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics ) writer.write(records) writer.stop( /* success = */ true) @@ -195,8 +195,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId - taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics ) intercept[SparkException] { @@ -217,8 +217,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId - taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics ) intercept[SparkException] { writer.write((0 until 100000).iterator.map(i => { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 333adb0c84..3fabec0f60 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -226,7 +226,12 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.DataSourceWriter"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.writer.DataWriterFactory.createWriter"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter"), + + // [SPARK-26141] Enable custom metrics implementation in shuffle write + // Following are Java private classes + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.TimeTrackingOutputStream.this") ) // Exclude rules for 2.4.x