diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index da23b96afa..9811199a85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -920,7 +920,8 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) } case p @ Project(_, agg: Aggregate) => - if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) { + if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions) || + !canCollapseAggregate(p, agg)) { p } else { agg.copy(aggregateExpressions = buildCleanedProjectList( @@ -950,6 +951,18 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { }.exists(!_.deterministic)) } + /** + * A project cannot be collapsed with an aggregate when there are correlated scalar + * subqueries in the project list, because currently we only allow correlated subqueries + * in aggregate if they are also part of the grouping expressions. Otherwise the plan + * after subquery rewrite will not be valid. + */ + private def canCollapseAggregate(p: Project, a: Aggregate): Boolean = { + p.projectList.forall(_.collect { + case s: ScalarSubquery if s.outerAttrs.nonEmpty => s + }.isEmpty) + } + private def buildCleanedProjectList( upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Seq[NamedExpression] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index c3362b377e..448909f602 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -1877,4 +1877,22 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark "ReusedSubqueryExec should reuse an existing subquery") } } + + test("SPARK-36747: should not combine Project with Aggregate") { + withTempView("t") { + Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t") + checkAnswer( + sql(""" + |SELECT m, (SELECT SUM(c2) FROM t WHERE c1 = m) + |FROM (SELECT MIN(c2) AS m FROM t) + |""".stripMargin), + Row(1, 2) :: Nil) + checkAnswer( + sql(""" + |SELECT c, (SELECT SUM(c2) FROM t WHERE c1 = c) + |FROM (SELECT c1 AS c FROM t GROUP BY c1) + |""".stripMargin), + Row(0, 1) :: Row(1, 2) :: Nil) + } + } }