[SPARK-35794][SQL] Allow custom plugin for AQE cost evaluator
### What changes were proposed in this pull request?
Current AQE has cost evaluator to decide whether to use new plan after replanning. The current used evaluator is `SimpleCostEvaluator` to make decision based on number of shuffle in the query plan. This is not perfect cost evaluator, and different production environments might want to use different custom evaluators. E.g., sometimes we might want to still do skew join even though it might introduce extra shuffle (trade off resource for better latency), sometimes we might want to take sort into consideration for cost as well. Take our own setting as an example, we are using a custom remote shuffle service (Cosco), and the cost model is more complicated. So We want to make the cost evaluator to be pluggable, and developers can implement their own `CostEvaluator` subclass and plug in dynamically based on configuration.
The approach is to introduce a new config to allow define sub-class name of `CostEvaluator` - `spark.sql.adaptive.customCostEvaluatorClass`. And add `CostEvaluator.instantiate` to instantiate the cost evaluator class in `AdaptiveSparkPlanExec.costEvaluator`.
### Why are the changes needed?
Make AQE cost evaluation more flexible.
### Does this PR introduce _any_ user-facing change?
No but an internal config is introduced - `spark.sql.adaptive.customCostEvaluatorClass` to allow custom implementation of `CostEvaluator`.
### How was this patch tested?
Added unit test in `AdaptiveQueryExecSuite.scala`.
Closes #32944 from c21/aqe-cost.
Authored-by: Cheng Su <chengsu@fb.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 044dddf288
)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
26bcf02833
commit
39b3a04bfe
|
@ -678,6 +678,14 @@ object SQLConf {
|
|||
.booleanConf
|
||||
.createWithDefault(true)
|
||||
|
||||
val ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS =
|
||||
buildConf("spark.sql.adaptive.customCostEvaluatorClass")
|
||||
.doc("The custom cost evaluator class to be used for adaptive execution. If not being set," +
|
||||
" Spark will use its own SimpleCostEvaluator by default.")
|
||||
.version("3.2.0")
|
||||
.stringConf
|
||||
.createOptional
|
||||
|
||||
val SUBEXPRESSION_ELIMINATION_ENABLED =
|
||||
buildConf("spark.sql.subexpressionElimination.enabled")
|
||||
.internal()
|
||||
|
|
|
@ -130,7 +130,11 @@ case class AdaptiveSparkPlanExec(
|
|||
}
|
||||
}
|
||||
|
||||
@transient private val costEvaluator = SimpleCostEvaluator
|
||||
@transient private val costEvaluator =
|
||||
conf.getConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS) match {
|
||||
case Some(className) => CostEvaluator.instantiate(className, session.sparkContext.getConf)
|
||||
case _ => SimpleCostEvaluator
|
||||
}
|
||||
|
||||
@transient val initialPlan = context.session.withActive {
|
||||
applyPhysicalRules(
|
||||
|
|
|
@ -17,16 +17,42 @@
|
|||
|
||||
package org.apache.spark.sql.execution.adaptive
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.annotation.Unstable
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.execution.SparkPlan
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
* Represents the cost of a plan.
|
||||
* An interface to represent the cost of a plan.
|
||||
*
|
||||
* @note This class is subject to be changed and/or moved in the near future.
|
||||
*/
|
||||
@Unstable
|
||||
trait Cost extends Ordered[Cost]
|
||||
|
||||
/**
|
||||
* Evaluates the cost of a physical plan.
|
||||
* An interface to evaluate the cost of a physical plan.
|
||||
*
|
||||
* @note This class is subject to be changed and/or moved in the near future.
|
||||
*/
|
||||
@Unstable
|
||||
trait CostEvaluator {
|
||||
def evaluateCost(plan: SparkPlan): Cost
|
||||
}
|
||||
|
||||
object CostEvaluator extends Logging {
|
||||
|
||||
/**
|
||||
* Instantiates a [[CostEvaluator]] using the given className.
|
||||
*/
|
||||
def instantiate(className: String, conf: SparkConf): CostEvaluator = {
|
||||
logDebug(s"Creating CostEvaluator $className")
|
||||
val evaluators = Utils.loadExtensions(classOf[CostEvaluator], Seq(className), conf)
|
||||
require(evaluators.nonEmpty, "A valid AQE cost evaluator must be specified by config " +
|
||||
s"${SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS.key}, but $className resulted in zero " +
|
||||
"valid evaluator.")
|
||||
evaluators.head
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1898,4 +1898,54 @@ class AdaptiveQueryExecSuite
|
|||
assert(coalesceReader.head.partitionSpecs.length == 1)
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-35794: Allow custom plugin for cost evaluator") {
|
||||
CostEvaluator.instantiate(
|
||||
classOf[SimpleShuffleSortCostEvaluator].getCanonicalName, spark.sparkContext.getConf)
|
||||
intercept[IllegalArgumentException] {
|
||||
CostEvaluator.instantiate(
|
||||
classOf[InvalidCostEvaluator].getCanonicalName, spark.sparkContext.getConf)
|
||||
}
|
||||
|
||||
withSQLConf(
|
||||
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
|
||||
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
|
||||
val query = "SELECT * FROM testData join testData2 ON key = a where value = '1'"
|
||||
|
||||
withSQLConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS.key ->
|
||||
"org.apache.spark.sql.execution.adaptive.SimpleShuffleSortCostEvaluator") {
|
||||
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query)
|
||||
val smj = findTopLevelSortMergeJoin(plan)
|
||||
assert(smj.size == 1)
|
||||
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
|
||||
assert(bhj.size == 1)
|
||||
checkNumLocalShuffleReaders(adaptivePlan)
|
||||
}
|
||||
|
||||
withSQLConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS.key ->
|
||||
"org.apache.spark.sql.execution.adaptive.InvalidCostEvaluator") {
|
||||
intercept[IllegalArgumentException] {
|
||||
runAdaptiveAndVerifyResult(query)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Invalid implementation class for [[CostEvaluator]].
|
||||
*/
|
||||
private class InvalidCostEvaluator() {}
|
||||
|
||||
/**
|
||||
* A simple [[CostEvaluator]] to count number of [[ShuffleExchangeLike]] and [[SortExec]].
|
||||
*/
|
||||
private case class SimpleShuffleSortCostEvaluator() extends CostEvaluator {
|
||||
override def evaluateCost(plan: SparkPlan): Cost = {
|
||||
val cost = plan.collect {
|
||||
case s: ShuffleExchangeLike => s
|
||||
case s: SortExec => s
|
||||
}.size
|
||||
SimpleCost(cost)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue