diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 1aa6ba4201..bf4eaa59ff 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -20,6 +20,7 @@ package org.apache.spark.shuffle.unsafe; import java.io.File; import java.io.IOException; import java.util.LinkedList; +import javax.annotation.Nullable; import scala.Tuple2; @@ -86,9 +87,12 @@ final class UnsafeShuffleExternalSorter { private final LinkedList spills = new LinkedList(); + /** Peak memory used by this sorter so far, in bytes. **/ + private long peakMemoryUsedBytes; + // These variables are reset after spilling: - private UnsafeShuffleInMemorySorter sorter; - private MemoryBlock currentPage = null; + @Nullable private UnsafeShuffleInMemorySorter sorter; + @Nullable private MemoryBlock currentPage = null; private long currentPagePosition = -1; private long freeSpaceInCurrentPage = 0; @@ -106,6 +110,7 @@ final class UnsafeShuffleExternalSorter { this.blockManager = blockManager; this.taskContext = taskContext; this.initialSize = initialSize; + this.peakMemoryUsedBytes = initialSize; this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; @@ -279,10 +284,26 @@ final class UnsafeShuffleExternalSorter { for (MemoryBlock page : allocatedPages) { totalPageSize += page.size(); } - return sorter.getMemoryUsage() + totalPageSize; + return ((sorter == null) ? 0 : sorter.getMemoryUsage()) + totalPageSize; + } + + private void updatePeakMemoryUsed() { + long mem = getMemoryUsage(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; } private long freeMemory() { + updatePeakMemoryUsed(); long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { memoryManager.freePage(block); diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index d47d6fc9c2..6e2eeb37c8 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -27,6 +27,7 @@ import scala.Product2; import scala.collection.JavaConversions; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; +import scala.collection.immutable.Map; import com.google.common.annotations.VisibleForTesting; import com.google.common.io.ByteStreams; @@ -78,8 +79,9 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final SparkConf sparkConf; private final boolean transferToEnabled; - private MapStatus mapStatus = null; - private UnsafeShuffleExternalSorter sorter = null; + @Nullable private MapStatus mapStatus; + @Nullable private UnsafeShuffleExternalSorter sorter; + private long peakMemoryUsedBytes = 0; /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { @@ -131,9 +133,28 @@ public class UnsafeShuffleWriter extends ShuffleWriter { @VisibleForTesting public int maxRecordSizeBytes() { + assert(sorter != null); return sorter.maxRecordSizeBytes; } + private void updatePeakMemoryUsed() { + // sorter can be null if this writer is closed + if (sorter != null) { + long mem = sorter.getPeakMemoryUsedBytes(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; + } + /** * This convenience method should only be called in test code. */ @@ -144,7 +165,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { @Override public void write(scala.collection.Iterator> records) throws IOException { - // Keep track of success so we know if we ecountered an exception + // Keep track of success so we know if we encountered an exception // We do this rather than a standard try/catch/re-throw to handle // generic throwables. boolean success = false; @@ -189,6 +210,8 @@ public class UnsafeShuffleWriter extends ShuffleWriter { @VisibleForTesting void closeAndWriteOutput() throws IOException { + assert(sorter != null); + updatePeakMemoryUsed(); serBuffer = null; serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); @@ -209,6 +232,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { @VisibleForTesting void insertRecordIntoSorter(Product2 record) throws IOException { + assert(sorter != null); final K key = record._1(); final int partitionId = partitioner.getPartition(key); serBuffer.reset(); @@ -431,6 +455,14 @@ public class UnsafeShuffleWriter extends ShuffleWriter { @Override public Option stop(boolean success) { try { + // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite) + Map> internalAccumulators = + taskContext.internalMetricsToAccumulators(); + if (internalAccumulators != null) { + internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY()) + .add(getPeakMemoryUsedBytes()); + } + if (stopping) { return Option.apply(null); } else { diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 01a66084e9..20347433e1 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -505,7 +505,7 @@ public final class BytesToBytesMap { // Here, we'll copy the data into our data pages. Because we only store a relative offset from // the key address instead of storing the absolute address of the value, the key and value // must be stored in the same memory page. - // (8 byte key length) (key) (8 byte value length) (value) + // (8 byte key length) (key) (value) final long requiredSize = 8 + keyLengthBytes + valueLengthBytes; // --- Figure out where to insert the new record --------------------------------------------- @@ -655,7 +655,10 @@ public final class BytesToBytesMap { return pageSizeBytes; } - /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */ + /** + * Returns the total amount of memory, in bytes, consumed by this map's managed structures. + * Note that this is also the peak memory used by this map, since the map is append-only. + */ public long getTotalMemoryConsumption() { long totalDataPagesSize = 0L; for (MemoryBlock dataPage : dataPages) { @@ -674,7 +677,6 @@ public final class BytesToBytesMap { return timeSpentResizingNs; } - /** * Returns the average number of probes per key lookup. */ diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index b984301cbb..bf5f965a9d 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -70,13 +70,14 @@ public final class UnsafeExternalSorter { private final LinkedList spillWriters = new LinkedList<>(); // These variables are reset after spilling: - private UnsafeInMemorySorter inMemSorter; + @Nullable private UnsafeInMemorySorter inMemSorter; // Whether the in-mem sorter is created internally, or passed in from outside. // If it is passed in from outside, we shouldn't release the in-mem sorter's memory. private boolean isInMemSorterExternal = false; private MemoryBlock currentPage = null; private long currentPagePosition = -1; private long freeSpaceInCurrentPage = 0; + private long peakMemoryUsedBytes = 0; public static UnsafeExternalSorter createWithExistingInMemorySorter( TaskMemoryManager taskMemoryManager, @@ -183,6 +184,7 @@ public final class UnsafeExternalSorter { * Sort and spill the current records in response to memory pressure. */ public void spill() throws IOException { + assert(inMemSorter != null); logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", Thread.currentThread().getId(), Utils.bytesToString(getMemoryUsage()), @@ -219,7 +221,22 @@ public final class UnsafeExternalSorter { for (MemoryBlock page : allocatedPages) { totalPageSize += page.size(); } - return inMemSorter.getMemoryUsage() + totalPageSize; + return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize; + } + + private void updatePeakMemoryUsed() { + long mem = getMemoryUsage(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; } @VisibleForTesting @@ -233,6 +250,7 @@ public final class UnsafeExternalSorter { * @return the number of bytes freed. */ public long freeMemory() { + updatePeakMemoryUsed(); long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { taskMemoryManager.freePage(block); @@ -277,7 +295,8 @@ public final class UnsafeExternalSorter { * @return true if the record can be inserted without requiring more allocations, false otherwise. */ private boolean haveSpaceForRecord(int requiredSpace) { - assert (requiredSpace > 0); + assert(requiredSpace > 0); + assert(inMemSorter != null); return (inMemSorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); } @@ -290,6 +309,7 @@ public final class UnsafeExternalSorter { * the record size. */ private void allocateSpaceForRecord(int requiredSpace) throws IOException { + assert(inMemSorter != null); // TODO: merge these steps to first calculate total memory requirements for this insert, // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the // data page. @@ -350,6 +370,7 @@ public final class UnsafeExternalSorter { if (!haveSpaceForRecord(totalSpaceRequired)) { allocateSpaceForRecord(totalSpaceRequired); } + assert(inMemSorter != null); final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); @@ -382,6 +403,7 @@ public final class UnsafeExternalSorter { if (!haveSpaceForRecord(totalSpaceRequired)) { allocateSpaceForRecord(totalSpaceRequired); } + assert(inMemSorter != null); final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); @@ -405,6 +427,7 @@ public final class UnsafeExternalSorter { } public UnsafeSorterIterator getSortedIterator() throws IOException { + assert(inMemSorter != null); final UnsafeSorterIterator inMemoryIterator = inMemSorter.getSortedIterator(); int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0); if (spillWriters.isEmpty()) { diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index b1cef47042..648cd1b104 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -207,7 +207,7 @@ span.additional-metric-title { /* Hide all additional metrics by default. This is done here rather than using JavaScript to * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ .scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote, -.serialization_time, .getting_result_time { +.serialization_time, .getting_result_time, .peak_execution_memory { display: none; } diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index eb75f26718..b6a0119c69 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -152,8 +152,14 @@ class Accumulable[R, T] private[spark] ( in.defaultReadObject() value_ = zero deserialized = true - val taskContext = TaskContext.get() - taskContext.registerAccumulator(this) + // Automatically register the accumulator when it is deserialized with the task closure. + // Note that internal accumulators are deserialized before the TaskContext is created and + // are registered in the TaskContext constructor. + if (!isInternal) { + val taskContext = TaskContext.get() + assume(taskContext != null, "Task context was null when deserializing user accumulators") + taskContext.registerAccumulator(this) + } } override def toString: String = if (value_ == null) "null" else value_.toString @@ -248,10 +254,20 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa * @param param helper object defining how to add elements of type `T` * @tparam T result type */ -class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String]) - extends Accumulable[T, T](initialValue, param, name) { +class Accumulator[T] private[spark] ( + @transient initialValue: T, + param: AccumulatorParam[T], + name: Option[String], + internal: Boolean) + extends Accumulable[T, T](initialValue, param, name, internal) { - def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None) + def this(initialValue: T, param: AccumulatorParam[T], name: Option[String]) = { + this(initialValue, param, name, false) + } + + def this(initialValue: T, param: AccumulatorParam[T]) = { + this(initialValue, param, None, false) + } } /** @@ -342,3 +358,37 @@ private[spark] object Accumulators extends Logging { } } + +private[spark] object InternalAccumulator { + val PEAK_EXECUTION_MEMORY = "peakExecutionMemory" + val TEST_ACCUMULATOR = "testAccumulator" + + // For testing only. + // This needs to be a def since we don't want to reuse the same accumulator across stages. + private def maybeTestAccumulator: Option[Accumulator[Long]] = { + if (sys.props.contains("spark.testing")) { + Some(new Accumulator( + 0L, AccumulatorParam.LongAccumulatorParam, Some(TEST_ACCUMULATOR), internal = true)) + } else { + None + } + } + + /** + * Accumulators for tracking internal metrics. + * + * These accumulators are created with the stage such that all tasks in the stage will + * add to the same set of accumulators. We do this to report the distribution of accumulator + * values across all tasks within each stage. + */ + def create(): Seq[Accumulator[Long]] = { + Seq( + // Execution memory refers to the memory used by internal data structures created + // during shuffles, aggregations and joins. The value of this accumulator should be + // approximately the sum of the peak sizes across all such data structures created + // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort. + new Accumulator( + 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true) + ) ++ maybeTestAccumulator.toSeq + } +} diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index ceeb58075d..289aab9bd9 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -58,12 +58,7 @@ case class Aggregator[K, V, C] ( } else { val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) combiners.insertAll(iter) - // Update task metrics if context is not null - // TODO: Make context non optional in a future release - Option(context).foreach { c => - c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled) - c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled) - } + updateMetrics(context, combiners) combiners.iterator } } @@ -89,13 +84,18 @@ case class Aggregator[K, V, C] ( } else { val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) combiners.insertAll(iter) - // Update task metrics if context is not null - // TODO: Make context non-optional in a future release - Option(context).foreach { c => - c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled) - c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled) - } + updateMetrics(context, combiners) combiners.iterator } } + + /** Update task metrics after populating the external map. */ + private def updateMetrics(context: TaskContext, map: ExternalAppendOnlyMap[_, _, _]): Unit = { + Option(context).foreach { c => + c.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + c.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + c.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) + } + } } diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 5d2c551d58..63cca80b2d 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -61,12 +61,12 @@ object TaskContext { protected[spark] def unset(): Unit = taskContext.remove() /** - * Return an empty task context that is not actually used. - * Internal use only. + * An empty task context that does not represent an actual task. */ - private[spark] def empty(): TaskContext = { - new TaskContextImpl(0, 0, 0, 0, null, null) + private[spark] def empty(): TaskContextImpl = { + new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty) } + } @@ -187,4 +187,9 @@ abstract class TaskContext extends Serializable { * accumulator id and the value of the Map is the latest accumulator local value. */ private[spark] def collectAccumulators(): Map[Long, Any] + + /** + * Accumulators for tracking internal metrics indexed by the name. + */ + private[spark] val internalMetricsToAccumulators: Map[String, Accumulator[Long]] } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 9ee168ae01..5df94c6d3a 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -32,6 +32,7 @@ private[spark] class TaskContextImpl( override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, @transient private val metricsSystem: MetricsSystem, + internalAccumulators: Seq[Accumulator[Long]], val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext @@ -114,4 +115,11 @@ private[spark] class TaskContextImpl( private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized { accumulators.mapValues(_.localValue).toMap } + + private[spark] override val internalMetricsToAccumulators: Map[String, Accumulator[Long]] = { + // Explicitly register internal accumulators here because these are + // not captured in the task closure and are already deserialized + internalAccumulators.foreach(registerAccumulator) + internalAccumulators.map { a => (a.name.get, a) }.toMap + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 130b58882d..9c617fc719 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -23,8 +23,7 @@ import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} -import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} +import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} import org.apache.spark.util.Utils @@ -169,8 +168,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: for ((it, depNum) <- rddIterators) { map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } - context.taskMetrics.incMemoryBytesSpilled(map.memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(map.diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) new InterruptibleIterator(context, map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index e0edd7d4ae..11d123eec4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -24,11 +24,12 @@ import org.apache.spark.annotation.DeveloperApi * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. */ @DeveloperApi -class AccumulableInfo ( +class AccumulableInfo private[spark] ( val id: Long, val name: String, val update: Option[String], // represents a partial update within a task - val value: String) { + val value: String, + val internal: Boolean) { override def equals(other: Any): Boolean = other match { case acc: AccumulableInfo => @@ -40,10 +41,10 @@ class AccumulableInfo ( object AccumulableInfo { def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = { - new AccumulableInfo(id, name, update, value) + new AccumulableInfo(id, name, update, value, internal = false) } def apply(id: Long, name: String, value: String): AccumulableInfo = { - new AccumulableInfo(id, name, None, value) + new AccumulableInfo(id, name, None, value, internal = false) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c4fa277c21..bb489c6b6e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -773,16 +773,26 @@ class DAGScheduler( stage.pendingTasks.clear() // First figure out the indexes of partition ids to compute. - val partitionsToCompute: Seq[Int] = { + val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = { stage match { case stage: ShuffleMapStage => - (0 until stage.numPartitions).filter(id => stage.outputLocs(id).isEmpty) + val allPartitions = 0 until stage.numPartitions + val filteredPartitions = allPartitions.filter { id => stage.outputLocs(id).isEmpty } + (allPartitions, filteredPartitions) case stage: ResultStage => val job = stage.resultOfJob.get - (0 until job.numPartitions).filter(id => !job.finished(id)) + val allPartitions = 0 until job.numPartitions + val filteredPartitions = allPartitions.filter { id => !job.finished(id) } + (allPartitions, filteredPartitions) } } + // Reset internal accumulators only if this stage is not partially submitted + // Otherwise, we may override existing accumulator values from some tasks + if (allPartitions == partitionsToCompute) { + stage.resetInternalAccumulators() + } + val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull runningStages += stage @@ -852,7 +862,8 @@ class DAGScheduler( partitionsToCompute.map { id => val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs) + new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, + taskBinary, part, locs, stage.internalAccumulators) } case stage: ResultStage => @@ -861,7 +872,8 @@ class DAGScheduler( val p: Int = job.partitions(id) val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) - new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id) + new ResultTask(stage.id, stage.latestInfo.attemptId, + taskBinary, part, locs, id, stage.internalAccumulators) } } } catch { @@ -916,9 +928,11 @@ class DAGScheduler( // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && partialValue != acc.zero) { val name = acc.name.get - stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, s"${acc.value}") + val value = s"${acc.value}" + stage.latestInfo.accumulables(id) = + new AccumulableInfo(id, name, None, value, acc.isInternal) event.taskInfo.accumulables += - AccumulableInfo(id, name, Some(s"$partialValue"), s"${acc.value}") + new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 9c2606e278..c4dc080e2b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -45,8 +45,10 @@ private[spark] class ResultTask[T, U]( taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient locs: Seq[TaskLocation], - val outputId: Int) - extends Task[U](stageId, stageAttemptId, partition.index) with Serializable { + val outputId: Int, + internalAccumulators: Seq[Accumulator[Long]]) + extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators) + with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 14c8c00961..f478f9982a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -43,12 +43,14 @@ private[spark] class ShuffleMapTask( stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, - @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging { + @transient private var locs: Seq[TaskLocation], + internalAccumulators: Seq[Accumulator[Long]]) + extends Task[MapStatus](stageId, stageAttemptId, partition.index, internalAccumulators) + with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, 0, null, new Partition { override def index: Int = 0 }, null) + this(0, 0, null, new Partition { override def index: Int = 0 }, null, null) } @transient private val preferredLocs: Seq[TaskLocation] = { @@ -69,7 +71,7 @@ private[spark] class ShuffleMapTask( val manager = SparkEnv.get.shuffleManager writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) - return writer.stop(success = true).get + writer.stop(success = true).get } catch { case e: Exception => try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 40a333a3e0..de05ee256d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -68,6 +68,22 @@ private[spark] abstract class Stage( val name = callSite.shortForm val details = callSite.longForm + private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty + + /** Internal accumulators shared across all tasks in this stage. */ + def internalAccumulators: Seq[Accumulator[Long]] = _internalAccumulators + + /** + * Re-initialize the internal accumulators associated with this stage. + * + * This is called every time the stage is submitted, *except* when a subset of tasks + * belonging to this stage has already finished. Otherwise, reinitializing the internal + * accumulators here again will override partial values from the finished tasks. + */ + def resetInternalAccumulators(): Unit = { + _internalAccumulators = InternalAccumulator.create() + } + /** * Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized * here, before any attempts have actually been created, because the DAGScheduler uses this diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 1978305cfe..9edf9f048f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -23,7 +23,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext} +import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.unsafe.memory.TaskMemoryManager @@ -47,7 +47,8 @@ import org.apache.spark.util.Utils private[spark] abstract class Task[T]( val stageId: Int, val stageAttemptId: Int, - var partitionId: Int) extends Serializable { + val partitionId: Int, + internalAccumulators: Seq[Accumulator[Long]]) extends Serializable { /** * The key of the Map is the accumulator id and the value of the Map is the latest accumulator @@ -68,12 +69,13 @@ private[spark] abstract class Task[T]( metricsSystem: MetricsSystem) : (T, AccumulatorUpdates) = { context = new TaskContextImpl( - stageId = stageId, - partitionId = partitionId, - taskAttemptId = taskAttemptId, - attemptNumber = attemptNumber, - taskMemoryManager = taskMemoryManager, - metricsSystem = metricsSystem, + stageId, + partitionId, + taskAttemptId, + attemptNumber, + taskMemoryManager, + metricsSystem, + internalAccumulators, runningLocally = false) TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index de79fa56f0..0c8f08f0f3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{InterruptibleIterator, Logging, MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark._ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} @@ -100,8 +100,10 @@ private[spark] class HashShuffleReader[K, C]( // the ExternalSorter won't spill to disk. val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) sorter.insertAll(aggregatedIter) - context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) sorter.iterator case None => aggregatedIter diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index e2d25e3636..cb122eaed8 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -62,6 +62,13 @@ private[spark] object ToolTips { """Time that the executor spent paused for Java garbage collection while the task was running.""" + val PEAK_EXECUTION_MEMORY = + """Execution memory refers to the memory used by internal data structures created during + shuffles, aggregations and joins when Tungsten is enabled. The value of this accumulator + should be approximately the sum of the peak sizes across all such data structures created + in this task. For SQL jobs, this only tracks all unsafe operators, broadcast joins, and + external sort.""" + val JOB_TIMELINE = """Shows when jobs started and ended and when executors joined or left. Drag to scroll. Click Enable Zooming and use mouse wheel to zoom in/out.""" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index cf04b5e592..3954c3d1ef 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -26,6 +26,7 @@ import scala.xml.{Elem, Node, Unparsed} import org.apache.commons.lang3.StringEscapeUtils +import org.apache.spark.{InternalAccumulator, SparkConf} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} import org.apache.spark.ui._ @@ -67,6 +68,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { // if we find that it's okay. private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000) + private val displayPeakExecutionMemory = + parent.conf.getOption("spark.sql.unsafe.enabled").exists(_.toBoolean) def render(request: HttpServletRequest): Seq[Node] = { progressListener.synchronized { @@ -114,10 +117,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val stageData = stageDataOption.get val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime) - val numCompleted = tasks.count(_.taskInfo.finished) - val accumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables - val hasAccumulators = accumulables.size > 0 + + val allAccumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables + val externalAccumulables = allAccumulables.values.filter { acc => !acc.internal } + val hasAccumulators = externalAccumulables.size > 0 val summary =
@@ -221,6 +225,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Getting Result Time + {if (displayPeakExecutionMemory) { +
  • + + + Peak Execution Memory + +
  • + }}
    @@ -241,11 +254,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val accumulableTable = UIUtils.listingTable( accumulableHeaders, accumulableRow, - accumulables.values.toSeq) + externalAccumulables.toSeq) val currentTime = System.currentTimeMillis() val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( + parent.conf, UIUtils.prependBaseUri(parent.basePath) + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", tasks, @@ -294,12 +308,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { else { def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] = Distribution(data).get.getQuantiles() - def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = { getDistributionQuantiles(times).map { millis => {UIUtils.formatDuration(millis.toLong)} } } + def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = { + getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) + } val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.executorDeserializeTime.toDouble @@ -349,6 +365,23 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(gettingResultTimes) + + val peakExecutionMemory = validTasks.map { case TaskUIData(info, _, _) => + info.accumulables + .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } + .map { acc => acc.value.toLong } + .getOrElse(0L) + .toDouble + } + val peakExecutionMemoryQuantiles = { + + + Peak Execution Memory + + +: getFormattedSizeQuantiles(peakExecutionMemory) + } + // The scheduler delay includes the network delay to send the task to the worker // machine and to send back the result (but not the time to fetch the task result, // if it needed to be fetched from the block manager on the worker). @@ -359,10 +392,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { title={ToolTips.SCHEDULER_DELAY} data-placement="right">Scheduler Delay val schedulerDelayQuantiles = schedulerDelayTitle +: getFormattedTimeQuantiles(schedulerDelays) - - def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = - getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) - def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) : Seq[Elem] = { val recordDist = getDistributionQuantiles(records).iterator @@ -466,6 +495,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {serializationQuantiles} , {gettingResultQuantiles}, + if (displayPeakExecutionMemory) { + + {peakExecutionMemoryQuantiles} + + } else { + Nil + }, if (stageData.hasInput) {inputQuantiles} else Nil, if (stageData.hasOutput) {outputQuantiles} else Nil, if (stageData.hasShuffleRead) { @@ -499,7 +535,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val executorTable = new ExecutorTable(stageId, stageAttemptId, parent) val maybeAccumulableTable: Seq[Node] = - if (accumulables.size > 0) {

    Accumulators

    ++ accumulableTable } else Seq() + if (hasAccumulators) {

    Accumulators

    ++ accumulableTable } else Seq() val content = summary ++ @@ -750,29 +786,30 @@ private[ui] case class TaskTableRowBytesSpilledData( * Contains all data that needs for sorting and generating HTML. Using this one rather than * TaskUIData to avoid creating duplicate contents during sorting the data. */ -private[ui] case class TaskTableRowData( - index: Int, - taskId: Long, - attempt: Int, - speculative: Boolean, - status: String, - taskLocality: String, - executorIdAndHost: String, - launchTime: Long, - duration: Long, - formatDuration: String, - schedulerDelay: Long, - taskDeserializationTime: Long, - gcTime: Long, - serializationTime: Long, - gettingResultTime: Long, - accumulators: Option[String], // HTML - input: Option[TaskTableRowInputData], - output: Option[TaskTableRowOutputData], - shuffleRead: Option[TaskTableRowShuffleReadData], - shuffleWrite: Option[TaskTableRowShuffleWriteData], - bytesSpilled: Option[TaskTableRowBytesSpilledData], - error: String) +private[ui] class TaskTableRowData( + val index: Int, + val taskId: Long, + val attempt: Int, + val speculative: Boolean, + val status: String, + val taskLocality: String, + val executorIdAndHost: String, + val launchTime: Long, + val duration: Long, + val formatDuration: String, + val schedulerDelay: Long, + val taskDeserializationTime: Long, + val gcTime: Long, + val serializationTime: Long, + val gettingResultTime: Long, + val peakExecutionMemoryUsed: Long, + val accumulators: Option[String], // HTML + val input: Option[TaskTableRowInputData], + val output: Option[TaskTableRowOutputData], + val shuffleRead: Option[TaskTableRowShuffleReadData], + val shuffleWrite: Option[TaskTableRowShuffleWriteData], + val bytesSpilled: Option[TaskTableRowBytesSpilledData], + val error: String) private[ui] class TaskDataSource( tasks: Seq[TaskUIData], @@ -816,10 +853,15 @@ private[ui] class TaskDataSource( val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) val gettingResultTime = getGettingResultTime(info, currentTime) - val maybeAccumulators = info.accumulables - val accumulatorsReadable = maybeAccumulators.map { acc => + val (taskInternalAccumulables, taskExternalAccumulables) = + info.accumulables.partition(_.internal) + val externalAccumulableReadable = taskExternalAccumulables.map { acc => StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}") } + val peakExecutionMemoryUsed = taskInternalAccumulables + .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } + .map { acc => acc.value.toLong } + .getOrElse(0L) val maybeInput = metrics.flatMap(_.inputMetrics) val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L) @@ -923,7 +965,7 @@ private[ui] class TaskDataSource( None } - TaskTableRowData( + new TaskTableRowData( info.index, info.taskId, info.attempt, @@ -939,7 +981,8 @@ private[ui] class TaskDataSource( gcTime, serializationTime, gettingResultTime, - if (hasAccumulators) Some(accumulatorsReadable.mkString("
    ")) else None, + peakExecutionMemoryUsed, + if (hasAccumulators) Some(externalAccumulableReadable.mkString("
    ")) else None, input, output, shuffleRead, @@ -1006,6 +1049,10 @@ private[ui] class TaskDataSource( override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = Ordering.Long.compare(x.gettingResultTime, y.gettingResultTime) } + case "Peak Execution Memory" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.peakExecutionMemoryUsed, y.peakExecutionMemoryUsed) + } case "Accumulators" => if (hasAccumulators) { new Ordering[TaskTableRowData] { @@ -1132,6 +1179,7 @@ private[ui] class TaskDataSource( } private[ui] class TaskPagedTable( + conf: SparkConf, basePath: String, data: Seq[TaskUIData], hasAccumulators: Boolean, @@ -1143,7 +1191,11 @@ private[ui] class TaskPagedTable( currentTime: Long, pageSize: Int, sortColumn: String, - desc: Boolean) extends PagedTable[TaskTableRowData]{ + desc: Boolean) extends PagedTable[TaskTableRowData] { + + // We only track peak memory used for unsafe operators + private val displayPeakExecutionMemory = + conf.getOption("spark.sql.unsafe.enabled").exists(_.toBoolean) override def tableId: String = "" @@ -1195,6 +1247,13 @@ private[ui] class TaskPagedTable( ("GC Time", ""), ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ + { + if (displayPeakExecutionMemory) { + Seq(("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) + } else { + Nil + } + } ++ {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ {if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++ {if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ @@ -1271,6 +1330,11 @@ private[ui] class TaskPagedTable( {UIUtils.formatDuration(task.gettingResultTime)} + {if (displayPeakExecutionMemory) { + + {Utils.bytesToString(task.peakExecutionMemoryUsed)} + + }} {if (task.accumulators.nonEmpty) { {Unparsed(task.accumulators.get)} }} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index 9bf67db8ac..d2dfc5a329 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -31,4 +31,5 @@ private[spark] object TaskDetailsClassNames { val SHUFFLE_READ_REMOTE_SIZE = "shuffle_read_remote" val RESULT_SERIALIZATION_TIME = "serialization_time" val GETTING_RESULT_TIME = "getting_result_time" + val PEAK_EXECUTION_MEMORY = "peak_execution_memory" } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index d166037351..f929b12606 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -89,6 +89,7 @@ class ExternalAppendOnlyMap[K, V, C]( // Number of bytes spilled in total private var _diskBytesSpilled = 0L + def diskBytesSpilled: Long = _diskBytesSpilled // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = @@ -97,6 +98,10 @@ class ExternalAppendOnlyMap[K, V, C]( // Write metrics for current spill private var curWriteMetrics: ShuffleWriteMetrics = _ + // Peak size of the in-memory map observed so far, in bytes + private var _peakMemoryUsedBytes: Long = 0L + def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes + private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() @@ -126,7 +131,11 @@ class ExternalAppendOnlyMap[K, V, C]( while (entries.hasNext) { curEntry = entries.next() - if (maybeSpill(currentMap, currentMap.estimateSize())) { + val estimatedSize = currentMap.estimateSize() + if (estimatedSize > _peakMemoryUsedBytes) { + _peakMemoryUsedBytes = estimatedSize + } + if (maybeSpill(currentMap, estimatedSize)) { currentMap = new SizeTrackingAppendOnlyMap[K, C] } currentMap.changeValue(curEntry._1, update) @@ -207,8 +216,6 @@ class ExternalAppendOnlyMap[K, V, C]( spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) } - def diskBytesSpilled: Long = _diskBytesSpilled - /** * Return an iterator that merges the in-memory map with the spilled maps. * If no spill has occurred, simply return the in-memory map's iterator. diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index ba7ec834d6..19287edbaf 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -152,6 +152,9 @@ private[spark] class ExternalSorter[K, V, C]( private var _diskBytesSpilled = 0L def diskBytesSpilled: Long = _diskBytesSpilled + // Peak size of the in-memory data structure observed so far, in bytes + private var _peakMemoryUsedBytes: Long = 0L + def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the @@ -224,15 +227,22 @@ private[spark] class ExternalSorter[K, V, C]( return } + var estimatedSize = 0L if (usingMap) { - if (maybeSpill(map, map.estimateSize())) { + estimatedSize = map.estimateSize() + if (maybeSpill(map, estimatedSize)) { map = new PartitionedAppendOnlyMap[K, C] } } else { - if (maybeSpill(buffer, buffer.estimateSize())) { + estimatedSize = buffer.estimateSize() + if (maybeSpill(buffer, estimatedSize)) { buffer = newBuffer() } } + + if (estimatedSize > _peakMemoryUsedBytes) { + _peakMemoryUsedBytes = estimatedSize + } } /** @@ -684,8 +694,10 @@ private[spark] class ExternalSorter[K, V, C]( } } - context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemoryUsedBytes) lengths } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index e948ca3347..ffe4b4baff 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -51,7 +51,6 @@ import org.junit.Test; import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.*; -import org.apache.spark.executor.TaskMetrics; import org.apache.spark.input.PortableDataStream; import org.apache.spark.partial.BoundedDouble; import org.apache.spark.partial.PartialResult; @@ -1011,7 +1010,7 @@ public class JavaAPISuite implements Serializable { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, null, false, new TaskMetrics()); + TaskContext context = TaskContext$.MODULE$.empty(); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 04fc09b323..98c32bbc29 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -190,6 +190,7 @@ public class UnsafeShuffleWriterSuite { }); when(taskContext.taskMetrics()).thenReturn(taskMetrics); + when(taskContext.internalMetricsToAccumulators()).thenReturn(null); when(shuffleDep.serializer()).thenReturn(Option.apply(serializer)); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); @@ -542,4 +543,57 @@ public class UnsafeShuffleWriterSuite { writer.stop(false); assertSpillFilesWereCleanedUp(); } + + @Test + public void testPeakMemoryUsed() throws Exception { + final long recordLengthBytes = 8; + final long pageSizeBytes = 256; + final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; + final SparkConf conf = new SparkConf().set("spark.buffer.pageSize", pageSizeBytes + "b"); + final UnsafeShuffleWriter writer = + new UnsafeShuffleWriter( + blockManager, + shuffleBlockResolver, + taskMemoryManager, + shuffleMemoryManager, + new UnsafeShuffleHandle(0, 1, shuffleDep), + 0, // map id + taskContext, + conf); + + // Peak memory should be monotonically increasing. More specifically, every time + // we allocate a new page it should increase by exactly the size of the page. + long previousPeakMemory = writer.getPeakMemoryUsedBytes(); + long newPeakMemory; + try { + for (int i = 0; i < numRecordsPerPage * 10; i++) { + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + if (i % numRecordsPerPage == 0) { + // We allocated a new page for this record, so peak memory should change + assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); + } else { + assertEquals(previousPeakMemory, newPeakMemory); + } + previousPeakMemory = newPeakMemory; + } + + // Spilling should not change peak memory + writer.forceSorterToSpill(); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + for (int i = 0; i < numRecordsPerPage; i++) { + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + } + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + + // Closing the writer should not change peak memory + writer.closeAndWriteOutput(); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + } finally { + writer.stop(false); + } + } } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index dbb7c662d7..0e23a64fb7 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -25,6 +25,7 @@ import org.junit.*; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import static org.hamcrest.Matchers.greaterThan; +import static org.junit.Assert.*; import static org.mockito.AdditionalMatchers.geq; import static org.mockito.Mockito.*; @@ -495,4 +496,42 @@ public abstract class AbstractBytesToBytesMapSuite { map.growAndRehash(); map.free(); } + + @Test + public void testTotalMemoryConsumption() { + final long recordLengthBytes = 24; + final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker + final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes; + final BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, 1024, pageSizeBytes); + + // Since BytesToBytesMap is append-only, we expect the total memory consumption to be + // monotonically increasing. More specifically, every time we allocate a new page it + // should increase by exactly the size of the page. In this regard, the memory usage + // at any given time is also the peak memory used. + long previousMemory = map.getTotalMemoryConsumption(); + long newMemory; + try { + for (long i = 0; i < numRecordsPerPage * 10; i++) { + final long[] value = new long[]{i}; + map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8).putNewKey( + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8, + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8); + newMemory = map.getTotalMemoryConsumption(); + if (i % numRecordsPerPage == 0) { + // We allocated a new page for this record, so peak memory should change + assertEquals(previousMemory + pageSizeBytes, newMemory); + } else { + assertEquals(previousMemory, newMemory); + } + previousMemory = newMemory; + } + } finally { + map.free(); + } + } } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 52fa8bcd57..c11949d57a 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -247,4 +247,50 @@ public class UnsafeExternalSorterSuite { assertSpillFilesWereCleanedUp(); } + @Test + public void testPeakMemoryUsed() throws Exception { + final long recordLengthBytes = 8; + final long pageSizeBytes = 256; + final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; + final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( + taskMemoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + recordComparator, + prefixComparator, + 1024, + pageSizeBytes); + + // Peak memory should be monotonically increasing. More specifically, every time + // we allocate a new page it should increase by exactly the size of the page. + long previousPeakMemory = sorter.getPeakMemoryUsedBytes(); + long newPeakMemory; + try { + for (int i = 0; i < numRecordsPerPage * 10; i++) { + insertNumber(sorter, i); + newPeakMemory = sorter.getPeakMemoryUsedBytes(); + if (i % numRecordsPerPage == 0) { + // We allocated a new page for this record, so peak memory should change + assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); + } else { + assertEquals(previousPeakMemory, newPeakMemory); + } + previousPeakMemory = newPeakMemory; + } + + // Spilling should not change peak memory + sorter.spill(); + newPeakMemory = sorter.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + for (int i = 0; i < numRecordsPerPage; i++) { + insertNumber(sorter, i); + } + newPeakMemory = sorter.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + } finally { + sorter.freeMemory(); + } + } + } diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index e942d6579b..48f549575f 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -18,13 +18,17 @@ package org.apache.spark import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.ref.WeakReference import org.scalatest.Matchers +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.scheduler._ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { - + import InternalAccumulator._ implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = new AccumulableParam[mutable.Set[A], A] { @@ -155,4 +159,191 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(!Accumulators.originals.get(accId).isDefined) } + test("internal accumulators in TaskContext") { + val accums = InternalAccumulator.create() + val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, accums) + val internalMetricsToAccums = taskContext.internalMetricsToAccumulators + val collectedInternalAccums = taskContext.collectInternalAccumulators() + val collectedAccums = taskContext.collectAccumulators() + assert(internalMetricsToAccums.size > 0) + assert(internalMetricsToAccums.values.forall(_.isInternal)) + assert(internalMetricsToAccums.contains(TEST_ACCUMULATOR)) + val testAccum = internalMetricsToAccums(TEST_ACCUMULATOR) + assert(collectedInternalAccums.size === internalMetricsToAccums.size) + assert(collectedInternalAccums.size === collectedAccums.size) + assert(collectedInternalAccums.contains(testAccum.id)) + assert(collectedAccums.contains(testAccum.id)) + } + + test("internal accumulators in a stage") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + // Have each task add 1 to the internal accumulator + sc.parallelize(1 to 100, numPartitions).mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 + iter + }.count() + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions) + // The accumulator values should be merged in the stage + val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) + assert(stageAccum.value.toLong === numPartitions) + // The accumulator should be updated locally on each task + val taskAccumValues = taskInfos.map { taskInfo => + val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.toLong === 1) + taskAccum.value.toLong + } + // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + } + + test("internal accumulators in multiple stages") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + // Each stage creates its own set of internal accumulators so the + // values for the same metric should not be mixed up across stages + sc.parallelize(1 to 100, numPartitions) + .map { i => (i, i) } + .mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 + iter + } + .reduceByKey { case (x, y) => x + y } + .mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 10 + iter + } + .repartition(numPartitions * 2) + .mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 100 + iter + } + .count() + // We ran 3 stages, and the accumulator values should be distinct + val stageInfos = listener.getCompletedStageInfos + assert(stageInfos.size === 3) + val firstStageAccum = findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR) + val secondStageAccum = findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR) + val thirdStageAccum = findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR) + assert(firstStageAccum.value.toLong === numPartitions) + assert(secondStageAccum.value.toLong === numPartitions * 10) + assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100) + } + + test("internal accumulators in fully resubmitted stages") { + testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks + } + + test("internal accumulators in partially resubmitted stages") { + testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset + } + + /** + * Return the accumulable info that matches the specified name. + */ + private def findAccumulableInfo( + accums: Iterable[AccumulableInfo], + name: String): AccumulableInfo = { + accums.find { a => a.name == name }.getOrElse { + throw new TestFailedException(s"internal accumulator '$name' not found", 0) + } + } + + /** + * Test whether internal accumulators are merged properly if some tasks fail. + */ + private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = { + val listener = new SaveInfoListener + val numPartitions = 10 + val numFailedPartitions = (0 until numPartitions).count(failCondition) + // This says use 1 core and retry tasks up to 2 times + sc = new SparkContext("local[1, 2]", "test") + sc.addSparkListener(listener) + sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => + val taskContext = TaskContext.get() + taskContext.internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 + // Fail the first attempts of a subset of the tasks + if (failCondition(i) && taskContext.attemptNumber() == 0) { + throw new Exception("Failing a task intentionally.") + } + iter + }.count() + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions + numFailedPartitions) + val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) + // We should not double count values in the merged accumulator + assert(stageAccum.value.toLong === numPartitions) + val taskAccumValues = taskInfos.flatMap { taskInfo => + if (!taskInfo.failed) { + // If a task succeeded, its update value should always be 1 + val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.toLong === 1) + Some(taskAccum.value.toLong) + } else { + // If a task failed, we should not get its accumulator values + assert(taskInfo.accumulables.isEmpty) + None + } + } + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + } + +} + +private[spark] object AccumulatorSuite { + + /** + * Run one or more Spark jobs and verify that the peak execution memory accumulator + * is updated afterwards. + */ + def verifyPeakExecutionMemorySet( + sc: SparkContext, + testName: String)(testBody: => Unit): Unit = { + val listener = new SaveInfoListener + sc.addSparkListener(listener) + // Verify that the accumulator does not already exist + sc.parallelize(1 to 10).count() + val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) + assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)) + testBody + // Verify that peak execution memory is updated + val accum = listener.getCompletedStageInfos + .flatMap(_.accumulables.values) + .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) + .getOrElse { + throw new TestFailedException( + s"peak execution memory accumulator not set in '$testName'", 0) + } + assert(accum.value.toLong > 0) + } +} + +/** + * A simple listener that keeps track of the TaskInfos and StageInfos of all completed jobs. + */ +private class SaveInfoListener extends SparkListener { + private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo] + private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo] + + def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq + def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq + + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + completedStageInfos += stageCompleted.stageInfo + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + completedTaskInfos += taskEnd.taskInfo + } } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 618a5fb247..cb8bd04e49 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -21,7 +21,7 @@ import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.scalatest.mock.MockitoSugar -import org.apache.spark.executor.DataReadMethod +import org.apache.spark.executor.{DataReadMethod, TaskMetrics} import org.apache.spark.rdd.RDD import org.apache.spark.storage._ @@ -65,7 +65,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0, null, null) + val context = TaskContext.empty() val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -77,7 +77,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result)) - val context = new TaskContextImpl(0, 0, 0, 0, null, null) + val context = TaskContext.empty() val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -86,14 +86,14 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before // Local computation should not persist the resulting value, so don't expect a put(). when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - val context = new TaskContextImpl(0, 0, 0, 0, null, null, true) + val context = new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty, runningLocally = true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0, null, null) + val context = TaskContext.empty() cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 3e8816a4c6..5f73ec8675 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -175,7 +175,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContextImpl(0, 0, 0, 0, null, null) + val tContext = TaskContext.empty() val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index b3ca150195..f7e16af9d3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -19,9 +19,11 @@ package org.apache.spark.scheduler import org.apache.spark.TaskContext -class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) { +class FakeTask( + stageId: Int, + prefLocs: Seq[TaskLocation] = Nil) + extends Task[Int](stageId, 0, 0, Seq.empty) { override def runTask(context: TaskContext): Int = 0 - override def preferredLocations: Seq[TaskLocation] = prefLocs } diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index 383855caef..f333247924 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -25,7 +25,7 @@ import org.apache.spark.TaskContext * A Task implementation that fails to serialize. */ private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) - extends Task[Array[Byte]](stageId, 0, 0) { + extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) { override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 9201d1e1f3..450ab7b9fe 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -57,8 +57,9 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() - val task = new ResultTask[String, String](0, 0, - sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) + val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).array) + val task = new ResultTask[String, String]( + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty) intercept[RuntimeException] { task.run(0, 0, null) } @@ -66,7 +67,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContextImpl(0, 0, 0, 0, null, null) + val context = TaskContext.empty() val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 3abb99c4b2..f7cc4bb61d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -136,7 +136,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex /** * A Task implementation that results in a large serialized task. */ -class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) { +class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) { val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024) val random = new Random(0) random.nextBytes(randomBuffer) diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala index db718ecabb..05b3afef5b 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala @@ -138,7 +138,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { shuffleHandle, reduceId, reduceId + 1, - new TaskContextImpl(0, 0, 0, 0, null, null), + TaskContext.empty(), blockManager, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index cf8bd8ae69..828153bdbf 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -29,7 +29,7 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.PrivateMethodTester -import org.apache.spark.{SparkFunSuite, TaskContextImpl} +import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener @@ -95,7 +95,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContextImpl(0, 0, 0, 0, null, null), + TaskContext.empty(), transfer, blockManager, blocksByAddress, @@ -165,7 +165,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null) + val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, @@ -227,7 +227,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null) + val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala new file mode 100644 index 0000000000..98f9314f31 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite, Success} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ +import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab} +import org.apache.spark.ui.scope.RDDOperationGraphListener + +class StagePageSuite extends SparkFunSuite with LocalSparkContext { + + test("peak execution memory only displayed if unsafe is enabled") { + val unsafeConf = "spark.sql.unsafe.enabled" + val conf = new SparkConf().set(unsafeConf, "true") + val html = renderStagePage(conf).toString().toLowerCase + val targetString = "peak execution memory" + assert(html.contains(targetString)) + // Disable unsafe and make sure it's not there + val conf2 = new SparkConf().set(unsafeConf, "false") + val html2 = renderStagePage(conf2).toString().toLowerCase + assert(!html2.contains(targetString)) + } + + /** + * Render a stage page started with the given conf and return the HTML. + * This also runs a dummy stage to populate the page with useful content. + */ + private def renderStagePage(conf: SparkConf): Seq[Node] = { + val jobListener = new JobProgressListener(conf) + val graphListener = new RDDOperationGraphListener(conf) + val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) + val request = mock(classOf[HttpServletRequest]) + when(tab.conf).thenReturn(conf) + when(tab.progressListener).thenReturn(jobListener) + when(tab.operationGraphListener).thenReturn(graphListener) + when(tab.appName).thenReturn("testing") + when(tab.headerTabs).thenReturn(Seq.empty) + when(request.getParameter("id")).thenReturn("0") + when(request.getParameter("attempt")).thenReturn("0") + val page = new StagePage(tab) + + // Simulate a stage in job progress listener + val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") + val taskInfo = new TaskInfo(0, 0, 0, 0, "0", "localhost", TaskLocality.ANY, false) + jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) + jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) + taskInfo.markSuccessful() + jobListener.onTaskEnd( + SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty)) + jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) + page.render(request) + } + +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 9c362f0de7..12e9bafcc9 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -399,4 +399,19 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } + test("external aggregation updates peak execution memory") { + val conf = createSparkConf(loadDefaults = false) + .set("spark.shuffle.memoryFraction", "0.001") + .set("spark.shuffle.manager", "hash") // make sure we're not also using ExternalSorter + sc = new SparkContext("local", "test", conf) + // No spilling + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map without spilling") { + sc.parallelize(1 to 10, 2).map { i => (i, i) }.reduceByKey(_ + _).count() + } + // With spilling + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map with spilling") { + sc.parallelize(1 to 1000 * 1000, 2).map { i => (i, i) }.reduceByKey(_ + _).count() + } + } + } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 986cd8623d..bdb0f4d507 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -692,7 +692,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { sortWithoutBreakingSortingContracts(createSparkConf(true, false)) } - def sortWithoutBreakingSortingContracts(conf: SparkConf) { + private def sortWithoutBreakingSortingContracts(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.01") conf.set("spark.shuffle.manager", "sort") sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) @@ -743,5 +743,15 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } sorter2.stop() - } + } + + test("sorting updates peak execution memory") { + val conf = createSparkConf(loadDefaults = false, kryo = false) + .set("spark.shuffle.manager", "sort") + sc = new SparkContext("local", "test", conf) + // Avoid aggregating here to make sure we're not also using ExternalAppendOnlyMap + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sorter") { + sc.parallelize(1 to 1000, 2).repartition(100).count() + } + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 5e4c6232c9..193906d247 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -106,6 +106,13 @@ final class UnsafeExternalRowSorter { sorter.spill(); } + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsage() { + return sorter.getPeakMemoryUsedBytes(); + } + private void cleanupResources() { sorter.freeMemory(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 9e2c9334a7..43d06ce9bd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -208,6 +208,14 @@ public final class UnsafeFixedWidthAggregationMap { }; } + /** + * The memory used by this map's managed structures, in bytes. + * Note that this is also the peak memory used by this map, since the map is append-only. + */ + public long getMemoryUsage() { + return map.getTotalMemoryConsumption(); + } + /** * Free the memory associated with this map. This is idempotent and can be called multiple times. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index cd87b8deba..bf4905dc1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import java.io.IOException -import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -263,11 +263,12 @@ case class GeneratedAggregate( assert(iter.hasNext, "There should be at least one row for this path") log.info("Using Unsafe-based aggregator") val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") + val taskContext = TaskContext.get() val aggregationMap = new UnsafeFixedWidthAggregationMap( newAggregationBuffer(EmptyRow), aggregationBufferSchema, groupKeySchema, - TaskContext.get.taskMemoryManager(), + taskContext.taskMemoryManager(), SparkEnv.get.shuffleMemoryManager, 1024 * 16, // initial capacity pageSizeBytes, @@ -284,6 +285,10 @@ case class GeneratedAggregate( updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) } + // Record memory used in the process + taskContext.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(aggregationMap.getMemoryUsage) + new Iterator[InternalRow] { private[this] val mapIterator = aggregationMap.iterator() private[this] val resultProjection = resultProjectionBuilder() @@ -300,7 +305,7 @@ case class GeneratedAggregate( } else { // This is the last element in the iterator, so let's free the buffer. Before we do, // though, we need to make a defensive copy of the result so that we don't return an - // object that might contain dangling pointers to the freed memory + // object that might contain dangling pointers to the freed memory. val resultCopy = result.copy() aggregationMap.free() resultCopy diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 624efc1b1d..e73e2523a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -70,7 +71,14 @@ case class BroadcastHashJoin( val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => - hashJoin(streamedIter, broadcastRelation.value) + val hashedRelation = broadcastRelation.value + hashedRelation match { + case unsafe: UnsafeHashedRelation => + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + case _ => + } + hashJoin(streamedIter, hashedRelation) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 309716a0ef..c35e439cc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -75,6 +76,13 @@ case class BroadcastHashOuterJoin( val hashTable = broadcastRelation.value val keyGenerator = streamedKeyGenerator + hashTable match { + case unsafe: UnsafeHashedRelation => + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + case _ => + } + joinType match { case LeftOuter => streamedIter.flatMap(currentRow => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index a60593911f..5bd06fbdca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -51,7 +52,14 @@ case class BroadcastLeftSemiJoinHash( val broadcastedRelation = sparkContext.broadcast(hashRelation) left.execute().mapPartitions { streamIter => - hashSemiJoin(streamIter, broadcastedRelation.value) + val hashedRelation = broadcastedRelation.value + hashedRelation match { + case unsafe: UnsafeHashedRelation => + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + case _ => + } + hashSemiJoin(streamIter, hashedRelation) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index cc8bbfd2f8..58b4236f7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -183,8 +183,27 @@ private[joins] final class UnsafeHashedRelation( private[joins] def this() = this(null) // Needed for serialization // Use BytesToBytesMap in executor for better performance (it's created when deserialization) + // This is used in broadcast joins and distributed mode only @transient private[this] var binaryMap: BytesToBytesMap = _ + /** + * Return the size of the unsafe map on the executors. + * + * For broadcast joins, this hashed relation is bigger on the driver because it is + * represented as a Java hash map there. While serializing the map to the executors, + * however, we rehash the contents in a binary map to reduce the memory footprint on + * the executors. + * + * For non-broadcast joins or in local mode, return 0. + */ + def getUnsafeSize: Long = { + if (binaryMap != null) { + binaryMap.getTotalMemoryConsumption + } else { + 0 + } + } + override def get(key: InternalRow): Seq[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] @@ -214,7 +233,7 @@ private[joins] final class UnsafeHashedRelation( } } else { - // Use the JavaHashMap in Local mode or ShuffleHashJoin + // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin) hashTable.get(unsafeKey) } } @@ -316,6 +335,7 @@ private[joins] object UnsafeHashedRelation { keyGenerator: UnsafeProjection, sizeEstimate: Int): HashedRelation = { + // Use a Java hash table here because unsafe maps expect fixed size records val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) // Create a mapping of buildKeys -> rows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 92cf328c76..3192b6ebe9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -76,6 +77,11 @@ case class ExternalSort( val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering)) sorter.insertAll(iterator.map(r => (r.copy(), null))) val baseIterator = sorter.iterator.map(_._1) + val context = TaskContext.get() + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) // TODO(marmbrus): The complex type signature below thwarts inference for no reason. CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) }, preservesPartitioning = true) @@ -137,7 +143,11 @@ case class TungstenSort( if (testSpillFrequency > 0) { sorter.setTestSpillFrequency(testSpillFrequency) } - sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + val taskContext = TaskContext.get() + taskContext.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage) + sortedIterator }, preservesPartitioning = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f1abae0720..29dfcf2575 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,6 +21,7 @@ import java.sql.Timestamp import org.scalatest.BeforeAndAfterAll +import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException @@ -258,6 +259,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } } + private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { + val df = sql(sqlText) + // First, check if we have GeneratedAggregate. + val hasGeneratedAgg = df.queryExecution.executedPlan + .collect { case _: GeneratedAggregate | _: aggregate.Aggregate => true } + .nonEmpty + if (!hasGeneratedAgg) { + fail( + s""" + |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. + |${df.queryExecution.simpleString} + """.stripMargin) + } + // Then, check results. + checkAnswer(df, expectedResults) + } + test("aggregation with codegen") { val originalValue = sqlContext.conf.codegenEnabled sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) @@ -267,26 +285,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { .unionAll(sqlContext.table("testData")) .registerTempTable("testData3x") - def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { - val df = sql(sqlText) - // First, check if we have GeneratedAggregate. - var hasGeneratedAgg = false - df.queryExecution.executedPlan.foreach { - case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true - case newAggregate: aggregate.Aggregate => hasGeneratedAgg = true - case _ => - } - if (!hasGeneratedAgg) { - fail( - s""" - |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. - |${df.queryExecution.simpleString} - """.stripMargin) - } - // Then, check results. - checkAnswer(df, expectedResults) - } - try { // Just to group rows. testCodeGen( @@ -1605,6 +1603,28 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) } + test("aggregation with codegen updates peak execution memory") { + withSQLConf( + (SQLConf.CODEGEN_ENABLED.key, "true"), + (SQLConf.USE_SQL_AGGREGATE2.key, "false")) { + val sc = sqlContext.sparkContext + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "aggregation with codegen") { + testCodeGen( + "SELECT key, count(value) FROM testData GROUP BY key", + (1 to 100).map(i => Row(i, 1))) + } + } + } + + test("external sorting updates peak execution memory") { + withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) { + val sc = sqlContext.sparkContext + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sort") { + sortTest() + } + } + } + test("SPARK-9511: error with table starting with number") { val df = sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) .toDF("num", "str") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index c794984851..88bce0e319 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.scalatest.BeforeAndAfterAll +import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.TestSQLContext @@ -59,6 +60,17 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { ) } + test("sorting updates peak execution memory") { + val sc = TestSQLContext.sparkContext + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => TungstenSort('a.asc :: Nil, true, child), + (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child), + sortAnswers = false) + } + } + // Test sorting on different data types for ( dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 7c591f6143..ef827b0fe9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -69,7 +69,8 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { taskAttemptId = Random.nextInt(10000), attemptNumber = 0, taskMemoryManager = taskMemoryManager, - metricsSystem = null)) + metricsSystem = null, + internalAccumulators = Seq.empty)) try { f diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 0282b25b9d..601a5a07ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -76,7 +76,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite { taskAttemptId = 98456, attemptNumber = 0, taskMemoryManager = taskMemMgr, - metricsSystem = null)) + metricsSystem = null, + internalAccumulators = Seq.empty)) // Create the data converters val kExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala new file mode 100644 index 0000000000..0554e11d25 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -0,0 +1,94 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +// TODO: uncomment the test here! It is currently failing due to +// bad interaction with org.apache.spark.sql.test.TestSQLContext. + +// scalastyle:off +//package org.apache.spark.sql.execution.joins +// +//import scala.reflect.ClassTag +// +//import org.scalatest.BeforeAndAfterAll +// +//import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} +//import org.apache.spark.sql.functions._ +//import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} +// +///** +// * Test various broadcast join operators with unsafe enabled. +// * +// * This needs to be its own suite because [[org.apache.spark.sql.test.TestSQLContext]] runs +// * in local mode, but for tests in this suite we need to run Spark in local-cluster mode. +// * In particular, the use of [[org.apache.spark.unsafe.map.BytesToBytesMap]] in +// * [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered without +// * serializing the hashed relation, which does not happen in local mode. +// */ +//class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { +// private var sc: SparkContext = null +// private var sqlContext: SQLContext = null +// +// /** +// * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. +// */ +// override def beforeAll(): Unit = { +// super.beforeAll() +// val conf = new SparkConf() +// .setMaster("local-cluster[2,1,1024]") +// .setAppName("testing") +// sc = new SparkContext(conf) +// sqlContext = new SQLContext(sc) +// sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) +// sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) +// } +// +// override def afterAll(): Unit = { +// sc.stop() +// sc = null +// sqlContext = null +// } +// +// /** +// * Test whether the specified broadcast join updates the peak execution memory accumulator. +// */ +// private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { +// AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) { +// val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") +// val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") +// // Comparison at the end is for broadcast left semi join +// val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") +// val df3 = df1.join(broadcast(df2), joinExpression, joinType) +// val plan = df3.queryExecution.executedPlan +// assert(plan.collect { case p: T => p }.size === 1) +// plan.executeCollect() +// } +// } +// +// test("unsafe broadcast hash join updates peak execution memory") { +// testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner") +// } +// +// test("unsafe broadcast hash outer join updates peak execution memory") { +// testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") +// } +// +// test("unsafe broadcast left semi join updates peak execution memory") { +// testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi") +// } +// +//} +// scalastyle:on