[SPARK-35545][SQL] Split SubqueryExpression's children field into outer attributes and join conditions

### What changes were proposed in this pull request?
This PR refactors `SubqueryExpression` class. It removes the children field from SubqueryExpression's constructor and adds `outerAttrs` and `joinCond`.

### Why are the changes needed?
Currently, the children field of a subquery expression is used to store both collected outer references in the subquery plan and join conditions after correlated predicates are pulled up.

For example:
`SELECT (SELECT max(c1) FROM t1 WHERE t1.c1 = t2.c1) FROM t2`

During the analysis phase, outer references in the subquery are stored in the children field: `scalar-subquery [t2.c1]`, but after the optimizer rule `PullupCorrelatedPredicates`, the children field will be used to store the join conditions, which contain both the inner and the outer references: `scalar-subquery [t1.c1 = t2.c1]`. This is why the references of SubqueryExpression excludes the inner plan's output:
29ed1a2de4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala (L68-L69)

This can be confusing and error-prone. The references for a subquery expression should always be defined as outer attribute references.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing tests.

Closes #32687 from allisonwang-db/refactor-subquery-expr.

Authored-by: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
allisonwang-db 2021-05-31 04:57:24 +00:00 committed by Wenchen Fan
parent 1a55019b1f
commit 806da9d6fa
9 changed files with 72 additions and 51 deletions

View file

@ -2343,11 +2343,11 @@ class Analyzer(override val catalogManager: CatalogManager)
private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY,
EXISTS_SUBQUERY, IN_SUBQUERY), ruleId) {
case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
case s @ ScalarSubquery(sub, _, exprId, _) if !sub.resolved =>
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, _, exprId) if !sub.resolved =>
case e @ Exists(sub, _, exprId, _) if !sub.resolved =>
resolveSubQuery(e, plans)(Exists(_, _, exprId))
case InSubquery(values, l @ ListQuery(_, _, exprId, _))
case InSubquery(values, l @ ListQuery(_, _, exprId, _, _))
if values.forall(_.resolved) && !l.resolved =>
val expr = resolveSubQuery(l, plans)((plan, exprs) => {
ListQuery(plan, exprs, exprId, plan.output)

View file

@ -744,17 +744,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
checkAnalysis(expr.plan)
expr match {
case ScalarSubquery(query, conditions, _) =>
case ScalarSubquery(query, outerAttrs, _, _) =>
// Scalar subquery must return one column as output.
if (query.output.size != 1) {
failAnalysis(
s"Scalar subquery must return only one column, but got ${query.output.size}")
}
if (conditions.nonEmpty) {
if (outerAttrs.nonEmpty) {
cleanQueryInScalarSubquery(query) match {
case a: Aggregate => checkAggregateInScalarSubquery(conditions, query, a)
case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(conditions, query, a)
case a: Aggregate => checkAggregateInScalarSubquery(outerAttrs, query, a)
case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(outerAttrs, query, a)
case p: LogicalPlan if p.maxRows.exists(_ <= 1) => // Ok
case fail => failAnalysis(s"Correlated scalar subqueries must be aggregated: $fail")
}

View file

@ -322,7 +322,7 @@ abstract class TypeCoercionBase {
// Handle type casting required between value expression and subquery output
// in IN subquery.
case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _))
case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _, conditions))
if !i.resolved && lhs.length == sub.output.length =>
// LHS is the value expressions of IN subquery.
// RHS is the subquery output.
@ -345,7 +345,7 @@ abstract class TypeCoercionBase {
}
val newSub = Project(castedRhs, sub)
InSubquery(newLhs, ListQuery(newSub, children, exprId, newSub.output))
InSubquery(newLhs, ListQuery(newSub, children, exprId, newSub.output, conditions))
} else {
i
}

View file

@ -59,14 +59,22 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression {
/**
* A base interface for expressions that contain a [[LogicalPlan]].
*
* @param plan: the subquery plan
* @param outerAttrs: the outer references in the subquery plan
* @param exprId: ID of the expression
* @param joinCond: the join conditions with the outer query. It contains both inner and outer
* query references.
*/
abstract class SubqueryExpression(
plan: LogicalPlan,
children: Seq[Expression],
exprId: ExprId) extends PlanExpression[LogicalPlan] {
outerAttrs: Seq[Expression],
exprId: ExprId,
joinCond: Seq[Expression] = Nil) extends PlanExpression[LogicalPlan] {
override lazy val resolved: Boolean = childrenResolved && plan.resolved
override lazy val references: AttributeSet =
if (plan.resolved) super.references -- plan.outputSet else super.references
AttributeSet.fromAttributeSets(outerAttrs.map(_.references))
override def children: Seq[Expression] = outerAttrs ++ joinCond
override def withNewPlan(plan: LogicalPlan): SubqueryExpression
override def semanticEquals(o: Expression): Boolean = o match {
case p: SubqueryExpression =>
@ -240,9 +248,10 @@ object SubExprUtils extends PredicateHelper {
*/
case class ScalarSubquery(
plan: LogicalPlan,
children: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId)
extends SubqueryExpression(plan, children, exprId) with Unevaluable {
outerAttrs: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId,
joinCond: Seq[Expression] = Seq.empty)
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond) with Unevaluable {
override def dataType: DataType = {
assert(plan.schema.fields.nonEmpty, "Scalar subquery should have only one column")
plan.schema.fields.head.dataType
@ -253,12 +262,16 @@ case class ScalarSubquery(
override lazy val canonicalized: Expression = {
ScalarSubquery(
plan.canonicalized,
children.map(_.canonicalized),
ExprId(0))
outerAttrs.map(_.canonicalized),
ExprId(0),
joinCond.map(_.canonicalized))
}
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): ScalarSubquery = copy(children = newChildren)
newChildren: IndexedSeq[Expression]): ScalarSubquery =
copy(
outerAttrs = newChildren.take(outerAttrs.size),
joinCond = newChildren.drop(outerAttrs.size))
final override def nodePatternsInternal: Seq[TreePattern] = Seq(SCALAR_SUBQUERY)
}
@ -286,10 +299,11 @@ object ScalarSubquery {
*/
case class ListQuery(
plan: LogicalPlan,
children: Seq[Expression] = Seq.empty,
outerAttrs: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId,
childOutputs: Seq[Attribute] = Seq.empty)
extends SubqueryExpression(plan, children, exprId) with Unevaluable {
childOutputs: Seq[Attribute] = Seq.empty,
joinCond: Seq[Expression] = Seq.empty)
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond) with Unevaluable {
override def dataType: DataType = if (childOutputs.length > 1) {
childOutputs.toStructType
} else {
@ -302,13 +316,16 @@ case class ListQuery(
override lazy val canonicalized: Expression = {
ListQuery(
plan.canonicalized,
children.map(_.canonicalized),
outerAttrs.map(_.canonicalized),
ExprId(0),
childOutputs.map(_.canonicalized.asInstanceOf[Attribute]))
childOutputs.map(_.canonicalized.asInstanceOf[Attribute]),
joinCond.map(_.canonicalized))
}
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ListQuery =
copy(children = newChildren)
copy(
outerAttrs = newChildren.take(outerAttrs.size),
joinCond = newChildren.drop(outerAttrs.size))
final override def nodePatternsInternal: Seq[TreePattern] = Seq(LIST_SUBQUERY)
}
@ -341,21 +358,25 @@ case class ListQuery(
*/
case class Exists(
plan: LogicalPlan,
children: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId)
extends SubqueryExpression(plan, children, exprId) with Predicate with Unevaluable {
outerAttrs: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId,
joinCond: Seq[Expression] = Seq.empty)
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond) with Predicate with Unevaluable {
override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan)
override def toString: String = s"exists#${exprId.id} $conditionString"
override lazy val canonicalized: Expression = {
Exists(
plan.canonicalized,
children.map(_.canonicalized),
ExprId(0))
outerAttrs.map(_.canonicalized),
ExprId(0),
joinCond.map(_.canonicalized))
}
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Exists =
copy(children = newChildren)
copy(
outerAttrs = newChildren.take(outerAttrs.size),
joinCond = newChildren.drop(outerAttrs.size))
final override def nodePatternsInternal: Seq[TreePattern] = Seq(EXISTS_SUBQUERY)
}

View file

@ -110,19 +110,19 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// Filter the plan by applying left semi and left anti joins.
withSubquery.foldLeft(newFilter) {
case (p, Exists(sub, conditions, _)) =>
case (p, Exists(sub, _, _, conditions)) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
buildJoin(outerPlan, sub, LeftSemi, joinCond)
case (p, Not(Exists(sub, conditions, _))) =>
case (p, Not(Exists(sub, _, _, conditions))) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
buildJoin(outerPlan, sub, LeftAnti, joinCond)
case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) =>
case (p, InSubquery(values, ListQuery(sub, _, _, _, conditions))) =>
// Deduplicate conflicting attributes if any.
val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
Join(outerPlan, newSub, LeftSemi, joinCond, JoinHint.NONE)
case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) =>
case (p, Not(InSubquery(values, ListQuery(sub, _, _, _, conditions)))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is regarded as a positive
// result; such a row will be filtered out by the Anti-Join operator.
@ -166,12 +166,12 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
var newPlan = plan
val newExprs = exprs.map { e =>
e.transformDownWithPruning(_.containsAnyPattern(EXISTS_SUBQUERY, IN_SUBQUERY)) {
case Exists(sub, conditions, _) =>
case Exists(sub, _, _, conditions) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
newPlan =
buildJoin(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
exists
case Not(InSubquery(values, ListQuery(sub, conditions, _, _))) =>
case Not(InSubquery(values, ListQuery(sub, _, _, _, conditions))) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
// Deduplicate conflicting attributes if any.
val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values))
@ -192,7 +192,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And)
newPlan = Join(newPlan, newSub, ExistenceJoin(exists), Some(finalJoinCond), JoinHint.NONE)
Not(exists)
case InSubquery(values, ListQuery(sub, conditions, _, _)) =>
case InSubquery(values, ListQuery(sub, _, _, _, conditions)) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
// Deduplicate conflicting attributes if any.
val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values))
@ -306,15 +306,15 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
plan.transformExpressionsWithPruning(_.containsAnyPattern(
SCALAR_SUBQUERY, EXISTS_SUBQUERY, LIST_SUBQUERY)) {
case ScalarSubquery(sub, children, exprId) if children.nonEmpty =>
case ScalarSubquery(sub, children, exprId, conditions) if children.nonEmpty =>
val (newPlan, newCond) = decorrelate(sub, outerPlans)
ScalarSubquery(newPlan, getJoinCondition(newCond, children), exprId)
case Exists(sub, children, exprId) if children.nonEmpty =>
ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions))
case Exists(sub, children, exprId, conditions) if children.nonEmpty =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
Exists(newPlan, getJoinCondition(newCond, children), exprId)
case ListQuery(sub, children, exprId, childOutputs) if children.nonEmpty =>
Exists(newPlan, children, exprId, getJoinCondition(newCond, conditions))
case ListQuery(sub, children, exprId, childOutputs, conditions) if children.nonEmpty =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
ListQuery(newPlan, getJoinCondition(newCond, children), exprId, childOutputs)
ListQuery(newPlan, children, exprId, childOutputs, getJoinCondition(newCond, conditions))
}
}
@ -524,7 +524,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = {
val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]()
val newChild = subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(sub, conditions, _)) =>
case (currentChild, ScalarSubquery(sub, _, _, conditions)) =>
val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions)
val origOutput = query.output.head

View file

@ -118,14 +118,14 @@ case class InsertAdaptiveSparkPlan(
return subqueryMap.toMap
}
plan.foreach(_.expressions.foreach(_.foreach {
case expressions.ScalarSubquery(p, _, exprId)
case expressions.ScalarSubquery(p, _, exprId, _)
if !subqueryMap.contains(exprId.id) =>
val executedPlan = compileSubquery(p)
verifyAdaptivePlan(executedPlan, p)
val subquery = SubqueryExec.createForScalarSubquery(
s"subquery#${exprId.id}", executedPlan)
subqueryMap.put(exprId.id, subquery)
case expressions.InSubquery(_, ListQuery(query, _, exprId, _))
case expressions.InSubquery(_, ListQuery(query, _, exprId, _, _))
if !subqueryMap.contains(exprId.id) =>
val executedPlan = compileSubquery(query)
verifyAdaptivePlan(executedPlan, query)

View file

@ -31,9 +31,9 @@ case class PlanAdaptiveSubqueries(
def apply(plan: SparkPlan): SparkPlan = {
plan.transformAllExpressionsWithPruning(
_.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) {
case expressions.ScalarSubquery(_, _, exprId) =>
case expressions.ScalarSubquery(_, _, exprId, _) =>
execution.ScalarSubquery(subqueryMap(exprId.id), exprId)
case expressions.InSubquery(values, ListQuery(_, _, exprId, _)) =>
case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _)) =>
val expr = if (values.length == 1) {
values.head
} else {

View file

@ -185,7 +185,7 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
SubqueryExec.createForScalarSubquery(
s"scalar-subquery#${subquery.exprId.id}", executedPlan),
subquery.exprId)
case expressions.InSubquery(values, ListQuery(query, _, exprId, _)) =>
case expressions.InSubquery(values, ListQuery(query, _, exprId, _, _)) =>
val expr = if (values.length == 1) {
values.head
} else {

View file

@ -970,7 +970,7 @@ class PlanResolutionSuite extends AnalysisTest {
query match {
case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", Seq()),
UnresolvedSubqueryColumnAliases(outputColumnNames, Project(_, _: OneRowRelation)))),
_, _, _) =>
_, _, _, _) =>
assert(projects.size == 1 && projects.head.name == "s.name")
assert(outputColumnNames.size == 1 && outputColumnNames.head == "name")
case o => fail("Unexpected subquery: \n" + o.treeString)
@ -1046,7 +1046,7 @@ class PlanResolutionSuite extends AnalysisTest {
query match {
case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", Seq()),
UnresolvedSubqueryColumnAliases(outputColumnNames, Project(_, _: OneRowRelation)))),
_, _, _) =>
_, _, _, _) =>
assert(projects.size == 1 && projects.head.name == "s.name")
assert(outputColumnNames.size == 1 && outputColumnNames.head == "name")
case o => fail("Unexpected subquery: \n" + o.treeString)