Compiler compiles

This commit is contained in:
Oliver Kennedy 2023-07-08 18:03:20 -04:00
parent 9348b1cbc1
commit 8ac8bfec8b
Signed by: okennedy
GPG key ID: 3E5F9B3ABD3FDB60
8 changed files with 205 additions and 72 deletions

View file

@ -1,9 +1,43 @@
package com.astraldb.catalyst
import org.apache.spark.sql.{ Row, SparkSession }
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
object OptimizerTest
{
def main(args: Array[String]): Unit =
{
println("Hello world")
val spark: SparkSession =
SparkSession.builder
.appName("OptimizerTest")
.master("local")
.getOrCreate()
val r = spark.emptyDataFrame
.select(lit(1) as "A", lit(2) as "B")
r.createOrReplaceTempView("R")
val df = spark.sql("SELECT * FROM R")
println(df.queryExecution.logical)
println(df.queryExecution.analyzed)
println("------------------")
val optimized =
Time("Astral Optimizer") {
Optimizer.rewrite(df.queryExecution.analyzed)
}
println("------------------\nAstral Optimized Query:\n")
println(optimized)
println("------------------")
val sparkOptimized =
Time("Spark Optimizer") {
df.queryExecution.optimizedPlan
}
println("------------------\nSpark Optimized Query:\n")
println(sparkOptimized)
println("------------------")
}
}

View file

@ -0,0 +1,15 @@
package com.astraldb.catalyst
object Time
{
def apply[T](label: String)(body: => T): T =
{
val start = System.currentTimeMillis()
val ret = body
val end = System.currentTimeMillis()
println(s"$label: ${(end - start).toFloat / 1000.0}s")
return ret
}
}

View file

@ -242,10 +242,10 @@ object Code
val renderedElems = elems.map { _.render(indent+2) }
val multiLine =
renderedElems.exists { _.isInstanceOf[Lines] }
|| renderedElems.map { _.asInstanceOf[PaddedString].body.size }.sum > 100
|| renderedElems.map { _.asInstanceOf[PaddedString].body.size }.sum > 60
if(multiLine){
Lines(
renderedElems.flatMap {
renderedElems.dropRight(1).flatMap {
case p:PaddedString =>
Lines.ofLine(indent, p, sep).body
case Lines(Seq()) =>
@ -253,10 +253,20 @@ object Code
case Lines(Seq(l)) =>
Seq(PaddedString(l, sep).body)
case Lines(l) =>
Seq(l.head) ++
l.tail.dropRight(1).map { " " + _ } ++
Seq(PaddedString(" ",l.last,sep).body)
}
l.dropRight(1) ++
Seq(PaddedString(l.last,sep).body)
} ++
(renderedElems.last match {
case p:PaddedString =>
Lines.ofLine(indent, p).body
case Lines(Seq()) =>
Lines.ofLine(indent).body
case Lines(Seq(l)) =>
Seq(PaddedString(l).body)
case Lines(l) =>
l.dropRight(1) ++
Seq(PaddedString(l.last,sep).body)
})
)
} else {
renderedElems.map { _.asInstanceOf[PaddedString] } match {

View file

@ -0,0 +1,45 @@
package com.astraldb.codegen
import com.astraldb.spec
class CodeScope(vars: Map[String, spec.Type | (Code, spec.Type)])
{
def typeOf(name: String): spec.Type =
vars(name) match {
case t:spec.Type => t
case c:(Code,spec.Type) => c._2
}
def codeOf(name: String): Code =
vars(name) match {
case t:spec.Type => Code.Literal(name)
case c:(Code,spec.Type) => c._1
}
def withVars(nameVal: (String, spec.Type|(Code, spec.Type))*): CodeScope =
new CodeScope(vars ++ nameVal.toMap)
def refine(code: Code, updatedType: spec.Type): CodeScope =
new CodeScope(vars.mapValues {
case c:(Code,spec.Type) if c._1 == code => c._1 -> updatedType
case v => v
}.toMap)
def refine(code: Code, updatedCode: Code, updatedType: spec.Type): CodeScope =
new CodeScope(vars.mapValues {
case c:(Code,spec.Type) if c._1 == code => updatedCode -> updatedType
case v => v
}.toMap)
def flatten: Map[String, spec.Type] =
vars.mapValues {
case t:spec.Type => t
case c:(Code,spec.Type) => c._2
}.toMap
override def toString: String =
"{ " + vars.map { v => v._1 + "->" + (v._2 match {
case t:spec.Type => t.toString
case c:(Code,spec.Type) => c._1.toString+":"+c._2.toString
})}.mkString(", ") + " }"
}

View file

@ -9,16 +9,24 @@ import com.astraldb.spec.Type
object Expression
{
def apply(schema: spec.Definition, op: Expr, scope: Map[String, spec.Type]): Code =
def assoc(op: Expr, t: ArithTypes.T): Seq[Expr] =
op match {
case Arith(at, l, r) if at == t =>
assoc(l, t) ++ assoc(r, t)
case _ => Seq(op)
}
def apply(schema: spec.Definition, op: Expr, scope: CodeScope): Code =
{
op match {
case c:SimpleConstant => Code.Literal(c.asScala)
case Var(v) => Code.Literal(v)
case Var(v) => scope.codeOf(v)
case Arith(t, lhs, rhs) =>
Code.BinOp(
Code.Parenthesize(apply(schema, lhs, scope)),
Code.List(
PaddedString.pad(1, ArithTypes.opString(t)),
Code.Parenthesize(apply(schema, rhs, scope))
assoc(op, t)
.map { e => Code.Parenthesize(apply(schema, e, scope)) }
)
case Cmp(t, lhs, rhs) =>
Code.BinOp(
@ -28,10 +36,9 @@ object Expression
)
case FunctionCall(fn, args) =>
Code.Literal(op.toString)
Code.PrefixedBlock(
prefix = Code.Literal(fn),
lParen = "(",
rParen = ")",
Code.Parens(
left = s"$fn(",
right = ")",
body = Code.List(
sep = PaddedString.rightPad(1, ","),
elems = args.map { apply(schema, _, scope) }
@ -43,10 +50,14 @@ object Expression
Code.BinOp(
apply(schema, target, scope),
".",
TypecheckExpression(target, schema, scope) match {
TypecheckExpression(target, schema, scope.flatten) match {
case Type.Node(nodeType) =>
val node =
schema.nodesByName.get(nodeType)
.getOrElse { assert(false, s"No node of type $nodeType: $op")}
assert(node.fields.size > index, s"Node type $nodeType only has ${node.fields.size} fields (field $index requested): $op in $scope")
Code.Literal(
schema.nodesByName(nodeType).fields(index).name
node.fields(index).name
)
case c =>
assert(false, s"Node subscript on something not a node: $op (in ${scope}): ${c.getClass.getSimpleName}")

View file

@ -6,6 +6,7 @@ import spec.Match._
import com.astraldb.codegen.Code.PaddedString
import com.astraldb.typecheck.TypecheckMatch
import com.astraldb.spec.Type
import com.astraldb.typecheck.TypecheckExpression
object Match
{
@ -14,29 +15,36 @@ object Match
pattern: spec.Match,
target: Code,
targetType: spec.Type,
onSuccess: Code,
onFail: Code,
onSuccess: CodeScope => Code,
onFail: CodeScope => Code,
name: Option[String],
scope: Map[String, spec.Type]
scope: CodeScope
): Code =
{
pattern match {
case And(Seq()) => onSuccess
case And(Seq()) => onSuccess(scope)
case And(Seq(a)) => apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
case And(a) =>
apply(schema, a.head,
target = target,
targetType = targetType,
onSuccess =
apply(schema, And(a.tail), target, targetType, onSuccess, onFail, name,
scope = TypecheckMatch(a.head, name, targetType, schema, scope)
),
scope =>
apply(schema, And(a.tail),
target = target,
targetType = targetType,
onSuccess = onSuccess,
onFail = onFail,
name = name,
scope = scope
),
onFail = onFail,
name = name,
scope = scope
)
case Not(a) => apply(schema, a, target, targetType, onFail, onSuccess, name, scope)
case Or(Seq()) => onFail
case Not(a) =>
apply(schema, a, target, targetType, onFail, onSuccess, name, scope)
case Or(Seq()) => onFail(scope)
case Or(Seq(a)) =>
apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
case Or(a) =>
@ -44,27 +52,28 @@ object Match
target = target,
targetType = targetType,
onSuccess = onSuccess,
onFail = apply(schema, Or(a.tail), target, targetType, onSuccess, onFail, name, scope),
onFail =
_ => apply(schema, Or(a.tail),
target = target,
targetType = targetType,
onSuccess = onSuccess,
onFail = onFail,
name = name,
scope = scope
),
name = name,
scope = scope
)
case Bind(symbol, pattern) =>
Code.Block(
Code.BinOp(
Code.Literal(s"val $symbol"),
Code.PaddedString.pad(1, "="),
target
) +:
apply(
schema = schema,
pattern = pattern,
target = Code.Literal(symbol),
targetType = targetType,
onSuccess = onSuccess,
onFail = onFail,
name = Some(symbol),
scope = scope ++ Map(symbol -> targetType)
).block
apply(
schema = schema,
pattern = pattern,
target = target,
targetType = targetType,
onSuccess = onSuccess,
onFail = onFail,
name = Some(symbol),
scope = scope.withVars(symbol -> (target, targetType))
)
case BindExpression(symbol, op) =>
Code.Block(
@ -74,13 +83,19 @@ object Match
Code.PaddedString.pad(1, "="),
Expression(schema, op, scope)
)
) ++ onSuccess.block
) ++ onSuccess(
scope.withVars(symbol -> (Code.Literal(symbol),
TypecheckExpression(op, schema, scope.flatten)
))
).block
)
case Node(nodeLabel, children) =>
{
val selectedName =
name.getOrElse { "genericNode" }
val node = schema.nodesByName(nodeLabel)
val selectedName =
name.getOrElse { "genericNode" }+"_t"
val updatedScope =
scope.refine(target, Code.Literal(selectedName), Type.Node(nodeLabel))
Code.IfThenElse(
condition =
Code.BinOp(
@ -104,33 +119,33 @@ object Match
children
.zip(node.fields)
.foldRight(onSuccess) { case ((child, field), andThen) =>
apply(
nextScope => apply(
schema = schema,
pattern = child,
target = Code.Literal(s"$selectedName.${field.name}"),
targetType = Type.Node(nodeLabel),
targetType = field.t,
onSuccess = andThen,
onFail = onFail,
name = Some(s"${selectedName}_${field.name}"),
scope = scope ++ Map(selectedName -> Type.Node(nodeLabel))
scope = nextScope
)
}
}(updatedScope)
.block
),
elseBlock = onFail
elseBlock = onFail(scope)
)
}
case Test(op) =>
Code.IfThenElse(
condition = Expression(schema, op, scope),
thenBlock = onSuccess,
elseBlock = onFail
thenBlock = onSuccess(scope),
elseBlock = onFail(scope)
)
case Any =>
onSuccess
onSuccess(scope)
case OfType(nodeType) =>
val selectedName =
name.getOrElse { "genericNode" }
name.getOrElse { "genericNode" } + "_t"
Code.IfThenElse(
condition = Code.BinOp(target, ".", Code.Literal(s"isInstanceOf[${nodeType.scalaType}]")),
thenBlock =
@ -145,11 +160,13 @@ object Match
Code.Literal(s"asInstanceOf[${nodeType.scalaType}]")
)
)
)++onSuccess.block
)++onSuccess(scope.refine(target, Code.Literal(selectedName), nodeType)).block
),
elseBlock = onFail
elseBlock = onFail(scope)
)
case _ => Code.Literal(s"??${pattern.getClass.getSimpleName}??")
}
}
}

View file

@ -5,9 +5,8 @@
object Optimizer
{
val rules = Seq[Rule[LogicalPlan]](
@for(rule <- ctx.rules){
@rule.safeLabel,
}
@for(rule <- ctx.rules){
@rule.safeLabel, }
)
def MAX_ITERATIONS = 100

View file

@ -4,6 +4,7 @@
@import com.astraldb.codegen.Match
@import com.astraldb.codegen.Expression
@import com.astraldb.codegen.Code
@import com.astraldb.codegen.CodeScope
@import com.astraldb.typecheck.TypecheckMatch
@(schema: Definition, rule: Rule)
@ -12,23 +13,24 @@ object @{rule.safeLabel} extends Rule[LogicalPlan]
{
def apply(plan: LogicalPlan): LogicalPlan =
{
@Match(
schema = schema,
pattern = rule.pattern,
target = Code.Literal("plan"),
targetType = Type.AST(rule.family),
onSuccess = Expression(schema, rule.rewrite,
TypecheckMatch(
@{
val matchSchema = TypecheckMatch(
rule.pattern,
Some("plan"),
Type.AST(rule.family),
schema,
schema.globals
)
),
onFail = Code.Literal("plan"),
}
@Match(
schema = schema,
pattern = rule.pattern,
target = Code.Literal("plan"),
targetType = Type.AST(rule.family),
onSuccess = Expression(schema, rule.rewrite, _),
onFail = { _ => Code.Literal("plan") },
name = Some("plan"),
scope = schema.globals
scope = CodeScope(schema.globals)
).toString(4).stripPrefix(" ")
}
}