From adce5ee721c6a844ff21dfcd8515859458fe611d Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 5 Mar 2016 19:25:03 +0800 Subject: [PATCH] [SPARK-12720][SQL] SQL Generation Support for Cube, Rollup, and Grouping Sets #### What changes were proposed in this pull request? This PR is for supporting SQL generation for cube, rollup and grouping sets. For example, a query using rollup: ```SQL SELECT count(*) as cnt, key % 5, grouping_id() FROM t1 GROUP BY key % 5 WITH ROLLUP ``` Original logical plan: ``` Aggregate [(key#17L % cast(5 as bigint))#47L,grouping__id#46], [(count(1),mode=Complete,isDistinct=false) AS cnt#43L, (key#17L % cast(5 as bigint))#47L AS _c1#45L, grouping__id#46 AS _c2#44] +- Expand [List(key#17L, value#18, (key#17L % cast(5 as bigint))#47L, 0), List(key#17L, value#18, null, 1)], [key#17L,value#18,(key#17L % cast(5 as bigint))#47L,grouping__id#46] +- Project [key#17L, value#18, (key#17L % cast(5 as bigint)) AS (key#17L % cast(5 as bigint))#47L] +- Subquery t1 +- Relation[key#17L,value#18] ParquetRelation ``` Converted SQL: ```SQL SELECT count( 1) AS `cnt`, (`t1`.`key` % CAST(5 AS BIGINT)), grouping_id() AS `_c2` FROM `default`.`t1` GROUP BY (`t1`.`key` % CAST(5 AS BIGINT)) GROUPING SETS (((`t1`.`key` % CAST(5 AS BIGINT))), ()) ``` #### How was the this patch tested? Added eight test cases in `LogicalPlanToSQLSuite`. Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #11283 from gatorsmile/groupingSetsToSQL. --- python/pyspark/sql/functions.py | 14 +- .../sql/catalyst/expressions/grouping.scala | 1 + .../apache/spark/sql/hive/SQLBuilder.scala | 76 +++++++++- .../sql/hive/LogicalPlanToSQLSuite.scala | 143 ++++++++++++++++++ 4 files changed, 226 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 92e724fef4..88924e2981 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -348,13 +348,13 @@ def grouping_id(*cols): grouping columns). >>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show() - +-----+------------+--------+ - | name|groupingid()|sum(age)| - +-----+------------+--------+ - | null| 1| 7| - |Alice| 0| 2| - | Bob| 0| 5| - +-----+------------+--------+ + +-----+-------------+--------+ + | name|grouping_id()|sum(age)| + +-----+-------------+--------+ + | null| 1| 7| + |Alice| 0| 2| + | Bob| 0| 5| + +-----+-------------+--------+ """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.grouping_id(_to_seq(sc, cols, _to_java_column)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index a204060630..437e417266 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -63,4 +63,5 @@ case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Une override def children: Seq[Expression] = groupByExprs override def dataType: DataType = IntegerType override def nullable: Boolean = false + override def prettyName: String = "grouping_id" } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 9a14ccff57..8d411a9a40 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.types.{DataType, NullType} +import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, NullType} /** * A place holder for generated SQL for subquery expression. @@ -118,6 +118,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case p: Project => projectToSQL(p, isDistinct = false) + case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) if isGroupingSet(a, e, p) => + groupingSetToSQL(a, e, p) + case p: Aggregate => aggregateToSQL(p) @@ -244,6 +247,77 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi ) } + private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = + output1.size == output2.size && + output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) + + private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = { + assert(a.child == e && e.child == p) + a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && + sameOutput(e.output, p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute])) + } + + private def groupingSetToSQL( + agg: Aggregate, + expand: Expand, + project: Project): String = { + assert(agg.groupingExpressions.length > 1) + + // The last column of Expand is always grouping ID + val gid = expand.output.last + + val numOriginalOutput = project.child.output.length + // Assumption: Aggregate's groupingExpressions is composed of + // 1) the attributes of aliased group by expressions + // 2) gid, which is always the last one + val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute]) + // Assumption: Project's projectList is composed of + // 1) the original output (Project's child.output), + // 2) the aliased group by expressions. + val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child) + val groupingSQL = groupByExprs.map(_.sql).mkString(", ") + + // a map from group by attributes to the original group by expressions. + val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) + + val groupingSet: Seq[Seq[Expression]] = expand.projections.map { project => + // Assumption: expand.projections is composed of + // 1) the original output (Project's child.output), + // 2) group by attributes(or null literal) + // 3) gid, which is always the last one in each project in Expand + project.drop(numOriginalOutput).dropRight(1).collect { + case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr) + } + } + val groupingSetSQL = + "GROUPING SETS(" + + groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")" + + val aggExprs = agg.aggregateExpressions.map { case expr => + expr.transformDown { + // grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back. + case ar: AttributeReference if ar == gid => GroupingID(Nil) + case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar) + case a @ Cast(BitwiseAnd( + ShiftRight(ar: AttributeReference, Literal(value: Any, IntegerType)), + Literal(1, IntegerType)), ByteType) if ar == gid => + // for converting an expression to its original SQL format grouping(col) + val idx = groupByExprs.length - 1 - value.asInstanceOf[Int] + groupByExprs.lift(idx).map(Grouping).getOrElse(a) + } + } + + build( + "SELECT", + aggExprs.map(_.sql).mkString(", "), + if (agg.child == OneRowRelation) "" else "FROM", + toSQL(project.child), + "GROUP BY", + groupingSQL, + groupingSetSQL + ) + } + object Canonicalizer extends RuleExecutor[LogicalPlan] { override protected def batches: Seq[Batch] = Seq( Batch("Canonicalizer", FixedPoint(100), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index d708fcf8dd..f457d43e19 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -218,6 +218,149 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkHiveQl("SELECT DISTINCT id FROM parquet_t0") } + test("rollup/cube #1") { + // Original logical plan: + // Aggregate [(key#17L % cast(5 as bigint))#47L,grouping__id#46], + // [(count(1),mode=Complete,isDistinct=false) AS cnt#43L, + // (key#17L % cast(5 as bigint))#47L AS _c1#45L, + // grouping__id#46 AS _c2#44] + // +- Expand [List(key#17L, value#18, (key#17L % cast(5 as bigint))#47L, 0), + // List(key#17L, value#18, null, 1)], + // [key#17L,value#18,(key#17L % cast(5 as bigint))#47L,grouping__id#46] + // +- Project [key#17L, + // value#18, + // (key#17L % cast(5 as bigint)) AS (key#17L % cast(5 as bigint))#47L] + // +- Subquery t1 + // +- Relation[key#17L,value#18] ParquetRelation + // Converted SQL: + // SELECT count( 1) AS `cnt`, + // (`t1`.`key` % CAST(5 AS BIGINT)), + // grouping_id() AS `_c2` + // FROM `default`.`t1` + // GROUP BY (`t1`.`key` % CAST(5 AS BIGINT)) + // GROUPING SETS (((`t1`.`key` % CAST(5 AS BIGINT))), ()) + checkHiveQl( + "SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH ROLLUP") + checkHiveQl( + "SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH CUBE") + } + + test("rollup/cube #2") { + checkHiveQl("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH ROLLUP") + checkHiveQl("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH CUBE") + } + + test("rollup/cube #3") { + checkHiveQl( + "SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH ROLLUP") + checkHiveQl( + "SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH CUBE") + } + + test("rollup/cube #4") { + checkHiveQl( + s""" + |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1 + |GROUP BY key % 5, key - 5 WITH ROLLUP + """.stripMargin) + checkHiveQl( + s""" + |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1 + |GROUP BY key % 5, key - 5 WITH CUBE + """.stripMargin) + } + + test("rollup/cube #5") { + checkHiveQl( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3 + |FROM (SELECT key, key%2, key - 5 FROM parquet_t1) t GROUP BY key%5, key-5 + |WITH ROLLUP + """.stripMargin) + checkHiveQl( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3 + |FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5 + |WITH CUBE + """.stripMargin) + } + + test("rollup/cube #6") { + checkHiveQl("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b") + checkHiveQl("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b") + checkHiveQl("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b") + checkHiveQl("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b") + checkHiveQl("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH ROLLUP") + checkHiveQl("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH CUBE") + } + + test("rollup/cube #7") { + checkHiveQl("SELECT a, b, grouping_id(a, b) FROM parquet_t2 GROUP BY cube(a, b)") + checkHiveQl("SELECT a, b, grouping(b) FROM parquet_t2 GROUP BY cube(a, b)") + checkHiveQl("SELECT a, b, grouping(a) FROM parquet_t2 GROUP BY cube(a, b)") + } + + test("rollup/cube #8") { + // grouping_id() is part of another expression + checkHiveQl( + s""" + |SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid + |FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5 + |WITH ROLLUP + """.stripMargin) + checkHiveQl( + s""" + |SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid + |FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5 + |WITH CUBE + """.stripMargin) + } + + test("rollup/cube #9") { + // self join is used as the child node of ROLLUP/CUBE with replaced quantifiers + checkHiveQl( + s""" + |SELECT t.key - 5, cnt, SUM(cnt) + |FROM (SELECT x.key, COUNT(*) as cnt + |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t + |GROUP BY cnt, t.key - 5 + |WITH ROLLUP + """.stripMargin) + checkHiveQl( + s""" + |SELECT t.key - 5, cnt, SUM(cnt) + |FROM (SELECT x.key, COUNT(*) as cnt + |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t + |GROUP BY cnt, t.key - 5 + |WITH CUBE + """.stripMargin) + } + + test("grouping sets #1") { + checkHiveQl( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id() AS k3 + |FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5 + |GROUPING SETS (key % 5, key - 5) + """.stripMargin) + } + + test("grouping sets #2") { + checkHiveQl( + "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a, b) ORDER BY a, b") + checkHiveQl( + "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a) ORDER BY a, b") + checkHiveQl( + "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (b) ORDER BY a, b") + checkHiveQl( + "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (()) ORDER BY a, b") + checkHiveQl( + s""" + |SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b + |GROUPING SETS ((), (a), (a, b)) ORDER BY a, b + """.stripMargin) + } + test("cluster by") { checkHiveQl("SELECT id FROM parquet_t0 CLUSTER BY id") }