[SPARK-35545][SQL] Split SubqueryExpression's children field into outer attributes and join conditions
### What changes were proposed in this pull request?
This PR refactors `SubqueryExpression` class. It removes the children field from SubqueryExpression's constructor and adds `outerAttrs` and `joinCond`.
### Why are the changes needed?
Currently, the children field of a subquery expression is used to store both collected outer references in the subquery plan and join conditions after correlated predicates are pulled up.
For example:
`SELECT (SELECT max(c1) FROM t1 WHERE t1.c1 = t2.c1) FROM t2`
During the analysis phase, outer references in the subquery are stored in the children field: `scalar-subquery [t2.c1]`, but after the optimizer rule `PullupCorrelatedPredicates`, the children field will be used to store the join conditions, which contain both the inner and the outer references: `scalar-subquery [t1.c1 = t2.c1]`. This is why the references of SubqueryExpression excludes the inner plan's output:
29ed1a2de4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala (L68-L69)
This can be confusing and error-prone. The references for a subquery expression should always be defined as outer attribute references.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
Closes #32687 from allisonwang-db/refactor-subquery-expr.
Authored-by: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
1a55019b1f
commit
806da9d6fa
|
@ -2343,11 +2343,11 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
|
||||
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY,
|
||||
EXISTS_SUBQUERY, IN_SUBQUERY), ruleId) {
|
||||
case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
|
||||
case s @ ScalarSubquery(sub, _, exprId, _) if !sub.resolved =>
|
||||
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
|
||||
case e @ Exists(sub, _, exprId) if !sub.resolved =>
|
||||
case e @ Exists(sub, _, exprId, _) if !sub.resolved =>
|
||||
resolveSubQuery(e, plans)(Exists(_, _, exprId))
|
||||
case InSubquery(values, l @ ListQuery(_, _, exprId, _))
|
||||
case InSubquery(values, l @ ListQuery(_, _, exprId, _, _))
|
||||
if values.forall(_.resolved) && !l.resolved =>
|
||||
val expr = resolveSubQuery(l, plans)((plan, exprs) => {
|
||||
ListQuery(plan, exprs, exprId, plan.output)
|
||||
|
|
|
@ -744,17 +744,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
|
|||
checkAnalysis(expr.plan)
|
||||
|
||||
expr match {
|
||||
case ScalarSubquery(query, conditions, _) =>
|
||||
case ScalarSubquery(query, outerAttrs, _, _) =>
|
||||
// Scalar subquery must return one column as output.
|
||||
if (query.output.size != 1) {
|
||||
failAnalysis(
|
||||
s"Scalar subquery must return only one column, but got ${query.output.size}")
|
||||
}
|
||||
|
||||
if (conditions.nonEmpty) {
|
||||
if (outerAttrs.nonEmpty) {
|
||||
cleanQueryInScalarSubquery(query) match {
|
||||
case a: Aggregate => checkAggregateInScalarSubquery(conditions, query, a)
|
||||
case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(conditions, query, a)
|
||||
case a: Aggregate => checkAggregateInScalarSubquery(outerAttrs, query, a)
|
||||
case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(outerAttrs, query, a)
|
||||
case p: LogicalPlan if p.maxRows.exists(_ <= 1) => // Ok
|
||||
case fail => failAnalysis(s"Correlated scalar subqueries must be aggregated: $fail")
|
||||
}
|
||||
|
|
|
@ -322,7 +322,7 @@ abstract class TypeCoercionBase {
|
|||
|
||||
// Handle type casting required between value expression and subquery output
|
||||
// in IN subquery.
|
||||
case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _))
|
||||
case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _, conditions))
|
||||
if !i.resolved && lhs.length == sub.output.length =>
|
||||
// LHS is the value expressions of IN subquery.
|
||||
// RHS is the subquery output.
|
||||
|
@ -345,7 +345,7 @@ abstract class TypeCoercionBase {
|
|||
}
|
||||
|
||||
val newSub = Project(castedRhs, sub)
|
||||
InSubquery(newLhs, ListQuery(newSub, children, exprId, newSub.output))
|
||||
InSubquery(newLhs, ListQuery(newSub, children, exprId, newSub.output, conditions))
|
||||
} else {
|
||||
i
|
||||
}
|
||||
|
|
|
@ -59,14 +59,22 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression {
|
|||
|
||||
/**
|
||||
* A base interface for expressions that contain a [[LogicalPlan]].
|
||||
*
|
||||
* @param plan: the subquery plan
|
||||
* @param outerAttrs: the outer references in the subquery plan
|
||||
* @param exprId: ID of the expression
|
||||
* @param joinCond: the join conditions with the outer query. It contains both inner and outer
|
||||
* query references.
|
||||
*/
|
||||
abstract class SubqueryExpression(
|
||||
plan: LogicalPlan,
|
||||
children: Seq[Expression],
|
||||
exprId: ExprId) extends PlanExpression[LogicalPlan] {
|
||||
outerAttrs: Seq[Expression],
|
||||
exprId: ExprId,
|
||||
joinCond: Seq[Expression] = Nil) extends PlanExpression[LogicalPlan] {
|
||||
override lazy val resolved: Boolean = childrenResolved && plan.resolved
|
||||
override lazy val references: AttributeSet =
|
||||
if (plan.resolved) super.references -- plan.outputSet else super.references
|
||||
AttributeSet.fromAttributeSets(outerAttrs.map(_.references))
|
||||
override def children: Seq[Expression] = outerAttrs ++ joinCond
|
||||
override def withNewPlan(plan: LogicalPlan): SubqueryExpression
|
||||
override def semanticEquals(o: Expression): Boolean = o match {
|
||||
case p: SubqueryExpression =>
|
||||
|
@ -240,9 +248,10 @@ object SubExprUtils extends PredicateHelper {
|
|||
*/
|
||||
case class ScalarSubquery(
|
||||
plan: LogicalPlan,
|
||||
children: Seq[Expression] = Seq.empty,
|
||||
exprId: ExprId = NamedExpression.newExprId)
|
||||
extends SubqueryExpression(plan, children, exprId) with Unevaluable {
|
||||
outerAttrs: Seq[Expression] = Seq.empty,
|
||||
exprId: ExprId = NamedExpression.newExprId,
|
||||
joinCond: Seq[Expression] = Seq.empty)
|
||||
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond) with Unevaluable {
|
||||
override def dataType: DataType = {
|
||||
assert(plan.schema.fields.nonEmpty, "Scalar subquery should have only one column")
|
||||
plan.schema.fields.head.dataType
|
||||
|
@ -253,12 +262,16 @@ case class ScalarSubquery(
|
|||
override lazy val canonicalized: Expression = {
|
||||
ScalarSubquery(
|
||||
plan.canonicalized,
|
||||
children.map(_.canonicalized),
|
||||
ExprId(0))
|
||||
outerAttrs.map(_.canonicalized),
|
||||
ExprId(0),
|
||||
joinCond.map(_.canonicalized))
|
||||
}
|
||||
|
||||
override protected def withNewChildrenInternal(
|
||||
newChildren: IndexedSeq[Expression]): ScalarSubquery = copy(children = newChildren)
|
||||
newChildren: IndexedSeq[Expression]): ScalarSubquery =
|
||||
copy(
|
||||
outerAttrs = newChildren.take(outerAttrs.size),
|
||||
joinCond = newChildren.drop(outerAttrs.size))
|
||||
|
||||
final override def nodePatternsInternal: Seq[TreePattern] = Seq(SCALAR_SUBQUERY)
|
||||
}
|
||||
|
@ -286,10 +299,11 @@ object ScalarSubquery {
|
|||
*/
|
||||
case class ListQuery(
|
||||
plan: LogicalPlan,
|
||||
children: Seq[Expression] = Seq.empty,
|
||||
outerAttrs: Seq[Expression] = Seq.empty,
|
||||
exprId: ExprId = NamedExpression.newExprId,
|
||||
childOutputs: Seq[Attribute] = Seq.empty)
|
||||
extends SubqueryExpression(plan, children, exprId) with Unevaluable {
|
||||
childOutputs: Seq[Attribute] = Seq.empty,
|
||||
joinCond: Seq[Expression] = Seq.empty)
|
||||
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond) with Unevaluable {
|
||||
override def dataType: DataType = if (childOutputs.length > 1) {
|
||||
childOutputs.toStructType
|
||||
} else {
|
||||
|
@ -302,13 +316,16 @@ case class ListQuery(
|
|||
override lazy val canonicalized: Expression = {
|
||||
ListQuery(
|
||||
plan.canonicalized,
|
||||
children.map(_.canonicalized),
|
||||
outerAttrs.map(_.canonicalized),
|
||||
ExprId(0),
|
||||
childOutputs.map(_.canonicalized.asInstanceOf[Attribute]))
|
||||
childOutputs.map(_.canonicalized.asInstanceOf[Attribute]),
|
||||
joinCond.map(_.canonicalized))
|
||||
}
|
||||
|
||||
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ListQuery =
|
||||
copy(children = newChildren)
|
||||
copy(
|
||||
outerAttrs = newChildren.take(outerAttrs.size),
|
||||
joinCond = newChildren.drop(outerAttrs.size))
|
||||
|
||||
final override def nodePatternsInternal: Seq[TreePattern] = Seq(LIST_SUBQUERY)
|
||||
}
|
||||
|
@ -341,21 +358,25 @@ case class ListQuery(
|
|||
*/
|
||||
case class Exists(
|
||||
plan: LogicalPlan,
|
||||
children: Seq[Expression] = Seq.empty,
|
||||
exprId: ExprId = NamedExpression.newExprId)
|
||||
extends SubqueryExpression(plan, children, exprId) with Predicate with Unevaluable {
|
||||
outerAttrs: Seq[Expression] = Seq.empty,
|
||||
exprId: ExprId = NamedExpression.newExprId,
|
||||
joinCond: Seq[Expression] = Seq.empty)
|
||||
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond) with Predicate with Unevaluable {
|
||||
override def nullable: Boolean = false
|
||||
override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan)
|
||||
override def toString: String = s"exists#${exprId.id} $conditionString"
|
||||
override lazy val canonicalized: Expression = {
|
||||
Exists(
|
||||
plan.canonicalized,
|
||||
children.map(_.canonicalized),
|
||||
ExprId(0))
|
||||
outerAttrs.map(_.canonicalized),
|
||||
ExprId(0),
|
||||
joinCond.map(_.canonicalized))
|
||||
}
|
||||
|
||||
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Exists =
|
||||
copy(children = newChildren)
|
||||
copy(
|
||||
outerAttrs = newChildren.take(outerAttrs.size),
|
||||
joinCond = newChildren.drop(outerAttrs.size))
|
||||
|
||||
final override def nodePatternsInternal: Seq[TreePattern] = Seq(EXISTS_SUBQUERY)
|
||||
}
|
||||
|
|
|
@ -110,19 +110,19 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
|
|||
|
||||
// Filter the plan by applying left semi and left anti joins.
|
||||
withSubquery.foldLeft(newFilter) {
|
||||
case (p, Exists(sub, conditions, _)) =>
|
||||
case (p, Exists(sub, _, _, conditions)) =>
|
||||
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
|
||||
buildJoin(outerPlan, sub, LeftSemi, joinCond)
|
||||
case (p, Not(Exists(sub, conditions, _))) =>
|
||||
case (p, Not(Exists(sub, _, _, conditions))) =>
|
||||
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
|
||||
buildJoin(outerPlan, sub, LeftAnti, joinCond)
|
||||
case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) =>
|
||||
case (p, InSubquery(values, ListQuery(sub, _, _, _, conditions))) =>
|
||||
// Deduplicate conflicting attributes if any.
|
||||
val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
|
||||
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
|
||||
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
|
||||
Join(outerPlan, newSub, LeftSemi, joinCond, JoinHint.NONE)
|
||||
case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) =>
|
||||
case (p, Not(InSubquery(values, ListQuery(sub, _, _, _, conditions)))) =>
|
||||
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
|
||||
// Construct the condition. A NULL in one of the conditions is regarded as a positive
|
||||
// result; such a row will be filtered out by the Anti-Join operator.
|
||||
|
@ -166,12 +166,12 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
|
|||
var newPlan = plan
|
||||
val newExprs = exprs.map { e =>
|
||||
e.transformDownWithPruning(_.containsAnyPattern(EXISTS_SUBQUERY, IN_SUBQUERY)) {
|
||||
case Exists(sub, conditions, _) =>
|
||||
case Exists(sub, _, _, conditions) =>
|
||||
val exists = AttributeReference("exists", BooleanType, nullable = false)()
|
||||
newPlan =
|
||||
buildJoin(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
|
||||
exists
|
||||
case Not(InSubquery(values, ListQuery(sub, conditions, _, _))) =>
|
||||
case Not(InSubquery(values, ListQuery(sub, _, _, _, conditions))) =>
|
||||
val exists = AttributeReference("exists", BooleanType, nullable = false)()
|
||||
// Deduplicate conflicting attributes if any.
|
||||
val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values))
|
||||
|
@ -192,7 +192,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
|
|||
val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And)
|
||||
newPlan = Join(newPlan, newSub, ExistenceJoin(exists), Some(finalJoinCond), JoinHint.NONE)
|
||||
Not(exists)
|
||||
case InSubquery(values, ListQuery(sub, conditions, _, _)) =>
|
||||
case InSubquery(values, ListQuery(sub, _, _, _, conditions)) =>
|
||||
val exists = AttributeReference("exists", BooleanType, nullable = false)()
|
||||
// Deduplicate conflicting attributes if any.
|
||||
val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values))
|
||||
|
@ -306,15 +306,15 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
|
|||
|
||||
plan.transformExpressionsWithPruning(_.containsAnyPattern(
|
||||
SCALAR_SUBQUERY, EXISTS_SUBQUERY, LIST_SUBQUERY)) {
|
||||
case ScalarSubquery(sub, children, exprId) if children.nonEmpty =>
|
||||
case ScalarSubquery(sub, children, exprId, conditions) if children.nonEmpty =>
|
||||
val (newPlan, newCond) = decorrelate(sub, outerPlans)
|
||||
ScalarSubquery(newPlan, getJoinCondition(newCond, children), exprId)
|
||||
case Exists(sub, children, exprId) if children.nonEmpty =>
|
||||
ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions))
|
||||
case Exists(sub, children, exprId, conditions) if children.nonEmpty =>
|
||||
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
|
||||
Exists(newPlan, getJoinCondition(newCond, children), exprId)
|
||||
case ListQuery(sub, children, exprId, childOutputs) if children.nonEmpty =>
|
||||
Exists(newPlan, children, exprId, getJoinCondition(newCond, conditions))
|
||||
case ListQuery(sub, children, exprId, childOutputs, conditions) if children.nonEmpty =>
|
||||
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
|
||||
ListQuery(newPlan, getJoinCondition(newCond, children), exprId, childOutputs)
|
||||
ListQuery(newPlan, children, exprId, childOutputs, getJoinCondition(newCond, conditions))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -524,7 +524,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
|
|||
subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = {
|
||||
val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]()
|
||||
val newChild = subqueries.foldLeft(child) {
|
||||
case (currentChild, ScalarSubquery(sub, conditions, _)) =>
|
||||
case (currentChild, ScalarSubquery(sub, _, _, conditions)) =>
|
||||
val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions)
|
||||
val origOutput = query.output.head
|
||||
|
||||
|
|
|
@ -118,14 +118,14 @@ case class InsertAdaptiveSparkPlan(
|
|||
return subqueryMap.toMap
|
||||
}
|
||||
plan.foreach(_.expressions.foreach(_.foreach {
|
||||
case expressions.ScalarSubquery(p, _, exprId)
|
||||
case expressions.ScalarSubquery(p, _, exprId, _)
|
||||
if !subqueryMap.contains(exprId.id) =>
|
||||
val executedPlan = compileSubquery(p)
|
||||
verifyAdaptivePlan(executedPlan, p)
|
||||
val subquery = SubqueryExec.createForScalarSubquery(
|
||||
s"subquery#${exprId.id}", executedPlan)
|
||||
subqueryMap.put(exprId.id, subquery)
|
||||
case expressions.InSubquery(_, ListQuery(query, _, exprId, _))
|
||||
case expressions.InSubquery(_, ListQuery(query, _, exprId, _, _))
|
||||
if !subqueryMap.contains(exprId.id) =>
|
||||
val executedPlan = compileSubquery(query)
|
||||
verifyAdaptivePlan(executedPlan, query)
|
||||
|
|
|
@ -31,9 +31,9 @@ case class PlanAdaptiveSubqueries(
|
|||
def apply(plan: SparkPlan): SparkPlan = {
|
||||
plan.transformAllExpressionsWithPruning(
|
||||
_.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) {
|
||||
case expressions.ScalarSubquery(_, _, exprId) =>
|
||||
case expressions.ScalarSubquery(_, _, exprId, _) =>
|
||||
execution.ScalarSubquery(subqueryMap(exprId.id), exprId)
|
||||
case expressions.InSubquery(values, ListQuery(_, _, exprId, _)) =>
|
||||
case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _)) =>
|
||||
val expr = if (values.length == 1) {
|
||||
values.head
|
||||
} else {
|
||||
|
|
|
@ -185,7 +185,7 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
|
|||
SubqueryExec.createForScalarSubquery(
|
||||
s"scalar-subquery#${subquery.exprId.id}", executedPlan),
|
||||
subquery.exprId)
|
||||
case expressions.InSubquery(values, ListQuery(query, _, exprId, _)) =>
|
||||
case expressions.InSubquery(values, ListQuery(query, _, exprId, _, _)) =>
|
||||
val expr = if (values.length == 1) {
|
||||
values.head
|
||||
} else {
|
||||
|
|
|
@ -970,7 +970,7 @@ class PlanResolutionSuite extends AnalysisTest {
|
|||
query match {
|
||||
case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", Seq()),
|
||||
UnresolvedSubqueryColumnAliases(outputColumnNames, Project(_, _: OneRowRelation)))),
|
||||
_, _, _) =>
|
||||
_, _, _, _) =>
|
||||
assert(projects.size == 1 && projects.head.name == "s.name")
|
||||
assert(outputColumnNames.size == 1 && outputColumnNames.head == "name")
|
||||
case o => fail("Unexpected subquery: \n" + o.treeString)
|
||||
|
@ -1046,7 +1046,7 @@ class PlanResolutionSuite extends AnalysisTest {
|
|||
query match {
|
||||
case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", Seq()),
|
||||
UnresolvedSubqueryColumnAliases(outputColumnNames, Project(_, _: OneRowRelation)))),
|
||||
_, _, _) =>
|
||||
_, _, _, _) =>
|
||||
assert(projects.size == 1 && projects.head.name == "s.name")
|
||||
assert(outputColumnNames.size == 1 && outputColumnNames.head == "name")
|
||||
case o => fail("Unexpected subquery: \n" + o.treeString)
|
||||
|
|
Loading…
Reference in a new issue