Compiler compiles
This commit is contained in:
parent
9348b1cbc1
commit
8ac8bfec8b
|
@ -1,9 +1,43 @@
|
|||
package com.astraldb.catalyst
|
||||
|
||||
import org.apache.spark.sql.{ Row, SparkSession }
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
object OptimizerTest
|
||||
{
|
||||
def main(args: Array[String]): Unit =
|
||||
{
|
||||
println("Hello world")
|
||||
val spark: SparkSession =
|
||||
SparkSession.builder
|
||||
.appName("OptimizerTest")
|
||||
.master("local")
|
||||
.getOrCreate()
|
||||
val r = spark.emptyDataFrame
|
||||
.select(lit(1) as "A", lit(2) as "B")
|
||||
r.createOrReplaceTempView("R")
|
||||
val df = spark.sql("SELECT * FROM R")
|
||||
|
||||
println(df.queryExecution.logical)
|
||||
println(df.queryExecution.analyzed)
|
||||
|
||||
println("------------------")
|
||||
|
||||
val optimized =
|
||||
Time("Astral Optimizer") {
|
||||
Optimizer.rewrite(df.queryExecution.analyzed)
|
||||
}
|
||||
|
||||
println("------------------\nAstral Optimized Query:\n")
|
||||
println(optimized)
|
||||
println("------------------")
|
||||
|
||||
val sparkOptimized =
|
||||
Time("Spark Optimizer") {
|
||||
df.queryExecution.optimizedPlan
|
||||
}
|
||||
println("------------------\nSpark Optimized Query:\n")
|
||||
println(sparkOptimized)
|
||||
println("------------------")
|
||||
}
|
||||
}
|
15
astral/catalyst/impl/src/com/astraldb/catalyst/Time.scala
Normal file
15
astral/catalyst/impl/src/com/astraldb/catalyst/Time.scala
Normal file
|
@ -0,0 +1,15 @@
|
|||
package com.astraldb.catalyst
|
||||
|
||||
object Time
|
||||
{
|
||||
def apply[T](label: String)(body: => T): T =
|
||||
{
|
||||
val start = System.currentTimeMillis()
|
||||
val ret = body
|
||||
val end = System.currentTimeMillis()
|
||||
|
||||
println(s"$label: ${(end - start).toFloat / 1000.0}s")
|
||||
|
||||
return ret
|
||||
}
|
||||
}
|
|
@ -242,10 +242,10 @@ object Code
|
|||
val renderedElems = elems.map { _.render(indent+2) }
|
||||
val multiLine =
|
||||
renderedElems.exists { _.isInstanceOf[Lines] }
|
||||
|| renderedElems.map { _.asInstanceOf[PaddedString].body.size }.sum > 100
|
||||
|| renderedElems.map { _.asInstanceOf[PaddedString].body.size }.sum > 60
|
||||
if(multiLine){
|
||||
Lines(
|
||||
renderedElems.flatMap {
|
||||
renderedElems.dropRight(1).flatMap {
|
||||
case p:PaddedString =>
|
||||
Lines.ofLine(indent, p, sep).body
|
||||
case Lines(Seq()) =>
|
||||
|
@ -253,10 +253,20 @@ object Code
|
|||
case Lines(Seq(l)) =>
|
||||
Seq(PaddedString(l, sep).body)
|
||||
case Lines(l) =>
|
||||
Seq(l.head) ++
|
||||
l.tail.dropRight(1).map { " " + _ } ++
|
||||
Seq(PaddedString(" ",l.last,sep).body)
|
||||
}
|
||||
l.dropRight(1) ++
|
||||
Seq(PaddedString(l.last,sep).body)
|
||||
} ++
|
||||
(renderedElems.last match {
|
||||
case p:PaddedString =>
|
||||
Lines.ofLine(indent, p).body
|
||||
case Lines(Seq()) =>
|
||||
Lines.ofLine(indent).body
|
||||
case Lines(Seq(l)) =>
|
||||
Seq(PaddedString(l).body)
|
||||
case Lines(l) =>
|
||||
l.dropRight(1) ++
|
||||
Seq(PaddedString(l.last,sep).body)
|
||||
})
|
||||
)
|
||||
} else {
|
||||
renderedElems.map { _.asInstanceOf[PaddedString] } match {
|
||||
|
|
45
astral/compiler/src/com/astraldb/codegen/CodeScope.scala
Normal file
45
astral/compiler/src/com/astraldb/codegen/CodeScope.scala
Normal file
|
@ -0,0 +1,45 @@
|
|||
package com.astraldb.codegen
|
||||
|
||||
import com.astraldb.spec
|
||||
|
||||
class CodeScope(vars: Map[String, spec.Type | (Code, spec.Type)])
|
||||
{
|
||||
def typeOf(name: String): spec.Type =
|
||||
vars(name) match {
|
||||
case t:spec.Type => t
|
||||
case c:(Code,spec.Type) => c._2
|
||||
}
|
||||
|
||||
def codeOf(name: String): Code =
|
||||
vars(name) match {
|
||||
case t:spec.Type => Code.Literal(name)
|
||||
case c:(Code,spec.Type) => c._1
|
||||
}
|
||||
|
||||
def withVars(nameVal: (String, spec.Type|(Code, spec.Type))*): CodeScope =
|
||||
new CodeScope(vars ++ nameVal.toMap)
|
||||
|
||||
def refine(code: Code, updatedType: spec.Type): CodeScope =
|
||||
new CodeScope(vars.mapValues {
|
||||
case c:(Code,spec.Type) if c._1 == code => c._1 -> updatedType
|
||||
case v => v
|
||||
}.toMap)
|
||||
|
||||
def refine(code: Code, updatedCode: Code, updatedType: spec.Type): CodeScope =
|
||||
new CodeScope(vars.mapValues {
|
||||
case c:(Code,spec.Type) if c._1 == code => updatedCode -> updatedType
|
||||
case v => v
|
||||
}.toMap)
|
||||
|
||||
def flatten: Map[String, spec.Type] =
|
||||
vars.mapValues {
|
||||
case t:spec.Type => t
|
||||
case c:(Code,spec.Type) => c._2
|
||||
}.toMap
|
||||
|
||||
override def toString: String =
|
||||
"{ " + vars.map { v => v._1 + "->" + (v._2 match {
|
||||
case t:spec.Type => t.toString
|
||||
case c:(Code,spec.Type) => c._1.toString+":"+c._2.toString
|
||||
})}.mkString(", ") + " }"
|
||||
}
|
|
@ -9,16 +9,24 @@ import com.astraldb.spec.Type
|
|||
|
||||
object Expression
|
||||
{
|
||||
def apply(schema: spec.Definition, op: Expr, scope: Map[String, spec.Type]): Code =
|
||||
|
||||
def assoc(op: Expr, t: ArithTypes.T): Seq[Expr] =
|
||||
op match {
|
||||
case Arith(at, l, r) if at == t =>
|
||||
assoc(l, t) ++ assoc(r, t)
|
||||
case _ => Seq(op)
|
||||
}
|
||||
|
||||
def apply(schema: spec.Definition, op: Expr, scope: CodeScope): Code =
|
||||
{
|
||||
op match {
|
||||
case c:SimpleConstant => Code.Literal(c.asScala)
|
||||
case Var(v) => Code.Literal(v)
|
||||
case Var(v) => scope.codeOf(v)
|
||||
case Arith(t, lhs, rhs) =>
|
||||
Code.BinOp(
|
||||
Code.Parenthesize(apply(schema, lhs, scope)),
|
||||
Code.List(
|
||||
PaddedString.pad(1, ArithTypes.opString(t)),
|
||||
Code.Parenthesize(apply(schema, rhs, scope))
|
||||
assoc(op, t)
|
||||
.map { e => Code.Parenthesize(apply(schema, e, scope)) }
|
||||
)
|
||||
case Cmp(t, lhs, rhs) =>
|
||||
Code.BinOp(
|
||||
|
@ -28,10 +36,9 @@ object Expression
|
|||
)
|
||||
case FunctionCall(fn, args) =>
|
||||
Code.Literal(op.toString)
|
||||
Code.PrefixedBlock(
|
||||
prefix = Code.Literal(fn),
|
||||
lParen = "(",
|
||||
rParen = ")",
|
||||
Code.Parens(
|
||||
left = s"$fn(",
|
||||
right = ")",
|
||||
body = Code.List(
|
||||
sep = PaddedString.rightPad(1, ","),
|
||||
elems = args.map { apply(schema, _, scope) }
|
||||
|
@ -43,10 +50,14 @@ object Expression
|
|||
Code.BinOp(
|
||||
apply(schema, target, scope),
|
||||
".",
|
||||
TypecheckExpression(target, schema, scope) match {
|
||||
TypecheckExpression(target, schema, scope.flatten) match {
|
||||
case Type.Node(nodeType) =>
|
||||
val node =
|
||||
schema.nodesByName.get(nodeType)
|
||||
.getOrElse { assert(false, s"No node of type $nodeType: $op")}
|
||||
assert(node.fields.size > index, s"Node type $nodeType only has ${node.fields.size} fields (field $index requested): $op in $scope")
|
||||
Code.Literal(
|
||||
schema.nodesByName(nodeType).fields(index).name
|
||||
node.fields(index).name
|
||||
)
|
||||
case c =>
|
||||
assert(false, s"Node subscript on something not a node: $op (in ${scope}): ${c.getClass.getSimpleName}")
|
||||
|
|
|
@ -6,6 +6,7 @@ import spec.Match._
|
|||
import com.astraldb.codegen.Code.PaddedString
|
||||
import com.astraldb.typecheck.TypecheckMatch
|
||||
import com.astraldb.spec.Type
|
||||
import com.astraldb.typecheck.TypecheckExpression
|
||||
|
||||
object Match
|
||||
{
|
||||
|
@ -14,29 +15,36 @@ object Match
|
|||
pattern: spec.Match,
|
||||
target: Code,
|
||||
targetType: spec.Type,
|
||||
onSuccess: Code,
|
||||
onFail: Code,
|
||||
onSuccess: CodeScope => Code,
|
||||
onFail: CodeScope => Code,
|
||||
name: Option[String],
|
||||
scope: Map[String, spec.Type]
|
||||
scope: CodeScope
|
||||
): Code =
|
||||
{
|
||||
pattern match {
|
||||
case And(Seq()) => onSuccess
|
||||
case And(Seq()) => onSuccess(scope)
|
||||
case And(Seq(a)) => apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
|
||||
case And(a) =>
|
||||
apply(schema, a.head,
|
||||
target = target,
|
||||
targetType = targetType,
|
||||
onSuccess =
|
||||
apply(schema, And(a.tail), target, targetType, onSuccess, onFail, name,
|
||||
scope = TypecheckMatch(a.head, name, targetType, schema, scope)
|
||||
),
|
||||
scope =>
|
||||
apply(schema, And(a.tail),
|
||||
target = target,
|
||||
targetType = targetType,
|
||||
onSuccess = onSuccess,
|
||||
onFail = onFail,
|
||||
name = name,
|
||||
scope = scope
|
||||
),
|
||||
onFail = onFail,
|
||||
name = name,
|
||||
scope = scope
|
||||
)
|
||||
case Not(a) => apply(schema, a, target, targetType, onFail, onSuccess, name, scope)
|
||||
case Or(Seq()) => onFail
|
||||
case Not(a) =>
|
||||
apply(schema, a, target, targetType, onFail, onSuccess, name, scope)
|
||||
case Or(Seq()) => onFail(scope)
|
||||
case Or(Seq(a)) =>
|
||||
apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
|
||||
case Or(a) =>
|
||||
|
@ -44,27 +52,28 @@ object Match
|
|||
target = target,
|
||||
targetType = targetType,
|
||||
onSuccess = onSuccess,
|
||||
onFail = apply(schema, Or(a.tail), target, targetType, onSuccess, onFail, name, scope),
|
||||
onFail =
|
||||
_ => apply(schema, Or(a.tail),
|
||||
target = target,
|
||||
targetType = targetType,
|
||||
onSuccess = onSuccess,
|
||||
onFail = onFail,
|
||||
name = name,
|
||||
scope = scope
|
||||
),
|
||||
name = name,
|
||||
scope = scope
|
||||
)
|
||||
case Bind(symbol, pattern) =>
|
||||
Code.Block(
|
||||
Code.BinOp(
|
||||
Code.Literal(s"val $symbol"),
|
||||
Code.PaddedString.pad(1, "="),
|
||||
target
|
||||
) +:
|
||||
apply(
|
||||
schema = schema,
|
||||
pattern = pattern,
|
||||
target = Code.Literal(symbol),
|
||||
targetType = targetType,
|
||||
onSuccess = onSuccess,
|
||||
onFail = onFail,
|
||||
name = Some(symbol),
|
||||
scope = scope ++ Map(symbol -> targetType)
|
||||
).block
|
||||
apply(
|
||||
schema = schema,
|
||||
pattern = pattern,
|
||||
target = target,
|
||||
targetType = targetType,
|
||||
onSuccess = onSuccess,
|
||||
onFail = onFail,
|
||||
name = Some(symbol),
|
||||
scope = scope.withVars(symbol -> (target, targetType))
|
||||
)
|
||||
case BindExpression(symbol, op) =>
|
||||
Code.Block(
|
||||
|
@ -74,13 +83,19 @@ object Match
|
|||
Code.PaddedString.pad(1, "="),
|
||||
Expression(schema, op, scope)
|
||||
)
|
||||
) ++ onSuccess.block
|
||||
) ++ onSuccess(
|
||||
scope.withVars(symbol -> (Code.Literal(symbol),
|
||||
TypecheckExpression(op, schema, scope.flatten)
|
||||
))
|
||||
).block
|
||||
)
|
||||
case Node(nodeLabel, children) =>
|
||||
{
|
||||
val selectedName =
|
||||
name.getOrElse { "genericNode" }
|
||||
val node = schema.nodesByName(nodeLabel)
|
||||
val selectedName =
|
||||
name.getOrElse { "genericNode" }+"_t"
|
||||
val updatedScope =
|
||||
scope.refine(target, Code.Literal(selectedName), Type.Node(nodeLabel))
|
||||
Code.IfThenElse(
|
||||
condition =
|
||||
Code.BinOp(
|
||||
|
@ -104,33 +119,33 @@ object Match
|
|||
children
|
||||
.zip(node.fields)
|
||||
.foldRight(onSuccess) { case ((child, field), andThen) =>
|
||||
apply(
|
||||
nextScope => apply(
|
||||
schema = schema,
|
||||
pattern = child,
|
||||
target = Code.Literal(s"$selectedName.${field.name}"),
|
||||
targetType = Type.Node(nodeLabel),
|
||||
targetType = field.t,
|
||||
onSuccess = andThen,
|
||||
onFail = onFail,
|
||||
name = Some(s"${selectedName}_${field.name}"),
|
||||
scope = scope ++ Map(selectedName -> Type.Node(nodeLabel))
|
||||
scope = nextScope
|
||||
)
|
||||
}
|
||||
}(updatedScope)
|
||||
.block
|
||||
),
|
||||
elseBlock = onFail
|
||||
elseBlock = onFail(scope)
|
||||
)
|
||||
}
|
||||
case Test(op) =>
|
||||
Code.IfThenElse(
|
||||
condition = Expression(schema, op, scope),
|
||||
thenBlock = onSuccess,
|
||||
elseBlock = onFail
|
||||
thenBlock = onSuccess(scope),
|
||||
elseBlock = onFail(scope)
|
||||
)
|
||||
case Any =>
|
||||
onSuccess
|
||||
onSuccess(scope)
|
||||
case OfType(nodeType) =>
|
||||
val selectedName =
|
||||
name.getOrElse { "genericNode" }
|
||||
name.getOrElse { "genericNode" } + "_t"
|
||||
Code.IfThenElse(
|
||||
condition = Code.BinOp(target, ".", Code.Literal(s"isInstanceOf[${nodeType.scalaType}]")),
|
||||
thenBlock =
|
||||
|
@ -145,11 +160,13 @@ object Match
|
|||
Code.Literal(s"asInstanceOf[${nodeType.scalaType}]")
|
||||
)
|
||||
)
|
||||
)++onSuccess.block
|
||||
)++onSuccess(scope.refine(target, Code.Literal(selectedName), nodeType)).block
|
||||
),
|
||||
elseBlock = onFail
|
||||
elseBlock = onFail(scope)
|
||||
)
|
||||
case _ => Code.Literal(s"??${pattern.getClass.getSimpleName}??")
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
|
@ -5,9 +5,8 @@
|
|||
object Optimizer
|
||||
{
|
||||
val rules = Seq[Rule[LogicalPlan]](
|
||||
@for(rule <- ctx.rules){
|
||||
@rule.safeLabel,
|
||||
}
|
||||
@for(rule <- ctx.rules){
|
||||
@rule.safeLabel, }
|
||||
)
|
||||
|
||||
def MAX_ITERATIONS = 100
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
@import com.astraldb.codegen.Match
|
||||
@import com.astraldb.codegen.Expression
|
||||
@import com.astraldb.codegen.Code
|
||||
@import com.astraldb.codegen.CodeScope
|
||||
@import com.astraldb.typecheck.TypecheckMatch
|
||||
|
||||
@(schema: Definition, rule: Rule)
|
||||
|
@ -12,23 +13,24 @@ object @{rule.safeLabel} extends Rule[LogicalPlan]
|
|||
{
|
||||
def apply(plan: LogicalPlan): LogicalPlan =
|
||||
{
|
||||
@Match(
|
||||
schema = schema,
|
||||
pattern = rule.pattern,
|
||||
target = Code.Literal("plan"),
|
||||
targetType = Type.AST(rule.family),
|
||||
onSuccess = Expression(schema, rule.rewrite,
|
||||
TypecheckMatch(
|
||||
@{
|
||||
val matchSchema = TypecheckMatch(
|
||||
rule.pattern,
|
||||
Some("plan"),
|
||||
Type.AST(rule.family),
|
||||
schema,
|
||||
schema.globals
|
||||
)
|
||||
),
|
||||
onFail = Code.Literal("plan"),
|
||||
}
|
||||
@Match(
|
||||
schema = schema,
|
||||
pattern = rule.pattern,
|
||||
target = Code.Literal("plan"),
|
||||
targetType = Type.AST(rule.family),
|
||||
onSuccess = Expression(schema, rule.rewrite, _),
|
||||
onFail = { _ => Code.Literal("plan") },
|
||||
name = Some("plan"),
|
||||
scope = schema.globals
|
||||
scope = CodeScope(schema.globals)
|
||||
).toString(4).stripPrefix(" ")
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue