diff --git a/TODOs.md b/TODOs.md index 913a6d0..ae9954a 100644 --- a/TODOs.md +++ b/TODOs.md @@ -54,7 +54,7 @@ - [ ] CombineConcats, - [ ] PushdownPredicatesAndPruneColumnsForCTEDef - [ ] Codegen: Output a Spark-like optimizer based on the above rules - - [ ] Logic generation + - [x] Logic generation - [ ] Compilation pipeline to streamline testing - [ ] Test: Do these rules generate the same results as spark? - [ ] Apply the rule merge optimization diff --git a/astral/catalyst/impl/src/com/astraldb/catalyst/OptimizerTest.scala b/astral/catalyst/impl/src/com/astraldb/catalyst/OptimizerTest.scala new file mode 100644 index 0000000..ecc8d75 --- /dev/null +++ b/astral/catalyst/impl/src/com/astraldb/catalyst/OptimizerTest.scala @@ -0,0 +1,9 @@ +package com.astraldb.catalyst + +object OptimizerTest +{ + def main(args: Array[String]): Unit = + { + println("Hello world") + } +} \ No newline at end of file diff --git a/astral/catalyst/src/com/astraldb/catalyst/Astral.scala b/astral/catalyst/src/com/astraldb/catalyst/Astral.scala index 4dbe3e4..d734f54 100644 --- a/astral/catalyst/src/com/astraldb/catalyst/Astral.scala +++ b/astral/catalyst/src/com/astraldb/catalyst/Astral.scala @@ -1,5 +1,7 @@ package com.astraldb.catalyst +import com.astraldb.codegen.Render + object Astral { def main(args: Array[String]): Unit = @@ -12,6 +14,7 @@ object Astral System.exit(-1) } - println(Catalyst.definition) + // println(Catalyst.definition) + Render.print(Catalyst.definition) } } \ No newline at end of file diff --git a/astral/catalyst/src/com/astraldb/catalyst/Catalyst.scala b/astral/catalyst/src/com/astraldb/catalyst/Catalyst.scala index 8337659..0d2edd0 100644 --- a/astral/catalyst/src/com/astraldb/catalyst/Catalyst.scala +++ b/astral/catalyst/src/com/astraldb/catalyst/Catalyst.scala @@ -30,6 +30,52 @@ object Catalyst extends HardcodedDefinition ////////////////////////////////////////////////////// + Function("PushProjectionThroughUnion.projectListIsDeterministic", + Type.Bool, + )( + Type.Any + ) + + Function("PushProjectionThroughUnion.unionHasChildren", + Type.Bool, + )( + Type.Any + ) + + Function("PushProjectionThroughUnion.pushProjectionThroughUnion", + Type.Array(Type.AST("LogicalPlan")) + )( + Type.Any, + Type.Any + ) + + Function("ExtractFiltersAndInnerJoins", + Type.Struct(Map( + "_1" -> Type.Struct(Map("size" -> Type.Int)), + "_2" -> Type.Struct(Map("nonEmpty" -> Type.Bool)), + )) + )( + Type.AST("LogicalPlan") + ) + + Function("ReorderJoin.rewrite", Type.Any)( + Type.AST("LogicalPlan") + ) + + Function("EliminateOuterJoin.buildNewJoinType", + Type.Union(Seq( + Type.Native("RightOuter"), + Type.Native("LeftOuter"), + Type.Native("FullOuter"), + )) + )( + Type.AST("LogicalPlan") + ) + + Global("JoinHint.NONE", Type.Native("JoinHint")) + + ////////////////////////////////////////////////////// + Rule("PushProjectionThroughUnion", "LogicalPlan")( Match("Project")( Bind("projectList", MatchAny), diff --git a/astral/catalyst/src/com/astraldb/catalyst/Generate.scala b/astral/catalyst/src/com/astraldb/catalyst/Generate.scala new file mode 100644 index 0000000..adb7c37 --- /dev/null +++ b/astral/catalyst/src/com/astraldb/catalyst/Generate.scala @@ -0,0 +1,9 @@ +package com.astraldb.catalyst + +object Generate +{ + def main(args: Array[String]): Unit = + { + println("Optimizer.sql") + } +} \ No newline at end of file diff --git a/astral/compiler/src/com/astraldb/codegen/Code.scala b/astral/compiler/src/com/astraldb/codegen/Code.scala new file mode 100644 index 0000000..60a3c97 --- /dev/null +++ b/astral/compiler/src/com/astraldb/codegen/Code.scala @@ -0,0 +1,276 @@ +package com.astraldb.codegen + + +sealed trait Code +{ + def render(indent: Int): Code.PaddedString|Code.Lines + + def renderLines(indent: Int): Code.Lines = + render(indent) match { + case x:Code.PaddedString => + Code.Lines.ofLine(indent, x) + case x:Code.Lines => x + } + + def block = + this match { + case Code.Block(lines) => lines + case _ => Seq(this) + } + + override def toString(): String = + renderLines(0).body.mkString("\n") + def toString(indent: Int): String = + renderLines(indent).body.mkString("\n") +} + + + +object Code +{ + class PaddedString(val body: String, val lPad: Int = 0, val rPad: Int = 0) + + object PaddedString + { + def apply(elems: (String|PaddedString)*): PaddedString = + { + if(elems.isEmpty){ new PaddedString("") } + else { + val (buffer, lPad, rPad) = + elems.head match { + case s: String => + (StringBuffer(s), 0, 0) + case p: PaddedString => + (StringBuffer(p.body), p.lPad, p.rPad) + } + var pad = rPad + for(e <- elems.tail) + { + if(pad > 0){ buffer.append(" " * pad) } + e match { + case s:String => + buffer.append(s) + pad = 0 + case p:PaddedString => + buffer.append(" "*Math.max(0, p.lPad - pad)) + buffer.append(p.body) + pad = p.rPad + } + } + new PaddedString(buffer.toString, lPad, pad) + } + } + + def leftPad(pad: Int, body: String): PaddedString = + new PaddedString(body, lPad = pad) + def rightPad(pad: Int, body: String): PaddedString = + new PaddedString(body, rPad = pad) + def pad(pad: Int, body: String): PaddedString = + new PaddedString(body, lPad = pad, rPad = pad) + } + + + case class Lines(body: Seq[String]) + { + def ++(other: Lines): Lines = + new Lines(body ++ other.body) + } + object Lines + { + def ofLine(indent: Int, elems: (String|PaddedString)*): Lines = + { + val p = PaddedString(elems:_*) + new Lines(Seq( + " " * Math.max(indent, p.lPad) + p.body + )) + } + def ofLines(elems: (String|Lines)*): Lines = + new Lines( + elems.flatMap { + case s: String => Seq(s) + case Lines(l) => l + } + ) + val empty = new Lines(Seq.empty) + def flatten(elems: Seq[Lines]): Lines = + elems.foldLeft(empty){ _ ++ _ } + + def apply(body: Seq[String]): Lines = + new Lines(body) + } + + case object Empty extends Code + { + def render(indent: Int): PaddedString|Lines = + PaddedString("") + } + + case class Block( + lines: Seq[Code], + ) extends Code + { + def render(indent: Int): PaddedString|Lines = + Lines.flatten( + lines.map { _.renderLines(indent) } + ) + } + + case class Literal(txt: String) extends Code + { + def render(indent: Int): PaddedString|Lines = + PaddedString(txt) + } + + case class BinOp(left: Code, op: String|PaddedString, right: Code) extends Code + { + def render(indent: Int): PaddedString|Lines = + (left.render(indent), right.render(indent)) match { + case (l:PaddedString, r:PaddedString) => + PaddedString(l, op, r) + case (l:PaddedString, r:Lines) => + Lines.ofLine(indent, l, op) ++ r + case (l:Lines, r:PaddedString) => + l ++ Lines.ofLine(indent, op, r) + case (l:Lines, r:Lines) => + l ++ Lines.ofLine(indent+2, op) ++ r + } + } + + case class Parens(left: String|PaddedString, body: Code, right: String|PaddedString) extends Code + { + def render(indent: Int): PaddedString|Lines = + body.render(indent+2) match { + case b:PaddedString => + PaddedString(left, b, right) + case b:Lines => + Lines.ofLine(indent, left) ++ + b ++ + Lines.ofLine(indent, right) + } + } + + /** + * A "block" with a prefix + * + * Used to represent if-then, while, for, etc... blocks + */ + case class PrefixedBlock(prefix: Code, lParen: String|PaddedString, rParen: String|PaddedString, body: Code) extends Code + { + def render(indent: Int): PaddedString|Lines = + body.render(indent+2) match { + case b:PaddedString => prefix.render(indent) match { + case p:PaddedString => + PaddedString(p, lParen, b, rParen) + case p:Lines => + p ++ Lines.ofLine(indent, lParen, b, rParen) + } + case b:Lines => + prefix.renderLines(indent) ++ + Lines.ofLine(indent, lParen) ++ + b ++ + Lines.ofLine(indent, rParen) + } + } + + /** + * Two "block"s with a prefix and a joiner + * + * Used to represent if-then-else + */ + case class PrefixedBlockPair(prefix: Code, lParen: String|PaddedString, joiner: String|PaddedString, rParen: String|PaddedString, lBody: Code, rBody: Code) extends Code + { + def render(indent: Int): PaddedString|Lines = + { + val p = prefix.render(indent) + val l = lBody.render(indent+2) + val r = rBody.render(indent+2) + + val leftNeedsTrailingRParen = + p.isInstanceOf[Lines] || l.isInstanceOf[Lines] + + (p match { + case pp: PaddedString => + l match { + case lp: PaddedString => + Lines.ofLine(indent, pp, lParen, lp, rParen) + case ll: Lines => + Lines.ofLine(indent, pp, lParen) ++ + ll + } + case pl: Lines => + l match { + case lp: PaddedString => + pl ++ + Lines.ofLine(indent, lParen) ++ + Lines.ofLine(indent, lp) + case ll: Lines => + pl ++ + Lines.ofLine(indent, lParen) ++ + ll + } + }) ++ ((r, leftNeedsTrailingRParen) match { + case (rp: PaddedString, false) => + Lines.ofLine(indent, joiner, lParen, rp, rParen) + case (rp: PaddedString, true) => + Lines.ofLine(indent, rParen, joiner, lParen, rp, rParen) + case (rp: Lines, false) => + Lines.ofLine(indent, joiner, lParen) ++ + rp ++ + Lines.ofLine(indent, rParen) + case (rp: Lines, true) => + Lines.ofLine(indent, rParen, joiner, lParen) ++ + rp ++ + Lines.ofLine(indent, rParen) + }) + } + } + + def IfThenElse(condition: Code, thenBlock: Code, elseBlock: Code): Code = + Code.PrefixedBlockPair( + prefix = Code.Parens("if(",condition,")"), + lParen = PaddedString.rightPad(1, "{"), + rParen = PaddedString.leftPad(1, "}"), + joiner = PaddedString.pad(1, "else"), + lBody = thenBlock, + rBody = elseBlock + ) + + case class List(sep: String|PaddedString, elems: Seq[Code]) extends Code + { + def render(indent: Int): PaddedString|Lines = + { + val renderedElems = elems.map { _.render(indent+2) } + val multiLine = + renderedElems.exists { _.isInstanceOf[Lines] } + || renderedElems.map { _.asInstanceOf[PaddedString].body.size }.sum > 100 + if(multiLine){ + Lines( + renderedElems.flatMap { + case p:PaddedString => + Lines.ofLine(indent, p, sep).body + case Lines(Seq()) => + Lines.ofLine(indent, sep).body + case Lines(Seq(l)) => + Seq(PaddedString(l, sep).body) + case Lines(l) => + Seq(l.head) ++ + l.tail.dropRight(1).map { " " + _ } ++ + Seq(PaddedString(" ",l.last,sep).body) + } + ) + } else { + renderedElems.map { _.asInstanceOf[PaddedString] } match { + case Seq() => PaddedString("") + case Seq(a) => a + case a => + PaddedString( ( + Seq(a.head) ++ a.tail.flatMap { Seq(sep, _) } + ).toSeq:_*) + } + } + } + } + + def Parenthesize(body: Code): Code = + Parens("(", body, ")") +} \ No newline at end of file diff --git a/astral/compiler/src/com/astraldb/codegen/Expression.scala b/astral/compiler/src/com/astraldb/codegen/Expression.scala new file mode 100644 index 0000000..69f269c --- /dev/null +++ b/astral/compiler/src/com/astraldb/codegen/Expression.scala @@ -0,0 +1,64 @@ +package com.astraldb.codegen + +import com.astraldb.spec +import com.astraldb.expression._ +import com.astraldb.expression.{ Expression => Expr} +import com.astraldb.codegen.Code.PaddedString +import com.astraldb.typecheck.TypecheckExpression +import com.astraldb.spec.Type + +object Expression +{ + def apply(schema: spec.Definition, op: Expr, scope: Map[String, spec.Type]): Code = + { + op match { + case c:SimpleConstant => Code.Literal(c.asScala) + case Var(v) => Code.Literal(v) + case Arith(t, lhs, rhs) => + Code.BinOp( + Code.Parenthesize(apply(schema, lhs, scope)), + PaddedString.pad(1, ArithTypes.opString(t)), + Code.Parenthesize(apply(schema, rhs, scope)) + ) + case Cmp(t, lhs, rhs) => + Code.BinOp( + Code.Parenthesize(apply(schema, lhs, scope)), + PaddedString.pad(1, CmpTypes.opString(t)), + Code.Parenthesize(apply(schema, rhs, scope)) + ) + case FunctionCall(fn, args) => + Code.Literal(op.toString) + Code.PrefixedBlock( + prefix = Code.Literal(fn), + lParen = "(", + rParen = ")", + body = Code.List( + sep = PaddedString.rightPad(1, ","), + elems = args.map { apply(schema, _, scope) } + ) + ) + case MakeNode(node, fields) => + apply(schema, FunctionCall(node, fields), scope) + case NodeSubscript(target, index) => + Code.BinOp( + apply(schema, target, scope), + ".", + TypecheckExpression(target, schema, scope) match { + case Type.Node(nodeType) => + Code.Literal( + schema.nodesByName(nodeType).fields(index).name + ) + case c => + assert(false, s"Node subscript on something not a node: $op (in ${scope}): ${c.getClass.getSimpleName}") + } + ) + + case StructSubscript(target, name) => + Code.BinOp( + apply(schema, target, scope), + ".", + Code.Literal(name) + ) + } + } +} \ No newline at end of file diff --git a/astral/compiler/src/com/astraldb/codegen/Match.scala b/astral/compiler/src/com/astraldb/codegen/Match.scala new file mode 100644 index 0000000..7882926 --- /dev/null +++ b/astral/compiler/src/com/astraldb/codegen/Match.scala @@ -0,0 +1,140 @@ +package com.astraldb.codegen + +import com.astraldb.spec + +import spec.Match._ +import com.astraldb.codegen.Code.PaddedString +import com.astraldb.typecheck.TypecheckMatch +import com.astraldb.spec.Type + +object Match +{ + def apply( + schema: spec.Definition, + pattern: spec.Match, + target: Code, + targetType: spec.Type, + onSuccess: Code, + onFail: Code, + name: Option[String], + scope: Map[String, spec.Type] + ): Code = + { + pattern match { + case And(Seq()) => onSuccess + case And(Seq(a)) => apply(schema, a, target, targetType, onSuccess, onFail, name, scope) + case And(a) => + apply(schema, a.head, + target = target, + targetType = targetType, + onSuccess = + apply(schema, And(a.tail), target, targetType, onSuccess, onFail, name, + scope = TypecheckMatch(a.head, name, targetType, schema, scope) + ), + onFail = onFail, + name = name, + scope = scope + ) + case Not(a) => apply(schema, a, target, targetType, onFail, onSuccess, name, scope) + case Or(Seq()) => onFail + case Or(Seq(a)) => + apply(schema, a, target, targetType, onSuccess, onFail, name, scope) + case Or(a) => + apply(schema, a.head, + target = target, + targetType = targetType, + onSuccess = onSuccess, + onFail = apply(schema, Or(a.tail), target, targetType, onSuccess, onFail, name, scope), + name = name, + scope = scope + ) + case Bind(symbol, pattern) => + Code.Block( + Code.BinOp( + Code.Literal(s"val $symbol"), + Code.PaddedString.pad(1, "="), + target + ) +: + apply( + schema = schema, + pattern = pattern, + target = Code.Literal(symbol), + targetType = targetType, + onSuccess = onSuccess, + onFail = onFail, + name = Some(symbol), + scope = scope ++ Map(symbol -> targetType) + ).block + ) + case BindExpression(symbol, op) => + Code.Block( + Seq( + Code.BinOp( + Code.Literal(s"val $symbol"), + Code.PaddedString.pad(1, "="), + Expression(schema, op, scope) + ) + ) ++ onSuccess.block + ) + case Node(nodeLabel, children) => + { + val selectedName = + name.getOrElse { "genericNode" } + val node = schema.nodesByName(nodeLabel) + Code.IfThenElse( + condition = + Code.BinOp( + target, + ".", + Code.Literal(s"isInstanceOf[$nodeLabel]") + ), + thenBlock = + Code.Block( + Seq( + Code.BinOp( + Code.Literal(s"val $selectedName"), + Code.PaddedString.pad(1, "="), + Code.BinOp( + target, + ".", + Code.Literal(s"asInstanceOf[$nodeLabel]") + ) + ) + ) ++ + children + .zip(node.fields) + .foldRight(onSuccess) { case ((child, field), andThen) => + apply( + schema = schema, + pattern = child, + target = Code.Literal(s"$selectedName.${field.name}"), + targetType = Type.Node(nodeLabel), + onSuccess = andThen, + onFail = onFail, + name = Some(s"${selectedName}_${field.name}"), + scope = scope ++ Map(selectedName -> Type.Node(nodeLabel)) + ) + } + .block + ), + elseBlock = onFail + ) + } + case Test(op) => + Code.IfThenElse( + condition = Expression(schema, op, scope), + thenBlock = onSuccess, + elseBlock = onFail + ) + case Any => + onSuccess + case OfType(nodeType) => + Code.IfThenElse( + condition = Code.BinOp(target, ".", Code.Literal(s"isInstanceOf[${nodeType.scalaType}]")), + thenBlock = onSuccess, + elseBlock = onFail + ) + case _ => Code.Literal(s"??${pattern.getClass.getSimpleName}??") + } + } +} \ No newline at end of file diff --git a/astral/compiler/src/com/astraldb/codegen/Optimizer.scala b/astral/compiler/src/com/astraldb/codegen/Optimizer.scala new file mode 100644 index 0000000..1495cd1 --- /dev/null +++ b/astral/compiler/src/com/astraldb/codegen/Optimizer.scala @@ -0,0 +1,11 @@ +package com.astraldb.codegen + +import com.astraldb.spec.Definition + +object Optimizer +{ + def apply(definition: Definition): String = + { + scala.Optimizer(definition).toString + } +} \ No newline at end of file diff --git a/astral/compiler/src/com/astraldb/codegen/Render.scala b/astral/compiler/src/com/astraldb/codegen/Render.scala new file mode 100644 index 0000000..e9e1a35 --- /dev/null +++ b/astral/compiler/src/com/astraldb/codegen/Render.scala @@ -0,0 +1,25 @@ +package com.astraldb.codegen + +import com.astraldb.spec.Definition + +object Render +{ + def apply(schema: Definition): Map[String, String] = + { + Map( + "Optimizer.scala" -> Optimizer(schema) + ) ++ schema.rules.map { rule => + s"${rule.safeLabel}.scala" -> Rule(schema, rule) + }.toMap + } + + def print(schema: Definition): Unit = + { + + for( (file,content) <- apply(schema) ) + { + println(s"//////////////////// $file") + println(content) + } + } +} \ No newline at end of file diff --git a/astral/compiler/src/com/astraldb/codegen/Rule.scala b/astral/compiler/src/com/astraldb/codegen/Rule.scala new file mode 100644 index 0000000..92c63cd --- /dev/null +++ b/astral/compiler/src/com/astraldb/codegen/Rule.scala @@ -0,0 +1,11 @@ +package com.astraldb.codegen + +import com.astraldb.spec + +object Rule +{ + def apply(schema: spec.Definition, rule: spec.Rule): String = + { + scala.Rule(schema, rule).toString + } +} \ No newline at end of file diff --git a/astral/compiler/src/com/astraldb/expression/Eval.scala b/astral/compiler/src/com/astraldb/expression/Eval.scala index 9f5029c..18fc803 100644 --- a/astral/compiler/src/com/astraldb/expression/Eval.scala +++ b/astral/compiler/src/com/astraldb/expression/Eval.scala @@ -9,6 +9,7 @@ object Eval class TypeMismatch(problem: Expression, expr: Expression, lhsType: Type, rhsType: Type) extends EvalException(expr) class TypeError(problem: Expression, expr: Expression, problemType: Type) extends EvalException(expr) class UnsupportedExternalFeature(problem: Expression, expr: Expression) extends EvalException(expr) + class MissingFieldError(expr: Expression, field: String) extends EvalException(expr) def apply(base: Expression, scope: Map[String, Constant] = Map.empty): Constant = { @@ -117,12 +118,24 @@ object Eval elements(index) case c => throw new TypeError(e, base, c.t) } + case StructSubscript(target, name) => + rcr(target, scope) match { + case StructConstant(elements) => + elements.find { _._1 == name } + .getOrElse { + throw new MissingFieldError(e, name) + } + ._2 + case c => throw new TypeError(e, base, c.t) + } case FunctionCall(name, args) => throw new UnsupportedExternalFeature(e, base) case MakeNode(nodeType, fields) => NodeConstant(nodeType, fields.map { rcr(_, scope) }) case MakeArray(elements) => ArrayConstant(elements.map { rcr(_, scope) }) + case MakeStruct(elements) => + StructConstant(elements.map { case (f, e) => f -> rcr(e, scope) }) case Let(symbol, value, rest) => rcr(rest, scope ++ Map(symbol -> rcr(value, scope))) } diff --git a/astral/compiler/src/com/astraldb/expression/Expression.scala b/astral/compiler/src/com/astraldb/expression/Expression.scala index a5eab63..0531a9e 100755 --- a/astral/compiler/src/com/astraldb/expression/Expression.scala +++ b/astral/compiler/src/com/astraldb/expression/Expression.scala @@ -109,6 +109,7 @@ sealed trait SimpleConstant extends Constant { def fields: Seq[Constant] = Seq.empty def reassembleFields(in: Seq[Constant]): Expression = this + def asScala: String = toString } /** * A constant with components (Nodes, Arrays) diff --git a/astral/compiler/src/com/astraldb/spec/Definition.scala b/astral/compiler/src/com/astraldb/spec/Definition.scala index 731bb25..77d12b5 100755 --- a/astral/compiler/src/com/astraldb/spec/Definition.scala +++ b/astral/compiler/src/com/astraldb/spec/Definition.scala @@ -5,7 +5,8 @@ import com.astraldb.expression._ case class Definition( nodes:Map[String, Map[String, Node]], - rules:Seq[Rule] + rules:Seq[Rule], + globals: Map[String, Type], ) { override def toString = s"""/////// ASTs ////// @@ -18,6 +19,8 @@ case class Definition( nodes.flatMap { case (family, elements) => elements.keys.map { _ -> family } } .toMap + val nodesByName: Map[String, Node] = + nodes.values.flatten.toMap def validate() = { @@ -30,11 +33,13 @@ class HardcodedDefinition lazy val definition: Definition = Definition( nodes = nodes.mapValues { _.map { n => n.name -> n }.toMap }.toMap, - rules = rules.toSeq + rules = rules.toSeq, + globals = globals.toMap ) val nodes = mutable.Map[String, mutable.Buffer[Node]]() val rules = mutable.Buffer[Rule]() + val globals = mutable.Map[String,Type]() import FieldConversions._ @@ -51,6 +56,16 @@ class HardcodedDefinition com.astraldb.spec.Rule(label, family, pattern, replacement) ) + def Function(label: String, ret: Type = Type.Unit)(args: Type*): Unit = + { + globals(label) = Type.Function(args,ret) + } + + def Global(label: String, t: Type): Unit = + { + globals(label) = t + } + //////////////////////// Matchers def Match(node: String)(fields: Match*): Match = com.astraldb.spec.Match.Node(node, fields) diff --git a/astral/compiler/src/com/astraldb/spec/Match.scala b/astral/compiler/src/com/astraldb/spec/Match.scala index 30e3294..056d094 100644 --- a/astral/compiler/src/com/astraldb/spec/Match.scala +++ b/astral/compiler/src/com/astraldb/spec/Match.scala @@ -1,6 +1,5 @@ package com.astraldb.spec -import com.astraldb.ast._ import com.astraldb.spec.Match.Scope import com.astraldb.expression._ diff --git a/astral/compiler/src/com/astraldb/spec/Rule.scala b/astral/compiler/src/com/astraldb/spec/Rule.scala index 0217ccc..e9fe6e1 100644 --- a/astral/compiler/src/com/astraldb/spec/Rule.scala +++ b/astral/compiler/src/com/astraldb/spec/Rule.scala @@ -17,7 +17,11 @@ case class Rule( def validate(schema: Definition): Unit = { val scope = - TypecheckMatch(pattern, Type.AST(family), schema, Map.empty) + TypecheckMatch(pattern, None, Type.AST(family), schema, schema.globals) TypecheckExpression(rewrite, schema, scope) } + + def safeLabel = + label.replaceAll("[^A-Za-z0-9]+", "_") + .replace("^([0-9])","_\\1") } \ No newline at end of file diff --git a/astral/compiler/src/com/astraldb/spec/Type.scala b/astral/compiler/src/com/astraldb/spec/Type.scala index 1fb2b3d..b3f8c49 100755 --- a/astral/compiler/src/com/astraldb/spec/Type.scala +++ b/astral/compiler/src/com/astraldb/spec/Type.scala @@ -8,6 +8,8 @@ sealed trait Type { case Type.Union(of) => of case _ => Seq(this) } + + def scalaType: String } object Type @@ -15,42 +17,52 @@ object Type case class Native(name: String) extends Type { override def toString: String = s"Native[${name}]" + def scalaType: String = name } case class AST(family: String) extends Type { override def toString: String = s"Ast[${family}]" + def scalaType: String = family } case class Node(nodeType: String) extends Type { override def toString: String = s"Node[${nodeType}]" + def scalaType: String = nodeType } case class Struct(fields: Map[String, Type]) extends Type { override def toString: String = s"Struct {${fields.map { x => s"${x._1}: ${x._2}"}.mkString(", ")}}" + def scalaType: String = ??? } case class Option(base: Type) extends Type { override def toString: String = s"Option[$base]" + def scalaType: String = s"Option[${base.scalaType}]" } case class Array(base: Type) extends Type { override def toString: String = s"Array[${base}]" + def scalaType: String = s"Array[${base.scalaType}]" } case class Function(args: Seq[Type], ret: Type) extends Type { override def toString: String = s"(${args.mkString(", ")}) => $ret" + def scalaType: String = s"(${args.map { _.scalaType }.mkString(", ")}) => ${ret.scalaType}" } case object Unit extends Type { override def toString: String = "Unit" + def scalaType: String = "Unit" } case object Any extends Type { override def toString: String = "Any" + def scalaType: String = "Any" } case class Union(of: Seq[Type]) extends Type { override def toString: String = s"Union(${of.mkString(" | ")})" + def scalaType: String = of.map { _.scalaType }.mkString("|") } sealed trait PrimType extends Type @@ -58,18 +70,22 @@ object Type case object Int extends PrimType { override def toString: String = "int" + def scalaType: String = "Int" } case object Float extends PrimType { override def toString: String = "float" + def scalaType: String = "Double" } case object Bool extends PrimType { override def toString: String = "bool" + def scalaType: String = "Boolean" } case object String extends PrimType { override def toString: String = "string" + def scalaType: String = "String" } def union(ts: Seq[Type]) = diff --git a/astral/compiler/src/com/astraldb/typecheck/Typecheck.scala b/astral/compiler/src/com/astraldb/typecheck/Typecheck.scala new file mode 100644 index 0000000..cf2b628 --- /dev/null +++ b/astral/compiler/src/com/astraldb/typecheck/Typecheck.scala @@ -0,0 +1,26 @@ +package com.astraldb.typecheck + +import com.astraldb.spec._ +import com.astraldb.expression._ + +object Typecheck +{ + + def escalatesTo(source: Type, target: Type, schema: Definition): Boolean = + { + (source, target) match { + case (a, b) if a == b => true + case (Type.Node(label), Type.AST(family)) => + schema.nodes.get(family) + .map { _ contains label } + .getOrElse { false } + case (Type.Union(elems), a) => + elems.forall { escalatesTo(_, a, schema) } + case (a, Type.Union(elems)) => + elems.forall { escalatesTo(a, _, schema) } + case (Type.Native(_), Type.Native(_)) => true // don't try to encode native inheritance + case (_, Type.Any) => true + case (_, _) => false + } + } +} \ No newline at end of file diff --git a/astral/compiler/src/com/astraldb/typecheck/TypecheckExpression.scala b/astral/compiler/src/com/astraldb/typecheck/TypecheckExpression.scala index a64a38e..bb5f114 100644 --- a/astral/compiler/src/com/astraldb/typecheck/TypecheckExpression.scala +++ b/astral/compiler/src/com/astraldb/typecheck/TypecheckExpression.scala @@ -5,9 +5,123 @@ import com.astraldb.expression._ object TypecheckExpression { + class TypecheckException(expr: Expression) extends Exception + class TypeMismatch(expr: Expression, a: Type, b: Type) extends TypecheckException(expr) + + def apply(expr: Expression, schema: Definition, scope: Map[String, Type]): Type = { // TODO - return Type.Any + expr match { + case c:Constant => c.t + + case Arith(op, lhs, rhs) => + op match { + case ArithTypes.Add | ArithTypes.Mul | ArithTypes.Sub | ArithTypes.Div => + ( + apply(lhs, schema, scope), + apply(rhs, schema, scope), + ) match { + case (Type.Int, Type.Int) => Type.Int + case (Type.Float, Type.Float) => Type.Float + case (Type.String, Type.String) if op == ArithTypes.Add => Type.String + case (a, b) => assert(false, s"Mismatched types $a and $b: $expr") + } + case ArithTypes.And | ArithTypes.Or => + assert(apply(lhs, schema, scope) == Type.Bool, s"Mismatched types $lhs is not a boolean: $expr") + assert(apply(rhs, schema, scope) == Type.Bool, s"Mismatched types $rhs is not a boolean: $expr") + Type.Bool + } + + case Cmp(op, lhs, rhs) => + op match { + case CmpTypes.Eq | CmpTypes.Neq => + val l = apply(lhs, schema, scope) + val r = apply(rhs, schema, scope) + assert(l == r, s"Comparison between mismatched types: $l and $r: $expr") + Type.Bool + case _ => + ( + apply(lhs, schema, scope), + apply(rhs, schema, scope), + ) match { + case (Type.Int, Type.Int) => Type.Bool + case (Type.Float, Type.Float) => Type.Bool + case (a, b) => assert(false, s"Comparison between mismatched types $a and $b: $expr") + } + } + + case Var(v) => scope(v) + + case FunctionCall(fn, args) => + scope.get(fn) match { + case Some(Type.Function(argTypes, ret)) => + for( (arg, argType) <- args.zip(argTypes) ) + { + val actualType = apply(arg, schema, scope) + assert( + Typecheck.escalatesTo(actualType, argType, schema), + s"Mismatched argument $arg ($actualType doesn't escalate to $argType): $expr" + ) + } + ret + case Some(c) => + assert(false, s"Trying to call $fn, which is a $c: $expr") + case None => + assert(false, s"Trying to call $fn, which is not defined (in ${scope.keys.mkString(", ")}): $expr") + + } + + case MakeNode(nodeType, fields) => + val node = schema.nodesByName(nodeType) + for( (fieldConstructor, fieldType) <- fields.zip(node.fields) ) + { + val actualType = apply(fieldConstructor, schema, scope) + assert(Typecheck.escalatesTo(actualType, fieldType.t, schema), + s"Mismatched node constructor argument $fieldConstructor ($actualType doesn't escalate to ${fieldType.t}): $expr" + ) + } + Type.Node(nodeType) + + case NodeSubscript(target, index) => + apply(target, schema, scope) match { + case Type.Node(nodeType) => + val node = schema.nodesByName(nodeType) + node.fields(index).t + case _ => + assert(false, s"Node subscript on something not a node: $expr") + } + + case StructSubscript(target, name) => + apply(target, schema, scope) match { + case Type.Struct(fields) => + assert(fields contains name, + s"Struct subscript on a struct that doesn't contain $name ($fields): $expr" + ) + fields(name) + case Type.Node(nodeType) => + val node = schema.nodesByName(nodeType) + node.fields.find { _._1 == name } + .getOrElse { + assert(false, + s"Struct subscript on a node that doesn't contain $name (${node.fields}): $expr" + ) + } + ._2 + case _ => + assert(false, s"Struct subscript on something not a struct: $expr") + } + + case FunctionalIfThenElse(condition, ifTrue, ifFalse) => + assert(apply(condition, schema, scope) == Type.Bool, + s"If then else with a non-boolean condition ($condition): $expr" + ) + val a = apply(ifTrue, schema, scope) + val b = apply(ifFalse, schema, scope) + assert(a == b, + s"If then else with mismatched return types $a and $b: $expr" + ) + a + } } } \ No newline at end of file diff --git a/astral/compiler/src/com/astraldb/typecheck/TypecheckMatch.scala b/astral/compiler/src/com/astraldb/typecheck/TypecheckMatch.scala index 3c1aae6..19d5a20 100644 --- a/astral/compiler/src/com/astraldb/typecheck/TypecheckMatch.scala +++ b/astral/compiler/src/com/astraldb/typecheck/TypecheckMatch.scala @@ -6,12 +6,12 @@ import com.astraldb.expression._ object TypecheckMatch { def apply(pattern: Match, family: String, schema: Definition, scope: Map[String, Type]): Map[String, Type] = - apply(pattern, Type.AST(family), schema, scope) + apply(pattern, None, Type.AST(family), schema, scope) - def apply(pattern: Match, expectedType: Type, schema: Definition, scope: Map[String, Type]): Map[String, Type] = + def apply(pattern: Match, target: Option[String], expectedType: Type, schema: Definition, scope: Map[String, Type]): Map[String, Type] = { def checkAllChildren() = - pattern.children.foreach { apply(_, expectedType, schema, scope) } + pattern.children.foreach { apply(_, target, expectedType, schema, scope) } def astSchema(label: String): Node = { @@ -29,15 +29,15 @@ object TypecheckMatch } pattern match { - case Match.Not(child) => apply(child, expectedType, schema, scope) + case Match.Not(child) => apply(child, target, expectedType, schema, scope) case Match.And(children) => children.foldLeft(scope) { (scope, child) => - apply(child, expectedType, schema, scope) + apply(child, target, expectedType, schema, scope) } case Match.Or(children) => - children.flatMap { apply(_, expectedType, schema, scope).toSeq } + children.flatMap { apply(_, target, expectedType, schema, scope).toSeq } .groupBy { _._1 } .mapValues { childTypes => val base = @@ -49,13 +49,13 @@ object TypecheckMatch case Match.OfType(t) => - assert(escalatesTo(t, expectedType, schema), + assert(Typecheck.escalatesTo(t, expectedType, schema), s"Matching a node by type; Type $t can not possibly matched by an element of type $expectedType: $pattern" ) - scope + scope ++ target.map { _ -> t }.toMap case Match.Exact(child) => - assert(escalatesTo(child.t, expectedType, schema), + assert(Typecheck.escalatesTo(child.t, expectedType, schema), s"Matching a node by value; Value $child can not possibly match an element of type $expectedType: $pattern" ) scope @@ -74,7 +74,7 @@ object TypecheckMatch assert(false, s"Can't path into a simple type: $t: $pattern") } } - apply(child, targetType, schema, scope) + apply(child, target, targetType, schema, scope) case Match.Recursive(_) => // not entirely sure how to typecheck this... @@ -82,7 +82,7 @@ object TypecheckMatch ??? case Match.Bind(symbol, child) => - apply(child, expectedType, schema, scope ++ Map(symbol -> expectedType)) + apply(child, Some(symbol), expectedType, schema, scope ++ Map(symbol -> expectedType)) case Match.Any => scope @@ -92,23 +92,23 @@ object TypecheckMatch case Match.Lookup(symbol) => assert(scope contains symbol, s"$symbol is not bound: $pattern") - assert(escalatesTo(scope(symbol), expectedType, schema), s"$symbol (of type ${scope(symbol)}) can not possibly match an element of type $expectedType") + assert(Typecheck.escalatesTo(scope(symbol), expectedType, schema), s"$symbol (of type ${scope(symbol)}) can not possibly match an element of type $expectedType") scope case Match.ApplyToScope(symbol, child) => assert(scope contains symbol, s"$symbol is not bound: $pattern") - apply(child, scope(symbol), schema, scope) + apply(child, target, scope(symbol), schema, scope) case Match.Forall(child) => expectedType match { case Type.Array(base) => - apply(child, base, schema, scope) + apply(child, target, base, schema, scope) case _ => assert(false, s"Forall match on something other than an array: $pattern") } case Match.Test(op) => - assert(escalatesTo(TypecheckExpression(op, schema, scope), Type.Bool, schema), + assert(Typecheck.escalatesTo(TypecheckExpression(op, schema, scope), Type.Bool, schema), s"Non-boolean test expression: $op" ) scope @@ -126,23 +126,10 @@ object TypecheckMatch var runningScope = scope for( (Field(fieldName, fieldType), fieldMatch) <- nodeSchema.fields.zip(fields) ) { - runningScope = apply(fieldMatch, fieldType, schema, runningScope) + runningScope = apply(fieldMatch, target, fieldType, schema, runningScope) } - runningScope + runningScope ++ target.map { _ -> Type.Node(label) }.toMap } } - def escalatesTo(source: Type, target: Type, schema: Definition): Boolean = - { - (source, target) match { - case (a, b) if a == b => true - case (Type.Node(label), Type.AST(family)) => - schema.nodes.get(family) - .map { _ contains label } - .getOrElse { false } - case (Type.Native(_), Type.Native(_)) => true // don't try to encode native inheritance - case (Type.Any, _) => true - case (_, _) => false - } - } } \ No newline at end of file diff --git a/astral/compiler/views/com/astraldb/codegen/Optimizer.scala.scala b/astral/compiler/views/com/astraldb/codegen/Optimizer.scala.scala new file mode 100644 index 0000000..ac15f6b --- /dev/null +++ b/astral/compiler/views/com/astraldb/codegen/Optimizer.scala.scala @@ -0,0 +1,33 @@ +@import com.astraldb.spec.Definition + +@(ctx:Definition) + +object Optimizer +{ + val rules = Seq[Rule[LogicalPlan]]( + @for(rule <- ctx.rules){ + @rule.safeLabel, + } + ) + + def MAX_ITERATIONS = 100 + + def rewrite(plan: LogicalPlan): LogicalPlan = + { + var current = plan + var last = plan + + for(i <- 0 until MAX_ITERATIONS) + { + for(rule <- rules) + { + plan = rule(plan) + } + if(last.fastEquals(plan)) + { + return plan + } + } + return plan + } +} \ No newline at end of file diff --git a/astral/compiler/views/com/astraldb/codegen/Rule.scala.scala b/astral/compiler/views/com/astraldb/codegen/Rule.scala.scala new file mode 100644 index 0000000..86401ab --- /dev/null +++ b/astral/compiler/views/com/astraldb/codegen/Rule.scala.scala @@ -0,0 +1,34 @@ +@import com.astraldb.spec.Rule +@import com.astraldb.spec.Type +@import com.astraldb.spec.Definition +@import com.astraldb.codegen.Match +@import com.astraldb.codegen.Expression +@import com.astraldb.codegen.Code +@import com.astraldb.typecheck.TypecheckMatch + +@(schema: Definition, rule: Rule) + +object @{rule.safeLabel} extends Rule[LogicalPlan] +{ + def apply(plan: LogicalPlan): LogicalPlan = + { + @Match( + schema = schema, + pattern = rule.pattern, + target = Code.Literal("plan"), + targetType = Type.AST(rule.family), + onSuccess = Expression(schema, rule.rewrite, + TypecheckMatch( + rule.pattern, + Some("plan"), + Type.AST(rule.family), + schema, + schema.globals + ) + ), + onFail = Code.Literal("plan"), + name = Some("plan"), + scope = schema.globals + ).toString(4).stripPrefix(" ") + } +} \ No newline at end of file diff --git a/build.sc b/build.sc index a05f8cc..93fe37c 100644 --- a/build.sc +++ b/build.sc @@ -1,22 +1,40 @@ import mill._ import mill.scalalib._ import mill.scalalib.publish._ +import $ivy.`com.lihaoyi::mill-contrib-twirllib:`, mill.twirllib._ -object astral extends Module { - def scalaVersion = "3.2.1" +object astral extends Module +{ + def scalaVersion = "3.3.0" - object compiler extends ScalaModule with PublishModule { + def compile = compiler.compile + + object compiler extends ScalaModule + with PublishModule + with TwirlModule + { val VERSION = "0.0.1-SNAPSHOT" def scalaVersion = astral.scalaVersion + def twirlScalaVersion = scalaVersion + def twirlVersion = "1.6.0-RC4" def mainClass = Some("com.astraldb.Astral") + def generatedSources = T{ Seq(compileTwirl().classes) } + + /************************************************* + *** Twirl Config + *************************************************/ + def twirlFormats = super.twirlFormats() ++ Map( + "scala" -> "play.twirl.api.TxtFormat" + ) /************************************************* *** Backend Dependencies *************************************************/ - // def ivyDeps = Agg( - // ) + def ivyDeps = Agg( + ivy"com.typesafe.play::twirl-api::${twirlVersion()}" + ) def publishVersion = VERSION override def pomSettings = PomSettings( @@ -35,9 +53,53 @@ object astral extends Module { { def scalaVersion = astral.scalaVersion - def mainClass = Some("com.astraldb.catalyst.Astral") + def mainClass = Some("com.astraldb.catalyst.Generate") def moduleDeps = Seq(astral.compiler) + + def classPath = T{ Seq[PathRef](compile().classes) ++ resources() } + + def rendered = T { + val target = T.dest + val files = scala.collection.mutable.Buffer[String]() + os.proc( + "scala", + "-cp", + classPath().mkString(":"), + mainClass().get, + target + ).call( + cwd = target, + stdout = os.ProcessOutput.Readlines( + line => files += line + ) + ) + + for(f <- files) + { + println(s"GOT : $f") + } + + /* return */ + files.map { + file => PathRef(target / file) + }.toSeq + } + + def render(args: String*) = T.command { + println(rendered()) + } + + object impl extends ScalaModule + { + def scalaVersion = "2.13.8" + + def generatedSources = T{ astral.catalyst.rendered() } + + def ivyDeps = Agg( + ivy"org.apache.spark::spark-sql::3.4.1", + ) + } } }