Compare commits

...

3 Commits

7 changed files with 411 additions and 30 deletions

View File

@ -129,6 +129,25 @@ object SparkMethods
def namedExpressionsAreDeterministic(nes: Seq[NamedExpression]): Boolean =
nes.forall(_.deterministic)
def unaryNodeIsDeterministic(u: UnaryNode): Boolean =
u.expressions.forall(_.deterministic)
def canPushThrough(p: UnaryNode): Boolean = p match {
case _: AppendColumns => true
case _: Distinct => true
case _: Generate => true
case _: Pivot => true
case _: RepartitionByExpression => true
case _: Repartition => true
case _: RebalancePartitions => true
case _: ScriptTransformation => true
case _: Sort => true
case _: BatchEvalPython => true
case _: ArrowEvalPython => true
case _: Expand => true
case _ => false
}
def canPushThroughCondition(plan: LogicalPlan, condition: Expression): Boolean =
{
val attributes = plan.outputSet
@ -138,6 +157,18 @@ object SparkMethods
}
}
def deterministicAggregateExpressions(aggregate: Aggregate): Boolean = {
aggregate.aggregateExpressions.forall(_.deterministic)
}
def groupingAggregateExpressionsNonEmpty(aggregate: Aggregate): Boolean = {
aggregate.groupingExpressions.nonEmpty
}
def windowPartitionSpecsAreInstancesOf(w: Window): Boolean = {
w.partitionSpec.forall(_.isInstanceOf[AttributeReference])
}
def combineFiltersApplyLocally(fc: Expression, nf: Filter, nc: Expression, grandChild: LogicalPlan): LogicalPlan =
{
val (combineCandidates, nonDeterministic) =
@ -152,9 +183,143 @@ object SparkMethods
nonDeterministic.reduceOption(And).map(c => Filter(c, mergedFilter)).getOrElse(mergedFilter)
}
def pushPredicateThroughNonJoinApplyLocallyOne(condition: Expression, project: Project, fields: Seq[NamedExpression], grandChild: LogicalPlan ): LogicalPlan =
def pushDownPredicate(
filter: Filter,
grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = {
val (candidates, nonDeterministic) =
splitConjunctivePredicates(filter.condition).partition(_.deterministic)
val (pushDown, rest) = candidates.partition { cond =>
cond.references.subsetOf(grandchild.outputSet)
}
val stayUp = rest ++ nonDeterministic
if (pushDown.nonEmpty) {
val newChild = insertFilter(pushDown.reduceLeft(And))
if (stayUp.nonEmpty) {
Filter(stayUp.reduceLeft(And), newChild)
} else {
newChild
}
} else {
filter
}
}
def pushPredicateThroughNonJoinApplyLocallyOne(
condition: Expression, project: Project,
fields: Seq[NamedExpression], grandChild: LogicalPlan ): LogicalPlan =
{
val aliasMap = getAliasMap(project)
project.copy(child = Filter(replaceAlias(condition, getAliasMap(project)), grandChild))
}
def pushPredicateThroughNonJoinApplyLocallyTwo(
filter: Filter, condition: Expression, aggregate: Aggregate) : LogicalPlan =
{
val aliasMap = getAliasMap(aggregate)
// For each filter, expand the alias and check if the filter can be evaluated using
// attributes produced by the aggregate operator's child operator.
val (candidates, nonDeterministic) =
splitConjunctivePredicates(condition).partition(_.deterministic)
val (pushDown, rest) = candidates.partition { cond =>
val replaced = replaceAlias(cond, aliasMap)
cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet)
}
val stayUp = rest ++ nonDeterministic
if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And)
val replaced = replaceAlias(pushDownPredicate, aliasMap)
val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child))
// If there is no more filter to stay up, just eliminate the filter.
// Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)".
if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate)
} else {
filter
}
}
def pushPredicateThroughNonJoinApplyLocallyThree(
filter: Filter, condition: Expression, w: Window): LogicalPlan =
{
val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references))
val (candidates, nonDeterministic) =
splitConjunctivePredicates(condition).partition(_.deterministic)
val (pushDown, rest) = candidates.partition { cond =>
cond.references.subsetOf(partitionAttrs)
}
val stayUp = rest ++ nonDeterministic
if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And)
val newWindow = w.copy(child = Filter(pushDownPredicate, w.child))
if (stayUp.isEmpty) newWindow else Filter(stayUp.reduce(And), newWindow)
} else {
filter
}
}
def pushPredicateThroughNonJoinApplyLocallyFour(
filter: Filter, condition: Expression, union: Union
): LogicalPlan =
{
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition(_.deterministic)
if (pushDown.nonEmpty) {
val pushDownCond = pushDown.reduceLeft(And)
val output = union.output
val newGrandChildren = union.children.map { grandchild =>
val newCond = pushDownCond transform {
case e if output.exists(_.semanticEquals(e)) =>
grandchild.output(output.indexWhere(_.semanticEquals(e)))
}
assert(newCond.references.subsetOf(grandchild.outputSet))
Filter(newCond, grandchild)
}
val newUnion = union.withNewChildren(newGrandChildren)
if (stayUp.nonEmpty) {
Filter(stayUp.reduceLeft(And), newUnion)
} else {
newUnion
}
} else {
filter
}
}
def pushPredicateThroughNonJoinApplyLocallyFive(
filter: Filter, condition: Expression, watermark: EventTimeWatermark
): LogicalPlan =
{
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { p =>
p.deterministic && !p.references.contains(watermark.eventTime)
}
if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduceLeft(And)
val newWatermark = watermark.copy(child = Filter(pushDownPredicate, watermark.child))
// If there is no more filter to stay up, just eliminate the filter.
// Otherwise, create "Filter(stayUp) <- watermark <- Filter(pushDownPredicate)".
if (stayUp.isEmpty) newWatermark else Filter(stayUp.reduceLeft(And), newWatermark)
} else {
filter
}
}
def pushPredicateThroughNonJoinApplyLocallySix(
filter: Filter, u: UnaryNode
): LogicalPlan =
{
pushDownPredicate(filter, u.child) { predicate =>
u.withNewChildren(Seq(Filter(predicate, u.child)))
}
}
}

View File

@ -2,6 +2,7 @@ package com.astraldb.catalyst
import com.astraldb.spec.HardcodedDefinition
import com.astraldb.spec.Type
import scala.collection.View.Filter
object Catalyst extends HardcodedDefinition
{
@ -38,7 +39,18 @@ object Catalyst extends HardcodedDefinition
"groupingExpressions" -> Type.Array(Type.Native("Expression")),
"aggregateExpressions" -> Type.Array(Type.Native("NamedExpression")),
"child" -> Type.AST("LogicalPlan")
)
),
Node("Window")(
"windowExpressions" -> Type.Array(Type.Native("NamedExpression")),
"partitionSpec" -> Type.Array(Type.Native("Expression")),
"orderSpec" -> Type.Array(Type.Native("SortOrder")),
"child" -> Type.AST("LogicalPlan")
),
Node("EventTimeWatermark")(
"eventTime" -> Type.Native("Attribute"),
"delay" -> Type.Native("CalendarInterval"),
"child" -> Type.AST("LogicalPlan"),
),
)
//////////////////////////////////////////////////////
@ -121,11 +133,31 @@ object Catalyst extends HardcodedDefinition
Type.Array(Type.Native("NamedExpression"))
)
Function("canPushThrough", Type.Bool)(
Type.Node("UnaryNode")
)
Function("canPushThroughCondition", Type.Bool)(
Type.AST("LogicalPlan"),
Type.Native("Expression"),
)
Function("unaryNodeIsDeterministic", Type.Bool)(
Type.ASTSubtype("UnaryNode")
)
Function("deterministicAggregateExpressions", Type.Bool)(
Type.Node("Aggregate")
)
Function("groupingAggregateExpressionsNonEmpty", Type.Bool)(
Type.Node("Aggregate")
)
Function("windowPartitionSpecsAreInstancesOf", Type.Bool)(
Type.Node("Window")
)
Function("combineFiltersApplyLocally", Type.AST("LogicalPlan"))(
Type.Native("Expression"),
Type.Node("Filter"),
@ -140,6 +172,35 @@ object Catalyst extends HardcodedDefinition
Type.AST("LogicalPlan"),
)
Function("pushPredicateThroughNonJoinApplyLocallyTwo", Type.AST("LogicalPlan"))(
Type.Node("Filter"),
Type.Native("Expression"),
Type.Node("Aggregate"),
)
Function("pushPredicateThroughNonJoinApplyLocallyThree", Type.AST("LogicalPlan"))(
Type.Node("Filter"),
Type.Native("Expression"),
Type.Node("Window"),
)
Function("pushPredicateThroughNonJoinApplyLocallyFour", Type.AST("LogicalPlan"))(
Type.Node("Filter"),
Type.Native("Expression"),
Type.Node("Union"),
)
Function("pushPredicateThroughNonJoinApplyLocallyFive", Type.AST("LogicalPlan"))(
Type.Node("Filter"),
Type.Native("Expression"),
Type.Node("EventTimeWatermark"),
)
Function("pushPredicateThroughNonJoinApplyLocallySix", Type.AST("LogicalPlan"))(
Type.Node("Filter"),
Type.ASTSubtype("UnaryNode"),
)
Global("JoinHint.NONE", Type.Native("JoinHint"))
Global("RightOuter", Type.Native("JoinType"))
Global("LeftOuter", Type.Native("JoinType"))
@ -426,6 +487,7 @@ object Catalyst extends HardcodedDefinition
)
)
// PushDownPredicates and its many cases.
Rule("PushDownPredicates-1", "LogicalPlan")(
Match("Filter")(
Bind("fc"),
@ -468,4 +530,101 @@ object Catalyst extends HardcodedDefinition
Ref("grandChild"),
)
)
Rule("PushDownPredicates-2-2", "LogicalPlan")(
Bind("filter", Match("Filter")(
Bind("condition"),
Bind("aggregate", Match("Aggregate")(
Bind("unused1"),
Bind("unused2"),
Bind("unused3"),
)),
)) and Test(
Apply("deterministicAggregateExpressions")(
Ref("aggregate")
) and
Apply("groupingAggregateExpressionsNonEmpty")(
Ref("aggregate")
)
)
)(
Apply("pushPredicateThroughNonJoinApplyLocallyTwo")(
Ref("filter"),
Ref("condition"),
Ref("aggregate"),
)
)
Rule("PushDownPredicates-2-3", "LogicalPlan")(
Bind("filter", Match("Filter")(
Bind("condition"),
Bind("w", Match("Window")(
Bind("unused1"),
Bind("unused2"),
Bind("unused3"),
Bind("unused4"),
)),
)) and Test(
Apply("windowPartitionSpecsAreInstancesOf")(
Ref("w")
))
)(
Apply("pushPredicateThroughNonJoinApplyLocallyThree")(
Ref("filter"),
Ref("condition"),
Ref("w"),
)
)
Rule("PushDownPredicates-2-4", "LogicalPlan")(
Bind("filter", Match("Filter")(
Bind("condition"),
Bind("union", Match ("Union")(
Bind("unused1"),
Bind("unused2"),
Bind("unused3"),
)),
))
)(
Apply("pushPredicateThroughNonJoinApplyLocallyFour")(
Ref("filter"),
Ref("condition"),
Ref("union"),
)
)
Rule("PushDownPredicates-2-5", "LogicalPlan")(
Bind("filter", Match("Filter")(
Bind("condition"),
Bind("watermark", Match("EventTimeWatermark")(
Bind("unused1"),
Bind("unused2"),
Bind("unused3"),
)),
))
)(
Apply("pushPredicateThroughNonJoinApplyLocallyFive")(
Ref("filter"),
Ref("condition"),
Ref("watermark"),
)
)
Rule("PushDownPredicates-2-6", "LogicalPlan")(
Bind("filter", Match("Filter")(
Bind("unusedCondition"),
Bind("u", Match("UnaryNode")), // here be dragons.
)) and Test(
Apply("canPushThrough")(
Ref("u")
) and Apply("unaryNodeIsDeterministic")(
Ref("u")
)
)
)(
Apply("pushPredicateThroughNonJoinApplyLocallySix")(
Ref("filter"),
Ref("u"),
)
)
}

View File

@ -0,0 +1,15 @@
package com.astraldb.spec
case class ASTDefinition(
family: Type.AST,
nodes: Set[Node]
)
{
val subtypes:Map[Type.ASTSubtype, Set[Node]] =
nodes.flatMap { node =>
node.supertypes.map { st =>
st -> node
}
}
.groupMap { _._1 } { _._2 }
}

View File

@ -4,10 +4,14 @@ import scala.collection.mutable
import com.astraldb.expression._
case class Definition(
nodes:Map[String, Seq[Node]],
asts: Map[String, ASTDefinition],
rules:Seq[Rule],
globals: Map[String, Type],
) {
def nodes =
asts.mapValues { _.nodes }
override def toString =
s"""/////// ASTs //////
|${nodes.map { case (family, nodeTypes) => s"Ast(${family})(\n ${nodeTypes.mkString(",\n ")}\n)" }.mkString("\n\n")}
@ -15,6 +19,7 @@ case class Definition(
|/////// Rules /////
|${rules.mkString("\n\n")}
""".stripMargin
val familyOfNode: Map[String, String] =
nodes.flatMap { case (family, elements) => elements.map { _.name -> family } }
.toMap
@ -32,7 +37,12 @@ class HardcodedDefinition
{
lazy val definition: Definition =
Definition(
nodes = nodes.mapValues { _.toSeq }.toMap,
asts = nodes.map { case (f, n) =>
val family = Type.AST(f)
f -> ASTDefinition(
family = family,
nodes = n.map { _.copy(family = family) }.toSet
)}.toMap,
rules = rules.toSeq,
globals = globals.toMap
)
@ -43,13 +53,12 @@ class HardcodedDefinition
import FieldConversions._
def Ast(label: String)(newNodes: Node*): Unit =
def Ast(label: String)(newNodes: => Node*): Unit =
nodes.getOrElseUpdate(label, mutable.Buffer.empty)
.appendAll(newNodes)
def Node(label: String)(fields: (String, Type)*): Node =
com.astraldb.spec.Node(label, fields)
com.astraldb.spec.Node(label, fields, supertypes = Set.empty, family = null)
def Rule(label: String, family: String)(pattern: Match)(replacement: Expression): Unit =
rules.append(

View File

@ -1,6 +1,11 @@
package com.astraldb.spec;
case class Node(val name:String, val fields:Seq[Field])
case class Node(
val name: String,
val fields: Seq[Field],
val family: Type.AST,
val supertypes: Set[Type.ASTSubtype]
)
{
def renderName = name+"Node"
def enumName = "JITD_NODE_"+name
@ -9,4 +14,9 @@ case class Node(val name:String, val fields:Seq[Field])
override def toString =
name + "(" + fields.map { _.toString }.mkString(", ") + ")"
def withSupertypes(supertypes: String*): Node =
copy(supertypes = this.supertypes ++ supertypes.map { Type.ASTSubtype(_) })
def allSupertypes: Set[Type.ASTType] = supertypes ++ Set(family)
}

View File

@ -19,12 +19,20 @@ object Type
override def toString: String = s"Native[${name}]"
def scalaType: String = name
}
case class AST(family: String) extends Type
sealed trait ASTType extends Type
case class AST(family: String) extends ASTType
{
override def toString: String = s"Ast[${family}]"
def scalaType: String = family
}
case class Node(nodeType: String) extends Type
case class ASTSubtype(typeName: String) extends ASTType
{
override def toString: String = s"ASTSubtype[${typeName}]"
def scalaType: String = typeName
}
case class Node(nodeType: String) extends ASTType
{
override def toString: String = s"Node[${nodeType}]"
def scalaType: String = nodeType

View File

@ -12,10 +12,8 @@ object Typecheck
{
(source, target) match {
case (a, b) if a == b => true
case (Type.Node(label), Type.AST(family)) =>
schema.nodes.get(family)
.map { _.exists { _.name == label } }
.getOrElse { false }
case (Type.Node(label), o:Type.ASTType) =>
schema.nodesByName(label).allSupertypes contains o
case (Type.Union(elems), a) =>
elems.forall { escalatesTo(_, a, schema) }
case (a, Type.Union(elems)) =>
@ -30,14 +28,16 @@ object Typecheck
{
if(a == b){ return a }
(a, b) match {
case (Type.Node(label), Type.AST(family)) if
schema.nodes.get(family)
.map { _.exists { _.name == label } }
.getOrElse { false } => a
case (Type.AST(family), Type.Node(label)) if
schema.nodes.get(family)
.map { _.exists { _.name == label } }
.getOrElse { false } => b
case (n@Type.Node(label), o:Type.ASTType) if
schema.nodesByName(label)
.allSupertypes contains o => n
case (o:Type.ASTType, n@Type.Node(label)) if
schema.nodesByName(label)
.allSupertypes contains o => n
case (n:Type.ASTSubtype, a:Type.AST) if
schema.asts(a.family).subtypes contains n => n
case (a:Type.AST, n:Type.ASTSubtype) if
schema.asts(a.family).subtypes contains n => n
case (Type.Union(elems), _) if elems contains b => a
case (_, Type.Union(elems)) if elems contains a => b
case (_, Type.Any) => a
@ -53,14 +53,29 @@ object Typecheck
{
if(a == b){ return a }
(a, b) match {
case (Type.Node(label), Type.AST(family)) if
schema.nodes.get(family)
.map { _.exists { _.name == label } }
.getOrElse { false } => b
case (Type.AST(family), Type.Node(label)) if
schema.nodes.get(family)
.map { _.exists { _.name == label } }
.getOrElse { false } => a
case (n:Type.Node, o:Type.ASTType) if
schema.nodesByName(n.nodeType)
.allSupertypes contains o => o
case (o:Type.ASTType, n@Type.Node(label)) if
schema.nodesByName(label)
.allSupertypes contains o => o
case (n:Type.ASTSubtype, o:Type.AST) if
schema.asts(o.family).subtypes contains n => o
case (o:Type.AST, n:Type.ASTSubtype) if
schema.asts(o.family).subtypes contains n => o
case (a:Type.Node, b:Type.Node) =>
val sharedSuperTypes =
schema.nodesByName(a.nodeType).allSupertypes &
schema.nodesByName(b.nodeType).allSupertypes
if(sharedSuperTypes.isEmpty){
assert(false, s"Node types $a and $b have nothing in common")
} else {
sharedSuperTypes.find { _.isInstanceOf[Type.ASTSubtype] }
.orElse { sharedSuperTypes.find { _.isInstanceOf[Type.AST] }}
.getOrElse {
assert(false, "A node can't inherit from another node")
}
}
case (Type.Union(elems), _) if elems contains b => a
case (_, Type.Union(elems)) if elems contains a => b
case (_, Type.Any) => Type.Any