Merge branch 'main' of git.odin.cse.buffalo.edu:Astral/astral-compiler
This commit is contained in:
commit
9e38df1761
2
TODOs.md
2
TODOs.md
|
@ -54,7 +54,7 @@
|
||||||
- [ ] CombineConcats,
|
- [ ] CombineConcats,
|
||||||
- [ ] PushdownPredicatesAndPruneColumnsForCTEDef
|
- [ ] PushdownPredicatesAndPruneColumnsForCTEDef
|
||||||
- [ ] Codegen: Output a Spark-like optimizer based on the above rules
|
- [ ] Codegen: Output a Spark-like optimizer based on the above rules
|
||||||
- [ ] Logic generation
|
- [x] Logic generation
|
||||||
- [ ] Compilation pipeline to streamline testing
|
- [ ] Compilation pipeline to streamline testing
|
||||||
- [ ] Test: Do these rules generate the same results as spark?
|
- [ ] Test: Do these rules generate the same results as spark?
|
||||||
- [ ] Apply the rule merge optimization
|
- [ ] Apply the rule merge optimization
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
package com.astraldb.catalyst
|
||||||
|
|
||||||
|
object OptimizerTest
|
||||||
|
{
|
||||||
|
def main(args: Array[String]): Unit =
|
||||||
|
{
|
||||||
|
println("Hello world")
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,7 @@
|
||||||
package com.astraldb.catalyst
|
package com.astraldb.catalyst
|
||||||
|
|
||||||
|
import com.astraldb.codegen.Render
|
||||||
|
|
||||||
object Astral
|
object Astral
|
||||||
{
|
{
|
||||||
def main(args: Array[String]): Unit =
|
def main(args: Array[String]): Unit =
|
||||||
|
@ -12,6 +14,7 @@ object Astral
|
||||||
System.exit(-1)
|
System.exit(-1)
|
||||||
|
|
||||||
}
|
}
|
||||||
println(Catalyst.definition)
|
// println(Catalyst.definition)
|
||||||
|
Render.print(Catalyst.definition)
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -30,6 +30,52 @@ object Catalyst extends HardcodedDefinition
|
||||||
|
|
||||||
//////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
Function("PushProjectionThroughUnion.projectListIsDeterministic",
|
||||||
|
Type.Bool,
|
||||||
|
)(
|
||||||
|
Type.Any
|
||||||
|
)
|
||||||
|
|
||||||
|
Function("PushProjectionThroughUnion.unionHasChildren",
|
||||||
|
Type.Bool,
|
||||||
|
)(
|
||||||
|
Type.Any
|
||||||
|
)
|
||||||
|
|
||||||
|
Function("PushProjectionThroughUnion.pushProjectionThroughUnion",
|
||||||
|
Type.Array(Type.AST("LogicalPlan"))
|
||||||
|
)(
|
||||||
|
Type.Any,
|
||||||
|
Type.Any
|
||||||
|
)
|
||||||
|
|
||||||
|
Function("ExtractFiltersAndInnerJoins",
|
||||||
|
Type.Struct(Map(
|
||||||
|
"_1" -> Type.Struct(Map("size" -> Type.Int)),
|
||||||
|
"_2" -> Type.Struct(Map("nonEmpty" -> Type.Bool)),
|
||||||
|
))
|
||||||
|
)(
|
||||||
|
Type.AST("LogicalPlan")
|
||||||
|
)
|
||||||
|
|
||||||
|
Function("ReorderJoin.rewrite", Type.Any)(
|
||||||
|
Type.AST("LogicalPlan")
|
||||||
|
)
|
||||||
|
|
||||||
|
Function("EliminateOuterJoin.buildNewJoinType",
|
||||||
|
Type.Union(Seq(
|
||||||
|
Type.Native("RightOuter"),
|
||||||
|
Type.Native("LeftOuter"),
|
||||||
|
Type.Native("FullOuter"),
|
||||||
|
))
|
||||||
|
)(
|
||||||
|
Type.AST("LogicalPlan")
|
||||||
|
)
|
||||||
|
|
||||||
|
Global("JoinHint.NONE", Type.Native("JoinHint"))
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////
|
||||||
|
|
||||||
Rule("PushProjectionThroughUnion", "LogicalPlan")(
|
Rule("PushProjectionThroughUnion", "LogicalPlan")(
|
||||||
Match("Project")(
|
Match("Project")(
|
||||||
Bind("projectList", MatchAny),
|
Bind("projectList", MatchAny),
|
||||||
|
|
9
astral/catalyst/src/com/astraldb/catalyst/Generate.scala
Normal file
9
astral/catalyst/src/com/astraldb/catalyst/Generate.scala
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
package com.astraldb.catalyst
|
||||||
|
|
||||||
|
object Generate
|
||||||
|
{
|
||||||
|
def main(args: Array[String]): Unit =
|
||||||
|
{
|
||||||
|
println("Optimizer.sql")
|
||||||
|
}
|
||||||
|
}
|
276
astral/compiler/src/com/astraldb/codegen/Code.scala
Normal file
276
astral/compiler/src/com/astraldb/codegen/Code.scala
Normal file
|
@ -0,0 +1,276 @@
|
||||||
|
package com.astraldb.codegen
|
||||||
|
|
||||||
|
|
||||||
|
sealed trait Code
|
||||||
|
{
|
||||||
|
def render(indent: Int): Code.PaddedString|Code.Lines
|
||||||
|
|
||||||
|
def renderLines(indent: Int): Code.Lines =
|
||||||
|
render(indent) match {
|
||||||
|
case x:Code.PaddedString =>
|
||||||
|
Code.Lines.ofLine(indent, x)
|
||||||
|
case x:Code.Lines => x
|
||||||
|
}
|
||||||
|
|
||||||
|
def block =
|
||||||
|
this match {
|
||||||
|
case Code.Block(lines) => lines
|
||||||
|
case _ => Seq(this)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def toString(): String =
|
||||||
|
renderLines(0).body.mkString("\n")
|
||||||
|
def toString(indent: Int): String =
|
||||||
|
renderLines(indent).body.mkString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
object Code
|
||||||
|
{
|
||||||
|
class PaddedString(val body: String, val lPad: Int = 0, val rPad: Int = 0)
|
||||||
|
|
||||||
|
object PaddedString
|
||||||
|
{
|
||||||
|
def apply(elems: (String|PaddedString)*): PaddedString =
|
||||||
|
{
|
||||||
|
if(elems.isEmpty){ new PaddedString("") }
|
||||||
|
else {
|
||||||
|
val (buffer, lPad, rPad) =
|
||||||
|
elems.head match {
|
||||||
|
case s: String =>
|
||||||
|
(StringBuffer(s), 0, 0)
|
||||||
|
case p: PaddedString =>
|
||||||
|
(StringBuffer(p.body), p.lPad, p.rPad)
|
||||||
|
}
|
||||||
|
var pad = rPad
|
||||||
|
for(e <- elems.tail)
|
||||||
|
{
|
||||||
|
if(pad > 0){ buffer.append(" " * pad) }
|
||||||
|
e match {
|
||||||
|
case s:String =>
|
||||||
|
buffer.append(s)
|
||||||
|
pad = 0
|
||||||
|
case p:PaddedString =>
|
||||||
|
buffer.append(" "*Math.max(0, p.lPad - pad))
|
||||||
|
buffer.append(p.body)
|
||||||
|
pad = p.rPad
|
||||||
|
}
|
||||||
|
}
|
||||||
|
new PaddedString(buffer.toString, lPad, pad)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def leftPad(pad: Int, body: String): PaddedString =
|
||||||
|
new PaddedString(body, lPad = pad)
|
||||||
|
def rightPad(pad: Int, body: String): PaddedString =
|
||||||
|
new PaddedString(body, rPad = pad)
|
||||||
|
def pad(pad: Int, body: String): PaddedString =
|
||||||
|
new PaddedString(body, lPad = pad, rPad = pad)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
case class Lines(body: Seq[String])
|
||||||
|
{
|
||||||
|
def ++(other: Lines): Lines =
|
||||||
|
new Lines(body ++ other.body)
|
||||||
|
}
|
||||||
|
object Lines
|
||||||
|
{
|
||||||
|
def ofLine(indent: Int, elems: (String|PaddedString)*): Lines =
|
||||||
|
{
|
||||||
|
val p = PaddedString(elems:_*)
|
||||||
|
new Lines(Seq(
|
||||||
|
" " * Math.max(indent, p.lPad) + p.body
|
||||||
|
))
|
||||||
|
}
|
||||||
|
def ofLines(elems: (String|Lines)*): Lines =
|
||||||
|
new Lines(
|
||||||
|
elems.flatMap {
|
||||||
|
case s: String => Seq(s)
|
||||||
|
case Lines(l) => l
|
||||||
|
}
|
||||||
|
)
|
||||||
|
val empty = new Lines(Seq.empty)
|
||||||
|
def flatten(elems: Seq[Lines]): Lines =
|
||||||
|
elems.foldLeft(empty){ _ ++ _ }
|
||||||
|
|
||||||
|
def apply(body: Seq[String]): Lines =
|
||||||
|
new Lines(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
case object Empty extends Code
|
||||||
|
{
|
||||||
|
def render(indent: Int): PaddedString|Lines =
|
||||||
|
PaddedString("")
|
||||||
|
}
|
||||||
|
|
||||||
|
case class Block(
|
||||||
|
lines: Seq[Code],
|
||||||
|
) extends Code
|
||||||
|
{
|
||||||
|
def render(indent: Int): PaddedString|Lines =
|
||||||
|
Lines.flatten(
|
||||||
|
lines.map { _.renderLines(indent) }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
case class Literal(txt: String) extends Code
|
||||||
|
{
|
||||||
|
def render(indent: Int): PaddedString|Lines =
|
||||||
|
PaddedString(txt)
|
||||||
|
}
|
||||||
|
|
||||||
|
case class BinOp(left: Code, op: String|PaddedString, right: Code) extends Code
|
||||||
|
{
|
||||||
|
def render(indent: Int): PaddedString|Lines =
|
||||||
|
(left.render(indent), right.render(indent)) match {
|
||||||
|
case (l:PaddedString, r:PaddedString) =>
|
||||||
|
PaddedString(l, op, r)
|
||||||
|
case (l:PaddedString, r:Lines) =>
|
||||||
|
Lines.ofLine(indent, l, op) ++ r
|
||||||
|
case (l:Lines, r:PaddedString) =>
|
||||||
|
l ++ Lines.ofLine(indent, op, r)
|
||||||
|
case (l:Lines, r:Lines) =>
|
||||||
|
l ++ Lines.ofLine(indent+2, op) ++ r
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case class Parens(left: String|PaddedString, body: Code, right: String|PaddedString) extends Code
|
||||||
|
{
|
||||||
|
def render(indent: Int): PaddedString|Lines =
|
||||||
|
body.render(indent+2) match {
|
||||||
|
case b:PaddedString =>
|
||||||
|
PaddedString(left, b, right)
|
||||||
|
case b:Lines =>
|
||||||
|
Lines.ofLine(indent, left) ++
|
||||||
|
b ++
|
||||||
|
Lines.ofLine(indent, right)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A "block" with a prefix
|
||||||
|
*
|
||||||
|
* Used to represent if-then, while, for, etc... blocks
|
||||||
|
*/
|
||||||
|
case class PrefixedBlock(prefix: Code, lParen: String|PaddedString, rParen: String|PaddedString, body: Code) extends Code
|
||||||
|
{
|
||||||
|
def render(indent: Int): PaddedString|Lines =
|
||||||
|
body.render(indent+2) match {
|
||||||
|
case b:PaddedString => prefix.render(indent) match {
|
||||||
|
case p:PaddedString =>
|
||||||
|
PaddedString(p, lParen, b, rParen)
|
||||||
|
case p:Lines =>
|
||||||
|
p ++ Lines.ofLine(indent, lParen, b, rParen)
|
||||||
|
}
|
||||||
|
case b:Lines =>
|
||||||
|
prefix.renderLines(indent) ++
|
||||||
|
Lines.ofLine(indent, lParen) ++
|
||||||
|
b ++
|
||||||
|
Lines.ofLine(indent, rParen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Two "block"s with a prefix and a joiner
|
||||||
|
*
|
||||||
|
* Used to represent if-then-else
|
||||||
|
*/
|
||||||
|
case class PrefixedBlockPair(prefix: Code, lParen: String|PaddedString, joiner: String|PaddedString, rParen: String|PaddedString, lBody: Code, rBody: Code) extends Code
|
||||||
|
{
|
||||||
|
def render(indent: Int): PaddedString|Lines =
|
||||||
|
{
|
||||||
|
val p = prefix.render(indent)
|
||||||
|
val l = lBody.render(indent+2)
|
||||||
|
val r = rBody.render(indent+2)
|
||||||
|
|
||||||
|
val leftNeedsTrailingRParen =
|
||||||
|
p.isInstanceOf[Lines] || l.isInstanceOf[Lines]
|
||||||
|
|
||||||
|
(p match {
|
||||||
|
case pp: PaddedString =>
|
||||||
|
l match {
|
||||||
|
case lp: PaddedString =>
|
||||||
|
Lines.ofLine(indent, pp, lParen, lp, rParen)
|
||||||
|
case ll: Lines =>
|
||||||
|
Lines.ofLine(indent, pp, lParen) ++
|
||||||
|
ll
|
||||||
|
}
|
||||||
|
case pl: Lines =>
|
||||||
|
l match {
|
||||||
|
case lp: PaddedString =>
|
||||||
|
pl ++
|
||||||
|
Lines.ofLine(indent, lParen) ++
|
||||||
|
Lines.ofLine(indent, lp)
|
||||||
|
case ll: Lines =>
|
||||||
|
pl ++
|
||||||
|
Lines.ofLine(indent, lParen) ++
|
||||||
|
ll
|
||||||
|
}
|
||||||
|
}) ++ ((r, leftNeedsTrailingRParen) match {
|
||||||
|
case (rp: PaddedString, false) =>
|
||||||
|
Lines.ofLine(indent, joiner, lParen, rp, rParen)
|
||||||
|
case (rp: PaddedString, true) =>
|
||||||
|
Lines.ofLine(indent, rParen, joiner, lParen, rp, rParen)
|
||||||
|
case (rp: Lines, false) =>
|
||||||
|
Lines.ofLine(indent, joiner, lParen) ++
|
||||||
|
rp ++
|
||||||
|
Lines.ofLine(indent, rParen)
|
||||||
|
case (rp: Lines, true) =>
|
||||||
|
Lines.ofLine(indent, rParen, joiner, lParen) ++
|
||||||
|
rp ++
|
||||||
|
Lines.ofLine(indent, rParen)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def IfThenElse(condition: Code, thenBlock: Code, elseBlock: Code): Code =
|
||||||
|
Code.PrefixedBlockPair(
|
||||||
|
prefix = Code.Parens("if(",condition,")"),
|
||||||
|
lParen = PaddedString.rightPad(1, "{"),
|
||||||
|
rParen = PaddedString.leftPad(1, "}"),
|
||||||
|
joiner = PaddedString.pad(1, "else"),
|
||||||
|
lBody = thenBlock,
|
||||||
|
rBody = elseBlock
|
||||||
|
)
|
||||||
|
|
||||||
|
case class List(sep: String|PaddedString, elems: Seq[Code]) extends Code
|
||||||
|
{
|
||||||
|
def render(indent: Int): PaddedString|Lines =
|
||||||
|
{
|
||||||
|
val renderedElems = elems.map { _.render(indent+2) }
|
||||||
|
val multiLine =
|
||||||
|
renderedElems.exists { _.isInstanceOf[Lines] }
|
||||||
|
|| renderedElems.map { _.asInstanceOf[PaddedString].body.size }.sum > 100
|
||||||
|
if(multiLine){
|
||||||
|
Lines(
|
||||||
|
renderedElems.flatMap {
|
||||||
|
case p:PaddedString =>
|
||||||
|
Lines.ofLine(indent, p, sep).body
|
||||||
|
case Lines(Seq()) =>
|
||||||
|
Lines.ofLine(indent, sep).body
|
||||||
|
case Lines(Seq(l)) =>
|
||||||
|
Seq(PaddedString(l, sep).body)
|
||||||
|
case Lines(l) =>
|
||||||
|
Seq(l.head) ++
|
||||||
|
l.tail.dropRight(1).map { " " + _ } ++
|
||||||
|
Seq(PaddedString(" ",l.last,sep).body)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
renderedElems.map { _.asInstanceOf[PaddedString] } match {
|
||||||
|
case Seq() => PaddedString("")
|
||||||
|
case Seq(a) => a
|
||||||
|
case a =>
|
||||||
|
PaddedString( (
|
||||||
|
Seq(a.head) ++ a.tail.flatMap { Seq(sep, _) }
|
||||||
|
).toSeq:_*)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def Parenthesize(body: Code): Code =
|
||||||
|
Parens("(", body, ")")
|
||||||
|
}
|
64
astral/compiler/src/com/astraldb/codegen/Expression.scala
Normal file
64
astral/compiler/src/com/astraldb/codegen/Expression.scala
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
package com.astraldb.codegen
|
||||||
|
|
||||||
|
import com.astraldb.spec
|
||||||
|
import com.astraldb.expression._
|
||||||
|
import com.astraldb.expression.{ Expression => Expr}
|
||||||
|
import com.astraldb.codegen.Code.PaddedString
|
||||||
|
import com.astraldb.typecheck.TypecheckExpression
|
||||||
|
import com.astraldb.spec.Type
|
||||||
|
|
||||||
|
object Expression
|
||||||
|
{
|
||||||
|
def apply(schema: spec.Definition, op: Expr, scope: Map[String, spec.Type]): Code =
|
||||||
|
{
|
||||||
|
op match {
|
||||||
|
case c:SimpleConstant => Code.Literal(c.asScala)
|
||||||
|
case Var(v) => Code.Literal(v)
|
||||||
|
case Arith(t, lhs, rhs) =>
|
||||||
|
Code.BinOp(
|
||||||
|
Code.Parenthesize(apply(schema, lhs, scope)),
|
||||||
|
PaddedString.pad(1, ArithTypes.opString(t)),
|
||||||
|
Code.Parenthesize(apply(schema, rhs, scope))
|
||||||
|
)
|
||||||
|
case Cmp(t, lhs, rhs) =>
|
||||||
|
Code.BinOp(
|
||||||
|
Code.Parenthesize(apply(schema, lhs, scope)),
|
||||||
|
PaddedString.pad(1, CmpTypes.opString(t)),
|
||||||
|
Code.Parenthesize(apply(schema, rhs, scope))
|
||||||
|
)
|
||||||
|
case FunctionCall(fn, args) =>
|
||||||
|
Code.Literal(op.toString)
|
||||||
|
Code.PrefixedBlock(
|
||||||
|
prefix = Code.Literal(fn),
|
||||||
|
lParen = "(",
|
||||||
|
rParen = ")",
|
||||||
|
body = Code.List(
|
||||||
|
sep = PaddedString.rightPad(1, ","),
|
||||||
|
elems = args.map { apply(schema, _, scope) }
|
||||||
|
)
|
||||||
|
)
|
||||||
|
case MakeNode(node, fields) =>
|
||||||
|
apply(schema, FunctionCall(node, fields), scope)
|
||||||
|
case NodeSubscript(target, index) =>
|
||||||
|
Code.BinOp(
|
||||||
|
apply(schema, target, scope),
|
||||||
|
".",
|
||||||
|
TypecheckExpression(target, schema, scope) match {
|
||||||
|
case Type.Node(nodeType) =>
|
||||||
|
Code.Literal(
|
||||||
|
schema.nodesByName(nodeType).fields(index).name
|
||||||
|
)
|
||||||
|
case c =>
|
||||||
|
assert(false, s"Node subscript on something not a node: $op (in ${scope}): ${c.getClass.getSimpleName}")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
case StructSubscript(target, name) =>
|
||||||
|
Code.BinOp(
|
||||||
|
apply(schema, target, scope),
|
||||||
|
".",
|
||||||
|
Code.Literal(name)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
140
astral/compiler/src/com/astraldb/codegen/Match.scala
Normal file
140
astral/compiler/src/com/astraldb/codegen/Match.scala
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
package com.astraldb.codegen
|
||||||
|
|
||||||
|
import com.astraldb.spec
|
||||||
|
|
||||||
|
import spec.Match._
|
||||||
|
import com.astraldb.codegen.Code.PaddedString
|
||||||
|
import com.astraldb.typecheck.TypecheckMatch
|
||||||
|
import com.astraldb.spec.Type
|
||||||
|
|
||||||
|
object Match
|
||||||
|
{
|
||||||
|
def apply(
|
||||||
|
schema: spec.Definition,
|
||||||
|
pattern: spec.Match,
|
||||||
|
target: Code,
|
||||||
|
targetType: spec.Type,
|
||||||
|
onSuccess: Code,
|
||||||
|
onFail: Code,
|
||||||
|
name: Option[String],
|
||||||
|
scope: Map[String, spec.Type]
|
||||||
|
): Code =
|
||||||
|
{
|
||||||
|
pattern match {
|
||||||
|
case And(Seq()) => onSuccess
|
||||||
|
case And(Seq(a)) => apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
|
||||||
|
case And(a) =>
|
||||||
|
apply(schema, a.head,
|
||||||
|
target = target,
|
||||||
|
targetType = targetType,
|
||||||
|
onSuccess =
|
||||||
|
apply(schema, And(a.tail), target, targetType, onSuccess, onFail, name,
|
||||||
|
scope = TypecheckMatch(a.head, name, targetType, schema, scope)
|
||||||
|
),
|
||||||
|
onFail = onFail,
|
||||||
|
name = name,
|
||||||
|
scope = scope
|
||||||
|
)
|
||||||
|
case Not(a) => apply(schema, a, target, targetType, onFail, onSuccess, name, scope)
|
||||||
|
case Or(Seq()) => onFail
|
||||||
|
case Or(Seq(a)) =>
|
||||||
|
apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
|
||||||
|
case Or(a) =>
|
||||||
|
apply(schema, a.head,
|
||||||
|
target = target,
|
||||||
|
targetType = targetType,
|
||||||
|
onSuccess = onSuccess,
|
||||||
|
onFail = apply(schema, Or(a.tail), target, targetType, onSuccess, onFail, name, scope),
|
||||||
|
name = name,
|
||||||
|
scope = scope
|
||||||
|
)
|
||||||
|
case Bind(symbol, pattern) =>
|
||||||
|
Code.Block(
|
||||||
|
Code.BinOp(
|
||||||
|
Code.Literal(s"val $symbol"),
|
||||||
|
Code.PaddedString.pad(1, "="),
|
||||||
|
target
|
||||||
|
) +:
|
||||||
|
apply(
|
||||||
|
schema = schema,
|
||||||
|
pattern = pattern,
|
||||||
|
target = Code.Literal(symbol),
|
||||||
|
targetType = targetType,
|
||||||
|
onSuccess = onSuccess,
|
||||||
|
onFail = onFail,
|
||||||
|
name = Some(symbol),
|
||||||
|
scope = scope ++ Map(symbol -> targetType)
|
||||||
|
).block
|
||||||
|
)
|
||||||
|
case BindExpression(symbol, op) =>
|
||||||
|
Code.Block(
|
||||||
|
Seq(
|
||||||
|
Code.BinOp(
|
||||||
|
Code.Literal(s"val $symbol"),
|
||||||
|
Code.PaddedString.pad(1, "="),
|
||||||
|
Expression(schema, op, scope)
|
||||||
|
)
|
||||||
|
) ++ onSuccess.block
|
||||||
|
)
|
||||||
|
case Node(nodeLabel, children) =>
|
||||||
|
{
|
||||||
|
val selectedName =
|
||||||
|
name.getOrElse { "genericNode" }
|
||||||
|
val node = schema.nodesByName(nodeLabel)
|
||||||
|
Code.IfThenElse(
|
||||||
|
condition =
|
||||||
|
Code.BinOp(
|
||||||
|
target,
|
||||||
|
".",
|
||||||
|
Code.Literal(s"isInstanceOf[$nodeLabel]")
|
||||||
|
),
|
||||||
|
thenBlock =
|
||||||
|
Code.Block(
|
||||||
|
Seq(
|
||||||
|
Code.BinOp(
|
||||||
|
Code.Literal(s"val $selectedName"),
|
||||||
|
Code.PaddedString.pad(1, "="),
|
||||||
|
Code.BinOp(
|
||||||
|
target,
|
||||||
|
".",
|
||||||
|
Code.Literal(s"asInstanceOf[$nodeLabel]")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
) ++
|
||||||
|
children
|
||||||
|
.zip(node.fields)
|
||||||
|
.foldRight(onSuccess) { case ((child, field), andThen) =>
|
||||||
|
apply(
|
||||||
|
schema = schema,
|
||||||
|
pattern = child,
|
||||||
|
target = Code.Literal(s"$selectedName.${field.name}"),
|
||||||
|
targetType = Type.Node(nodeLabel),
|
||||||
|
onSuccess = andThen,
|
||||||
|
onFail = onFail,
|
||||||
|
name = Some(s"${selectedName}_${field.name}"),
|
||||||
|
scope = scope ++ Map(selectedName -> Type.Node(nodeLabel))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
.block
|
||||||
|
),
|
||||||
|
elseBlock = onFail
|
||||||
|
)
|
||||||
|
}
|
||||||
|
case Test(op) =>
|
||||||
|
Code.IfThenElse(
|
||||||
|
condition = Expression(schema, op, scope),
|
||||||
|
thenBlock = onSuccess,
|
||||||
|
elseBlock = onFail
|
||||||
|
)
|
||||||
|
case Any =>
|
||||||
|
onSuccess
|
||||||
|
case OfType(nodeType) =>
|
||||||
|
Code.IfThenElse(
|
||||||
|
condition = Code.BinOp(target, ".", Code.Literal(s"isInstanceOf[${nodeType.scalaType}]")),
|
||||||
|
thenBlock = onSuccess,
|
||||||
|
elseBlock = onFail
|
||||||
|
)
|
||||||
|
case _ => Code.Literal(s"??${pattern.getClass.getSimpleName}??")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
11
astral/compiler/src/com/astraldb/codegen/Optimizer.scala
Normal file
11
astral/compiler/src/com/astraldb/codegen/Optimizer.scala
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
package com.astraldb.codegen
|
||||||
|
|
||||||
|
import com.astraldb.spec.Definition
|
||||||
|
|
||||||
|
object Optimizer
|
||||||
|
{
|
||||||
|
def apply(definition: Definition): String =
|
||||||
|
{
|
||||||
|
scala.Optimizer(definition).toString
|
||||||
|
}
|
||||||
|
}
|
25
astral/compiler/src/com/astraldb/codegen/Render.scala
Normal file
25
astral/compiler/src/com/astraldb/codegen/Render.scala
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
package com.astraldb.codegen
|
||||||
|
|
||||||
|
import com.astraldb.spec.Definition
|
||||||
|
|
||||||
|
object Render
|
||||||
|
{
|
||||||
|
def apply(schema: Definition): Map[String, String] =
|
||||||
|
{
|
||||||
|
Map(
|
||||||
|
"Optimizer.scala" -> Optimizer(schema)
|
||||||
|
) ++ schema.rules.map { rule =>
|
||||||
|
s"${rule.safeLabel}.scala" -> Rule(schema, rule)
|
||||||
|
}.toMap
|
||||||
|
}
|
||||||
|
|
||||||
|
def print(schema: Definition): Unit =
|
||||||
|
{
|
||||||
|
|
||||||
|
for( (file,content) <- apply(schema) )
|
||||||
|
{
|
||||||
|
println(s"//////////////////// $file")
|
||||||
|
println(content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
11
astral/compiler/src/com/astraldb/codegen/Rule.scala
Normal file
11
astral/compiler/src/com/astraldb/codegen/Rule.scala
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
package com.astraldb.codegen
|
||||||
|
|
||||||
|
import com.astraldb.spec
|
||||||
|
|
||||||
|
object Rule
|
||||||
|
{
|
||||||
|
def apply(schema: spec.Definition, rule: spec.Rule): String =
|
||||||
|
{
|
||||||
|
scala.Rule(schema, rule).toString
|
||||||
|
}
|
||||||
|
}
|
|
@ -9,6 +9,7 @@ object Eval
|
||||||
class TypeMismatch(problem: Expression, expr: Expression, lhsType: Type, rhsType: Type) extends EvalException(expr)
|
class TypeMismatch(problem: Expression, expr: Expression, lhsType: Type, rhsType: Type) extends EvalException(expr)
|
||||||
class TypeError(problem: Expression, expr: Expression, problemType: Type) extends EvalException(expr)
|
class TypeError(problem: Expression, expr: Expression, problemType: Type) extends EvalException(expr)
|
||||||
class UnsupportedExternalFeature(problem: Expression, expr: Expression) extends EvalException(expr)
|
class UnsupportedExternalFeature(problem: Expression, expr: Expression) extends EvalException(expr)
|
||||||
|
class MissingFieldError(expr: Expression, field: String) extends EvalException(expr)
|
||||||
|
|
||||||
def apply(base: Expression, scope: Map[String, Constant] = Map.empty): Constant =
|
def apply(base: Expression, scope: Map[String, Constant] = Map.empty): Constant =
|
||||||
{
|
{
|
||||||
|
@ -117,12 +118,24 @@ object Eval
|
||||||
elements(index)
|
elements(index)
|
||||||
case c => throw new TypeError(e, base, c.t)
|
case c => throw new TypeError(e, base, c.t)
|
||||||
}
|
}
|
||||||
|
case StructSubscript(target, name) =>
|
||||||
|
rcr(target, scope) match {
|
||||||
|
case StructConstant(elements) =>
|
||||||
|
elements.find { _._1 == name }
|
||||||
|
.getOrElse {
|
||||||
|
throw new MissingFieldError(e, name)
|
||||||
|
}
|
||||||
|
._2
|
||||||
|
case c => throw new TypeError(e, base, c.t)
|
||||||
|
}
|
||||||
case FunctionCall(name, args) =>
|
case FunctionCall(name, args) =>
|
||||||
throw new UnsupportedExternalFeature(e, base)
|
throw new UnsupportedExternalFeature(e, base)
|
||||||
case MakeNode(nodeType, fields) =>
|
case MakeNode(nodeType, fields) =>
|
||||||
NodeConstant(nodeType, fields.map { rcr(_, scope) })
|
NodeConstant(nodeType, fields.map { rcr(_, scope) })
|
||||||
case MakeArray(elements) =>
|
case MakeArray(elements) =>
|
||||||
ArrayConstant(elements.map { rcr(_, scope) })
|
ArrayConstant(elements.map { rcr(_, scope) })
|
||||||
|
case MakeStruct(elements) =>
|
||||||
|
StructConstant(elements.map { case (f, e) => f -> rcr(e, scope) })
|
||||||
case Let(symbol, value, rest) =>
|
case Let(symbol, value, rest) =>
|
||||||
rcr(rest, scope ++ Map(symbol -> rcr(value, scope)))
|
rcr(rest, scope ++ Map(symbol -> rcr(value, scope)))
|
||||||
}
|
}
|
||||||
|
|
|
@ -109,6 +109,7 @@ sealed trait SimpleConstant extends Constant
|
||||||
{
|
{
|
||||||
def fields: Seq[Constant] = Seq.empty
|
def fields: Seq[Constant] = Seq.empty
|
||||||
def reassembleFields(in: Seq[Constant]): Expression = this
|
def reassembleFields(in: Seq[Constant]): Expression = this
|
||||||
|
def asScala: String = toString
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* A constant with components (Nodes, Arrays)
|
* A constant with components (Nodes, Arrays)
|
||||||
|
|
|
@ -5,7 +5,8 @@ import com.astraldb.expression._
|
||||||
|
|
||||||
case class Definition(
|
case class Definition(
|
||||||
nodes:Map[String, Map[String, Node]],
|
nodes:Map[String, Map[String, Node]],
|
||||||
rules:Seq[Rule]
|
rules:Seq[Rule],
|
||||||
|
globals: Map[String, Type],
|
||||||
) {
|
) {
|
||||||
override def toString =
|
override def toString =
|
||||||
s"""/////// ASTs //////
|
s"""/////// ASTs //////
|
||||||
|
@ -18,6 +19,8 @@ case class Definition(
|
||||||
nodes.flatMap { case (family, elements) => elements.keys.map { _ -> family } }
|
nodes.flatMap { case (family, elements) => elements.keys.map { _ -> family } }
|
||||||
.toMap
|
.toMap
|
||||||
|
|
||||||
|
val nodesByName: Map[String, Node] =
|
||||||
|
nodes.values.flatten.toMap
|
||||||
|
|
||||||
def validate() =
|
def validate() =
|
||||||
{
|
{
|
||||||
|
@ -30,11 +33,13 @@ class HardcodedDefinition
|
||||||
lazy val definition: Definition =
|
lazy val definition: Definition =
|
||||||
Definition(
|
Definition(
|
||||||
nodes = nodes.mapValues { _.map { n => n.name -> n }.toMap }.toMap,
|
nodes = nodes.mapValues { _.map { n => n.name -> n }.toMap }.toMap,
|
||||||
rules = rules.toSeq
|
rules = rules.toSeq,
|
||||||
|
globals = globals.toMap
|
||||||
)
|
)
|
||||||
|
|
||||||
val nodes = mutable.Map[String, mutable.Buffer[Node]]()
|
val nodes = mutable.Map[String, mutable.Buffer[Node]]()
|
||||||
val rules = mutable.Buffer[Rule]()
|
val rules = mutable.Buffer[Rule]()
|
||||||
|
val globals = mutable.Map[String,Type]()
|
||||||
|
|
||||||
import FieldConversions._
|
import FieldConversions._
|
||||||
|
|
||||||
|
@ -51,6 +56,16 @@ class HardcodedDefinition
|
||||||
com.astraldb.spec.Rule(label, family, pattern, replacement)
|
com.astraldb.spec.Rule(label, family, pattern, replacement)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def Function(label: String, ret: Type = Type.Unit)(args: Type*): Unit =
|
||||||
|
{
|
||||||
|
globals(label) = Type.Function(args,ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
def Global(label: String, t: Type): Unit =
|
||||||
|
{
|
||||||
|
globals(label) = t
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////// Matchers
|
//////////////////////// Matchers
|
||||||
def Match(node: String)(fields: Match*): Match =
|
def Match(node: String)(fields: Match*): Match =
|
||||||
com.astraldb.spec.Match.Node(node, fields)
|
com.astraldb.spec.Match.Node(node, fields)
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
package com.astraldb.spec
|
package com.astraldb.spec
|
||||||
|
|
||||||
import com.astraldb.ast._
|
|
||||||
import com.astraldb.spec.Match.Scope
|
import com.astraldb.spec.Match.Scope
|
||||||
import com.astraldb.expression._
|
import com.astraldb.expression._
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,11 @@ case class Rule(
|
||||||
def validate(schema: Definition): Unit =
|
def validate(schema: Definition): Unit =
|
||||||
{
|
{
|
||||||
val scope =
|
val scope =
|
||||||
TypecheckMatch(pattern, Type.AST(family), schema, Map.empty)
|
TypecheckMatch(pattern, None, Type.AST(family), schema, schema.globals)
|
||||||
TypecheckExpression(rewrite, schema, scope)
|
TypecheckExpression(rewrite, schema, scope)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def safeLabel =
|
||||||
|
label.replaceAll("[^A-Za-z0-9]+", "_")
|
||||||
|
.replace("^([0-9])","_\\1")
|
||||||
}
|
}
|
|
@ -8,6 +8,8 @@ sealed trait Type {
|
||||||
case Type.Union(of) => of
|
case Type.Union(of) => of
|
||||||
case _ => Seq(this)
|
case _ => Seq(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def scalaType: String
|
||||||
}
|
}
|
||||||
|
|
||||||
object Type
|
object Type
|
||||||
|
@ -15,42 +17,52 @@ object Type
|
||||||
case class Native(name: String) extends Type
|
case class Native(name: String) extends Type
|
||||||
{
|
{
|
||||||
override def toString: String = s"Native[${name}]"
|
override def toString: String = s"Native[${name}]"
|
||||||
|
def scalaType: String = name
|
||||||
}
|
}
|
||||||
case class AST(family: String) extends Type
|
case class AST(family: String) extends Type
|
||||||
{
|
{
|
||||||
override def toString: String = s"Ast[${family}]"
|
override def toString: String = s"Ast[${family}]"
|
||||||
|
def scalaType: String = family
|
||||||
}
|
}
|
||||||
case class Node(nodeType: String) extends Type
|
case class Node(nodeType: String) extends Type
|
||||||
{
|
{
|
||||||
override def toString: String = s"Node[${nodeType}]"
|
override def toString: String = s"Node[${nodeType}]"
|
||||||
|
def scalaType: String = nodeType
|
||||||
}
|
}
|
||||||
case class Struct(fields: Map[String, Type]) extends Type
|
case class Struct(fields: Map[String, Type]) extends Type
|
||||||
{
|
{
|
||||||
override def toString: String = s"Struct {${fields.map { x => s"${x._1}: ${x._2}"}.mkString(", ")}}"
|
override def toString: String = s"Struct {${fields.map { x => s"${x._1}: ${x._2}"}.mkString(", ")}}"
|
||||||
|
def scalaType: String = ???
|
||||||
}
|
}
|
||||||
case class Option(base: Type) extends Type
|
case class Option(base: Type) extends Type
|
||||||
{
|
{
|
||||||
override def toString: String = s"Option[$base]"
|
override def toString: String = s"Option[$base]"
|
||||||
|
def scalaType: String = s"Option[${base.scalaType}]"
|
||||||
}
|
}
|
||||||
case class Array(base: Type) extends Type
|
case class Array(base: Type) extends Type
|
||||||
{
|
{
|
||||||
override def toString: String = s"Array[${base}]"
|
override def toString: String = s"Array[${base}]"
|
||||||
|
def scalaType: String = s"Array[${base.scalaType}]"
|
||||||
}
|
}
|
||||||
case class Function(args: Seq[Type], ret: Type) extends Type
|
case class Function(args: Seq[Type], ret: Type) extends Type
|
||||||
{
|
{
|
||||||
override def toString: String = s"(${args.mkString(", ")}) => $ret"
|
override def toString: String = s"(${args.mkString(", ")}) => $ret"
|
||||||
|
def scalaType: String = s"(${args.map { _.scalaType }.mkString(", ")}) => ${ret.scalaType}"
|
||||||
}
|
}
|
||||||
case object Unit extends Type
|
case object Unit extends Type
|
||||||
{
|
{
|
||||||
override def toString: String = "Unit"
|
override def toString: String = "Unit"
|
||||||
|
def scalaType: String = "Unit"
|
||||||
}
|
}
|
||||||
case object Any extends Type
|
case object Any extends Type
|
||||||
{
|
{
|
||||||
override def toString: String = "Any"
|
override def toString: String = "Any"
|
||||||
|
def scalaType: String = "Any"
|
||||||
}
|
}
|
||||||
case class Union(of: Seq[Type]) extends Type
|
case class Union(of: Seq[Type]) extends Type
|
||||||
{
|
{
|
||||||
override def toString: String = s"Union(${of.mkString(" | ")})"
|
override def toString: String = s"Union(${of.mkString(" | ")})"
|
||||||
|
def scalaType: String = of.map { _.scalaType }.mkString("|")
|
||||||
}
|
}
|
||||||
|
|
||||||
sealed trait PrimType extends Type
|
sealed trait PrimType extends Type
|
||||||
|
@ -58,18 +70,22 @@ object Type
|
||||||
case object Int extends PrimType
|
case object Int extends PrimType
|
||||||
{
|
{
|
||||||
override def toString: String = "int"
|
override def toString: String = "int"
|
||||||
|
def scalaType: String = "Int"
|
||||||
}
|
}
|
||||||
case object Float extends PrimType
|
case object Float extends PrimType
|
||||||
{
|
{
|
||||||
override def toString: String = "float"
|
override def toString: String = "float"
|
||||||
|
def scalaType: String = "Double"
|
||||||
}
|
}
|
||||||
case object Bool extends PrimType
|
case object Bool extends PrimType
|
||||||
{
|
{
|
||||||
override def toString: String = "bool"
|
override def toString: String = "bool"
|
||||||
|
def scalaType: String = "Boolean"
|
||||||
}
|
}
|
||||||
case object String extends PrimType
|
case object String extends PrimType
|
||||||
{
|
{
|
||||||
override def toString: String = "string"
|
override def toString: String = "string"
|
||||||
|
def scalaType: String = "String"
|
||||||
}
|
}
|
||||||
|
|
||||||
def union(ts: Seq[Type]) =
|
def union(ts: Seq[Type]) =
|
||||||
|
|
26
astral/compiler/src/com/astraldb/typecheck/Typecheck.scala
Normal file
26
astral/compiler/src/com/astraldb/typecheck/Typecheck.scala
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
package com.astraldb.typecheck
|
||||||
|
|
||||||
|
import com.astraldb.spec._
|
||||||
|
import com.astraldb.expression._
|
||||||
|
|
||||||
|
object Typecheck
|
||||||
|
{
|
||||||
|
|
||||||
|
def escalatesTo(source: Type, target: Type, schema: Definition): Boolean =
|
||||||
|
{
|
||||||
|
(source, target) match {
|
||||||
|
case (a, b) if a == b => true
|
||||||
|
case (Type.Node(label), Type.AST(family)) =>
|
||||||
|
schema.nodes.get(family)
|
||||||
|
.map { _ contains label }
|
||||||
|
.getOrElse { false }
|
||||||
|
case (Type.Union(elems), a) =>
|
||||||
|
elems.forall { escalatesTo(_, a, schema) }
|
||||||
|
case (a, Type.Union(elems)) =>
|
||||||
|
elems.forall { escalatesTo(a, _, schema) }
|
||||||
|
case (Type.Native(_), Type.Native(_)) => true // don't try to encode native inheritance
|
||||||
|
case (_, Type.Any) => true
|
||||||
|
case (_, _) => false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -5,9 +5,123 @@ import com.astraldb.expression._
|
||||||
|
|
||||||
object TypecheckExpression
|
object TypecheckExpression
|
||||||
{
|
{
|
||||||
|
class TypecheckException(expr: Expression) extends Exception
|
||||||
|
class TypeMismatch(expr: Expression, a: Type, b: Type) extends TypecheckException(expr)
|
||||||
|
|
||||||
|
|
||||||
def apply(expr: Expression, schema: Definition, scope: Map[String, Type]): Type =
|
def apply(expr: Expression, schema: Definition, scope: Map[String, Type]): Type =
|
||||||
{
|
{
|
||||||
// TODO
|
// TODO
|
||||||
return Type.Any
|
expr match {
|
||||||
|
case c:Constant => c.t
|
||||||
|
|
||||||
|
case Arith(op, lhs, rhs) =>
|
||||||
|
op match {
|
||||||
|
case ArithTypes.Add | ArithTypes.Mul | ArithTypes.Sub | ArithTypes.Div =>
|
||||||
|
(
|
||||||
|
apply(lhs, schema, scope),
|
||||||
|
apply(rhs, schema, scope),
|
||||||
|
) match {
|
||||||
|
case (Type.Int, Type.Int) => Type.Int
|
||||||
|
case (Type.Float, Type.Float) => Type.Float
|
||||||
|
case (Type.String, Type.String) if op == ArithTypes.Add => Type.String
|
||||||
|
case (a, b) => assert(false, s"Mismatched types $a and $b: $expr")
|
||||||
|
}
|
||||||
|
case ArithTypes.And | ArithTypes.Or =>
|
||||||
|
assert(apply(lhs, schema, scope) == Type.Bool, s"Mismatched types $lhs is not a boolean: $expr")
|
||||||
|
assert(apply(rhs, schema, scope) == Type.Bool, s"Mismatched types $rhs is not a boolean: $expr")
|
||||||
|
Type.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
case Cmp(op, lhs, rhs) =>
|
||||||
|
op match {
|
||||||
|
case CmpTypes.Eq | CmpTypes.Neq =>
|
||||||
|
val l = apply(lhs, schema, scope)
|
||||||
|
val r = apply(rhs, schema, scope)
|
||||||
|
assert(l == r, s"Comparison between mismatched types: $l and $r: $expr")
|
||||||
|
Type.Bool
|
||||||
|
case _ =>
|
||||||
|
(
|
||||||
|
apply(lhs, schema, scope),
|
||||||
|
apply(rhs, schema, scope),
|
||||||
|
) match {
|
||||||
|
case (Type.Int, Type.Int) => Type.Bool
|
||||||
|
case (Type.Float, Type.Float) => Type.Bool
|
||||||
|
case (a, b) => assert(false, s"Comparison between mismatched types $a and $b: $expr")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case Var(v) => scope(v)
|
||||||
|
|
||||||
|
case FunctionCall(fn, args) =>
|
||||||
|
scope.get(fn) match {
|
||||||
|
case Some(Type.Function(argTypes, ret)) =>
|
||||||
|
for( (arg, argType) <- args.zip(argTypes) )
|
||||||
|
{
|
||||||
|
val actualType = apply(arg, schema, scope)
|
||||||
|
assert(
|
||||||
|
Typecheck.escalatesTo(actualType, argType, schema),
|
||||||
|
s"Mismatched argument $arg ($actualType doesn't escalate to $argType): $expr"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
ret
|
||||||
|
case Some(c) =>
|
||||||
|
assert(false, s"Trying to call $fn, which is a $c: $expr")
|
||||||
|
case None =>
|
||||||
|
assert(false, s"Trying to call $fn, which is not defined (in ${scope.keys.mkString(", ")}): $expr")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
case MakeNode(nodeType, fields) =>
|
||||||
|
val node = schema.nodesByName(nodeType)
|
||||||
|
for( (fieldConstructor, fieldType) <- fields.zip(node.fields) )
|
||||||
|
{
|
||||||
|
val actualType = apply(fieldConstructor, schema, scope)
|
||||||
|
assert(Typecheck.escalatesTo(actualType, fieldType.t, schema),
|
||||||
|
s"Mismatched node constructor argument $fieldConstructor ($actualType doesn't escalate to ${fieldType.t}): $expr"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Type.Node(nodeType)
|
||||||
|
|
||||||
|
case NodeSubscript(target, index) =>
|
||||||
|
apply(target, schema, scope) match {
|
||||||
|
case Type.Node(nodeType) =>
|
||||||
|
val node = schema.nodesByName(nodeType)
|
||||||
|
node.fields(index).t
|
||||||
|
case _ =>
|
||||||
|
assert(false, s"Node subscript on something not a node: $expr")
|
||||||
|
}
|
||||||
|
|
||||||
|
case StructSubscript(target, name) =>
|
||||||
|
apply(target, schema, scope) match {
|
||||||
|
case Type.Struct(fields) =>
|
||||||
|
assert(fields contains name,
|
||||||
|
s"Struct subscript on a struct that doesn't contain $name ($fields): $expr"
|
||||||
|
)
|
||||||
|
fields(name)
|
||||||
|
case Type.Node(nodeType) =>
|
||||||
|
val node = schema.nodesByName(nodeType)
|
||||||
|
node.fields.find { _._1 == name }
|
||||||
|
.getOrElse {
|
||||||
|
assert(false,
|
||||||
|
s"Struct subscript on a node that doesn't contain $name (${node.fields}): $expr"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
._2
|
||||||
|
case _ =>
|
||||||
|
assert(false, s"Struct subscript on something not a struct: $expr")
|
||||||
|
}
|
||||||
|
|
||||||
|
case FunctionalIfThenElse(condition, ifTrue, ifFalse) =>
|
||||||
|
assert(apply(condition, schema, scope) == Type.Bool,
|
||||||
|
s"If then else with a non-boolean condition ($condition): $expr"
|
||||||
|
)
|
||||||
|
val a = apply(ifTrue, schema, scope)
|
||||||
|
val b = apply(ifFalse, schema, scope)
|
||||||
|
assert(a == b,
|
||||||
|
s"If then else with mismatched return types $a and $b: $expr"
|
||||||
|
)
|
||||||
|
a
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -6,12 +6,12 @@ import com.astraldb.expression._
|
||||||
object TypecheckMatch
|
object TypecheckMatch
|
||||||
{
|
{
|
||||||
def apply(pattern: Match, family: String, schema: Definition, scope: Map[String, Type]): Map[String, Type] =
|
def apply(pattern: Match, family: String, schema: Definition, scope: Map[String, Type]): Map[String, Type] =
|
||||||
apply(pattern, Type.AST(family), schema, scope)
|
apply(pattern, None, Type.AST(family), schema, scope)
|
||||||
|
|
||||||
def apply(pattern: Match, expectedType: Type, schema: Definition, scope: Map[String, Type]): Map[String, Type] =
|
def apply(pattern: Match, target: Option[String], expectedType: Type, schema: Definition, scope: Map[String, Type]): Map[String, Type] =
|
||||||
{
|
{
|
||||||
def checkAllChildren() =
|
def checkAllChildren() =
|
||||||
pattern.children.foreach { apply(_, expectedType, schema, scope) }
|
pattern.children.foreach { apply(_, target, expectedType, schema, scope) }
|
||||||
|
|
||||||
def astSchema(label: String): Node =
|
def astSchema(label: String): Node =
|
||||||
{
|
{
|
||||||
|
@ -29,15 +29,15 @@ object TypecheckMatch
|
||||||
}
|
}
|
||||||
|
|
||||||
pattern match {
|
pattern match {
|
||||||
case Match.Not(child) => apply(child, expectedType, schema, scope)
|
case Match.Not(child) => apply(child, target, expectedType, schema, scope)
|
||||||
|
|
||||||
case Match.And(children) =>
|
case Match.And(children) =>
|
||||||
children.foldLeft(scope) { (scope, child) =>
|
children.foldLeft(scope) { (scope, child) =>
|
||||||
apply(child, expectedType, schema, scope)
|
apply(child, target, expectedType, schema, scope)
|
||||||
}
|
}
|
||||||
|
|
||||||
case Match.Or(children) =>
|
case Match.Or(children) =>
|
||||||
children.flatMap { apply(_, expectedType, schema, scope).toSeq }
|
children.flatMap { apply(_, target, expectedType, schema, scope).toSeq }
|
||||||
.groupBy { _._1 }
|
.groupBy { _._1 }
|
||||||
.mapValues { childTypes =>
|
.mapValues { childTypes =>
|
||||||
val base =
|
val base =
|
||||||
|
@ -49,13 +49,13 @@ object TypecheckMatch
|
||||||
|
|
||||||
|
|
||||||
case Match.OfType(t) =>
|
case Match.OfType(t) =>
|
||||||
assert(escalatesTo(t, expectedType, schema),
|
assert(Typecheck.escalatesTo(t, expectedType, schema),
|
||||||
s"Matching a node by type; Type $t can not possibly matched by an element of type $expectedType: $pattern"
|
s"Matching a node by type; Type $t can not possibly matched by an element of type $expectedType: $pattern"
|
||||||
)
|
)
|
||||||
scope
|
scope ++ target.map { _ -> t }.toMap
|
||||||
|
|
||||||
case Match.Exact(child) =>
|
case Match.Exact(child) =>
|
||||||
assert(escalatesTo(child.t, expectedType, schema),
|
assert(Typecheck.escalatesTo(child.t, expectedType, schema),
|
||||||
s"Matching a node by value; Value $child can not possibly match an element of type $expectedType: $pattern"
|
s"Matching a node by value; Value $child can not possibly match an element of type $expectedType: $pattern"
|
||||||
)
|
)
|
||||||
scope
|
scope
|
||||||
|
@ -74,7 +74,7 @@ object TypecheckMatch
|
||||||
assert(false, s"Can't path into a simple type: $t: $pattern")
|
assert(false, s"Can't path into a simple type: $t: $pattern")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
apply(child, targetType, schema, scope)
|
apply(child, target, targetType, schema, scope)
|
||||||
|
|
||||||
case Match.Recursive(_) =>
|
case Match.Recursive(_) =>
|
||||||
// not entirely sure how to typecheck this...
|
// not entirely sure how to typecheck this...
|
||||||
|
@ -82,7 +82,7 @@ object TypecheckMatch
|
||||||
???
|
???
|
||||||
|
|
||||||
case Match.Bind(symbol, child) =>
|
case Match.Bind(symbol, child) =>
|
||||||
apply(child, expectedType, schema, scope ++ Map(symbol -> expectedType))
|
apply(child, Some(symbol), expectedType, schema, scope ++ Map(symbol -> expectedType))
|
||||||
|
|
||||||
case Match.Any =>
|
case Match.Any =>
|
||||||
scope
|
scope
|
||||||
|
@ -92,23 +92,23 @@ object TypecheckMatch
|
||||||
|
|
||||||
case Match.Lookup(symbol) =>
|
case Match.Lookup(symbol) =>
|
||||||
assert(scope contains symbol, s"$symbol is not bound: $pattern")
|
assert(scope contains symbol, s"$symbol is not bound: $pattern")
|
||||||
assert(escalatesTo(scope(symbol), expectedType, schema), s"$symbol (of type ${scope(symbol)}) can not possibly match an element of type $expectedType")
|
assert(Typecheck.escalatesTo(scope(symbol), expectedType, schema), s"$symbol (of type ${scope(symbol)}) can not possibly match an element of type $expectedType")
|
||||||
scope
|
scope
|
||||||
|
|
||||||
case Match.ApplyToScope(symbol, child) =>
|
case Match.ApplyToScope(symbol, child) =>
|
||||||
assert(scope contains symbol, s"$symbol is not bound: $pattern")
|
assert(scope contains symbol, s"$symbol is not bound: $pattern")
|
||||||
apply(child, scope(symbol), schema, scope)
|
apply(child, target, scope(symbol), schema, scope)
|
||||||
|
|
||||||
case Match.Forall(child) =>
|
case Match.Forall(child) =>
|
||||||
expectedType match {
|
expectedType match {
|
||||||
case Type.Array(base) =>
|
case Type.Array(base) =>
|
||||||
apply(child, base, schema, scope)
|
apply(child, target, base, schema, scope)
|
||||||
case _ =>
|
case _ =>
|
||||||
assert(false, s"Forall match on something other than an array: $pattern")
|
assert(false, s"Forall match on something other than an array: $pattern")
|
||||||
}
|
}
|
||||||
|
|
||||||
case Match.Test(op) =>
|
case Match.Test(op) =>
|
||||||
assert(escalatesTo(TypecheckExpression(op, schema, scope), Type.Bool, schema),
|
assert(Typecheck.escalatesTo(TypecheckExpression(op, schema, scope), Type.Bool, schema),
|
||||||
s"Non-boolean test expression: $op"
|
s"Non-boolean test expression: $op"
|
||||||
)
|
)
|
||||||
scope
|
scope
|
||||||
|
@ -126,23 +126,10 @@ object TypecheckMatch
|
||||||
var runningScope = scope
|
var runningScope = scope
|
||||||
for( (Field(fieldName, fieldType), fieldMatch) <- nodeSchema.fields.zip(fields) )
|
for( (Field(fieldName, fieldType), fieldMatch) <- nodeSchema.fields.zip(fields) )
|
||||||
{
|
{
|
||||||
runningScope = apply(fieldMatch, fieldType, schema, runningScope)
|
runningScope = apply(fieldMatch, target, fieldType, schema, runningScope)
|
||||||
}
|
}
|
||||||
runningScope
|
runningScope ++ target.map { _ -> Type.Node(label) }.toMap
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def escalatesTo(source: Type, target: Type, schema: Definition): Boolean =
|
|
||||||
{
|
|
||||||
(source, target) match {
|
|
||||||
case (a, b) if a == b => true
|
|
||||||
case (Type.Node(label), Type.AST(family)) =>
|
|
||||||
schema.nodes.get(family)
|
|
||||||
.map { _ contains label }
|
|
||||||
.getOrElse { false }
|
|
||||||
case (Type.Native(_), Type.Native(_)) => true // don't try to encode native inheritance
|
|
||||||
case (Type.Any, _) => true
|
|
||||||
case (_, _) => false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
@import com.astraldb.spec.Definition
|
||||||
|
|
||||||
|
@(ctx:Definition)
|
||||||
|
|
||||||
|
object Optimizer
|
||||||
|
{
|
||||||
|
val rules = Seq[Rule[LogicalPlan]](
|
||||||
|
@for(rule <- ctx.rules){
|
||||||
|
@rule.safeLabel,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def MAX_ITERATIONS = 100
|
||||||
|
|
||||||
|
def rewrite(plan: LogicalPlan): LogicalPlan =
|
||||||
|
{
|
||||||
|
var current = plan
|
||||||
|
var last = plan
|
||||||
|
|
||||||
|
for(i <- 0 until MAX_ITERATIONS)
|
||||||
|
{
|
||||||
|
for(rule <- rules)
|
||||||
|
{
|
||||||
|
plan = rule(plan)
|
||||||
|
}
|
||||||
|
if(last.fastEquals(plan))
|
||||||
|
{
|
||||||
|
return plan
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return plan
|
||||||
|
}
|
||||||
|
}
|
34
astral/compiler/views/com/astraldb/codegen/Rule.scala.scala
Normal file
34
astral/compiler/views/com/astraldb/codegen/Rule.scala.scala
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
@import com.astraldb.spec.Rule
|
||||||
|
@import com.astraldb.spec.Type
|
||||||
|
@import com.astraldb.spec.Definition
|
||||||
|
@import com.astraldb.codegen.Match
|
||||||
|
@import com.astraldb.codegen.Expression
|
||||||
|
@import com.astraldb.codegen.Code
|
||||||
|
@import com.astraldb.typecheck.TypecheckMatch
|
||||||
|
|
||||||
|
@(schema: Definition, rule: Rule)
|
||||||
|
|
||||||
|
object @{rule.safeLabel} extends Rule[LogicalPlan]
|
||||||
|
{
|
||||||
|
def apply(plan: LogicalPlan): LogicalPlan =
|
||||||
|
{
|
||||||
|
@Match(
|
||||||
|
schema = schema,
|
||||||
|
pattern = rule.pattern,
|
||||||
|
target = Code.Literal("plan"),
|
||||||
|
targetType = Type.AST(rule.family),
|
||||||
|
onSuccess = Expression(schema, rule.rewrite,
|
||||||
|
TypecheckMatch(
|
||||||
|
rule.pattern,
|
||||||
|
Some("plan"),
|
||||||
|
Type.AST(rule.family),
|
||||||
|
schema,
|
||||||
|
schema.globals
|
||||||
|
)
|
||||||
|
),
|
||||||
|
onFail = Code.Literal("plan"),
|
||||||
|
name = Some("plan"),
|
||||||
|
scope = schema.globals
|
||||||
|
).toString(4).stripPrefix(" ")
|
||||||
|
}
|
||||||
|
}
|
74
build.sc
74
build.sc
|
@ -1,22 +1,40 @@
|
||||||
import mill._
|
import mill._
|
||||||
import mill.scalalib._
|
import mill.scalalib._
|
||||||
import mill.scalalib.publish._
|
import mill.scalalib.publish._
|
||||||
|
import $ivy.`com.lihaoyi::mill-contrib-twirllib:`, mill.twirllib._
|
||||||
|
|
||||||
object astral extends Module {
|
object astral extends Module
|
||||||
def scalaVersion = "3.2.1"
|
{
|
||||||
|
def scalaVersion = "3.3.0"
|
||||||
|
|
||||||
object compiler extends ScalaModule with PublishModule {
|
def compile = compiler.compile
|
||||||
|
|
||||||
|
object compiler extends ScalaModule
|
||||||
|
with PublishModule
|
||||||
|
with TwirlModule
|
||||||
|
{
|
||||||
val VERSION = "0.0.1-SNAPSHOT"
|
val VERSION = "0.0.1-SNAPSHOT"
|
||||||
|
|
||||||
def scalaVersion = astral.scalaVersion
|
def scalaVersion = astral.scalaVersion
|
||||||
|
def twirlScalaVersion = scalaVersion
|
||||||
|
def twirlVersion = "1.6.0-RC4"
|
||||||
|
|
||||||
def mainClass = Some("com.astraldb.Astral")
|
def mainClass = Some("com.astraldb.Astral")
|
||||||
|
def generatedSources = T{ Seq(compileTwirl().classes) }
|
||||||
|
|
||||||
|
/*************************************************
|
||||||
|
*** Twirl Config
|
||||||
|
*************************************************/
|
||||||
|
def twirlFormats = super.twirlFormats() ++ Map(
|
||||||
|
"scala" -> "play.twirl.api.TxtFormat"
|
||||||
|
)
|
||||||
|
|
||||||
/*************************************************
|
/*************************************************
|
||||||
*** Backend Dependencies
|
*** Backend Dependencies
|
||||||
*************************************************/
|
*************************************************/
|
||||||
// def ivyDeps = Agg(
|
def ivyDeps = Agg(
|
||||||
// )
|
ivy"com.typesafe.play::twirl-api::${twirlVersion()}"
|
||||||
|
)
|
||||||
|
|
||||||
def publishVersion = VERSION
|
def publishVersion = VERSION
|
||||||
override def pomSettings = PomSettings(
|
override def pomSettings = PomSettings(
|
||||||
|
@ -35,9 +53,53 @@ object astral extends Module {
|
||||||
{
|
{
|
||||||
def scalaVersion = astral.scalaVersion
|
def scalaVersion = astral.scalaVersion
|
||||||
|
|
||||||
def mainClass = Some("com.astraldb.catalyst.Astral")
|
def mainClass = Some("com.astraldb.catalyst.Generate")
|
||||||
|
|
||||||
def moduleDeps = Seq(astral.compiler)
|
def moduleDeps = Seq(astral.compiler)
|
||||||
|
|
||||||
|
def classPath = T{ Seq[PathRef](compile().classes) ++ resources() }
|
||||||
|
|
||||||
|
def rendered = T {
|
||||||
|
val target = T.dest
|
||||||
|
val files = scala.collection.mutable.Buffer[String]()
|
||||||
|
os.proc(
|
||||||
|
"scala",
|
||||||
|
"-cp",
|
||||||
|
classPath().mkString(":"),
|
||||||
|
mainClass().get,
|
||||||
|
target
|
||||||
|
).call(
|
||||||
|
cwd = target,
|
||||||
|
stdout = os.ProcessOutput.Readlines(
|
||||||
|
line => files += line
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for(f <- files)
|
||||||
|
{
|
||||||
|
println(s"GOT : $f")
|
||||||
|
}
|
||||||
|
|
||||||
|
/* return */
|
||||||
|
files.map {
|
||||||
|
file => PathRef(target / file)
|
||||||
|
}.toSeq
|
||||||
|
}
|
||||||
|
|
||||||
|
def render(args: String*) = T.command {
|
||||||
|
println(rendered())
|
||||||
|
}
|
||||||
|
|
||||||
|
object impl extends ScalaModule
|
||||||
|
{
|
||||||
|
def scalaVersion = "2.13.8"
|
||||||
|
|
||||||
|
def generatedSources = T{ astral.catalyst.rendered() }
|
||||||
|
|
||||||
|
def ivyDeps = Agg(
|
||||||
|
ivy"org.apache.spark::spark-sql::3.4.1",
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue