General Testing

relalg
Oliver Kennedy 2023-07-18 23:11:39 -04:00
parent 9270b4d6c7
commit 6811aa74ec
Signed by: okennedy
GPG Key ID: 3E5F9B3ABD3FDB60
17 changed files with 177 additions and 66 deletions

View File

@ -1,56 +0,0 @@
package com.astraldb.catalyst
import org.apache.spark.sql.{ Row, SparkSession }
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
object OptimizerTest
{
def main(args: Array[String]): Unit =
{
val spark: SparkSession =
SparkSession.builder
.appName("OptimizerTest")
.master("local")
.getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
val r = spark.emptyDataFrame
.select(lit(1) as "A", lit(2) as "B")
r.createOrReplaceTempView("R")
val df = spark.sql("SELECT * FROM R")
val analyzed = df.queryExecution.analyzed
println(df.queryExecution.logical)
println(analyzed)
println("------------------")
val optimized =
Time("Astral Optimizer") {
Optimizer.rewrite(analyzed)
}
println("------------------\nAstral-Naive Optimized Query:\n")
println(optimized)
println("------------------")
val bddOptimized =
Time("BDD Optimizer") {
LogicalPlanBDDOptimizer.rewrite(analyzed)
}
println("------------------\nAstral-BDD Optimized Query:\n")
println(optimized)
println("------------------")
val sparkOptimized =
Time("Spark Optimizer") {
df.queryExecution.optimizedPlan
}
println("------------------\nSpark Optimized Query:\n")
println(sparkOptimized)
println("------------------")
}
}

View File

@ -0,0 +1,65 @@
package com.astraldb.catalyst
object TPCHTest
{
val INIT_TABLES = Seq[String](
"CREATE TABLE LINEITEM (l_orderkey INT,l_partkey INT,l_suppkey INT,l_linenumber INT,l_quantity DECIMAL,l_extendedprice DECIMAL,l_discount DECIMAL,l_tax DECIMAL,l_returnflag CHAR(1),l_linestatus CHAR(1),l_shipdate DATE,l_commitdate DATE,l_receiptdate DATE,l_shipinstruct CHAR(25),l_shipmode CHAR(10),l_comment VARCHAR(44)) USING csv OPTIONS(path './test_data/tpch/lineitem.tbl', delimiter '|')",
"CREATE TABLE ORDERS (o_orderkey INT,o_custkey INT,o_orderstatus CHAR(1),o_totalprice DECIMAL,o_orderdate DATE,o_orderpriority CHAR(15),o_clerk CHAR(15),o_shippriority INT,o_comment VARCHAR(79)) USING csv OPTIONS(path './test_data/tpch/orders.tbl', delimiter '|')",
"CREATE TABLE PART (p_partkey INT,p_name VARCHAR(55),p_mfgr CHAR(25),p_brand CHAR(10),p_type VARCHAR(25),p_size INT,p_container CHAR(10),p_retailprice DECIMAL,p_comment VARCHAR(23)) USING csv OPTIONS(path './test_data/tpch/part.tbl', delimiter '|')",
"CREATE TABLE CUSTOMER (c_custkey INT,c_name VARCHAR(25),c_address VARCHAR(40),c_nationkey INT,c_phone CHAR(15),c_acctbal DECIMAL,c_mktsegment CHAR(10),c_comment VARCHAR(117)) USING csv OPTIONS(path './test_data/tpch/customer.tbl', delimiter '|')",
"CREATE TABLE SUPPLIER (s_suppkey INT,s_name CHAR(25),s_address VARCHAR(40),s_nationkey INT,s_phone CHAR(15),s_acctbal DECIMAL,s_comment VARCHAR(101)) USING csv OPTIONS(path './test_data/tpch/supplier.tbl', delimiter '|')",
"CREATE TABLE PARTSUPP (ps_partkey INT,ps_suppkey INT,ps_availqty INT,ps_supplycost DECIMAL,ps_comment VARCHAR(199)) USING csv OPTIONS(path './test_data/tpch/partsupp.tbl', delimiter '|')",
"CREATE TABLE NATION (n_nationkey INT,n_name CHAR(25),n_regionkey INT,n_comment VARCHAR(152)) USING csv OPTIONS(path './test_data/tpch/nation.tbl', delimiter '|')",
"CREATE TABLE REGION (r_regionkey INT,r_name CHAR(25),r_comment VARCHAR(152)) USING csv OPTIONS(path './test_data/tpch/region.tbl', delimiter '|')",
)
val QUERIES = Seq[(Int,String)](
1 -> "select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, avg(l_quantity) as avg_qty, avg(l_extendedprice) as avg_price, avg(l_discount) as avg_disc, count(*) as count_order from lineitem where l_shipdate <= date '1998-12-01' - interval '90' day group by l_returnflag, l_linestatus order by l_returnflag, l_linestatus",
2 -> "select s_acctbal, s_name, n_name, p_partkey, p_mfgr, s_address, s_phone, s_comment from part, supplier, partsupp, nation, region where p_partkey = ps_partkey and s_suppkey = ps_suppkey and p_size = 15 and p_type like '%BRASS' and s_nationkey = n_nationkey and n_regionkey = r_regionkey and r_name = 'EUROPE' and ps_supplycost = ( select min(ps_supplycost) from partsupp, supplier, nation, region where p_partkey = ps_partkey and s_suppkey = ps_suppkey and s_nationkey = n_nationkey and n_regionkey = r_regionkey and r_name = 'EUROPE' ) order by s_acctbal desc, n_name, s_name, p_partkey",
3 -> "select l_orderkey, sum(l_extendedprice * (1 - l_discount)) as revenue, o_orderdate, o_shippriority from customer, orders, lineitem where c_mktsegment = 'BUILDING' and c_custkey = o_custkey and l_orderkey = o_orderkey and o_orderdate < date '1995-03-15' and l_shipdate > date '1995-03-15' group by l_orderkey, o_orderdate, o_shippriority order by revenue desc, o_orderdate",
4 -> "select o_orderpriority, count(*) as order_count from orders where o_orderdate >= date '1993-07-01' and o_orderdate < date '1993-07-01' + interval '3' month and exists ( select * from lineitem where l_orderkey = o_orderkey and l_commitdate < l_receiptdate ) group by o_orderpriority order by o_orderpriority",
5 -> "select n_name, sum(l_extendedprice * (1 - l_discount)) as revenue from customer, orders, lineitem, supplier, nation, region where c_custkey = o_custkey and l_orderkey = o_orderkey and l_suppkey = s_suppkey and c_nationkey = s_nationkey and s_nationkey = n_nationkey and n_regionkey = r_regionkey and r_name = 'ASIA' and o_orderdate >= date '1994-01-01' and o_orderdate < date '1994-01-01' + interval '1' year group by n_name order by revenue desc",
6 -> "select sum(l_extendedprice * l_discount) as revenue from lineitem where l_shipdate >= date '1994-01-01' and l_shipdate < date '1994-01-01' + interval '1' year and l_discount between .06 - 0.01 and .06 + 0.01 and l_quantity < 24",
7 -> "select supp_nation, cust_nation, l_year, sum(volume) as revenue from ( select n1.n_name as supp_nation, n2.n_name as cust_nation, extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume from supplier, lineitem, orders, customer, nation n1, nation n2 where s_suppkey = l_suppkey and o_orderkey = l_orderkey and c_custkey = o_custkey and s_nationkey = n1.n_nationkey and c_nationkey = n2.n_nationkey and ( (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') ) and l_shipdate between date '1995-01-01' and date '1996-12-31' ) as shipping group by supp_nation, cust_nation, l_year order by supp_nation, cust_nation, l_year",
8 -> "select o_year, sum(case when nation = 'BRAZIL' then volume else 0 end) / sum(volume) as mkt_share from ( select extract(year from o_orderdate) as o_year, l_extendedprice * (1 - l_discount) as volume, n2.n_name as nation from part, supplier, lineitem, orders, customer, nation n1, nation n2, region where p_partkey = l_partkey and s_suppkey = l_suppkey and l_orderkey = o_orderkey and o_custkey = c_custkey and c_nationkey = n1.n_nationkey and n1.n_regionkey = r_regionkey and r_name = 'AMERICA' and s_nationkey = n2.n_nationkey and o_orderdate between date '1995-01-01' and date '1996-12-31' and p_type = 'ECONOMY ANODIZED STEEL' ) as all_nations group by o_year order by o_year; ",
9 -> "select nation, o_year, sum(amount) as sum_profit from ( select n_name as nation, extract(year from o_orderdate) as o_year, l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount from part, supplier, lineitem, partsupp, orders, nation where s_suppkey = l_suppkey and ps_suppkey = l_suppkey and ps_partkey = l_partkey and p_partkey = l_partkey and o_orderkey = l_orderkey and s_nationkey = n_nationkey and p_name like '%green%' ) as profit group by nation, o_year order by nation, o_year desc",
10 -> "select c_custkey, c_name, sum(l_extendedprice * (1 - l_discount)) as revenue, c_acctbal, n_name, c_address, c_phone, c_comment from customer, orders, lineitem, nation where c_custkey = o_custkey and l_orderkey = o_orderkey and o_orderdate >= date '1993-10-01' and o_orderdate < date '1993-10-01' + interval '3' month and l_returnflag = 'R' and c_nationkey = n_nationkey group by c_custkey, c_name, c_acctbal, c_phone, n_name, c_address, c_comment order by revenue desc",
11 -> "select ps_partkey, sum(ps_supplycost * ps_availqty) as value from partsupp, supplier, nation where ps_suppkey = s_suppkey and s_nationkey = n_nationkey and n_name = 'GERMANY' group by ps_partkey having sum(ps_supplycost * ps_availqty) > ( select sum(ps_supplycost * ps_availqty) * 0.0001000000 from partsupp, supplier, nation where ps_suppkey = s_suppkey and s_nationkey = n_nationkey and n_name = 'GERMANY' ) order by value desc",
12 -> "select l_shipmode, sum(case when o_orderpriority = '1-URGENT' or o_orderpriority = '2-HIGH' then 1 else 0 end) as high_line_count, sum(case when o_orderpriority <> '1-URGENT' and o_orderpriority <> '2-HIGH' then 1 else 0 end) as low_line_count from orders, lineitem where o_orderkey = l_orderkey and l_shipmode in ('MAIL', 'SHIP') and l_commitdate < l_receiptdate and l_shipdate < l_commitdate and l_receiptdate >= date '1994-01-01' and l_receiptdate < date '1994-01-01' + interval '1' year group by l_shipmode order by l_shipmode",
13 -> "select c_count, count(*) as custdist from ( select c_custkey, count(o_orderkey) from customer left outer join orders on c_custkey = o_custkey and o_comment not like '%special%requests%' group by c_custkey ) as c_orders (c_custkey, c_count) group by c_count order by custdist desc, c_count desc",
14 -> "select 100.00 * sum(case when p_type like 'PROMO%' then l_extendedprice * (1 - l_discount) else 0 end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue from lineitem, part where l_partkey = p_partkey and l_shipdate >= date '1995-09-01' and l_shipdate < date '1995-09-01' + interval '1' month",
15 -> "create temporary view revenue0 (supplier_no, total_revenue) as select l_suppkey, sum(l_extendedprice * (1 - l_discount)) from lineitem where l_shipdate >= date '1996-01-01' and l_shipdate < date '1996-01-01' + interval '3' month group by l_suppkey",
15 -> "select s_suppkey, s_name, s_address, s_phone, total_revenue from supplier, revenue0 where s_suppkey = supplier_no and total_revenue = ( select max(total_revenue) from revenue0 ) order by s_suppkey",
16 -> "select p_brand, p_type, p_size, count(distinct ps_suppkey) as supplier_cnt from partsupp, part where p_partkey = ps_partkey and p_brand <> 'Brand#45' and p_type not like 'MEDIUM POLISHED%' and p_size in (49, 14, 23, 45, 19, 3, 36, 9) and ps_suppkey not in ( select s_suppkey from supplier where s_comment like '%Customer%Complaints%' ) group by p_brand, p_type, p_size order by supplier_cnt desc, p_brand, p_type, p_size",
17 -> "select sum(l_extendedprice) / 7.0 as avg_yearly from lineitem, part where p_partkey = l_partkey and p_brand = 'Brand#23' and p_container = 'MED BOX' and l_quantity < ( select 0.2 * avg(l_quantity) from lineitem where l_partkey = p_partkey )",
18 -> "select c_name, c_custkey, o_orderkey, o_orderdate, o_totalprice, sum(l_quantity) from customer, orders, lineitem where o_orderkey in ( select l_orderkey from lineitem group by l_orderkey having sum(l_quantity) > 300 ) and c_custkey = o_custkey and o_orderkey = l_orderkey group by c_name, c_custkey, o_orderkey, o_orderdate, o_totalprice order by o_totalprice desc, o_orderdate",
19 -> "select sum(l_extendedprice* (1 - l_discount)) as revenue from lineitem, part where ( p_partkey = l_partkey and p_brand = 'Brand#12' and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') and l_quantity >= 1 and l_quantity <= 1 + 10 and p_size between 1 and 5 and l_shipmode in ('AIR', 'AIR REG') and l_shipinstruct = 'DELIVER IN PERSON' ) or ( p_partkey = l_partkey and p_brand = 'Brand#23' and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') and l_quantity >= 10 and l_quantity <= 10 + 10 and p_size between 1 and 10 and l_shipmode in ('AIR', 'AIR REG') and l_shipinstruct = 'DELIVER IN PERSON' ) or ( p_partkey = l_partkey and p_brand = 'Brand#34' and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') and l_quantity >= 20 and l_quantity <= 20 + 10 and p_size between 1 and 15 and l_shipmode in ('AIR', 'AIR REG') and l_shipinstruct = 'DELIVER IN PERSON' )",
20 -> "select s_name, s_address from supplier, nation where s_suppkey in ( select ps_suppkey from partsupp where ps_partkey in ( select p_partkey from part where p_name like 'forest%' ) and ps_availqty > ( select 0.5 * sum(l_quantity) from lineitem where l_partkey = ps_partkey and l_suppkey = ps_suppkey and l_shipdate >= date '1994-01-01' and l_shipdate < date '1994-01-01' + interval '1' year ) ) and s_nationkey = n_nationkey and n_name = 'CANADA' order by s_name",
21 -> "select s_name, count(*) as numwait from supplier, lineitem l1, orders, nation where s_suppkey = l1.l_suppkey and o_orderkey = l1.l_orderkey and o_orderstatus = 'F' and l1.l_receiptdate > l1.l_commitdate and exists ( select * from lineitem l2 where l2.l_orderkey = l1.l_orderkey and l2.l_suppkey <> l1.l_suppkey ) and not exists ( select * from lineitem l3 where l3.l_orderkey = l1.l_orderkey and l3.l_suppkey <> l1.l_suppkey and l3.l_receiptdate > l3.l_commitdate ) and s_nationkey = n_nationkey and n_name = 'SAUDI ARABIA' group by s_name order by numwait desc, s_name",
22 -> "select cntrycode, count(*) as numcust, sum(c_acctbal) as totacctbal from ( select substring(c_phone from 1 for 2) as cntrycode, c_acctbal from customer where substring(c_phone from 1 for 2) in ('13', '31', '23', '29', '30', '18', '17') and c_acctbal > ( select avg(c_acctbal) from customer where c_acctbal > 0.00 and substring(c_phone from 1 for 2) in ('13', '31', '23', '29', '30', '18', '17') ) and not exists ( select * from orders where o_custkey = c_custkey ) ) as custsale group by cntrycode order by cntrycode",
)
def main(args: Array[String]): Unit =
{
val spark = Tester.initDefaultSpark()
for(tbl <- INIT_TABLES)
{
spark.sql(tbl)
}
Tester.dumpStats()
for((label, query) <- QUERIES.take(1))
{
println("===================================")
println(s"TPCH Query $label")
println("===================================")
val result = Tester.test(spark.sql(query))
println(result.summary)
}
}
}

View File

@ -0,0 +1,95 @@
package com.astraldb.catalyst
import org.apache.spark.sql.{ Row, SparkSession, DataFrame }
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
object Tester
{
val LINE_SEP = "----------------"
case class Result(plan: LogicalPlan, runtime: Double)
{
override def toString(): String =
s"$LINE_SEP\nTime: $runtime s\nPlan:\n$plan\n$LINE_SEP"
}
object Result
{
def apply(x: (LogicalPlan, Double)): Result =
Result(x._1, x._2)
def time(x: => LogicalPlan): Result =
Result(Time.measure(x))
}
case class AllResults(
val input: LogicalPlan,
val analyzed: LogicalPlan,
val astral: Result,
val bdd: Result,
val spark: Result,
// val lesserSpark: Result
)
{
def tests = Seq(
"Astral-Raw" -> astral,
"Astral-BDD" -> bdd,
"Spark-Full" -> spark
)
override def toString(): String =
s"\nInput\n$input\n$LINE_SEP\n\nANALYZED\n$LINE_SEP\n$analyzed\n$LINE_SEP" +
tests.map { case (name, result) => s"\n\n$name\n$result" }.mkString
def summary: String =
tests.map { case (name, result) => s"$name: ${result.runtime * 1000} ms" }.mkString("\n")
}
def dumpStats(): Unit =
{
println(s"Using Optimizer with ${Optimizer.rules.size} rule fragments based on ${Optimizer.rules.map { _.getClass.getSimpleName().replaceAll("\\$", "").replaceAll("[^a-zA-Z].+", "") }.toSet.size} rules")
println(s"Using BDD Optimizer with ${LogicalPlanBDDOptimizer.rules.size} rule fragments based on ${LogicalPlanBDDOptimizer.rules.map { _.getClass.getSimpleName().replaceAll("\\$", "").replaceAll("[^a-zA-Z].+", "") }.toSet.size} rules")
}
def test(df: DataFrame): AllResults =
{
val analyzed = df.queryExecution.analyzed
AllResults(
input = df.queryExecution.logical,
analyzed = analyzed,
astral = Result.time {
Optimizer.rewrite(analyzed)
},
bdd = Result.time {
LogicalPlanBDDOptimizer.rewrite(analyzed)
},
spark = Result.time {
df.queryExecution.optimizedPlan
}
)
}
def initDefaultSpark() =
{
SparkSession.builder
.appName("OptimizerTest")
.master("local")
.getOrCreate()
}
def main(args: Array[String]): Unit =
{
val spark: SparkSession = initDefaultSpark()
spark.sparkContext.setLogLevel("ERROR")
val r = spark.emptyDataFrame
.select(lit(1) as "A", lit(2) as "B")
r.createOrReplaceTempView("R")
val df = spark.sql("SELECT * FROM R")
println(test(df))
}
}

View File

@ -2,14 +2,12 @@ package com.astraldb.catalyst
object Time
{
def apply[T](label: String)(body: => T): T =
def measure[T](body: => T): (T, Double) =
{
val start = System.currentTimeMillis()
val start = System.nanoTime()
val ret = body
val end = System.currentTimeMillis()
val end = System.nanoTime()
println(s"$label: ${(end - start).toFloat / 1000.0}s")
return ret
return (ret, (end - start).toFloat / 1000000000.0)
}
}

View File

@ -9,7 +9,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.{ SparkSession, DataFrame, Row }
import org.mimirdb.caveats.Caveats
object LoadVizierTest
object VizierTest
{
val MIMIR_LOAD = "org.mimirdb.data.LoadConstructor$"
@ -136,7 +136,7 @@ object LoadVizierTest
Caveats.registerAllUDFs(spark)
loadMimirDump("test_data/vizier/nanostructures.json", spark)
val dfs = loadMimirDump("test_data/vizier/nanostructures.json", spark)
}
}

View File

@ -380,7 +380,7 @@ class AnnotateWithRowIds(
}
/*********************************************************/
case Sample(
case org.apache.spark.sql.catalyst.plans.logical.Sample(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,

View File

@ -107,7 +107,7 @@ object astral extends Module
def scalaVersion = "2.12.15"
def generatedSources = T{ astral.catalyst.rendered() }
def mainClass = Some("com.astraldb.catalyst.OptimizerTest")
def mainClass = Some("com.astraldb.catalyst.Tester")
def ivyDeps = Agg(
ivy"org.apache.spark::spark-sql::3.4.1",

1
spark-warehouse/test_data Symbolic link
View File

@ -0,0 +1 @@
../test_data/

View File

View File

View File

View File

0
test_data/tpch/part.tbl Normal file
View File

View File

View File

View File

8
tmp Normal file
View File

@ -0,0 +1,8 @@
lineitem.tbl
orders.tbl
part.tbl
customer.tbl
supplier.tbl
partsupp.tbl
nation.tbl
region.tbl