Merge branch 'main' of git.odin.cse.buffalo.edu:Astral/astral-compiler
commit
88760113ed
|
@ -1,7 +1,9 @@
|
|||
package com.astraldb.catalyst
|
||||
|
||||
import com.astraldb.codegen.Render
|
||||
import com.astraldb.codegen.Rule
|
||||
import com.astraldb.bdd.{ Compiler => BDDCompiler }
|
||||
import com.astraldb.spec.Type
|
||||
|
||||
object Astral
|
||||
{
|
||||
|
@ -18,9 +20,15 @@ object Astral
|
|||
|
||||
}
|
||||
|
||||
println(
|
||||
BDDCompiler.bdd(definition)
|
||||
)
|
||||
val family = Type.AST("LogicalPlan")
|
||||
|
||||
val bdd =
|
||||
BDDCompiler.bdd(definition, family)
|
||||
|
||||
print(bdd)
|
||||
|
||||
println(
|
||||
Rule(definition, bdd, family)
|
||||
)
|
||||
}
|
||||
}
|
|
@ -7,6 +7,7 @@ import com.astraldb.typecheck.TypecheckMatch.MatchTypecheckError
|
|||
import com.astraldb.typecheck.TypecheckExpression.ExpressionTypecheckError
|
||||
|
||||
import com.astraldb.codegen.Render
|
||||
import com.astraldb.spec.Type
|
||||
|
||||
object Generate
|
||||
{
|
||||
|
@ -20,7 +21,12 @@ object Generate
|
|||
// err.printStackTrace()
|
||||
System.exit(-1)
|
||||
}
|
||||
for( (file, content) <- Render(Catalyst.definition) )
|
||||
|
||||
val families = Seq[Type.AST](
|
||||
Type.AST("LogicalPlan")
|
||||
)
|
||||
|
||||
for( (file, content) <- Render(Catalyst.definition, families) )
|
||||
{
|
||||
val of = new BufferedWriter(new FileWriter(file))
|
||||
of.write("""package com.astraldb.catalyst
|
||||
|
|
|
@ -35,6 +35,7 @@ case class PickByType(
|
|||
case class PickByMatch(
|
||||
path: Seq[Int],
|
||||
matcher: Match,
|
||||
bindings: Seq[Pathed[(String, Type)]],
|
||||
ifMatched: BDD,
|
||||
ifNotMatched: BDD,
|
||||
) extends BDD
|
||||
|
@ -50,6 +51,8 @@ case class PickByMatch(
|
|||
case class BindAnExpression(
|
||||
symbol: String,
|
||||
expression: Expression,
|
||||
exprType: Type,
|
||||
bindings: Seq[Pathed[(String, Type)]],
|
||||
andThen: BDD
|
||||
) extends BDD
|
||||
{
|
||||
|
@ -63,11 +66,13 @@ case class BindAnExpression(
|
|||
|
||||
case class Rewrite(
|
||||
label: String,
|
||||
rewrite: Expression
|
||||
rewrite: Expression,
|
||||
bindings: Seq[Pathed[(String, Type)]],
|
||||
) extends BDD
|
||||
{
|
||||
val schema = rewrite.references.toSeq
|
||||
def code =
|
||||
Code.Literal(s"Rewrite with $label")
|
||||
Code.Literal(s"Rewrite with $label($schema)")
|
||||
}
|
||||
|
||||
case object NoRewrite extends BDD
|
||||
|
|
|
@ -3,6 +3,8 @@ package com.astraldb.bdd
|
|||
import com.astraldb.spec.Match
|
||||
import com.astraldb.expression.Expression
|
||||
import com.astraldb.spec.Type
|
||||
import com.astraldb.spec.Definition
|
||||
import com.astraldb.typecheck.TypecheckExpression
|
||||
|
||||
sealed trait CandidateStep
|
||||
|
||||
|
@ -13,16 +15,21 @@ object CheckTypeStep
|
|||
CheckTypeStep(pathed.path, pathed.value)
|
||||
}
|
||||
|
||||
case class CheckMatchStep(path: Pathed.Path, matcher: Match) extends CandidateStep
|
||||
case class CheckMatchStep(path: Pathed.Path, matcher: Match, bindings: Seq[Pathed[(String, Type)]]) extends CandidateStep
|
||||
object CheckMatchStep
|
||||
{
|
||||
def apply(pathed: Pathed[Match]): CheckMatchStep =
|
||||
CheckMatchStep(pathed.path, pathed.value)
|
||||
def apply(pathed: Pathed[Match], target: Target): CheckMatchStep =
|
||||
{
|
||||
val refs = pathed.value.references.map { _.v }.toSet
|
||||
CheckMatchStep(pathed.path, pathed.value,
|
||||
target.bindings.filter { b => refs(b.value._1) }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
case class BindExpressionStep(symbol: String, expression: Expression) extends CandidateStep
|
||||
case class BindExpressionStep(symbol: String, expression: Expression, exprType: Type, bindings: Seq[Pathed[(String, Type)]]) extends CandidateStep
|
||||
object BindExpressionStep
|
||||
{
|
||||
def apply(binding: (String, Expression)): BindExpressionStep =
|
||||
BindExpressionStep(binding._1, binding._2)
|
||||
def apply(binding: (String, Expression, Type), target: Target): BindExpressionStep =
|
||||
BindExpressionStep(binding._1, binding._2, binding._3, target.bindings)
|
||||
}
|
||||
|
|
|
@ -4,6 +4,10 @@ import com.astraldb.spec.Definition
|
|||
import com.astraldb.spec.Match
|
||||
import com.astraldb.spec.Type
|
||||
import com.astraldb.expression.Expression
|
||||
import com.astraldb.spec.Rule
|
||||
import com.astraldb.expression.Var
|
||||
import com.astraldb.expression.IfIsDefined
|
||||
import com.astraldb.typecheck.TypecheckMatch
|
||||
|
||||
object Compiler
|
||||
{
|
||||
|
@ -93,24 +97,67 @@ object Compiler
|
|||
disjunctify(matcher)
|
||||
}
|
||||
|
||||
def getTargets(schema: Definition): Seq[Target] =
|
||||
def getClauseTargets(schema: Definition, rule: Rule): Seq[Target] =
|
||||
{
|
||||
schema.rules.flatMap { rule =>
|
||||
val clauses =
|
||||
normalize(schema, rule.pattern).clauses
|
||||
clauses.zipWithIndex.map {
|
||||
case (clause, idx) =>
|
||||
val label =
|
||||
rule.label + (if(clauses.size > 1) { "-"+idx } else { "" })
|
||||
Target.of(label, clause, rule.rewrite)
|
||||
.withoutUnusedBindings
|
||||
}
|
||||
val clauses:Seq[(Conjunction, Set[String])] =
|
||||
normalize(schema, rule.pattern).clauses
|
||||
.map { c =>
|
||||
c -> TypecheckMatch(c.asMatch, rule.family, schema, schema.globals)
|
||||
.keySet
|
||||
}
|
||||
|
||||
val optionalBindings:Set[String] =
|
||||
clauses.flatMap { _._2 }
|
||||
.groupBy { x => x }
|
||||
.filter { _._2.size < clauses.size }
|
||||
.keySet
|
||||
|
||||
clauses.zipWithIndex.map {
|
||||
case ((clause, bindings), idx) =>
|
||||
val label =
|
||||
rule.label + (if(clauses.size > 1) { "-"+idx } else { "" })
|
||||
|
||||
val optionalReferencesInClause =
|
||||
bindings & optionalBindings
|
||||
|
||||
// if(!optionalBindings.isEmpty)
|
||||
// {
|
||||
// println(s"Optional: $optionalBindings")
|
||||
// println(s"Optional Bound: $optionalReferencesInClause")
|
||||
// }
|
||||
|
||||
val sanitizedRewrite =
|
||||
rule.rewrite.transform {
|
||||
case IfIsDefined(symbol, ifYes, ifNo)
|
||||
if optionalBindings(symbol)=>
|
||||
if(optionalReferencesInClause(symbol)) { ifYes }
|
||||
else { ifNo }
|
||||
}
|
||||
|
||||
// if(sanitizedRewrite != rule.rewrite)
|
||||
// {
|
||||
// println(rule.rewrite)
|
||||
// println("into")
|
||||
// println(sanitizedRewrite)
|
||||
// }
|
||||
|
||||
Target.of(schema, label, rule.family, clause, sanitizedRewrite)
|
||||
.withoutUnusedBindings
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
def getTargets(schema: Definition, family: Type.AST): Seq[Target] =
|
||||
{
|
||||
schema.rules.flatMap {
|
||||
case rule if rule.family != family => None
|
||||
case rule => getClauseTargets(schema, rule)
|
||||
}
|
||||
}
|
||||
|
||||
def getSanitizedTargets(schema: Definition): Seq[Target] =
|
||||
def getSanitizedTargets(schema: Definition, family: Type.AST): Seq[Target] =
|
||||
{
|
||||
var targets = getTargets(schema)
|
||||
var targets = getTargets(schema, family)
|
||||
|
||||
// Sanitize expression bindings via alpha renaming and
|
||||
// common subexpression elimination
|
||||
|
@ -146,7 +193,7 @@ object Compiler
|
|||
.toSet.toSeq
|
||||
.groupBy { _._1 }
|
||||
.filter { _._2.size > 1 }
|
||||
.map { case (name:String, exprs:Seq[(String, Expression)]) =>
|
||||
.map { case (name:String, exprs:Seq[(String, Expression, Type)]) =>
|
||||
name ->
|
||||
exprs.map { _._2 }
|
||||
.zipWithIndex
|
||||
|
@ -159,7 +206,7 @@ object Compiler
|
|||
targets = targets.map { target =>
|
||||
val repair =
|
||||
target.exprBindings
|
||||
.flatMap { case (name, expr) =>
|
||||
.flatMap { case (name, expr, t) =>
|
||||
replacementNames.get(name)
|
||||
.flatMap { _.get(expr).map { newName =>
|
||||
name -> newName
|
||||
|
@ -173,16 +220,15 @@ object Compiler
|
|||
}
|
||||
|
||||
|
||||
def bdd(schema: Definition): BDD =
|
||||
def bdd(schema: Definition, family: Type.AST): BDD =
|
||||
{
|
||||
println(getSanitizedTargets(schema).mkString("\n---\n"))
|
||||
def stepBdd(targets: Seq[Target], state: State): BDD =
|
||||
{
|
||||
val success = targets.find { _.successfulOn(state) }
|
||||
|
||||
if(success.isDefined)
|
||||
{
|
||||
Rewrite(success.get.label, success.get.rewrite)
|
||||
Rewrite(success.get.label, success.get.rewrite, success.get.bindings)
|
||||
} else {
|
||||
val activeTargets =
|
||||
targets.filter { _.canBeSuccessfulOn(state) }
|
||||
|
@ -194,7 +240,7 @@ object Compiler
|
|||
val candidates:Seq[CandidateStep] =
|
||||
activeTargets
|
||||
.flatMap { t =>
|
||||
val c = t.candidates(state)
|
||||
val c = t.candidates(schema, state)
|
||||
assert(!c.isEmpty, s"Target ${t.label} in state \n$state\n is not successful or failed, but has no candidates for progress.")
|
||||
/* return */ c
|
||||
}
|
||||
|
@ -215,7 +261,7 @@ object Compiler
|
|||
val types:Set[Type] =
|
||||
bestCheckTypeStep.get._2.map { _._2.check }.toSet
|
||||
|
||||
println(s"Checking path: [${path.mkString(", ")}] for ${types.mkString(", ")}")
|
||||
System.err.println(s"Checking path: [${path.mkString(", ")}] for ${types.mkString(", ")}")
|
||||
|
||||
PickByType(
|
||||
path,
|
||||
|
@ -241,10 +287,11 @@ object Compiler
|
|||
|
||||
bestOtherCandidate match {
|
||||
case _:CheckTypeStep => assert(false, "We filtered out all the check type steps by this point")
|
||||
case CheckMatchStep(path, matcher) =>
|
||||
case CheckMatchStep(path, matcher, bindings) =>
|
||||
PickByMatch(
|
||||
path = path,
|
||||
matcher = matcher,
|
||||
bindings = bindings,
|
||||
ifMatched =
|
||||
stepBdd(activeTargets,
|
||||
state.withSuccessfulMatch(path, matcher)),
|
||||
|
@ -252,8 +299,10 @@ object Compiler
|
|||
stepBdd(activeTargets,
|
||||
state.withFailedMatch(path, matcher)),
|
||||
)
|
||||
case BindExpressionStep(symbol, expression) =>
|
||||
case BindExpressionStep(symbol, expression, exprType, bindings) =>
|
||||
BindAnExpression(symbol, expression,
|
||||
exprType,
|
||||
bindings,
|
||||
stepBdd(activeTargets,
|
||||
state.withBinding(symbol)
|
||||
)
|
||||
|
@ -265,7 +314,7 @@ object Compiler
|
|||
}
|
||||
}
|
||||
|
||||
stepBdd(getSanitizedTargets(schema), State.empty(schema))
|
||||
stepBdd(getSanitizedTargets(schema, family), State.empty(schema))
|
||||
}
|
||||
|
||||
}
|
|
@ -3,21 +3,38 @@ package com.astraldb.bdd
|
|||
import com.astraldb.expression._
|
||||
import com.astraldb.spec._
|
||||
import scala.util.Try
|
||||
import com.astraldb.typecheck.TypecheckExpression
|
||||
|
||||
case class Target(
|
||||
label: String,
|
||||
family: Type.AST,
|
||||
rewrite: Expression,
|
||||
bindings: Seq[Pathed[String]] = Seq.empty,
|
||||
exprBindings: Seq[(String, Expression)] = Seq.empty,
|
||||
bindings: Seq[Pathed[(String, Type)]] = Seq.empty,
|
||||
exprBindings: Seq[(String, Expression, Type)] = Seq.empty,
|
||||
types: Seq[Pathed[Type]] = Seq.empty,
|
||||
matchers: Seq[Pathed[Match]] = Seq.empty,
|
||||
)
|
||||
{
|
||||
def withBinding(path: Seq[Int], symbol: String) =
|
||||
copy(bindings = bindings :+ Pathed(path, symbol))
|
||||
{
|
||||
val varType =
|
||||
types.find { _.path == path }
|
||||
.map { _.value }
|
||||
.getOrElse { family }
|
||||
copy(
|
||||
bindings = bindings :+ Pathed(path, (symbol, varType))
|
||||
)
|
||||
}
|
||||
|
||||
def withType(path: Seq[Int], pathType: Type) =
|
||||
copy(types = types :+ Pathed(path, pathType))
|
||||
copy(
|
||||
types = types :+ Pathed(path, pathType),
|
||||
bindings = bindings.map {
|
||||
case b if b.path == path =>
|
||||
b.copy( value = b.value._1 -> pathType )
|
||||
case b => b
|
||||
}
|
||||
)
|
||||
|
||||
def withMatcher(path: Seq[Int], matcher: Match) =
|
||||
if(matcher == Match.Any){ this }
|
||||
|
@ -25,8 +42,8 @@ case class Target(
|
|||
copy(matchers = matchers :+ Pathed(path, matcher))
|
||||
}
|
||||
|
||||
def withBoundExpression(symbol: String, expr: Expression) =
|
||||
copy(exprBindings = exprBindings :+ (symbol, expr))
|
||||
def withBoundExpression(symbol: String, expr: Expression, exprType: Type) =
|
||||
copy(exprBindings = exprBindings :+ (symbol, expr, exprType))
|
||||
|
||||
def withoutUnusedBindings =
|
||||
{
|
||||
|
@ -36,11 +53,11 @@ case class Target(
|
|||
exprBindings.flatMap { _._2.references }.map { _.v}
|
||||
|
||||
copy(
|
||||
bindings = bindings.filter { b => refs(b.value) }
|
||||
bindings = bindings.filter { b => refs(b.value._1) }
|
||||
)
|
||||
}
|
||||
|
||||
def candidates(state: State): Seq[CandidateStep] =
|
||||
def candidates(schema: Definition, state: State): Seq[CandidateStep] =
|
||||
{
|
||||
val requestedTypedPaths = types.map { _.path }.toSet
|
||||
val typedPaths = state.typeOf.keySet
|
||||
|
@ -57,7 +74,7 @@ case class Target(
|
|||
if(requestedTypedPaths(b.path)) { b.pathInSet(typedPaths) }
|
||||
// If the binding itself is not typed, then we just need its ancestors bound
|
||||
else { b.ancestorsInSet(typedPaths) }
|
||||
}.map { b => Var(b.value) }
|
||||
}.map { b => Var(b.value._1) }
|
||||
).toSet
|
||||
|
||||
// println(s"Valid bindings: $validBindings")
|
||||
|
@ -101,8 +118,8 @@ case class Target(
|
|||
// println(s"Candidate Matchers: $candidateMatchers")
|
||||
|
||||
(candidateTypings.map { CheckTypeStep(_) }: Seq[CandidateStep]) ++
|
||||
candidateMatchers.map { CheckMatchStep(_) } ++
|
||||
candidateBindings.map { BindExpressionStep(_) }
|
||||
candidateMatchers.map { CheckMatchStep(_, this) } ++
|
||||
candidateBindings.map { BindExpressionStep(_, this) }
|
||||
}
|
||||
|
||||
def canBeSuccessfulOn(state: State): Boolean =
|
||||
|
@ -131,13 +148,14 @@ case class Target(
|
|||
copy(
|
||||
rewrite = rewrite.rename(replacementMap),
|
||||
bindings = bindings.map { b =>
|
||||
replacementMap.get(b.value)
|
||||
.map { r => b.copy(value = r) }
|
||||
replacementMap.get(b.value._1)
|
||||
.map { r => b.copy(value = (r, b.value._2)) }
|
||||
.getOrElse(b)
|
||||
},
|
||||
exprBindings = exprBindings.map { case (name, expr) =>
|
||||
exprBindings = exprBindings.map { case (name, expr, exprType) =>
|
||||
( replacementMap.getOrElse(name, name),
|
||||
expr.rename(replacementMap)
|
||||
expr.rename(replacementMap),
|
||||
exprType
|
||||
)
|
||||
},
|
||||
matchers = matchers.map { m =>
|
||||
|
@ -146,6 +164,12 @@ case class Target(
|
|||
)
|
||||
}
|
||||
|
||||
def scope: Map[String, Type] =
|
||||
(
|
||||
bindings.map { _.value } ++
|
||||
exprBindings.map { b => b._1 -> b._3 }
|
||||
).toMap
|
||||
|
||||
override def toString(): String =
|
||||
label+
|
||||
(if(types.isEmpty){ "" } else {
|
||||
|
@ -158,7 +182,7 @@ case class Target(
|
|||
})+
|
||||
(if(exprBindings.isEmpty){ "" } else {
|
||||
"\n Expression Bindings\n"+
|
||||
exprBindings.map { case (s, e) => " "+s+" <- "+e }.mkString("\n")
|
||||
exprBindings.map { case (s, e, t) => " "+s+" <- "+e }.mkString("\n")
|
||||
})+
|
||||
(if(matchers.isEmpty){ "" } else {
|
||||
"\n Matchers\n"+
|
||||
|
@ -169,9 +193,9 @@ case class Target(
|
|||
|
||||
object Target
|
||||
{
|
||||
def of(label: String, clause: Conjunction, rewrite: Expression): Target =
|
||||
def of(schema: Definition, label: String, family: Type.AST, clause: Conjunction, rewrite: Expression): Target =
|
||||
{
|
||||
clause.atoms.foldLeft(Target(label, rewrite)){ (target, atom) =>
|
||||
clause.atoms.foldLeft(Target(label, family, rewrite)){ (target, atom) =>
|
||||
atom match {
|
||||
case Match.Path(path, Match.Bind(symbol, child)) =>
|
||||
target.withBinding(path, symbol)
|
||||
|
@ -186,7 +210,9 @@ object Target
|
|||
case Match.Path(path, child) =>
|
||||
target.withMatcher(path, child)
|
||||
case Match.BindExpression(symbol, expr) =>
|
||||
target.withBoundExpression(symbol, expr)
|
||||
target.withBoundExpression(symbol, expr,
|
||||
TypecheckExpression(expr, schema, target.scope++schema.globals)
|
||||
)
|
||||
case _ =>
|
||||
target.withMatcher(Seq.empty, atom)
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package com.astraldb.bdd
|
|||
|
||||
import com.astraldb.spec.Match
|
||||
import com.astraldb.codegen.Code
|
||||
import com.astraldb.expression.Var
|
||||
|
||||
case class Conjunction(atoms: Seq[Match])
|
||||
{
|
||||
|
@ -28,6 +29,10 @@ case class Conjunction(atoms: Seq[Match])
|
|||
")"
|
||||
)
|
||||
}
|
||||
|
||||
def references:Set[Var] = atoms.flatMap { _.references }.toSet
|
||||
|
||||
def asMatch = Match.And(atoms)
|
||||
}
|
||||
|
||||
object Conjunction
|
||||
|
@ -78,6 +83,10 @@ case class Disjunction(clauses: Seq[Conjunction])
|
|||
")"
|
||||
)
|
||||
}
|
||||
|
||||
def references:Set[Var] = clauses.flatMap { _.references }.toSet
|
||||
|
||||
def asMatch = Match.Or(clauses.map { _.asMatch })
|
||||
}
|
||||
|
||||
object Disjunction
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
package com.astraldb.codegen
|
||||
|
||||
import com.astraldb.spec.Definition
|
||||
import com.astraldb.spec.Type
|
||||
import com.astraldb.bdd
|
||||
import com.astraldb.codegen.Code.PaddedString
|
||||
import com.astraldb.expression.Var
|
||||
|
||||
object BDD
|
||||
{
|
||||
def apply(
|
||||
schema: Definition,
|
||||
family: Type.AST,
|
||||
rule: bdd.BDD,
|
||||
onFail: Code,
|
||||
root: Code,
|
||||
): Code =
|
||||
{
|
||||
def codeScopeFor[T](refs: Set[Var], bindings: Seq[bdd.Pathed[(String, Type)]], pathNames: Map[bdd.Pathed.Path, (Code, Type)], boundVars: Map[String, Type], context: T): CodeScope =
|
||||
CodeScope(
|
||||
(
|
||||
(refs.map { _.v }
|
||||
-- boundVars.keys
|
||||
-- schema.globals.keys
|
||||
).toSeq.map { ref =>
|
||||
val pathRef =
|
||||
bindings.find { _.value._1 == ref }
|
||||
.getOrElse {
|
||||
assert(false, s"Reference to a variable that hasn't been bound yet: $ref (in $bindings):\n$context")
|
||||
}
|
||||
val varDescription =
|
||||
pathNames.find { _._1 == pathRef.path }
|
||||
.getOrElse {
|
||||
assert(false, s"Reference to a variable at an unbound path: ${pathRef.path} (in $pathNames):\n$context")
|
||||
}
|
||||
._2
|
||||
ref -> varDescription
|
||||
}.toMap:Map[String, Type|(Code,Type)]
|
||||
) ++ (boundVars:Map[String, Type|(Code,Type)])
|
||||
++ (schema.globals:Map[String, Type|(Code,Type)])
|
||||
)
|
||||
|
||||
def recur(rule: bdd.BDD, pathNames: Map[bdd.Pathed.Path, (Code, Type)], boundVars: Map[String, Type]): Code =
|
||||
rule match {
|
||||
case bdd.NoRewrite => onFail
|
||||
case bdd.PickByType(path, types) =>
|
||||
// println(s"Generating code for $path")
|
||||
val target = pathNames.get(path)
|
||||
.getOrElse {
|
||||
assert(false, s"No name available for $path (in $pathNames):\n$rule")
|
||||
}
|
||||
._1
|
||||
Code.IfElifElse(
|
||||
types.map { case (targetType, andThen) =>
|
||||
{
|
||||
val nodeName = "root"+path.map { "_child"+_ }.mkString+"_as_"+targetType.scalaType
|
||||
val newPathNames: Map[bdd.Pathed.Path, (Code, Type)] =
|
||||
targetType match {
|
||||
case Type.Node(node) =>
|
||||
schema.nodesByName(node)
|
||||
.fields
|
||||
.zipWithIndex
|
||||
.map { case (field, idx) =>
|
||||
(path :+ idx) -> (
|
||||
Code.BinOp(
|
||||
Code.Literal(nodeName),
|
||||
".",
|
||||
Code.Literal(field.name)
|
||||
),
|
||||
field.t
|
||||
)
|
||||
}
|
||||
.toMap
|
||||
case _ => Map.empty
|
||||
}
|
||||
Code.Pair(target, Code.Literal(s".isInstanceOf[${targetType.scalaType}]"))
|
||||
-> Code.Block(Seq(
|
||||
Code.BinOp(
|
||||
Code.Literal(s"val $nodeName"),
|
||||
PaddedString.pad(1, "="),
|
||||
Code.Pair(target,
|
||||
Code.Literal(s".asInstanceOf[${targetType.scalaType}]")
|
||||
)
|
||||
)
|
||||
)++recur(
|
||||
andThen,
|
||||
pathNames ++ newPathNames ++ Map(path -> (Code.Literal(nodeName), targetType)),
|
||||
boundVars,
|
||||
).block
|
||||
)
|
||||
}
|
||||
}.toSeq,
|
||||
onFail
|
||||
)
|
||||
case bdd.PickByMatch(path, pattern, bindings, onMatch, onFail) =>
|
||||
val target = pathNames.get(path)
|
||||
.getOrElse {
|
||||
assert(false, s"No name available for $path (in $pathNames):\n$rule")
|
||||
}
|
||||
val scope =
|
||||
codeScopeFor(pattern.references, bindings, pathNames, boundVars, pattern)
|
||||
Match(
|
||||
schema = schema,
|
||||
pattern = pattern,
|
||||
target = target._1,
|
||||
targetPath = path,
|
||||
targetType = target._2,
|
||||
onSuccess =
|
||||
_ => recur(
|
||||
onMatch,
|
||||
pathNames,
|
||||
boundVars,
|
||||
),
|
||||
onFail =
|
||||
_ => recur(
|
||||
onFail,
|
||||
pathNames,
|
||||
boundVars,
|
||||
),
|
||||
name = None,
|
||||
scope = scope
|
||||
)
|
||||
case bdd.Rewrite(label, op, bindings) =>
|
||||
Code.Block(
|
||||
Seq(
|
||||
Code.Literal(s"// Rewrite by $label"),
|
||||
) ++ Expression(
|
||||
schema,
|
||||
op,
|
||||
codeScopeFor(op.references, bindings, pathNames, boundVars, op)
|
||||
).block
|
||||
)
|
||||
|
||||
case bdd.BindAnExpression(symbol, expression, exprType, bindings, andThen) =>
|
||||
Code.Block(
|
||||
Seq(
|
||||
Code.Pair(
|
||||
Code.Literal(s"val $symbol ="),
|
||||
Expression(
|
||||
schema, expression,
|
||||
codeScopeFor(expression.references, bindings, pathNames, boundVars, expression)
|
||||
)
|
||||
)
|
||||
)++recur(
|
||||
andThen,
|
||||
pathNames,
|
||||
boundVars ++ Map(symbol -> exprType)
|
||||
).block
|
||||
)
|
||||
}
|
||||
|
||||
recur(
|
||||
rule,
|
||||
Map(
|
||||
Seq.empty -> (root, family)
|
||||
),
|
||||
Map.empty,
|
||||
)
|
||||
}
|
||||
}
|
|
@ -321,4 +321,34 @@ object Code
|
|||
|
||||
def Parenthesize(body: Code): Code =
|
||||
Parens("(", body, ")")
|
||||
|
||||
def IndentedPair(left: Code, right: Code): Code =
|
||||
Pair(left, Indent(2, right))
|
||||
|
||||
def IfElifElse(ifThenClauses: Seq[(Code, Code)], elseClause: Code): Code =
|
||||
{
|
||||
Block(
|
||||
Seq(
|
||||
IndentedPair(
|
||||
Parens("if(", ifThenClauses.head._1, PaddedString.rightPad(1, ") {")),
|
||||
ifThenClauses.head._2,
|
||||
)
|
||||
) ++ ifThenClauses.tail.map { it =>
|
||||
IndentedPair(
|
||||
Parens(
|
||||
PaddedString.leftPad(1, "} else if("),
|
||||
it._1,
|
||||
PaddedString.rightPad(1, ") {")
|
||||
),
|
||||
it._2,
|
||||
)
|
||||
} ++ Seq(
|
||||
Parens(
|
||||
PaddedString.pad(1, "} else {"),
|
||||
elseClause,
|
||||
PaddedString.leftPad(1, "}")
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
|
@ -2,7 +2,7 @@ package com.astraldb.codegen
|
|||
|
||||
import com.astraldb.spec
|
||||
|
||||
class CodeScope(vars: Map[String, spec.Type | (Code, spec.Type)])
|
||||
case class CodeScope(vars: Map[String, spec.Type | (Code, spec.Type)])
|
||||
{
|
||||
type Element = spec.Type | (Code, spec.Type)
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ object Match
|
|||
schema: spec.Definition,
|
||||
pattern: spec.Match,
|
||||
target: Code,
|
||||
targetPath: Seq[Int],
|
||||
targetType: spec.Type,
|
||||
onSuccess: CodeScope => Code,
|
||||
onFail: CodeScope => Code,
|
||||
|
@ -25,16 +26,18 @@ object Match
|
|||
{
|
||||
pattern match {
|
||||
case And(Seq()) => onSuccess(scope)
|
||||
case And(Seq(a)) => apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
|
||||
case And(Seq(a)) => apply(schema, a, target, targetPath, targetType, onSuccess, onFail, name, scope)
|
||||
case And(a) =>
|
||||
apply(schema, a.head,
|
||||
target = target,
|
||||
targetPath = targetPath,
|
||||
targetType = targetType,
|
||||
onSuccess =
|
||||
scope =>
|
||||
apply(schema, And(a.tail),
|
||||
target = target,
|
||||
targetType = targetType,
|
||||
targetPath = targetPath,
|
||||
onSuccess = onSuccess,
|
||||
onFail = onFail,
|
||||
name = name,
|
||||
|
@ -45,35 +48,45 @@ object Match
|
|||
scope = scope
|
||||
)
|
||||
case Not(a) =>
|
||||
apply(schema, a, target, targetType, onFail, onSuccess, name, scope)
|
||||
apply(schema, a, target, targetPath, targetType, onFail, onSuccess, name, scope)
|
||||
case Or(Seq()) => onFail(scope)
|
||||
case Or(Seq(a)) =>
|
||||
apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
|
||||
apply(schema, a, target, targetPath, targetType, onSuccess, onFail, name, scope)
|
||||
case Or(a) =>
|
||||
val optional:Map[String, Type] =
|
||||
a.flatMap {
|
||||
TypecheckMatch(_, name, targetType, schema, scope.flatten).toSeq
|
||||
val optional:Map[Var|TypecheckMatch.PathVar, Type] =
|
||||
a.flatMap { m =>
|
||||
TypecheckMatch(
|
||||
m,
|
||||
targetType,
|
||||
schema,
|
||||
scope.flatten
|
||||
).toSeq
|
||||
}.groupBy { _._1 }
|
||||
.filter { _._2.size < a.size }
|
||||
.mapValues { case v => Typecheck.glb(v.map { _._2 }, schema) }
|
||||
.map { case (v, t) =>
|
||||
Var(v) ->
|
||||
Typecheck.glb(t.map { _._2 }, schema)
|
||||
}
|
||||
.toMap
|
||||
|
||||
a.foldRight(onFail) { (matcher, nextOnFail) =>
|
||||
_ =>
|
||||
apply(schema, matcher,
|
||||
target = target,
|
||||
targetPath = targetPath,
|
||||
targetType = targetType,
|
||||
onSuccess = {
|
||||
(scope: CodeScope) => onSuccess(
|
||||
scope.map {
|
||||
case (name, t: Type) if optional contains name =>
|
||||
case (name, t: Type) if optional contains Var(name) =>
|
||||
name -> (Code.Literal(s"(Some($name):${Type.Option(t).scalaType})"), Type.Option(t))
|
||||
case (name, c: (Code, Type)) if optional contains name =>
|
||||
case (name, c: (Code, Type)) if optional contains Var(name) =>
|
||||
name -> (Code.Parens(s"(Some(", c._1, s"):${Type.Option(c._2).scalaType})"), Type.Option(c._2))
|
||||
case x => x
|
||||
}.withVars(
|
||||
(optional.keySet -- scope.keys).toSeq.map { k =>
|
||||
k -> (Code.Literal(s"(None:${Type.Option(optional(k)).scalaType})"):Code, Type.Option(optional(k)):Type)
|
||||
(optional.keySet -- scope.keys.map { Var(_) }).toSeq.collect {
|
||||
case k:Var =>
|
||||
k.v -> (Code.Literal(s"(None:${Type.Option(optional(k)).scalaType})"):Code, Type.Option(optional(k)):Type)
|
||||
}:_*
|
||||
)
|
||||
)
|
||||
|
@ -93,6 +106,7 @@ object Match
|
|||
Var(selectedName) eq pattern
|
||||
)),
|
||||
target = target,
|
||||
targetPath = targetPath,
|
||||
targetType = targetType,
|
||||
onSuccess = onSuccess,
|
||||
onFail = onFail,
|
||||
|
@ -105,6 +119,7 @@ object Match
|
|||
schema = schema,
|
||||
pattern = pattern,
|
||||
target = target,
|
||||
targetPath = targetPath,
|
||||
targetType = targetType,
|
||||
onSuccess = onSuccess,
|
||||
onFail = onFail,
|
||||
|
@ -154,11 +169,13 @@ object Match
|
|||
) ++
|
||||
children
|
||||
.zip(node.fields)
|
||||
.foldRight(onSuccess) { case ((child, field), andThen) =>
|
||||
.zipWithIndex
|
||||
.foldRight(onSuccess) { case (((child, field), idx), andThen) =>
|
||||
nextScope => apply(
|
||||
schema = schema,
|
||||
pattern = child,
|
||||
target = Code.Literal(s"$selectedName.${field.name}"),
|
||||
targetPath = targetPath :+ idx,
|
||||
targetType = field.t,
|
||||
onSuccess = andThen,
|
||||
onFail = onFail,
|
||||
|
|
|
@ -1,14 +1,23 @@
|
|||
package com.astraldb.codegen
|
||||
|
||||
import com.astraldb.spec.Definition
|
||||
import com.astraldb.spec.Type
|
||||
import com.astraldb.bdd
|
||||
|
||||
object Render
|
||||
{
|
||||
def apply(schema: Definition): Map[String, String] =
|
||||
def apply(schema: Definition, bddFamilies: Seq[Type.AST] = Seq.empty): Map[String, String] =
|
||||
{
|
||||
val bddRule: Seq[(Type.AST, bdd.BDD)] =
|
||||
bddFamilies.map { f =>
|
||||
f -> bdd.Compiler.bdd(schema, f)
|
||||
}
|
||||
|
||||
Map(
|
||||
"Optimizer.scala" -> Optimizer(schema)
|
||||
) ++ schema.rules.map { rule =>
|
||||
"Optimizer.scala" -> Optimizer(schema),
|
||||
) ++ (bddRule.map { case (family, bdd) =>
|
||||
s"${family.scalaType}BDD.scala" -> Rule(schema, bdd, family)
|
||||
}.toMap) ++ schema.rules.map { rule =>
|
||||
s"${rule.safeLabel}.scala" -> Rule(schema, rule)
|
||||
}.toMap
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package com.astraldb.codegen
|
||||
|
||||
import com.astraldb.spec
|
||||
import com.astraldb.bdd
|
||||
|
||||
object Rule
|
||||
{
|
||||
|
@ -8,4 +9,9 @@ object Rule
|
|||
{
|
||||
scala.Rule(schema, rule).toString
|
||||
}
|
||||
|
||||
def apply(schema: spec.Definition, rule: bdd.BDD, family: spec.Type.AST) =
|
||||
{
|
||||
scala.BDDBatch(schema, family, rule).toString
|
||||
}
|
||||
}
|
|
@ -360,6 +360,8 @@ case class Let(symbol: String, value: Expression, rest: Expression) extends Expr
|
|||
def children = Seq(value, rest)
|
||||
def reassemble(in: Seq[Expression]): Expression = Let(symbol, in(0), in(1))
|
||||
override def toString = s"let $symbol = $value in $rest"
|
||||
override def references: Set[Var] =
|
||||
value.references ++ (rest.references -- Set(Var(symbol)))
|
||||
}
|
||||
/**
|
||||
* Branch on whether a symbol is defined in scope
|
||||
|
@ -370,4 +372,6 @@ case class IfIsDefined(symbol: String, ifYes: Expression, ifNo: Expression) exte
|
|||
def reassemble(in: Seq[Expression]): Expression =
|
||||
copy(ifYes = in(0), ifNo = in(1))
|
||||
override def toString = s"if($symbol exists){ $ifYes } else { $ifNo }"
|
||||
override def references: Set[Var] =
|
||||
super.references ++ Set(Var(symbol))
|
||||
}
|
|
@ -4,7 +4,7 @@ import scala.collection.mutable
|
|||
import com.astraldb.expression._
|
||||
|
||||
case class Definition(
|
||||
asts: Map[String, ASTDefinition],
|
||||
asts: Map[Type.AST, ASTDefinition],
|
||||
rules:Seq[Rule],
|
||||
globals: Map[String, Type],
|
||||
) {
|
||||
|
@ -20,7 +20,7 @@ case class Definition(
|
|||
|${rules.mkString("\n\n")}
|
||||
""".stripMargin
|
||||
|
||||
val familyOfNode: Map[String, String] =
|
||||
val familyOfNode: Map[String, Type.AST] =
|
||||
nodes.flatMap { case (family, elements) => elements.map { _.name -> family } }
|
||||
.toMap
|
||||
|
||||
|
@ -39,7 +39,7 @@ class HardcodedDefinition
|
|||
Definition(
|
||||
asts = nodes.map { case (f, n) =>
|
||||
val family = Type.AST(f)
|
||||
f -> ASTDefinition(
|
||||
family -> ASTDefinition(
|
||||
family = family,
|
||||
nodes = n.map { _.copy(family = family) }.toSet
|
||||
)}.toMap,
|
||||
|
@ -62,7 +62,7 @@ class HardcodedDefinition
|
|||
|
||||
def Rule(label: String, family: String)(pattern: Match)(replacement: Expression): Unit =
|
||||
rules.append(
|
||||
com.astraldb.spec.Rule(label, family, pattern, replacement)
|
||||
com.astraldb.spec.Rule(label, Type.AST(family), pattern, replacement)
|
||||
)
|
||||
|
||||
def Function(label: String, ret: Type = Type.Unit)(args: Type*): Unit =
|
||||
|
|
|
@ -6,7 +6,7 @@ import com.astraldb.typecheck.TypecheckExpression
|
|||
|
||||
case class Rule(
|
||||
label: String,
|
||||
family: String,
|
||||
family: Type.AST,
|
||||
pattern: Match,
|
||||
rewrite: Expression
|
||||
)
|
||||
|
@ -17,11 +17,11 @@ case class Rule(
|
|||
def validate(schema: Definition): Unit =
|
||||
{
|
||||
val scope =
|
||||
TypecheckMatch(pattern, None, Type.AST(family), schema, schema.globals)
|
||||
TypecheckMatch(pattern, family, schema, schema.globals)
|
||||
TypecheckExpression(rewrite, schema, scope) match {
|
||||
case Type.Node(n) =>
|
||||
assert(schema.familyOfNode(n) == family, s"Rule $label constructs a $n of family ${schema.familyOfNode(n)}, while the rule is for famil $family")
|
||||
case Type.AST(f) =>
|
||||
case f:Type.AST =>
|
||||
assert(f == family, s"Rule $label constructs a $f, while the rule is for family $family")
|
||||
case c =>
|
||||
assert(false, s"Rule $label generates a non-AST type $c")
|
||||
|
|
|
@ -27,6 +27,13 @@ object Typecheck
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the greatest lower bound.
|
||||
*
|
||||
* This is the *more* precise type of the two arguments.
|
||||
*
|
||||
* e.g., glb(Node, AST) = Node
|
||||
*/
|
||||
def glb(a: Type, b: Type, schema: Definition): Type =
|
||||
{
|
||||
if(a == b){ return a }
|
||||
|
@ -49,12 +56,31 @@ object Typecheck
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the greatest lower bound.
|
||||
*
|
||||
* This is the *more* precise type of the two arguments.
|
||||
*
|
||||
* e.g., glb(Node, AST) = Node
|
||||
*/
|
||||
def glb(elems: Seq[Type], schema: Definition): Type =
|
||||
elems.foldLeft(Type.Any: Type){ glb(_, _, schema) }
|
||||
|
||||
/**
|
||||
* Compute the least upper bound.
|
||||
*
|
||||
* This is the *less* precise type of the two arguments.
|
||||
*
|
||||
* e.g., glb(Node, AST) = AST
|
||||
*/
|
||||
def lub(a: Type, b: Type, schema: Definition): Type =
|
||||
lubOption(a, b, schema).getOrElse {
|
||||
assert(false, s"Types $a and $b don't unify")
|
||||
}
|
||||
|
||||
def lubOption(a: Type, b: Type, schema: Definition): Option[Type] =
|
||||
{
|
||||
if(a == b){ return a }
|
||||
if(a == b){ return Some(a) }
|
||||
(a, b) match {
|
||||
case (n:Type.Node, o:Type.ASTType) if
|
||||
schema.nodesByName(n.nodeType)
|
||||
|
|
|
@ -117,8 +117,8 @@ object TypecheckExpression
|
|||
case Type.Node(nodeType) =>
|
||||
val node = schema.nodesByName(nodeType)
|
||||
node.fields(index).t
|
||||
case _ =>
|
||||
assert(false, s"Node subscript on something not a node")
|
||||
case t =>
|
||||
assert(false, s"Node subscript on something not a node: $t")
|
||||
}
|
||||
|
||||
case StructSubscript(target, name) =>
|
||||
|
|
|
@ -6,19 +6,247 @@ import com.astraldb.expression._
|
|||
object TypecheckMatch
|
||||
{
|
||||
|
||||
def apply(pattern: Match, targetType: Type, schema: Definition, scope: Map[String, Type]): Map[String, Type] =
|
||||
{
|
||||
val finalState =
|
||||
check(
|
||||
pattern = pattern,
|
||||
targetPath = Seq.empty,
|
||||
State(
|
||||
scope = scope ++ schema.globals,
|
||||
pathTypes = Map(Seq.empty -> targetType),
|
||||
optionalBindings = Set.empty,
|
||||
schema = schema
|
||||
)
|
||||
)
|
||||
finalState.asScope
|
||||
}
|
||||
|
||||
case class PathVar(path: Seq[Int])
|
||||
|
||||
case class State(
|
||||
scope: Map[String, PathVar|Type],
|
||||
pathTypes: Map[Seq[Int], Type],
|
||||
optionalBindings: Set[String],
|
||||
schema: Definition
|
||||
)
|
||||
{
|
||||
assert(pathTypes.contains(Seq.empty), s"We should never create a state without a root path: $pathTypes")
|
||||
|
||||
override def toString(): String =
|
||||
"---- Scope ----\n"+
|
||||
scope.map {
|
||||
case (name, t:Type) if optionalBindings(name) => s"$name:$t (optional)\n"
|
||||
case (name, t:Type) => s"$name:$t\n"
|
||||
case (name, PathVar(path)) if optionalBindings(name) => s"$name @ [${path.mkString(", ")}] (optional)\n"
|
||||
case (name, PathVar(path)) => s"$name @ [${path.mkString(", ")}]\n"
|
||||
}.mkString +
|
||||
"---- PathTypes ----\n"+
|
||||
pathTypes.map {
|
||||
case (path, t) => s"@[${path.mkString(",")}]:$t\n"
|
||||
}.mkString
|
||||
|
||||
|
||||
def typeAtPath(path: Seq[Int]): Type =
|
||||
path.foldLeft( (
|
||||
Seq[Int](),
|
||||
pathTypes.get(Seq.empty)
|
||||
.getOrElse {
|
||||
System.err.println(this)
|
||||
assert(false, s"Trying to look up the type of a path when the root is untyped")
|
||||
}
|
||||
) ) { case ((subPath, t), field) =>
|
||||
val fieldPath = subPath :+ field
|
||||
val fieldType =
|
||||
t match {
|
||||
case Type.Node(label) =>
|
||||
assert(schema.nodesByName contains label, s"No AST node of type $label defined")
|
||||
val nodeSchema = schema.nodesByName(label)
|
||||
assert(nodeSchema.fields.size > field, s"Can't path into the $field'th element of a $t")
|
||||
nodeSchema.fields(field).t
|
||||
case Type.Array(base) =>
|
||||
base
|
||||
case _ =>
|
||||
// System.err.println(this)
|
||||
assert(false, s"Can't path into a simple type: $t @[${subPath.mkString(", ")}] / [${path.mkString(", ")}]")
|
||||
}
|
||||
(
|
||||
fieldPath,
|
||||
pathTypes.get(fieldPath)
|
||||
.map { Typecheck.glb(_, fieldType, schema) }
|
||||
.getOrElse { fieldType }
|
||||
)
|
||||
}
|
||||
._2
|
||||
|
||||
def typeOfVar(v: String): Type =
|
||||
{
|
||||
val t =
|
||||
scope.get(v)
|
||||
.getOrElse {
|
||||
assert(false, s"Unbound variable: $v")
|
||||
} match {
|
||||
case PathVar(path) => typeAtPath(path)
|
||||
case t:Type => t
|
||||
}
|
||||
if(optionalBindings(v)){ Type.Option(t) } else { t }
|
||||
}
|
||||
|
||||
def asScope: Map[String, Type] =
|
||||
scope.map { case (v, t) =>
|
||||
val base = t match {
|
||||
case t:Type => t
|
||||
case PathVar(path) => typeAtPath(path)
|
||||
}
|
||||
v -> (if(optionalBindings(v)){ Type.Option(base) } else { base })
|
||||
}.toMap
|
||||
|
||||
def pathOfVar(v: String): Seq[Int] =
|
||||
scope.get(v)
|
||||
.getOrElse {
|
||||
assert(false, s"Unbound variable: $v")
|
||||
} match {
|
||||
case PathVar(path) => path
|
||||
case t:Type =>
|
||||
assert(false, s"Expected $v to be bound to a path")
|
||||
}
|
||||
|
||||
def bindPath(path: Seq[Int], t: Type): State =
|
||||
copy(
|
||||
pathTypes =
|
||||
pathTypes ++ Map(
|
||||
path ->
|
||||
pathTypes.get(path)
|
||||
.map { Typecheck.glb(_, t, schema) }
|
||||
.getOrElse(t)
|
||||
)
|
||||
)
|
||||
|
||||
def bindVarToPath(v: String, path: Seq[Int]): State =
|
||||
copy(
|
||||
scope = scope ++ Map(
|
||||
v -> PathVar(path)
|
||||
)
|
||||
)
|
||||
|
||||
def bindVarToType(v: String, t: Type): State =
|
||||
copy(
|
||||
scope = scope ++ Map(
|
||||
v -> t
|
||||
)
|
||||
)
|
||||
|
||||
def or(other: State): State =
|
||||
{
|
||||
State(
|
||||
scope =
|
||||
(scope.keySet ++ other.scope.keySet).toSeq.flatMap { v =>
|
||||
(scope.get(v), other.scope.get(v)) match {
|
||||
case (Some(myT:Type), Some(otherT:Type)) =>
|
||||
Some(v -> Typecheck.lub(myT, otherT, schema))
|
||||
case (Some(myT:Type), Some(PathVar(otherVPath))) =>
|
||||
Some(v -> Typecheck.lub(myT, other.typeAtPath(otherVPath), schema))
|
||||
case (Some(PathVar(myVPath)), Some(otherT:Type)) =>
|
||||
Some(v -> Typecheck.lub(typeAtPath(myVPath), otherT, schema))
|
||||
case (Some(PathVar(myVPath)), Some(PathVar(otherVPath))) =>
|
||||
if(myVPath != otherVPath){
|
||||
Some(v -> Typecheck.lub(typeAtPath(myVPath), other.typeAtPath(otherVPath), schema))
|
||||
} else {
|
||||
Some(v -> PathVar(myVPath))
|
||||
}
|
||||
case (Some(t:Type), None) =>
|
||||
Some(v -> t) // optional fields get combined later
|
||||
case (Some(PathVar(path)), None) =>
|
||||
Some(v -> typeAtPath(path)) // optional fields get combined later
|
||||
case (None, Some(t:Type)) =>
|
||||
Some(v -> t) // optional fields get combined later
|
||||
case (None, Some(PathVar(path))) =>
|
||||
Some(v -> other.typeAtPath(path)) // optional fields get combined later
|
||||
case (None, None) =>
|
||||
None
|
||||
}
|
||||
}.toMap,
|
||||
pathTypes =
|
||||
(pathTypes.keySet ++ other.pathTypes.keySet).toSeq.flatMap { p =>
|
||||
(pathTypes.get(p), other.pathTypes.get(p)) match {
|
||||
case (Some(t), Some(ot)) =>
|
||||
Typecheck.lubOption(t, ot, schema).map { p -> _ }
|
||||
case (_, None) =>
|
||||
None
|
||||
case (None, _) =>
|
||||
None
|
||||
}
|
||||
}.toMap,
|
||||
optionalBindings =
|
||||
optionalBindings ++ other.optionalBindings ++
|
||||
((scope.keySet ++ other.scope.keySet)
|
||||
.filter { v =>
|
||||
!(scope contains v) || !(other.scope contains v)
|
||||
}
|
||||
),
|
||||
schema = schema
|
||||
)
|
||||
}
|
||||
def narrowToPrefix(prefix: Seq[Int]): State =
|
||||
copy(
|
||||
scope =
|
||||
scope.mapValues {
|
||||
case t:Type => t
|
||||
case PathVar(p) =>
|
||||
(
|
||||
if(p startsWith prefix) { PathVar(p.drop(prefix.length)) }
|
||||
else { typeAtPath(p) }
|
||||
):(PathVar|Type)
|
||||
}.toMap[String, PathVar|Type],
|
||||
pathTypes =
|
||||
pathTypes.filterKeys { _ startsWith prefix }
|
||||
.map { case (p, v) => p.drop(prefix.length) -> v }
|
||||
.toMap
|
||||
)
|
||||
|
||||
def expandPrefix(prefix: Seq[Int], target: State): State =
|
||||
{
|
||||
val cmp = narrowToPrefix(prefix)
|
||||
copy(
|
||||
scope =
|
||||
scope ++ ((
|
||||
// include only modifications to the schema
|
||||
target.scope.toSet -- cmp.scope.toSet
|
||||
).toMap.mapValues {
|
||||
case t:Type => t:(Type|PathVar)
|
||||
case PathVar(p) => PathVar(prefix ++ p):(Type|PathVar)
|
||||
}),
|
||||
pathTypes =
|
||||
pathTypes ++
|
||||
target.pathTypes.map {
|
||||
case (path, t) =>
|
||||
(prefix ++ path) -> t
|
||||
},
|
||||
optionalBindings =
|
||||
target.optionalBindings
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
object State
|
||||
{
|
||||
def empty(schema: Definition) =
|
||||
new State(schema.globals, Map.empty, Set.empty, schema)
|
||||
}
|
||||
|
||||
case class MatchTypecheckError(msg: String, stack: List[Match]) extends Exception
|
||||
{
|
||||
override def getMessage(): String =
|
||||
stack.map { _.toString }.mkString("\n ,--in--^\n")+"\n ,--in--^\n"+msg
|
||||
}
|
||||
|
||||
def apply(pattern: Match, family: String, schema: Definition, scope: Map[String, Type]): Map[String, Type] =
|
||||
apply(pattern, None, Type.AST(family), schema, scope)
|
||||
|
||||
def apply(pattern: Match, target: Option[String], expectedType: Type, schema: Definition, scope: Map[String, Type]): Map[String, Type] =
|
||||
def check(
|
||||
pattern: Match,
|
||||
targetPath: Seq[Int],
|
||||
state: State,
|
||||
): State =
|
||||
{
|
||||
def checkAllChildren() =
|
||||
pattern.children.foreach { apply(_, target, expectedType, schema, scope) }
|
||||
def schema = state.schema
|
||||
|
||||
def astSchema(label: String): Node =
|
||||
{
|
||||
|
@ -27,7 +255,7 @@ object TypecheckMatch
|
|||
}
|
||||
|
||||
|
||||
def astSchemaWithFamily(family: String, label: String): Node =
|
||||
def astSchemaWithFamily(family: Type.AST, label: String): Node =
|
||||
{
|
||||
assert(schema.nodes contains family, s"AST Family '$family' is not defined")
|
||||
val astSchema = schema.nodes(family)
|
||||
|
@ -36,54 +264,44 @@ object TypecheckMatch
|
|||
node.get
|
||||
}
|
||||
|
||||
def typeOfCurrentNode: Type =
|
||||
state.typeAtPath(targetPath)
|
||||
|
||||
|
||||
try {
|
||||
pattern match {
|
||||
case Match.Not(child) => apply(child, target, expectedType, schema, scope)
|
||||
case Match.Not(child) =>
|
||||
check(child, targetPath, state)
|
||||
|
||||
case Match.And(children) =>
|
||||
children.foldLeft(scope) { (scope, child) =>
|
||||
apply(child, target, expectedType, schema, scope)
|
||||
children.foldLeft(state) { (state, child) =>
|
||||
check(child, targetPath, state)
|
||||
}
|
||||
|
||||
case Match.Or(children) =>
|
||||
children.flatMap { apply(_, target, expectedType, schema, scope).toSeq }
|
||||
.groupBy { _._1 }
|
||||
.mapValues { childTypes =>
|
||||
val base =
|
||||
Type.union(childTypes.map { _._2 })
|
||||
if(childTypes.size == children.size){ base }
|
||||
else { Type.Option(base) }
|
||||
}
|
||||
.toMap
|
||||
case Match.Or(Seq()) =>
|
||||
state
|
||||
|
||||
case Match.Or(children) =>
|
||||
children.tail
|
||||
.foldLeft(check(children.head, targetPath, state)) {
|
||||
(stateAccum, child) =>
|
||||
stateAccum or check(child, targetPath, state)
|
||||
}
|
||||
|
||||
case Match.OfType(t) =>
|
||||
assert(Typecheck.escalatesTo(t, expectedType, schema),
|
||||
s"Matching a node by type; Type $t can not possibly matched by an element of type $expectedType: $pattern"
|
||||
assert(Typecheck.escalatesTo(t, typeOfCurrentNode, schema),
|
||||
s"Matching a node by type; Type $t can not possibly matched by an element of type $typeOfCurrentNode: @[${targetPath.mkString(", ")}]: $pattern"
|
||||
)
|
||||
scope ++ target.map { _ -> t }.toMap
|
||||
state.bindPath(targetPath, t)
|
||||
|
||||
case Match.Exact(child) =>
|
||||
assert(Typecheck.escalatesTo(child.t, expectedType, schema),
|
||||
s"Matching a node by value; Value $child can not possibly match an element of type $expectedType: $pattern"
|
||||
assert(Typecheck.escalatesTo(child.t, typeOfCurrentNode, schema),
|
||||
s"Matching a node by value; Value $child can not possibly match an element of type $typeOfCurrentNode: @[${targetPath.mkString(", ")}]: $pattern"
|
||||
)
|
||||
scope
|
||||
state
|
||||
|
||||
case Match.Path(path, child) =>
|
||||
val targetType =
|
||||
path.foldLeft(expectedType) { (t, field) =>
|
||||
t match {
|
||||
case Type.Node(label) =>
|
||||
val nodeSchema = astSchema(label)
|
||||
assert(nodeSchema.fields.size > field, s"Can't path into the $field'th element of a $t: $pattern")
|
||||
nodeSchema.fields(field).t
|
||||
case Type.Array(base) =>
|
||||
base
|
||||
case _ =>
|
||||
assert(false, s"Can't path into a simple type: $t: $pattern")
|
||||
}
|
||||
}
|
||||
apply(child, target, targetType, schema, scope)
|
||||
check(child, targetPath ++ path, state)
|
||||
|
||||
case Match.Recursive(_) =>
|
||||
// not entirely sure how to typecheck this...
|
||||
|
@ -91,56 +309,63 @@ object TypecheckMatch
|
|||
???
|
||||
|
||||
case Match.Bind(symbol, child) =>
|
||||
apply(child, Some(symbol), expectedType, schema, scope ++ Map(symbol -> expectedType))
|
||||
check(child, targetPath, state.bindVarToPath(symbol, targetPath))
|
||||
|
||||
case Match.Any =>
|
||||
scope
|
||||
state
|
||||
|
||||
case Match.Fail =>
|
||||
Map.empty
|
||||
State.empty(schema)
|
||||
|
||||
case Match.Lookup(symbol) =>
|
||||
assert(scope contains symbol, s"$symbol is not bound: $pattern")
|
||||
assert(Typecheck.escalatesTo(scope(symbol), expectedType, schema), s"$symbol (of type ${scope(symbol)}) can not possibly match an element of type $expectedType")
|
||||
scope
|
||||
assert(Typecheck.escalatesTo(
|
||||
state.typeOfVar(symbol),
|
||||
typeOfCurrentNode,
|
||||
schema
|
||||
), s"$symbol (of type ${state.typeOfVar(symbol)}) can not possibly match an element of type $typeOfCurrentNode")
|
||||
state
|
||||
|
||||
case Match.ApplyToScope(symbol, child) =>
|
||||
assert(scope contains symbol, s"$symbol is not bound: $pattern")
|
||||
apply(child, target, scope(symbol), schema, scope)
|
||||
state.expandPrefix(
|
||||
state.pathOfVar(symbol),
|
||||
check(
|
||||
child,
|
||||
Seq.empty,
|
||||
state.narrowToPrefix(targetPath)
|
||||
)
|
||||
)
|
||||
|
||||
case Match.Forall(child) =>
|
||||
expectedType match {
|
||||
case Type.Array(base) =>
|
||||
apply(child, target, base, schema, scope)
|
||||
case _ =>
|
||||
assert(false, s"Forall match on something other than an array: $pattern")
|
||||
}
|
||||
check(child, targetPath :+ 0, state)
|
||||
|
||||
case Match.Test(op) =>
|
||||
assert(Typecheck.escalatesTo(TypecheckExpression(op, schema, scope), Type.Bool, schema),
|
||||
assert(Typecheck.escalatesTo(TypecheckExpression(op, schema, state.asScope), Type.Bool, schema),
|
||||
s"Non-boolean test expression: $op"
|
||||
)
|
||||
scope
|
||||
state
|
||||
|
||||
case Match.BindExpression(symbol, op) =>
|
||||
scope ++ Map(
|
||||
symbol -> TypecheckExpression(op, schema, scope)
|
||||
)
|
||||
val exprType = TypecheckExpression(op, schema, state.asScope)
|
||||
state.bindVarToType(symbol, exprType)
|
||||
|
||||
case Match.Node(label, fields) =>
|
||||
assert(expectedType.isInstanceOf[Type.AST], s"Matching a node when $expectedType was expected: $pattern")
|
||||
val family = expectedType.asInstanceOf[Type.AST].family
|
||||
val nodeSchema = astSchemaWithFamily(family, label)
|
||||
assert(nodeSchema.fields.size == fields.size, s"Node type $family.$label expects ${nodeSchema.fields.size} but pattern has ${fields.size}: $pattern")
|
||||
var runningScope = scope
|
||||
for( (Field(fieldName, fieldType), fieldMatch) <- nodeSchema.fields.zip(fields) )
|
||||
{
|
||||
runningScope = apply(fieldMatch, target, fieldType, schema, runningScope)
|
||||
assert(
|
||||
Typecheck.escalatesTo(
|
||||
Type.Node(label),
|
||||
typeOfCurrentNode,
|
||||
schema
|
||||
),
|
||||
s"Can't possibly match a Node[$label] with an item of type $typeOfCurrentNode: $pattern"
|
||||
)
|
||||
fields.zipWithIndex.foldLeft(
|
||||
state.bindPath(targetPath, Type.Node(label))
|
||||
) { case (state, (fieldPattern, fieldIdx)) =>
|
||||
check(fieldPattern, targetPath :+ fieldIdx, state)
|
||||
}
|
||||
runningScope ++ target.map { _ -> Type.Node(label) }.toMap
|
||||
}
|
||||
} catch {
|
||||
case e: AssertionError =>
|
||||
// e.printStackTrace()
|
||||
throw new MatchTypecheckError(e.getMessage(), pattern::Nil)
|
||||
case TypecheckExpression.ExpressionTypecheckError(msg, stack) =>
|
||||
throw new MatchTypecheckError(s"${stack.head}\n ,--in--^\n$msg", pattern::Nil)
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
@import com.astraldb.spec.Definition
|
||||
@import com.astraldb.spec.Type
|
||||
@import com.astraldb.bdd.BDD
|
||||
@import com.astraldb.codegen
|
||||
@import com.astraldb.codegen.Code
|
||||
|
||||
@(schema: Definition, family: Type.AST, bdd: BDD)
|
||||
|
||||
object @{family.scalaType}BDD extends Rule[@{family.scalaType}]
|
||||
{
|
||||
def apply(plan: @{family.scalaType}): @{family.scalaType} =
|
||||
{
|
||||
@{codegen.BDD(schema, family, bdd,
|
||||
onFail = Code.Literal("plan"),
|
||||
root = Code.Literal("plan")
|
||||
)
|
||||
.toString(indent = 4)
|
||||
.stripPrefix(" ")}
|
||||
}
|
||||
}
|
|
@ -9,15 +9,14 @@
|
|||
|
||||
@(schema: Definition, rule: Rule)
|
||||
|
||||
object @{rule.safeLabel} extends Rule[LogicalPlan]
|
||||
object @{rule.safeLabel} extends Rule[@{rule.family.scalaType}]
|
||||
{
|
||||
def apply(plan: LogicalPlan): LogicalPlan =
|
||||
def apply(plan: @{rule.family.scalaType}): @{rule.family.scalaType} =
|
||||
{
|
||||
@{
|
||||
val matchSchema = TypecheckMatch(
|
||||
rule.pattern,
|
||||
Some("plan"),
|
||||
Type.AST(rule.family),
|
||||
rule.family,
|
||||
schema,
|
||||
schema.globals
|
||||
)
|
||||
|
@ -26,7 +25,8 @@ object @{rule.safeLabel} extends Rule[LogicalPlan]
|
|||
schema = schema,
|
||||
pattern = rule.pattern,
|
||||
target = Code.Literal("plan"),
|
||||
targetType = Type.AST(rule.family),
|
||||
targetPath = Seq.empty,
|
||||
targetType = rule.family,
|
||||
onSuccess = Expression(schema, rule.rewrite, _),
|
||||
onFail = { _ => Code.Literal("plan") },
|
||||
name = Some("plan"),
|
||||
|
|
Loading…
Reference in New Issue