Progress on BDDs: Generating BDD structure (still need codegen)

This commit is contained in:
Oliver Kennedy 2023-07-11 19:23:38 -04:00
parent b87d3a01b0
commit 4bb865e6df
Signed by: okennedy
GPG key ID: 3E5F9B3ABD3FDB60
9 changed files with 482 additions and 10 deletions

View file

@ -19,7 +19,7 @@ object Astral
}
println(
BDDCompiler.targets(definition).mkString("\n-----\n")
BDDCompiler.bdd(definition)
)
}

View file

@ -2,13 +2,35 @@ package com.astraldb.bdd
import com.astraldb.expression._
import com.astraldb.spec._
import com.astraldb.codegen.Code
sealed trait BDD
{
def code: Code
override def toString(): String =
code.toString
}
case class PickByType(
path: Seq[Int],
subtrees: Map[Type, BDD]
) extends BDD
{
def code =
Code.Parens(
left = s"@[${path.mkString(",")}]:PickByType {",
right = s"}",
body = Code.Block(
subtrees.map { case (t, andThen) =>
Code.Parens(
left = s"case $t",
right = "",
body = andThen.code
)
}.toSeq
)
)
}
case class PickByMatch(
path: Seq[Int],
@ -16,7 +38,41 @@ case class PickByMatch(
ifMatched: BDD,
ifNotMatched: BDD,
) extends BDD
{
def code =
Code.IfThenElse(
Code.Literal(Match.Path(path, matcher).toString),
ifMatched.code,
ifNotMatched.code
)
}
case class BindAnExpression(
symbol: String,
expression: Expression,
andThen: BDD
) extends BDD
{
def code =
Code.Parens(
left = s"let $symbol <- $expression",
right = "",
body = andThen.code
)
}
case class Rewrite(
label: String,
rewrite: Expression
) extends BDD
{
def code =
Code.Literal(s"Rewrite with $label")
}
case object NoRewrite extends BDD
{
def code =
Code.Literal(s"No Rewrite Possible")
}

View file

@ -0,0 +1,28 @@
package com.astraldb.bdd
import com.astraldb.spec.Match
import com.astraldb.expression.Expression
import com.astraldb.spec.Type
sealed trait CandidateStep
case class CheckTypeStep(path: Pathed.Path, check: Type) extends CandidateStep
object CheckTypeStep
{
def apply(pathed: Pathed[Type]): CheckTypeStep =
CheckTypeStep(pathed.path, pathed.value)
}
case class CheckMatchStep(path: Pathed.Path, matcher: Match) extends CandidateStep
object CheckMatchStep
{
def apply(pathed: Pathed[Match]): CheckMatchStep =
CheckMatchStep(pathed.path, pathed.value)
}
case class BindExpressionStep(symbol: String, expression: Expression) extends CandidateStep
object BindExpressionStep
{
def apply(binding: (String, Expression)): BindExpressionStep =
BindExpressionStep(binding._1, binding._2)
}

View file

@ -3,6 +3,7 @@ package com.astraldb.bdd
import com.astraldb.spec.Definition
import com.astraldb.spec.Match
import com.astraldb.spec.Type
import com.astraldb.expression.Expression
object Compiler
{
@ -92,7 +93,7 @@ object Compiler
disjunctify(matcher)
}
def targets(schema: Definition): Seq[Target] =
def getTargets(schema: Definition): Seq[Target] =
{
schema.rules.flatMap { rule =>
val clauses =
@ -107,4 +108,164 @@ object Compiler
}
}
def getSanitizedTargets(schema: Definition): Seq[Target] =
{
var targets = getTargets(schema)
// Sanitize expression bindings via alpha renaming and
// common subexpression elimination
// * If two expressions are the same, they should
// have the same name.
val duplicateExprs: Map[Expression, String] =
targets.flatMap { _.exprBindings }
.toSet.toSeq
.groupBy { _._2 }
.filter { _._2.size > 1 }
.mapValues { _.map { _._1 }.head }
.toMap
targets = targets.map { target =>
val repair =
target.exprBindings
.filter { b => duplicateExprs contains b._2 }
.map { b =>
b._1 -> duplicateExprs(b._2)
}
.filterNot { a => a._1 != a._2 }
if(repair.isEmpty){ target }
else {
target.rename(repair:_*)
}
}
// * If two expressions are different, they should
// have different names.
val replacementNames: Map[String, Map[Expression, String]] =
targets.flatMap { _.exprBindings }
.toSet.toSeq
.groupBy { _._1 }
.filter { _._2.size > 1 }
.map { case (name:String, exprs:Seq[(String, Expression)]) =>
name ->
exprs.map { _._2 }
.zipWithIndex
.toMap
.mapValues { name + "_" + _ }
.toMap
}
.toMap
targets = targets.map { target =>
val repair =
target.exprBindings
.flatMap { case (name, expr) =>
replacementNames.get(name)
.flatMap { _.get(expr).map { newName =>
name -> newName
} }
}
if(repair.isEmpty) { target }
else { target.rename(repair:_*) }
}
return targets
}
def bdd(schema: Definition): BDD =
{
println(getSanitizedTargets(schema).mkString("\n---\n"))
def stepBdd(targets: Seq[Target], state: State): BDD =
{
val success = targets.find { _.successfulOn(state) }
if(success.isDefined)
{
Rewrite(success.get.label, success.get.rewrite)
} else {
val activeTargets =
targets.filter { _.canBeSuccessfulOn(state) }
if(activeTargets.isEmpty){
NoRewrite
} else {
val candidates:Seq[CandidateStep] =
activeTargets
.flatMap { t =>
val c = t.candidates(state)
assert(!c.isEmpty, s"Target ${t.label} in state \n$state\n is not successful or failed, but has no candidates for progress.")
/* return */ c
}
val bestCheckTypeStep =
candidates.collect {
case check:CheckTypeStep =>
check.path -> check
}
.groupBy { _._1 }
.toSeq
.maxByOption { _._2.size }
if(bestCheckTypeStep.isDefined)
{
val path: Pathed.Path =
bestCheckTypeStep.get._1
val types:Set[Type] =
bestCheckTypeStep.get._2.map { _._2.check }.toSet
println(s"Checking path: [${path.mkString(", ")}] for ${types.mkString(", ")}")
PickByType(
path,
types.toSeq.map { t =>
t -> stepBdd(activeTargets, state.withType(path, t))
}.toMap
)
} else { // no advancement through a best check type
val nonCheckTypeCandidates =
candidates.filterNot { _.isInstanceOf[CheckTypeStep] }
assert(!nonCheckTypeCandidates.isEmpty,
s"No possible candidates to make progress: \nstate=$state\ntargets=\n${activeTargets.mkString("\n---\n")}\ncandidates=$candidates"
)
val bestOtherCandidate =
nonCheckTypeCandidates
.groupBy { x => x }
.maxBy { _._2.size }
._1
bestOtherCandidate match {
case _:CheckTypeStep => assert(false, "We filtered out all the check type steps by this point")
case CheckMatchStep(path, matcher) =>
PickByMatch(
path = path,
matcher = matcher,
ifMatched =
stepBdd(activeTargets,
state.withSuccessfulMatch(path, matcher)),
ifNotMatched =
stepBdd(activeTargets,
state.withFailedMatch(path, matcher)),
)
case BindExpressionStep(symbol, expression) =>
BindAnExpression(symbol, expression,
stepBdd(activeTargets,
state.withBinding(symbol)
)
)
}
}
}
}
}
stepBdd(getSanitizedTargets(schema), State.empty(schema))
}
}

View file

@ -1,8 +1,41 @@
package com.astraldb.bdd
case class Pathed[T](path: Seq[Int], value: T)
case class Pathed[T](path: Pathed.Path, value: T)
{
override def toString(): String =
s"@[${path.mkString(", ")}]: $value"
def tuple = (path, value)
def pathInSet(set: Set[Pathed.Path]) =
Pathed.selfAndAncestorsInSet(path, set)
def ancestorsInSet(set: Set[Pathed.Path]) =
Pathed.ancestorsInSet(path, set)
}
object Pathed
{
type Path = Seq[Int]
def selfAndAncestorsInSet(path: Path, set: Set[Path]): Boolean =
{
if(!set(Seq.empty)){ false }
else {
path.foldLeft(Some(Seq[Int]()):Option[Seq[Int]]) {
case (Some(seq), elem) =>
val newSeq = seq :+ elem
if(set contains (newSeq)) { Some(newSeq) } else { None }
case (None, _) => None
}.isDefined
}
}
def ancestorsInSet(path:Path, set: Set[Path]): Boolean =
if(path.isEmpty){ true }
else {
selfAndAncestorsInSet(path.dropRight(1), set)
}
}

View file

@ -4,6 +4,42 @@ import com.astraldb.expression._
import com.astraldb.spec._
case class State(
types: Pathed[Type],
matchers: Pathed[Match]
types: Set[Pathed[Type]],
passedMatchers: Set[Pathed[Match]],
failedMatchers: Set[Pathed[Match]],
bindings: Set[String]
)
{
assert((passedMatchers & failedMatchers).isEmpty, "A matcher can't be both successful and failed.")
lazy val typeOf = types.map { _.tuple }.toMap
def withType(path: Pathed.Path, value: Type): State =
copy(types = types ++ Set(Pathed(path, value)))
def withSuccessfulMatch(path: Pathed.Path, matcher: Match): State =
copy(passedMatchers = passedMatchers ++ Set(Pathed(path, matcher)))
def withFailedMatch(path: Pathed.Path, matcher: Match): State =
copy(failedMatchers = failedMatchers ++ Set(Pathed(path, matcher)))
def withBinding(symbol: String): State =
copy(bindings = bindings ++ Set(symbol))
override def toString(): String =
"=== Types ===\n"++
types.mkString("\n")++
"\n=== Passed Matchers ===\n"++
passedMatchers.mkString("\n")++
"\n=== Failed Matchers ===\n"++
failedMatchers.mkString("\n")++
"\n=== Bindings ===\n"++
bindings.mkString(",")
}
object State
{
def empty(schema: Definition) =
State(Set.empty, Set.empty, Set.empty, schema.globals.keySet)
}

View file

@ -2,6 +2,7 @@ package com.astraldb.bdd
import com.astraldb.expression._
import com.astraldb.spec._
import scala.util.Try
case class Target(
label: String,
@ -29,12 +30,122 @@ case class Target(
def withoutUnusedBindings =
{
val refs = (rewrite.references ++ matchers.flatMap { _.value.references }).map { _.v }
val refs: Set[String] = (
rewrite.references ++
matchers.flatMap { _.value.references }).map { _.v } ++
exprBindings.flatMap { _._2.references }.map { _.v}
copy(
bindings = bindings.filter { b => refs(b.value) }
)
}
def candidates(state: State): Seq[CandidateStep] =
{
val requestedTypedPaths = types.map { _.path }.toSet
val typedPaths = state.typeOf.keySet
/**
* Variable bindings that have had all of their prerequisites filled.
*/
val validBindings:Set[Var] = (
// bindings that were computed by a prior step
state.bindings.map { Var(_) } ++
// path bindings that have been fully typed
bindings.filter { b =>
// If we want the binding to be typed, we need to check up to the type of the binding
if(requestedTypedPaths(b.path)) { b.pathInSet(typedPaths) }
// If the binding itself is not typed, then we just need its ancestors bound
else { b.ancestorsInSet(typedPaths) }
}.map { b => Var(b.value) }
).toSet
// println(s"Valid bindings: $validBindings")
/**
* Expression bindings that could potentially be run right now.
*/
val candidateBindings =
exprBindings
// Don't re-compute bindings that are already available
.filterNot { b => validBindings(Var(b._1)) }
// A binding needs all of its referenced variables to be bound
.filter { _._2.references.forall { validBindings(_) } }
// println(s"Typed Paths: $typedPaths")
val candidateTypings =
types
// Don't re-check typings that are already available
.filterNot { t => typedPaths(t.path) }
// All ancestors need to be typed
.filter { _.ancestorsInSet(typedPaths) }
// println(s"Candidate Typings: $candidateTypings")
val candidateMatchers =
matchers
// Don't re-check matchers that have already been computed
.filterNot { m =>
// println("m1:"+m);
state.passedMatchers(m) }
// All ancestors need to be typed
.filter { m =>
// println("m2:"+m);
m.ancestorsInSet(typedPaths) }
// A matcher needs all of its referenced variables to be bound
.filter { m =>
// println("m3:"+m+"\n"+m.value.references);
m.value.references.forall { validBindings(_) } }
// println(s"Candidate Matchers: $candidateMatchers")
(candidateTypings.map { CheckTypeStep(_) }: Seq[CandidateStep]) ++
candidateMatchers.map { CheckMatchStep(_) } ++
candidateBindings.map { BindExpressionStep(_) }
}
def canBeSuccessfulOn(state: State): Boolean =
types.forall { t =>
state.typeOf.get(t.path)
.map { _ == t.value }
.getOrElse { true }
} && matchers.forall { m =>
!(state.failedMatchers contains m)
}
def successfulOn(state: State): Boolean =
types.forall { t =>
state.typeOf.get(t.path)
.map { _ == t.value }
.getOrElse { false }
} && matchers.forall { m =>
(state.passedMatchers contains m)
} && exprBindings.forall { e =>
state.bindings.contains(e._1)
}
def rename(replacements: (String, String)*): Target =
{
val replacementMap = replacements.toMap
copy(
rewrite = rewrite.rename(replacementMap),
bindings = bindings.map { b =>
replacementMap.get(b.value)
.map { r => b.copy(value = r) }
.getOrElse(b)
},
exprBindings = exprBindings.map { case (name, expr) =>
( replacementMap.getOrElse(name, name),
expr.rename(replacementMap)
)
},
matchers = matchers.map { m =>
m.copy(value = m.value.rename(replacementMap))
}
)
}
override def toString(): String =
label+
(if(types.isEmpty){ "" } else {

View file

@ -79,11 +79,18 @@ sealed abstract class Expression
def children: Seq[Expression]
def reassemble(in: Seq[Expression]): Expression
def rebuild(fn:(Expression => Expression)): Expression =
reassemble(children.map { fn(_) })
def transform(fn: PartialFunction[Expression,Expression]): Expression =
reassemble(children.map { _.transform(fn) }).map(fn)
def map(fn: PartialFunction[Expression,Expression]): Expression =
fn.applyOrElse(this, _ => this)
def references: Set[Var] =
children.flatMap { _.references }.toSet
def rename(newNames: Map[String, String]): Expression =
transform {
case Var(v) if newNames contains v => Var(newNames(v))
}
}
///////////////////////// Constants ////////////////////////////

View file

@ -26,8 +26,11 @@ sealed trait Match
def orSeq: Seq[Match] =
this match { case Match.Or(children) => children; case _ => Seq(this) }
def expressions: Seq[Expression]
def references: Set[Var] =
children.flatMap { _.references }.toSet
children.flatMap { _.references }.toSet ++
expressions.flatMap { _.references }.toSet
def and(other: Match) =
Match.And(andSeq ++ other.andSeq)
@ -48,6 +51,9 @@ sealed trait Match
def mapChildren(f: Function[Match, Match]): Match =
reassemble(children.map(f))
def rename(newNames: Map[String, String]): Match =
reassemble(children.map { _.rename(newNames) })
}
object Match
{
@ -73,6 +79,7 @@ object Match
s"Not(\n$prefix${child.toString(prefix+" ")}\n$prefix)"
def children = Seq(child)
def reassemble(in: Seq[Match]) = copy(child = in(0))
def expressions: Seq[Expression] = Seq.empty
}
/**
@ -87,6 +94,7 @@ object Match
def toString(prefix: String): String =
s"And(\n$prefix ${children.map { _.toString(prefix+" ")}.mkString(",\n"+prefix+" ")}\n$prefix)"
def reassemble(in: Seq[Match]) = copy(children = in)
def expressions: Seq[Expression] = Seq.empty
}
/**
@ -104,6 +112,7 @@ object Match
def toString(prefix: String): String =
s"Or(\n$prefix ${children.map { _.toString(prefix+" ")}.mkString(",\n"+prefix+" ")}\n$prefix)"
def reassemble(in: Seq[Match]) = copy(children = in)
def expressions: Seq[Expression] = Seq.empty
}
/**
@ -128,6 +137,7 @@ object Match
def toString(prefix: String): String =
s"{$nodeLabel}(\n$prefix ${children.map { _.toString(prefix+" ")}.mkString(",\n"+prefix+" ")}\n$prefix)"
def reassemble(in: Seq[Match]) = copy(children = in)
def expressions: Seq[Expression] = Seq.empty
}
/**
@ -145,6 +155,7 @@ object Match
s"{$nodeType}"
def children: Seq[Match] = Seq.empty
def reassemble(in: Seq[Match]): Match = this
def expressions: Seq[Expression] = Seq.empty
}
object OfType
{
@ -167,6 +178,7 @@ object Match
s"Exact(${pattern})"
def children: Seq[Match] = Seq.empty
def reassemble(in: Seq[Match]): Match = this
def expressions: Seq[Expression] = Seq.empty
}
/**
@ -194,7 +206,7 @@ object Match
s"@[${path.mkString(",")}](\n$prefix ${pattern.toString(prefix+" ")}\n$prefix)"
def children: Seq[Match] = Seq(pattern)
def reassemble(in: Seq[Match]): Match = copy(pattern = in(0))
def expressions: Seq[Expression] = Seq.empty
}
/**
@ -211,6 +223,7 @@ object Match
s"Recursive(\n$prefix ${pattern.toString(prefix+" ")}\n$prefix)"
def children: Seq[Match] = Seq(pattern)
def reassemble(in: Seq[Match]): Match = copy(pattern = in(0))
def expressions: Seq[Expression] = Seq.empty
}
/**
@ -231,6 +244,12 @@ object Match
}
def children: Seq[Match] = Seq(pattern)
def reassemble(in: Seq[Match]): Match = copy(pattern = in(0))
def expressions: Seq[Expression] = Seq.empty
override def rename(newNames: Map[String, String]): Match =
Bind(
symbol = newNames.getOrElse(symbol, symbol),
pattern.rename(newNames)
)
}
/**
@ -244,6 +263,7 @@ object Match
"_"
def children: Seq[Match] = Seq.empty
def reassemble(in: Seq[Match]): Match = this
def expressions: Seq[Expression] = Seq.empty
}
/**
@ -257,6 +277,7 @@ object Match
"Fail!"
def children: Seq[Match] = Seq.empty
def reassemble(in: Seq[Match]): Match = this
def expressions: Seq[Expression] = Seq.empty
}
/**
@ -276,6 +297,9 @@ object Match
def children: Seq[Match] = Seq.empty
def reassemble(in: Seq[Match]): Match = this
override def references = super.references ++ Set(Var(symbol))
def expressions: Seq[Expression] = Seq.empty
override def rename(newNames: Map[String, String]): Match =
copy(symbol = newNames.getOrElse(symbol, symbol))
}
/**
@ -295,6 +319,12 @@ object Match
def children: Seq[Match] = Seq(pattern)
def reassemble(in: Seq[Match]): Match = copy(pattern = in(0))
override def references = super.references ++ Set(Var(symbol))
def expressions: Seq[Expression] = Seq.empty
override def rename(newNames: Map[String, String]): Match =
copy(
symbol = newNames.getOrElse(symbol, symbol),
pattern = pattern.rename(newNames)
)
}
/**
@ -317,6 +347,7 @@ object Match
s"@[*](\n$prefix ${pattern.toString(prefix+" ")}\n$prefix)"
def children: Seq[Match] = Seq(pattern)
def reassemble(in: Seq[Match]): Match = copy(pattern = in(0))
def expressions: Seq[Expression] = Seq.empty
}
/**
@ -334,6 +365,9 @@ object Match
def children: Seq[Match] = Seq.empty
def reassemble(in: Seq[Match]): Match = this
override def references = super.references ++ op.references
def expressions: Seq[Expression] = Seq(op)
override def rename(newNames: Map[String, String]): Match =
Test(op.rename(newNames))
}
/**
@ -348,5 +382,11 @@ object Match
def children: Seq[Match] = Seq.empty
def reassemble(in: Seq[Match]): Match = this
override def references = super.references ++ op.references
def expressions: Seq[Expression] = Seq(op)
override def rename(newNames: Map[String, String]): Match =
BindExpression(
symbol = newNames.getOrElse(symbol, symbol),
op = op.rename(newNames)
)
}
}