From 79e3d0d98f884dd1f87ad385c682ba380a60dbc8 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 23 Jun 2021 07:20:47 +0000 Subject: [PATCH] [SPARK-35855][SQL] Unify reuse map data structures in non-AQE and AQE rules ### What changes were proposed in this pull request? This PR unifies reuse map data structures in non-AQE and AQE rules to a simple `Map[, ]` based on the discussion here: https://github.com/apache/spark/pull/28885#discussion_r655073897 ### Why are the changes needed? The proposed `Map[, ]` is simpler than the currently used `Map[, ArrayBuffer[]]` in `ReuseMap`/`ReuseExchangeAndSubquery` (non-AQE) and consistent with the `ReuseAdaptiveSubquery` (AQE) subquery reuse rule. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing UTs. Closes #33021 from peter-toth/SPARK-35855-unify-reuse-map-data-structures. Authored-by: Peter Toth Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/util/ReuseMap.scala | 73 ------------------- .../apache/spark/sql/util/ReuseMapSuite.scala | 73 ------------------- .../reuse/ReuseExchangeAndSubquery.scala | 29 +++++--- .../execution/joins/BroadcastJoinSuite.scala | 4 +- 4 files changed, 22 insertions(+), 157 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/util/ReuseMap.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/util/ReuseMapSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ReuseMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ReuseMap.scala deleted file mode 100644 index fbee4f0fc4..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ReuseMap.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.util - -import scala.collection.mutable.{ArrayBuffer, Map} - -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.types.StructType - -/** - * Map of canonicalized plans that can be used to find reuse possibilities. - * - * To avoid costly canonicalization of a plan: - * - we use its schema first to check if it can be replaced to a reused one at all - * - we insert it into the map of canonicalized plans only when at least 2 have the same schema - * - * @tparam T the type of the node we want to reuse - * @tparam T2 the type of the canonicalized node - */ -class ReuseMap[T <: T2, T2 <: QueryPlan[T2]] { - private val map = Map[StructType, ArrayBuffer[T]]() - - /** - * Find a matching plan with the same canonicalized form in the map or add the new plan to the - * map otherwise. - * - * @param plan the input plan - * @return the matching plan or the input plan - */ - private def lookupOrElseAdd(plan: T): T = { - val sameSchema = map.getOrElseUpdate(plan.schema, ArrayBuffer()) - val samePlan = sameSchema.find(plan.sameResult) - if (samePlan.isDefined) { - samePlan.get - } else { - sameSchema += plan - plan - } - } - - /** - * Find a matching plan with the same canonicalized form in the map and apply `f` on it or add - * the new plan to the map otherwise. - * - * @param plan the input plan - * @param f the function to apply - * @tparam T2 the type of the reuse node - * @return the matching plan with `f` applied or the input plan - */ - def reuseOrElseAdd[T2 >: T](plan: T, f: T => T2): T2 = { - val found = lookupOrElseAdd(plan) - if (found eq plan) { - plan - } else { - f(found) - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ReuseMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ReuseMapSuite.scala deleted file mode 100644 index 6a74aa46aa..0000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ReuseMapSuite.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.util - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} -import org.apache.spark.sql.types.IntegerType - -case class TestNode(children: Seq[TestNode], output: Seq[Attribute]) extends LogicalPlan { - override protected def withNewChildrenInternal( - newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = copy(children = children) -} -case class TestReuseNode(child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output - - override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = - copy(child = newChild) -} - -class ReuseMapSuite extends SparkFunSuite { - private val leafNode1 = TestNode(Nil, Seq(AttributeReference("a", IntegerType)())) - private val leafNode2 = TestNode(Nil, Seq(AttributeReference("b", IntegerType)())) - private val parentNode1 = TestNode(Seq(leafNode1), Seq(AttributeReference("a", IntegerType)())) - private val parentNode2 = TestNode(Seq(leafNode2), Seq(AttributeReference("b", IntegerType)())) - - private def reuse(testNode: TestNode) = TestReuseNode(testNode) - - test("no reuse if same instance") { - val reuseMap = new ReuseMap[TestNode, LogicalPlan]() - - reuseMap.reuseOrElseAdd(leafNode1, reuse) - reuseMap.reuseOrElseAdd(parentNode1, reuse) - - assert(reuseMap.reuseOrElseAdd(leafNode1, reuse) == leafNode1) - assert(reuseMap.reuseOrElseAdd(parentNode1, reuse) == parentNode1) - } - - test("reuse if different instance with same canonicalized plan") { - val reuseMap = new ReuseMap[TestNode, LogicalPlan]() - reuseMap.reuseOrElseAdd(leafNode1, reuse) - reuseMap.reuseOrElseAdd(parentNode1, reuse) - - assert(reuseMap.reuseOrElseAdd(leafNode1.clone.asInstanceOf[TestNode], reuse) == - reuse(leafNode1)) - assert(reuseMap.reuseOrElseAdd(parentNode1.clone.asInstanceOf[TestNode], reuse) == - reuse(parentNode1)) - } - - test("no reuse if different canonicalized plan") { - val reuseMap = new ReuseMap[TestNode, LogicalPlan]() - reuseMap.reuseOrElseAdd(leafNode1, reuse) - reuseMap.reuseOrElseAdd(parentNode1, reuse) - - assert(reuseMap.reuseOrElseAdd(leafNode2, reuse) == leafNode2) - assert(reuseMap.reuseOrElseAdd(parentNode2, reuse) == parentNode2) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/reuse/ReuseExchangeAndSubquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/reuse/ReuseExchangeAndSubquery.scala index 0de8178a9b..471b926dc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/reuse/ReuseExchangeAndSubquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/reuse/ReuseExchangeAndSubquery.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.reuse +import scala.collection.mutable + import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.execution.{BaseSubqueryExec, ExecSubqueryExpression, ReusedSubqueryExec, SparkPlan} import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec} -import org.apache.spark.sql.util.ReuseMap /** * Find out duplicated exchanges and subqueries in the whole spark plan including subqueries, then @@ -36,24 +37,34 @@ case object ReuseExchangeAndSubquery extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { if (conf.exchangeReuseEnabled || conf.subqueryReuseEnabled) { - val exchanges = new ReuseMap[Exchange, SparkPlan]() - val subqueries = new ReuseMap[BaseSubqueryExec, SparkPlan]() + val exchanges = mutable.Map.empty[SparkPlan, Exchange] + val subqueries = mutable.Map.empty[SparkPlan, BaseSubqueryExec] def reuse(plan: SparkPlan): SparkPlan = { plan.transformUpWithPruning(_.containsAnyPattern(EXCHANGE, PLAN_EXPRESSION)) { case exchange: Exchange if conf.exchangeReuseEnabled => - exchanges.reuseOrElseAdd(exchange, ReusedExchangeExec(exchange.output, _)) + val cachedExchange = exchanges.getOrElseUpdate(exchange.canonicalized, exchange) + if (cachedExchange.ne(exchange)) { + ReusedExchangeExec(exchange.output, cachedExchange) + } else { + cachedExchange + } case other => other.transformExpressionsUpWithPruning(_.containsPattern(PLAN_EXPRESSION)) { case sub: ExecSubqueryExpression => val subquery = reuse(sub.plan).asInstanceOf[BaseSubqueryExec] - sub.withNewPlan( - if (conf.subqueryReuseEnabled) { - subqueries.reuseOrElseAdd(subquery, ReusedSubqueryExec(_)) + val newSubquery = if (conf.subqueryReuseEnabled) { + val cachedSubquery = subqueries.getOrElseUpdate(subquery.canonicalized, subquery) + if (cachedSubquery.ne(subquery)) { + ReusedSubqueryExec(cachedSubquery) } else { - subquery - }) + cachedSubquery + } + } else { + subquery + } + sub.withNewPlan(newSubquery) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 98a1089709..92c38ee228 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -479,9 +479,9 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils test("broadcast join where streamed side's output partitioning is PartitioningCollection") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { val t1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1") - val t2 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i2", "j2") + val t2 = (0 until 100).map(i => (i % 5, i % 14)).toDF("i2", "j2") val t3 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i3", "j3") - val t4 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i4", "j4") + val t4 = (0 until 100).map(i => (i % 5, i % 15)).toDF("i4", "j4") // join1 is a sort merge join (shuffle on the both sides). val join1 = t1.join(t2, t1("i1") === t2("i2"))