Compare commits

...

4 Commits

Author SHA1 Message Date
Nick Brown 25a138139a
merge of formatting. 2023-08-24 09:31:23 -04:00
Nick Brown 8db885c970
slight update to wording. 2023-08-24 09:28:58 -04:00
Nick Brown c90b0a4953
initial documentation 2023-08-24 09:28:57 -04:00
Oliver Kennedy fe18eeac08
Final tests 2023-07-20 10:08:06 -04:00
12 changed files with 310 additions and 52 deletions

View File

@ -14,19 +14,19 @@ You can install Scala and Mill using [Coursier](https://get-coursier.io/docs/cli
## Quick Guide
To clean the project:
`make clean` or `mill clean`
`make clean`, or `mill clean`
To compile the ASTral Catalyst Optimizer:
`make`, `make compile` or `mill astral.catalyst.impl.compile`
`make`, or `make compile`, or `mill astral.catalyst.impl.compile`
To run a basic sanity check on the ASTral Catalyst Optimizer:
`make check` or `mill astral.catalyst.impl.run`
`make check`, or `mill astral.catalyst.impl.run`
To run a subset of the TPCH Benchmark across ASTral, Spark with ASTral's rule set, and Spark:
`make test` or `mill astral.catalyst.impl.runMain com.astraldb.catalyst.TPCHTest`
`make test`, or `mill astral.catalyst.impl.runMain com.astraldb.catalyst.TPCHTest`
*`-w` can be appended to `mill` in any of the above commands to put `mill` into watch mode for incremental changes*
## Language
- [Specification](docs/LANG.md)
- [Specification](docs/LANG.md) (TODO)

View File

@ -512,4 +512,37 @@ object SparkMethods
case j: ExistenceJoin => true
case _ => false
}
def combineReferences(left: LogicalPlan, right: LogicalPlan): AttributeSet =
left.references ++ right.references
def inputSetIsSubsetOfRefs(child: LogicalPlan, required: AttributeSet): Boolean =
child.inputSet.subsetOf(required)
def ColumnPruningTwentyReplaceChildren(child: LogicalPlan, required: AttributeSet): LogicalPlan =
{
val newChildren = child.children.map(c => prunedChild(c, required))
// if(child.isInstanceOf[Join] && (child.references.map { _.exprId }.toSeq contains ExprId(51))){
// println(s"Pruning Join: ${child.references}: $required\n${newChildren.mkString("\n-------\n")}\n========")}
child.withNewChildren(newChildren)
}
def outputSetIsSubset(p2: LogicalPlan, child: LogicalPlan): Boolean =
p2.outputSet.subsetOf(child.outputSet)
def projectIsOnlyOfAttributes(p2: Project): Boolean =
p2.projectList.forall(_.isInstanceOf[AttributeReference])
def hasConflictingAttrsWithSubquery(predicate: Expression, child: LogicalPlan): Boolean =
predicate.find {
case s: SubqueryExpression if s.plan.outputSet.intersect(child.outputSet).nonEmpty => true
case _ => false
}.isDefined
def LeftSemiOrAntiJoin(t: JoinType): Boolean =
t match {
case LeftSemi | LeftAnti => true
case _ => false
}
}

View File

@ -52,18 +52,27 @@ object TPCHTest
Tester.dumpStats()
for((label, query) <- QUERIES.drop(1).take(1))
{
println("===================================")
println(s"TPCH Query $label")
println("===================================")
val result = Tester.test(spark.sql(query))
println(result.summary)
val targets = QUERIES
// val targets = Seq(QUERIES(0))
// Tester.testAlignment(spark.sql(query))
val allResults =
for((label, query) <- targets)
yield {
println("===================================")
println(s"TPCH Query $label")
println("===================================")
val result = Tester.test(spark.sql(query))
println(result.summary)
result.validate(false)
}
// Tester.testAlignment(spark.sql(query))
// result.validate(false)
(s"tpch_$label", result)
}
Tester.writeResultCSV("tpch.csv", allResults)
}
}

View File

@ -5,27 +5,37 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import com.astraldb.catalyst.lesserspark.{ Optimizer => LesserSparkOptimizer }
import java.io._
object Tester
{
val LINE_SEP = "----------------"
val BURN_IN_ROUNDS = 2
val BURN_IN_ROUNDS = 4
val TRIALS = 10
case class Result(plan: LogicalPlan, runtime: Double)
case class Result(plan: LogicalPlan, iterations: Int, runtimes: Seq[Double])
{
val runtime = runtimes.sum / runtimes.size
val stddev =
Math.sqrt(
runtimes.map { x => x * x }.sum / runtimes.size - runtime * runtime
)
override def toString(): String =
s"$LINE_SEP\nTime: $runtime s\nPlan:\n$plan\n$LINE_SEP"
s"$LINE_SEP\nTime: $runtime s (σ^2 = $stddev)\nIterations: $iterations\nPlan:\n$plan\n$LINE_SEP"
}
object Result
{
def apply(x: (LogicalPlan, Double)): Result =
Result(x._1, x._2)
def time(x: => LogicalPlan): Result =
def time(x: => (LogicalPlan, Int)): Result =
{
// run it three times and take the last for fairness
for(i <- 0 until BURN_IN_ROUNDS) { x }
Result(Time.measure(x))
val results = (0 until TRIALS).map { _ => Time.measure(x) }
Result(
plan = results.head._1._1,
iterations = results.head._1._2,
results.map { _._2 }
)
}
}
@ -38,7 +48,7 @@ object Tester
val spark: Result,
)
{
def tests = Seq(
"Astral-Raw" -> astral,
"Astral-BDD" -> bdd,
@ -59,7 +69,7 @@ object Tester
s"\nInput\n$input\n$LINE_SEP\n\nANALYZED\n$LINE_SEP\n$analyzed\n$LINE_SEP" +
tests.map { case (name, result) => s"\n\n$name\n$result" }.mkString
def summary: String =
tests.map { case (name, result) => s"$name: ${result.runtime * 1000} ms" }.mkString("\n")
tests.map { case (name, result) => s"$name: ${result.runtime * 1000} ms (stddev = ${result.stddev * 1000}, ${result.iterations} iterations)" }.mkString("\n")
}
def diffPlans(plan1: LogicalPlan, plan2: LogicalPlan): (String, String, String) =
@ -68,11 +78,11 @@ object Tester
val b = plan2.toString.split("\n")
val same = a.zip(b).takeWhile { case (a, b) => a == b }
(
same.mkString("\n"),
same.map { _._1 }.mkString("\n"),
a.drop(same.size).mkString("\n"),
b.drop(same.size).mkString("\n")
)
}
}
def diffPlanSummary(label1: String, label2: String, plan1: LogicalPlan, plan2: LogicalPlan): String =
{
@ -87,6 +97,23 @@ object Tester
println(s"Using Fair Spark Optimizer with ${LesserSparkOptimizer.rules.size} rules")
}
def writeResultCSV(file: String, results: Seq[(String, AllResults)]): Unit =
{
val out = new BufferedWriter(new FileWriter(file))
try {
out.write("experiment,"+results.head._2.tests.map { x =>
s"${x._1.toLowerCase}-avg,${x._1.toLowerCase}-stddev,${x._1.toLowerCase}-iter"
}.mkString(",")+"\n")
for( (label, result) <- results )
{
out.write(label+","+result.tests.map { x =>
s"${x._2.runtime},${x._2.stddev},${x._2.iterations}"
}.mkString(",")+"\n")
}
} finally {
out.close()
}
}
def test(df: DataFrame): AllResults =
{
@ -107,11 +134,25 @@ object Tester
LesserSparkOptimizer.rewrite(analyzed)
},
spark = Result.time {
optimizer.execute(analyzed)
optimizer.execute(analyzed) -> 0
}
)
}
def fixpoint(target: LogicalPlan)(op: LogicalPlan => LogicalPlan): LogicalPlan =
{
var curr = target
var last:LogicalPlan = null
for(i <- (0 until 30))
{
if(curr != last){
last = curr
curr = op(curr)
}
}
return curr
}
def testAlignment(df: DataFrame): Unit =
{
val baseRulesByGroup =
@ -128,8 +169,12 @@ object Tester
{
for((group, rule) <- nativeRules)
{
val native = rule(curr)
val base = baseRulesByGroup(group).foldLeft(curr) { (curr, rule) => rule(curr) }
val native =
fixpoint(curr)(rule(_))
val base =
fixpoint(curr)(
baseRulesByGroup(group).foldLeft(_) { (curr, rule) => rule(curr) }
)
if(!native.fastEquals(base))
{

View File

@ -6,6 +6,14 @@ import scala.collection.View.Filter
object Catalyst extends HardcodedDefinition
{
// Only one rule from here
// SkipRules("CollapseProject")
// ColumnPruning introduces Projects, which interacts a bit oddly with
// BDDs.
SkipRules("ColumnPruning")
Ast("LogicalPlan")(
Node("Filter")(
"condition" -> Type.Native("Expression"),
@ -94,6 +102,9 @@ object Catalyst extends HardcodedDefinition
"generatorOutput" -> Type.Array(Type.Native("Attribute")),
"child" -> Type.AST("LogicalPlan")
)
).withSubtypes(
"SetOperation",
"Distinct"
)
//////////////////////////////////////////////////////
@ -332,6 +343,30 @@ object Catalyst extends HardcodedDefinition
Type.Node("Generate"),
)
Function("LeftExistenceJoin", Type.Bool)(Type.Native("JoinType"))
Function("combineReferences", Type.Array(Type.Native("Attribute")))(
Type.AST("LogicalPlan"),
Type.AST("LogicalPlan")
)
Function("inputSetIsSubsetOfRefs", Type.Bool)(
Type.AST("LogicalPlan"),
Type.Array(Type.Native("Attribute"))
)
Function("ColumnPruningTwentyReplaceChildren", Type.AST("LogicalPlan"))(
Type.AST("LogicalPlan"),
Type.Array(Type.Native("Attribute"))
)
Function("outputSetIsSubset", Type.Bool)(
Type.AST("LogicalPlan"),
Type.AST("LogicalPlan"),
)
Function("projectIsOnlyOfAttributes", Type.Bool)(
Type.Node("Project"),
)
Function("hasConflictingAttrsWithSubquery", Type.Bool)(
Type.Native("Expression"),
Type.AST("LogicalPlan"),
)
Function("LeftSemiOrAntiJoin", Type.Bool)(Type.Native("JoinType"))
Global("JoinHint.NONE", Type.Native("JoinHint"))
Global("RightOuter", Type.Native("JoinType"))
@ -821,7 +856,7 @@ object Catalyst extends HardcodedDefinition
Bind("gChild"),
)),
Bind("rightOp"),
Bind("joinType"),
Bind("joinType") and Test(Apply("LeftSemiOrAntiJoin")(Ref("joinType"))),
Bind("joinCond"),
Bind("hint"),
)) and Test(
@ -865,23 +900,38 @@ object Catalyst extends HardcodedDefinition
))
)) and Test(
Apply("referencesASubset")(Ref("p2"), Ref("p1")).not
) and Test(
Apply("canCollapseExpressions")(
Ref("p1").structField("projectList"),
Ref("p2").structField("projectList"),
Bool(true)
).not
)
)(
Construct("Project")(
Ref("target"),
Construct("Project")(
Let("newTargetList2" ->
Apply("filterNamedExpressions")(
Ref("target2"),
Ref("p1")
),
Ref("child")
)
)(
//workaround for chaining issue: CollapseProject-1 won't fire
// If(
// Apply("canCollapseExpressions")(
// Ref("target"),
// Ref("newTargetList2"),
// Bool(true)
// )
// )(
// Construct("Project")(
// Apply("buildCleanedProjectList")(
// Ref("target"),
// Ref("newTargetList2"),
// ),
// Ref("child")
// )
// )(
Construct("Project")(
Ref("target"),
Construct("Project")(
Ref("newTargetList2"),
Ref("child")
)
)
// )
)
)
@ -1138,6 +1188,9 @@ object Catalyst extends HardcodedDefinition
// // all the columns will be used to compare, so we can't prune them
// case p @ Project(_, _: SetOperation) => p
// case p @ Project(_, _: Distinct) => p
//The above two exist solely to prevent the subsequent rules from firing. Skip
// // Eliminate unneeded attributes from children of Union.
// case p @ Project(_, u: Union) =>
// if (!u.outputSet.subsetOf(p.references)) {
@ -1172,6 +1225,8 @@ object Catalyst extends HardcodedDefinition
// // Can't prune the columns on LeafNode
// case p @ Project(_, _: LeafNode) => p
//this exists solely to prevent the subsequent rules from firing. Skip
// case NestedColumnAliasing(rewrittenPlan) => rewrittenPlan
// // for all other logical plans that inherits the output from it's children
@ -1184,6 +1239,64 @@ object Catalyst extends HardcodedDefinition
// } else {
// p
// }
// Rule("ColumnPruning-20", "LogicalPlan")(
// Bind("p", Match("Project")(
// MatchAny,
// Bind("child",
// (OfType("Project").not)
// //the following are used to preclude cases explicitly checked for above
// and (OfType("Window").not)
// and (OfType("Union").not)
// and (OfType(Type.ASTSubtype("SetOperation")).not)
// and (OfType(Type.ASTSubtype("Distinct")).not)
// )
// )) and BindExpression("required", Apply("combineReferences")(Ref("p"), Ref("child")))
// and Test( Apply("inputSetIsSubsetOfRefs")(Ref("child"), Ref("required")).not )
// )(
// Construct("Project")(
// Ref("p").structField("projectList"),
// Apply("ColumnPruningTwentyReplaceChildren")(
// Ref("child"), Ref("required")
// )
// )
// )
// private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transformUp {
// case p1 @ Project(_, f @ Filter(e, p2 @ Project(_, child)))
// if p2.outputSet.subsetOf(child.outputSet) &&
// // We only remove attribute-only project.
// p2.projectList.forall(_.isInstanceOf[AttributeReference]) &&
// // We can't remove project when the child has conflicting attributes
// // with the subquery in filter predicate
// !hasConflictingAttrsWithSubquery(e, child) =>
// p1.copy(child = f.copy(child = child))
// }
// Rule("ColumnPruning-Post", "LogicalPlan")(
// Bind("p1", Match("Project")(
// MatchAny,
// Bind("f", Match("Filter")(
// Bind("e"),
// Bind("p2", Match("Project")(
// MatchAny,
// Bind("child")
// ))
// ))
// )) and Test(
// Apply("outputSetIsSubset")(Ref("p2"), Ref("child"))
// ) and Test(
// Apply("projectIsOnlyOfAttributes")(Ref("p2"))
// ) and Test(
// Apply("hasConflictingAttrsWithSubquery")(Ref("e"), Ref("child")).not
// )
// )(
// Construct("Project")(
// Ref("p1").structField("projectList"),
// Construct("Filter")(
// Ref("e"),
// Ref("child")
// )
// )
// )
//////////////////////////////////////////////////////

View File

@ -97,6 +97,37 @@ object Match
)
}(scope)
case Path(Seq(), pattern) =>
apply(
schema = schema,
pattern = pattern,
target = target,
targetPath = targetPath,
targetType = targetType,
onSuccess = onSuccess,
onFail = onFail,
name = name,
scope = scope
)
case Path(path, pattern) =>
targetType match {
case Type.Node(label) =>
val node = schema.nodesByName(label)
apply(
schema = schema,
pattern = Path(path.tail, pattern),
target = Code.BinOp(target, ".", Code.Literal(node.fields(path.head).name)),
targetPath = targetPath :+ path.head,
targetType = node.fields(path.head).t,
onSuccess = onSuccess,
onFail = onFail,
name = name,
scope = scope
)
case t => assert(false, s"Rendering a path matcher for $target, but is of non-node type $t")
}
case Exact(pattern) =>
val selectedName =
name.getOrElse { "genericNode" }+"_t"

View File

@ -2,7 +2,8 @@ package com.astraldb.spec
case class ASTDefinition(
family: Type.AST,
nodes: Set[Node]
nodes: Set[Node],
extraSubtypes: Set[Type.ASTSubtype] = Set.empty
)
{
val subtypes:Map[Type.ASTSubtype, Set[Node]] =
@ -12,4 +13,5 @@ case class ASTDefinition(
}
}
.groupMap { _._1 } { _._2 }
++ (extraSubtypes.map { _ -> Set() }.toMap)
}

View File

@ -41,21 +41,36 @@ class HardcodedDefinition
val family = Type.AST(f)
family -> ASTDefinition(
family = family,
nodes = n.map { _.copy(family = family) }.toSet
nodes = n.map { _.copy(family = family) }.toSet,
extraSubtypes = subtypes.get(family).toSet.flatten
)}.toMap,
rules = rules.toSeq,
rules = rules.toSeq.filterNot { rule =>
skiprules.exists { rule.label startsWith _ }
},
globals = globals.toMap
)
val nodes = mutable.Map[String, mutable.Buffer[Node]]()
val rules = mutable.Buffer[Rule]()
val globals = mutable.Map[String,Type]()
val subtypes = mutable.Map[Type.AST, mutable.Buffer[Type.ASTSubtype]]()
val skiprules = mutable.Buffer[String]()
import FieldConversions._
def Ast(label: String)(newNodes: => Node*): Unit =
class ASTBuilder(family: Type.AST)
{
def withSubtypes(labels: String*): Unit =
subtypes.getOrElseUpdate(family, { mutable.Buffer() })
.appendAll(labels.map { Type.ASTSubtype(_) })
}
def Ast(label: String)(newNodes: => Node*): ASTBuilder =
{
nodes.getOrElseUpdate(label, mutable.Buffer.empty)
.appendAll(newNodes)
new ASTBuilder(Type.AST(label))
}
def Node(label: String)(fields: (String, Type)*): Node =
com.astraldb.spec.Node(label, fields, supertypes = Set.empty, family = null)
@ -75,6 +90,12 @@ class HardcodedDefinition
globals(label) = t
}
def SkipRules(rules: String*): Unit =
{
skiprules.appendAll(rules)
}
//////////////////////// Matchers
def Match(node: String)(fields: Match*): Match =
com.astraldb.spec.Match.Node(node, fields)

View File

@ -37,6 +37,9 @@ sealed trait Match
def or(other: Match) =
Match.Or(orSeq ++ other.orSeq)
def not =
Match.Not(this)
def transform(f: PartialFunction[Match, Match]): Match =
transformDown(f)

View File

@ -272,6 +272,7 @@ object TypecheckMatch
pattern match {
case Match.Not(child) =>
check(child, targetPath, state)
state
case Match.And(children) =>
children.foldLeft(state) { (state, child) =>

View File

@ -11,7 +11,7 @@ object Optimizer
def MAX_ITERATIONS = 100
def rewrite(plan: LogicalPlan): LogicalPlan =
def rewrite(plan: LogicalPlan): (LogicalPlan, Int) =
{
var current = plan
var last = plan
@ -24,10 +24,10 @@ object Optimizer
}
if(last.fastEquals(current))
{
return current
return (current, i+1)
}
last = current
}
return current
return (current, MAX_ITERATIONS)
}
}

View File

@ -11,7 +11,7 @@ object @clazz
def MAX_ITERATIONS = 100
def rewrite(plan: LogicalPlan): LogicalPlan =
def rewrite(plan: LogicalPlan): (LogicalPlan, Int) =
{
var current = plan
var last = plan
@ -24,10 +24,10 @@ object @clazz
}
if(last.fastEquals(current))
{
return current
return (current, i+1)
}
last = current
}
return current
return (current, MAX_ITERATIONS)
}
}