wrapping up pushDownPredicates and start of PushDownLeftSemiAntiJoin.

nicksrules
Nick Brown 2023-07-17 09:44:53 -04:00
parent f50de1abb0
commit b912a41523
Signed by: bicknrown
GPG Key ID: 47AF495B3DCCE9C3
3 changed files with 276 additions and 5 deletions

View File

@ -6,7 +6,7 @@
- [x] PushProjectionThroughLimit,
- [x] ReorderJoin,
- [x] EliminateOuterJoin,
- [ ] PushDownPredicates,
- [x] PushDownPredicates,
- [ ] PushDownLeftSemiAntiJoin,
- [ ] PushLeftSemiLeftAntiThroughJoin,
- [ ] LimitPushDown,

View File

@ -132,7 +132,7 @@ object SparkMethods
def unaryNodeIsDeterministic(u: UnaryNode): Boolean =
u.expressions.forall(_.deterministic)
def canPushThrough(p: UnaryNode): Boolean = p match {
def canPushThroughNonJoin(p: UnaryNode): Boolean = p match {
case _: AppendColumns => true
case _: Distinct => true
case _: Generate => true
@ -148,6 +148,24 @@ object SparkMethods
case _ => false
}
def canPushThroughJoin(joinType: JoinType): Boolean = joinType match {
case _: InnerLike | LeftSemi | RightOuter | LeftOuter | LeftAnti | ExistenceJoin(_) => true
case _ => false
}
def canPushThroughConditionSemiAntiJoin(
plans: LogicalPlan,
condition: Option[Expression],
rightOp: LogicalPlan): Boolean = {
val attributes = AttributeSet(Seq(plans).flatMap(_.output))
if (condition.isDefined) {
val matched = condition.get.references.intersect(rightOp.outputSet).intersect(attributes)
matched.isEmpty
} else {
true
}
}
def canPushThroughCondition(plan: LogicalPlan, condition: Expression): Boolean =
{
val attributes = plan.outputSet
@ -322,4 +340,128 @@ object SparkMethods
u.withNewChildren(Seq(Filter(predicate, u.child)))
}
}
def splitJoinAL(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
val (pushDownCandidates, nonDeterministic) = condition.partition(_.deterministic)
val (leftEvaluateCondition, rest) =
pushDownCandidates.partition(_.references.subsetOf(left.outputSet))
val (rightEvaluateCondition, commonCondition) =
rest.partition(expr => expr.references.subsetOf(right.outputSet))
(leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ nonDeterministic)
}
def pushPredicateThroughJoinAL1(
f: Filter, filterCondition: Expression, left: LogicalPlan, right: LogicalPlan, joinType: JoinType, joinCondition: Option[Expression], hint: JoinHint
): LogicalPlan =
{
val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
splitJoinAL(splitConjunctivePredicates(filterCondition), left, right)
joinType match {
case _: InnerLike =>
// push down the single side `where` condition into respective sides
val newLeft = leftFilterConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = rightFilterConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val (newJoinConditions, others) =
commonFilterCondition.partition(canEvaluateWithinJoin)
val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And)
val join = Join(newLeft, newRight, joinType, newJoinCond, hint)
if (others.nonEmpty) {
Filter(others.reduceLeft(And), join)
} else {
join
}
case RightOuter =>
// push down the right side only `where` condition
val newLeft = left
val newRight = rightFilterConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = joinCondition
val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond, hint)
(leftFilterConditions ++ commonFilterCondition).
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
case LeftOuter | LeftExistence(_) =>
// push down the left side only `where` condition
val newLeft = leftFilterConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = right
val newJoinCond = joinCondition
val newJoin = Join(newLeft, newRight, joinType, newJoinCond, hint)
(rightFilterConditions ++ commonFilterCondition).
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
case other =>
throw new IllegalStateException(s"Unexpected join type: $other")
}
}
def pushPredicateThroughJoinAL2(
j: Join, left: LogicalPlan, right: LogicalPlan, joinType: JoinType, joinCondition: Option[Expression], hint: JoinHint
): LogicalPlan =
{
val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
splitJoinAL(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
joinType match {
case _: InnerLike | LeftSemi =>
// push down the single side only join filter for both sides sub queries
val newLeft = leftJoinConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = rightJoinConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = commonJoinCondition.reduceLeftOption(And)
Join(newLeft, newRight, joinType, newJoinCond, hint)
case RightOuter =>
// push down the left side only join filter for left side sub query
val newLeft = leftJoinConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = right
val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And)
Join(newLeft, newRight, RightOuter, newJoinCond, hint)
case LeftOuter | LeftAnti | ExistenceJoin(_) =>
// push down the right side only join filter for right sub query
val newLeft = left
val newRight = rightJoinConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And)
Join(newLeft, newRight, joinType, newJoinCond, hint)
case other =>
throw new IllegalStateException(s"Unexpected join type: $other")
}
}
def negatedPListExistsHasCorrelatedScalarSubquery(pList: Seq[NamedExpression]): Boolean =
!pList.exists(ScalarSubquery.hasCorrelatedScalarSubquery)
def pushDownLeftSemiAntiJoin1(
j: Join, p: Project, pList: Seq[NamedExpression], gChild: LogicalPlan, rightOp: LogicalPlan, joinType: JoinType, joinCond: Option[Expression], hint: JoinHint
): LogicalPlan =
{
if (joinCond.isEmpty) {
// No join condition, just push down the Join below Project
p.copy(child = Join(gChild, rightOp, joinType, joinCond, hint))
} else {
val aliasMap = getAliasMap(p)
// Do not push complex join condition
if (aliasMap.forall(_._2.child.children.isEmpty)) {
val newJoinCond = if (aliasMap.nonEmpty) {
Option(replaceAlias(joinCond.get, aliasMap))
} else {
joinCond
}
p.copy(child = Join(gChild, rightOp, joinType, newJoinCond, hint))
} else {
j
}
}
}
}

View File

@ -133,15 +133,25 @@ object Catalyst extends HardcodedDefinition
Type.Array(Type.Native("NamedExpression"))
)
Function("canPushThrough", Type.Bool)(
Function("canPushThroughNonJoin", Type.Bool)(
Type.ASTSubtype("UnaryNode")
)
Function("canPushThroughJoin", Type.Bool)(
Type.Native("JoinType")
)
Function("canPushThroughCondition", Type.Bool)(
Type.AST("LogicalPlan"),
Type.Native("Expression"),
)
Function("canPushThroughConditionSemiAntiJoin", Type.Bool)(
Type.AST("LogicalPlan"),
Type.Option(Type.Native("Expression")),
Type.AST("LogicalPlan"),
)
Function("unaryNodeIsDeterministic", Type.Bool)(
Type.ASTSubtype("UnaryNode")
)
@ -201,6 +211,39 @@ object Catalyst extends HardcodedDefinition
Type.ASTSubtype("UnaryNode"),
)
Function("pushPredicateThroughJoinAL1", Type.AST("LogicalPlan"))(
Type.Node("Filter"),
Type.Native("Expression"),
Type.AST("LogicalPlan"),
Type.AST("LogicalPlan"),
Type.Native("JoinType"),
Type.Option(Type.Native("Expression")),
Type.Native("JoinHint"),
)
Function("pushPredicateThroughJoinAL2", Type.AST("LogicalPlan"))(
Type.Node("Join"),
Type.AST("LogicalPlan"),
Type.AST("LogicalPlan"),
Type.Native("JoinType"),
Type.Option(Type.Native("Expression")),
Type.Native("JoinHint"),
)
Function("pushDownLeftSemiAntiJoin1", Type.AST("LogicalPlan"))(
Type.Node("Join"),
Type.Node("Project"),
Type.Array(Type.Native("NamedExpression")),
Type.AST("LogicalPlan"),
Type.AST("LogicalPlan"),
Type.Native("JoinType"),
Type.Option(Type.Native("Expression")),
Type.Native("JoinHint"),
)
Function("negatedPListExistsHasCorrelatedScalarSubquery", Type.Bool)(
Type.Array(Type.Native("NamedExpression"))
)
Global("JoinHint.NONE", Type.Native("JoinHint"))
Global("RightOuter", Type.Native("JoinType"))
Global("LeftOuter", Type.Native("JoinType"))
@ -613,9 +656,9 @@ object Catalyst extends HardcodedDefinition
Rule("PushDownPredicates-2-6", "LogicalPlan")(
Bind("filter", Match("Filter")(
Bind("unusedCondition"),
Bind("u", OfType(Type.ASTSubtype("UnaryNode"))), // here be dragons.
Bind("u", OfType(Type.ASTSubtype("UnaryNode"))),
)) and Test(
Apply("canPushThrough")(
Apply("canPushThroughNonJoin")(
Ref("u")
) and Apply("unaryNodeIsDeterministic")(
Ref("u")
@ -627,4 +670,90 @@ object Catalyst extends HardcodedDefinition
Ref("u"),
)
)
Rule("PushDownPredicates-3-1", "LogicalPlan")(
Bind("f", Match("Filter")(
Bind("filterCondition"),
Bind("j", Match("Join")(
Bind("left"),
Bind("right"),
Bind("joinType"),
Bind("joinCondition"),
Bind("hint"),
)),
)) and Test(
Apply("canPushThroughJoin")(
Ref("joinType")
)
)
)(
Apply("pushPredicateThroughJoinAL1")(
Ref("f"),
Ref("filterCondition"),
Ref("left"),
Ref("right"),
Ref("joinType"),
Ref("joinCondition"),
Ref("hint"),
)
)
Rule("PushDownPredicates-3-2", "LogicalPlan")(
Bind("j", Match("Join")(
Bind("left"),
Bind("right"),
Bind("joinType"),
Bind("joinCondition"),
Bind("hint"),
)) and Test(
Apply("canPushThroughJoin")(
Ref("joinType")
)
)
)(
Apply("pushPredicateThroughJoinAL2")(
Ref("j"),
Ref("left"),
Ref("right"),
Ref("joinType"),
Ref("joinCondition"),
Ref("hint"),
)
)
Rule("PushDownLeftSemiAntiJoin-1", "LogicalPlan")(
Bind("j", Match("Join")(
Bind("p", Match("Project")(
Bind("pList"),
Bind("gChild"),
)),
Bind("rightOp"),
Bind("joinType"),
Bind("joinCond"),
Bind("hint"),
)) and Test(
Apply("namedExpressionsAreDeterministic")(
Ref("pList")
)) and Test(
Apply("negatedPListExistsHasCorrelatedScalarSubquery")(
Ref("pList")
)) and Test(
Apply ("canPushThroughConditionSemiAntiJoin")(
Ref("gChild"),
Ref("joinCond"),
Ref("rightOp"),
)
)
)(
Apply("pushDownLeftSemiAntiJoin1")(
Ref("j"),
Ref("p"),
Ref("pList"),
Ref("gChild"),
Ref("rightOp"),
Ref("joinType"),
Ref("joinCond"),
Ref("hint"),
)
)
}