[SPARK-36747][SQL] Do not collapse Project with Aggregate when correlated subqueries are present in the project list
### What changes were proposed in this pull request?
This PR adds a check in the optimizer rule `CollapseProject` to avoid combining Project with Aggregate when the project list contains one or more correlated scalar subqueries that reference the output of the aggregate. Combining Project with Aggregate can lead to an invalid plan after correlated subquery rewrite. This is because correlated scalar subqueries' references are used as join conditions, which cannot host aggregate expressions.
For example
```sql
select (select sum(c2) from t where c1 = cast(s as int)) from (select sum(c2) s from t)
```
```
== Optimized Logical Plan ==
Aggregate [sum(c2)#10L AS scalarsubquery(s)#11L] <--- Aggregate has neither grouping nor aggregate expressions.
+- Project [sum(c2)#10L]
+- Join LeftOuter, (c1#2 = cast(sum(c2#3) as int)) <--- Aggregate expression in join condition
:- LocalRelation [c2#3]
+- Aggregate [c1#2], [sum(c2#3) AS sum(c2)#10L, c1#2]
+- LocalRelation [c1#2, c2#3]
java.lang.UnsupportedOperationException: Cannot generate code for expression: sum(input[0, int, false])
```
Currently, we only allow a correlated scalar subquery in Aggregate if it is also in the grouping expressions.
079a9c5292/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala (L661-L666)
### Why are the changes needed?
To fix an existing optimizer issue.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit test.
Closes #33990 from allisonwang-db/spark-36747-collapse-agg.
Authored-by: allisonwang-db <allison.wang@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
8a1a91bd71
commit
4a8dc5f7a3
|
@ -924,7 +924,8 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
|
||||||
if canCollapseExpressions(p1.projectList, p2.projectList, alwaysInline) =>
|
if canCollapseExpressions(p1.projectList, p2.projectList, alwaysInline) =>
|
||||||
p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList))
|
p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList))
|
||||||
case p @ Project(_, agg: Aggregate)
|
case p @ Project(_, agg: Aggregate)
|
||||||
if canCollapseExpressions(p.projectList, agg.aggregateExpressions, alwaysInline) =>
|
if canCollapseExpressions(p.projectList, agg.aggregateExpressions, alwaysInline) &&
|
||||||
|
canCollapseAggregate(p, agg) =>
|
||||||
agg.copy(aggregateExpressions = buildCleanedProjectList(
|
agg.copy(aggregateExpressions = buildCleanedProjectList(
|
||||||
p.projectList, agg.aggregateExpressions))
|
p.projectList, agg.aggregateExpressions))
|
||||||
case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _))))
|
case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _))))
|
||||||
|
@ -982,6 +983,18 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
|
||||||
case _ => false
|
case _ => false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A project cannot be collapsed with an aggregate when there are correlated scalar
|
||||||
|
* subqueries in the project list, because currently we only allow correlated subqueries
|
||||||
|
* in aggregate if they are also part of the grouping expressions. Otherwise the plan
|
||||||
|
* after subquery rewrite will not be valid.
|
||||||
|
*/
|
||||||
|
private def canCollapseAggregate(p: Project, a: Aggregate): Boolean = {
|
||||||
|
p.projectList.forall(_.collect {
|
||||||
|
case s: ScalarSubquery if s.outerAttrs.nonEmpty => s
|
||||||
|
}.isEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
private def buildCleanedProjectList(
|
private def buildCleanedProjectList(
|
||||||
upper: Seq[NamedExpression],
|
upper: Seq[NamedExpression],
|
||||||
lower: Seq[NamedExpression]): Seq[NamedExpression] = {
|
lower: Seq[NamedExpression]): Seq[NamedExpression] = {
|
||||||
|
|
|
@ -1902,4 +1902,22 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
|
||||||
assert(exchanges.size === 1)
|
assert(exchanges.size === 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-36747: should not combine Project with Aggregate") {
|
||||||
|
withTempView("t") {
|
||||||
|
Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t")
|
||||||
|
checkAnswer(
|
||||||
|
sql("""
|
||||||
|
|SELECT m, (SELECT SUM(c2) FROM t WHERE c1 = m)
|
||||||
|
|FROM (SELECT MIN(c2) AS m FROM t)
|
||||||
|
|""".stripMargin),
|
||||||
|
Row(1, 2) :: Nil)
|
||||||
|
checkAnswer(
|
||||||
|
sql("""
|
||||||
|
|SELECT c, (SELECT SUM(c2) FROM t WHERE c1 = c)
|
||||||
|
|FROM (SELECT c1 AS c FROM t GROUP BY c1)
|
||||||
|
|""".stripMargin),
|
||||||
|
Row(0, 1) :: Row(1, 2) :: Nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue