Merge branch 'main' of git.odin.cse.buffalo.edu:Astral/astral-compiler

nicksrules
Oliver Kennedy 2023-07-13 15:29:13 -04:00
commit 88760113ed
Signed by: okennedy
GPG Key ID: 3E5F9B3ABD3FDB60
21 changed files with 760 additions and 153 deletions

View File

@ -1,7 +1,9 @@
package com.astraldb.catalyst
import com.astraldb.codegen.Render
import com.astraldb.codegen.Rule
import com.astraldb.bdd.{ Compiler => BDDCompiler }
import com.astraldb.spec.Type
object Astral
{
@ -18,9 +20,15 @@ object Astral
}
println(
BDDCompiler.bdd(definition)
)
val family = Type.AST("LogicalPlan")
val bdd =
BDDCompiler.bdd(definition, family)
print(bdd)
println(
Rule(definition, bdd, family)
)
}
}

View File

@ -7,6 +7,7 @@ import com.astraldb.typecheck.TypecheckMatch.MatchTypecheckError
import com.astraldb.typecheck.TypecheckExpression.ExpressionTypecheckError
import com.astraldb.codegen.Render
import com.astraldb.spec.Type
object Generate
{
@ -20,7 +21,12 @@ object Generate
// err.printStackTrace()
System.exit(-1)
}
for( (file, content) <- Render(Catalyst.definition) )
val families = Seq[Type.AST](
Type.AST("LogicalPlan")
)
for( (file, content) <- Render(Catalyst.definition, families) )
{
val of = new BufferedWriter(new FileWriter(file))
of.write("""package com.astraldb.catalyst

View File

@ -35,6 +35,7 @@ case class PickByType(
case class PickByMatch(
path: Seq[Int],
matcher: Match,
bindings: Seq[Pathed[(String, Type)]],
ifMatched: BDD,
ifNotMatched: BDD,
) extends BDD
@ -50,6 +51,8 @@ case class PickByMatch(
case class BindAnExpression(
symbol: String,
expression: Expression,
exprType: Type,
bindings: Seq[Pathed[(String, Type)]],
andThen: BDD
) extends BDD
{
@ -63,11 +66,13 @@ case class BindAnExpression(
case class Rewrite(
label: String,
rewrite: Expression
rewrite: Expression,
bindings: Seq[Pathed[(String, Type)]],
) extends BDD
{
val schema = rewrite.references.toSeq
def code =
Code.Literal(s"Rewrite with $label")
Code.Literal(s"Rewrite with $label($schema)")
}
case object NoRewrite extends BDD

View File

@ -3,6 +3,8 @@ package com.astraldb.bdd
import com.astraldb.spec.Match
import com.astraldb.expression.Expression
import com.astraldb.spec.Type
import com.astraldb.spec.Definition
import com.astraldb.typecheck.TypecheckExpression
sealed trait CandidateStep
@ -13,16 +15,21 @@ object CheckTypeStep
CheckTypeStep(pathed.path, pathed.value)
}
case class CheckMatchStep(path: Pathed.Path, matcher: Match) extends CandidateStep
case class CheckMatchStep(path: Pathed.Path, matcher: Match, bindings: Seq[Pathed[(String, Type)]]) extends CandidateStep
object CheckMatchStep
{
def apply(pathed: Pathed[Match]): CheckMatchStep =
CheckMatchStep(pathed.path, pathed.value)
def apply(pathed: Pathed[Match], target: Target): CheckMatchStep =
{
val refs = pathed.value.references.map { _.v }.toSet
CheckMatchStep(pathed.path, pathed.value,
target.bindings.filter { b => refs(b.value._1) }
)
}
}
case class BindExpressionStep(symbol: String, expression: Expression) extends CandidateStep
case class BindExpressionStep(symbol: String, expression: Expression, exprType: Type, bindings: Seq[Pathed[(String, Type)]]) extends CandidateStep
object BindExpressionStep
{
def apply(binding: (String, Expression)): BindExpressionStep =
BindExpressionStep(binding._1, binding._2)
def apply(binding: (String, Expression, Type), target: Target): BindExpressionStep =
BindExpressionStep(binding._1, binding._2, binding._3, target.bindings)
}

View File

@ -4,6 +4,10 @@ import com.astraldb.spec.Definition
import com.astraldb.spec.Match
import com.astraldb.spec.Type
import com.astraldb.expression.Expression
import com.astraldb.spec.Rule
import com.astraldb.expression.Var
import com.astraldb.expression.IfIsDefined
import com.astraldb.typecheck.TypecheckMatch
object Compiler
{
@ -93,24 +97,67 @@ object Compiler
disjunctify(matcher)
}
def getTargets(schema: Definition): Seq[Target] =
def getClauseTargets(schema: Definition, rule: Rule): Seq[Target] =
{
schema.rules.flatMap { rule =>
val clauses =
normalize(schema, rule.pattern).clauses
clauses.zipWithIndex.map {
case (clause, idx) =>
val label =
rule.label + (if(clauses.size > 1) { "-"+idx } else { "" })
Target.of(label, clause, rule.rewrite)
.withoutUnusedBindings
}
val clauses:Seq[(Conjunction, Set[String])] =
normalize(schema, rule.pattern).clauses
.map { c =>
c -> TypecheckMatch(c.asMatch, rule.family, schema, schema.globals)
.keySet
}
val optionalBindings:Set[String] =
clauses.flatMap { _._2 }
.groupBy { x => x }
.filter { _._2.size < clauses.size }
.keySet
clauses.zipWithIndex.map {
case ((clause, bindings), idx) =>
val label =
rule.label + (if(clauses.size > 1) { "-"+idx } else { "" })
val optionalReferencesInClause =
bindings & optionalBindings
// if(!optionalBindings.isEmpty)
// {
// println(s"Optional: $optionalBindings")
// println(s"Optional Bound: $optionalReferencesInClause")
// }
val sanitizedRewrite =
rule.rewrite.transform {
case IfIsDefined(symbol, ifYes, ifNo)
if optionalBindings(symbol)=>
if(optionalReferencesInClause(symbol)) { ifYes }
else { ifNo }
}
// if(sanitizedRewrite != rule.rewrite)
// {
// println(rule.rewrite)
// println("into")
// println(sanitizedRewrite)
// }
Target.of(schema, label, rule.family, clause, sanitizedRewrite)
.withoutUnusedBindings
}
}
def getTargets(schema: Definition, family: Type.AST): Seq[Target] =
{
schema.rules.flatMap {
case rule if rule.family != family => None
case rule => getClauseTargets(schema, rule)
}
}
def getSanitizedTargets(schema: Definition): Seq[Target] =
def getSanitizedTargets(schema: Definition, family: Type.AST): Seq[Target] =
{
var targets = getTargets(schema)
var targets = getTargets(schema, family)
// Sanitize expression bindings via alpha renaming and
// common subexpression elimination
@ -146,7 +193,7 @@ object Compiler
.toSet.toSeq
.groupBy { _._1 }
.filter { _._2.size > 1 }
.map { case (name:String, exprs:Seq[(String, Expression)]) =>
.map { case (name:String, exprs:Seq[(String, Expression, Type)]) =>
name ->
exprs.map { _._2 }
.zipWithIndex
@ -159,7 +206,7 @@ object Compiler
targets = targets.map { target =>
val repair =
target.exprBindings
.flatMap { case (name, expr) =>
.flatMap { case (name, expr, t) =>
replacementNames.get(name)
.flatMap { _.get(expr).map { newName =>
name -> newName
@ -173,16 +220,15 @@ object Compiler
}
def bdd(schema: Definition): BDD =
def bdd(schema: Definition, family: Type.AST): BDD =
{
println(getSanitizedTargets(schema).mkString("\n---\n"))
def stepBdd(targets: Seq[Target], state: State): BDD =
{
val success = targets.find { _.successfulOn(state) }
if(success.isDefined)
{
Rewrite(success.get.label, success.get.rewrite)
Rewrite(success.get.label, success.get.rewrite, success.get.bindings)
} else {
val activeTargets =
targets.filter { _.canBeSuccessfulOn(state) }
@ -194,7 +240,7 @@ object Compiler
val candidates:Seq[CandidateStep] =
activeTargets
.flatMap { t =>
val c = t.candidates(state)
val c = t.candidates(schema, state)
assert(!c.isEmpty, s"Target ${t.label} in state \n$state\n is not successful or failed, but has no candidates for progress.")
/* return */ c
}
@ -215,7 +261,7 @@ object Compiler
val types:Set[Type] =
bestCheckTypeStep.get._2.map { _._2.check }.toSet
println(s"Checking path: [${path.mkString(", ")}] for ${types.mkString(", ")}")
System.err.println(s"Checking path: [${path.mkString(", ")}] for ${types.mkString(", ")}")
PickByType(
path,
@ -241,10 +287,11 @@ object Compiler
bestOtherCandidate match {
case _:CheckTypeStep => assert(false, "We filtered out all the check type steps by this point")
case CheckMatchStep(path, matcher) =>
case CheckMatchStep(path, matcher, bindings) =>
PickByMatch(
path = path,
matcher = matcher,
bindings = bindings,
ifMatched =
stepBdd(activeTargets,
state.withSuccessfulMatch(path, matcher)),
@ -252,8 +299,10 @@ object Compiler
stepBdd(activeTargets,
state.withFailedMatch(path, matcher)),
)
case BindExpressionStep(symbol, expression) =>
case BindExpressionStep(symbol, expression, exprType, bindings) =>
BindAnExpression(symbol, expression,
exprType,
bindings,
stepBdd(activeTargets,
state.withBinding(symbol)
)
@ -265,7 +314,7 @@ object Compiler
}
}
stepBdd(getSanitizedTargets(schema), State.empty(schema))
stepBdd(getSanitizedTargets(schema, family), State.empty(schema))
}
}

View File

@ -3,21 +3,38 @@ package com.astraldb.bdd
import com.astraldb.expression._
import com.astraldb.spec._
import scala.util.Try
import com.astraldb.typecheck.TypecheckExpression
case class Target(
label: String,
family: Type.AST,
rewrite: Expression,
bindings: Seq[Pathed[String]] = Seq.empty,
exprBindings: Seq[(String, Expression)] = Seq.empty,
bindings: Seq[Pathed[(String, Type)]] = Seq.empty,
exprBindings: Seq[(String, Expression, Type)] = Seq.empty,
types: Seq[Pathed[Type]] = Seq.empty,
matchers: Seq[Pathed[Match]] = Seq.empty,
)
{
def withBinding(path: Seq[Int], symbol: String) =
copy(bindings = bindings :+ Pathed(path, symbol))
{
val varType =
types.find { _.path == path }
.map { _.value }
.getOrElse { family }
copy(
bindings = bindings :+ Pathed(path, (symbol, varType))
)
}
def withType(path: Seq[Int], pathType: Type) =
copy(types = types :+ Pathed(path, pathType))
copy(
types = types :+ Pathed(path, pathType),
bindings = bindings.map {
case b if b.path == path =>
b.copy( value = b.value._1 -> pathType )
case b => b
}
)
def withMatcher(path: Seq[Int], matcher: Match) =
if(matcher == Match.Any){ this }
@ -25,8 +42,8 @@ case class Target(
copy(matchers = matchers :+ Pathed(path, matcher))
}
def withBoundExpression(symbol: String, expr: Expression) =
copy(exprBindings = exprBindings :+ (symbol, expr))
def withBoundExpression(symbol: String, expr: Expression, exprType: Type) =
copy(exprBindings = exprBindings :+ (symbol, expr, exprType))
def withoutUnusedBindings =
{
@ -36,11 +53,11 @@ case class Target(
exprBindings.flatMap { _._2.references }.map { _.v}
copy(
bindings = bindings.filter { b => refs(b.value) }
bindings = bindings.filter { b => refs(b.value._1) }
)
}
def candidates(state: State): Seq[CandidateStep] =
def candidates(schema: Definition, state: State): Seq[CandidateStep] =
{
val requestedTypedPaths = types.map { _.path }.toSet
val typedPaths = state.typeOf.keySet
@ -57,7 +74,7 @@ case class Target(
if(requestedTypedPaths(b.path)) { b.pathInSet(typedPaths) }
// If the binding itself is not typed, then we just need its ancestors bound
else { b.ancestorsInSet(typedPaths) }
}.map { b => Var(b.value) }
}.map { b => Var(b.value._1) }
).toSet
// println(s"Valid bindings: $validBindings")
@ -101,8 +118,8 @@ case class Target(
// println(s"Candidate Matchers: $candidateMatchers")
(candidateTypings.map { CheckTypeStep(_) }: Seq[CandidateStep]) ++
candidateMatchers.map { CheckMatchStep(_) } ++
candidateBindings.map { BindExpressionStep(_) }
candidateMatchers.map { CheckMatchStep(_, this) } ++
candidateBindings.map { BindExpressionStep(_, this) }
}
def canBeSuccessfulOn(state: State): Boolean =
@ -131,13 +148,14 @@ case class Target(
copy(
rewrite = rewrite.rename(replacementMap),
bindings = bindings.map { b =>
replacementMap.get(b.value)
.map { r => b.copy(value = r) }
replacementMap.get(b.value._1)
.map { r => b.copy(value = (r, b.value._2)) }
.getOrElse(b)
},
exprBindings = exprBindings.map { case (name, expr) =>
exprBindings = exprBindings.map { case (name, expr, exprType) =>
( replacementMap.getOrElse(name, name),
expr.rename(replacementMap)
expr.rename(replacementMap),
exprType
)
},
matchers = matchers.map { m =>
@ -146,6 +164,12 @@ case class Target(
)
}
def scope: Map[String, Type] =
(
bindings.map { _.value } ++
exprBindings.map { b => b._1 -> b._3 }
).toMap
override def toString(): String =
label+
(if(types.isEmpty){ "" } else {
@ -158,7 +182,7 @@ case class Target(
})+
(if(exprBindings.isEmpty){ "" } else {
"\n Expression Bindings\n"+
exprBindings.map { case (s, e) => " "+s+" <- "+e }.mkString("\n")
exprBindings.map { case (s, e, t) => " "+s+" <- "+e }.mkString("\n")
})+
(if(matchers.isEmpty){ "" } else {
"\n Matchers\n"+
@ -169,9 +193,9 @@ case class Target(
object Target
{
def of(label: String, clause: Conjunction, rewrite: Expression): Target =
def of(schema: Definition, label: String, family: Type.AST, clause: Conjunction, rewrite: Expression): Target =
{
clause.atoms.foldLeft(Target(label, rewrite)){ (target, atom) =>
clause.atoms.foldLeft(Target(label, family, rewrite)){ (target, atom) =>
atom match {
case Match.Path(path, Match.Bind(symbol, child)) =>
target.withBinding(path, symbol)
@ -186,7 +210,9 @@ object Target
case Match.Path(path, child) =>
target.withMatcher(path, child)
case Match.BindExpression(symbol, expr) =>
target.withBoundExpression(symbol, expr)
target.withBoundExpression(symbol, expr,
TypecheckExpression(expr, schema, target.scope++schema.globals)
)
case _ =>
target.withMatcher(Seq.empty, atom)
}

View File

@ -2,6 +2,7 @@ package com.astraldb.bdd
import com.astraldb.spec.Match
import com.astraldb.codegen.Code
import com.astraldb.expression.Var
case class Conjunction(atoms: Seq[Match])
{
@ -28,6 +29,10 @@ case class Conjunction(atoms: Seq[Match])
")"
)
}
def references:Set[Var] = atoms.flatMap { _.references }.toSet
def asMatch = Match.And(atoms)
}
object Conjunction
@ -78,6 +83,10 @@ case class Disjunction(clauses: Seq[Conjunction])
")"
)
}
def references:Set[Var] = clauses.flatMap { _.references }.toSet
def asMatch = Match.Or(clauses.map { _.asMatch })
}
object Disjunction

View File

@ -0,0 +1,160 @@
package com.astraldb.codegen
import com.astraldb.spec.Definition
import com.astraldb.spec.Type
import com.astraldb.bdd
import com.astraldb.codegen.Code.PaddedString
import com.astraldb.expression.Var
object BDD
{
def apply(
schema: Definition,
family: Type.AST,
rule: bdd.BDD,
onFail: Code,
root: Code,
): Code =
{
def codeScopeFor[T](refs: Set[Var], bindings: Seq[bdd.Pathed[(String, Type)]], pathNames: Map[bdd.Pathed.Path, (Code, Type)], boundVars: Map[String, Type], context: T): CodeScope =
CodeScope(
(
(refs.map { _.v }
-- boundVars.keys
-- schema.globals.keys
).toSeq.map { ref =>
val pathRef =
bindings.find { _.value._1 == ref }
.getOrElse {
assert(false, s"Reference to a variable that hasn't been bound yet: $ref (in $bindings):\n$context")
}
val varDescription =
pathNames.find { _._1 == pathRef.path }
.getOrElse {
assert(false, s"Reference to a variable at an unbound path: ${pathRef.path} (in $pathNames):\n$context")
}
._2
ref -> varDescription
}.toMap:Map[String, Type|(Code,Type)]
) ++ (boundVars:Map[String, Type|(Code,Type)])
++ (schema.globals:Map[String, Type|(Code,Type)])
)
def recur(rule: bdd.BDD, pathNames: Map[bdd.Pathed.Path, (Code, Type)], boundVars: Map[String, Type]): Code =
rule match {
case bdd.NoRewrite => onFail
case bdd.PickByType(path, types) =>
// println(s"Generating code for $path")
val target = pathNames.get(path)
.getOrElse {
assert(false, s"No name available for $path (in $pathNames):\n$rule")
}
._1
Code.IfElifElse(
types.map { case (targetType, andThen) =>
{
val nodeName = "root"+path.map { "_child"+_ }.mkString+"_as_"+targetType.scalaType
val newPathNames: Map[bdd.Pathed.Path, (Code, Type)] =
targetType match {
case Type.Node(node) =>
schema.nodesByName(node)
.fields
.zipWithIndex
.map { case (field, idx) =>
(path :+ idx) -> (
Code.BinOp(
Code.Literal(nodeName),
".",
Code.Literal(field.name)
),
field.t
)
}
.toMap
case _ => Map.empty
}
Code.Pair(target, Code.Literal(s".isInstanceOf[${targetType.scalaType}]"))
-> Code.Block(Seq(
Code.BinOp(
Code.Literal(s"val $nodeName"),
PaddedString.pad(1, "="),
Code.Pair(target,
Code.Literal(s".asInstanceOf[${targetType.scalaType}]")
)
)
)++recur(
andThen,
pathNames ++ newPathNames ++ Map(path -> (Code.Literal(nodeName), targetType)),
boundVars,
).block
)
}
}.toSeq,
onFail
)
case bdd.PickByMatch(path, pattern, bindings, onMatch, onFail) =>
val target = pathNames.get(path)
.getOrElse {
assert(false, s"No name available for $path (in $pathNames):\n$rule")
}
val scope =
codeScopeFor(pattern.references, bindings, pathNames, boundVars, pattern)
Match(
schema = schema,
pattern = pattern,
target = target._1,
targetPath = path,
targetType = target._2,
onSuccess =
_ => recur(
onMatch,
pathNames,
boundVars,
),
onFail =
_ => recur(
onFail,
pathNames,
boundVars,
),
name = None,
scope = scope
)
case bdd.Rewrite(label, op, bindings) =>
Code.Block(
Seq(
Code.Literal(s"// Rewrite by $label"),
) ++ Expression(
schema,
op,
codeScopeFor(op.references, bindings, pathNames, boundVars, op)
).block
)
case bdd.BindAnExpression(symbol, expression, exprType, bindings, andThen) =>
Code.Block(
Seq(
Code.Pair(
Code.Literal(s"val $symbol ="),
Expression(
schema, expression,
codeScopeFor(expression.references, bindings, pathNames, boundVars, expression)
)
)
)++recur(
andThen,
pathNames,
boundVars ++ Map(symbol -> exprType)
).block
)
}
recur(
rule,
Map(
Seq.empty -> (root, family)
),
Map.empty,
)
}
}

View File

@ -321,4 +321,34 @@ object Code
def Parenthesize(body: Code): Code =
Parens("(", body, ")")
def IndentedPair(left: Code, right: Code): Code =
Pair(left, Indent(2, right))
def IfElifElse(ifThenClauses: Seq[(Code, Code)], elseClause: Code): Code =
{
Block(
Seq(
IndentedPair(
Parens("if(", ifThenClauses.head._1, PaddedString.rightPad(1, ") {")),
ifThenClauses.head._2,
)
) ++ ifThenClauses.tail.map { it =>
IndentedPair(
Parens(
PaddedString.leftPad(1, "} else if("),
it._1,
PaddedString.rightPad(1, ") {")
),
it._2,
)
} ++ Seq(
Parens(
PaddedString.pad(1, "} else {"),
elseClause,
PaddedString.leftPad(1, "}")
)
)
)
}
}

View File

@ -2,7 +2,7 @@ package com.astraldb.codegen
import com.astraldb.spec
class CodeScope(vars: Map[String, spec.Type | (Code, spec.Type)])
case class CodeScope(vars: Map[String, spec.Type | (Code, spec.Type)])
{
type Element = spec.Type | (Code, spec.Type)

View File

@ -16,6 +16,7 @@ object Match
schema: spec.Definition,
pattern: spec.Match,
target: Code,
targetPath: Seq[Int],
targetType: spec.Type,
onSuccess: CodeScope => Code,
onFail: CodeScope => Code,
@ -25,16 +26,18 @@ object Match
{
pattern match {
case And(Seq()) => onSuccess(scope)
case And(Seq(a)) => apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
case And(Seq(a)) => apply(schema, a, target, targetPath, targetType, onSuccess, onFail, name, scope)
case And(a) =>
apply(schema, a.head,
target = target,
targetPath = targetPath,
targetType = targetType,
onSuccess =
scope =>
apply(schema, And(a.tail),
target = target,
targetType = targetType,
targetPath = targetPath,
onSuccess = onSuccess,
onFail = onFail,
name = name,
@ -45,35 +48,45 @@ object Match
scope = scope
)
case Not(a) =>
apply(schema, a, target, targetType, onFail, onSuccess, name, scope)
apply(schema, a, target, targetPath, targetType, onFail, onSuccess, name, scope)
case Or(Seq()) => onFail(scope)
case Or(Seq(a)) =>
apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
apply(schema, a, target, targetPath, targetType, onSuccess, onFail, name, scope)
case Or(a) =>
val optional:Map[String, Type] =
a.flatMap {
TypecheckMatch(_, name, targetType, schema, scope.flatten).toSeq
val optional:Map[Var|TypecheckMatch.PathVar, Type] =
a.flatMap { m =>
TypecheckMatch(
m,
targetType,
schema,
scope.flatten
).toSeq
}.groupBy { _._1 }
.filter { _._2.size < a.size }
.mapValues { case v => Typecheck.glb(v.map { _._2 }, schema) }
.map { case (v, t) =>
Var(v) ->
Typecheck.glb(t.map { _._2 }, schema)
}
.toMap
a.foldRight(onFail) { (matcher, nextOnFail) =>
_ =>
apply(schema, matcher,
target = target,
targetPath = targetPath,
targetType = targetType,
onSuccess = {
(scope: CodeScope) => onSuccess(
scope.map {
case (name, t: Type) if optional contains name =>
case (name, t: Type) if optional contains Var(name) =>
name -> (Code.Literal(s"(Some($name):${Type.Option(t).scalaType})"), Type.Option(t))
case (name, c: (Code, Type)) if optional contains name =>
case (name, c: (Code, Type)) if optional contains Var(name) =>
name -> (Code.Parens(s"(Some(", c._1, s"):${Type.Option(c._2).scalaType})"), Type.Option(c._2))
case x => x
}.withVars(
(optional.keySet -- scope.keys).toSeq.map { k =>
k -> (Code.Literal(s"(None:${Type.Option(optional(k)).scalaType})"):Code, Type.Option(optional(k)):Type)
(optional.keySet -- scope.keys.map { Var(_) }).toSeq.collect {
case k:Var =>
k.v -> (Code.Literal(s"(None:${Type.Option(optional(k)).scalaType})"):Code, Type.Option(optional(k)):Type)
}:_*
)
)
@ -93,6 +106,7 @@ object Match
Var(selectedName) eq pattern
)),
target = target,
targetPath = targetPath,
targetType = targetType,
onSuccess = onSuccess,
onFail = onFail,
@ -105,6 +119,7 @@ object Match
schema = schema,
pattern = pattern,
target = target,
targetPath = targetPath,
targetType = targetType,
onSuccess = onSuccess,
onFail = onFail,
@ -154,11 +169,13 @@ object Match
) ++
children
.zip(node.fields)
.foldRight(onSuccess) { case ((child, field), andThen) =>
.zipWithIndex
.foldRight(onSuccess) { case (((child, field), idx), andThen) =>
nextScope => apply(
schema = schema,
pattern = child,
target = Code.Literal(s"$selectedName.${field.name}"),
targetPath = targetPath :+ idx,
targetType = field.t,
onSuccess = andThen,
onFail = onFail,

View File

@ -1,14 +1,23 @@
package com.astraldb.codegen
import com.astraldb.spec.Definition
import com.astraldb.spec.Type
import com.astraldb.bdd
object Render
{
def apply(schema: Definition): Map[String, String] =
def apply(schema: Definition, bddFamilies: Seq[Type.AST] = Seq.empty): Map[String, String] =
{
val bddRule: Seq[(Type.AST, bdd.BDD)] =
bddFamilies.map { f =>
f -> bdd.Compiler.bdd(schema, f)
}
Map(
"Optimizer.scala" -> Optimizer(schema)
) ++ schema.rules.map { rule =>
"Optimizer.scala" -> Optimizer(schema),
) ++ (bddRule.map { case (family, bdd) =>
s"${family.scalaType}BDD.scala" -> Rule(schema, bdd, family)
}.toMap) ++ schema.rules.map { rule =>
s"${rule.safeLabel}.scala" -> Rule(schema, rule)
}.toMap
}

View File

@ -1,6 +1,7 @@
package com.astraldb.codegen
import com.astraldb.spec
import com.astraldb.bdd
object Rule
{
@ -8,4 +9,9 @@ object Rule
{
scala.Rule(schema, rule).toString
}
def apply(schema: spec.Definition, rule: bdd.BDD, family: spec.Type.AST) =
{
scala.BDDBatch(schema, family, rule).toString
}
}

View File

@ -360,6 +360,8 @@ case class Let(symbol: String, value: Expression, rest: Expression) extends Expr
def children = Seq(value, rest)
def reassemble(in: Seq[Expression]): Expression = Let(symbol, in(0), in(1))
override def toString = s"let $symbol = $value in $rest"
override def references: Set[Var] =
value.references ++ (rest.references -- Set(Var(symbol)))
}
/**
* Branch on whether a symbol is defined in scope
@ -370,4 +372,6 @@ case class IfIsDefined(symbol: String, ifYes: Expression, ifNo: Expression) exte
def reassemble(in: Seq[Expression]): Expression =
copy(ifYes = in(0), ifNo = in(1))
override def toString = s"if($symbol exists){ $ifYes } else { $ifNo }"
override def references: Set[Var] =
super.references ++ Set(Var(symbol))
}

View File

@ -4,7 +4,7 @@ import scala.collection.mutable
import com.astraldb.expression._
case class Definition(
asts: Map[String, ASTDefinition],
asts: Map[Type.AST, ASTDefinition],
rules:Seq[Rule],
globals: Map[String, Type],
) {
@ -20,7 +20,7 @@ case class Definition(
|${rules.mkString("\n\n")}
""".stripMargin
val familyOfNode: Map[String, String] =
val familyOfNode: Map[String, Type.AST] =
nodes.flatMap { case (family, elements) => elements.map { _.name -> family } }
.toMap
@ -39,7 +39,7 @@ class HardcodedDefinition
Definition(
asts = nodes.map { case (f, n) =>
val family = Type.AST(f)
f -> ASTDefinition(
family -> ASTDefinition(
family = family,
nodes = n.map { _.copy(family = family) }.toSet
)}.toMap,
@ -62,7 +62,7 @@ class HardcodedDefinition
def Rule(label: String, family: String)(pattern: Match)(replacement: Expression): Unit =
rules.append(
com.astraldb.spec.Rule(label, family, pattern, replacement)
com.astraldb.spec.Rule(label, Type.AST(family), pattern, replacement)
)
def Function(label: String, ret: Type = Type.Unit)(args: Type*): Unit =

View File

@ -6,7 +6,7 @@ import com.astraldb.typecheck.TypecheckExpression
case class Rule(
label: String,
family: String,
family: Type.AST,
pattern: Match,
rewrite: Expression
)
@ -17,11 +17,11 @@ case class Rule(
def validate(schema: Definition): Unit =
{
val scope =
TypecheckMatch(pattern, None, Type.AST(family), schema, schema.globals)
TypecheckMatch(pattern, family, schema, schema.globals)
TypecheckExpression(rewrite, schema, scope) match {
case Type.Node(n) =>
assert(schema.familyOfNode(n) == family, s"Rule $label constructs a $n of family ${schema.familyOfNode(n)}, while the rule is for famil $family")
case Type.AST(f) =>
case f:Type.AST =>
assert(f == family, s"Rule $label constructs a $f, while the rule is for family $family")
case c =>
assert(false, s"Rule $label generates a non-AST type $c")

View File

@ -27,6 +27,13 @@ object Typecheck
}
}
/**
* Compute the greatest lower bound.
*
* This is the *more* precise type of the two arguments.
*
* e.g., glb(Node, AST) = Node
*/
def glb(a: Type, b: Type, schema: Definition): Type =
{
if(a == b){ return a }
@ -49,12 +56,31 @@ object Typecheck
}
}
/**
* Compute the greatest lower bound.
*
* This is the *more* precise type of the two arguments.
*
* e.g., glb(Node, AST) = Node
*/
def glb(elems: Seq[Type], schema: Definition): Type =
elems.foldLeft(Type.Any: Type){ glb(_, _, schema) }
/**
* Compute the least upper bound.
*
* This is the *less* precise type of the two arguments.
*
* e.g., glb(Node, AST) = AST
*/
def lub(a: Type, b: Type, schema: Definition): Type =
lubOption(a, b, schema).getOrElse {
assert(false, s"Types $a and $b don't unify")
}
def lubOption(a: Type, b: Type, schema: Definition): Option[Type] =
{
if(a == b){ return a }
if(a == b){ return Some(a) }
(a, b) match {
case (n:Type.Node, o:Type.ASTType) if
schema.nodesByName(n.nodeType)

View File

@ -117,8 +117,8 @@ object TypecheckExpression
case Type.Node(nodeType) =>
val node = schema.nodesByName(nodeType)
node.fields(index).t
case _ =>
assert(false, s"Node subscript on something not a node")
case t =>
assert(false, s"Node subscript on something not a node: $t")
}
case StructSubscript(target, name) =>

View File

@ -6,19 +6,247 @@ import com.astraldb.expression._
object TypecheckMatch
{
def apply(pattern: Match, targetType: Type, schema: Definition, scope: Map[String, Type]): Map[String, Type] =
{
val finalState =
check(
pattern = pattern,
targetPath = Seq.empty,
State(
scope = scope ++ schema.globals,
pathTypes = Map(Seq.empty -> targetType),
optionalBindings = Set.empty,
schema = schema
)
)
finalState.asScope
}
case class PathVar(path: Seq[Int])
case class State(
scope: Map[String, PathVar|Type],
pathTypes: Map[Seq[Int], Type],
optionalBindings: Set[String],
schema: Definition
)
{
assert(pathTypes.contains(Seq.empty), s"We should never create a state without a root path: $pathTypes")
override def toString(): String =
"---- Scope ----\n"+
scope.map {
case (name, t:Type) if optionalBindings(name) => s"$name:$t (optional)\n"
case (name, t:Type) => s"$name:$t\n"
case (name, PathVar(path)) if optionalBindings(name) => s"$name @ [${path.mkString(", ")}] (optional)\n"
case (name, PathVar(path)) => s"$name @ [${path.mkString(", ")}]\n"
}.mkString +
"---- PathTypes ----\n"+
pathTypes.map {
case (path, t) => s"@[${path.mkString(",")}]:$t\n"
}.mkString
def typeAtPath(path: Seq[Int]): Type =
path.foldLeft( (
Seq[Int](),
pathTypes.get(Seq.empty)
.getOrElse {
System.err.println(this)
assert(false, s"Trying to look up the type of a path when the root is untyped")
}
) ) { case ((subPath, t), field) =>
val fieldPath = subPath :+ field
val fieldType =
t match {
case Type.Node(label) =>
assert(schema.nodesByName contains label, s"No AST node of type $label defined")
val nodeSchema = schema.nodesByName(label)
assert(nodeSchema.fields.size > field, s"Can't path into the $field'th element of a $t")
nodeSchema.fields(field).t
case Type.Array(base) =>
base
case _ =>
// System.err.println(this)
assert(false, s"Can't path into a simple type: $t @[${subPath.mkString(", ")}] / [${path.mkString(", ")}]")
}
(
fieldPath,
pathTypes.get(fieldPath)
.map { Typecheck.glb(_, fieldType, schema) }
.getOrElse { fieldType }
)
}
._2
def typeOfVar(v: String): Type =
{
val t =
scope.get(v)
.getOrElse {
assert(false, s"Unbound variable: $v")
} match {
case PathVar(path) => typeAtPath(path)
case t:Type => t
}
if(optionalBindings(v)){ Type.Option(t) } else { t }
}
def asScope: Map[String, Type] =
scope.map { case (v, t) =>
val base = t match {
case t:Type => t
case PathVar(path) => typeAtPath(path)
}
v -> (if(optionalBindings(v)){ Type.Option(base) } else { base })
}.toMap
def pathOfVar(v: String): Seq[Int] =
scope.get(v)
.getOrElse {
assert(false, s"Unbound variable: $v")
} match {
case PathVar(path) => path
case t:Type =>
assert(false, s"Expected $v to be bound to a path")
}
def bindPath(path: Seq[Int], t: Type): State =
copy(
pathTypes =
pathTypes ++ Map(
path ->
pathTypes.get(path)
.map { Typecheck.glb(_, t, schema) }
.getOrElse(t)
)
)
def bindVarToPath(v: String, path: Seq[Int]): State =
copy(
scope = scope ++ Map(
v -> PathVar(path)
)
)
def bindVarToType(v: String, t: Type): State =
copy(
scope = scope ++ Map(
v -> t
)
)
def or(other: State): State =
{
State(
scope =
(scope.keySet ++ other.scope.keySet).toSeq.flatMap { v =>
(scope.get(v), other.scope.get(v)) match {
case (Some(myT:Type), Some(otherT:Type)) =>
Some(v -> Typecheck.lub(myT, otherT, schema))
case (Some(myT:Type), Some(PathVar(otherVPath))) =>
Some(v -> Typecheck.lub(myT, other.typeAtPath(otherVPath), schema))
case (Some(PathVar(myVPath)), Some(otherT:Type)) =>
Some(v -> Typecheck.lub(typeAtPath(myVPath), otherT, schema))
case (Some(PathVar(myVPath)), Some(PathVar(otherVPath))) =>
if(myVPath != otherVPath){
Some(v -> Typecheck.lub(typeAtPath(myVPath), other.typeAtPath(otherVPath), schema))
} else {
Some(v -> PathVar(myVPath))
}
case (Some(t:Type), None) =>
Some(v -> t) // optional fields get combined later
case (Some(PathVar(path)), None) =>
Some(v -> typeAtPath(path)) // optional fields get combined later
case (None, Some(t:Type)) =>
Some(v -> t) // optional fields get combined later
case (None, Some(PathVar(path))) =>
Some(v -> other.typeAtPath(path)) // optional fields get combined later
case (None, None) =>
None
}
}.toMap,
pathTypes =
(pathTypes.keySet ++ other.pathTypes.keySet).toSeq.flatMap { p =>
(pathTypes.get(p), other.pathTypes.get(p)) match {
case (Some(t), Some(ot)) =>
Typecheck.lubOption(t, ot, schema).map { p -> _ }
case (_, None) =>
None
case (None, _) =>
None
}
}.toMap,
optionalBindings =
optionalBindings ++ other.optionalBindings ++
((scope.keySet ++ other.scope.keySet)
.filter { v =>
!(scope contains v) || !(other.scope contains v)
}
),
schema = schema
)
}
def narrowToPrefix(prefix: Seq[Int]): State =
copy(
scope =
scope.mapValues {
case t:Type => t
case PathVar(p) =>
(
if(p startsWith prefix) { PathVar(p.drop(prefix.length)) }
else { typeAtPath(p) }
):(PathVar|Type)
}.toMap[String, PathVar|Type],
pathTypes =
pathTypes.filterKeys { _ startsWith prefix }
.map { case (p, v) => p.drop(prefix.length) -> v }
.toMap
)
def expandPrefix(prefix: Seq[Int], target: State): State =
{
val cmp = narrowToPrefix(prefix)
copy(
scope =
scope ++ ((
// include only modifications to the schema
target.scope.toSet -- cmp.scope.toSet
).toMap.mapValues {
case t:Type => t:(Type|PathVar)
case PathVar(p) => PathVar(prefix ++ p):(Type|PathVar)
}),
pathTypes =
pathTypes ++
target.pathTypes.map {
case (path, t) =>
(prefix ++ path) -> t
},
optionalBindings =
target.optionalBindings
)
}
}
object State
{
def empty(schema: Definition) =
new State(schema.globals, Map.empty, Set.empty, schema)
}
case class MatchTypecheckError(msg: String, stack: List[Match]) extends Exception
{
override def getMessage(): String =
stack.map { _.toString }.mkString("\n ,--in--^\n")+"\n ,--in--^\n"+msg
}
def apply(pattern: Match, family: String, schema: Definition, scope: Map[String, Type]): Map[String, Type] =
apply(pattern, None, Type.AST(family), schema, scope)
def apply(pattern: Match, target: Option[String], expectedType: Type, schema: Definition, scope: Map[String, Type]): Map[String, Type] =
def check(
pattern: Match,
targetPath: Seq[Int],
state: State,
): State =
{
def checkAllChildren() =
pattern.children.foreach { apply(_, target, expectedType, schema, scope) }
def schema = state.schema
def astSchema(label: String): Node =
{
@ -27,7 +255,7 @@ object TypecheckMatch
}
def astSchemaWithFamily(family: String, label: String): Node =
def astSchemaWithFamily(family: Type.AST, label: String): Node =
{
assert(schema.nodes contains family, s"AST Family '$family' is not defined")
val astSchema = schema.nodes(family)
@ -36,54 +264,44 @@ object TypecheckMatch
node.get
}
def typeOfCurrentNode: Type =
state.typeAtPath(targetPath)
try {
pattern match {
case Match.Not(child) => apply(child, target, expectedType, schema, scope)
case Match.Not(child) =>
check(child, targetPath, state)
case Match.And(children) =>
children.foldLeft(scope) { (scope, child) =>
apply(child, target, expectedType, schema, scope)
children.foldLeft(state) { (state, child) =>
check(child, targetPath, state)
}
case Match.Or(children) =>
children.flatMap { apply(_, target, expectedType, schema, scope).toSeq }
.groupBy { _._1 }
.mapValues { childTypes =>
val base =
Type.union(childTypes.map { _._2 })
if(childTypes.size == children.size){ base }
else { Type.Option(base) }
}
.toMap
case Match.Or(Seq()) =>
state
case Match.Or(children) =>
children.tail
.foldLeft(check(children.head, targetPath, state)) {
(stateAccum, child) =>
stateAccum or check(child, targetPath, state)
}
case Match.OfType(t) =>
assert(Typecheck.escalatesTo(t, expectedType, schema),
s"Matching a node by type; Type $t can not possibly matched by an element of type $expectedType: $pattern"
assert(Typecheck.escalatesTo(t, typeOfCurrentNode, schema),
s"Matching a node by type; Type $t can not possibly matched by an element of type $typeOfCurrentNode: @[${targetPath.mkString(", ")}]: $pattern"
)
scope ++ target.map { _ -> t }.toMap
state.bindPath(targetPath, t)
case Match.Exact(child) =>
assert(Typecheck.escalatesTo(child.t, expectedType, schema),
s"Matching a node by value; Value $child can not possibly match an element of type $expectedType: $pattern"
assert(Typecheck.escalatesTo(child.t, typeOfCurrentNode, schema),
s"Matching a node by value; Value $child can not possibly match an element of type $typeOfCurrentNode: @[${targetPath.mkString(", ")}]: $pattern"
)
scope
state
case Match.Path(path, child) =>
val targetType =
path.foldLeft(expectedType) { (t, field) =>
t match {
case Type.Node(label) =>
val nodeSchema = astSchema(label)
assert(nodeSchema.fields.size > field, s"Can't path into the $field'th element of a $t: $pattern")
nodeSchema.fields(field).t
case Type.Array(base) =>
base
case _ =>
assert(false, s"Can't path into a simple type: $t: $pattern")
}
}
apply(child, target, targetType, schema, scope)
check(child, targetPath ++ path, state)
case Match.Recursive(_) =>
// not entirely sure how to typecheck this...
@ -91,56 +309,63 @@ object TypecheckMatch
???
case Match.Bind(symbol, child) =>
apply(child, Some(symbol), expectedType, schema, scope ++ Map(symbol -> expectedType))
check(child, targetPath, state.bindVarToPath(symbol, targetPath))
case Match.Any =>
scope
state
case Match.Fail =>
Map.empty
State.empty(schema)
case Match.Lookup(symbol) =>
assert(scope contains symbol, s"$symbol is not bound: $pattern")
assert(Typecheck.escalatesTo(scope(symbol), expectedType, schema), s"$symbol (of type ${scope(symbol)}) can not possibly match an element of type $expectedType")
scope
assert(Typecheck.escalatesTo(
state.typeOfVar(symbol),
typeOfCurrentNode,
schema
), s"$symbol (of type ${state.typeOfVar(symbol)}) can not possibly match an element of type $typeOfCurrentNode")
state
case Match.ApplyToScope(symbol, child) =>
assert(scope contains symbol, s"$symbol is not bound: $pattern")
apply(child, target, scope(symbol), schema, scope)
state.expandPrefix(
state.pathOfVar(symbol),
check(
child,
Seq.empty,
state.narrowToPrefix(targetPath)
)
)
case Match.Forall(child) =>
expectedType match {
case Type.Array(base) =>
apply(child, target, base, schema, scope)
case _ =>
assert(false, s"Forall match on something other than an array: $pattern")
}
check(child, targetPath :+ 0, state)
case Match.Test(op) =>
assert(Typecheck.escalatesTo(TypecheckExpression(op, schema, scope), Type.Bool, schema),
assert(Typecheck.escalatesTo(TypecheckExpression(op, schema, state.asScope), Type.Bool, schema),
s"Non-boolean test expression: $op"
)
scope
state
case Match.BindExpression(symbol, op) =>
scope ++ Map(
symbol -> TypecheckExpression(op, schema, scope)
)
val exprType = TypecheckExpression(op, schema, state.asScope)
state.bindVarToType(symbol, exprType)
case Match.Node(label, fields) =>
assert(expectedType.isInstanceOf[Type.AST], s"Matching a node when $expectedType was expected: $pattern")
val family = expectedType.asInstanceOf[Type.AST].family
val nodeSchema = astSchemaWithFamily(family, label)
assert(nodeSchema.fields.size == fields.size, s"Node type $family.$label expects ${nodeSchema.fields.size} but pattern has ${fields.size}: $pattern")
var runningScope = scope
for( (Field(fieldName, fieldType), fieldMatch) <- nodeSchema.fields.zip(fields) )
{
runningScope = apply(fieldMatch, target, fieldType, schema, runningScope)
assert(
Typecheck.escalatesTo(
Type.Node(label),
typeOfCurrentNode,
schema
),
s"Can't possibly match a Node[$label] with an item of type $typeOfCurrentNode: $pattern"
)
fields.zipWithIndex.foldLeft(
state.bindPath(targetPath, Type.Node(label))
) { case (state, (fieldPattern, fieldIdx)) =>
check(fieldPattern, targetPath :+ fieldIdx, state)
}
runningScope ++ target.map { _ -> Type.Node(label) }.toMap
}
} catch {
case e: AssertionError =>
// e.printStackTrace()
throw new MatchTypecheckError(e.getMessage(), pattern::Nil)
case TypecheckExpression.ExpressionTypecheckError(msg, stack) =>
throw new MatchTypecheckError(s"${stack.head}\n ,--in--^\n$msg", pattern::Nil)

View File

@ -0,0 +1,20 @@
@import com.astraldb.spec.Definition
@import com.astraldb.spec.Type
@import com.astraldb.bdd.BDD
@import com.astraldb.codegen
@import com.astraldb.codegen.Code
@(schema: Definition, family: Type.AST, bdd: BDD)
object @{family.scalaType}BDD extends Rule[@{family.scalaType}]
{
def apply(plan: @{family.scalaType}): @{family.scalaType} =
{
@{codegen.BDD(schema, family, bdd,
onFail = Code.Literal("plan"),
root = Code.Literal("plan")
)
.toString(indent = 4)
.stripPrefix(" ")}
}
}

View File

@ -9,15 +9,14 @@
@(schema: Definition, rule: Rule)
object @{rule.safeLabel} extends Rule[LogicalPlan]
object @{rule.safeLabel} extends Rule[@{rule.family.scalaType}]
{
def apply(plan: LogicalPlan): LogicalPlan =
def apply(plan: @{rule.family.scalaType}): @{rule.family.scalaType} =
{
@{
val matchSchema = TypecheckMatch(
rule.pattern,
Some("plan"),
Type.AST(rule.family),
rule.family,
schema,
schema.globals
)
@ -26,7 +25,8 @@ object @{rule.safeLabel} extends Rule[LogicalPlan]
schema = schema,
pattern = rule.pattern,
target = Code.Literal("plan"),
targetType = Type.AST(rule.family),
targetPath = Seq.empty,
targetType = rule.family,
onSuccess = Expression(schema, rule.rewrite, _),
onFail = { _ => Code.Literal("plan") },
name = Some("plan"),