Compare commits

...

7 Commits

25 changed files with 1081 additions and 205 deletions

View File

@ -6,7 +6,7 @@
- [x] PushProjectionThroughLimit,
- [x] ReorderJoin,
- [x] EliminateOuterJoin,
- [ ] PushDownPredicates,
- [x] PushDownPredicates,
- [ ] PushDownLeftSemiAntiJoin,
- [ ] PushLeftSemiLeftAntiThroughJoin,
- [ ] LimitPushDown,

View File

@ -132,7 +132,7 @@ object SparkMethods
def unaryNodeIsDeterministic(u: UnaryNode): Boolean =
u.expressions.forall(_.deterministic)
def canPushThrough(p: UnaryNode): Boolean = p match {
def canPushThroughNonJoin(p: UnaryNode): Boolean = p match {
case _: AppendColumns => true
case _: Distinct => true
case _: Generate => true
@ -148,6 +148,24 @@ object SparkMethods
case _ => false
}
def canPushThroughJoin(joinType: JoinType): Boolean = joinType match {
case _: InnerLike | LeftSemi | RightOuter | LeftOuter | LeftAnti | ExistenceJoin(_) => true
case _ => false
}
def canPushThroughConditionSemiAntiJoin(
plans: LogicalPlan,
condition: Option[Expression],
rightOp: LogicalPlan): Boolean = {
val attributes = AttributeSet(Seq(plans).flatMap(_.output))
if (condition.isDefined) {
val matched = condition.get.references.intersect(rightOp.outputSet).intersect(attributes)
matched.isEmpty
} else {
true
}
}
def canPushThroughCondition(plan: LogicalPlan, condition: Expression): Boolean =
{
val attributes = plan.outputSet
@ -322,4 +340,128 @@ object SparkMethods
u.withNewChildren(Seq(Filter(predicate, u.child)))
}
}
def splitJoinAL(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
val (pushDownCandidates, nonDeterministic) = condition.partition(_.deterministic)
val (leftEvaluateCondition, rest) =
pushDownCandidates.partition(_.references.subsetOf(left.outputSet))
val (rightEvaluateCondition, commonCondition) =
rest.partition(expr => expr.references.subsetOf(right.outputSet))
(leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ nonDeterministic)
}
def pushPredicateThroughJoinAL1(
f: Filter, filterCondition: Expression, left: LogicalPlan, right: LogicalPlan, joinType: JoinType, joinCondition: Option[Expression], hint: JoinHint
): LogicalPlan =
{
val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
splitJoinAL(splitConjunctivePredicates(filterCondition), left, right)
joinType match {
case _: InnerLike =>
// push down the single side `where` condition into respective sides
val newLeft = leftFilterConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = rightFilterConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val (newJoinConditions, others) =
commonFilterCondition.partition(canEvaluateWithinJoin)
val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And)
val join = Join(newLeft, newRight, joinType, newJoinCond, hint)
if (others.nonEmpty) {
Filter(others.reduceLeft(And), join)
} else {
join
}
case RightOuter =>
// push down the right side only `where` condition
val newLeft = left
val newRight = rightFilterConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = joinCondition
val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond, hint)
(leftFilterConditions ++ commonFilterCondition).
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
case LeftOuter | LeftExistence(_) =>
// push down the left side only `where` condition
val newLeft = leftFilterConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = right
val newJoinCond = joinCondition
val newJoin = Join(newLeft, newRight, joinType, newJoinCond, hint)
(rightFilterConditions ++ commonFilterCondition).
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
case other =>
throw new IllegalStateException(s"Unexpected join type: $other")
}
}
def pushPredicateThroughJoinAL2(
j: Join, left: LogicalPlan, right: LogicalPlan, joinType: JoinType, joinCondition: Option[Expression], hint: JoinHint
): LogicalPlan =
{
val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
splitJoinAL(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
joinType match {
case _: InnerLike | LeftSemi =>
// push down the single side only join filter for both sides sub queries
val newLeft = leftJoinConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = rightJoinConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = commonJoinCondition.reduceLeftOption(And)
Join(newLeft, newRight, joinType, newJoinCond, hint)
case RightOuter =>
// push down the left side only join filter for left side sub query
val newLeft = leftJoinConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = right
val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And)
Join(newLeft, newRight, RightOuter, newJoinCond, hint)
case LeftOuter | LeftAnti | ExistenceJoin(_) =>
// push down the right side only join filter for right sub query
val newLeft = left
val newRight = rightJoinConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And)
Join(newLeft, newRight, joinType, newJoinCond, hint)
case other =>
throw new IllegalStateException(s"Unexpected join type: $other")
}
}
def negatedPListExistsHasCorrelatedScalarSubquery(pList: Seq[NamedExpression]): Boolean =
!pList.exists(ScalarSubquery.hasCorrelatedScalarSubquery)
def pushDownLeftSemiAntiJoin1(
j: Join, p: Project, pList: Seq[NamedExpression], gChild: LogicalPlan, rightOp: LogicalPlan, joinType: JoinType, joinCond: Option[Expression], hint: JoinHint
): LogicalPlan =
{
if (joinCond.isEmpty) {
// No join condition, just push down the Join below Project
p.copy(child = Join(gChild, rightOp, joinType, joinCond, hint))
} else {
val aliasMap = getAliasMap(p)
// Do not push complex join condition
if (aliasMap.forall(_._2.child.children.isEmpty)) {
val newJoinCond = if (aliasMap.nonEmpty) {
Option(replaceAlias(joinCond.get, aliasMap))
} else {
joinCond
}
p.copy(child = Join(gChild, rightOp, joinType, newJoinCond, hint))
} else {
j
}
}
}
}

View File

@ -1,7 +1,9 @@
package com.astraldb.catalyst
import com.astraldb.codegen.Render
import com.astraldb.codegen.Rule
import com.astraldb.bdd.{ Compiler => BDDCompiler }
import com.astraldb.spec.Type
object Astral
{
@ -18,9 +20,15 @@ object Astral
}
println(
BDDCompiler.bdd(definition)
)
val family = Type.AST("LogicalPlan")
val bdd =
BDDCompiler.bdd(definition, family)
println("============= Definition =================")
println(definition)
println("================ BDD =====================")
println(bdd)
}
}

View File

@ -10,11 +10,11 @@ object Catalyst extends HardcodedDefinition
Node("Filter")(
"condition" -> Type.Native("Expression"),
"child" -> Type.AST("LogicalPlan")
),
).withSupertypes("UnaryNode"),
Node("Project")(
"projectList" -> Type.Array(Type.Native("NamedExpression")),
"child" -> Type.AST("LogicalPlan")
),
).withSupertypes("UnaryNode"),
Node("Union")(
"children" -> Type.Array(Type.AST("LogicalPlan")),
"byName" -> Type.Bool,
@ -30,27 +30,27 @@ object Catalyst extends HardcodedDefinition
Node("LocalLimit")(
"limitExpr" -> Type.Native("Expression"),
"child" -> Type.AST("LogicalPlan")
),
).withSupertypes("UnaryNode"),
Node("GlobalLimit")(
"limitExpr" -> Type.Native("Expression"),
"child" -> Type.AST("LogicalPlan")
),
).withSupertypes("UnaryNode"),
Node("Aggregate")(
"groupingExpressions" -> Type.Array(Type.Native("Expression")),
"aggregateExpressions" -> Type.Array(Type.Native("NamedExpression")),
"child" -> Type.AST("LogicalPlan")
),
).withSupertypes("UnaryNode"),
Node("Window")(
"windowExpressions" -> Type.Array(Type.Native("NamedExpression")),
"partitionSpec" -> Type.Array(Type.Native("Expression")),
"orderSpec" -> Type.Array(Type.Native("SortOrder")),
"child" -> Type.AST("LogicalPlan")
),
).withSupertypes("UnaryNode"),
Node("EventTimeWatermark")(
"eventTime" -> Type.Native("Attribute"),
"delay" -> Type.Native("CalendarInterval"),
"child" -> Type.AST("LogicalPlan"),
),
).withSupertypes("UnaryNode"),
)
//////////////////////////////////////////////////////
@ -133,8 +133,12 @@ object Catalyst extends HardcodedDefinition
Type.Array(Type.Native("NamedExpression"))
)
Function("canPushThrough", Type.Bool)(
Type.Node("UnaryNode")
Function("canPushThroughNonJoin", Type.Bool)(
Type.ASTSubtype("UnaryNode")
)
Function("canPushThroughJoin", Type.Bool)(
Type.Native("JoinType")
)
Function("canPushThroughCondition", Type.Bool)(
@ -142,6 +146,12 @@ object Catalyst extends HardcodedDefinition
Type.Native("Expression"),
)
Function("canPushThroughConditionSemiAntiJoin", Type.Bool)(
Type.AST("LogicalPlan"),
Type.Option(Type.Native("Expression")),
Type.AST("LogicalPlan"),
)
Function("unaryNodeIsDeterministic", Type.Bool)(
Type.ASTSubtype("UnaryNode")
)
@ -201,6 +211,39 @@ object Catalyst extends HardcodedDefinition
Type.ASTSubtype("UnaryNode"),
)
Function("pushPredicateThroughJoinAL1", Type.AST("LogicalPlan"))(
Type.Node("Filter"),
Type.Native("Expression"),
Type.AST("LogicalPlan"),
Type.AST("LogicalPlan"),
Type.Native("JoinType"),
Type.Option(Type.Native("Expression")),
Type.Native("JoinHint"),
)
Function("pushPredicateThroughJoinAL2", Type.AST("LogicalPlan"))(
Type.Node("Join"),
Type.AST("LogicalPlan"),
Type.AST("LogicalPlan"),
Type.Native("JoinType"),
Type.Option(Type.Native("Expression")),
Type.Native("JoinHint"),
)
Function("pushDownLeftSemiAntiJoin1", Type.AST("LogicalPlan"))(
Type.Node("Join"),
Type.Node("Project"),
Type.Array(Type.Native("NamedExpression")),
Type.AST("LogicalPlan"),
Type.AST("LogicalPlan"),
Type.Native("JoinType"),
Type.Option(Type.Native("Expression")),
Type.Native("JoinHint"),
)
Function("negatedPListExistsHasCorrelatedScalarSubquery", Type.Bool)(
Type.Array(Type.Native("NamedExpression"))
)
Global("JoinHint.NONE", Type.Native("JoinHint"))
Global("RightOuter", Type.Native("JoinType"))
Global("LeftOuter", Type.Native("JoinType"))
@ -613,9 +656,9 @@ object Catalyst extends HardcodedDefinition
Rule("PushDownPredicates-2-6", "LogicalPlan")(
Bind("filter", Match("Filter")(
Bind("unusedCondition"),
Bind("u", Match("UnaryNode")), // here be dragons.
Bind("u", OfType(Type.ASTSubtype("UnaryNode"))),
)) and Test(
Apply("canPushThrough")(
Apply("canPushThroughNonJoin")(
Ref("u")
) and Apply("unaryNodeIsDeterministic")(
Ref("u")
@ -627,4 +670,90 @@ object Catalyst extends HardcodedDefinition
Ref("u"),
)
)
Rule("PushDownPredicates-3-1", "LogicalPlan")(
Bind("f", Match("Filter")(
Bind("filterCondition"),
Bind("j", Match("Join")(
Bind("left"),
Bind("right"),
Bind("joinType"),
Bind("joinCondition"),
Bind("hint"),
)),
)) and Test(
Apply("canPushThroughJoin")(
Ref("joinType")
)
)
)(
Apply("pushPredicateThroughJoinAL1")(
Ref("f"),
Ref("filterCondition"),
Ref("left"),
Ref("right"),
Ref("joinType"),
Ref("joinCondition"),
Ref("hint"),
)
)
Rule("PushDownPredicates-3-2", "LogicalPlan")(
Bind("j", Match("Join")(
Bind("left"),
Bind("right"),
Bind("joinType"),
Bind("joinCondition"),
Bind("hint"),
)) and Test(
Apply("canPushThroughJoin")(
Ref("joinType")
)
)
)(
Apply("pushPredicateThroughJoinAL2")(
Ref("j"),
Ref("left"),
Ref("right"),
Ref("joinType"),
Ref("joinCondition"),
Ref("hint"),
)
)
Rule("PushDownLeftSemiAntiJoin-1", "LogicalPlan")(
Bind("j", Match("Join")(
Bind("p", Match("Project")(
Bind("pList"),
Bind("gChild"),
)),
Bind("rightOp"),
Bind("joinType"),
Bind("joinCond"),
Bind("hint"),
)) and Test(
Apply("namedExpressionsAreDeterministic")(
Ref("pList")
)) and Test(
Apply("negatedPListExistsHasCorrelatedScalarSubquery")(
Ref("pList")
)) and Test(
Apply ("canPushThroughConditionSemiAntiJoin")(
Ref("gChild"),
Ref("joinCond"),
Ref("rightOp"),
)
)
)(
Apply("pushDownLeftSemiAntiJoin1")(
Ref("j"),
Ref("p"),
Ref("pList"),
Ref("gChild"),
Ref("rightOp"),
Ref("joinType"),
Ref("joinCond"),
Ref("hint"),
)
)
}

View File

@ -7,6 +7,7 @@ import com.astraldb.typecheck.TypecheckMatch.MatchTypecheckError
import com.astraldb.typecheck.TypecheckExpression.ExpressionTypecheckError
import com.astraldb.codegen.Render
import com.astraldb.spec.Type
object Generate
{
@ -20,7 +21,12 @@ object Generate
// err.printStackTrace()
System.exit(-1)
}
for( (file, content) <- Render(Catalyst.definition) )
val families = Seq[Type.AST](
Type.AST("LogicalPlan")
)
for( (file, content) <- Render(Catalyst.definition, families) )
{
val of = new BufferedWriter(new FileWriter(file))
of.write("""package com.astraldb.catalyst

View File

@ -35,6 +35,7 @@ case class PickByType(
case class PickByMatch(
path: Seq[Int],
matcher: Match,
bindings: Seq[Pathed[(String, Type)]],
ifMatched: BDD,
ifNotMatched: BDD,
) extends BDD
@ -50,6 +51,8 @@ case class PickByMatch(
case class BindAnExpression(
symbol: String,
expression: Expression,
exprType: Type,
bindings: Seq[Pathed[(String, Type)]],
andThen: BDD
) extends BDD
{
@ -63,11 +66,13 @@ case class BindAnExpression(
case class Rewrite(
label: String,
rewrite: Expression
rewrite: Expression,
bindings: Seq[Pathed[(String, Type)]],
) extends BDD
{
val schema = rewrite.references.toSeq
def code =
Code.Literal(s"Rewrite with $label")
Code.Literal(s"Rewrite with $label($schema)")
}
case object NoRewrite extends BDD

View File

@ -3,6 +3,8 @@ package com.astraldb.bdd
import com.astraldb.spec.Match
import com.astraldb.expression.Expression
import com.astraldb.spec.Type
import com.astraldb.spec.Definition
import com.astraldb.typecheck.TypecheckExpression
sealed trait CandidateStep
@ -13,16 +15,21 @@ object CheckTypeStep
CheckTypeStep(pathed.path, pathed.value)
}
case class CheckMatchStep(path: Pathed.Path, matcher: Match) extends CandidateStep
case class CheckMatchStep(path: Pathed.Path, matcher: Match, bindings: Seq[Pathed[(String, Type)]]) extends CandidateStep
object CheckMatchStep
{
def apply(pathed: Pathed[Match]): CheckMatchStep =
CheckMatchStep(pathed.path, pathed.value)
def apply(pathed: Pathed[Match], target: Target): CheckMatchStep =
{
val refs = pathed.value.references.map { _.v }.toSet
CheckMatchStep(pathed.path, pathed.value,
target.bindings.filter { b => refs(b.value._1) }
)
}
}
case class BindExpressionStep(symbol: String, expression: Expression) extends CandidateStep
case class BindExpressionStep(symbol: String, expression: Expression, exprType: Type, bindings: Seq[Pathed[(String, Type)]]) extends CandidateStep
object BindExpressionStep
{
def apply(binding: (String, Expression)): BindExpressionStep =
BindExpressionStep(binding._1, binding._2)
def apply(binding: (String, Expression, Type), target: Target): BindExpressionStep =
BindExpressionStep(binding._1, binding._2, binding._3, target.bindings)
}

View File

@ -4,6 +4,10 @@ import com.astraldb.spec.Definition
import com.astraldb.spec.Match
import com.astraldb.spec.Type
import com.astraldb.expression.Expression
import com.astraldb.spec.Rule
import com.astraldb.expression.Var
import com.astraldb.expression.IfIsDefined
import com.astraldb.typecheck.TypecheckMatch
object Compiler
{
@ -93,24 +97,67 @@ object Compiler
disjunctify(matcher)
}
def getTargets(schema: Definition): Seq[Target] =
def getClauseTargets(schema: Definition, rule: Rule): Seq[Target] =
{
schema.rules.flatMap { rule =>
val clauses =
normalize(schema, rule.pattern).clauses
clauses.zipWithIndex.map {
case (clause, idx) =>
val label =
rule.label + (if(clauses.size > 1) { "-"+idx } else { "" })
Target.of(label, clause, rule.rewrite)
.withoutUnusedBindings
}
val clauses:Seq[(Conjunction, Set[String])] =
normalize(schema, rule.pattern).clauses
.map { c =>
c -> TypecheckMatch(c.asMatch, rule.family, schema, schema.globals)
.keySet
}
val optionalBindings:Set[String] =
clauses.flatMap { _._2 }
.groupBy { x => x }
.filter { _._2.size < clauses.size }
.keySet
clauses.zipWithIndex.map {
case ((clause, bindings), idx) =>
val label =
rule.label + (if(clauses.size > 1) { "-"+idx } else { "" })
val optionalReferencesInClause =
bindings & optionalBindings
// if(!optionalBindings.isEmpty)
// {
// println(s"Optional: $optionalBindings")
// println(s"Optional Bound: $optionalReferencesInClause")
// }
val sanitizedRewrite =
rule.rewrite.transform {
case IfIsDefined(symbol, ifYes, ifNo)
if optionalBindings(symbol)=>
if(optionalReferencesInClause(symbol)) { ifYes }
else { ifNo }
}
// if(sanitizedRewrite != rule.rewrite)
// {
// println(rule.rewrite)
// println("into")
// println(sanitizedRewrite)
// }
Target.of(schema, label, rule.family, clause, sanitizedRewrite)
.withoutUnusedBindings
}
}
def getTargets(schema: Definition, family: Type.AST): Seq[Target] =
{
schema.rules.flatMap {
case rule if rule.family != family => None
case rule => getClauseTargets(schema, rule)
}
}
def getSanitizedTargets(schema: Definition): Seq[Target] =
def getSanitizedTargets(schema: Definition, family: Type.AST): Seq[Target] =
{
var targets = getTargets(schema)
var targets = getTargets(schema, family)
// Sanitize expression bindings via alpha renaming and
// common subexpression elimination
@ -146,7 +193,7 @@ object Compiler
.toSet.toSeq
.groupBy { _._1 }
.filter { _._2.size > 1 }
.map { case (name:String, exprs:Seq[(String, Expression)]) =>
.map { case (name:String, exprs:Seq[(String, Expression, Type)]) =>
name ->
exprs.map { _._2 }
.zipWithIndex
@ -159,7 +206,7 @@ object Compiler
targets = targets.map { target =>
val repair =
target.exprBindings
.flatMap { case (name, expr) =>
.flatMap { case (name, expr, t) =>
replacementNames.get(name)
.flatMap { _.get(expr).map { newName =>
name -> newName
@ -173,16 +220,15 @@ object Compiler
}
def bdd(schema: Definition): BDD =
def bdd(schema: Definition, family: Type.AST): BDD =
{
println(getSanitizedTargets(schema).mkString("\n---\n"))
def stepBdd(targets: Seq[Target], state: State): BDD =
{
val success = targets.find { _.successfulOn(state) }
if(success.isDefined)
{
Rewrite(success.get.label, success.get.rewrite)
Rewrite(success.get.label, success.get.rewrite, success.get.bindings)
} else {
val activeTargets =
targets.filter { _.canBeSuccessfulOn(state) }
@ -194,7 +240,7 @@ object Compiler
val candidates:Seq[CandidateStep] =
activeTargets
.flatMap { t =>
val c = t.candidates(state)
val c = t.candidates(schema, state)
assert(!c.isEmpty, s"Target ${t.label} in state \n$state\n is not successful or failed, but has no candidates for progress.")
/* return */ c
}
@ -215,7 +261,7 @@ object Compiler
val types:Set[Type] =
bestCheckTypeStep.get._2.map { _._2.check }.toSet
println(s"Checking path: [${path.mkString(", ")}] for ${types.mkString(", ")}")
System.err.println(s"Checking path: [${path.mkString(", ")}] for ${types.mkString(", ")}")
PickByType(
path,
@ -241,10 +287,11 @@ object Compiler
bestOtherCandidate match {
case _:CheckTypeStep => assert(false, "We filtered out all the check type steps by this point")
case CheckMatchStep(path, matcher) =>
case CheckMatchStep(path, matcher, bindings) =>
PickByMatch(
path = path,
matcher = matcher,
bindings = bindings,
ifMatched =
stepBdd(activeTargets,
state.withSuccessfulMatch(path, matcher)),
@ -252,8 +299,10 @@ object Compiler
stepBdd(activeTargets,
state.withFailedMatch(path, matcher)),
)
case BindExpressionStep(symbol, expression) =>
case BindExpressionStep(symbol, expression, exprType, bindings) =>
BindAnExpression(symbol, expression,
exprType,
bindings,
stepBdd(activeTargets,
state.withBinding(symbol)
)
@ -265,7 +314,7 @@ object Compiler
}
}
stepBdd(getSanitizedTargets(schema), State.empty(schema))
stepBdd(getSanitizedTargets(schema, family), State.empty(schema))
}
}

View File

@ -3,21 +3,38 @@ package com.astraldb.bdd
import com.astraldb.expression._
import com.astraldb.spec._
import scala.util.Try
import com.astraldb.typecheck.TypecheckExpression
case class Target(
label: String,
family: Type.AST,
rewrite: Expression,
bindings: Seq[Pathed[String]] = Seq.empty,
exprBindings: Seq[(String, Expression)] = Seq.empty,
bindings: Seq[Pathed[(String, Type)]] = Seq.empty,
exprBindings: Seq[(String, Expression, Type)] = Seq.empty,
types: Seq[Pathed[Type]] = Seq.empty,
matchers: Seq[Pathed[Match]] = Seq.empty,
)
{
def withBinding(path: Seq[Int], symbol: String) =
copy(bindings = bindings :+ Pathed(path, symbol))
{
val varType =
types.find { _.path == path }
.map { _.value }
.getOrElse { family }
copy(
bindings = bindings :+ Pathed(path, (symbol, varType))
)
}
def withType(path: Seq[Int], pathType: Type) =
copy(types = types :+ Pathed(path, pathType))
copy(
types = types :+ Pathed(path, pathType),
bindings = bindings.map {
case b if b.path == path =>
b.copy( value = b.value._1 -> pathType )
case b => b
}
)
def withMatcher(path: Seq[Int], matcher: Match) =
if(matcher == Match.Any){ this }
@ -25,8 +42,8 @@ case class Target(
copy(matchers = matchers :+ Pathed(path, matcher))
}
def withBoundExpression(symbol: String, expr: Expression) =
copy(exprBindings = exprBindings :+ (symbol, expr))
def withBoundExpression(symbol: String, expr: Expression, exprType: Type) =
copy(exprBindings = exprBindings :+ (symbol, expr, exprType))
def withoutUnusedBindings =
{
@ -36,11 +53,11 @@ case class Target(
exprBindings.flatMap { _._2.references }.map { _.v}
copy(
bindings = bindings.filter { b => refs(b.value) }
bindings = bindings.filter { b => refs(b.value._1) }
)
}
def candidates(state: State): Seq[CandidateStep] =
def candidates(schema: Definition, state: State): Seq[CandidateStep] =
{
val requestedTypedPaths = types.map { _.path }.toSet
val typedPaths = state.typeOf.keySet
@ -57,7 +74,7 @@ case class Target(
if(requestedTypedPaths(b.path)) { b.pathInSet(typedPaths) }
// If the binding itself is not typed, then we just need its ancestors bound
else { b.ancestorsInSet(typedPaths) }
}.map { b => Var(b.value) }
}.map { b => Var(b.value._1) }
).toSet
// println(s"Valid bindings: $validBindings")
@ -101,8 +118,8 @@ case class Target(
// println(s"Candidate Matchers: $candidateMatchers")
(candidateTypings.map { CheckTypeStep(_) }: Seq[CandidateStep]) ++
candidateMatchers.map { CheckMatchStep(_) } ++
candidateBindings.map { BindExpressionStep(_) }
candidateMatchers.map { CheckMatchStep(_, this) } ++
candidateBindings.map { BindExpressionStep(_, this) }
}
def canBeSuccessfulOn(state: State): Boolean =
@ -131,13 +148,14 @@ case class Target(
copy(
rewrite = rewrite.rename(replacementMap),
bindings = bindings.map { b =>
replacementMap.get(b.value)
.map { r => b.copy(value = r) }
replacementMap.get(b.value._1)
.map { r => b.copy(value = (r, b.value._2)) }
.getOrElse(b)
},
exprBindings = exprBindings.map { case (name, expr) =>
exprBindings = exprBindings.map { case (name, expr, exprType) =>
( replacementMap.getOrElse(name, name),
expr.rename(replacementMap)
expr.rename(replacementMap),
exprType
)
},
matchers = matchers.map { m =>
@ -146,6 +164,12 @@ case class Target(
)
}
def scope: Map[String, Type] =
(
bindings.map { _.value } ++
exprBindings.map { b => b._1 -> b._3 }
).toMap
override def toString(): String =
label+
(if(types.isEmpty){ "" } else {
@ -158,7 +182,7 @@ case class Target(
})+
(if(exprBindings.isEmpty){ "" } else {
"\n Expression Bindings\n"+
exprBindings.map { case (s, e) => " "+s+" <- "+e }.mkString("\n")
exprBindings.map { case (s, e, t) => " "+s+" <- "+e }.mkString("\n")
})+
(if(matchers.isEmpty){ "" } else {
"\n Matchers\n"+
@ -169,9 +193,9 @@ case class Target(
object Target
{
def of(label: String, clause: Conjunction, rewrite: Expression): Target =
def of(schema: Definition, label: String, family: Type.AST, clause: Conjunction, rewrite: Expression): Target =
{
clause.atoms.foldLeft(Target(label, rewrite)){ (target, atom) =>
clause.atoms.foldLeft(Target(label, family, rewrite)){ (target, atom) =>
atom match {
case Match.Path(path, Match.Bind(symbol, child)) =>
target.withBinding(path, symbol)
@ -186,7 +210,9 @@ object Target
case Match.Path(path, child) =>
target.withMatcher(path, child)
case Match.BindExpression(symbol, expr) =>
target.withBoundExpression(symbol, expr)
target.withBoundExpression(symbol, expr,
TypecheckExpression(expr, schema, target.scope++schema.globals)
)
case _ =>
target.withMatcher(Seq.empty, atom)
}

View File

@ -2,6 +2,7 @@ package com.astraldb.bdd
import com.astraldb.spec.Match
import com.astraldb.codegen.Code
import com.astraldb.expression.Var
case class Conjunction(atoms: Seq[Match])
{
@ -28,6 +29,10 @@ case class Conjunction(atoms: Seq[Match])
")"
)
}
def references:Set[Var] = atoms.flatMap { _.references }.toSet
def asMatch = Match.And(atoms)
}
object Conjunction
@ -78,6 +83,10 @@ case class Disjunction(clauses: Seq[Conjunction])
")"
)
}
def references:Set[Var] = clauses.flatMap { _.references }.toSet
def asMatch = Match.Or(clauses.map { _.asMatch })
}
object Disjunction

View File

@ -0,0 +1,160 @@
package com.astraldb.codegen
import com.astraldb.spec.Definition
import com.astraldb.spec.Type
import com.astraldb.bdd
import com.astraldb.codegen.Code.PaddedString
import com.astraldb.expression.Var
object BDD
{
def apply(
schema: Definition,
family: Type.AST,
rule: bdd.BDD,
onFail: Code,
root: Code,
): Code =
{
def codeScopeFor[T](refs: Set[Var], bindings: Seq[bdd.Pathed[(String, Type)]], pathNames: Map[bdd.Pathed.Path, (Code, Type)], boundVars: Map[String, Type], context: T): CodeScope =
CodeScope(
(
(refs.map { _.v }
-- boundVars.keys
-- schema.globals.keys
).toSeq.map { ref =>
val pathRef =
bindings.find { _.value._1 == ref }
.getOrElse {
assert(false, s"Reference to a variable that hasn't been bound yet: $ref (in $bindings):\n$context")
}
val varDescription =
pathNames.find { _._1 == pathRef.path }
.getOrElse {
assert(false, s"Reference to a variable at an unbound path: ${pathRef.path} (in $pathNames):\n$context")
}
._2
ref -> varDescription
}.toMap:Map[String, Type|(Code,Type)]
) ++ (boundVars:Map[String, Type|(Code,Type)])
++ (schema.globals:Map[String, Type|(Code,Type)])
)
def recur(rule: bdd.BDD, pathNames: Map[bdd.Pathed.Path, (Code, Type)], boundVars: Map[String, Type]): Code =
rule match {
case bdd.NoRewrite => onFail
case bdd.PickByType(path, types) =>
// println(s"Generating code for $path")
val target = pathNames.get(path)
.getOrElse {
assert(false, s"No name available for $path (in $pathNames):\n$rule")
}
._1
Code.IfElifElse(
types.map { case (targetType, andThen) =>
{
val nodeName = "root"+path.map { "_child"+_ }.mkString+"_as_"+targetType.scalaType
val newPathNames: Map[bdd.Pathed.Path, (Code, Type)] =
targetType match {
case Type.Node(node) =>
schema.nodesByName(node)
.fields
.zipWithIndex
.map { case (field, idx) =>
(path :+ idx) -> (
Code.BinOp(
Code.Literal(nodeName),
".",
Code.Literal(field.name)
),
field.t
)
}
.toMap
case _ => Map.empty
}
Code.Pair(target, Code.Literal(s".isInstanceOf[${targetType.scalaType}]"))
-> Code.Block(Seq(
Code.BinOp(
Code.Literal(s"val $nodeName"),
PaddedString.pad(1, "="),
Code.Pair(target,
Code.Literal(s".asInstanceOf[${targetType.scalaType}]")
)
)
)++recur(
andThen,
pathNames ++ newPathNames ++ Map(path -> (Code.Literal(nodeName), targetType)),
boundVars,
).block
)
}
}.toSeq,
onFail
)
case bdd.PickByMatch(path, pattern, bindings, onMatch, onFail) =>
val target = pathNames.get(path)
.getOrElse {
assert(false, s"No name available for $path (in $pathNames):\n$rule")
}
val scope =
codeScopeFor(pattern.references, bindings, pathNames, boundVars, pattern)
Match(
schema = schema,
pattern = pattern,
target = target._1,
targetPath = path,
targetType = target._2,
onSuccess =
_ => recur(
onMatch,
pathNames,
boundVars,
),
onFail =
_ => recur(
onFail,
pathNames,
boundVars,
),
name = None,
scope = scope
)
case bdd.Rewrite(label, op, bindings) =>
Code.Block(
Seq(
Code.Literal(s"// Rewrite by $label"),
) ++ Expression(
schema,
op,
codeScopeFor(op.references, bindings, pathNames, boundVars, op)
).block
)
case bdd.BindAnExpression(symbol, expression, exprType, bindings, andThen) =>
Code.Block(
Seq(
Code.Pair(
Code.Literal(s"val $symbol ="),
Expression(
schema, expression,
codeScopeFor(expression.references, bindings, pathNames, boundVars, expression)
)
)
)++recur(
andThen,
pathNames,
boundVars ++ Map(symbol -> exprType)
).block
)
}
recur(
rule,
Map(
Seq.empty -> (root, family)
),
Map.empty,
)
}
}

View File

@ -321,4 +321,34 @@ object Code
def Parenthesize(body: Code): Code =
Parens("(", body, ")")
def IndentedPair(left: Code, right: Code): Code =
Pair(left, Indent(2, right))
def IfElifElse(ifThenClauses: Seq[(Code, Code)], elseClause: Code): Code =
{
Block(
Seq(
IndentedPair(
Parens("if(", ifThenClauses.head._1, PaddedString.rightPad(1, ") {")),
ifThenClauses.head._2,
)
) ++ ifThenClauses.tail.map { it =>
IndentedPair(
Parens(
PaddedString.leftPad(1, "} else if("),
it._1,
PaddedString.rightPad(1, ") {")
),
it._2,
)
} ++ Seq(
Parens(
PaddedString.pad(1, "} else {"),
elseClause,
PaddedString.leftPad(1, "}")
)
)
)
}
}

View File

@ -2,7 +2,7 @@ package com.astraldb.codegen
import com.astraldb.spec
class CodeScope(vars: Map[String, spec.Type | (Code, spec.Type)])
case class CodeScope(vars: Map[String, spec.Type | (Code, spec.Type)])
{
type Element = spec.Type | (Code, spec.Type)

View File

@ -16,6 +16,7 @@ object Match
schema: spec.Definition,
pattern: spec.Match,
target: Code,
targetPath: Seq[Int],
targetType: spec.Type,
onSuccess: CodeScope => Code,
onFail: CodeScope => Code,
@ -25,16 +26,18 @@ object Match
{
pattern match {
case And(Seq()) => onSuccess(scope)
case And(Seq(a)) => apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
case And(Seq(a)) => apply(schema, a, target, targetPath, targetType, onSuccess, onFail, name, scope)
case And(a) =>
apply(schema, a.head,
target = target,
targetPath = targetPath,
targetType = targetType,
onSuccess =
scope =>
apply(schema, And(a.tail),
target = target,
targetType = targetType,
targetPath = targetPath,
onSuccess = onSuccess,
onFail = onFail,
name = name,
@ -45,35 +48,45 @@ object Match
scope = scope
)
case Not(a) =>
apply(schema, a, target, targetType, onFail, onSuccess, name, scope)
apply(schema, a, target, targetPath, targetType, onFail, onSuccess, name, scope)
case Or(Seq()) => onFail(scope)
case Or(Seq(a)) =>
apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
apply(schema, a, target, targetPath, targetType, onSuccess, onFail, name, scope)
case Or(a) =>
val optional:Map[String, Type] =
a.flatMap {
TypecheckMatch(_, name, targetType, schema, scope.flatten).toSeq
val optional:Map[Var|TypecheckMatch.PathVar, Type] =
a.flatMap { m =>
TypecheckMatch(
m,
targetType,
schema,
scope.flatten
).toSeq
}.groupBy { _._1 }
.filter { _._2.size < a.size }
.mapValues { case v => Typecheck.glb(v.map { _._2 }, schema) }
.map { case (v, t) =>
Var(v) ->
Typecheck.glb(t.map { _._2 }, schema)
}
.toMap
a.foldRight(onFail) { (matcher, nextOnFail) =>
_ =>
apply(schema, matcher,
target = target,
targetPath = targetPath,
targetType = targetType,
onSuccess = {
(scope: CodeScope) => onSuccess(
scope.map {
case (name, t: Type) if optional contains name =>
case (name, t: Type) if optional contains Var(name) =>
name -> (Code.Literal(s"(Some($name):${Type.Option(t).scalaType})"), Type.Option(t))
case (name, c: (Code, Type)) if optional contains name =>
case (name, c: (Code, Type)) if optional contains Var(name) =>
name -> (Code.Parens(s"(Some(", c._1, s"):${Type.Option(c._2).scalaType})"), Type.Option(c._2))
case x => x
}.withVars(
(optional.keySet -- scope.keys).toSeq.map { k =>
k -> (Code.Literal(s"(None:${Type.Option(optional(k)).scalaType})"):Code, Type.Option(optional(k)):Type)
(optional.keySet -- scope.keys.map { Var(_) }).toSeq.collect {
case k:Var =>
k.v -> (Code.Literal(s"(None:${Type.Option(optional(k)).scalaType})"):Code, Type.Option(optional(k)):Type)
}:_*
)
)
@ -93,6 +106,7 @@ object Match
Var(selectedName) eq pattern
)),
target = target,
targetPath = targetPath,
targetType = targetType,
onSuccess = onSuccess,
onFail = onFail,
@ -105,6 +119,7 @@ object Match
schema = schema,
pattern = pattern,
target = target,
targetPath = targetPath,
targetType = targetType,
onSuccess = onSuccess,
onFail = onFail,
@ -154,11 +169,13 @@ object Match
) ++
children
.zip(node.fields)
.foldRight(onSuccess) { case ((child, field), andThen) =>
.zipWithIndex
.foldRight(onSuccess) { case (((child, field), idx), andThen) =>
nextScope => apply(
schema = schema,
pattern = child,
target = Code.Literal(s"$selectedName.${field.name}"),
targetPath = targetPath :+ idx,
targetType = field.t,
onSuccess = andThen,
onFail = onFail,

View File

@ -1,14 +1,23 @@
package com.astraldb.codegen
import com.astraldb.spec.Definition
import com.astraldb.spec.Type
import com.astraldb.bdd
object Render
{
def apply(schema: Definition): Map[String, String] =
def apply(schema: Definition, bddFamilies: Seq[Type.AST] = Seq.empty): Map[String, String] =
{
val bddRule: Seq[(Type.AST, bdd.BDD)] =
bddFamilies.map { f =>
f -> bdd.Compiler.bdd(schema, f)
}
Map(
"Optimizer.scala" -> Optimizer(schema)
) ++ schema.rules.map { rule =>
"Optimizer.scala" -> Optimizer(schema),
) ++ (bddRule.map { case (family, bdd) =>
s"${family.scalaType}BDD.scala" -> Rule(schema, bdd, family)
}.toMap) ++ schema.rules.map { rule =>
s"${rule.safeLabel}.scala" -> Rule(schema, rule)
}.toMap
}

View File

@ -1,6 +1,7 @@
package com.astraldb.codegen
import com.astraldb.spec
import com.astraldb.bdd
object Rule
{
@ -8,4 +9,9 @@ object Rule
{
scala.Rule(schema, rule).toString
}
def apply(schema: spec.Definition, rule: bdd.BDD, family: spec.Type.AST) =
{
scala.BDDBatch(schema, family, rule).toString
}
}

View File

@ -360,6 +360,8 @@ case class Let(symbol: String, value: Expression, rest: Expression) extends Expr
def children = Seq(value, rest)
def reassemble(in: Seq[Expression]): Expression = Let(symbol, in(0), in(1))
override def toString = s"let $symbol = $value in $rest"
override def references: Set[Var] =
value.references ++ (rest.references -- Set(Var(symbol)))
}
/**
* Branch on whether a symbol is defined in scope
@ -370,4 +372,6 @@ case class IfIsDefined(symbol: String, ifYes: Expression, ifNo: Expression) exte
def reassemble(in: Seq[Expression]): Expression =
copy(ifYes = in(0), ifNo = in(1))
override def toString = s"if($symbol exists){ $ifYes } else { $ifNo }"
override def references: Set[Var] =
super.references ++ Set(Var(symbol))
}

View File

@ -4,7 +4,7 @@ import scala.collection.mutable
import com.astraldb.expression._
case class Definition(
asts: Map[String, ASTDefinition],
asts: Map[Type.AST, ASTDefinition],
rules:Seq[Rule],
globals: Map[String, Type],
) {
@ -20,7 +20,7 @@ case class Definition(
|${rules.mkString("\n\n")}
""".stripMargin
val familyOfNode: Map[String, String] =
val familyOfNode: Map[String, Type.AST] =
nodes.flatMap { case (family, elements) => elements.map { _.name -> family } }
.toMap
@ -39,7 +39,7 @@ class HardcodedDefinition
Definition(
asts = nodes.map { case (f, n) =>
val family = Type.AST(f)
f -> ASTDefinition(
family -> ASTDefinition(
family = family,
nodes = n.map { _.copy(family = family) }.toSet
)}.toMap,
@ -62,7 +62,7 @@ class HardcodedDefinition
def Rule(label: String, family: String)(pattern: Match)(replacement: Expression): Unit =
rules.append(
com.astraldb.spec.Rule(label, family, pattern, replacement)
com.astraldb.spec.Rule(label, Type.AST(family), pattern, replacement)
)
def Function(label: String, ret: Type = Type.Unit)(args: Type*): Unit =

View File

@ -6,7 +6,7 @@ import com.astraldb.typecheck.TypecheckExpression
case class Rule(
label: String,
family: String,
family: Type.AST,
pattern: Match,
rewrite: Expression
)
@ -17,11 +17,11 @@ case class Rule(
def validate(schema: Definition): Unit =
{
val scope =
TypecheckMatch(pattern, None, Type.AST(family), schema, schema.globals)
TypecheckMatch(pattern, family, schema, schema.globals)
TypecheckExpression(rewrite, schema, scope) match {
case Type.Node(n) =>
assert(schema.familyOfNode(n) == family, s"Rule $label constructs a $n of family ${schema.familyOfNode(n)}, while the rule is for famil $family")
case Type.AST(f) =>
case f:Type.AST =>
assert(f == family, s"Rule $label constructs a $f, while the rule is for family $family")
case c =>
assert(false, s"Rule $label generates a non-AST type $c")

View File

@ -14,6 +14,8 @@ object Typecheck
case (a, b) if a == b => true
case (Type.Node(label), o:Type.ASTType) =>
schema.nodesByName(label).allSupertypes contains o
case (n:Type.ASTSubtype, o:Type.AST) =>
schema.asts(o).subtypes contains n
case (Type.Union(elems), a) =>
elems.forall { escalatesTo(_, a, schema) }
case (a, Type.Union(elems)) =>
@ -24,6 +26,13 @@ object Typecheck
}
}
/**
* Compute the greatest lower bound.
*
* This is the *more* precise type of the two arguments.
*
* e.g., glb(Node, AST) = Node
*/
def glb(a: Type, b: Type, schema: Definition): Type =
{
if(a == b){ return a }
@ -35,9 +44,9 @@ object Typecheck
schema.nodesByName(label)
.allSupertypes contains o => n
case (n:Type.ASTSubtype, a:Type.AST) if
schema.asts(a.family).subtypes contains n => n
schema.asts(a).subtypes contains n => n
case (a:Type.AST, n:Type.ASTSubtype) if
schema.asts(a.family).subtypes contains n => n
schema.asts(a).subtypes contains n => n
case (Type.Union(elems), _) if elems contains b => a
case (_, Type.Union(elems)) if elems contains a => b
case (_, Type.Any) => a
@ -46,23 +55,42 @@ object Typecheck
}
}
/**
* Compute the greatest lower bound.
*
* This is the *more* precise type of the two arguments.
*
* e.g., glb(Node, AST) = Node
*/
def glb(elems: Seq[Type], schema: Definition): Type =
elems.foldLeft(Type.Any: Type){ glb(_, _, schema) }
/**
* Compute the least upper bound.
*
* This is the *less* precise type of the two arguments.
*
* e.g., glb(Node, AST) = AST
*/
def lub(a: Type, b: Type, schema: Definition): Type =
lubOption(a, b, schema).getOrElse {
assert(false, s"Types $a and $b don't unify")
}
def lubOption(a: Type, b: Type, schema: Definition): Option[Type] =
{
if(a == b){ return a }
if(a == b){ return Some(a) }
(a, b) match {
case (n:Type.Node, o:Type.ASTType) if
schema.nodesByName(n.nodeType)
.allSupertypes contains o => o
.allSupertypes contains o => Some(o)
case (o:Type.ASTType, n@Type.Node(label)) if
schema.nodesByName(label)
.allSupertypes contains o => o
.allSupertypes contains o => Some(o)
case (n:Type.ASTSubtype, o:Type.AST) if
schema.asts(o.family).subtypes contains n => o
schema.asts(o).subtypes contains n => Some(o)
case (o:Type.AST, n:Type.ASTSubtype) if
schema.asts(o.family).subtypes contains n => o
schema.asts(o).subtypes contains n => Some(o)
case (a:Type.Node, b:Type.Node) =>
val sharedSuperTypes =
schema.nodesByName(a.nodeType).allSupertypes &
@ -72,15 +100,12 @@ object Typecheck
} else {
sharedSuperTypes.find { _.isInstanceOf[Type.ASTSubtype] }
.orElse { sharedSuperTypes.find { _.isInstanceOf[Type.AST] }}
.getOrElse {
assert(false, "A node can't inherit from another node")
}
}
case (Type.Union(elems), _) if elems contains b => a
case (_, Type.Union(elems)) if elems contains a => b
case (_, Type.Any) => Type.Any
case (Type.Any, _) => Type.Any
case _ => assert(false, s"Types $a and $b don't unify")
case (Type.Union(elems), _) if elems contains b => Some(a)
case (_, Type.Union(elems)) if elems contains a => Some(b)
case (_, Type.Any) => Some(Type.Any)
case (Type.Any, _) => Some(Type.Any)
case _ => None
}
}

View File

@ -117,8 +117,8 @@ object TypecheckExpression
case Type.Node(nodeType) =>
val node = schema.nodesByName(nodeType)
node.fields(index).t
case _ =>
assert(false, s"Node subscript on something not a node")
case t =>
assert(false, s"Node subscript on something not a node: $t")
}
case StructSubscript(target, name) =>

View File

@ -6,19 +6,247 @@ import com.astraldb.expression._
object TypecheckMatch
{
def apply(pattern: Match, targetType: Type, schema: Definition, scope: Map[String, Type]): Map[String, Type] =
{
val finalState =
check(
pattern = pattern,
targetPath = Seq.empty,
State(
scope = scope ++ schema.globals,
pathTypes = Map(Seq.empty -> targetType),
optionalBindings = Set.empty,
schema = schema
)
)
finalState.asScope
}
case class PathVar(path: Seq[Int])
case class State(
scope: Map[String, PathVar|Type],
pathTypes: Map[Seq[Int], Type],
optionalBindings: Set[String],
schema: Definition
)
{
assert(pathTypes.contains(Seq.empty), s"We should never create a state without a root path: $pathTypes")
override def toString(): String =
"---- Scope ----\n"+
scope.map {
case (name, t:Type) if optionalBindings(name) => s"$name:$t (optional)\n"
case (name, t:Type) => s"$name:$t\n"
case (name, PathVar(path)) if optionalBindings(name) => s"$name @ [${path.mkString(", ")}] (optional)\n"
case (name, PathVar(path)) => s"$name @ [${path.mkString(", ")}]\n"
}.mkString +
"---- PathTypes ----\n"+
pathTypes.map {
case (path, t) => s"@[${path.mkString(",")}]:$t\n"
}.mkString
def typeAtPath(path: Seq[Int]): Type =
path.foldLeft( (
Seq[Int](),
pathTypes.get(Seq.empty)
.getOrElse {
System.err.println(this)
assert(false, s"Trying to look up the type of a path when the root is untyped")
}
) ) { case ((subPath, t), field) =>
val fieldPath = subPath :+ field
val fieldType =
t match {
case Type.Node(label) =>
assert(schema.nodesByName contains label, s"No AST node of type $label defined")
val nodeSchema = schema.nodesByName(label)
assert(nodeSchema.fields.size > field, s"Can't path into the $field'th element of a $t")
nodeSchema.fields(field).t
case Type.Array(base) =>
base
case _ =>
// System.err.println(this)
assert(false, s"Can't path into a simple type: $t @[${subPath.mkString(", ")}] / [${path.mkString(", ")}]")
}
(
fieldPath,
pathTypes.get(fieldPath)
.map { Typecheck.glb(_, fieldType, schema) }
.getOrElse { fieldType }
)
}
._2
def typeOfVar(v: String): Type =
{
val t =
scope.get(v)
.getOrElse {
assert(false, s"Unbound variable: $v")
} match {
case PathVar(path) => typeAtPath(path)
case t:Type => t
}
if(optionalBindings(v)){ Type.Option(t) } else { t }
}
def asScope: Map[String, Type] =
scope.map { case (v, t) =>
val base = t match {
case t:Type => t
case PathVar(path) => typeAtPath(path)
}
v -> (if(optionalBindings(v)){ Type.Option(base) } else { base })
}.toMap
def pathOfVar(v: String): Seq[Int] =
scope.get(v)
.getOrElse {
assert(false, s"Unbound variable: $v")
} match {
case PathVar(path) => path
case t:Type =>
assert(false, s"Expected $v to be bound to a path")
}
def bindPath(path: Seq[Int], t: Type): State =
copy(
pathTypes =
pathTypes ++ Map(
path ->
pathTypes.get(path)
.map { Typecheck.glb(_, t, schema) }
.getOrElse(t)
)
)
def bindVarToPath(v: String, path: Seq[Int]): State =
copy(
scope = scope ++ Map(
v -> PathVar(path)
)
)
def bindVarToType(v: String, t: Type): State =
copy(
scope = scope ++ Map(
v -> t
)
)
def or(other: State): State =
{
State(
scope =
(scope.keySet ++ other.scope.keySet).toSeq.flatMap { v =>
(scope.get(v), other.scope.get(v)) match {
case (Some(myT:Type), Some(otherT:Type)) =>
Some(v -> Typecheck.lub(myT, otherT, schema))
case (Some(myT:Type), Some(PathVar(otherVPath))) =>
Some(v -> Typecheck.lub(myT, other.typeAtPath(otherVPath), schema))
case (Some(PathVar(myVPath)), Some(otherT:Type)) =>
Some(v -> Typecheck.lub(typeAtPath(myVPath), otherT, schema))
case (Some(PathVar(myVPath)), Some(PathVar(otherVPath))) =>
if(myVPath != otherVPath){
Some(v -> Typecheck.lub(typeAtPath(myVPath), other.typeAtPath(otherVPath), schema))
} else {
Some(v -> PathVar(myVPath))
}
case (Some(t:Type), None) =>
Some(v -> t) // optional fields get combined later
case (Some(PathVar(path)), None) =>
Some(v -> typeAtPath(path)) // optional fields get combined later
case (None, Some(t:Type)) =>
Some(v -> t) // optional fields get combined later
case (None, Some(PathVar(path))) =>
Some(v -> other.typeAtPath(path)) // optional fields get combined later
case (None, None) =>
None
}
}.toMap,
pathTypes =
(pathTypes.keySet ++ other.pathTypes.keySet).toSeq.flatMap { p =>
(pathTypes.get(p), other.pathTypes.get(p)) match {
case (Some(t), Some(ot)) =>
Typecheck.lubOption(t, ot, schema).map { p -> _ }
case (_, None) =>
None
case (None, _) =>
None
}
}.toMap,
optionalBindings =
optionalBindings ++ other.optionalBindings ++
((scope.keySet ++ other.scope.keySet)
.filter { v =>
!(scope contains v) || !(other.scope contains v)
}
),
schema = schema
)
}
def narrowToPrefix(prefix: Seq[Int]): State =
copy(
scope =
scope.mapValues {
case t:Type => t
case PathVar(p) =>
(
if(p startsWith prefix) { PathVar(p.drop(prefix.length)) }
else { typeAtPath(p) }
):(PathVar|Type)
}.toMap[String, PathVar|Type],
pathTypes =
pathTypes.filterKeys { _ startsWith prefix }
.map { case (p, v) => p.drop(prefix.length) -> v }
.toMap
)
def expandPrefix(prefix: Seq[Int], target: State): State =
{
val cmp = narrowToPrefix(prefix)
copy(
scope =
scope ++ ((
// include only modifications to the schema
target.scope.toSet -- cmp.scope.toSet
).toMap.mapValues {
case t:Type => t:(Type|PathVar)
case PathVar(p) => PathVar(prefix ++ p):(Type|PathVar)
}),
pathTypes =
pathTypes ++
target.pathTypes.map {
case (path, t) =>
(prefix ++ path) -> t
},
optionalBindings =
target.optionalBindings
)
}
}
object State
{
def empty(schema: Definition) =
new State(schema.globals, Map.empty, Set.empty, schema)
}
case class MatchTypecheckError(msg: String, stack: List[Match]) extends Exception
{
override def getMessage(): String =
stack.map { _.toString }.mkString("\n ,--in--^\n")+"\n ,--in--^\n"+msg
}
def apply(pattern: Match, family: String, schema: Definition, scope: Map[String, Type]): Map[String, Type] =
apply(pattern, None, Type.AST(family), schema, scope)
def apply(pattern: Match, target: Option[String], expectedType: Type, schema: Definition, scope: Map[String, Type]): Map[String, Type] =
def check(
pattern: Match,
targetPath: Seq[Int],
state: State,
): State =
{
def checkAllChildren() =
pattern.children.foreach { apply(_, target, expectedType, schema, scope) }
def schema = state.schema
def astSchema(label: String): Node =
{
@ -27,7 +255,7 @@ object TypecheckMatch
}
def astSchemaWithFamily(family: String, label: String): Node =
def astSchemaWithFamily(family: Type.AST, label: String): Node =
{
assert(schema.nodes contains family, s"AST Family '$family' is not defined")
val astSchema = schema.nodes(family)
@ -36,54 +264,44 @@ object TypecheckMatch
node.get
}
def typeOfCurrentNode: Type =
state.typeAtPath(targetPath)
try {
pattern match {
case Match.Not(child) => apply(child, target, expectedType, schema, scope)
case Match.Not(child) =>
check(child, targetPath, state)
case Match.And(children) =>
children.foldLeft(scope) { (scope, child) =>
apply(child, target, expectedType, schema, scope)
children.foldLeft(state) { (state, child) =>
check(child, targetPath, state)
}
case Match.Or(children) =>
children.flatMap { apply(_, target, expectedType, schema, scope).toSeq }
.groupBy { _._1 }
.mapValues { childTypes =>
val base =
Type.union(childTypes.map { _._2 })
if(childTypes.size == children.size){ base }
else { Type.Option(base) }
}
.toMap
case Match.Or(Seq()) =>
state
case Match.Or(children) =>
children.tail
.foldLeft(check(children.head, targetPath, state)) {
(stateAccum, child) =>
stateAccum or check(child, targetPath, state)
}
case Match.OfType(t) =>
assert(Typecheck.escalatesTo(t, expectedType, schema),
s"Matching a node by type; Type $t can not possibly matched by an element of type $expectedType: $pattern"
assert(Typecheck.escalatesTo(t, typeOfCurrentNode, schema),
s"Matching a node by type; Type $t can not possibly matched by an element of type $typeOfCurrentNode: @[${targetPath.mkString(", ")}]: $pattern"
)
scope ++ target.map { _ -> t }.toMap
state.bindPath(targetPath, t)
case Match.Exact(child) =>
assert(Typecheck.escalatesTo(child.t, expectedType, schema),
s"Matching a node by value; Value $child can not possibly match an element of type $expectedType: $pattern"
assert(Typecheck.escalatesTo(child.t, typeOfCurrentNode, schema),
s"Matching a node by value; Value $child can not possibly match an element of type $typeOfCurrentNode: @[${targetPath.mkString(", ")}]: $pattern"
)
scope
state
case Match.Path(path, child) =>
val targetType =
path.foldLeft(expectedType) { (t, field) =>
t match {
case Type.Node(label) =>
val nodeSchema = astSchema(label)
assert(nodeSchema.fields.size > field, s"Can't path into the $field'th element of a $t: $pattern")
nodeSchema.fields(field).t
case Type.Array(base) =>
base
case _ =>
assert(false, s"Can't path into a simple type: $t: $pattern")
}
}
apply(child, target, targetType, schema, scope)
check(child, targetPath ++ path, state)
case Match.Recursive(_) =>
// not entirely sure how to typecheck this...
@ -91,56 +309,63 @@ object TypecheckMatch
???
case Match.Bind(symbol, child) =>
apply(child, Some(symbol), expectedType, schema, scope ++ Map(symbol -> expectedType))
check(child, targetPath, state.bindVarToPath(symbol, targetPath))
case Match.Any =>
scope
state
case Match.Fail =>
Map.empty
State.empty(schema)
case Match.Lookup(symbol) =>
assert(scope contains symbol, s"$symbol is not bound: $pattern")
assert(Typecheck.escalatesTo(scope(symbol), expectedType, schema), s"$symbol (of type ${scope(symbol)}) can not possibly match an element of type $expectedType")
scope
assert(Typecheck.escalatesTo(
state.typeOfVar(symbol),
typeOfCurrentNode,
schema
), s"$symbol (of type ${state.typeOfVar(symbol)}) can not possibly match an element of type $typeOfCurrentNode")
state
case Match.ApplyToScope(symbol, child) =>
assert(scope contains symbol, s"$symbol is not bound: $pattern")
apply(child, target, scope(symbol), schema, scope)
state.expandPrefix(
state.pathOfVar(symbol),
check(
child,
Seq.empty,
state.narrowToPrefix(targetPath)
)
)
case Match.Forall(child) =>
expectedType match {
case Type.Array(base) =>
apply(child, target, base, schema, scope)
case _ =>
assert(false, s"Forall match on something other than an array: $pattern")
}
check(child, targetPath :+ 0, state)
case Match.Test(op) =>
assert(Typecheck.escalatesTo(TypecheckExpression(op, schema, scope), Type.Bool, schema),
assert(Typecheck.escalatesTo(TypecheckExpression(op, schema, state.asScope), Type.Bool, schema),
s"Non-boolean test expression: $op"
)
scope
state
case Match.BindExpression(symbol, op) =>
scope ++ Map(
symbol -> TypecheckExpression(op, schema, scope)
)
val exprType = TypecheckExpression(op, schema, state.asScope)
state.bindVarToType(symbol, exprType)
case Match.Node(label, fields) =>
assert(expectedType.isInstanceOf[Type.AST], s"Matching a node when $expectedType was expected: $pattern")
val family = expectedType.asInstanceOf[Type.AST].family
val nodeSchema = astSchemaWithFamily(family, label)
assert(nodeSchema.fields.size == fields.size, s"Node type $family.$label expects ${nodeSchema.fields.size} but pattern has ${fields.size}: $pattern")
var runningScope = scope
for( (Field(fieldName, fieldType), fieldMatch) <- nodeSchema.fields.zip(fields) )
{
runningScope = apply(fieldMatch, target, fieldType, schema, runningScope)
assert(
Typecheck.escalatesTo(
Type.Node(label),
typeOfCurrentNode,
schema
),
s"Can't possibly match a Node[$label] with an item of type $typeOfCurrentNode: $pattern"
)
fields.zipWithIndex.foldLeft(
state.bindPath(targetPath, Type.Node(label))
) { case (state, (fieldPattern, fieldIdx)) =>
check(fieldPattern, targetPath :+ fieldIdx, state)
}
runningScope ++ target.map { _ -> Type.Node(label) }.toMap
}
} catch {
case e: AssertionError =>
// e.printStackTrace()
throw new MatchTypecheckError(e.getMessage(), pattern::Nil)
case TypecheckExpression.ExpressionTypecheckError(msg, stack) =>
throw new MatchTypecheckError(s"${stack.head}\n ,--in--^\n$msg", pattern::Nil)

View File

@ -0,0 +1,20 @@
@import com.astraldb.spec.Definition
@import com.astraldb.spec.Type
@import com.astraldb.bdd.BDD
@import com.astraldb.codegen
@import com.astraldb.codegen.Code
@(schema: Definition, family: Type.AST, bdd: BDD)
object @{family.scalaType}BDD extends Rule[@{family.scalaType}]
{
def apply(plan: @{family.scalaType}): @{family.scalaType} =
{
@{codegen.BDD(schema, family, bdd,
onFail = Code.Literal("plan"),
root = Code.Literal("plan")
)
.toString(indent = 4)
.stripPrefix(" ")}
}
}

View File

@ -9,15 +9,14 @@
@(schema: Definition, rule: Rule)
object @{rule.safeLabel} extends Rule[LogicalPlan]
object @{rule.safeLabel} extends Rule[@{rule.family.scalaType}]
{
def apply(plan: LogicalPlan): LogicalPlan =
def apply(plan: @{rule.family.scalaType}): @{rule.family.scalaType} =
{
@{
val matchSchema = TypecheckMatch(
rule.pattern,
Some("plan"),
Type.AST(rule.family),
rule.family,
schema,
schema.globals
)
@ -26,7 +25,8 @@ object @{rule.safeLabel} extends Rule[LogicalPlan]
schema = schema,
pattern = rule.pattern,
target = Code.Literal("plan"),
targetType = Type.AST(rule.family),
targetPath = Seq.empty,
targetType = rule.family,
onSuccess = Expression(schema, rule.rewrite, _),
onFail = { _ => Code.Literal("plan") },
name = Some("plan"),

View File

@ -106,38 +106,37 @@ object astral extends Module
def scalaVersion = "2.13.8"
def generatedSources = T{ astral.catalyst.rendered() }
def moduleDeps = Seq(astral.compiler, astral.catalyst)
def ivyDeps = Agg(
ivy"org.apache.spark::spark-sql::3.4.1",
)
def internalJavaVersion = T {
try {
val jvm = System.getProperties().getProperty("java.version")
println(f"Running Vizier with `${jvm}`")
jvm.split("\\.")(0).toInt
} catch {
case _:NumberFormatException | _:ArrayIndexOutOfBoundsException =>
println("Unable to retrieve java version. Guessing 11+")
11
}
}
def internalJavaVersion = T {
try {
val jvm = System.getProperties().getProperty("java.version")
println(f"Running Catalyst lite with `${jvm}`")
jvm.split("\\.")(0).toInt
} catch {
case _:NumberFormatException | _:ArrayIndexOutOfBoundsException =>
println("Unable to retrieve java version. Guessing 11+")
11
}
}
def forkArgs = T {
if(internalJavaVersion() >= 11){
Seq(
// Required on Java 11+ for Arrow compatibility
// per: https://spark.apache.org/docs/latest/index.html
"-Dio.netty.tryReflectionSetAccessible=true",
def forkArgs = T {
if(internalJavaVersion() >= 11){
Seq(
// Required on Java 11+ for Arrow compatibility
// per: https://spark.apache.org/docs/latest/index.html
"-Dio.netty.tryReflectionSetAccessible=true",
// Required for Spark on java 11+
// per: https://stackoverflow.com/questions/72230174/java-17-solution-for-spark-java-lang-noclassdeffounderror-could-not-initializ
"--add-exports", "java.base/sun.nio.ch=ALL-UNNAMED",
"--add-opens", "java.base/sun.nio.ch=ALL-UNNAMED",
)
} else { Seq[String]() }
}
// Required for Spark on java 11+
// per: https://stackoverflow.com/questions/72230174/java-17-solution-for-spark-java-lang-noclassdeffounderror-could-not-initializ
"--add-exports", "java.base/sun.nio.ch=ALL-UNNAMED",
"--add-opens", "java.base/sun.nio.ch=ALL-UNNAMED",
)
} else { Seq[String]() }
}
}
}
}