Merge branch 'main' of git.odin.cse.buffalo.edu:Astral/astral-compiler

This commit is contained in:
Oliver Kennedy 2023-07-07 18:42:06 -04:00
commit 9e38df1761
Signed by: okennedy
GPG key ID: 3E5F9B3ABD3FDB60
23 changed files with 941 additions and 43 deletions

View file

@ -54,7 +54,7 @@
- [ ] CombineConcats,
- [ ] PushdownPredicatesAndPruneColumnsForCTEDef
- [ ] Codegen: Output a Spark-like optimizer based on the above rules
- [ ] Logic generation
- [x] Logic generation
- [ ] Compilation pipeline to streamline testing
- [ ] Test: Do these rules generate the same results as spark?
- [ ] Apply the rule merge optimization

View file

@ -0,0 +1,9 @@
package com.astraldb.catalyst
object OptimizerTest
{
def main(args: Array[String]): Unit =
{
println("Hello world")
}
}

View file

@ -1,5 +1,7 @@
package com.astraldb.catalyst
import com.astraldb.codegen.Render
object Astral
{
def main(args: Array[String]): Unit =
@ -12,6 +14,7 @@ object Astral
System.exit(-1)
}
println(Catalyst.definition)
// println(Catalyst.definition)
Render.print(Catalyst.definition)
}
}

View file

@ -30,6 +30,52 @@ object Catalyst extends HardcodedDefinition
//////////////////////////////////////////////////////
Function("PushProjectionThroughUnion.projectListIsDeterministic",
Type.Bool,
)(
Type.Any
)
Function("PushProjectionThroughUnion.unionHasChildren",
Type.Bool,
)(
Type.Any
)
Function("PushProjectionThroughUnion.pushProjectionThroughUnion",
Type.Array(Type.AST("LogicalPlan"))
)(
Type.Any,
Type.Any
)
Function("ExtractFiltersAndInnerJoins",
Type.Struct(Map(
"_1" -> Type.Struct(Map("size" -> Type.Int)),
"_2" -> Type.Struct(Map("nonEmpty" -> Type.Bool)),
))
)(
Type.AST("LogicalPlan")
)
Function("ReorderJoin.rewrite", Type.Any)(
Type.AST("LogicalPlan")
)
Function("EliminateOuterJoin.buildNewJoinType",
Type.Union(Seq(
Type.Native("RightOuter"),
Type.Native("LeftOuter"),
Type.Native("FullOuter"),
))
)(
Type.AST("LogicalPlan")
)
Global("JoinHint.NONE", Type.Native("JoinHint"))
//////////////////////////////////////////////////////
Rule("PushProjectionThroughUnion", "LogicalPlan")(
Match("Project")(
Bind("projectList", MatchAny),

View file

@ -0,0 +1,9 @@
package com.astraldb.catalyst
object Generate
{
def main(args: Array[String]): Unit =
{
println("Optimizer.sql")
}
}

View file

@ -0,0 +1,276 @@
package com.astraldb.codegen
sealed trait Code
{
def render(indent: Int): Code.PaddedString|Code.Lines
def renderLines(indent: Int): Code.Lines =
render(indent) match {
case x:Code.PaddedString =>
Code.Lines.ofLine(indent, x)
case x:Code.Lines => x
}
def block =
this match {
case Code.Block(lines) => lines
case _ => Seq(this)
}
override def toString(): String =
renderLines(0).body.mkString("\n")
def toString(indent: Int): String =
renderLines(indent).body.mkString("\n")
}
object Code
{
class PaddedString(val body: String, val lPad: Int = 0, val rPad: Int = 0)
object PaddedString
{
def apply(elems: (String|PaddedString)*): PaddedString =
{
if(elems.isEmpty){ new PaddedString("") }
else {
val (buffer, lPad, rPad) =
elems.head match {
case s: String =>
(StringBuffer(s), 0, 0)
case p: PaddedString =>
(StringBuffer(p.body), p.lPad, p.rPad)
}
var pad = rPad
for(e <- elems.tail)
{
if(pad > 0){ buffer.append(" " * pad) }
e match {
case s:String =>
buffer.append(s)
pad = 0
case p:PaddedString =>
buffer.append(" "*Math.max(0, p.lPad - pad))
buffer.append(p.body)
pad = p.rPad
}
}
new PaddedString(buffer.toString, lPad, pad)
}
}
def leftPad(pad: Int, body: String): PaddedString =
new PaddedString(body, lPad = pad)
def rightPad(pad: Int, body: String): PaddedString =
new PaddedString(body, rPad = pad)
def pad(pad: Int, body: String): PaddedString =
new PaddedString(body, lPad = pad, rPad = pad)
}
case class Lines(body: Seq[String])
{
def ++(other: Lines): Lines =
new Lines(body ++ other.body)
}
object Lines
{
def ofLine(indent: Int, elems: (String|PaddedString)*): Lines =
{
val p = PaddedString(elems:_*)
new Lines(Seq(
" " * Math.max(indent, p.lPad) + p.body
))
}
def ofLines(elems: (String|Lines)*): Lines =
new Lines(
elems.flatMap {
case s: String => Seq(s)
case Lines(l) => l
}
)
val empty = new Lines(Seq.empty)
def flatten(elems: Seq[Lines]): Lines =
elems.foldLeft(empty){ _ ++ _ }
def apply(body: Seq[String]): Lines =
new Lines(body)
}
case object Empty extends Code
{
def render(indent: Int): PaddedString|Lines =
PaddedString("")
}
case class Block(
lines: Seq[Code],
) extends Code
{
def render(indent: Int): PaddedString|Lines =
Lines.flatten(
lines.map { _.renderLines(indent) }
)
}
case class Literal(txt: String) extends Code
{
def render(indent: Int): PaddedString|Lines =
PaddedString(txt)
}
case class BinOp(left: Code, op: String|PaddedString, right: Code) extends Code
{
def render(indent: Int): PaddedString|Lines =
(left.render(indent), right.render(indent)) match {
case (l:PaddedString, r:PaddedString) =>
PaddedString(l, op, r)
case (l:PaddedString, r:Lines) =>
Lines.ofLine(indent, l, op) ++ r
case (l:Lines, r:PaddedString) =>
l ++ Lines.ofLine(indent, op, r)
case (l:Lines, r:Lines) =>
l ++ Lines.ofLine(indent+2, op) ++ r
}
}
case class Parens(left: String|PaddedString, body: Code, right: String|PaddedString) extends Code
{
def render(indent: Int): PaddedString|Lines =
body.render(indent+2) match {
case b:PaddedString =>
PaddedString(left, b, right)
case b:Lines =>
Lines.ofLine(indent, left) ++
b ++
Lines.ofLine(indent, right)
}
}
/**
* A "block" with a prefix
*
* Used to represent if-then, while, for, etc... blocks
*/
case class PrefixedBlock(prefix: Code, lParen: String|PaddedString, rParen: String|PaddedString, body: Code) extends Code
{
def render(indent: Int): PaddedString|Lines =
body.render(indent+2) match {
case b:PaddedString => prefix.render(indent) match {
case p:PaddedString =>
PaddedString(p, lParen, b, rParen)
case p:Lines =>
p ++ Lines.ofLine(indent, lParen, b, rParen)
}
case b:Lines =>
prefix.renderLines(indent) ++
Lines.ofLine(indent, lParen) ++
b ++
Lines.ofLine(indent, rParen)
}
}
/**
* Two "block"s with a prefix and a joiner
*
* Used to represent if-then-else
*/
case class PrefixedBlockPair(prefix: Code, lParen: String|PaddedString, joiner: String|PaddedString, rParen: String|PaddedString, lBody: Code, rBody: Code) extends Code
{
def render(indent: Int): PaddedString|Lines =
{
val p = prefix.render(indent)
val l = lBody.render(indent+2)
val r = rBody.render(indent+2)
val leftNeedsTrailingRParen =
p.isInstanceOf[Lines] || l.isInstanceOf[Lines]
(p match {
case pp: PaddedString =>
l match {
case lp: PaddedString =>
Lines.ofLine(indent, pp, lParen, lp, rParen)
case ll: Lines =>
Lines.ofLine(indent, pp, lParen) ++
ll
}
case pl: Lines =>
l match {
case lp: PaddedString =>
pl ++
Lines.ofLine(indent, lParen) ++
Lines.ofLine(indent, lp)
case ll: Lines =>
pl ++
Lines.ofLine(indent, lParen) ++
ll
}
}) ++ ((r, leftNeedsTrailingRParen) match {
case (rp: PaddedString, false) =>
Lines.ofLine(indent, joiner, lParen, rp, rParen)
case (rp: PaddedString, true) =>
Lines.ofLine(indent, rParen, joiner, lParen, rp, rParen)
case (rp: Lines, false) =>
Lines.ofLine(indent, joiner, lParen) ++
rp ++
Lines.ofLine(indent, rParen)
case (rp: Lines, true) =>
Lines.ofLine(indent, rParen, joiner, lParen) ++
rp ++
Lines.ofLine(indent, rParen)
})
}
}
def IfThenElse(condition: Code, thenBlock: Code, elseBlock: Code): Code =
Code.PrefixedBlockPair(
prefix = Code.Parens("if(",condition,")"),
lParen = PaddedString.rightPad(1, "{"),
rParen = PaddedString.leftPad(1, "}"),
joiner = PaddedString.pad(1, "else"),
lBody = thenBlock,
rBody = elseBlock
)
case class List(sep: String|PaddedString, elems: Seq[Code]) extends Code
{
def render(indent: Int): PaddedString|Lines =
{
val renderedElems = elems.map { _.render(indent+2) }
val multiLine =
renderedElems.exists { _.isInstanceOf[Lines] }
|| renderedElems.map { _.asInstanceOf[PaddedString].body.size }.sum > 100
if(multiLine){
Lines(
renderedElems.flatMap {
case p:PaddedString =>
Lines.ofLine(indent, p, sep).body
case Lines(Seq()) =>
Lines.ofLine(indent, sep).body
case Lines(Seq(l)) =>
Seq(PaddedString(l, sep).body)
case Lines(l) =>
Seq(l.head) ++
l.tail.dropRight(1).map { " " + _ } ++
Seq(PaddedString(" ",l.last,sep).body)
}
)
} else {
renderedElems.map { _.asInstanceOf[PaddedString] } match {
case Seq() => PaddedString("")
case Seq(a) => a
case a =>
PaddedString( (
Seq(a.head) ++ a.tail.flatMap { Seq(sep, _) }
).toSeq:_*)
}
}
}
}
def Parenthesize(body: Code): Code =
Parens("(", body, ")")
}

View file

@ -0,0 +1,64 @@
package com.astraldb.codegen
import com.astraldb.spec
import com.astraldb.expression._
import com.astraldb.expression.{ Expression => Expr}
import com.astraldb.codegen.Code.PaddedString
import com.astraldb.typecheck.TypecheckExpression
import com.astraldb.spec.Type
object Expression
{
def apply(schema: spec.Definition, op: Expr, scope: Map[String, spec.Type]): Code =
{
op match {
case c:SimpleConstant => Code.Literal(c.asScala)
case Var(v) => Code.Literal(v)
case Arith(t, lhs, rhs) =>
Code.BinOp(
Code.Parenthesize(apply(schema, lhs, scope)),
PaddedString.pad(1, ArithTypes.opString(t)),
Code.Parenthesize(apply(schema, rhs, scope))
)
case Cmp(t, lhs, rhs) =>
Code.BinOp(
Code.Parenthesize(apply(schema, lhs, scope)),
PaddedString.pad(1, CmpTypes.opString(t)),
Code.Parenthesize(apply(schema, rhs, scope))
)
case FunctionCall(fn, args) =>
Code.Literal(op.toString)
Code.PrefixedBlock(
prefix = Code.Literal(fn),
lParen = "(",
rParen = ")",
body = Code.List(
sep = PaddedString.rightPad(1, ","),
elems = args.map { apply(schema, _, scope) }
)
)
case MakeNode(node, fields) =>
apply(schema, FunctionCall(node, fields), scope)
case NodeSubscript(target, index) =>
Code.BinOp(
apply(schema, target, scope),
".",
TypecheckExpression(target, schema, scope) match {
case Type.Node(nodeType) =>
Code.Literal(
schema.nodesByName(nodeType).fields(index).name
)
case c =>
assert(false, s"Node subscript on something not a node: $op (in ${scope}): ${c.getClass.getSimpleName}")
}
)
case StructSubscript(target, name) =>
Code.BinOp(
apply(schema, target, scope),
".",
Code.Literal(name)
)
}
}
}

View file

@ -0,0 +1,140 @@
package com.astraldb.codegen
import com.astraldb.spec
import spec.Match._
import com.astraldb.codegen.Code.PaddedString
import com.astraldb.typecheck.TypecheckMatch
import com.astraldb.spec.Type
object Match
{
def apply(
schema: spec.Definition,
pattern: spec.Match,
target: Code,
targetType: spec.Type,
onSuccess: Code,
onFail: Code,
name: Option[String],
scope: Map[String, spec.Type]
): Code =
{
pattern match {
case And(Seq()) => onSuccess
case And(Seq(a)) => apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
case And(a) =>
apply(schema, a.head,
target = target,
targetType = targetType,
onSuccess =
apply(schema, And(a.tail), target, targetType, onSuccess, onFail, name,
scope = TypecheckMatch(a.head, name, targetType, schema, scope)
),
onFail = onFail,
name = name,
scope = scope
)
case Not(a) => apply(schema, a, target, targetType, onFail, onSuccess, name, scope)
case Or(Seq()) => onFail
case Or(Seq(a)) =>
apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
case Or(a) =>
apply(schema, a.head,
target = target,
targetType = targetType,
onSuccess = onSuccess,
onFail = apply(schema, Or(a.tail), target, targetType, onSuccess, onFail, name, scope),
name = name,
scope = scope
)
case Bind(symbol, pattern) =>
Code.Block(
Code.BinOp(
Code.Literal(s"val $symbol"),
Code.PaddedString.pad(1, "="),
target
) +:
apply(
schema = schema,
pattern = pattern,
target = Code.Literal(symbol),
targetType = targetType,
onSuccess = onSuccess,
onFail = onFail,
name = Some(symbol),
scope = scope ++ Map(symbol -> targetType)
).block
)
case BindExpression(symbol, op) =>
Code.Block(
Seq(
Code.BinOp(
Code.Literal(s"val $symbol"),
Code.PaddedString.pad(1, "="),
Expression(schema, op, scope)
)
) ++ onSuccess.block
)
case Node(nodeLabel, children) =>
{
val selectedName =
name.getOrElse { "genericNode" }
val node = schema.nodesByName(nodeLabel)
Code.IfThenElse(
condition =
Code.BinOp(
target,
".",
Code.Literal(s"isInstanceOf[$nodeLabel]")
),
thenBlock =
Code.Block(
Seq(
Code.BinOp(
Code.Literal(s"val $selectedName"),
Code.PaddedString.pad(1, "="),
Code.BinOp(
target,
".",
Code.Literal(s"asInstanceOf[$nodeLabel]")
)
)
) ++
children
.zip(node.fields)
.foldRight(onSuccess) { case ((child, field), andThen) =>
apply(
schema = schema,
pattern = child,
target = Code.Literal(s"$selectedName.${field.name}"),
targetType = Type.Node(nodeLabel),
onSuccess = andThen,
onFail = onFail,
name = Some(s"${selectedName}_${field.name}"),
scope = scope ++ Map(selectedName -> Type.Node(nodeLabel))
)
}
.block
),
elseBlock = onFail
)
}
case Test(op) =>
Code.IfThenElse(
condition = Expression(schema, op, scope),
thenBlock = onSuccess,
elseBlock = onFail
)
case Any =>
onSuccess
case OfType(nodeType) =>
Code.IfThenElse(
condition = Code.BinOp(target, ".", Code.Literal(s"isInstanceOf[${nodeType.scalaType}]")),
thenBlock = onSuccess,
elseBlock = onFail
)
case _ => Code.Literal(s"??${pattern.getClass.getSimpleName}??")
}
}
}

View file

@ -0,0 +1,11 @@
package com.astraldb.codegen
import com.astraldb.spec.Definition
object Optimizer
{
def apply(definition: Definition): String =
{
scala.Optimizer(definition).toString
}
}

View file

@ -0,0 +1,25 @@
package com.astraldb.codegen
import com.astraldb.spec.Definition
object Render
{
def apply(schema: Definition): Map[String, String] =
{
Map(
"Optimizer.scala" -> Optimizer(schema)
) ++ schema.rules.map { rule =>
s"${rule.safeLabel}.scala" -> Rule(schema, rule)
}.toMap
}
def print(schema: Definition): Unit =
{
for( (file,content) <- apply(schema) )
{
println(s"//////////////////// $file")
println(content)
}
}
}

View file

@ -0,0 +1,11 @@
package com.astraldb.codegen
import com.astraldb.spec
object Rule
{
def apply(schema: spec.Definition, rule: spec.Rule): String =
{
scala.Rule(schema, rule).toString
}
}

View file

@ -9,6 +9,7 @@ object Eval
class TypeMismatch(problem: Expression, expr: Expression, lhsType: Type, rhsType: Type) extends EvalException(expr)
class TypeError(problem: Expression, expr: Expression, problemType: Type) extends EvalException(expr)
class UnsupportedExternalFeature(problem: Expression, expr: Expression) extends EvalException(expr)
class MissingFieldError(expr: Expression, field: String) extends EvalException(expr)
def apply(base: Expression, scope: Map[String, Constant] = Map.empty): Constant =
{
@ -117,12 +118,24 @@ object Eval
elements(index)
case c => throw new TypeError(e, base, c.t)
}
case StructSubscript(target, name) =>
rcr(target, scope) match {
case StructConstant(elements) =>
elements.find { _._1 == name }
.getOrElse {
throw new MissingFieldError(e, name)
}
._2
case c => throw new TypeError(e, base, c.t)
}
case FunctionCall(name, args) =>
throw new UnsupportedExternalFeature(e, base)
case MakeNode(nodeType, fields) =>
NodeConstant(nodeType, fields.map { rcr(_, scope) })
case MakeArray(elements) =>
ArrayConstant(elements.map { rcr(_, scope) })
case MakeStruct(elements) =>
StructConstant(elements.map { case (f, e) => f -> rcr(e, scope) })
case Let(symbol, value, rest) =>
rcr(rest, scope ++ Map(symbol -> rcr(value, scope)))
}

View file

@ -109,6 +109,7 @@ sealed trait SimpleConstant extends Constant
{
def fields: Seq[Constant] = Seq.empty
def reassembleFields(in: Seq[Constant]): Expression = this
def asScala: String = toString
}
/**
* A constant with components (Nodes, Arrays)

View file

@ -5,7 +5,8 @@ import com.astraldb.expression._
case class Definition(
nodes:Map[String, Map[String, Node]],
rules:Seq[Rule]
rules:Seq[Rule],
globals: Map[String, Type],
) {
override def toString =
s"""/////// ASTs //////
@ -18,6 +19,8 @@ case class Definition(
nodes.flatMap { case (family, elements) => elements.keys.map { _ -> family } }
.toMap
val nodesByName: Map[String, Node] =
nodes.values.flatten.toMap
def validate() =
{
@ -30,11 +33,13 @@ class HardcodedDefinition
lazy val definition: Definition =
Definition(
nodes = nodes.mapValues { _.map { n => n.name -> n }.toMap }.toMap,
rules = rules.toSeq
rules = rules.toSeq,
globals = globals.toMap
)
val nodes = mutable.Map[String, mutable.Buffer[Node]]()
val rules = mutable.Buffer[Rule]()
val globals = mutable.Map[String,Type]()
import FieldConversions._
@ -51,6 +56,16 @@ class HardcodedDefinition
com.astraldb.spec.Rule(label, family, pattern, replacement)
)
def Function(label: String, ret: Type = Type.Unit)(args: Type*): Unit =
{
globals(label) = Type.Function(args,ret)
}
def Global(label: String, t: Type): Unit =
{
globals(label) = t
}
//////////////////////// Matchers
def Match(node: String)(fields: Match*): Match =
com.astraldb.spec.Match.Node(node, fields)

View file

@ -1,6 +1,5 @@
package com.astraldb.spec
import com.astraldb.ast._
import com.astraldb.spec.Match.Scope
import com.astraldb.expression._

View file

@ -17,7 +17,11 @@ case class Rule(
def validate(schema: Definition): Unit =
{
val scope =
TypecheckMatch(pattern, Type.AST(family), schema, Map.empty)
TypecheckMatch(pattern, None, Type.AST(family), schema, schema.globals)
TypecheckExpression(rewrite, schema, scope)
}
def safeLabel =
label.replaceAll("[^A-Za-z0-9]+", "_")
.replace("^([0-9])","_\\1")
}

View file

@ -8,6 +8,8 @@ sealed trait Type {
case Type.Union(of) => of
case _ => Seq(this)
}
def scalaType: String
}
object Type
@ -15,42 +17,52 @@ object Type
case class Native(name: String) extends Type
{
override def toString: String = s"Native[${name}]"
def scalaType: String = name
}
case class AST(family: String) extends Type
{
override def toString: String = s"Ast[${family}]"
def scalaType: String = family
}
case class Node(nodeType: String) extends Type
{
override def toString: String = s"Node[${nodeType}]"
def scalaType: String = nodeType
}
case class Struct(fields: Map[String, Type]) extends Type
{
override def toString: String = s"Struct {${fields.map { x => s"${x._1}: ${x._2}"}.mkString(", ")}}"
def scalaType: String = ???
}
case class Option(base: Type) extends Type
{
override def toString: String = s"Option[$base]"
def scalaType: String = s"Option[${base.scalaType}]"
}
case class Array(base: Type) extends Type
{
override def toString: String = s"Array[${base}]"
def scalaType: String = s"Array[${base.scalaType}]"
}
case class Function(args: Seq[Type], ret: Type) extends Type
{
override def toString: String = s"(${args.mkString(", ")}) => $ret"
def scalaType: String = s"(${args.map { _.scalaType }.mkString(", ")}) => ${ret.scalaType}"
}
case object Unit extends Type
{
override def toString: String = "Unit"
def scalaType: String = "Unit"
}
case object Any extends Type
{
override def toString: String = "Any"
def scalaType: String = "Any"
}
case class Union(of: Seq[Type]) extends Type
{
override def toString: String = s"Union(${of.mkString(" | ")})"
def scalaType: String = of.map { _.scalaType }.mkString("|")
}
sealed trait PrimType extends Type
@ -58,18 +70,22 @@ object Type
case object Int extends PrimType
{
override def toString: String = "int"
def scalaType: String = "Int"
}
case object Float extends PrimType
{
override def toString: String = "float"
def scalaType: String = "Double"
}
case object Bool extends PrimType
{
override def toString: String = "bool"
def scalaType: String = "Boolean"
}
case object String extends PrimType
{
override def toString: String = "string"
def scalaType: String = "String"
}
def union(ts: Seq[Type]) =

View file

@ -0,0 +1,26 @@
package com.astraldb.typecheck
import com.astraldb.spec._
import com.astraldb.expression._
object Typecheck
{
def escalatesTo(source: Type, target: Type, schema: Definition): Boolean =
{
(source, target) match {
case (a, b) if a == b => true
case (Type.Node(label), Type.AST(family)) =>
schema.nodes.get(family)
.map { _ contains label }
.getOrElse { false }
case (Type.Union(elems), a) =>
elems.forall { escalatesTo(_, a, schema) }
case (a, Type.Union(elems)) =>
elems.forall { escalatesTo(a, _, schema) }
case (Type.Native(_), Type.Native(_)) => true // don't try to encode native inheritance
case (_, Type.Any) => true
case (_, _) => false
}
}
}

View file

@ -5,9 +5,123 @@ import com.astraldb.expression._
object TypecheckExpression
{
class TypecheckException(expr: Expression) extends Exception
class TypeMismatch(expr: Expression, a: Type, b: Type) extends TypecheckException(expr)
def apply(expr: Expression, schema: Definition, scope: Map[String, Type]): Type =
{
// TODO
return Type.Any
expr match {
case c:Constant => c.t
case Arith(op, lhs, rhs) =>
op match {
case ArithTypes.Add | ArithTypes.Mul | ArithTypes.Sub | ArithTypes.Div =>
(
apply(lhs, schema, scope),
apply(rhs, schema, scope),
) match {
case (Type.Int, Type.Int) => Type.Int
case (Type.Float, Type.Float) => Type.Float
case (Type.String, Type.String) if op == ArithTypes.Add => Type.String
case (a, b) => assert(false, s"Mismatched types $a and $b: $expr")
}
case ArithTypes.And | ArithTypes.Or =>
assert(apply(lhs, schema, scope) == Type.Bool, s"Mismatched types $lhs is not a boolean: $expr")
assert(apply(rhs, schema, scope) == Type.Bool, s"Mismatched types $rhs is not a boolean: $expr")
Type.Bool
}
case Cmp(op, lhs, rhs) =>
op match {
case CmpTypes.Eq | CmpTypes.Neq =>
val l = apply(lhs, schema, scope)
val r = apply(rhs, schema, scope)
assert(l == r, s"Comparison between mismatched types: $l and $r: $expr")
Type.Bool
case _ =>
(
apply(lhs, schema, scope),
apply(rhs, schema, scope),
) match {
case (Type.Int, Type.Int) => Type.Bool
case (Type.Float, Type.Float) => Type.Bool
case (a, b) => assert(false, s"Comparison between mismatched types $a and $b: $expr")
}
}
case Var(v) => scope(v)
case FunctionCall(fn, args) =>
scope.get(fn) match {
case Some(Type.Function(argTypes, ret)) =>
for( (arg, argType) <- args.zip(argTypes) )
{
val actualType = apply(arg, schema, scope)
assert(
Typecheck.escalatesTo(actualType, argType, schema),
s"Mismatched argument $arg ($actualType doesn't escalate to $argType): $expr"
)
}
ret
case Some(c) =>
assert(false, s"Trying to call $fn, which is a $c: $expr")
case None =>
assert(false, s"Trying to call $fn, which is not defined (in ${scope.keys.mkString(", ")}): $expr")
}
case MakeNode(nodeType, fields) =>
val node = schema.nodesByName(nodeType)
for( (fieldConstructor, fieldType) <- fields.zip(node.fields) )
{
val actualType = apply(fieldConstructor, schema, scope)
assert(Typecheck.escalatesTo(actualType, fieldType.t, schema),
s"Mismatched node constructor argument $fieldConstructor ($actualType doesn't escalate to ${fieldType.t}): $expr"
)
}
Type.Node(nodeType)
case NodeSubscript(target, index) =>
apply(target, schema, scope) match {
case Type.Node(nodeType) =>
val node = schema.nodesByName(nodeType)
node.fields(index).t
case _ =>
assert(false, s"Node subscript on something not a node: $expr")
}
case StructSubscript(target, name) =>
apply(target, schema, scope) match {
case Type.Struct(fields) =>
assert(fields contains name,
s"Struct subscript on a struct that doesn't contain $name ($fields): $expr"
)
fields(name)
case Type.Node(nodeType) =>
val node = schema.nodesByName(nodeType)
node.fields.find { _._1 == name }
.getOrElse {
assert(false,
s"Struct subscript on a node that doesn't contain $name (${node.fields}): $expr"
)
}
._2
case _ =>
assert(false, s"Struct subscript on something not a struct: $expr")
}
case FunctionalIfThenElse(condition, ifTrue, ifFalse) =>
assert(apply(condition, schema, scope) == Type.Bool,
s"If then else with a non-boolean condition ($condition): $expr"
)
val a = apply(ifTrue, schema, scope)
val b = apply(ifFalse, schema, scope)
assert(a == b,
s"If then else with mismatched return types $a and $b: $expr"
)
a
}
}
}

View file

@ -6,12 +6,12 @@ import com.astraldb.expression._
object TypecheckMatch
{
def apply(pattern: Match, family: String, schema: Definition, scope: Map[String, Type]): Map[String, Type] =
apply(pattern, Type.AST(family), schema, scope)
apply(pattern, None, Type.AST(family), schema, scope)
def apply(pattern: Match, expectedType: Type, schema: Definition, scope: Map[String, Type]): Map[String, Type] =
def apply(pattern: Match, target: Option[String], expectedType: Type, schema: Definition, scope: Map[String, Type]): Map[String, Type] =
{
def checkAllChildren() =
pattern.children.foreach { apply(_, expectedType, schema, scope) }
pattern.children.foreach { apply(_, target, expectedType, schema, scope) }
def astSchema(label: String): Node =
{
@ -29,15 +29,15 @@ object TypecheckMatch
}
pattern match {
case Match.Not(child) => apply(child, expectedType, schema, scope)
case Match.Not(child) => apply(child, target, expectedType, schema, scope)
case Match.And(children) =>
children.foldLeft(scope) { (scope, child) =>
apply(child, expectedType, schema, scope)
apply(child, target, expectedType, schema, scope)
}
case Match.Or(children) =>
children.flatMap { apply(_, expectedType, schema, scope).toSeq }
children.flatMap { apply(_, target, expectedType, schema, scope).toSeq }
.groupBy { _._1 }
.mapValues { childTypes =>
val base =
@ -49,13 +49,13 @@ object TypecheckMatch
case Match.OfType(t) =>
assert(escalatesTo(t, expectedType, schema),
assert(Typecheck.escalatesTo(t, expectedType, schema),
s"Matching a node by type; Type $t can not possibly matched by an element of type $expectedType: $pattern"
)
scope
scope ++ target.map { _ -> t }.toMap
case Match.Exact(child) =>
assert(escalatesTo(child.t, expectedType, schema),
assert(Typecheck.escalatesTo(child.t, expectedType, schema),
s"Matching a node by value; Value $child can not possibly match an element of type $expectedType: $pattern"
)
scope
@ -74,7 +74,7 @@ object TypecheckMatch
assert(false, s"Can't path into a simple type: $t: $pattern")
}
}
apply(child, targetType, schema, scope)
apply(child, target, targetType, schema, scope)
case Match.Recursive(_) =>
// not entirely sure how to typecheck this...
@ -82,7 +82,7 @@ object TypecheckMatch
???
case Match.Bind(symbol, child) =>
apply(child, expectedType, schema, scope ++ Map(symbol -> expectedType))
apply(child, Some(symbol), expectedType, schema, scope ++ Map(symbol -> expectedType))
case Match.Any =>
scope
@ -92,23 +92,23 @@ object TypecheckMatch
case Match.Lookup(symbol) =>
assert(scope contains symbol, s"$symbol is not bound: $pattern")
assert(escalatesTo(scope(symbol), expectedType, schema), s"$symbol (of type ${scope(symbol)}) can not possibly match an element of type $expectedType")
assert(Typecheck.escalatesTo(scope(symbol), expectedType, schema), s"$symbol (of type ${scope(symbol)}) can not possibly match an element of type $expectedType")
scope
case Match.ApplyToScope(symbol, child) =>
assert(scope contains symbol, s"$symbol is not bound: $pattern")
apply(child, scope(symbol), schema, scope)
apply(child, target, scope(symbol), schema, scope)
case Match.Forall(child) =>
expectedType match {
case Type.Array(base) =>
apply(child, base, schema, scope)
apply(child, target, base, schema, scope)
case _ =>
assert(false, s"Forall match on something other than an array: $pattern")
}
case Match.Test(op) =>
assert(escalatesTo(TypecheckExpression(op, schema, scope), Type.Bool, schema),
assert(Typecheck.escalatesTo(TypecheckExpression(op, schema, scope), Type.Bool, schema),
s"Non-boolean test expression: $op"
)
scope
@ -126,23 +126,10 @@ object TypecheckMatch
var runningScope = scope
for( (Field(fieldName, fieldType), fieldMatch) <- nodeSchema.fields.zip(fields) )
{
runningScope = apply(fieldMatch, fieldType, schema, runningScope)
runningScope = apply(fieldMatch, target, fieldType, schema, runningScope)
}
runningScope
runningScope ++ target.map { _ -> Type.Node(label) }.toMap
}
}
def escalatesTo(source: Type, target: Type, schema: Definition): Boolean =
{
(source, target) match {
case (a, b) if a == b => true
case (Type.Node(label), Type.AST(family)) =>
schema.nodes.get(family)
.map { _ contains label }
.getOrElse { false }
case (Type.Native(_), Type.Native(_)) => true // don't try to encode native inheritance
case (Type.Any, _) => true
case (_, _) => false
}
}
}

View file

@ -0,0 +1,33 @@
@import com.astraldb.spec.Definition
@(ctx:Definition)
object Optimizer
{
val rules = Seq[Rule[LogicalPlan]](
@for(rule <- ctx.rules){
@rule.safeLabel,
}
)
def MAX_ITERATIONS = 100
def rewrite(plan: LogicalPlan): LogicalPlan =
{
var current = plan
var last = plan
for(i <- 0 until MAX_ITERATIONS)
{
for(rule <- rules)
{
plan = rule(plan)
}
if(last.fastEquals(plan))
{
return plan
}
}
return plan
}
}

View file

@ -0,0 +1,34 @@
@import com.astraldb.spec.Rule
@import com.astraldb.spec.Type
@import com.astraldb.spec.Definition
@import com.astraldb.codegen.Match
@import com.astraldb.codegen.Expression
@import com.astraldb.codegen.Code
@import com.astraldb.typecheck.TypecheckMatch
@(schema: Definition, rule: Rule)
object @{rule.safeLabel} extends Rule[LogicalPlan]
{
def apply(plan: LogicalPlan): LogicalPlan =
{
@Match(
schema = schema,
pattern = rule.pattern,
target = Code.Literal("plan"),
targetType = Type.AST(rule.family),
onSuccess = Expression(schema, rule.rewrite,
TypecheckMatch(
rule.pattern,
Some("plan"),
Type.AST(rule.family),
schema,
schema.globals
)
),
onFail = Code.Literal("plan"),
name = Some("plan"),
scope = schema.globals
).toString(4).stripPrefix(" ")
}
}

View file

@ -1,22 +1,40 @@
import mill._
import mill.scalalib._
import mill.scalalib.publish._
import $ivy.`com.lihaoyi::mill-contrib-twirllib:`, mill.twirllib._
object astral extends Module {
def scalaVersion = "3.2.1"
object astral extends Module
{
def scalaVersion = "3.3.0"
object compiler extends ScalaModule with PublishModule {
def compile = compiler.compile
object compiler extends ScalaModule
with PublishModule
with TwirlModule
{
val VERSION = "0.0.1-SNAPSHOT"
def scalaVersion = astral.scalaVersion
def twirlScalaVersion = scalaVersion
def twirlVersion = "1.6.0-RC4"
def mainClass = Some("com.astraldb.Astral")
def generatedSources = T{ Seq(compileTwirl().classes) }
/*************************************************
*** Twirl Config
*************************************************/
def twirlFormats = super.twirlFormats() ++ Map(
"scala" -> "play.twirl.api.TxtFormat"
)
/*************************************************
*** Backend Dependencies
*************************************************/
// def ivyDeps = Agg(
// )
def ivyDeps = Agg(
ivy"com.typesafe.play::twirl-api::${twirlVersion()}"
)
def publishVersion = VERSION
override def pomSettings = PomSettings(
@ -35,9 +53,53 @@ object astral extends Module {
{
def scalaVersion = astral.scalaVersion
def mainClass = Some("com.astraldb.catalyst.Astral")
def mainClass = Some("com.astraldb.catalyst.Generate")
def moduleDeps = Seq(astral.compiler)
def classPath = T{ Seq[PathRef](compile().classes) ++ resources() }
def rendered = T {
val target = T.dest
val files = scala.collection.mutable.Buffer[String]()
os.proc(
"scala",
"-cp",
classPath().mkString(":"),
mainClass().get,
target
).call(
cwd = target,
stdout = os.ProcessOutput.Readlines(
line => files += line
)
)
for(f <- files)
{
println(s"GOT : $f")
}
/* return */
files.map {
file => PathRef(target / file)
}.toSeq
}
def render(args: String*) = T.command {
println(rendered())
}
object impl extends ScalaModule
{
def scalaVersion = "2.13.8"
def generatedSources = T{ astral.catalyst.rendered() }
def ivyDeps = Agg(
ivy"org.apache.spark::spark-sql::3.4.1",
)
}
}
}