Code Factorization

main
dbalakri 2023-12-07 08:58:46 -05:00
parent 643571fe8a
commit 59df52e66d
8 changed files with 474 additions and 43 deletions

View File

@ -40,7 +40,7 @@ object BDD
++ (schema.globals:Map[String, Type|(Code,Type)])
)
def recur(rule: bdd.BDD, pathNames: Map[bdd.Pathed.Path, (Code, Type)], boundVars: Map[String, Type], symCount: Int): Code =
def recur(rule: bdd.BDD, pathNames: Map[bdd.Pathed.Path, (Code, Type)], boundVars: Map[String, Type]): Code =
rule match {
case bdd.NoRewrite => onFail
case bdd.PickByType(path, types, elseBranch) =>
@ -53,21 +53,7 @@ object BDD
Code.IfElifElse(
types.map { case (targetType, andThen) =>
{
// val newSymCount = symCount
// val nodeName = "root"+path.map { "_child"+_ }.mkString+"_as_"+targetType.scalaType
val nodeNameString = "root"+path.map { "_child"+_ }.mkString+"_as_"+targetType.scalaType
// println("Type nodeNameString: "+nodeNameString)
val nameCount = if (nodeNameString.length() >= 30){
("Sym_"+symCount, symCount + 1)
} else {
(nodeNameString, symCount)
}
val nodeName = nameCount._1
val newSymCount = nameCount._2
// println("Type nodeName being used: "+nodeName)
val nodeName = "root"+path.map { "_child"+_ }.mkString+"_as_"+targetType.scalaType
val newPathNames: Map[bdd.Pathed.Path, (Code, Type)] =
targetType match {
case Type.Node(node) =>
@ -99,12 +85,12 @@ object BDD
)++recur(
andThen,
pathNames ++ newPathNames ++ Map(path -> (Code.Literal(nodeName), targetType)),
boundVars, newSymCount
boundVars
).block
)
}
}.toSeq,
recur(elseBranch, pathNames, boundVars, symCount)
recur(elseBranch, pathNames, boundVars)
)
case bdd.PickByMatch(path, pattern, bindings, onMatch, onFail) =>
val target = pathNames.get(path)
@ -124,14 +110,12 @@ object BDD
onMatch,
pathNames,
boundVars,
symCount,
),
onFail =
_ => recur(
onFail,
pathNames,
boundVars,
symCount,
),
name = None,
scope = scope
@ -160,8 +144,7 @@ object BDD
)++recur(
andThen,
pathNames,
boundVars ++ Map(symbol -> exprType),
symCount
boundVars ++ Map(symbol -> exprType)
).block
)
}
@ -171,8 +154,7 @@ object BDD
Map(
Seq.empty -> (root, family)
),
Map.empty,
1
Map.empty
)
}
}

View File

@ -0,0 +1,155 @@
package com.astraldb.codegen
import com.astraldb.spec.Definition
import com.astraldb.spec.Type
import com.astraldb.spec
import com.astraldb.bdd
import com.astraldb.codegen.Code.PaddedString
import com.astraldb.expression.Var
object BDDFactorize
{
def apply(
schema: Definition,
family: Type.AST,
rule: bdd.BDD,
onFail: Code,
root: Code,
): Map[String, (String,Code)] =
{
def codeScopeFor[T](refs: Set[Var], bindings: Seq[bdd.Pathed[(String, Type)]], pathNames: Map[bdd.Pathed.Path, (Code, Type)], boundVars: Map[String, Type], context: T): CodeScope =
CodeScope(
(
(refs.map { _.v }
-- boundVars.keys
-- schema.globals.keys
).toSeq.map { ref =>
val pathRef =
bindings.find { _.value._1 == ref }
.getOrElse {
assert(false, s"Reference to a variable that hasn't been bound yet: $ref (in $bindings):\n$context")
}
val varDescription =
pathNames.find { _._1 == pathRef.path }
.getOrElse {
assert(false, s"Reference to a variable at an unbound path: ${pathRef.path} (in $pathNames):\n$context")
}
._2
ref -> varDescription
}.toMap:Map[String, Type|(Code,Type)]
) ++ (boundVars:Map[String, Type|(Code,Type)])
++ (schema.globals:Map[String, Type|(Code,Type)])
)
def recur(rule: bdd.BDD, pathNames: Map[bdd.Pathed.Path, (Code, Type)], boundVars: Map[String, Type], symCount: Int, syms: Seq[String]): Map[String, (String,Code)] =
{
rule match{
case bdd.NoRewrite => Map.empty
case bdd.Rewrite(label, op, bindings) =>
val rewriteCode = Code.Block(
Seq(
Code.Literal(s"// Rewrite by $label"),
) ++ Expression(
schema,
op,
codeScopeFor(op.references, bindings, pathNames, boundVars, op)
).block
)
// println(pathNames)
Map(s"RewriteBy$label" -> (s"(${syms.mkString(",")})",rewriteCode))
case bdd.PickByType(path, types, elseBranch) =>
types.map{
case (targetType, andThen) =>
val nodeNameString = "root"+path.map { "_child"+_ }.mkString+"_as_"+targetType.scalaType
// println("Type nodeNameString: "+nodeNameString)
val nameCount = if (nodeNameString.length() >= 30){
("Sym_"+symCount, symCount + 1)
} else {
(nodeNameString, symCount)
}
val nodeName = nameCount._1
val newSymCount = nameCount._2
val newPathNames: Map[bdd.Pathed.Path, (Code, Type)] =
targetType match {
case Type.Node(node) =>
schema.nodesByName(node)
.fields
.zipWithIndex
.map { case (field, idx) =>
(path :+ idx) -> (
Code.BinOp(
Code.Literal(nodeName),
".",
Code.Literal(field.name)
),
field.t
)
}
.toMap
case _ => Map.empty
}
recur(
andThen,
pathNames ++ newPathNames ++ Map(path -> (Code.Literal(nodeName), targetType)),
boundVars,
newSymCount,
syms ++ Seq(nodeName+": "+targetType.scalaType)
)
}.flatten.toMap ++ recur(elseBranch, pathNames, boundVars, symCount, syms)
case bdd.PickByMatch(fpath, fpattern, fbindings, fonMatch, fonFail) =>
val target = pathNames.get(fpath)
.getOrElse {
assert(false, s"No name available for $fpath (in $pathNames):\n$rule")
}
val scope =
codeScopeFor(fpattern.references, fbindings, pathNames, boundVars, fpattern)
val return_factors = recur(fonMatch, pathNames, boundVars, symCount, syms) ++ recur(fonFail, pathNames, boundVars, symCount, syms )
if(fonFail.isInstanceOf[bdd.PickByMatch]) {
val failbranch = fonFail.asInstanceOf[bdd.PickByMatch]
// System.err.println("====================")
// System.err.println("Factorize fail pattern: "+failbranch.matcher)
// System.err.println("Factorize fail ifmatched hash: "+failbranch.ifMatched.toString.hashCode.abs)
val code = BenchmarkBDD.recur(schema,family, onFail, failbranch.ifMatched, pathNames, boundVars, symCount, syms, Map.empty)
// System.err.println("Factorize fail ifmatched code: "+code)
return_factors ++ Map(s"Func${failbranch.ifMatched.toString.hashCode.abs.toString}" -> (s"(${syms.mkString(",")}, $onFail: ${family.scalaType})", code) )
} else {
return_factors
}
// recur(fonMatch, pathNames, boundVars, symCount, syms) ++ recur(fonFail, pathNames, boundVars, symCount, syms )
case bdd.BindAnExpression(symbol, expression, exprType, bindings, andThen) =>
recur(
andThen,
pathNames,
boundVars ++ Map(symbol -> exprType),
symCount,
syms
)
}
}
recur(
rule,
Map(
Seq.empty -> (root, family)
),
Map.empty,
1,
Seq.empty
)
}
}

View File

@ -0,0 +1,255 @@
package com.astraldb.codegen
import com.astraldb.spec.Definition
import com.astraldb.spec.Type
import com.astraldb.spec
import com.astraldb.bdd
import com.astraldb.codegen.Code.PaddedString
import com.astraldb.expression.Var
object BenchmarkBDD
{
def codeScopeFor[T](
schema: Definition,
refs: Set[Var],
bindings: Seq[bdd.Pathed[(String, Type)]],
pathNames: Map[bdd.Pathed.Path, (Code, Type)],
boundVars: Map[String, Type],
context: T): CodeScope =
CodeScope(
(
(refs.map { _.v }
-- boundVars.keys
-- schema.globals.keys
).toSeq.map { ref =>
val pathRef =
bindings.find { _.value._1 == ref }
.getOrElse {
assert(false, s"Reference to a variable that hasn't been bound yet: $ref (in $bindings):\n$context")
}
val varDescription =
pathNames.find { _._1 == pathRef.path }
.getOrElse {
assert(false, s"Reference to a variable at an unbound path: ${pathRef.path} (in $pathNames):\n$context")
}
._2
ref -> varDescription
}.toMap:Map[String, Type|(Code,Type)]
) ++ (boundVars:Map[String, Type|(Code,Type)])
++ (schema.globals:Map[String, Type|(Code,Type)])
)
def recur(
schema: Definition,
family: Type.AST,
onFail: Code,
rule: bdd.BDD,
pathNames: Map[bdd.Pathed.Path, (Code, Type)],
boundVars: Map[String, Type],
symCount: Int,
syms: Seq[String],
predefs: Map[String, (String,Code)]): Code =
rule match {
case bdd.NoRewrite => onFail
case bdd.PickByType(path, types, elseBranch) =>
// println(s"Generating code for $path")
val target = pathNames.get(path)
.getOrElse {
assert(false, s"No name available for $path (in $pathNames):\n$rule")
}
._1
Code.IfElifElse(
types.map { case (targetType, andThen) =>
{
// val newSymCount = symCount
// val nodeName = "root"+path.map { "_child"+_ }.mkString+"_as_"+targetType.scalaType
val nodeNameString = "root"+path.map { "_child"+_ }.mkString+"_as_"+targetType.scalaType
// println("Type nodeNameString: "+nodeNameString)
val nameCount = if (nodeNameString.length() >= 30){
("Sym_"+symCount, symCount + 1)
} else {
(nodeNameString, symCount)
}
val nodeName = nameCount._1
val newSymCount = nameCount._2
// println("Type nodeName being used: "+nodeName)
val newPathNames: Map[bdd.Pathed.Path, (Code, Type)] =
targetType match {
case Type.Node(node) =>
schema.nodesByName(node)
.fields
.zipWithIndex
.map { case (field, idx) =>
(path :+ idx) -> (
Code.BinOp(
Code.Literal(nodeName),
".",
Code.Literal(field.name)
),
field.t
)
}
.toMap
case _ => Map.empty
}
Code.Pair(target, Code.Literal(s".isInstanceOf[${targetType.scalaType}]"))
-> Code.Block(Seq(
Code.BinOp(
Code.Literal(s"val $nodeName"),
PaddedString.pad(1, "="),
Code.Pair(target,
Code.Literal(s".asInstanceOf[${targetType.scalaType}]")
)
)
)++recur(
schema,
family,
onFail,
andThen,
pathNames ++ newPathNames ++ Map(path -> (Code.Literal(nodeName), targetType)),
boundVars, newSymCount, syms ++ Seq(nodeName), predefs
).block
)
}
}.toSeq,
recur(schema,family,onFail, elseBranch, pathNames, boundVars, symCount, syms, predefs)
)
case me@bdd.PickByMatch(fpath, fpattern, fbindings, fonMatch, fonFail) =>
val target = pathNames.get(fpath)
.getOrElse {
assert(false, s"No name available for $fpath (in $pathNames):\n$rule")
}
val scope =
codeScopeFor(schema, fpattern.references, fbindings, pathNames, boundVars, fpattern)
val onMatchRecur = recur(
schema, family, onFail,
fonMatch,
pathNames,
boundVars,
symCount,
syms,
predefs
)
val onFailRecur = recur(
schema, family, onFail,
fonFail,
pathNames,
boundVars,
symCount,
syms,
predefs
)
// if(fonFail.isInstanceOf[bdd.PickByMatch]){
// System.err.println("Me: "+fpattern)
// }
val key = s"Func${me.toString.hashCode.abs.toString}"
if(predefs.contains(key))
{
Code.Literal(s"$key(${syms.mkString(",")}, $onFail)")
}
else
{
Match(
schema = schema,
pattern = fpattern,
target = target._1,
targetPath = fpath,
targetType = target._2,
onSuccess =
_ => onMatchRecur,
onFail =
_ => onFailRecur,
name = None,
scope = scope
)
}
// val match_code:Option[(String,Option[(String, Code)])] = if(fonFail.isInstanceOf[bdd.PickByMatch])
// {
// val failbranch = fonFail.asInstanceOf[bdd.PickByMatch]
// System.err.println("====================")
// System.err.println("Gen fail pattern: "+failbranch.matcher)
// System.err.println("Gen fail ifmatched hash: "+failbranch.ifMatched.toString.hashCode.abs)
// val key = s"Func${failbranch.ifMatched.toString.hashCode.abs.toString}"
// Some(key,predefs.get(key))
// } else {
// None
// }
// val whattodo = if(match_code.isDefined)
// {
// Code.Literal(s"${match_code.get._1}(${syms.mkString(",")})")
// } else
// {
// onMatchRecur
// }
// System.err.println(whattodo)
case bdd.Rewrite(label, op, bindings) =>
Code.Block(
Seq(
Code.Literal(s"RewriteBy$label(${syms.mkString(",")})"),
)
// ++ Expression(
// schema,
// op,
// codeScopeFor(op.references, bindings, pathNames, boundVars, op)
// ).block
)
case bdd.BindAnExpression(symbol, expression, exprType, bindings, andThen) =>
Code.Block(
Seq(
Code.Pair(
Code.Literal(s"val $symbol ="),
Expression(
schema, expression,
codeScopeFor(schema, expression.references, bindings, pathNames, boundVars, expression)
)
)
)++recur(
schema,
family,
onFail,
andThen,
pathNames,
boundVars ++ Map(symbol -> exprType),
symCount, syms, predefs
).block
)
}
def apply(
schema: Definition,
family: Type.AST,
rule: bdd.BDD,
onFail: Code,
root: Code,
predefs: Map[String, (String,Code)],
): Code =
{
recur(
schema,
family,
onFail,
rule,
Map(
Seq.empty -> (root, family)
),
Map.empty,
1,
Seq.empty,
predefs
)
}
}

View File

@ -2,6 +2,7 @@ package com.astraldb.codegen
import com.astraldb.spec
import com.astraldb.bdd
import com.astraldb.codegen
object Rule
{
@ -20,7 +21,8 @@ object Rule
{
if(benchmark)
{
scala.BDDBatch(schema, family, rule, "Benchmark").toString
val funcs = codegen.BDDFactorize(schema, family, rule, Code.Literal("plan"),Code.Literal("plan"))
scala.BDDBatchBenchmarks(schema, family, rule, "Benchmark", funcs).toString
} else
{
scala.BDDBatch(schema, family, rule).toString

View File

@ -8,11 +8,12 @@
object @{clazz}@{family.scalaType}BDD extends Rule[@{family.scalaType}]
{
def apply(plan: @{family.scalaType}): @{family.scalaType} =
plan.transform { case plan =>
@{codegen.BDD(schema, family, bdd,
onFail = Code.Literal("plan"),
root = Code.Literal("plan")
root = Code.Literal("plan"),
)
.toString(indent = 4)
.stripPrefix(" ")}

View File

@ -0,0 +1,28 @@
@import com.astraldb.spec.Definition
@import com.astraldb.spec.Type
@import com.astraldb.bdd.BDD
@import com.astraldb.codegen
@import com.astraldb.codegen.Code
@(schema: Definition, family: Type.AST, bdd: BDD, clazz: String = "", funcs: Map[String, (String,Code)] = Map[String, (String, Code)]())
object @{clazz}@{family.scalaType}BDD extends Rule[@{family.scalaType}]
{
@for(fun <- funcs){def @fun._1@fun._2._1 =
{
@{fun._2._2
.toString(indent = 4)}
}
}
def apply(plan: @{family.scalaType}): @{family.scalaType} =
plan.transform { case plan =>
@{codegen.BenchmarkBDD(schema, family, bdd,
onFail = Code.Literal("plan"),
root = Code.Literal("plan"),
predefs = funcs
)
.toString(indent = 4)
.stripPrefix(" ")}
}
}

View File

@ -15,8 +15,8 @@ with open('generated_benchmarks/shareability/'+sys.argv[1]+'.crule', 'r') as rul
with open('astral/catalyst/src/com/astraldb/catalyst/ExtendedDSL.scala', 'w') as dsl_file:
with open('generated_benchmarks/rewrite_depth/dsl.scala') as header_file:
header = header_file.read()
# footer = "\n rulesInit1()\n rulesInit2()\n rulesInit3()\n rulesInit4()\n}\n"
footer = "\n rulesInit1()\n rulesInit2()\n}\n"
footer = "\n rulesInit1()\n rulesInit2()\n rulesInit3()\n rulesInit4()\n}\n"
# footer = "\n rulesInit1()\n rulesInit2()\n}\n"
dsl_file.write(header)
@ -28,18 +28,14 @@ with open('astral/catalyst/src/com/astraldb/catalyst/ExtendedDSL.scala', 'w') as
for j in range(25,50):
dsl_file.write('\t\t' + rules[j])
dsl_file.write("\t}")
# dsl_file.write("\n def rulesInit3(): Unit =\n {\n")
# for j in range(50,75):
# dsl_file.write('\t\t' + rules[j])
# dsl_file.write("\t}")
# dsl_file.write("\n def rulesInit4(): Unit =\n {\n")
# for j in range(75,100):
# dsl_file.write('\t\t' + rules[j])
# dsl_file.write("\t}")
# dsl_file.write("\n def rulesInit5(): Unit =\n {\n")
# for j in range(100,105):
# dsl_file.write('\t\t' + rules[j])
# dsl_file.write("\t}")
dsl_file.write("\n def rulesInit3(): Unit =\n {\n")
for j in range(50,75):
dsl_file.write('\t\t' + rules[j])
dsl_file.write("\t}")
dsl_file.write("\n def rulesInit4(): Unit =\n {\n")
for j in range(75,100):
dsl_file.write('\t\t' + rules[j])
dsl_file.write("\t}")
dsl_file.write(footer)
with open('generated_benchmarks/data_size/query_1000.in','r') as data_file:

View File

@ -29,11 +29,23 @@ def rule_gen_recursive(random_match_string, rule_name):
for elem in random_match_string:
rule_string_recursive +='))'
for elem in random_match_string:
rule_string_recursive += 'and Test( Ref("' +elem+'") eq String("'+elem+'"))'
rule_string_recursive += 'and Test(Ref("' +elem+'") eq String("'+elem+'"))'
rule_string_recursive += ')'
# print("==============")
# for elem1, elem2 in zip(random_match_string[::2], random_match_string[1::2]):
# print(elem1, elem2)
# rule_string_recursive += 'and Test(((Ref("' +elem1+'") eq String("'+elem1+'"))) and ((Ref("' +elem2+'") eq String("'+elem2+'"))))'
# rule_string_recursive += 'and Test(Ref("' +random_match_string[-1]+'") eq String("'+random_match_string[-1]+'"))'
# rule_string_recursive += ')'
# rule_string_recursive += 'and Test(((Ref("' +random_match_string[0]+'") eq String("'+random_match_string[0]+'"))) and ((Ref("' +random_match_string[1]+'") eq String("'+random_match_string[1]+'"))))'
# for elem in random_match_string[2:]:
# rule_string_recursive += 'and Test(Ref("' +elem+'") eq String("'+elem+'"))'
# rule_string_recursive += ')'
rule_string_recursive += '('
for elem in subs_string:
@ -48,7 +60,7 @@ def rule_gen_recursive(random_match_string, rule_name):
# print(rule_string_recursive, file = f)
return rule_string_recursive.strip() + "\n"
rule_count = 49
rule_count = 99
for i in range(0, len(match_string)):
count_recursive = 1
with open('generated_benchmarks/shareability/rules_'+str(i)+'_inorder.crule', 'w') as f: