From 4bb865e6dfa0e03e5f1b121d867866c27580290b Mon Sep 17 00:00:00 2001 From: Oliver Kennedy Date: Tue, 11 Jul 2023 19:23:38 -0400 Subject: [PATCH] Progress on BDDs: Generating BDD structure (still need codegen) --- .../src/com/astraldb/catalyst/Astral.scala | 2 +- .../compiler/src/com/astraldb/bdd/BDD.scala | 56 ++++++ .../src/com/astraldb/bdd/Candidate.scala | 28 +++ .../src/com/astraldb/bdd/Compiler.scala | 163 +++++++++++++++++- .../src/com/astraldb/bdd/Pathed.scala | 35 +++- .../compiler/src/com/astraldb/bdd/State.scala | 40 ++++- .../src/com/astraldb/bdd/Target.scala | 113 +++++++++++- .../com/astraldb/expression/Expression.scala | 11 +- .../src/com/astraldb/spec/Match.scala | 44 ++++- 9 files changed, 482 insertions(+), 10 deletions(-) create mode 100644 astral/compiler/src/com/astraldb/bdd/Candidate.scala diff --git a/astral/catalyst/src/com/astraldb/catalyst/Astral.scala b/astral/catalyst/src/com/astraldb/catalyst/Astral.scala index 7578cbd..225e03c 100644 --- a/astral/catalyst/src/com/astraldb/catalyst/Astral.scala +++ b/astral/catalyst/src/com/astraldb/catalyst/Astral.scala @@ -19,7 +19,7 @@ object Astral } println( - BDDCompiler.targets(definition).mkString("\n-----\n") + BDDCompiler.bdd(definition) ) } diff --git a/astral/compiler/src/com/astraldb/bdd/BDD.scala b/astral/compiler/src/com/astraldb/bdd/BDD.scala index 1b94896..12fb67a 100644 --- a/astral/compiler/src/com/astraldb/bdd/BDD.scala +++ b/astral/compiler/src/com/astraldb/bdd/BDD.scala @@ -2,13 +2,35 @@ package com.astraldb.bdd import com.astraldb.expression._ import com.astraldb.spec._ +import com.astraldb.codegen.Code sealed trait BDD +{ + def code: Code + override def toString(): String = + code.toString +} case class PickByType( path: Seq[Int], subtrees: Map[Type, BDD] ) extends BDD +{ + def code = + Code.Parens( + left = s"@[${path.mkString(",")}]:PickByType {", + right = s"}", + body = Code.Block( + subtrees.map { case (t, andThen) => + Code.Parens( + left = s"case $t", + right = "", + body = andThen.code + ) + }.toSeq + ) + ) +} case class PickByMatch( path: Seq[Int], @@ -16,7 +38,41 @@ case class PickByMatch( ifMatched: BDD, ifNotMatched: BDD, ) extends BDD +{ + def code = + Code.IfThenElse( + Code.Literal(Match.Path(path, matcher).toString), + ifMatched.code, + ifNotMatched.code + ) +} + +case class BindAnExpression( + symbol: String, + expression: Expression, + andThen: BDD +) extends BDD +{ + def code = + Code.Parens( + left = s"let $symbol <- $expression", + right = "", + body = andThen.code + ) +} case class Rewrite( + label: String, rewrite: Expression ) extends BDD +{ + def code = + Code.Literal(s"Rewrite with $label") +} + +case object NoRewrite extends BDD +{ + def code = + Code.Literal(s"No Rewrite Possible") +} + diff --git a/astral/compiler/src/com/astraldb/bdd/Candidate.scala b/astral/compiler/src/com/astraldb/bdd/Candidate.scala new file mode 100644 index 0000000..817de91 --- /dev/null +++ b/astral/compiler/src/com/astraldb/bdd/Candidate.scala @@ -0,0 +1,28 @@ +package com.astraldb.bdd + +import com.astraldb.spec.Match +import com.astraldb.expression.Expression +import com.astraldb.spec.Type + +sealed trait CandidateStep + +case class CheckTypeStep(path: Pathed.Path, check: Type) extends CandidateStep +object CheckTypeStep +{ + def apply(pathed: Pathed[Type]): CheckTypeStep = + CheckTypeStep(pathed.path, pathed.value) +} + +case class CheckMatchStep(path: Pathed.Path, matcher: Match) extends CandidateStep +object CheckMatchStep +{ + def apply(pathed: Pathed[Match]): CheckMatchStep = + CheckMatchStep(pathed.path, pathed.value) +} + +case class BindExpressionStep(symbol: String, expression: Expression) extends CandidateStep +object BindExpressionStep +{ + def apply(binding: (String, Expression)): BindExpressionStep = + BindExpressionStep(binding._1, binding._2) +} diff --git a/astral/compiler/src/com/astraldb/bdd/Compiler.scala b/astral/compiler/src/com/astraldb/bdd/Compiler.scala index 332ed2a..da9c116 100644 --- a/astral/compiler/src/com/astraldb/bdd/Compiler.scala +++ b/astral/compiler/src/com/astraldb/bdd/Compiler.scala @@ -3,6 +3,7 @@ package com.astraldb.bdd import com.astraldb.spec.Definition import com.astraldb.spec.Match import com.astraldb.spec.Type +import com.astraldb.expression.Expression object Compiler { @@ -92,7 +93,7 @@ object Compiler disjunctify(matcher) } - def targets(schema: Definition): Seq[Target] = + def getTargets(schema: Definition): Seq[Target] = { schema.rules.flatMap { rule => val clauses = @@ -107,4 +108,164 @@ object Compiler } } + def getSanitizedTargets(schema: Definition): Seq[Target] = + { + var targets = getTargets(schema) + + // Sanitize expression bindings via alpha renaming and + // common subexpression elimination + // * If two expressions are the same, they should + // have the same name. + + val duplicateExprs: Map[Expression, String] = + targets.flatMap { _.exprBindings } + .toSet.toSeq + .groupBy { _._2 } + .filter { _._2.size > 1 } + .mapValues { _.map { _._1 }.head } + .toMap + + targets = targets.map { target => + val repair = + target.exprBindings + .filter { b => duplicateExprs contains b._2 } + .map { b => + b._1 -> duplicateExprs(b._2) + } + .filterNot { a => a._1 != a._2 } + if(repair.isEmpty){ target } + else { + target.rename(repair:_*) + } + } + + // * If two expressions are different, they should + // have different names. + val replacementNames: Map[String, Map[Expression, String]] = + targets.flatMap { _.exprBindings } + .toSet.toSeq + .groupBy { _._1 } + .filter { _._2.size > 1 } + .map { case (name:String, exprs:Seq[(String, Expression)]) => + name -> + exprs.map { _._2 } + .zipWithIndex + .toMap + .mapValues { name + "_" + _ } + .toMap + } + .toMap + + targets = targets.map { target => + val repair = + target.exprBindings + .flatMap { case (name, expr) => + replacementNames.get(name) + .flatMap { _.get(expr).map { newName => + name -> newName + } } + } + if(repair.isEmpty) { target } + else { target.rename(repair:_*) } + } + + return targets + } + + + def bdd(schema: Definition): BDD = + { + println(getSanitizedTargets(schema).mkString("\n---\n")) + def stepBdd(targets: Seq[Target], state: State): BDD = + { + val success = targets.find { _.successfulOn(state) } + + if(success.isDefined) + { + Rewrite(success.get.label, success.get.rewrite) + } else { + val activeTargets = + targets.filter { _.canBeSuccessfulOn(state) } + + if(activeTargets.isEmpty){ + NoRewrite + } else { + + val candidates:Seq[CandidateStep] = + activeTargets + .flatMap { t => + val c = t.candidates(state) + assert(!c.isEmpty, s"Target ${t.label} in state \n$state\n is not successful or failed, but has no candidates for progress.") + /* return */ c + } + + val bestCheckTypeStep = + candidates.collect { + case check:CheckTypeStep => + check.path -> check + } + .groupBy { _._1 } + .toSeq + .maxByOption { _._2.size } + + if(bestCheckTypeStep.isDefined) + { + val path: Pathed.Path = + bestCheckTypeStep.get._1 + val types:Set[Type] = + bestCheckTypeStep.get._2.map { _._2.check }.toSet + + println(s"Checking path: [${path.mkString(", ")}] for ${types.mkString(", ")}") + + PickByType( + path, + types.toSeq.map { t => + t -> stepBdd(activeTargets, state.withType(path, t)) + }.toMap + ) + } else { // no advancement through a best check type + + val nonCheckTypeCandidates = + candidates.filterNot { _.isInstanceOf[CheckTypeStep] } + + assert(!nonCheckTypeCandidates.isEmpty, + s"No possible candidates to make progress: \nstate=$state\ntargets=\n${activeTargets.mkString("\n---\n")}\ncandidates=$candidates" + ) + + val bestOtherCandidate = + nonCheckTypeCandidates + .groupBy { x => x } + .maxBy { _._2.size } + ._1 + + + bestOtherCandidate match { + case _:CheckTypeStep => assert(false, "We filtered out all the check type steps by this point") + case CheckMatchStep(path, matcher) => + PickByMatch( + path = path, + matcher = matcher, + ifMatched = + stepBdd(activeTargets, + state.withSuccessfulMatch(path, matcher)), + ifNotMatched = + stepBdd(activeTargets, + state.withFailedMatch(path, matcher)), + ) + case BindExpressionStep(symbol, expression) => + BindAnExpression(symbol, expression, + stepBdd(activeTargets, + state.withBinding(symbol) + ) + ) + } + + } + } + } + } + + stepBdd(getSanitizedTargets(schema), State.empty(schema)) + } + } \ No newline at end of file diff --git a/astral/compiler/src/com/astraldb/bdd/Pathed.scala b/astral/compiler/src/com/astraldb/bdd/Pathed.scala index 532331d..66b699e 100644 --- a/astral/compiler/src/com/astraldb/bdd/Pathed.scala +++ b/astral/compiler/src/com/astraldb/bdd/Pathed.scala @@ -1,8 +1,41 @@ package com.astraldb.bdd -case class Pathed[T](path: Seq[Int], value: T) +case class Pathed[T](path: Pathed.Path, value: T) { override def toString(): String = s"@[${path.mkString(", ")}]: $value" + + def tuple = (path, value) + + def pathInSet(set: Set[Pathed.Path]) = + Pathed.selfAndAncestorsInSet(path, set) + + def ancestorsInSet(set: Set[Pathed.Path]) = + Pathed.ancestorsInSet(path, set) } + +object Pathed +{ + type Path = Seq[Int] + + def selfAndAncestorsInSet(path: Path, set: Set[Path]): Boolean = + { + if(!set(Seq.empty)){ false } + else { + path.foldLeft(Some(Seq[Int]()):Option[Seq[Int]]) { + case (Some(seq), elem) => + val newSeq = seq :+ elem + if(set contains (newSeq)) { Some(newSeq) } else { None } + + case (None, _) => None + }.isDefined + } + } + + def ancestorsInSet(path:Path, set: Set[Path]): Boolean = + if(path.isEmpty){ true } + else { + selfAndAncestorsInSet(path.dropRight(1), set) + } +} \ No newline at end of file diff --git a/astral/compiler/src/com/astraldb/bdd/State.scala b/astral/compiler/src/com/astraldb/bdd/State.scala index 4088222..d0d8918 100644 --- a/astral/compiler/src/com/astraldb/bdd/State.scala +++ b/astral/compiler/src/com/astraldb/bdd/State.scala @@ -4,6 +4,42 @@ import com.astraldb.expression._ import com.astraldb.spec._ case class State( - types: Pathed[Type], - matchers: Pathed[Match] + types: Set[Pathed[Type]], + passedMatchers: Set[Pathed[Match]], + failedMatchers: Set[Pathed[Match]], + bindings: Set[String] ) +{ + assert((passedMatchers & failedMatchers).isEmpty, "A matcher can't be both successful and failed.") + + lazy val typeOf = types.map { _.tuple }.toMap + + def withType(path: Pathed.Path, value: Type): State = + copy(types = types ++ Set(Pathed(path, value))) + + def withSuccessfulMatch(path: Pathed.Path, matcher: Match): State = + copy(passedMatchers = passedMatchers ++ Set(Pathed(path, matcher))) + + def withFailedMatch(path: Pathed.Path, matcher: Match): State = + copy(failedMatchers = failedMatchers ++ Set(Pathed(path, matcher))) + + def withBinding(symbol: String): State = + copy(bindings = bindings ++ Set(symbol)) + + override def toString(): String = + "=== Types ===\n"++ + types.mkString("\n")++ + "\n=== Passed Matchers ===\n"++ + passedMatchers.mkString("\n")++ + "\n=== Failed Matchers ===\n"++ + failedMatchers.mkString("\n")++ + "\n=== Bindings ===\n"++ + bindings.mkString(",") + +} + +object State +{ + def empty(schema: Definition) = + State(Set.empty, Set.empty, Set.empty, schema.globals.keySet) +} \ No newline at end of file diff --git a/astral/compiler/src/com/astraldb/bdd/Target.scala b/astral/compiler/src/com/astraldb/bdd/Target.scala index 951def8..8aee8e4 100644 --- a/astral/compiler/src/com/astraldb/bdd/Target.scala +++ b/astral/compiler/src/com/astraldb/bdd/Target.scala @@ -2,6 +2,7 @@ package com.astraldb.bdd import com.astraldb.expression._ import com.astraldb.spec._ +import scala.util.Try case class Target( label: String, @@ -29,12 +30,122 @@ case class Target( def withoutUnusedBindings = { - val refs = (rewrite.references ++ matchers.flatMap { _.value.references }).map { _.v } + val refs: Set[String] = ( + rewrite.references ++ + matchers.flatMap { _.value.references }).map { _.v } ++ + exprBindings.flatMap { _._2.references }.map { _.v} + copy( bindings = bindings.filter { b => refs(b.value) } ) } + def candidates(state: State): Seq[CandidateStep] = + { + val requestedTypedPaths = types.map { _.path }.toSet + val typedPaths = state.typeOf.keySet + + /** + * Variable bindings that have had all of their prerequisites filled. + */ + val validBindings:Set[Var] = ( + // bindings that were computed by a prior step + state.bindings.map { Var(_) } ++ + // path bindings that have been fully typed + bindings.filter { b => + // If we want the binding to be typed, we need to check up to the type of the binding + if(requestedTypedPaths(b.path)) { b.pathInSet(typedPaths) } + // If the binding itself is not typed, then we just need its ancestors bound + else { b.ancestorsInSet(typedPaths) } + }.map { b => Var(b.value) } + ).toSet + + // println(s"Valid bindings: $validBindings") + + /** + * Expression bindings that could potentially be run right now. + */ + val candidateBindings = + exprBindings + // Don't re-compute bindings that are already available + .filterNot { b => validBindings(Var(b._1)) } + // A binding needs all of its referenced variables to be bound + .filter { _._2.references.forall { validBindings(_) } } + + // println(s"Typed Paths: $typedPaths") + + val candidateTypings = + types + // Don't re-check typings that are already available + .filterNot { t => typedPaths(t.path) } + // All ancestors need to be typed + .filter { _.ancestorsInSet(typedPaths) } + + // println(s"Candidate Typings: $candidateTypings") + + val candidateMatchers = + matchers + // Don't re-check matchers that have already been computed + .filterNot { m => + // println("m1:"+m); + state.passedMatchers(m) } + // All ancestors need to be typed + .filter { m => + // println("m2:"+m); + m.ancestorsInSet(typedPaths) } + // A matcher needs all of its referenced variables to be bound + .filter { m => + // println("m3:"+m+"\n"+m.value.references); + m.value.references.forall { validBindings(_) } } + + // println(s"Candidate Matchers: $candidateMatchers") + + (candidateTypings.map { CheckTypeStep(_) }: Seq[CandidateStep]) ++ + candidateMatchers.map { CheckMatchStep(_) } ++ + candidateBindings.map { BindExpressionStep(_) } + } + + def canBeSuccessfulOn(state: State): Boolean = + types.forall { t => + state.typeOf.get(t.path) + .map { _ == t.value } + .getOrElse { true } + } && matchers.forall { m => + !(state.failedMatchers contains m) + } + + def successfulOn(state: State): Boolean = + types.forall { t => + state.typeOf.get(t.path) + .map { _ == t.value } + .getOrElse { false } + } && matchers.forall { m => + (state.passedMatchers contains m) + } && exprBindings.forall { e => + state.bindings.contains(e._1) + } + + def rename(replacements: (String, String)*): Target = + { + val replacementMap = replacements.toMap + copy( + rewrite = rewrite.rename(replacementMap), + bindings = bindings.map { b => + replacementMap.get(b.value) + .map { r => b.copy(value = r) } + .getOrElse(b) + }, + exprBindings = exprBindings.map { case (name, expr) => + ( replacementMap.getOrElse(name, name), + expr.rename(replacementMap) + ) + }, + matchers = matchers.map { m => + m.copy(value = m.value.rename(replacementMap)) + } + ) + } + override def toString(): String = label+ (if(types.isEmpty){ "" } else { diff --git a/astral/compiler/src/com/astraldb/expression/Expression.scala b/astral/compiler/src/com/astraldb/expression/Expression.scala index a08772d..6f57fff 100755 --- a/astral/compiler/src/com/astraldb/expression/Expression.scala +++ b/astral/compiler/src/com/astraldb/expression/Expression.scala @@ -79,11 +79,18 @@ sealed abstract class Expression def children: Seq[Expression] def reassemble(in: Seq[Expression]): Expression - def rebuild(fn:(Expression => Expression)): Expression = - reassemble(children.map { fn(_) }) + def transform(fn: PartialFunction[Expression,Expression]): Expression = + reassemble(children.map { _.transform(fn) }).map(fn) + def map(fn: PartialFunction[Expression,Expression]): Expression = + fn.applyOrElse(this, _ => this) def references: Set[Var] = children.flatMap { _.references }.toSet + + def rename(newNames: Map[String, String]): Expression = + transform { + case Var(v) if newNames contains v => Var(newNames(v)) + } } ///////////////////////// Constants //////////////////////////// diff --git a/astral/compiler/src/com/astraldb/spec/Match.scala b/astral/compiler/src/com/astraldb/spec/Match.scala index b17c18f..7491cef 100644 --- a/astral/compiler/src/com/astraldb/spec/Match.scala +++ b/astral/compiler/src/com/astraldb/spec/Match.scala @@ -26,8 +26,11 @@ sealed trait Match def orSeq: Seq[Match] = this match { case Match.Or(children) => children; case _ => Seq(this) } + def expressions: Seq[Expression] + def references: Set[Var] = - children.flatMap { _.references }.toSet + children.flatMap { _.references }.toSet ++ + expressions.flatMap { _.references }.toSet def and(other: Match) = Match.And(andSeq ++ other.andSeq) @@ -48,6 +51,9 @@ sealed trait Match def mapChildren(f: Function[Match, Match]): Match = reassemble(children.map(f)) + + def rename(newNames: Map[String, String]): Match = + reassemble(children.map { _.rename(newNames) }) } object Match { @@ -73,6 +79,7 @@ object Match s"Not(\n$prefix${child.toString(prefix+" ")}\n$prefix)" def children = Seq(child) def reassemble(in: Seq[Match]) = copy(child = in(0)) + def expressions: Seq[Expression] = Seq.empty } /** @@ -87,6 +94,7 @@ object Match def toString(prefix: String): String = s"And(\n$prefix ${children.map { _.toString(prefix+" ")}.mkString(",\n"+prefix+" ")}\n$prefix)" def reassemble(in: Seq[Match]) = copy(children = in) + def expressions: Seq[Expression] = Seq.empty } /** @@ -104,6 +112,7 @@ object Match def toString(prefix: String): String = s"Or(\n$prefix ${children.map { _.toString(prefix+" ")}.mkString(",\n"+prefix+" ")}\n$prefix)" def reassemble(in: Seq[Match]) = copy(children = in) + def expressions: Seq[Expression] = Seq.empty } /** @@ -128,6 +137,7 @@ object Match def toString(prefix: String): String = s"{$nodeLabel}(\n$prefix ${children.map { _.toString(prefix+" ")}.mkString(",\n"+prefix+" ")}\n$prefix)" def reassemble(in: Seq[Match]) = copy(children = in) + def expressions: Seq[Expression] = Seq.empty } /** @@ -145,6 +155,7 @@ object Match s"{$nodeType}" def children: Seq[Match] = Seq.empty def reassemble(in: Seq[Match]): Match = this + def expressions: Seq[Expression] = Seq.empty } object OfType { @@ -167,6 +178,7 @@ object Match s"Exact(${pattern})" def children: Seq[Match] = Seq.empty def reassemble(in: Seq[Match]): Match = this + def expressions: Seq[Expression] = Seq.empty } /** @@ -194,7 +206,7 @@ object Match s"@[${path.mkString(",")}](\n$prefix ${pattern.toString(prefix+" ")}\n$prefix)" def children: Seq[Match] = Seq(pattern) def reassemble(in: Seq[Match]): Match = copy(pattern = in(0)) - + def expressions: Seq[Expression] = Seq.empty } /** @@ -211,6 +223,7 @@ object Match s"Recursive(\n$prefix ${pattern.toString(prefix+" ")}\n$prefix)" def children: Seq[Match] = Seq(pattern) def reassemble(in: Seq[Match]): Match = copy(pattern = in(0)) + def expressions: Seq[Expression] = Seq.empty } /** @@ -231,6 +244,12 @@ object Match } def children: Seq[Match] = Seq(pattern) def reassemble(in: Seq[Match]): Match = copy(pattern = in(0)) + def expressions: Seq[Expression] = Seq.empty + override def rename(newNames: Map[String, String]): Match = + Bind( + symbol = newNames.getOrElse(symbol, symbol), + pattern.rename(newNames) + ) } /** @@ -244,6 +263,7 @@ object Match "_" def children: Seq[Match] = Seq.empty def reassemble(in: Seq[Match]): Match = this + def expressions: Seq[Expression] = Seq.empty } /** @@ -257,6 +277,7 @@ object Match "Fail!" def children: Seq[Match] = Seq.empty def reassemble(in: Seq[Match]): Match = this + def expressions: Seq[Expression] = Seq.empty } /** @@ -276,6 +297,9 @@ object Match def children: Seq[Match] = Seq.empty def reassemble(in: Seq[Match]): Match = this override def references = super.references ++ Set(Var(symbol)) + def expressions: Seq[Expression] = Seq.empty + override def rename(newNames: Map[String, String]): Match = + copy(symbol = newNames.getOrElse(symbol, symbol)) } /** @@ -295,6 +319,12 @@ object Match def children: Seq[Match] = Seq(pattern) def reassemble(in: Seq[Match]): Match = copy(pattern = in(0)) override def references = super.references ++ Set(Var(symbol)) + def expressions: Seq[Expression] = Seq.empty + override def rename(newNames: Map[String, String]): Match = + copy( + symbol = newNames.getOrElse(symbol, symbol), + pattern = pattern.rename(newNames) + ) } /** @@ -317,6 +347,7 @@ object Match s"@[*](\n$prefix ${pattern.toString(prefix+" ")}\n$prefix)" def children: Seq[Match] = Seq(pattern) def reassemble(in: Seq[Match]): Match = copy(pattern = in(0)) + def expressions: Seq[Expression] = Seq.empty } /** @@ -334,6 +365,9 @@ object Match def children: Seq[Match] = Seq.empty def reassemble(in: Seq[Match]): Match = this override def references = super.references ++ op.references + def expressions: Seq[Expression] = Seq(op) + override def rename(newNames: Map[String, String]): Match = + Test(op.rename(newNames)) } /** @@ -348,5 +382,11 @@ object Match def children: Seq[Match] = Seq.empty def reassemble(in: Seq[Match]): Match = this override def references = super.references ++ op.references + def expressions: Seq[Expression] = Seq(op) + override def rename(newNames: Map[String, String]): Match = + BindExpression( + symbol = newNames.getOrElse(symbol, symbol), + op = op.rename(newNames) + ) } } \ No newline at end of file