Compiler compiles
This commit is contained in:
parent
9348b1cbc1
commit
8ac8bfec8b
|
@ -1,9 +1,43 @@
|
||||||
package com.astraldb.catalyst
|
package com.astraldb.catalyst
|
||||||
|
|
||||||
|
import org.apache.spark.sql.{ Row, SparkSession }
|
||||||
|
import org.apache.spark.sql.types._
|
||||||
|
import org.apache.spark.sql.functions._
|
||||||
|
|
||||||
object OptimizerTest
|
object OptimizerTest
|
||||||
{
|
{
|
||||||
def main(args: Array[String]): Unit =
|
def main(args: Array[String]): Unit =
|
||||||
{
|
{
|
||||||
println("Hello world")
|
val spark: SparkSession =
|
||||||
|
SparkSession.builder
|
||||||
|
.appName("OptimizerTest")
|
||||||
|
.master("local")
|
||||||
|
.getOrCreate()
|
||||||
|
val r = spark.emptyDataFrame
|
||||||
|
.select(lit(1) as "A", lit(2) as "B")
|
||||||
|
r.createOrReplaceTempView("R")
|
||||||
|
val df = spark.sql("SELECT * FROM R")
|
||||||
|
|
||||||
|
println(df.queryExecution.logical)
|
||||||
|
println(df.queryExecution.analyzed)
|
||||||
|
|
||||||
|
println("------------------")
|
||||||
|
|
||||||
|
val optimized =
|
||||||
|
Time("Astral Optimizer") {
|
||||||
|
Optimizer.rewrite(df.queryExecution.analyzed)
|
||||||
|
}
|
||||||
|
|
||||||
|
println("------------------\nAstral Optimized Query:\n")
|
||||||
|
println(optimized)
|
||||||
|
println("------------------")
|
||||||
|
|
||||||
|
val sparkOptimized =
|
||||||
|
Time("Spark Optimizer") {
|
||||||
|
df.queryExecution.optimizedPlan
|
||||||
|
}
|
||||||
|
println("------------------\nSpark Optimized Query:\n")
|
||||||
|
println(sparkOptimized)
|
||||||
|
println("------------------")
|
||||||
}
|
}
|
||||||
}
|
}
|
15
astral/catalyst/impl/src/com/astraldb/catalyst/Time.scala
Normal file
15
astral/catalyst/impl/src/com/astraldb/catalyst/Time.scala
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
package com.astraldb.catalyst
|
||||||
|
|
||||||
|
object Time
|
||||||
|
{
|
||||||
|
def apply[T](label: String)(body: => T): T =
|
||||||
|
{
|
||||||
|
val start = System.currentTimeMillis()
|
||||||
|
val ret = body
|
||||||
|
val end = System.currentTimeMillis()
|
||||||
|
|
||||||
|
println(s"$label: ${(end - start).toFloat / 1000.0}s")
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
}
|
|
@ -242,10 +242,10 @@ object Code
|
||||||
val renderedElems = elems.map { _.render(indent+2) }
|
val renderedElems = elems.map { _.render(indent+2) }
|
||||||
val multiLine =
|
val multiLine =
|
||||||
renderedElems.exists { _.isInstanceOf[Lines] }
|
renderedElems.exists { _.isInstanceOf[Lines] }
|
||||||
|| renderedElems.map { _.asInstanceOf[PaddedString].body.size }.sum > 100
|
|| renderedElems.map { _.asInstanceOf[PaddedString].body.size }.sum > 60
|
||||||
if(multiLine){
|
if(multiLine){
|
||||||
Lines(
|
Lines(
|
||||||
renderedElems.flatMap {
|
renderedElems.dropRight(1).flatMap {
|
||||||
case p:PaddedString =>
|
case p:PaddedString =>
|
||||||
Lines.ofLine(indent, p, sep).body
|
Lines.ofLine(indent, p, sep).body
|
||||||
case Lines(Seq()) =>
|
case Lines(Seq()) =>
|
||||||
|
@ -253,10 +253,20 @@ object Code
|
||||||
case Lines(Seq(l)) =>
|
case Lines(Seq(l)) =>
|
||||||
Seq(PaddedString(l, sep).body)
|
Seq(PaddedString(l, sep).body)
|
||||||
case Lines(l) =>
|
case Lines(l) =>
|
||||||
Seq(l.head) ++
|
l.dropRight(1) ++
|
||||||
l.tail.dropRight(1).map { " " + _ } ++
|
Seq(PaddedString(l.last,sep).body)
|
||||||
Seq(PaddedString(" ",l.last,sep).body)
|
} ++
|
||||||
}
|
(renderedElems.last match {
|
||||||
|
case p:PaddedString =>
|
||||||
|
Lines.ofLine(indent, p).body
|
||||||
|
case Lines(Seq()) =>
|
||||||
|
Lines.ofLine(indent).body
|
||||||
|
case Lines(Seq(l)) =>
|
||||||
|
Seq(PaddedString(l).body)
|
||||||
|
case Lines(l) =>
|
||||||
|
l.dropRight(1) ++
|
||||||
|
Seq(PaddedString(l.last,sep).body)
|
||||||
|
})
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
renderedElems.map { _.asInstanceOf[PaddedString] } match {
|
renderedElems.map { _.asInstanceOf[PaddedString] } match {
|
||||||
|
|
45
astral/compiler/src/com/astraldb/codegen/CodeScope.scala
Normal file
45
astral/compiler/src/com/astraldb/codegen/CodeScope.scala
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
package com.astraldb.codegen
|
||||||
|
|
||||||
|
import com.astraldb.spec
|
||||||
|
|
||||||
|
class CodeScope(vars: Map[String, spec.Type | (Code, spec.Type)])
|
||||||
|
{
|
||||||
|
def typeOf(name: String): spec.Type =
|
||||||
|
vars(name) match {
|
||||||
|
case t:spec.Type => t
|
||||||
|
case c:(Code,spec.Type) => c._2
|
||||||
|
}
|
||||||
|
|
||||||
|
def codeOf(name: String): Code =
|
||||||
|
vars(name) match {
|
||||||
|
case t:spec.Type => Code.Literal(name)
|
||||||
|
case c:(Code,spec.Type) => c._1
|
||||||
|
}
|
||||||
|
|
||||||
|
def withVars(nameVal: (String, spec.Type|(Code, spec.Type))*): CodeScope =
|
||||||
|
new CodeScope(vars ++ nameVal.toMap)
|
||||||
|
|
||||||
|
def refine(code: Code, updatedType: spec.Type): CodeScope =
|
||||||
|
new CodeScope(vars.mapValues {
|
||||||
|
case c:(Code,spec.Type) if c._1 == code => c._1 -> updatedType
|
||||||
|
case v => v
|
||||||
|
}.toMap)
|
||||||
|
|
||||||
|
def refine(code: Code, updatedCode: Code, updatedType: spec.Type): CodeScope =
|
||||||
|
new CodeScope(vars.mapValues {
|
||||||
|
case c:(Code,spec.Type) if c._1 == code => updatedCode -> updatedType
|
||||||
|
case v => v
|
||||||
|
}.toMap)
|
||||||
|
|
||||||
|
def flatten: Map[String, spec.Type] =
|
||||||
|
vars.mapValues {
|
||||||
|
case t:spec.Type => t
|
||||||
|
case c:(Code,spec.Type) => c._2
|
||||||
|
}.toMap
|
||||||
|
|
||||||
|
override def toString: String =
|
||||||
|
"{ " + vars.map { v => v._1 + "->" + (v._2 match {
|
||||||
|
case t:spec.Type => t.toString
|
||||||
|
case c:(Code,spec.Type) => c._1.toString+":"+c._2.toString
|
||||||
|
})}.mkString(", ") + " }"
|
||||||
|
}
|
|
@ -9,16 +9,24 @@ import com.astraldb.spec.Type
|
||||||
|
|
||||||
object Expression
|
object Expression
|
||||||
{
|
{
|
||||||
def apply(schema: spec.Definition, op: Expr, scope: Map[String, spec.Type]): Code =
|
|
||||||
|
def assoc(op: Expr, t: ArithTypes.T): Seq[Expr] =
|
||||||
|
op match {
|
||||||
|
case Arith(at, l, r) if at == t =>
|
||||||
|
assoc(l, t) ++ assoc(r, t)
|
||||||
|
case _ => Seq(op)
|
||||||
|
}
|
||||||
|
|
||||||
|
def apply(schema: spec.Definition, op: Expr, scope: CodeScope): Code =
|
||||||
{
|
{
|
||||||
op match {
|
op match {
|
||||||
case c:SimpleConstant => Code.Literal(c.asScala)
|
case c:SimpleConstant => Code.Literal(c.asScala)
|
||||||
case Var(v) => Code.Literal(v)
|
case Var(v) => scope.codeOf(v)
|
||||||
case Arith(t, lhs, rhs) =>
|
case Arith(t, lhs, rhs) =>
|
||||||
Code.BinOp(
|
Code.List(
|
||||||
Code.Parenthesize(apply(schema, lhs, scope)),
|
|
||||||
PaddedString.pad(1, ArithTypes.opString(t)),
|
PaddedString.pad(1, ArithTypes.opString(t)),
|
||||||
Code.Parenthesize(apply(schema, rhs, scope))
|
assoc(op, t)
|
||||||
|
.map { e => Code.Parenthesize(apply(schema, e, scope)) }
|
||||||
)
|
)
|
||||||
case Cmp(t, lhs, rhs) =>
|
case Cmp(t, lhs, rhs) =>
|
||||||
Code.BinOp(
|
Code.BinOp(
|
||||||
|
@ -28,10 +36,9 @@ object Expression
|
||||||
)
|
)
|
||||||
case FunctionCall(fn, args) =>
|
case FunctionCall(fn, args) =>
|
||||||
Code.Literal(op.toString)
|
Code.Literal(op.toString)
|
||||||
Code.PrefixedBlock(
|
Code.Parens(
|
||||||
prefix = Code.Literal(fn),
|
left = s"$fn(",
|
||||||
lParen = "(",
|
right = ")",
|
||||||
rParen = ")",
|
|
||||||
body = Code.List(
|
body = Code.List(
|
||||||
sep = PaddedString.rightPad(1, ","),
|
sep = PaddedString.rightPad(1, ","),
|
||||||
elems = args.map { apply(schema, _, scope) }
|
elems = args.map { apply(schema, _, scope) }
|
||||||
|
@ -43,10 +50,14 @@ object Expression
|
||||||
Code.BinOp(
|
Code.BinOp(
|
||||||
apply(schema, target, scope),
|
apply(schema, target, scope),
|
||||||
".",
|
".",
|
||||||
TypecheckExpression(target, schema, scope) match {
|
TypecheckExpression(target, schema, scope.flatten) match {
|
||||||
case Type.Node(nodeType) =>
|
case Type.Node(nodeType) =>
|
||||||
|
val node =
|
||||||
|
schema.nodesByName.get(nodeType)
|
||||||
|
.getOrElse { assert(false, s"No node of type $nodeType: $op")}
|
||||||
|
assert(node.fields.size > index, s"Node type $nodeType only has ${node.fields.size} fields (field $index requested): $op in $scope")
|
||||||
Code.Literal(
|
Code.Literal(
|
||||||
schema.nodesByName(nodeType).fields(index).name
|
node.fields(index).name
|
||||||
)
|
)
|
||||||
case c =>
|
case c =>
|
||||||
assert(false, s"Node subscript on something not a node: $op (in ${scope}): ${c.getClass.getSimpleName}")
|
assert(false, s"Node subscript on something not a node: $op (in ${scope}): ${c.getClass.getSimpleName}")
|
||||||
|
|
|
@ -6,6 +6,7 @@ import spec.Match._
|
||||||
import com.astraldb.codegen.Code.PaddedString
|
import com.astraldb.codegen.Code.PaddedString
|
||||||
import com.astraldb.typecheck.TypecheckMatch
|
import com.astraldb.typecheck.TypecheckMatch
|
||||||
import com.astraldb.spec.Type
|
import com.astraldb.spec.Type
|
||||||
|
import com.astraldb.typecheck.TypecheckExpression
|
||||||
|
|
||||||
object Match
|
object Match
|
||||||
{
|
{
|
||||||
|
@ -14,29 +15,36 @@ object Match
|
||||||
pattern: spec.Match,
|
pattern: spec.Match,
|
||||||
target: Code,
|
target: Code,
|
||||||
targetType: spec.Type,
|
targetType: spec.Type,
|
||||||
onSuccess: Code,
|
onSuccess: CodeScope => Code,
|
||||||
onFail: Code,
|
onFail: CodeScope => Code,
|
||||||
name: Option[String],
|
name: Option[String],
|
||||||
scope: Map[String, spec.Type]
|
scope: CodeScope
|
||||||
): Code =
|
): Code =
|
||||||
{
|
{
|
||||||
pattern match {
|
pattern match {
|
||||||
case And(Seq()) => onSuccess
|
case And(Seq()) => onSuccess(scope)
|
||||||
case And(Seq(a)) => apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
|
case And(Seq(a)) => apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
|
||||||
case And(a) =>
|
case And(a) =>
|
||||||
apply(schema, a.head,
|
apply(schema, a.head,
|
||||||
target = target,
|
target = target,
|
||||||
targetType = targetType,
|
targetType = targetType,
|
||||||
onSuccess =
|
onSuccess =
|
||||||
apply(schema, And(a.tail), target, targetType, onSuccess, onFail, name,
|
scope =>
|
||||||
scope = TypecheckMatch(a.head, name, targetType, schema, scope)
|
apply(schema, And(a.tail),
|
||||||
|
target = target,
|
||||||
|
targetType = targetType,
|
||||||
|
onSuccess = onSuccess,
|
||||||
|
onFail = onFail,
|
||||||
|
name = name,
|
||||||
|
scope = scope
|
||||||
),
|
),
|
||||||
onFail = onFail,
|
onFail = onFail,
|
||||||
name = name,
|
name = name,
|
||||||
scope = scope
|
scope = scope
|
||||||
)
|
)
|
||||||
case Not(a) => apply(schema, a, target, targetType, onFail, onSuccess, name, scope)
|
case Not(a) =>
|
||||||
case Or(Seq()) => onFail
|
apply(schema, a, target, targetType, onFail, onSuccess, name, scope)
|
||||||
|
case Or(Seq()) => onFail(scope)
|
||||||
case Or(Seq(a)) =>
|
case Or(Seq(a)) =>
|
||||||
apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
|
apply(schema, a, target, targetType, onSuccess, onFail, name, scope)
|
||||||
case Or(a) =>
|
case Or(a) =>
|
||||||
|
@ -44,27 +52,28 @@ object Match
|
||||||
target = target,
|
target = target,
|
||||||
targetType = targetType,
|
targetType = targetType,
|
||||||
onSuccess = onSuccess,
|
onSuccess = onSuccess,
|
||||||
onFail = apply(schema, Or(a.tail), target, targetType, onSuccess, onFail, name, scope),
|
onFail =
|
||||||
|
_ => apply(schema, Or(a.tail),
|
||||||
|
target = target,
|
||||||
|
targetType = targetType,
|
||||||
|
onSuccess = onSuccess,
|
||||||
|
onFail = onFail,
|
||||||
|
name = name,
|
||||||
|
scope = scope
|
||||||
|
),
|
||||||
name = name,
|
name = name,
|
||||||
scope = scope
|
scope = scope
|
||||||
)
|
)
|
||||||
case Bind(symbol, pattern) =>
|
case Bind(symbol, pattern) =>
|
||||||
Code.Block(
|
|
||||||
Code.BinOp(
|
|
||||||
Code.Literal(s"val $symbol"),
|
|
||||||
Code.PaddedString.pad(1, "="),
|
|
||||||
target
|
|
||||||
) +:
|
|
||||||
apply(
|
apply(
|
||||||
schema = schema,
|
schema = schema,
|
||||||
pattern = pattern,
|
pattern = pattern,
|
||||||
target = Code.Literal(symbol),
|
target = target,
|
||||||
targetType = targetType,
|
targetType = targetType,
|
||||||
onSuccess = onSuccess,
|
onSuccess = onSuccess,
|
||||||
onFail = onFail,
|
onFail = onFail,
|
||||||
name = Some(symbol),
|
name = Some(symbol),
|
||||||
scope = scope ++ Map(symbol -> targetType)
|
scope = scope.withVars(symbol -> (target, targetType))
|
||||||
).block
|
|
||||||
)
|
)
|
||||||
case BindExpression(symbol, op) =>
|
case BindExpression(symbol, op) =>
|
||||||
Code.Block(
|
Code.Block(
|
||||||
|
@ -74,13 +83,19 @@ object Match
|
||||||
Code.PaddedString.pad(1, "="),
|
Code.PaddedString.pad(1, "="),
|
||||||
Expression(schema, op, scope)
|
Expression(schema, op, scope)
|
||||||
)
|
)
|
||||||
) ++ onSuccess.block
|
) ++ onSuccess(
|
||||||
|
scope.withVars(symbol -> (Code.Literal(symbol),
|
||||||
|
TypecheckExpression(op, schema, scope.flatten)
|
||||||
|
))
|
||||||
|
).block
|
||||||
)
|
)
|
||||||
case Node(nodeLabel, children) =>
|
case Node(nodeLabel, children) =>
|
||||||
{
|
{
|
||||||
val selectedName =
|
|
||||||
name.getOrElse { "genericNode" }
|
|
||||||
val node = schema.nodesByName(nodeLabel)
|
val node = schema.nodesByName(nodeLabel)
|
||||||
|
val selectedName =
|
||||||
|
name.getOrElse { "genericNode" }+"_t"
|
||||||
|
val updatedScope =
|
||||||
|
scope.refine(target, Code.Literal(selectedName), Type.Node(nodeLabel))
|
||||||
Code.IfThenElse(
|
Code.IfThenElse(
|
||||||
condition =
|
condition =
|
||||||
Code.BinOp(
|
Code.BinOp(
|
||||||
|
@ -104,33 +119,33 @@ object Match
|
||||||
children
|
children
|
||||||
.zip(node.fields)
|
.zip(node.fields)
|
||||||
.foldRight(onSuccess) { case ((child, field), andThen) =>
|
.foldRight(onSuccess) { case ((child, field), andThen) =>
|
||||||
apply(
|
nextScope => apply(
|
||||||
schema = schema,
|
schema = schema,
|
||||||
pattern = child,
|
pattern = child,
|
||||||
target = Code.Literal(s"$selectedName.${field.name}"),
|
target = Code.Literal(s"$selectedName.${field.name}"),
|
||||||
targetType = Type.Node(nodeLabel),
|
targetType = field.t,
|
||||||
onSuccess = andThen,
|
onSuccess = andThen,
|
||||||
onFail = onFail,
|
onFail = onFail,
|
||||||
name = Some(s"${selectedName}_${field.name}"),
|
name = Some(s"${selectedName}_${field.name}"),
|
||||||
scope = scope ++ Map(selectedName -> Type.Node(nodeLabel))
|
scope = nextScope
|
||||||
)
|
)
|
||||||
}
|
}(updatedScope)
|
||||||
.block
|
.block
|
||||||
),
|
),
|
||||||
elseBlock = onFail
|
elseBlock = onFail(scope)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
case Test(op) =>
|
case Test(op) =>
|
||||||
Code.IfThenElse(
|
Code.IfThenElse(
|
||||||
condition = Expression(schema, op, scope),
|
condition = Expression(schema, op, scope),
|
||||||
thenBlock = onSuccess,
|
thenBlock = onSuccess(scope),
|
||||||
elseBlock = onFail
|
elseBlock = onFail(scope)
|
||||||
)
|
)
|
||||||
case Any =>
|
case Any =>
|
||||||
onSuccess
|
onSuccess(scope)
|
||||||
case OfType(nodeType) =>
|
case OfType(nodeType) =>
|
||||||
val selectedName =
|
val selectedName =
|
||||||
name.getOrElse { "genericNode" }
|
name.getOrElse { "genericNode" } + "_t"
|
||||||
Code.IfThenElse(
|
Code.IfThenElse(
|
||||||
condition = Code.BinOp(target, ".", Code.Literal(s"isInstanceOf[${nodeType.scalaType}]")),
|
condition = Code.BinOp(target, ".", Code.Literal(s"isInstanceOf[${nodeType.scalaType}]")),
|
||||||
thenBlock =
|
thenBlock =
|
||||||
|
@ -145,11 +160,13 @@ object Match
|
||||||
Code.Literal(s"asInstanceOf[${nodeType.scalaType}]")
|
Code.Literal(s"asInstanceOf[${nodeType.scalaType}]")
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)++onSuccess.block
|
)++onSuccess(scope.refine(target, Code.Literal(selectedName), nodeType)).block
|
||||||
),
|
),
|
||||||
elseBlock = onFail
|
elseBlock = onFail(scope)
|
||||||
)
|
)
|
||||||
case _ => Code.Literal(s"??${pattern.getClass.getSimpleName}??")
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -6,8 +6,7 @@ object Optimizer
|
||||||
{
|
{
|
||||||
val rules = Seq[Rule[LogicalPlan]](
|
val rules = Seq[Rule[LogicalPlan]](
|
||||||
@for(rule <- ctx.rules){
|
@for(rule <- ctx.rules){
|
||||||
@rule.safeLabel,
|
@rule.safeLabel, }
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def MAX_ITERATIONS = 100
|
def MAX_ITERATIONS = 100
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
@import com.astraldb.codegen.Match
|
@import com.astraldb.codegen.Match
|
||||||
@import com.astraldb.codegen.Expression
|
@import com.astraldb.codegen.Expression
|
||||||
@import com.astraldb.codegen.Code
|
@import com.astraldb.codegen.Code
|
||||||
|
@import com.astraldb.codegen.CodeScope
|
||||||
@import com.astraldb.typecheck.TypecheckMatch
|
@import com.astraldb.typecheck.TypecheckMatch
|
||||||
|
|
||||||
@(schema: Definition, rule: Rule)
|
@(schema: Definition, rule: Rule)
|
||||||
|
@ -12,23 +13,24 @@ object @{rule.safeLabel} extends Rule[LogicalPlan]
|
||||||
{
|
{
|
||||||
def apply(plan: LogicalPlan): LogicalPlan =
|
def apply(plan: LogicalPlan): LogicalPlan =
|
||||||
{
|
{
|
||||||
@Match(
|
@{
|
||||||
schema = schema,
|
val matchSchema = TypecheckMatch(
|
||||||
pattern = rule.pattern,
|
|
||||||
target = Code.Literal("plan"),
|
|
||||||
targetType = Type.AST(rule.family),
|
|
||||||
onSuccess = Expression(schema, rule.rewrite,
|
|
||||||
TypecheckMatch(
|
|
||||||
rule.pattern,
|
rule.pattern,
|
||||||
Some("plan"),
|
Some("plan"),
|
||||||
Type.AST(rule.family),
|
Type.AST(rule.family),
|
||||||
schema,
|
schema,
|
||||||
schema.globals
|
schema.globals
|
||||||
)
|
)
|
||||||
),
|
}
|
||||||
onFail = Code.Literal("plan"),
|
@Match(
|
||||||
|
schema = schema,
|
||||||
|
pattern = rule.pattern,
|
||||||
|
target = Code.Literal("plan"),
|
||||||
|
targetType = Type.AST(rule.family),
|
||||||
|
onSuccess = Expression(schema, rule.rewrite, _),
|
||||||
|
onFail = { _ => Code.Literal("plan") },
|
||||||
name = Some("plan"),
|
name = Some("plan"),
|
||||||
scope = schema.globals
|
scope = CodeScope(schema.globals)
|
||||||
).toString(4).stripPrefix(" ")
|
).toString(4).stripPrefix(" ")
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in a new issue