diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9ded7b1f99..ed16185ae2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -924,7 +924,8 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { if canCollapseExpressions(p1.projectList, p2.projectList, alwaysInline) => p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) case p @ Project(_, agg: Aggregate) - if canCollapseExpressions(p.projectList, agg.aggregateExpressions, alwaysInline) => + if canCollapseExpressions(p.projectList, agg.aggregateExpressions, alwaysInline) && + canCollapseAggregate(p, agg) => agg.copy(aggregateExpressions = buildCleanedProjectList( p.projectList, agg.aggregateExpressions)) case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _)))) @@ -982,6 +983,18 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { case _ => false } + /** + * A project cannot be collapsed with an aggregate when there are correlated scalar + * subqueries in the project list, because currently we only allow correlated subqueries + * in aggregate if they are also part of the grouping expressions. Otherwise the plan + * after subquery rewrite will not be valid. + */ + private def canCollapseAggregate(p: Project, a: Aggregate): Boolean = { + p.projectList.forall(_.collect { + case s: ScalarSubquery if s.outerAttrs.nonEmpty => s + }.isEmpty) + } + private def buildCleanedProjectList( upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Seq[NamedExpression] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 1e3a61cfbd..e4d5bd305f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -1902,4 +1902,22 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark assert(exchanges.size === 1) } } + + test("SPARK-36747: should not combine Project with Aggregate") { + withTempView("t") { + Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t") + checkAnswer( + sql(""" + |SELECT m, (SELECT SUM(c2) FROM t WHERE c1 = m) + |FROM (SELECT MIN(c2) AS m FROM t) + |""".stripMargin), + Row(1, 2) :: Nil) + checkAnswer( + sql(""" + |SELECT c, (SELECT SUM(c2) FROM t WHERE c1 = c) + |FROM (SELECT c1 AS c FROM t GROUP BY c1) + |""".stripMargin), + Row(0, 1) :: Row(1, 2) :: Nil) + } + } }