more PushDownPredicate rules, and UnaryNode issues.

nicksrules
Nick Brown 2023-07-13 14:36:15 -04:00
parent 789b3c83c9
commit 6169f4d614
Signed by: bicknrown
GPG Key ID: 47AF495B3DCCE9C3
2 changed files with 326 additions and 2 deletions

View File

@ -129,6 +129,25 @@ object SparkMethods
def namedExpressionsAreDeterministic(nes: Seq[NamedExpression]): Boolean =
nes.forall(_.deterministic)
def unaryNodeIsDeterministic(u: UnaryNode): Boolean =
u.expressions.forall(_.deterministic)
def canPushThrough(p: UnaryNode): Boolean = p match {
case _: AppendColumns => true
case _: Distinct => true
case _: Generate => true
case _: Pivot => true
case _: RepartitionByExpression => true
case _: Repartition => true
case _: RebalancePartitions => true
case _: ScriptTransformation => true
case _: Sort => true
case _: BatchEvalPython => true
case _: ArrowEvalPython => true
case _: Expand => true
case _ => false
}
def canPushThroughCondition(plan: LogicalPlan, condition: Expression): Boolean =
{
val attributes = plan.outputSet
@ -138,6 +157,18 @@ object SparkMethods
}
}
def deterministicAggregateExpressions(aggregate: Aggregate): Boolean = {
aggregate.aggregateExpressions.forall(_.deterministic)
}
def groupingAggregateExpressionsNonEmpty(aggregate: Aggregate): Boolean = {
aggregate.groupingExpressions.nonEmpty
}
def windowPartitionSpecsAreInstancesOf(w: Window): Boolean = {
w.partitionSpec.forall(_.isInstanceOf[AttributeReference])
}
def combineFiltersApplyLocally(fc: Expression, nf: Filter, nc: Expression, grandChild: LogicalPlan): LogicalPlan =
{
val (combineCandidates, nonDeterministic) =
@ -152,9 +183,143 @@ object SparkMethods
nonDeterministic.reduceOption(And).map(c => Filter(c, mergedFilter)).getOrElse(mergedFilter)
}
def pushPredicateThroughNonJoinApplyLocallyOne(condition: Expression, project: Project, fields: Seq[NamedExpression], grandChild: LogicalPlan ): LogicalPlan =
def pushDownPredicate(
filter: Filter,
grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = {
val (candidates, nonDeterministic) =
splitConjunctivePredicates(filter.condition).partition(_.deterministic)
val (pushDown, rest) = candidates.partition { cond =>
cond.references.subsetOf(grandchild.outputSet)
}
val stayUp = rest ++ nonDeterministic
if (pushDown.nonEmpty) {
val newChild = insertFilter(pushDown.reduceLeft(And))
if (stayUp.nonEmpty) {
Filter(stayUp.reduceLeft(And), newChild)
} else {
newChild
}
} else {
filter
}
}
def pushPredicateThroughNonJoinApplyLocallyOne(
condition: Expression, project: Project,
fields: Seq[NamedExpression], grandChild: LogicalPlan ): LogicalPlan =
{
val aliasMap = getAliasMap(project)
project.copy(child = Filter(replaceAlias(condition, getAliasMap(project)), grandChild))
}
def pushPredicateThroughNonJoinApplyLocallyTwo(
filter: Filter, condition: Expression, aggregate: Aggregate) : LogicalPlan =
{
val aliasMap = getAliasMap(aggregate)
// For each filter, expand the alias and check if the filter can be evaluated using
// attributes produced by the aggregate operator's child operator.
val (candidates, nonDeterministic) =
splitConjunctivePredicates(condition).partition(_.deterministic)
val (pushDown, rest) = candidates.partition { cond =>
val replaced = replaceAlias(cond, aliasMap)
cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet)
}
val stayUp = rest ++ nonDeterministic
if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And)
val replaced = replaceAlias(pushDownPredicate, aliasMap)
val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child))
// If there is no more filter to stay up, just eliminate the filter.
// Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)".
if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate)
} else {
filter
}
}
def pushPredicateThroughNonJoinApplyLocallyThree(
filter: Filter, condition: Expression, w: Window): LogicalPlan =
{
val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references))
val (candidates, nonDeterministic) =
splitConjunctivePredicates(condition).partition(_.deterministic)
val (pushDown, rest) = candidates.partition { cond =>
cond.references.subsetOf(partitionAttrs)
}
val stayUp = rest ++ nonDeterministic
if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And)
val newWindow = w.copy(child = Filter(pushDownPredicate, w.child))
if (stayUp.isEmpty) newWindow else Filter(stayUp.reduce(And), newWindow)
} else {
filter
}
}
def pushPredicateThroughNonJoinApplyLocallyFour(
filter: Filter, condition: Expression, union: Union
): LogicalPlan =
{
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition(_.deterministic)
if (pushDown.nonEmpty) {
val pushDownCond = pushDown.reduceLeft(And)
val output = union.output
val newGrandChildren = union.children.map { grandchild =>
val newCond = pushDownCond transform {
case e if output.exists(_.semanticEquals(e)) =>
grandchild.output(output.indexWhere(_.semanticEquals(e)))
}
assert(newCond.references.subsetOf(grandchild.outputSet))
Filter(newCond, grandchild)
}
val newUnion = union.withNewChildren(newGrandChildren)
if (stayUp.nonEmpty) {
Filter(stayUp.reduceLeft(And), newUnion)
} else {
newUnion
}
} else {
filter
}
}
def pushPredicateThroughNonJoinApplyLocallyFive(
filter: Filter, condition: Expression, watermark: EventTimeWatermark
): LogicalPlan =
{
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { p =>
p.deterministic && !p.references.contains(watermark.eventTime)
}
if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduceLeft(And)
val newWatermark = watermark.copy(child = Filter(pushDownPredicate, watermark.child))
// If there is no more filter to stay up, just eliminate the filter.
// Otherwise, create "Filter(stayUp) <- watermark <- Filter(pushDownPredicate)".
if (stayUp.isEmpty) newWatermark else Filter(stayUp.reduceLeft(And), newWatermark)
} else {
filter
}
}
def pushPredicateThroughNonJoinApplyLocallySix(
filter: Filter, u: UnaryNode
): LogicalPlan =
{
pushDownPredicate(filter, u.child) { predicate =>
u.withNewChildren(Seq(Filter(predicate, u.child)))
}
}
}

View File

@ -2,6 +2,7 @@ package com.astraldb.catalyst
import com.astraldb.spec.HardcodedDefinition
import com.astraldb.spec.Type
import scala.collection.View.Filter
object Catalyst extends HardcodedDefinition
{
@ -38,7 +39,18 @@ object Catalyst extends HardcodedDefinition
"groupingExpressions" -> Type.Array(Type.Native("Expression")),
"aggregateExpressions" -> Type.Array(Type.Native("NamedExpression")),
"child" -> Type.AST("LogicalPlan")
)
),
Node("Window")(
"windowExpressions" -> Type.Array(Type.Native("NamedExpression")),
"partitionSpec" -> Type.Array(Type.Native("Expression")),
"orderSpec" -> Type.Array(Type.Native("SortOrder")),
"child" -> Type.AST("LogicalPlan")
),
Node("EventTimeWatermark")(
"eventTime" -> Type.Native("Attribute"),
"delay" -> Type.Native("CalendarInterval"),
"child" -> Type.AST("LogicalPlan"),
),
)
//////////////////////////////////////////////////////
@ -121,11 +133,31 @@ object Catalyst extends HardcodedDefinition
Type.Array(Type.Native("NamedExpression"))
)
Function("canPushThrough", Type.Bool)(
Type.Node("UnaryNode")
)
Function("canPushThroughCondition", Type.Bool)(
Type.AST("LogicalPlan"),
Type.Native("Expression"),
)
Function("unaryNodeIsDeterministic", Type.Bool)(
Type.ASTSubtype("UnaryNode")
)
Function("deterministicAggregateExpressions", Type.Bool)(
Type.Node("Aggregate")
)
Function("groupingAggregateExpressionsNonEmpty", Type.Bool)(
Type.Node("Aggregate")
)
Function("windowPartitionSpecsAreInstancesOf", Type.Bool)(
Type.Node("Window")
)
Function("combineFiltersApplyLocally", Type.AST("LogicalPlan"))(
Type.Native("Expression"),
Type.Node("Filter"),
@ -140,6 +172,35 @@ object Catalyst extends HardcodedDefinition
Type.AST("LogicalPlan"),
)
Function("pushPredicateThroughNonJoinApplyLocallyTwo", Type.AST("LogicalPlan"))(
Type.Node("Filter"),
Type.Native("Expression"),
Type.Node("Aggregate"),
)
Function("pushPredicateThroughNonJoinApplyLocallyThree", Type.AST("LogicalPlan"))(
Type.Node("Filter"),
Type.Native("Expression"),
Type.Node("Window"),
)
Function("pushPredicateThroughNonJoinApplyLocallyFour", Type.AST("LogicalPlan"))(
Type.Node("Filter"),
Type.Native("Expression"),
Type.Node("Union"),
)
Function("pushPredicateThroughNonJoinApplyLocallyFive", Type.AST("LogicalPlan"))(
Type.Node("Filter"),
Type.Native("Expression"),
Type.Node("EventTimeWatermark"),
)
Function("pushPredicateThroughNonJoinApplyLocallySix", Type.AST("LogicalPlan"))(
Type.Node("Filter"),
Type.ASTSubtype("UnaryNode"),
)
Global("JoinHint.NONE", Type.Native("JoinHint"))
Global("RightOuter", Type.Native("JoinType"))
Global("LeftOuter", Type.Native("JoinType"))
@ -426,6 +487,7 @@ object Catalyst extends HardcodedDefinition
)
)
// PushDownPredicates and its many cases.
Rule("PushDownPredicates-1", "LogicalPlan")(
Match("Filter")(
Bind("fc"),
@ -468,4 +530,101 @@ object Catalyst extends HardcodedDefinition
Ref("grandChild"),
)
)
Rule("PushDownPredicates-2-2", "LogicalPlan")(
Bind("filter", Match("Filter")(
Bind("condition"),
Bind("aggregate", Match("Aggregate")(
Bind("unused1"),
Bind("unused2"),
Bind("unused3"),
)),
)) and Test(
Apply("deterministicAggregateExpressions")(
Ref("aggregate")
) and
Apply("groupingAggregateExpressionsNonEmpty")(
Ref("aggregate")
)
)
)(
Apply("pushPredicateThroughNonJoinApplyLocallyTwo")(
Ref("filter"),
Ref("condition"),
Ref("aggregate"),
)
)
Rule("PushDownPredicates-2-3", "LogicalPlan")(
Bind("filter", Match("Filter")(
Bind("condition"),
Bind("w", Match("Window")(
Bind("unused1"),
Bind("unused2"),
Bind("unused3"),
Bind("unused4"),
)),
)) and Test(
Apply("windowPartitionSpecsAreInstancesOf")(
Ref("w")
))
)(
Apply("pushPredicateThroughNonJoinApplyLocallyThree")(
Ref("filter"),
Ref("condition"),
Ref("w"),
)
)
Rule("PushDownPredicates-2-4", "LogicalPlan")(
Bind("filter", Match("Filter")(
Bind("condition"),
Bind("union", Match ("Union")(
Bind("unused1"),
Bind("unused2"),
Bind("unused3"),
)),
))
)(
Apply("pushPredicateThroughNonJoinApplyLocallyFour")(
Ref("filter"),
Ref("condition"),
Ref("union"),
)
)
Rule("PushDownPredicates-2-5", "LogicalPlan")(
Bind("filter", Match("Filter")(
Bind("condition"),
Bind("watermark", Match("EventTimeWatermark")(
Bind("unused1"),
Bind("unused2"),
Bind("unused3"),
)),
))
)(
Apply("pushPredicateThroughNonJoinApplyLocallyFive")(
Ref("filter"),
Ref("condition"),
Ref("watermark"),
)
)
Rule("PushDownPredicates-2-6", "LogicalPlan")(
Bind("filter", Match("Filter")(
Bind("unusedCondition"),
Bind("u", Match("UnaryNode")), // here be dragons.
)) and Test(
Apply("canPushThrough")(
Ref("u")
) and Apply("unaryNodeIsDeterministic")(
Ref("u")
)
)
)(
Apply("pushPredicateThroughNonJoinApplyLocallySix")(
Ref("filter"),
Ref("u"),
)
)
}