From 4a8dc5f7a37b4ea84682b3ee67243bb5f030f302 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Thu, 23 Sep 2021 12:50:27 +0800 Subject: [PATCH] [SPARK-36747][SQL] Do not collapse Project with Aggregate when correlated subqueries are present in the project list ### What changes were proposed in this pull request? This PR adds a check in the optimizer rule `CollapseProject` to avoid combining Project with Aggregate when the project list contains one or more correlated scalar subqueries that reference the output of the aggregate. Combining Project with Aggregate can lead to an invalid plan after correlated subquery rewrite. This is because correlated scalar subqueries' references are used as join conditions, which cannot host aggregate expressions. For example ```sql select (select sum(c2) from t where c1 = cast(s as int)) from (select sum(c2) s from t) ``` ``` == Optimized Logical Plan == Aggregate [sum(c2)#10L AS scalarsubquery(s)#11L] <--- Aggregate has neither grouping nor aggregate expressions. +- Project [sum(c2)#10L] +- Join LeftOuter, (c1#2 = cast(sum(c2#3) as int)) <--- Aggregate expression in join condition :- LocalRelation [c2#3] +- Aggregate [c1#2], [sum(c2#3) AS sum(c2)#10L, c1#2] +- LocalRelation [c1#2, c2#3] java.lang.UnsupportedOperationException: Cannot generate code for expression: sum(input[0, int, false]) ``` Currently, we only allow a correlated scalar subquery in Aggregate if it is also in the grouping expressions. https://github.com/apache/spark/blob/079a9c52925818532b57c9cec1ddd31be723885e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala#L661-L666 ### Why are the changes needed? To fix an existing optimizer issue. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test. Closes #33990 from allisonwang-db/spark-36747-collapse-agg. Authored-by: allisonwang-db Signed-off-by: Wenchen Fan --- .../sql/catalyst/optimizer/Optimizer.scala | 15 ++++++++++++++- .../org/apache/spark/sql/SubquerySuite.scala | 18 ++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9ded7b1f99..ed16185ae2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -924,7 +924,8 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { if canCollapseExpressions(p1.projectList, p2.projectList, alwaysInline) => p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) case p @ Project(_, agg: Aggregate) - if canCollapseExpressions(p.projectList, agg.aggregateExpressions, alwaysInline) => + if canCollapseExpressions(p.projectList, agg.aggregateExpressions, alwaysInline) && + canCollapseAggregate(p, agg) => agg.copy(aggregateExpressions = buildCleanedProjectList( p.projectList, agg.aggregateExpressions)) case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _)))) @@ -982,6 +983,18 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { case _ => false } + /** + * A project cannot be collapsed with an aggregate when there are correlated scalar + * subqueries in the project list, because currently we only allow correlated subqueries + * in aggregate if they are also part of the grouping expressions. Otherwise the plan + * after subquery rewrite will not be valid. + */ + private def canCollapseAggregate(p: Project, a: Aggregate): Boolean = { + p.projectList.forall(_.collect { + case s: ScalarSubquery if s.outerAttrs.nonEmpty => s + }.isEmpty) + } + private def buildCleanedProjectList( upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Seq[NamedExpression] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 1e3a61cfbd..e4d5bd305f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -1902,4 +1902,22 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark assert(exchanges.size === 1) } } + + test("SPARK-36747: should not combine Project with Aggregate") { + withTempView("t") { + Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t") + checkAnswer( + sql(""" + |SELECT m, (SELECT SUM(c2) FROM t WHERE c1 = m) + |FROM (SELECT MIN(c2) AS m FROM t) + |""".stripMargin), + Row(1, 2) :: Nil) + checkAnswer( + sql(""" + |SELECT c, (SELECT SUM(c2) FROM t WHERE c1 = c) + |FROM (SELECT c1 AS c FROM t GROUP BY c1) + |""".stripMargin), + Row(0, 1) :: Row(1, 2) :: Nil) + } + } }