[SPARK-35855][SQL] Unify reuse map data structures in non-AQE and AQE rules
### What changes were proposed in this pull request? This PR unifies reuse map data structures in non-AQE and AQE rules to a simple `Map[<canonicalized plan>, <plan>]` based on the discussion here: https://github.com/apache/spark/pull/28885#discussion_r655073897 ### Why are the changes needed? The proposed `Map[<canonicalized plan>, <plan>]` is simpler than the currently used `Map[<schema>, ArrayBuffer[<plan>]]` in `ReuseMap`/`ReuseExchangeAndSubquery` (non-AQE) and consistent with the `ReuseAdaptiveSubquery` (AQE) subquery reuse rule. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing UTs. Closes #33021 from peter-toth/SPARK-35855-unify-reuse-map-data-structures. Authored-by: Peter Toth <peter.toth@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
20edfdd39a
commit
79e3d0d98f
|
@ -1,73 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.util
|
||||
|
||||
import scala.collection.mutable.{ArrayBuffer, Map}
|
||||
|
||||
import org.apache.spark.sql.catalyst.plans.QueryPlan
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
/**
|
||||
* Map of canonicalized plans that can be used to find reuse possibilities.
|
||||
*
|
||||
* To avoid costly canonicalization of a plan:
|
||||
* - we use its schema first to check if it can be replaced to a reused one at all
|
||||
* - we insert it into the map of canonicalized plans only when at least 2 have the same schema
|
||||
*
|
||||
* @tparam T the type of the node we want to reuse
|
||||
* @tparam T2 the type of the canonicalized node
|
||||
*/
|
||||
class ReuseMap[T <: T2, T2 <: QueryPlan[T2]] {
|
||||
private val map = Map[StructType, ArrayBuffer[T]]()
|
||||
|
||||
/**
|
||||
* Find a matching plan with the same canonicalized form in the map or add the new plan to the
|
||||
* map otherwise.
|
||||
*
|
||||
* @param plan the input plan
|
||||
* @return the matching plan or the input plan
|
||||
*/
|
||||
private def lookupOrElseAdd(plan: T): T = {
|
||||
val sameSchema = map.getOrElseUpdate(plan.schema, ArrayBuffer())
|
||||
val samePlan = sameSchema.find(plan.sameResult)
|
||||
if (samePlan.isDefined) {
|
||||
samePlan.get
|
||||
} else {
|
||||
sameSchema += plan
|
||||
plan
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Find a matching plan with the same canonicalized form in the map and apply `f` on it or add
|
||||
* the new plan to the map otherwise.
|
||||
*
|
||||
* @param plan the input plan
|
||||
* @param f the function to apply
|
||||
* @tparam T2 the type of the reuse node
|
||||
* @return the matching plan with `f` applied or the input plan
|
||||
*/
|
||||
def reuseOrElseAdd[T2 >: T](plan: T, f: T => T2): T2 = {
|
||||
val found = lookupOrElseAdd(plan)
|
||||
if (found eq plan) {
|
||||
plan
|
||||
} else {
|
||||
f(found)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,73 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.util
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
|
||||
import org.apache.spark.sql.types.IntegerType
|
||||
|
||||
case class TestNode(children: Seq[TestNode], output: Seq[Attribute]) extends LogicalPlan {
|
||||
override protected def withNewChildrenInternal(
|
||||
newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = copy(children = children)
|
||||
}
|
||||
case class TestReuseNode(child: LogicalPlan) extends UnaryNode {
|
||||
override def output: Seq[Attribute] = child.output
|
||||
|
||||
override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
|
||||
copy(child = newChild)
|
||||
}
|
||||
|
||||
class ReuseMapSuite extends SparkFunSuite {
|
||||
private val leafNode1 = TestNode(Nil, Seq(AttributeReference("a", IntegerType)()))
|
||||
private val leafNode2 = TestNode(Nil, Seq(AttributeReference("b", IntegerType)()))
|
||||
private val parentNode1 = TestNode(Seq(leafNode1), Seq(AttributeReference("a", IntegerType)()))
|
||||
private val parentNode2 = TestNode(Seq(leafNode2), Seq(AttributeReference("b", IntegerType)()))
|
||||
|
||||
private def reuse(testNode: TestNode) = TestReuseNode(testNode)
|
||||
|
||||
test("no reuse if same instance") {
|
||||
val reuseMap = new ReuseMap[TestNode, LogicalPlan]()
|
||||
|
||||
reuseMap.reuseOrElseAdd(leafNode1, reuse)
|
||||
reuseMap.reuseOrElseAdd(parentNode1, reuse)
|
||||
|
||||
assert(reuseMap.reuseOrElseAdd(leafNode1, reuse) == leafNode1)
|
||||
assert(reuseMap.reuseOrElseAdd(parentNode1, reuse) == parentNode1)
|
||||
}
|
||||
|
||||
test("reuse if different instance with same canonicalized plan") {
|
||||
val reuseMap = new ReuseMap[TestNode, LogicalPlan]()
|
||||
reuseMap.reuseOrElseAdd(leafNode1, reuse)
|
||||
reuseMap.reuseOrElseAdd(parentNode1, reuse)
|
||||
|
||||
assert(reuseMap.reuseOrElseAdd(leafNode1.clone.asInstanceOf[TestNode], reuse) ==
|
||||
reuse(leafNode1))
|
||||
assert(reuseMap.reuseOrElseAdd(parentNode1.clone.asInstanceOf[TestNode], reuse) ==
|
||||
reuse(parentNode1))
|
||||
}
|
||||
|
||||
test("no reuse if different canonicalized plan") {
|
||||
val reuseMap = new ReuseMap[TestNode, LogicalPlan]()
|
||||
reuseMap.reuseOrElseAdd(leafNode1, reuse)
|
||||
reuseMap.reuseOrElseAdd(parentNode1, reuse)
|
||||
|
||||
assert(reuseMap.reuseOrElseAdd(leafNode2, reuse) == leafNode2)
|
||||
assert(reuseMap.reuseOrElseAdd(parentNode2, reuse) == parentNode2)
|
||||
}
|
||||
}
|
|
@ -17,11 +17,12 @@
|
|||
|
||||
package org.apache.spark.sql.execution.reuse
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
import org.apache.spark.sql.catalyst.trees.TreePattern._
|
||||
import org.apache.spark.sql.execution.{BaseSubqueryExec, ExecSubqueryExpression, ReusedSubqueryExec, SparkPlan}
|
||||
import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec}
|
||||
import org.apache.spark.sql.util.ReuseMap
|
||||
|
||||
/**
|
||||
* Find out duplicated exchanges and subqueries in the whole spark plan including subqueries, then
|
||||
|
@ -36,24 +37,34 @@ case object ReuseExchangeAndSubquery extends Rule[SparkPlan] {
|
|||
|
||||
def apply(plan: SparkPlan): SparkPlan = {
|
||||
if (conf.exchangeReuseEnabled || conf.subqueryReuseEnabled) {
|
||||
val exchanges = new ReuseMap[Exchange, SparkPlan]()
|
||||
val subqueries = new ReuseMap[BaseSubqueryExec, SparkPlan]()
|
||||
val exchanges = mutable.Map.empty[SparkPlan, Exchange]
|
||||
val subqueries = mutable.Map.empty[SparkPlan, BaseSubqueryExec]
|
||||
|
||||
def reuse(plan: SparkPlan): SparkPlan = {
|
||||
plan.transformUpWithPruning(_.containsAnyPattern(EXCHANGE, PLAN_EXPRESSION)) {
|
||||
case exchange: Exchange if conf.exchangeReuseEnabled =>
|
||||
exchanges.reuseOrElseAdd(exchange, ReusedExchangeExec(exchange.output, _))
|
||||
val cachedExchange = exchanges.getOrElseUpdate(exchange.canonicalized, exchange)
|
||||
if (cachedExchange.ne(exchange)) {
|
||||
ReusedExchangeExec(exchange.output, cachedExchange)
|
||||
} else {
|
||||
cachedExchange
|
||||
}
|
||||
|
||||
case other =>
|
||||
other.transformExpressionsUpWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
|
||||
case sub: ExecSubqueryExpression =>
|
||||
val subquery = reuse(sub.plan).asInstanceOf[BaseSubqueryExec]
|
||||
sub.withNewPlan(
|
||||
if (conf.subqueryReuseEnabled) {
|
||||
subqueries.reuseOrElseAdd(subquery, ReusedSubqueryExec(_))
|
||||
val newSubquery = if (conf.subqueryReuseEnabled) {
|
||||
val cachedSubquery = subqueries.getOrElseUpdate(subquery.canonicalized, subquery)
|
||||
if (cachedSubquery.ne(subquery)) {
|
||||
ReusedSubqueryExec(cachedSubquery)
|
||||
} else {
|
||||
cachedSubquery
|
||||
}
|
||||
} else {
|
||||
subquery
|
||||
})
|
||||
}
|
||||
sub.withNewPlan(newSubquery)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -479,9 +479,9 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
|
|||
test("broadcast join where streamed side's output partitioning is PartitioningCollection") {
|
||||
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {
|
||||
val t1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1")
|
||||
val t2 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i2", "j2")
|
||||
val t2 = (0 until 100).map(i => (i % 5, i % 14)).toDF("i2", "j2")
|
||||
val t3 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i3", "j3")
|
||||
val t4 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i4", "j4")
|
||||
val t4 = (0 until 100).map(i => (i % 5, i % 15)).toDF("i4", "j4")
|
||||
|
||||
// join1 is a sort merge join (shuffle on the both sides).
|
||||
val join1 = t1.join(t2, t1("i1") === t2("i2"))
|
||||
|
|
Loading…
Reference in a new issue