Compiler compiles

This commit is contained in:
Oliver Kennedy 2023-07-08 18:03:20 -04:00
parent 9348b1cbc1
commit 8ac8bfec8b
Signed by: okennedy
GPG key ID: 3E5F9B3ABD3FDB60
8 changed files with 205 additions and 72 deletions

View file

@ -1,9 +1,43 @@
package com.astraldb.catalyst package com.astraldb.catalyst
import org.apache.spark.sql.{ Row, SparkSession }
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
object OptimizerTest object OptimizerTest
{ {
def main(args: Array[String]): Unit = def main(args: Array[String]): Unit =
{ {
println("Hello world") val spark: SparkSession =
SparkSession.builder
.appName("OptimizerTest")
.master("local")
.getOrCreate()
val r = spark.emptyDataFrame
.select(lit(1) as "A", lit(2) as "B")
r.createOrReplaceTempView("R")
val df = spark.sql("SELECT * FROM R")
println(df.queryExecution.logical)
println(df.queryExecution.analyzed)
println("------------------")
val optimized =
Time("Astral Optimizer") {
Optimizer.rewrite(df.queryExecution.analyzed)
}
println("------------------\nAstral Optimized Query:\n")
println(optimized)
println("------------------")
val sparkOptimized =
Time("Spark Optimizer") {
df.queryExecution.optimizedPlan
}
println("------------------\nSpark Optimized Query:\n")
println(sparkOptimized)
println("------------------")
} }
} }

View file

@ -0,0 +1,15 @@
package com.astraldb.catalyst
object Time
{
def apply[T](label: String)(body: => T): T =
{
val start = System.currentTimeMillis()
val ret = body
val end = System.currentTimeMillis()
println(s"$label: ${(end - start).toFloat / 1000.0}s")
return ret
}
}

View file

@ -242,10 +242,10 @@ object Code
val renderedElems = elems.map { _.render(indent+2) } val renderedElems = elems.map { _.render(indent+2) }
val multiLine = val multiLine =
renderedElems.exists { _.isInstanceOf[Lines] } renderedElems.exists { _.isInstanceOf[Lines] }
|| renderedElems.map { _.asInstanceOf[PaddedString].body.size }.sum > 100 || renderedElems.map { _.asInstanceOf[PaddedString].body.size }.sum > 60
if(multiLine){ if(multiLine){
Lines( Lines(
renderedElems.flatMap { renderedElems.dropRight(1).flatMap {
case p:PaddedString => case p:PaddedString =>
Lines.ofLine(indent, p, sep).body Lines.ofLine(indent, p, sep).body
case Lines(Seq()) => case Lines(Seq()) =>
@ -253,10 +253,20 @@ object Code
case Lines(Seq(l)) => case Lines(Seq(l)) =>
Seq(PaddedString(l, sep).body) Seq(PaddedString(l, sep).body)
case Lines(l) => case Lines(l) =>
Seq(l.head) ++ l.dropRight(1) ++
l.tail.dropRight(1).map { " " + _ } ++ Seq(PaddedString(l.last,sep).body)
Seq(PaddedString(" ",l.last,sep).body) } ++
} (renderedElems.last match {
case p:PaddedString =>
Lines.ofLine(indent, p).body
case Lines(Seq()) =>
Lines.ofLine(indent).body
case Lines(Seq(l)) =>
Seq(PaddedString(l).body)
case Lines(l) =>
l.dropRight(1) ++
Seq(PaddedString(l.last,sep).body)
})
) )
} else { } else {
renderedElems.map { _.asInstanceOf[PaddedString] } match { renderedElems.map { _.asInstanceOf[PaddedString] } match {

View file

@ -0,0 +1,45 @@
package com.astraldb.codegen
import com.astraldb.spec
class CodeScope(vars: Map[String, spec.Type | (Code, spec.Type)])
{
def typeOf(name: String): spec.Type =
vars(name) match {
case t:spec.Type => t
case c:(Code,spec.Type) => c._2
}
def codeOf(name: String): Code =
vars(name) match {
case t:spec.Type => Code.Literal(name)
case c:(Code,spec.Type) => c._1
}
def withVars(nameVal: (String, spec.Type|(Code, spec.Type))*): CodeScope =
new CodeScope(vars ++ nameVal.toMap)
def refine(code: Code, updatedType: spec.Type): CodeScope =
new CodeScope(vars.mapValues {
case c:(Code,spec.Type) if c._1 == code => c._1 -> updatedType
case v => v
}.toMap)
def refine(code: Code, updatedCode: Code, updatedType: spec.Type): CodeScope =
new CodeScope(vars.mapValues {
case c:(Code,spec.Type) if c._1 == code => updatedCode -> updatedType
case v => v
}.toMap)
def flatten: Map[String, spec.Type] =
vars.mapValues {
case t:spec.Type => t
case c:(Code,spec.Type) => c._2
}.toMap
override def toString: String =
"{ " + vars.map { v => v._1 + "->" + (v._2 match {
case t:spec.Type => t.toString
case c:(Code,spec.Type) => c._1.toString+":"+c._2.toString
})}.mkString(", ") + " }"
}

View file

@ -9,16 +9,24 @@ import com.astraldb.spec.Type
object Expression object Expression
{ {
def apply(schema: spec.Definition, op: Expr, scope: Map[String, spec.Type]): Code =
def assoc(op: Expr, t: ArithTypes.T): Seq[Expr] =
op match {
case Arith(at, l, r) if at == t =>
assoc(l, t) ++ assoc(r, t)
case _ => Seq(op)
}
def apply(schema: spec.Definition, op: Expr, scope: CodeScope): Code =
{ {
op match { op match {
case c:SimpleConstant => Code.Literal(c.asScala) case c:SimpleConstant => Code.Literal(c.asScala)
case Var(v) => Code.Literal(v) case Var(v) => scope.codeOf(v)
case Arith(t, lhs, rhs) => case Arith(t, lhs, rhs) =>
Code.BinOp( Code.List(
Code.Parenthesize(apply(schema, lhs, scope)),
PaddedString.pad(1, ArithTypes.opString(t)), PaddedString.pad(1, ArithTypes.opString(t)),
Code.Parenthesize(apply(schema, rhs, scope)) assoc(op, t)
.map { e => Code.Parenthesize(apply(schema, e, scope)) }
) )
case Cmp(t, lhs, rhs) => case Cmp(t, lhs, rhs) =>
Code.BinOp( Code.BinOp(
@ -28,10 +36,9 @@ object Expression
) )
case FunctionCall(fn, args) => case FunctionCall(fn, args) =>
Code.Literal(op.toString) Code.Literal(op.toString)
Code.PrefixedBlock( Code.Parens(
prefix = Code.Literal(fn), left = s"$fn(",
lParen = "(", right = ")",
rParen = ")",
body = Code.List( body = Code.List(
sep = PaddedString.rightPad(1, ","), sep = PaddedString.rightPad(1, ","),
elems = args.map { apply(schema, _, scope) } elems = args.map { apply(schema, _, scope) }
@ -43,10 +50,14 @@ object Expression
Code.BinOp( Code.BinOp(
apply(schema, target, scope), apply(schema, target, scope),
".", ".",
TypecheckExpression(target, schema, scope) match { TypecheckExpression(target, schema, scope.flatten) match {
case Type.Node(nodeType) => case Type.Node(nodeType) =>
val node =
schema.nodesByName.get(nodeType)
.getOrElse { assert(false, s"No node of type $nodeType: $op")}
assert(node.fields.size > index, s"Node type $nodeType only has ${node.fields.size} fields (field $index requested): $op in $scope")
Code.Literal( Code.Literal(
schema.nodesByName(nodeType).fields(index).name node.fields(index).name
) )
case c => case c =>
assert(false, s"Node subscript on something not a node: $op (in ${scope}): ${c.getClass.getSimpleName}") assert(false, s"Node subscript on something not a node: $op (in ${scope}): ${c.getClass.getSimpleName}")

View file

@ -6,6 +6,7 @@ import spec.Match._
import com.astraldb.codegen.Code.PaddedString import com.astraldb.codegen.Code.PaddedString
import com.astraldb.typecheck.TypecheckMatch import com.astraldb.typecheck.TypecheckMatch
import com.astraldb.spec.Type import com.astraldb.spec.Type
import com.astraldb.typecheck.TypecheckExpression
object Match object Match
{ {
@ -14,29 +15,36 @@ object Match
pattern: spec.Match, pattern: spec.Match,
target: Code, target: Code,
targetType: spec.Type, targetType: spec.Type,
onSuccess: Code, onSuccess: CodeScope => Code,
onFail: Code, onFail: CodeScope => Code,
name: Option[String], name: Option[String],
scope: Map[String, spec.Type] scope: CodeScope
): Code = ): Code =
{ {
pattern match { pattern match {
case And(Seq()) => onSuccess case And(Seq()) => onSuccess(scope)
case And(Seq(a)) => apply(schema, a, target, targetType, onSuccess, onFail, name, scope) case And(Seq(a)) => apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
case And(a) => case And(a) =>
apply(schema, a.head, apply(schema, a.head,
target = target, target = target,
targetType = targetType, targetType = targetType,
onSuccess = onSuccess =
apply(schema, And(a.tail), target, targetType, onSuccess, onFail, name, scope =>
scope = TypecheckMatch(a.head, name, targetType, schema, scope) apply(schema, And(a.tail),
target = target,
targetType = targetType,
onSuccess = onSuccess,
onFail = onFail,
name = name,
scope = scope
), ),
onFail = onFail, onFail = onFail,
name = name, name = name,
scope = scope scope = scope
) )
case Not(a) => apply(schema, a, target, targetType, onFail, onSuccess, name, scope) case Not(a) =>
case Or(Seq()) => onFail apply(schema, a, target, targetType, onFail, onSuccess, name, scope)
case Or(Seq()) => onFail(scope)
case Or(Seq(a)) => case Or(Seq(a)) =>
apply(schema, a, target, targetType, onSuccess, onFail, name, scope) apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
case Or(a) => case Or(a) =>
@ -44,27 +52,28 @@ object Match
target = target, target = target,
targetType = targetType, targetType = targetType,
onSuccess = onSuccess, onSuccess = onSuccess,
onFail = apply(schema, Or(a.tail), target, targetType, onSuccess, onFail, name, scope), onFail =
_ => apply(schema, Or(a.tail),
target = target,
targetType = targetType,
onSuccess = onSuccess,
onFail = onFail,
name = name,
scope = scope
),
name = name, name = name,
scope = scope scope = scope
) )
case Bind(symbol, pattern) => case Bind(symbol, pattern) =>
Code.Block(
Code.BinOp(
Code.Literal(s"val $symbol"),
Code.PaddedString.pad(1, "="),
target
) +:
apply( apply(
schema = schema, schema = schema,
pattern = pattern, pattern = pattern,
target = Code.Literal(symbol), target = target,
targetType = targetType, targetType = targetType,
onSuccess = onSuccess, onSuccess = onSuccess,
onFail = onFail, onFail = onFail,
name = Some(symbol), name = Some(symbol),
scope = scope ++ Map(symbol -> targetType) scope = scope.withVars(symbol -> (target, targetType))
).block
) )
case BindExpression(symbol, op) => case BindExpression(symbol, op) =>
Code.Block( Code.Block(
@ -74,13 +83,19 @@ object Match
Code.PaddedString.pad(1, "="), Code.PaddedString.pad(1, "="),
Expression(schema, op, scope) Expression(schema, op, scope)
) )
) ++ onSuccess.block ) ++ onSuccess(
scope.withVars(symbol -> (Code.Literal(symbol),
TypecheckExpression(op, schema, scope.flatten)
))
).block
) )
case Node(nodeLabel, children) => case Node(nodeLabel, children) =>
{ {
val selectedName =
name.getOrElse { "genericNode" }
val node = schema.nodesByName(nodeLabel) val node = schema.nodesByName(nodeLabel)
val selectedName =
name.getOrElse { "genericNode" }+"_t"
val updatedScope =
scope.refine(target, Code.Literal(selectedName), Type.Node(nodeLabel))
Code.IfThenElse( Code.IfThenElse(
condition = condition =
Code.BinOp( Code.BinOp(
@ -104,33 +119,33 @@ object Match
children children
.zip(node.fields) .zip(node.fields)
.foldRight(onSuccess) { case ((child, field), andThen) => .foldRight(onSuccess) { case ((child, field), andThen) =>
apply( nextScope => apply(
schema = schema, schema = schema,
pattern = child, pattern = child,
target = Code.Literal(s"$selectedName.${field.name}"), target = Code.Literal(s"$selectedName.${field.name}"),
targetType = Type.Node(nodeLabel), targetType = field.t,
onSuccess = andThen, onSuccess = andThen,
onFail = onFail, onFail = onFail,
name = Some(s"${selectedName}_${field.name}"), name = Some(s"${selectedName}_${field.name}"),
scope = scope ++ Map(selectedName -> Type.Node(nodeLabel)) scope = nextScope
) )
} }(updatedScope)
.block .block
), ),
elseBlock = onFail elseBlock = onFail(scope)
) )
} }
case Test(op) => case Test(op) =>
Code.IfThenElse( Code.IfThenElse(
condition = Expression(schema, op, scope), condition = Expression(schema, op, scope),
thenBlock = onSuccess, thenBlock = onSuccess(scope),
elseBlock = onFail elseBlock = onFail(scope)
) )
case Any => case Any =>
onSuccess onSuccess(scope)
case OfType(nodeType) => case OfType(nodeType) =>
val selectedName = val selectedName =
name.getOrElse { "genericNode" } name.getOrElse { "genericNode" } + "_t"
Code.IfThenElse( Code.IfThenElse(
condition = Code.BinOp(target, ".", Code.Literal(s"isInstanceOf[${nodeType.scalaType}]")), condition = Code.BinOp(target, ".", Code.Literal(s"isInstanceOf[${nodeType.scalaType}]")),
thenBlock = thenBlock =
@ -145,11 +160,13 @@ object Match
Code.Literal(s"asInstanceOf[${nodeType.scalaType}]") Code.Literal(s"asInstanceOf[${nodeType.scalaType}]")
) )
) )
)++onSuccess.block )++onSuccess(scope.refine(target, Code.Literal(selectedName), nodeType)).block
), ),
elseBlock = onFail elseBlock = onFail(scope)
) )
case _ => Code.Literal(s"??${pattern.getClass.getSimpleName}??") }
}
} }
} }

View file

@ -6,8 +6,7 @@ object Optimizer
{ {
val rules = Seq[Rule[LogicalPlan]]( val rules = Seq[Rule[LogicalPlan]](
@for(rule <- ctx.rules){ @for(rule <- ctx.rules){
@rule.safeLabel, @rule.safeLabel, }
}
) )
def MAX_ITERATIONS = 100 def MAX_ITERATIONS = 100

View file

@ -4,6 +4,7 @@
@import com.astraldb.codegen.Match @import com.astraldb.codegen.Match
@import com.astraldb.codegen.Expression @import com.astraldb.codegen.Expression
@import com.astraldb.codegen.Code @import com.astraldb.codegen.Code
@import com.astraldb.codegen.CodeScope
@import com.astraldb.typecheck.TypecheckMatch @import com.astraldb.typecheck.TypecheckMatch
@(schema: Definition, rule: Rule) @(schema: Definition, rule: Rule)
@ -12,23 +13,24 @@ object @{rule.safeLabel} extends Rule[LogicalPlan]
{ {
def apply(plan: LogicalPlan): LogicalPlan = def apply(plan: LogicalPlan): LogicalPlan =
{ {
@Match( @{
schema = schema, val matchSchema = TypecheckMatch(
pattern = rule.pattern,
target = Code.Literal("plan"),
targetType = Type.AST(rule.family),
onSuccess = Expression(schema, rule.rewrite,
TypecheckMatch(
rule.pattern, rule.pattern,
Some("plan"), Some("plan"),
Type.AST(rule.family), Type.AST(rule.family),
schema, schema,
schema.globals schema.globals
) )
), }
onFail = Code.Literal("plan"), @Match(
schema = schema,
pattern = rule.pattern,
target = Code.Literal("plan"),
targetType = Type.AST(rule.family),
onSuccess = Expression(schema, rule.rewrite, _),
onFail = { _ => Code.Literal("plan") },
name = Some("plan"), name = Some("plan"),
scope = schema.globals scope = CodeScope(schema.globals)
).toString(4).stripPrefix(" ") ).toString(4).stripPrefix(" ")
} }
} }