[SPARK-11423] remove MapPartitionsWithPreparationRDD
Since we do not need to preserve a page before calling compute(), MapPartitionsWithPreparationRDD is not needed anymore. This PR basically revert #8543, #8511, #8038, #8011 Author: Davies Liu <davies@databricks.com> Closes #9381 from davies/remove_prepare2.
This commit is contained in:
parent
bb5a2af034
commit
45029bfdea
|
@ -1,66 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.rdd
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
import org.apache.spark.{Partition, Partitioner, TaskContext}
|
||||
|
||||
/**
|
||||
* An RDD that applies a user provided function to every partition of the parent RDD, and
|
||||
* additionally allows the user to prepare each partition before computing the parent partition.
|
||||
*/
|
||||
private[spark] class MapPartitionsWithPreparationRDD[U: ClassTag, T: ClassTag, M: ClassTag](
|
||||
prev: RDD[T],
|
||||
preparePartition: () => M,
|
||||
executePartition: (TaskContext, Int, M, Iterator[T]) => Iterator[U],
|
||||
preservesPartitioning: Boolean = false)
|
||||
extends RDD[U](prev) {
|
||||
|
||||
override val partitioner: Option[Partitioner] = {
|
||||
if (preservesPartitioning) firstParent[T].partitioner else None
|
||||
}
|
||||
|
||||
override def getPartitions: Array[Partition] = firstParent[T].partitions
|
||||
|
||||
// In certain join operations, prepare can be called on the same partition multiple times.
|
||||
// In this case, we need to ensure that each call to compute gets a separate prepare argument.
|
||||
private[this] val preparedArguments: ArrayBuffer[M] = new ArrayBuffer[M]
|
||||
|
||||
/**
|
||||
* Prepare a partition for a single call to compute.
|
||||
*/
|
||||
def prepare(): Unit = {
|
||||
preparedArguments += preparePartition()
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare a partition before computing it from its parent.
|
||||
*/
|
||||
override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
|
||||
val prepared =
|
||||
if (preparedArguments.isEmpty) {
|
||||
preparePartition()
|
||||
} else {
|
||||
preparedArguments.remove(0)
|
||||
}
|
||||
val parentIterator = firstParent[T].iterator(partition, context)
|
||||
executePartition(context, partition.index, prepared, parentIterator)
|
||||
}
|
||||
}
|
|
@ -73,16 +73,6 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag](
|
|||
super.clearDependencies()
|
||||
rdds = null
|
||||
}
|
||||
|
||||
/**
|
||||
* Call the prepare method of every parent that has one.
|
||||
* This is needed for reserving execution memory in advance.
|
||||
*/
|
||||
protected def tryPrepareParents(): Unit = {
|
||||
rdds.collect {
|
||||
case rdd: MapPartitionsWithPreparationRDD[_, _, _] => rdd.prepare()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag](
|
||||
|
@ -94,7 +84,6 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]
|
|||
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) {
|
||||
|
||||
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
|
||||
tryPrepareParents()
|
||||
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
|
||||
f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context))
|
||||
}
|
||||
|
@ -118,7 +107,6 @@ private[spark] class ZippedPartitionsRDD3
|
|||
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) {
|
||||
|
||||
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
|
||||
tryPrepareParents()
|
||||
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
|
||||
f(rdd1.iterator(partitions(0), context),
|
||||
rdd2.iterator(partitions(1), context),
|
||||
|
@ -146,7 +134,6 @@ private[spark] class ZippedPartitionsRDD4
|
|||
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) {
|
||||
|
||||
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
|
||||
tryPrepareParents()
|
||||
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
|
||||
f(rdd1.iterator(partitions(0), context),
|
||||
rdd2.iterator(partitions(1), context),
|
||||
|
|
|
@ -1,66 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.rdd
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite, TaskContext}
|
||||
|
||||
class MapPartitionsWithPreparationRDDSuite extends SparkFunSuite with LocalSparkContext {
|
||||
|
||||
test("prepare called before parent partition is computed") {
|
||||
sc = new SparkContext("local", "test")
|
||||
|
||||
// Have the parent partition push a number to the list
|
||||
val parent = sc.parallelize(1 to 100, 1).mapPartitions { iter =>
|
||||
TestObject.things.append(20)
|
||||
iter
|
||||
}
|
||||
|
||||
// Push a different number during the prepare phase
|
||||
val preparePartition = () => { TestObject.things.append(10) }
|
||||
|
||||
// Push yet another number during the execution phase
|
||||
val executePartition = (
|
||||
taskContext: TaskContext,
|
||||
partitionIndex: Int,
|
||||
notUsed: Unit,
|
||||
parentIterator: Iterator[Int]) => {
|
||||
TestObject.things.append(30)
|
||||
TestObject.things.iterator
|
||||
}
|
||||
|
||||
// Verify that the numbers are pushed in the order expected
|
||||
val rdd = new MapPartitionsWithPreparationRDD[Int, Int, Unit](
|
||||
parent, preparePartition, executePartition)
|
||||
val result = rdd.collect()
|
||||
assert(result === Array(10, 20, 30))
|
||||
|
||||
TestObject.things.clear()
|
||||
// Zip two of these RDDs, both should be prepared before the parent is executed
|
||||
val rdd2 = new MapPartitionsWithPreparationRDD[Int, Int, Unit](
|
||||
parent, preparePartition, executePartition)
|
||||
val result2 = rdd.zipPartitions(rdd2)((a, b) => a).collect()
|
||||
assert(result2 === Array(10, 10, 20, 30, 20, 30))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private object TestObject {
|
||||
val things = new mutable.ListBuffer[Int]
|
||||
}
|
|
@ -107,7 +107,11 @@ object MimaExcludes {
|
|||
"org.apache.spark.sql.SQLContext.createSession")
|
||||
) ++ Seq(
|
||||
ProblemFilters.exclude[MissingMethodProblem](
|
||||
"org.apache.spark.SparkContext.preferredNodeLocationData_=")
|
||||
"org.apache.spark.SparkContext.preferredNodeLocationData_="),
|
||||
ProblemFilters.exclude[MissingClassProblem](
|
||||
"org.apache.spark.rdd.MapPartitionsWithPreparationRDD"),
|
||||
ProblemFilters.exclude[MissingClassProblem](
|
||||
"org.apache.spark.rdd.MapPartitionsWithPreparationRDD$")
|
||||
)
|
||||
case v if v.startsWith("1.5") =>
|
||||
Seq(
|
||||
|
|
|
@ -19,9 +19,8 @@ package org.apache.spark.sql.execution;
|
|||
|
||||
import java.io.IOException;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
|
||||
import org.apache.spark.SparkEnv;
|
||||
import org.apache.spark.memory.TaskMemoryManager;
|
||||
import org.apache.spark.sql.catalyst.InternalRow;
|
||||
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
|
||||
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
|
||||
|
@ -31,7 +30,6 @@ import org.apache.spark.unsafe.KVIterator;
|
|||
import org.apache.spark.unsafe.Platform;
|
||||
import org.apache.spark.unsafe.map.BytesToBytesMap;
|
||||
import org.apache.spark.unsafe.memory.MemoryLocation;
|
||||
import org.apache.spark.memory.TaskMemoryManager;
|
||||
|
||||
/**
|
||||
* Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
|
||||
|
@ -218,11 +216,6 @@ public final class UnsafeFixedWidthAggregationMap {
|
|||
return map.getPeakMemoryUsedBytes();
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public int getNumDataPages() {
|
||||
return map.getNumDataPages();
|
||||
}
|
||||
|
||||
/**
|
||||
* Free the memory associated with this map. This is idempotent and can be called multiple times.
|
||||
*/
|
||||
|
|
|
@ -17,15 +17,14 @@
|
|||
|
||||
package org.apache.spark.sql.execution.aggregate
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.errors._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnaryNode, SparkPlan}
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
case class TungstenAggregate(
|
||||
|
@ -84,59 +83,39 @@ case class TungstenAggregate(
|
|||
val dataSize = longMetric("dataSize")
|
||||
val spillSize = longMetric("spillSize")
|
||||
|
||||
/**
|
||||
* Set up the underlying unsafe data structures used before computing the parent partition.
|
||||
* This makes sure our iterator is not starved by other operators in the same task.
|
||||
*/
|
||||
def preparePartition(): TungstenAggregationIterator = {
|
||||
new TungstenAggregationIterator(
|
||||
groupingExpressions,
|
||||
nonCompleteAggregateExpressions,
|
||||
nonCompleteAggregateAttributes,
|
||||
completeAggregateExpressions,
|
||||
completeAggregateAttributes,
|
||||
initialInputBufferOffset,
|
||||
resultExpressions,
|
||||
newMutableProjection,
|
||||
child.output,
|
||||
testFallbackStartsAt,
|
||||
numInputRows,
|
||||
numOutputRows,
|
||||
dataSize,
|
||||
spillSize)
|
||||
}
|
||||
child.execute().mapPartitions { iter =>
|
||||
|
||||
/** Compute a partition using the iterator already set up previously. */
|
||||
def executePartition(
|
||||
context: TaskContext,
|
||||
partitionIndex: Int,
|
||||
aggregationIterator: TungstenAggregationIterator,
|
||||
parentIterator: Iterator[InternalRow]): Iterator[UnsafeRow] = {
|
||||
val hasInput = parentIterator.hasNext
|
||||
if (!hasInput) {
|
||||
// We're not using the underlying map, so we just can free it here
|
||||
aggregationIterator.free()
|
||||
if (groupingExpressions.isEmpty) {
|
||||
val hasInput = iter.hasNext
|
||||
if (!hasInput && groupingExpressions.nonEmpty) {
|
||||
// This is a grouped aggregate and the input iterator is empty,
|
||||
// so return an empty iterator.
|
||||
Iterator.empty
|
||||
} else {
|
||||
val aggregationIterator =
|
||||
new TungstenAggregationIterator(
|
||||
groupingExpressions,
|
||||
nonCompleteAggregateExpressions,
|
||||
nonCompleteAggregateAttributes,
|
||||
completeAggregateExpressions,
|
||||
completeAggregateAttributes,
|
||||
initialInputBufferOffset,
|
||||
resultExpressions,
|
||||
newMutableProjection,
|
||||
child.output,
|
||||
iter,
|
||||
testFallbackStartsAt,
|
||||
numInputRows,
|
||||
numOutputRows,
|
||||
dataSize,
|
||||
spillSize)
|
||||
if (!hasInput && groupingExpressions.isEmpty) {
|
||||
numOutputRows += 1
|
||||
Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
|
||||
} else {
|
||||
// This is a grouped aggregate and the input iterator is empty,
|
||||
// so return an empty iterator.
|
||||
Iterator.empty
|
||||
aggregationIterator
|
||||
}
|
||||
} else {
|
||||
aggregationIterator.start(parentIterator)
|
||||
aggregationIterator
|
||||
}
|
||||
}
|
||||
|
||||
// Note: we need to set up the iterator in each partition before computing the
|
||||
// parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747).
|
||||
val resultRdd = {
|
||||
new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator](
|
||||
child.execute(), preparePartition, executePartition, preservesPartitioning = true)
|
||||
}
|
||||
resultRdd.asInstanceOf[RDD[InternalRow]]
|
||||
}
|
||||
|
||||
override def simpleString: String = {
|
||||
|
|
|
@ -74,6 +74,8 @@ import org.apache.spark.sql.types.StructType
|
|||
* the function used to create mutable projections.
|
||||
* @param originalInputAttributes
|
||||
* attributes of representing input rows from `inputIter`.
|
||||
* @param inputIter
|
||||
* the iterator containing input [[UnsafeRow]]s.
|
||||
*/
|
||||
class TungstenAggregationIterator(
|
||||
groupingExpressions: Seq[NamedExpression],
|
||||
|
@ -85,6 +87,7 @@ class TungstenAggregationIterator(
|
|||
resultExpressions: Seq[NamedExpression],
|
||||
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
|
||||
originalInputAttributes: Seq[Attribute],
|
||||
inputIter: Iterator[InternalRow],
|
||||
testFallbackStartsAt: Option[Int],
|
||||
numInputRows: LongSQLMetric,
|
||||
numOutputRows: LongSQLMetric,
|
||||
|
@ -92,9 +95,6 @@ class TungstenAggregationIterator(
|
|||
spillSize: LongSQLMetric)
|
||||
extends Iterator[UnsafeRow] with Logging {
|
||||
|
||||
// The parent partition iterator, to be initialized later in `start`
|
||||
private[this] var inputIter: Iterator[InternalRow] = null
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// Part 1: Initializing aggregate functions.
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
|
@ -486,15 +486,11 @@ class TungstenAggregationIterator(
|
|||
false // disable tracking of performance metrics
|
||||
)
|
||||
|
||||
// Exposed for testing
|
||||
private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap
|
||||
|
||||
// The function used to read and process input rows. When processing input rows,
|
||||
// it first uses hash-based aggregation by putting groups and their buffers in
|
||||
// hashMap. If we could not allocate more memory for the map, we switch to
|
||||
// sort-based aggregation (by calling switchToSortBasedAggregation).
|
||||
private def processInputs(): Unit = {
|
||||
assert(inputIter != null, "attempted to process input when iterator was null")
|
||||
if (groupingExpressions.isEmpty) {
|
||||
// If there is no grouping expressions, we can just reuse the same buffer over and over again.
|
||||
// Note that it would be better to eliminate the hash map entirely in the future.
|
||||
|
@ -526,7 +522,6 @@ class TungstenAggregationIterator(
|
|||
// that it switch to sort-based aggregation after `fallbackStartsAt` input rows have
|
||||
// been processed.
|
||||
private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = {
|
||||
assert(inputIter != null, "attempted to process input when iterator was null")
|
||||
var i = 0
|
||||
while (!sortBased && inputIter.hasNext) {
|
||||
val newInput = inputIter.next()
|
||||
|
@ -567,15 +562,11 @@ class TungstenAggregationIterator(
|
|||
* Switch to sort-based aggregation when the hash-based approach is unable to acquire memory.
|
||||
*/
|
||||
private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = {
|
||||
assert(inputIter != null, "attempted to process input when iterator was null")
|
||||
logInfo("falling back to sort based aggregation.")
|
||||
// Step 1: Get the ExternalSorter containing sorted entries of the map.
|
||||
externalSorter = hashMap.destructAndCreateExternalSorter()
|
||||
|
||||
// Step 2: Free the memory used by the map.
|
||||
hashMap.free()
|
||||
|
||||
// Step 3: If we have aggregate function with mode Partial or Complete,
|
||||
// Step 2: If we have aggregate function with mode Partial or Complete,
|
||||
// we need to process input rows to get aggregation buffer.
|
||||
// So, later in the sort-based aggregation iterator, we can do merge.
|
||||
// If aggregate functions are with mode Final and PartialMerge,
|
||||
|
@ -770,31 +761,27 @@ class TungstenAggregationIterator(
|
|||
|
||||
/**
|
||||
* Start processing input rows.
|
||||
* Only after this method is called will this iterator be non-empty.
|
||||
*/
|
||||
def start(parentIter: Iterator[InternalRow]): Unit = {
|
||||
inputIter = parentIter
|
||||
testFallbackStartsAt match {
|
||||
case None =>
|
||||
processInputs()
|
||||
case Some(fallbackStartsAt) =>
|
||||
// This is the testing path. processInputsWithControlledFallback is same as processInputs
|
||||
// except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
|
||||
// have been processed.
|
||||
processInputsWithControlledFallback(fallbackStartsAt)
|
||||
}
|
||||
testFallbackStartsAt match {
|
||||
case None =>
|
||||
processInputs()
|
||||
case Some(fallbackStartsAt) =>
|
||||
// This is the testing path. processInputsWithControlledFallback is same as processInputs
|
||||
// except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
|
||||
// have been processed.
|
||||
processInputsWithControlledFallback(fallbackStartsAt)
|
||||
}
|
||||
|
||||
// If we did not switch to sort-based aggregation in processInputs,
|
||||
// we pre-load the first key-value pair from the map (to make hasNext idempotent).
|
||||
if (!sortBased) {
|
||||
// First, set aggregationBufferMapIterator.
|
||||
aggregationBufferMapIterator = hashMap.iterator()
|
||||
// Pre-load the first key-value pair from the aggregationBufferMapIterator.
|
||||
mapIteratorHasNext = aggregationBufferMapIterator.next()
|
||||
// If the map is empty, we just free it.
|
||||
if (!mapIteratorHasNext) {
|
||||
hashMap.free()
|
||||
}
|
||||
// If we did not switch to sort-based aggregation in processInputs,
|
||||
// we pre-load the first key-value pair from the map (to make hasNext idempotent).
|
||||
if (!sortBased) {
|
||||
// First, set aggregationBufferMapIterator.
|
||||
aggregationBufferMapIterator = hashMap.iterator()
|
||||
// Pre-load the first key-value pair from the aggregationBufferMapIterator.
|
||||
mapIteratorHasNext = aggregationBufferMapIterator.next()
|
||||
// If the map is empty, we just free it.
|
||||
if (!mapIteratorHasNext) {
|
||||
hashMap.free()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -868,13 +855,16 @@ class TungstenAggregationIterator(
|
|||
* Generate a output row when there is no input and there is no grouping expression.
|
||||
*/
|
||||
def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
|
||||
assert(groupingExpressions.isEmpty)
|
||||
assert(inputIter == null)
|
||||
generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer)
|
||||
}
|
||||
|
||||
/** Free memory used in the underlying map. */
|
||||
def free(): Unit = {
|
||||
hashMap.free()
|
||||
if (groupingExpressions.isEmpty) {
|
||||
sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
|
||||
// We create a output row and copy it. So, we can free the map.
|
||||
val resultCopy =
|
||||
generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy()
|
||||
hashMap.free()
|
||||
resultCopy
|
||||
} else {
|
||||
throw new IllegalStateException(
|
||||
"This method should not be called when groupingExpressions is not empty.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.execution
|
||||
|
||||
import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.errors._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
|
@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
|
|||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.util.CompletionIterator
|
||||
import org.apache.spark.util.collection.ExternalSorter
|
||||
import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext}
|
||||
import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// This file defines various sort operators.
|
||||
|
@ -77,6 +77,7 @@ case class Sort(
|
|||
* @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will
|
||||
* spill every `frequency` records.
|
||||
*/
|
||||
|
||||
case class TungstenSort(
|
||||
sortOrder: Seq[SortOrder],
|
||||
global: Boolean,
|
||||
|
@ -106,11 +107,7 @@ case class TungstenSort(
|
|||
val dataSize = longMetric("dataSize")
|
||||
val spillSize = longMetric("spillSize")
|
||||
|
||||
/**
|
||||
* Set up the sorter in each partition before computing the parent partition.
|
||||
* This makes sure our sorter is not starved by other sorters used in the same task.
|
||||
*/
|
||||
def preparePartition(): UnsafeExternalRowSorter = {
|
||||
child.execute().mapPartitions { iter =>
|
||||
val ordering = newOrdering(sortOrder, childOutput)
|
||||
|
||||
// The comparator for comparing prefix
|
||||
|
@ -131,33 +128,20 @@ case class TungstenSort(
|
|||
if (testSpillFrequency > 0) {
|
||||
sorter.setTestSpillFrequency(testSpillFrequency)
|
||||
}
|
||||
sorter
|
||||
}
|
||||
|
||||
/** Compute a partition using the sorter already set up previously. */
|
||||
def executePartition(
|
||||
taskContext: TaskContext,
|
||||
partitionIndex: Int,
|
||||
sorter: UnsafeExternalRowSorter,
|
||||
parentIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
|
||||
// Remember spill data size of this task before execute this operator so that we can
|
||||
// figure out how many bytes we spilled for this operator.
|
||||
val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled
|
||||
|
||||
val sortedIterator = sorter.sort(parentIterator.asInstanceOf[Iterator[UnsafeRow]])
|
||||
val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
|
||||
|
||||
dataSize += sorter.getPeakMemoryUsage
|
||||
spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore
|
||||
|
||||
taskContext.internalMetricsToAccumulators(
|
||||
TaskContext.get().internalMetricsToAccumulators(
|
||||
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage)
|
||||
sortedIterator
|
||||
}
|
||||
|
||||
// Note: we need to set up the external sorter in each partition before computing
|
||||
// the parent partition, so we cannot simply use `mapPartitions` here (SPARK-9709).
|
||||
new MapPartitionsWithPreparationRDD[InternalRow, InternalRow, UnsafeExternalRowSorter](
|
||||
child.execute(), preparePartition, executePartition, preservesPartitioning = true)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue