[SPARK-33494][SQL][AQE] Do not use local shuffle reader for repartition
### What changes were proposed in this pull request? This PR updates `ShuffleExchangeExec` to carry more information about how much we can change the partitioning. For `repartition(col)`, we should preserve the user-specified partitioning and don't apply the AQE local shuffle reader. ### Why are the changes needed? Similar to `repartition(number, col)`, we should respect the user-specified partitioning. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? a new test Closes #30432 from cloud-fan/aqe. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
01321bc0fe
commit
d1b4f06179
|
@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRe
|
|||
import org.apache.spark.sql.execution.aggregate.AggUtils
|
||||
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
|
||||
import org.apache.spark.sql.execution.command._
|
||||
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
|
||||
import org.apache.spark.sql.execution.exchange.{REPARTITION, REPARTITION_WITH_NUM, ShuffleExchangeExec}
|
||||
import org.apache.spark.sql.execution.python._
|
||||
import org.apache.spark.sql.execution.streaming._
|
||||
import org.apache.spark.sql.execution.streaming.sources.MemoryPlan
|
||||
|
@ -670,7 +670,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
case logical.Repartition(numPartitions, shuffle, child) =>
|
||||
if (shuffle) {
|
||||
ShuffleExchangeExec(RoundRobinPartitioning(numPartitions),
|
||||
planLater(child), noUserSpecifiedNumPartition = false) :: Nil
|
||||
planLater(child), REPARTITION_WITH_NUM) :: Nil
|
||||
} else {
|
||||
execution.CoalesceExec(numPartitions, planLater(child)) :: Nil
|
||||
}
|
||||
|
@ -703,10 +703,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
case r: logical.Range =>
|
||||
execution.RangeExec(r) :: Nil
|
||||
case r: logical.RepartitionByExpression =>
|
||||
exchange.ShuffleExchangeExec(
|
||||
r.partitioning,
|
||||
planLater(r.child),
|
||||
noUserSpecifiedNumPartition = r.optNumPartitions.isEmpty) :: Nil
|
||||
val shuffleOrigin = if (r.optNumPartitions.isEmpty) {
|
||||
REPARTITION
|
||||
} else {
|
||||
REPARTITION_WITH_NUM
|
||||
}
|
||||
exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child), shuffleOrigin) :: Nil
|
||||
case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
|
||||
case r: LogicalRDD =>
|
||||
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil
|
||||
|
|
|
@ -18,8 +18,10 @@
|
|||
package org.apache.spark.sql.execution.adaptive
|
||||
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
import org.apache.spark.sql.execution.SparkPlan
|
||||
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REPARTITION, ShuffleExchangeLike}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
||||
/**
|
||||
|
@ -47,7 +49,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl
|
|||
val shuffleStages = collectShuffleStages(plan)
|
||||
// ShuffleExchanges introduced by repartition do not support changing the number of partitions.
|
||||
// We change the number of partitions in the stage only if all the ShuffleExchanges support it.
|
||||
if (!shuffleStages.forall(_.shuffle.canChangeNumPartitions)) {
|
||||
if (!shuffleStages.forall(s => supportCoalesce(s.shuffle))) {
|
||||
plan
|
||||
} else {
|
||||
// `ShuffleQueryStageExec#mapStats` returns None when the input RDD has 0 partitions,
|
||||
|
@ -82,4 +84,9 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def supportCoalesce(s: ShuffleExchangeLike): Boolean = {
|
||||
s.outputPartitioning != SinglePartition &&
|
||||
(s.shuffleOrigin == ENSURE_REQUIREMENTS || s.shuffleOrigin == REPARTITION)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,9 +18,10 @@
|
|||
package org.apache.spark.sql.execution.adaptive
|
||||
|
||||
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
|
||||
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
import org.apache.spark.sql.execution._
|
||||
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec}
|
||||
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ShuffleExchangeExec, ShuffleExchangeLike}
|
||||
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
||||
|
@ -136,9 +137,13 @@ object OptimizeLocalShuffleReader extends Rule[SparkPlan] {
|
|||
|
||||
def canUseLocalShuffleReader(plan: SparkPlan): Boolean = plan match {
|
||||
case s: ShuffleQueryStageExec =>
|
||||
s.shuffle.canChangeNumPartitions && s.mapStats.isDefined
|
||||
s.mapStats.isDefined && supportLocalReader(s.shuffle)
|
||||
case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs) =>
|
||||
s.shuffle.canChangeNumPartitions && s.mapStats.isDefined && partitionSpecs.nonEmpty
|
||||
s.mapStats.isDefined && partitionSpecs.nonEmpty && supportLocalReader(s.shuffle)
|
||||
case _ => false
|
||||
}
|
||||
|
||||
private def supportLocalReader(s: ShuffleExchangeLike): Boolean = {
|
||||
s.outputPartitioning != SinglePartition && s.shuffleOrigin == ENSURE_REQUIREMENTS
|
||||
}
|
||||
}
|
||||
|
|
|
@ -57,9 +57,9 @@ trait ShuffleExchangeLike extends Exchange {
|
|||
def numPartitions: Int
|
||||
|
||||
/**
|
||||
* Returns whether the shuffle partition number can be changed.
|
||||
* The origin of this shuffle operator.
|
||||
*/
|
||||
def canChangeNumPartitions: Boolean
|
||||
def shuffleOrigin: ShuffleOrigin
|
||||
|
||||
/**
|
||||
* The asynchronous job that materializes the shuffle.
|
||||
|
@ -77,18 +77,30 @@ trait ShuffleExchangeLike extends Exchange {
|
|||
def runtimeStatistics: Statistics
|
||||
}
|
||||
|
||||
// Describes where the shuffle operator comes from.
|
||||
sealed trait ShuffleOrigin
|
||||
|
||||
// Indicates that the shuffle operator was added by the internal `EnsureRequirements` rule. It
|
||||
// means that the shuffle operator is used to ensure internal data partitioning requirements and
|
||||
// Spark is free to optimize it as long as the requirements are still ensured.
|
||||
case object ENSURE_REQUIREMENTS extends ShuffleOrigin
|
||||
|
||||
// Indicates that the shuffle operator was added by the user-specified repartition operator. Spark
|
||||
// can still optimize it via changing shuffle partition number, as data partitioning won't change.
|
||||
case object REPARTITION extends ShuffleOrigin
|
||||
|
||||
// Indicates that the shuffle operator was added by the user-specified repartition operator with
|
||||
// a certain partition number. Spark can't optimize it.
|
||||
case object REPARTITION_WITH_NUM extends ShuffleOrigin
|
||||
|
||||
/**
|
||||
* Performs a shuffle that will result in the desired partitioning.
|
||||
*/
|
||||
case class ShuffleExchangeExec(
|
||||
override val outputPartitioning: Partitioning,
|
||||
child: SparkPlan,
|
||||
noUserSpecifiedNumPartition: Boolean = true) extends ShuffleExchangeLike {
|
||||
|
||||
// If users specify the num partitions via APIs like `repartition`, we shouldn't change it.
|
||||
// For `SinglePartition`, it requires exactly one partition and we can't change it either.
|
||||
override def canChangeNumPartitions: Boolean =
|
||||
noUserSpecifiedNumPartition && outputPartitioning != SinglePartition
|
||||
shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS)
|
||||
extends ShuffleExchangeLike {
|
||||
|
||||
private lazy val writeMetrics =
|
||||
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
-- Automatically generated by SQLQueryTestSuite
|
||||
-- Number of queries: 23
|
||||
-- Number of queries: 24
|
||||
|
||||
|
||||
-- !query
|
||||
|
@ -67,10 +67,10 @@ Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL]
|
|||
== Physical Plan ==
|
||||
AdaptiveSparkPlan isFinalPlan=false
|
||||
+- HashAggregate(keys=[], functions=[sum(distinct cast(val#x as bigint)#xL)], output=[sum(DISTINCT val)#xL])
|
||||
+- Exchange SinglePartition, true, [id=#x]
|
||||
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
|
||||
+- HashAggregate(keys=[], functions=[partial_sum(distinct cast(val#x as bigint)#xL)], output=[sum#xL])
|
||||
+- HashAggregate(keys=[cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
|
||||
+- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), true, [id=#x]
|
||||
+- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), ENSURE_REQUIREMENTS, [id=#x]
|
||||
+- HashAggregate(keys=[cast(val#x as bigint) AS cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
|
||||
+- FileScan parquet default.explain_temp1[val#x] Batched: true, DataFilters: [], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/explain_temp1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<val:int>
|
||||
|
||||
|
@ -116,7 +116,7 @@ Results [2]: [key#x, max#x]
|
|||
|
||||
(4) Exchange
|
||||
Input [2]: [key#x, max#x]
|
||||
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
|
||||
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(5) HashAggregate
|
||||
Input [2]: [key#x, max#x]
|
||||
|
@ -127,7 +127,7 @@ Results [2]: [key#x, max(val#x)#x AS max(val)#x]
|
|||
|
||||
(6) Exchange
|
||||
Input [2]: [key#x, max(val)#x]
|
||||
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), true, [id=#x]
|
||||
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(7) Sort
|
||||
Input [2]: [key#x, max(val)#x]
|
||||
|
@ -179,7 +179,7 @@ Results [2]: [key#x, max#x]
|
|||
|
||||
(4) Exchange
|
||||
Input [2]: [key#x, max#x]
|
||||
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
|
||||
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(5) HashAggregate
|
||||
Input [2]: [key#x, max#x]
|
||||
|
@ -254,7 +254,7 @@ Results [2]: [key#x, val#x]
|
|||
|
||||
(7) Exchange
|
||||
Input [2]: [key#x, val#x]
|
||||
Arguments: hashpartitioning(key#x, val#x, 4), true, [id=#x]
|
||||
Arguments: hashpartitioning(key#x, val#x, 4), ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(8) HashAggregate
|
||||
Input [2]: [key#x, val#x]
|
||||
|
@ -576,7 +576,7 @@ Results [2]: [key#x, max#x]
|
|||
|
||||
(4) Exchange
|
||||
Input [2]: [key#x, max#x]
|
||||
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
|
||||
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(5) HashAggregate
|
||||
Input [2]: [key#x, max#x]
|
||||
|
@ -605,7 +605,7 @@ Results [2]: [key#x, max#x]
|
|||
|
||||
(9) Exchange
|
||||
Input [2]: [key#x, max#x]
|
||||
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
|
||||
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(10) HashAggregate
|
||||
Input [2]: [key#x, max#x]
|
||||
|
@ -687,7 +687,7 @@ Results [3]: [count#xL, sum#xL, count#xL]
|
|||
|
||||
(3) Exchange
|
||||
Input [3]: [count#xL, sum#xL, count#xL]
|
||||
Arguments: SinglePartition, true, [id=#x]
|
||||
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(4) HashAggregate
|
||||
Input [3]: [count#xL, sum#xL, count#xL]
|
||||
|
@ -732,7 +732,7 @@ Results [2]: [key#x, buf#x]
|
|||
|
||||
(3) Exchange
|
||||
Input [2]: [key#x, buf#x]
|
||||
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
|
||||
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(4) ObjectHashAggregate
|
||||
Input [2]: [key#x, buf#x]
|
||||
|
@ -783,7 +783,7 @@ Results [2]: [key#x, min#x]
|
|||
|
||||
(4) Exchange
|
||||
Input [2]: [key#x, min#x]
|
||||
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
|
||||
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(5) Sort
|
||||
Input [2]: [key#x, min#x]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
-- Automatically generated by SQLQueryTestSuite
|
||||
-- Number of queries: 23
|
||||
-- Number of queries: 24
|
||||
|
||||
|
||||
-- !query
|
||||
|
@ -66,10 +66,10 @@ Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL]
|
|||
|
||||
== Physical Plan ==
|
||||
*HashAggregate(keys=[], functions=[sum(distinct cast(val#x as bigint)#xL)], output=[sum(DISTINCT val)#xL])
|
||||
+- Exchange SinglePartition, true, [id=#x]
|
||||
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
|
||||
+- *HashAggregate(keys=[], functions=[partial_sum(distinct cast(val#x as bigint)#xL)], output=[sum#xL])
|
||||
+- *HashAggregate(keys=[cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
|
||||
+- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), true, [id=#x]
|
||||
+- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), ENSURE_REQUIREMENTS, [id=#x]
|
||||
+- *HashAggregate(keys=[cast(val#x as bigint) AS cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
|
||||
+- *ColumnarToRow
|
||||
+- FileScan parquet default.explain_temp1[val#x] Batched: true, DataFilters: [], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/explain_temp1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<val:int>
|
||||
|
@ -119,7 +119,7 @@ Results [2]: [key#x, max#x]
|
|||
|
||||
(5) Exchange
|
||||
Input [2]: [key#x, max#x]
|
||||
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
|
||||
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(6) HashAggregate [codegen id : 2]
|
||||
Input [2]: [key#x, max#x]
|
||||
|
@ -130,7 +130,7 @@ Results [2]: [key#x, max(val#x)#x AS max(val)#x]
|
|||
|
||||
(7) Exchange
|
||||
Input [2]: [key#x, max(val)#x]
|
||||
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), true, [id=#x]
|
||||
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(8) Sort [codegen id : 3]
|
||||
Input [2]: [key#x, max(val)#x]
|
||||
|
@ -181,7 +181,7 @@ Results [2]: [key#x, max#x]
|
|||
|
||||
(5) Exchange
|
||||
Input [2]: [key#x, max#x]
|
||||
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
|
||||
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(6) HashAggregate [codegen id : 2]
|
||||
Input [2]: [key#x, max#x]
|
||||
|
@ -259,7 +259,7 @@ Results [2]: [key#x, val#x]
|
|||
|
||||
(9) Exchange
|
||||
Input [2]: [key#x, val#x]
|
||||
Arguments: hashpartitioning(key#x, val#x, 4), true, [id=#x]
|
||||
Arguments: hashpartitioning(key#x, val#x, 4), ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(10) HashAggregate [codegen id : 4]
|
||||
Input [2]: [key#x, val#x]
|
||||
|
@ -452,7 +452,7 @@ Results [1]: [max#x]
|
|||
|
||||
(9) Exchange
|
||||
Input [1]: [max#x]
|
||||
Arguments: SinglePartition, true, [id=#x]
|
||||
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(10) HashAggregate [codegen id : 2]
|
||||
Input [1]: [max#x]
|
||||
|
@ -498,7 +498,7 @@ Results [1]: [max#x]
|
|||
|
||||
(16) Exchange
|
||||
Input [1]: [max#x]
|
||||
Arguments: SinglePartition, true, [id=#x]
|
||||
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(17) HashAggregate [codegen id : 2]
|
||||
Input [1]: [max#x]
|
||||
|
@ -580,7 +580,7 @@ Results [1]: [max#x]
|
|||
|
||||
(9) Exchange
|
||||
Input [1]: [max#x]
|
||||
Arguments: SinglePartition, true, [id=#x]
|
||||
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(10) HashAggregate [codegen id : 2]
|
||||
Input [1]: [max#x]
|
||||
|
@ -626,7 +626,7 @@ Results [2]: [sum#x, count#xL]
|
|||
|
||||
(16) Exchange
|
||||
Input [2]: [sum#x, count#xL]
|
||||
Arguments: SinglePartition, true, [id=#x]
|
||||
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(17) HashAggregate [codegen id : 2]
|
||||
Input [2]: [sum#x, count#xL]
|
||||
|
@ -690,7 +690,7 @@ Results [2]: [sum#x, count#xL]
|
|||
|
||||
(7) Exchange
|
||||
Input [2]: [sum#x, count#xL]
|
||||
Arguments: SinglePartition, true, [id=#x]
|
||||
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(8) HashAggregate [codegen id : 2]
|
||||
Input [2]: [sum#x, count#xL]
|
||||
|
@ -810,7 +810,7 @@ Results [2]: [key#x, max#x]
|
|||
|
||||
(5) Exchange
|
||||
Input [2]: [key#x, max#x]
|
||||
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
|
||||
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(6) HashAggregate [codegen id : 4]
|
||||
Input [2]: [key#x, max#x]
|
||||
|
@ -901,7 +901,7 @@ Results [3]: [count#xL, sum#xL, count#xL]
|
|||
|
||||
(4) Exchange
|
||||
Input [3]: [count#xL, sum#xL, count#xL]
|
||||
Arguments: SinglePartition, true, [id=#x]
|
||||
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(5) HashAggregate [codegen id : 2]
|
||||
Input [3]: [count#xL, sum#xL, count#xL]
|
||||
|
@ -945,7 +945,7 @@ Results [2]: [key#x, buf#x]
|
|||
|
||||
(4) Exchange
|
||||
Input [2]: [key#x, buf#x]
|
||||
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
|
||||
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(5) ObjectHashAggregate
|
||||
Input [2]: [key#x, buf#x]
|
||||
|
@ -995,7 +995,7 @@ Results [2]: [key#x, min#x]
|
|||
|
||||
(5) Exchange
|
||||
Input [2]: [key#x, min#x]
|
||||
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
|
||||
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]
|
||||
|
||||
(6) Sort [codegen id : 2]
|
||||
Input [2]: [key#x, min#x]
|
||||
|
|
|
@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
|
|||
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
|
||||
import org.apache.spark.sql.execution._
|
||||
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
|
||||
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike}
|
||||
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
|
||||
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE
|
||||
|
@ -766,7 +766,9 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] {
|
|||
case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleExchangeLike {
|
||||
override def numMappers: Int = delegate.numMappers
|
||||
override def numPartitions: Int = delegate.numPartitions
|
||||
override def canChangeNumPartitions: Boolean = delegate.canChangeNumPartitions
|
||||
override def shuffleOrigin: ShuffleOrigin = {
|
||||
delegate.shuffleOrigin
|
||||
}
|
||||
override def mapOutputStatisticsFuture: Future[MapOutputStatistics] =
|
||||
delegate.mapOutputStatisticsFuture
|
||||
override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] =
|
||||
|
|
|
@ -1307,4 +1307,14 @@ class AdaptiveQueryExecSuite
|
|||
spark.listenerManager.unregister(listener)
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-33494: Do not use local shuffle reader for repartition") {
|
||||
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
|
||||
val df = spark.table("testData").repartition('key)
|
||||
df.collect()
|
||||
// local shuffle reader breaks partitioning and shouldn't be used for repartition operation
|
||||
// which is specified by users.
|
||||
checkNumLocalShuffleReaders(df.queryExecution.executedPlan, numShufflesWithoutLocalReader = 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue