[SPARK-26141] Enable custom metrics implementation in shuffle write

## What changes were proposed in this pull request?
This is the write side counterpart to https://github.com/apache/spark/pull/23105

## How was this patch tested?
No behavior change expected, as it is a straightforward refactoring. Updated all existing test cases.

Closes #23106 from rxin/SPARK-26141.

Authored-by: Reynold Xin <rxin@databricks.com>
Signed-off-by: Reynold Xin <rxin@databricks.com>
This commit is contained in:
Reynold Xin 2018-11-26 22:35:52 -08:00
parent 85383d29ed
commit 6a064ba8f2
15 changed files with 79 additions and 54 deletions

View file

@ -37,12 +37,11 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.*;
@ -79,7 +78,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private final int numPartitions;
private final BlockManager blockManager;
private final Partitioner partitioner;
private final ShuffleWriteMetrics writeMetrics;
private final ShuffleWriteMetricsReporter writeMetrics;
private final int shuffleId;
private final int mapId;
private final Serializer serializer;
@ -103,8 +102,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
IndexShuffleBlockResolver shuffleBlockResolver,
BypassMergeSortShuffleHandle<K, V> handle,
int mapId,
TaskContext taskContext,
SparkConf conf) {
SparkConf conf,
ShuffleWriteMetricsReporter writeMetrics) {
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
@ -114,7 +113,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
this.shuffleId = dep.shuffleId();
this.partitioner = dep.partitioner();
this.numPartitions = partitioner.numPartitions();
this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
this.writeMetrics = writeMetrics;
this.serializer = dep.serializer();
this.shuffleBlockResolver = shuffleBlockResolver;
}

View file

@ -38,6 +38,7 @@ import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.memory.TooLargePageException;
import org.apache.spark.serializer.DummySerializerInstance;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.DiskBlockObjectWriter;
import org.apache.spark.storage.FileSegment;
@ -75,7 +76,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
private final TaskMemoryManager taskMemoryManager;
private final BlockManager blockManager;
private final TaskContext taskContext;
private final ShuffleWriteMetrics writeMetrics;
private final ShuffleWriteMetricsReporter writeMetrics;
/**
* Force this sorter to spill when there are this many elements in memory.
@ -113,7 +114,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
int initialSize,
int numPartitions,
SparkConf conf,
ShuffleWriteMetrics writeMetrics) {
ShuffleWriteMetricsReporter writeMetrics) {
super(memoryManager,
(int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, memoryManager.pageSizeBytes()),
memoryManager.getTungstenMemoryMode());
@ -144,7 +145,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
*/
private void writeSortedFile(boolean isLastFile) {
final ShuffleWriteMetrics writeMetricsToUse;
final ShuffleWriteMetricsReporter writeMetricsToUse;
if (isLastFile) {
// We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes.
@ -241,9 +242,14 @@ final class ShuffleExternalSorter extends MemoryConsumer {
//
// Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`.
// Consistent with ExternalSorter, we do not count this IO towards shuffle write time.
// This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this.
writeMetrics.incRecordsWritten(writeMetricsToUse.recordsWritten());
taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.bytesWritten());
// SPARK-3577 tracks the spill time separately.
// This is guaranteed to be a ShuffleWriteMetrics based on the if check in the beginning
// of this method.
writeMetrics.incRecordsWritten(
((ShuffleWriteMetrics)writeMetricsToUse).recordsWritten());
taskContext.taskMetrics().incDiskBytesSpilled(
((ShuffleWriteMetrics)writeMetricsToUse).bytesWritten());
}
}

View file

@ -37,7 +37,6 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.*;
import org.apache.spark.annotation.Private;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.io.NioBufferedFileInputStream;
@ -47,6 +46,7 @@ import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
@ -73,7 +73,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private final TaskMemoryManager memoryManager;
private final SerializerInstance serializer;
private final Partitioner partitioner;
private final ShuffleWriteMetrics writeMetrics;
private final ShuffleWriteMetricsReporter writeMetrics;
private final int shuffleId;
private final int mapId;
private final TaskContext taskContext;
@ -122,7 +122,8 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
SerializedShuffleHandle<K, V> handle,
int mapId,
TaskContext taskContext,
SparkConf sparkConf) throws IOException {
SparkConf sparkConf,
ShuffleWriteMetricsReporter writeMetrics) throws IOException {
final int numPartitions = handle.dependency().partitioner().numPartitions();
if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
throw new IllegalArgumentException(
@ -138,7 +139,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
this.shuffleId = dep.shuffleId();
this.serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
this.writeMetrics = writeMetrics;
this.taskContext = taskContext;
this.sparkConf = sparkConf;
this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);

View file

@ -21,7 +21,7 @@ import java.io.IOException;
import java.io.OutputStream;
import org.apache.spark.annotation.Private;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
/**
* Intercepts write calls and tracks total time spent writing in order to update shuffle write
@ -30,10 +30,11 @@ import org.apache.spark.executor.ShuffleWriteMetrics;
@Private
public final class TimeTrackingOutputStream extends OutputStream {
private final ShuffleWriteMetrics writeMetrics;
private final ShuffleWriteMetricsReporter writeMetrics;
private final OutputStream outputStream;
public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream outputStream) {
public TimeTrackingOutputStream(
ShuffleWriteMetricsReporter writeMetrics, OutputStream outputStream) {
this.writeMetrics = writeMetrics;
this.outputStream = outputStream;
}

View file

@ -18,6 +18,7 @@
package org.apache.spark.executor
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.util.LongAccumulator
@ -27,7 +28,7 @@ import org.apache.spark.util.LongAccumulator
* Operations are not thread-safe.
*/
@DeveloperApi
class ShuffleWriteMetrics private[spark] () extends Serializable {
class ShuffleWriteMetrics private[spark] () extends ShuffleWriteMetricsReporter with Serializable {
private[executor] val _bytesWritten = new LongAccumulator
private[executor] val _recordsWritten = new LongAccumulator
private[executor] val _writeTime = new LongAccumulator
@ -47,13 +48,13 @@ class ShuffleWriteMetrics private[spark] () extends Serializable {
*/
def writeTime: Long = _writeTime.sum
private[spark] def incBytesWritten(v: Long): Unit = _bytesWritten.add(v)
private[spark] def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v)
private[spark] def incWriteTime(v: Long): Unit = _writeTime.add(v)
private[spark] def decBytesWritten(v: Long): Unit = {
private[spark] override def incBytesWritten(v: Long): Unit = _bytesWritten.add(v)
private[spark] override def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v)
private[spark] override def incWriteTime(v: Long): Unit = _writeTime.add(v)
private[spark] override def decBytesWritten(v: Long): Unit = {
_bytesWritten.setValue(bytesWritten - v)
}
private[spark] def decRecordsWritten(v: Long): Unit = {
private[spark] override def decRecordsWritten(v: Long): Unit = {
_recordsWritten.setValue(recordsWritten - v)
}
}

View file

@ -95,7 +95,8 @@ private[spark] class ShuffleMapTask(
var writer: ShuffleWriter[Any, Any] = null
try {
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
writer = manager.getWriter[Any, Any](
dep.shuffleHandle, partitionId, context, context.taskMetrics().shuffleWriteMetrics)
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
writer.stop(success = true).get
} catch {

View file

@ -38,7 +38,11 @@ private[spark] trait ShuffleManager {
dependency: ShuffleDependency[K, V, C]): ShuffleHandle
/** Get a writer for a given partition. Called on executors by map tasks. */
def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V]
def getWriter[K, V](
handle: ShuffleHandle,
mapId: Int,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V]
/**
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).

View file

@ -125,7 +125,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
override def getWriter[K, V](
handle: ShuffleHandle,
mapId: Int,
context: TaskContext): ShuffleWriter[K, V] = {
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
numMapsForShuffle.putIfAbsent(
handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps)
val env = SparkEnv.get
@ -138,15 +139,16 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
unsafeShuffleHandle,
mapId,
context,
env.conf)
env.conf,
metrics)
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
bypassMergeSortHandle,
mapId,
context,
env.conf)
env.conf,
metrics)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
}

View file

@ -33,10 +33,9 @@ import scala.util.Random
import scala.util.control.NonFatal
import com.codahale.metrics.{MetricRegistry, MetricSet}
import com.google.common.io.CountingOutputStream
import org.apache.spark._
import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics}
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.memory.{MemoryManager, MemoryMode}
import org.apache.spark.metrics.source.Source
@ -50,7 +49,7 @@ import org.apache.spark.network.util.TransportConf
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.serializer.{SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.shuffle.{ShuffleManager, ShuffleWriteMetricsReporter}
import org.apache.spark.storage.memory._
import org.apache.spark.unsafe.Platform
import org.apache.spark.util._
@ -932,7 +931,7 @@ private[spark] class BlockManager(
file: File,
serializerInstance: SerializerInstance,
bufferSize: Int,
writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
writeMetrics: ShuffleWriteMetricsReporter): DiskBlockObjectWriter = {
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize,
syncWrites, writeMetrics, blockId)

View file

@ -20,9 +20,9 @@ package org.apache.spark.storage
import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream}
import java.nio.channels.FileChannel
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.util.Utils
/**
@ -43,7 +43,7 @@ private[spark] class DiskBlockObjectWriter(
syncWrites: Boolean,
// These write metrics concurrently shared with other active DiskBlockObjectWriters who
// are themselves performing writes. All updates must be relative.
writeMetrics: ShuffleWriteMetrics,
writeMetrics: ShuffleWriteMetricsReporter,
val blockId: BlockId = null)
extends OutputStream
with Logging {

View file

@ -793,8 +793,8 @@ private[spark] class ExternalSorter[K, V, C](
def nextPartition(): Int = cur._1._1
}
logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " +
s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
logInfo(s"Task ${TaskContext.get().taskAttemptId} force spilling in-memory map to disk " +
s"and it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
forceSpillFiles += spillFile
val spillReader = new SpillReader(spillFile)

View file

@ -162,7 +162,8 @@ public class UnsafeShuffleWriterSuite {
new SerializedShuffleHandle<>(0, 1, shuffleDep),
0, // map id
taskContext,
conf
conf,
taskContext.taskMetrics().shuffleWriteMetrics()
);
}
@ -521,7 +522,8 @@ public class UnsafeShuffleWriterSuite {
new SerializedShuffleHandle<>(0, 1, shuffleDep),
0, // map id
taskContext,
conf);
conf,
taskContext.taskMetrics().shuffleWriteMetrics());
// Peak memory should be monotonically increasing. More specifically, every time
// we allocate a new page it should increase by exactly the size of the page.

View file

@ -362,15 +362,19 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
mapTrackerMaster.registerShuffle(0, 1)
// first attempt -- its successful
val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem))
val context1 =
new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)
val writer1 = manager.getWriter[Int, Int](
shuffleHandle, 0, context1, context1.taskMetrics.shuffleWriteMetrics)
val data1 = (1 to 10).map { x => x -> x}
// second attempt -- also successful. We'll write out different data,
// just to simulate the fact that the records may get written differently
// depending on what gets spilled, what gets combined, etc.
val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem))
val context2 =
new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)
val writer2 = manager.getWriter[Int, Int](
shuffleHandle, 0, context2, context2.taskMetrics.shuffleWriteMetrics)
val data2 = (11 to 20).map { x => x -> x}
// interleave writes of both attempts -- we want to test that both attempts can occur

View file

@ -136,8 +136,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
blockResolver,
shuffleHandle,
0, // MapId
taskContext,
conf
conf,
taskContext.taskMetrics().shuffleWriteMetrics
)
writer.write(Iterator.empty)
writer.stop( /* success = */ true)
@ -160,8 +160,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
blockResolver,
shuffleHandle,
0, // MapId
taskContext,
conf
conf,
taskContext.taskMetrics().shuffleWriteMetrics
)
writer.write(records)
writer.stop( /* success = */ true)
@ -195,8 +195,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
blockResolver,
shuffleHandle,
0, // MapId
taskContext,
conf
conf,
taskContext.taskMetrics().shuffleWriteMetrics
)
intercept[SparkException] {
@ -217,8 +217,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
blockResolver,
shuffleHandle,
0, // MapId
taskContext,
conf
conf,
taskContext.taskMetrics().shuffleWriteMetrics
)
intercept[SparkException] {
writer.write((0 until 100000).iterator.map(i => {

View file

@ -226,7 +226,12 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.DataSourceWriter"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.writer.DataWriterFactory.createWriter"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter")
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter"),
// [SPARK-26141] Enable custom metrics implementation in shuffle write
// Following are Java private classes
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.TimeTrackingOutputStream.this")
)
// Exclude rules for 2.4.x