more PushDownPredicate rules, and UnaryNode issues.
parent
789b3c83c9
commit
6169f4d614
|
@ -129,6 +129,25 @@ object SparkMethods
|
|||
def namedExpressionsAreDeterministic(nes: Seq[NamedExpression]): Boolean =
|
||||
nes.forall(_.deterministic)
|
||||
|
||||
def unaryNodeIsDeterministic(u: UnaryNode): Boolean =
|
||||
u.expressions.forall(_.deterministic)
|
||||
|
||||
def canPushThrough(p: UnaryNode): Boolean = p match {
|
||||
case _: AppendColumns => true
|
||||
case _: Distinct => true
|
||||
case _: Generate => true
|
||||
case _: Pivot => true
|
||||
case _: RepartitionByExpression => true
|
||||
case _: Repartition => true
|
||||
case _: RebalancePartitions => true
|
||||
case _: ScriptTransformation => true
|
||||
case _: Sort => true
|
||||
case _: BatchEvalPython => true
|
||||
case _: ArrowEvalPython => true
|
||||
case _: Expand => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
def canPushThroughCondition(plan: LogicalPlan, condition: Expression): Boolean =
|
||||
{
|
||||
val attributes = plan.outputSet
|
||||
|
@ -138,6 +157,18 @@ object SparkMethods
|
|||
}
|
||||
}
|
||||
|
||||
def deterministicAggregateExpressions(aggregate: Aggregate): Boolean = {
|
||||
aggregate.aggregateExpressions.forall(_.deterministic)
|
||||
}
|
||||
|
||||
def groupingAggregateExpressionsNonEmpty(aggregate: Aggregate): Boolean = {
|
||||
aggregate.groupingExpressions.nonEmpty
|
||||
}
|
||||
|
||||
def windowPartitionSpecsAreInstancesOf(w: Window): Boolean = {
|
||||
w.partitionSpec.forall(_.isInstanceOf[AttributeReference])
|
||||
}
|
||||
|
||||
def combineFiltersApplyLocally(fc: Expression, nf: Filter, nc: Expression, grandChild: LogicalPlan): LogicalPlan =
|
||||
{
|
||||
val (combineCandidates, nonDeterministic) =
|
||||
|
@ -152,9 +183,143 @@ object SparkMethods
|
|||
nonDeterministic.reduceOption(And).map(c => Filter(c, mergedFilter)).getOrElse(mergedFilter)
|
||||
}
|
||||
|
||||
def pushPredicateThroughNonJoinApplyLocallyOne(condition: Expression, project: Project, fields: Seq[NamedExpression], grandChild: LogicalPlan ): LogicalPlan =
|
||||
def pushDownPredicate(
|
||||
filter: Filter,
|
||||
grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = {
|
||||
val (candidates, nonDeterministic) =
|
||||
splitConjunctivePredicates(filter.condition).partition(_.deterministic)
|
||||
|
||||
val (pushDown, rest) = candidates.partition { cond =>
|
||||
cond.references.subsetOf(grandchild.outputSet)
|
||||
}
|
||||
|
||||
val stayUp = rest ++ nonDeterministic
|
||||
|
||||
if (pushDown.nonEmpty) {
|
||||
val newChild = insertFilter(pushDown.reduceLeft(And))
|
||||
if (stayUp.nonEmpty) {
|
||||
Filter(stayUp.reduceLeft(And), newChild)
|
||||
} else {
|
||||
newChild
|
||||
}
|
||||
} else {
|
||||
filter
|
||||
}
|
||||
}
|
||||
|
||||
def pushPredicateThroughNonJoinApplyLocallyOne(
|
||||
condition: Expression, project: Project,
|
||||
fields: Seq[NamedExpression], grandChild: LogicalPlan ): LogicalPlan =
|
||||
{
|
||||
val aliasMap = getAliasMap(project)
|
||||
project.copy(child = Filter(replaceAlias(condition, getAliasMap(project)), grandChild))
|
||||
}
|
||||
|
||||
def pushPredicateThroughNonJoinApplyLocallyTwo(
|
||||
filter: Filter, condition: Expression, aggregate: Aggregate) : LogicalPlan =
|
||||
{
|
||||
val aliasMap = getAliasMap(aggregate)
|
||||
|
||||
// For each filter, expand the alias and check if the filter can be evaluated using
|
||||
// attributes produced by the aggregate operator's child operator.
|
||||
val (candidates, nonDeterministic) =
|
||||
splitConjunctivePredicates(condition).partition(_.deterministic)
|
||||
|
||||
val (pushDown, rest) = candidates.partition { cond =>
|
||||
val replaced = replaceAlias(cond, aliasMap)
|
||||
cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet)
|
||||
}
|
||||
|
||||
val stayUp = rest ++ nonDeterministic
|
||||
|
||||
if (pushDown.nonEmpty) {
|
||||
val pushDownPredicate = pushDown.reduce(And)
|
||||
val replaced = replaceAlias(pushDownPredicate, aliasMap)
|
||||
val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child))
|
||||
// If there is no more filter to stay up, just eliminate the filter.
|
||||
// Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)".
|
||||
if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate)
|
||||
} else {
|
||||
filter
|
||||
}
|
||||
}
|
||||
|
||||
def pushPredicateThroughNonJoinApplyLocallyThree(
|
||||
filter: Filter, condition: Expression, w: Window): LogicalPlan =
|
||||
{
|
||||
val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references))
|
||||
|
||||
val (candidates, nonDeterministic) =
|
||||
splitConjunctivePredicates(condition).partition(_.deterministic)
|
||||
|
||||
val (pushDown, rest) = candidates.partition { cond =>
|
||||
cond.references.subsetOf(partitionAttrs)
|
||||
}
|
||||
|
||||
val stayUp = rest ++ nonDeterministic
|
||||
|
||||
if (pushDown.nonEmpty) {
|
||||
val pushDownPredicate = pushDown.reduce(And)
|
||||
val newWindow = w.copy(child = Filter(pushDownPredicate, w.child))
|
||||
if (stayUp.isEmpty) newWindow else Filter(stayUp.reduce(And), newWindow)
|
||||
} else {
|
||||
filter
|
||||
}
|
||||
}
|
||||
|
||||
def pushPredicateThroughNonJoinApplyLocallyFour(
|
||||
filter: Filter, condition: Expression, union: Union
|
||||
): LogicalPlan =
|
||||
{
|
||||
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition(_.deterministic)
|
||||
|
||||
if (pushDown.nonEmpty) {
|
||||
val pushDownCond = pushDown.reduceLeft(And)
|
||||
val output = union.output
|
||||
val newGrandChildren = union.children.map { grandchild =>
|
||||
val newCond = pushDownCond transform {
|
||||
case e if output.exists(_.semanticEquals(e)) =>
|
||||
grandchild.output(output.indexWhere(_.semanticEquals(e)))
|
||||
}
|
||||
assert(newCond.references.subsetOf(grandchild.outputSet))
|
||||
Filter(newCond, grandchild)
|
||||
}
|
||||
val newUnion = union.withNewChildren(newGrandChildren)
|
||||
if (stayUp.nonEmpty) {
|
||||
Filter(stayUp.reduceLeft(And), newUnion)
|
||||
} else {
|
||||
newUnion
|
||||
}
|
||||
} else {
|
||||
filter
|
||||
}
|
||||
}
|
||||
|
||||
def pushPredicateThroughNonJoinApplyLocallyFive(
|
||||
filter: Filter, condition: Expression, watermark: EventTimeWatermark
|
||||
): LogicalPlan =
|
||||
{
|
||||
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { p =>
|
||||
p.deterministic && !p.references.contains(watermark.eventTime)
|
||||
}
|
||||
|
||||
if (pushDown.nonEmpty) {
|
||||
val pushDownPredicate = pushDown.reduceLeft(And)
|
||||
val newWatermark = watermark.copy(child = Filter(pushDownPredicate, watermark.child))
|
||||
// If there is no more filter to stay up, just eliminate the filter.
|
||||
// Otherwise, create "Filter(stayUp) <- watermark <- Filter(pushDownPredicate)".
|
||||
if (stayUp.isEmpty) newWatermark else Filter(stayUp.reduceLeft(And), newWatermark)
|
||||
} else {
|
||||
filter
|
||||
}
|
||||
}
|
||||
|
||||
def pushPredicateThroughNonJoinApplyLocallySix(
|
||||
filter: Filter, u: UnaryNode
|
||||
): LogicalPlan =
|
||||
{
|
||||
pushDownPredicate(filter, u.child) { predicate =>
|
||||
u.withNewChildren(Seq(Filter(predicate, u.child)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package com.astraldb.catalyst
|
|||
|
||||
import com.astraldb.spec.HardcodedDefinition
|
||||
import com.astraldb.spec.Type
|
||||
import scala.collection.View.Filter
|
||||
|
||||
object Catalyst extends HardcodedDefinition
|
||||
{
|
||||
|
@ -38,7 +39,18 @@ object Catalyst extends HardcodedDefinition
|
|||
"groupingExpressions" -> Type.Array(Type.Native("Expression")),
|
||||
"aggregateExpressions" -> Type.Array(Type.Native("NamedExpression")),
|
||||
"child" -> Type.AST("LogicalPlan")
|
||||
)
|
||||
),
|
||||
Node("Window")(
|
||||
"windowExpressions" -> Type.Array(Type.Native("NamedExpression")),
|
||||
"partitionSpec" -> Type.Array(Type.Native("Expression")),
|
||||
"orderSpec" -> Type.Array(Type.Native("SortOrder")),
|
||||
"child" -> Type.AST("LogicalPlan")
|
||||
),
|
||||
Node("EventTimeWatermark")(
|
||||
"eventTime" -> Type.Native("Attribute"),
|
||||
"delay" -> Type.Native("CalendarInterval"),
|
||||
"child" -> Type.AST("LogicalPlan"),
|
||||
),
|
||||
)
|
||||
|
||||
//////////////////////////////////////////////////////
|
||||
|
@ -121,11 +133,31 @@ object Catalyst extends HardcodedDefinition
|
|||
Type.Array(Type.Native("NamedExpression"))
|
||||
)
|
||||
|
||||
Function("canPushThrough", Type.Bool)(
|
||||
Type.Node("UnaryNode")
|
||||
)
|
||||
|
||||
Function("canPushThroughCondition", Type.Bool)(
|
||||
Type.AST("LogicalPlan"),
|
||||
Type.Native("Expression"),
|
||||
)
|
||||
|
||||
Function("unaryNodeIsDeterministic", Type.Bool)(
|
||||
Type.ASTSubtype("UnaryNode")
|
||||
)
|
||||
|
||||
Function("deterministicAggregateExpressions", Type.Bool)(
|
||||
Type.Node("Aggregate")
|
||||
)
|
||||
|
||||
Function("groupingAggregateExpressionsNonEmpty", Type.Bool)(
|
||||
Type.Node("Aggregate")
|
||||
)
|
||||
|
||||
Function("windowPartitionSpecsAreInstancesOf", Type.Bool)(
|
||||
Type.Node("Window")
|
||||
)
|
||||
|
||||
Function("combineFiltersApplyLocally", Type.AST("LogicalPlan"))(
|
||||
Type.Native("Expression"),
|
||||
Type.Node("Filter"),
|
||||
|
@ -140,6 +172,35 @@ object Catalyst extends HardcodedDefinition
|
|||
Type.AST("LogicalPlan"),
|
||||
)
|
||||
|
||||
Function("pushPredicateThroughNonJoinApplyLocallyTwo", Type.AST("LogicalPlan"))(
|
||||
Type.Node("Filter"),
|
||||
Type.Native("Expression"),
|
||||
Type.Node("Aggregate"),
|
||||
)
|
||||
|
||||
Function("pushPredicateThroughNonJoinApplyLocallyThree", Type.AST("LogicalPlan"))(
|
||||
Type.Node("Filter"),
|
||||
Type.Native("Expression"),
|
||||
Type.Node("Window"),
|
||||
)
|
||||
|
||||
Function("pushPredicateThroughNonJoinApplyLocallyFour", Type.AST("LogicalPlan"))(
|
||||
Type.Node("Filter"),
|
||||
Type.Native("Expression"),
|
||||
Type.Node("Union"),
|
||||
)
|
||||
|
||||
Function("pushPredicateThroughNonJoinApplyLocallyFive", Type.AST("LogicalPlan"))(
|
||||
Type.Node("Filter"),
|
||||
Type.Native("Expression"),
|
||||
Type.Node("EventTimeWatermark"),
|
||||
)
|
||||
|
||||
Function("pushPredicateThroughNonJoinApplyLocallySix", Type.AST("LogicalPlan"))(
|
||||
Type.Node("Filter"),
|
||||
Type.ASTSubtype("UnaryNode"),
|
||||
)
|
||||
|
||||
Global("JoinHint.NONE", Type.Native("JoinHint"))
|
||||
Global("RightOuter", Type.Native("JoinType"))
|
||||
Global("LeftOuter", Type.Native("JoinType"))
|
||||
|
@ -426,6 +487,7 @@ object Catalyst extends HardcodedDefinition
|
|||
)
|
||||
)
|
||||
|
||||
// PushDownPredicates and its many cases.
|
||||
Rule("PushDownPredicates-1", "LogicalPlan")(
|
||||
Match("Filter")(
|
||||
Bind("fc"),
|
||||
|
@ -468,4 +530,101 @@ object Catalyst extends HardcodedDefinition
|
|||
Ref("grandChild"),
|
||||
)
|
||||
)
|
||||
|
||||
Rule("PushDownPredicates-2-2", "LogicalPlan")(
|
||||
Bind("filter", Match("Filter")(
|
||||
Bind("condition"),
|
||||
Bind("aggregate", Match("Aggregate")(
|
||||
Bind("unused1"),
|
||||
Bind("unused2"),
|
||||
Bind("unused3"),
|
||||
)),
|
||||
)) and Test(
|
||||
Apply("deterministicAggregateExpressions")(
|
||||
Ref("aggregate")
|
||||
) and
|
||||
Apply("groupingAggregateExpressionsNonEmpty")(
|
||||
Ref("aggregate")
|
||||
)
|
||||
)
|
||||
)(
|
||||
Apply("pushPredicateThroughNonJoinApplyLocallyTwo")(
|
||||
Ref("filter"),
|
||||
Ref("condition"),
|
||||
Ref("aggregate"),
|
||||
)
|
||||
)
|
||||
|
||||
Rule("PushDownPredicates-2-3", "LogicalPlan")(
|
||||
Bind("filter", Match("Filter")(
|
||||
Bind("condition"),
|
||||
Bind("w", Match("Window")(
|
||||
Bind("unused1"),
|
||||
Bind("unused2"),
|
||||
Bind("unused3"),
|
||||
Bind("unused4"),
|
||||
)),
|
||||
)) and Test(
|
||||
Apply("windowPartitionSpecsAreInstancesOf")(
|
||||
Ref("w")
|
||||
))
|
||||
)(
|
||||
Apply("pushPredicateThroughNonJoinApplyLocallyThree")(
|
||||
Ref("filter"),
|
||||
Ref("condition"),
|
||||
Ref("w"),
|
||||
)
|
||||
)
|
||||
|
||||
Rule("PushDownPredicates-2-4", "LogicalPlan")(
|
||||
Bind("filter", Match("Filter")(
|
||||
Bind("condition"),
|
||||
Bind("union", Match ("Union")(
|
||||
Bind("unused1"),
|
||||
Bind("unused2"),
|
||||
Bind("unused3"),
|
||||
)),
|
||||
))
|
||||
)(
|
||||
Apply("pushPredicateThroughNonJoinApplyLocallyFour")(
|
||||
Ref("filter"),
|
||||
Ref("condition"),
|
||||
Ref("union"),
|
||||
)
|
||||
)
|
||||
|
||||
Rule("PushDownPredicates-2-5", "LogicalPlan")(
|
||||
Bind("filter", Match("Filter")(
|
||||
Bind("condition"),
|
||||
Bind("watermark", Match("EventTimeWatermark")(
|
||||
Bind("unused1"),
|
||||
Bind("unused2"),
|
||||
Bind("unused3"),
|
||||
)),
|
||||
))
|
||||
)(
|
||||
Apply("pushPredicateThroughNonJoinApplyLocallyFive")(
|
||||
Ref("filter"),
|
||||
Ref("condition"),
|
||||
Ref("watermark"),
|
||||
)
|
||||
)
|
||||
|
||||
Rule("PushDownPredicates-2-6", "LogicalPlan")(
|
||||
Bind("filter", Match("Filter")(
|
||||
Bind("unusedCondition"),
|
||||
Bind("u", Match("UnaryNode")), // here be dragons.
|
||||
)) and Test(
|
||||
Apply("canPushThrough")(
|
||||
Ref("u")
|
||||
) and Apply("unaryNodeIsDeterministic")(
|
||||
Ref("u")
|
||||
)
|
||||
)
|
||||
)(
|
||||
Apply("pushPredicateThroughNonJoinApplyLocallySix")(
|
||||
Ref("filter"),
|
||||
Ref("u"),
|
||||
)
|
||||
)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue