[SPARK-12720][SQL] SQL Generation Support for Cube, Rollup, and Grouping Sets
#### What changes were proposed in this pull request? This PR is for supporting SQL generation for cube, rollup and grouping sets. For example, a query using rollup: ```SQL SELECT count(*) as cnt, key % 5, grouping_id() FROM t1 GROUP BY key % 5 WITH ROLLUP ``` Original logical plan: ``` Aggregate [(key#17L % cast(5 as bigint))#47L,grouping__id#46], [(count(1),mode=Complete,isDistinct=false) AS cnt#43L, (key#17L % cast(5 as bigint))#47L AS _c1#45L, grouping__id#46 AS _c2#44] +- Expand [List(key#17L, value#18, (key#17L % cast(5 as bigint))#47L, 0), List(key#17L, value#18, null, 1)], [key#17L,value#18,(key#17L % cast(5 as bigint))#47L,grouping__id#46] +- Project [key#17L, value#18, (key#17L % cast(5 as bigint)) AS (key#17L % cast(5 as bigint))#47L] +- Subquery t1 +- Relation[key#17L,value#18] ParquetRelation ``` Converted SQL: ```SQL SELECT count( 1) AS `cnt`, (`t1`.`key` % CAST(5 AS BIGINT)), grouping_id() AS `_c2` FROM `default`.`t1` GROUP BY (`t1`.`key` % CAST(5 AS BIGINT)) GROUPING SETS (((`t1`.`key` % CAST(5 AS BIGINT))), ()) ``` #### How was the this patch tested? Added eight test cases in `LogicalPlanToSQLSuite`. Author: gatorsmile <gatorsmile@gmail.com> Author: xiaoli <lixiao1983@gmail.com> Author: Xiao Li <xiaoli@Xiaos-MacBook-Pro.local> Closes #11283 from gatorsmile/groupingSetsToSQL.
This commit is contained in:
parent
f19228eed8
commit
adce5ee721
|
@ -348,13 +348,13 @@ def grouping_id(*cols):
|
|||
grouping columns).
|
||||
|
||||
>>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show()
|
||||
+-----+------------+--------+
|
||||
| name|groupingid()|sum(age)|
|
||||
+-----+------------+--------+
|
||||
+-----+-------------+--------+
|
||||
| name|grouping_id()|sum(age)|
|
||||
+-----+-------------+--------+
|
||||
| null| 1| 7|
|
||||
|Alice| 0| 2|
|
||||
| Bob| 0| 5|
|
||||
+-----+------------+--------+
|
||||
+-----+-------------+--------+
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
jc = sc._jvm.functions.grouping_id(_to_seq(sc, cols, _to_java_column))
|
||||
|
|
|
@ -63,4 +63,5 @@ case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Une
|
|||
override def children: Seq[Expression] = groupByExprs
|
||||
override def dataType: DataType = IntegerType
|
||||
override def nullable: Boolean = false
|
||||
override def prettyName: String = "grouping_id"
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
|
|||
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
|
||||
import org.apache.spark.sql.catalyst.util.quoteIdentifier
|
||||
import org.apache.spark.sql.execution.datasources.LogicalRelation
|
||||
import org.apache.spark.sql.types.{DataType, NullType}
|
||||
import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, NullType}
|
||||
|
||||
/**
|
||||
* A place holder for generated SQL for subquery expression.
|
||||
|
@ -118,6 +118,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
|
|||
case p: Project =>
|
||||
projectToSQL(p, isDistinct = false)
|
||||
|
||||
case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) if isGroupingSet(a, e, p) =>
|
||||
groupingSetToSQL(a, e, p)
|
||||
|
||||
case p: Aggregate =>
|
||||
aggregateToSQL(p)
|
||||
|
||||
|
@ -244,6 +247,77 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
|
|||
)
|
||||
}
|
||||
|
||||
private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
|
||||
output1.size == output2.size &&
|
||||
output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))
|
||||
|
||||
private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = {
|
||||
assert(a.child == e && e.child == p)
|
||||
a.groupingExpressions.forall(_.isInstanceOf[Attribute]) &&
|
||||
sameOutput(e.output, p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute]))
|
||||
}
|
||||
|
||||
private def groupingSetToSQL(
|
||||
agg: Aggregate,
|
||||
expand: Expand,
|
||||
project: Project): String = {
|
||||
assert(agg.groupingExpressions.length > 1)
|
||||
|
||||
// The last column of Expand is always grouping ID
|
||||
val gid = expand.output.last
|
||||
|
||||
val numOriginalOutput = project.child.output.length
|
||||
// Assumption: Aggregate's groupingExpressions is composed of
|
||||
// 1) the attributes of aliased group by expressions
|
||||
// 2) gid, which is always the last one
|
||||
val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute])
|
||||
// Assumption: Project's projectList is composed of
|
||||
// 1) the original output (Project's child.output),
|
||||
// 2) the aliased group by expressions.
|
||||
val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child)
|
||||
val groupingSQL = groupByExprs.map(_.sql).mkString(", ")
|
||||
|
||||
// a map from group by attributes to the original group by expressions.
|
||||
val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs))
|
||||
|
||||
val groupingSet: Seq[Seq[Expression]] = expand.projections.map { project =>
|
||||
// Assumption: expand.projections is composed of
|
||||
// 1) the original output (Project's child.output),
|
||||
// 2) group by attributes(or null literal)
|
||||
// 3) gid, which is always the last one in each project in Expand
|
||||
project.drop(numOriginalOutput).dropRight(1).collect {
|
||||
case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr)
|
||||
}
|
||||
}
|
||||
val groupingSetSQL =
|
||||
"GROUPING SETS(" +
|
||||
groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")"
|
||||
|
||||
val aggExprs = agg.aggregateExpressions.map { case expr =>
|
||||
expr.transformDown {
|
||||
// grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back.
|
||||
case ar: AttributeReference if ar == gid => GroupingID(Nil)
|
||||
case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar)
|
||||
case a @ Cast(BitwiseAnd(
|
||||
ShiftRight(ar: AttributeReference, Literal(value: Any, IntegerType)),
|
||||
Literal(1, IntegerType)), ByteType) if ar == gid =>
|
||||
// for converting an expression to its original SQL format grouping(col)
|
||||
val idx = groupByExprs.length - 1 - value.asInstanceOf[Int]
|
||||
groupByExprs.lift(idx).map(Grouping).getOrElse(a)
|
||||
}
|
||||
}
|
||||
|
||||
build(
|
||||
"SELECT",
|
||||
aggExprs.map(_.sql).mkString(", "),
|
||||
if (agg.child == OneRowRelation) "" else "FROM",
|
||||
toSQL(project.child),
|
||||
"GROUP BY",
|
||||
groupingSQL,
|
||||
groupingSetSQL
|
||||
)
|
||||
}
|
||||
|
||||
object Canonicalizer extends RuleExecutor[LogicalPlan] {
|
||||
override protected def batches: Seq[Batch] = Seq(
|
||||
Batch("Canonicalizer", FixedPoint(100),
|
||||
|
|
|
@ -218,6 +218,149 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
|
|||
checkHiveQl("SELECT DISTINCT id FROM parquet_t0")
|
||||
}
|
||||
|
||||
test("rollup/cube #1") {
|
||||
// Original logical plan:
|
||||
// Aggregate [(key#17L % cast(5 as bigint))#47L,grouping__id#46],
|
||||
// [(count(1),mode=Complete,isDistinct=false) AS cnt#43L,
|
||||
// (key#17L % cast(5 as bigint))#47L AS _c1#45L,
|
||||
// grouping__id#46 AS _c2#44]
|
||||
// +- Expand [List(key#17L, value#18, (key#17L % cast(5 as bigint))#47L, 0),
|
||||
// List(key#17L, value#18, null, 1)],
|
||||
// [key#17L,value#18,(key#17L % cast(5 as bigint))#47L,grouping__id#46]
|
||||
// +- Project [key#17L,
|
||||
// value#18,
|
||||
// (key#17L % cast(5 as bigint)) AS (key#17L % cast(5 as bigint))#47L]
|
||||
// +- Subquery t1
|
||||
// +- Relation[key#17L,value#18] ParquetRelation
|
||||
// Converted SQL:
|
||||
// SELECT count( 1) AS `cnt`,
|
||||
// (`t1`.`key` % CAST(5 AS BIGINT)),
|
||||
// grouping_id() AS `_c2`
|
||||
// FROM `default`.`t1`
|
||||
// GROUP BY (`t1`.`key` % CAST(5 AS BIGINT))
|
||||
// GROUPING SETS (((`t1`.`key` % CAST(5 AS BIGINT))), ())
|
||||
checkHiveQl(
|
||||
"SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH ROLLUP")
|
||||
checkHiveQl(
|
||||
"SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH CUBE")
|
||||
}
|
||||
|
||||
test("rollup/cube #2") {
|
||||
checkHiveQl("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH ROLLUP")
|
||||
checkHiveQl("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH CUBE")
|
||||
}
|
||||
|
||||
test("rollup/cube #3") {
|
||||
checkHiveQl(
|
||||
"SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH ROLLUP")
|
||||
checkHiveQl(
|
||||
"SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH CUBE")
|
||||
}
|
||||
|
||||
test("rollup/cube #4") {
|
||||
checkHiveQl(
|
||||
s"""
|
||||
|SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1
|
||||
|GROUP BY key % 5, key - 5 WITH ROLLUP
|
||||
""".stripMargin)
|
||||
checkHiveQl(
|
||||
s"""
|
||||
|SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1
|
||||
|GROUP BY key % 5, key - 5 WITH CUBE
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
test("rollup/cube #5") {
|
||||
checkHiveQl(
|
||||
s"""
|
||||
|SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3
|
||||
|FROM (SELECT key, key%2, key - 5 FROM parquet_t1) t GROUP BY key%5, key-5
|
||||
|WITH ROLLUP
|
||||
""".stripMargin)
|
||||
checkHiveQl(
|
||||
s"""
|
||||
|SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3
|
||||
|FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5
|
||||
|WITH CUBE
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
test("rollup/cube #6") {
|
||||
checkHiveQl("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b")
|
||||
checkHiveQl("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b")
|
||||
checkHiveQl("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b")
|
||||
checkHiveQl("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b")
|
||||
checkHiveQl("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH ROLLUP")
|
||||
checkHiveQl("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH CUBE")
|
||||
}
|
||||
|
||||
test("rollup/cube #7") {
|
||||
checkHiveQl("SELECT a, b, grouping_id(a, b) FROM parquet_t2 GROUP BY cube(a, b)")
|
||||
checkHiveQl("SELECT a, b, grouping(b) FROM parquet_t2 GROUP BY cube(a, b)")
|
||||
checkHiveQl("SELECT a, b, grouping(a) FROM parquet_t2 GROUP BY cube(a, b)")
|
||||
}
|
||||
|
||||
test("rollup/cube #8") {
|
||||
// grouping_id() is part of another expression
|
||||
checkHiveQl(
|
||||
s"""
|
||||
|SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid
|
||||
|FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5
|
||||
|WITH ROLLUP
|
||||
""".stripMargin)
|
||||
checkHiveQl(
|
||||
s"""
|
||||
|SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid
|
||||
|FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5
|
||||
|WITH CUBE
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
test("rollup/cube #9") {
|
||||
// self join is used as the child node of ROLLUP/CUBE with replaced quantifiers
|
||||
checkHiveQl(
|
||||
s"""
|
||||
|SELECT t.key - 5, cnt, SUM(cnt)
|
||||
|FROM (SELECT x.key, COUNT(*) as cnt
|
||||
|FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t
|
||||
|GROUP BY cnt, t.key - 5
|
||||
|WITH ROLLUP
|
||||
""".stripMargin)
|
||||
checkHiveQl(
|
||||
s"""
|
||||
|SELECT t.key - 5, cnt, SUM(cnt)
|
||||
|FROM (SELECT x.key, COUNT(*) as cnt
|
||||
|FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t
|
||||
|GROUP BY cnt, t.key - 5
|
||||
|WITH CUBE
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
test("grouping sets #1") {
|
||||
checkHiveQl(
|
||||
s"""
|
||||
|SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id() AS k3
|
||||
|FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5
|
||||
|GROUPING SETS (key % 5, key - 5)
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
test("grouping sets #2") {
|
||||
checkHiveQl(
|
||||
"SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a, b) ORDER BY a, b")
|
||||
checkHiveQl(
|
||||
"SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a) ORDER BY a, b")
|
||||
checkHiveQl(
|
||||
"SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (b) ORDER BY a, b")
|
||||
checkHiveQl(
|
||||
"SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (()) ORDER BY a, b")
|
||||
checkHiveQl(
|
||||
s"""
|
||||
|SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b
|
||||
|GROUPING SETS ((), (a), (a, b)) ORDER BY a, b
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
test("cluster by") {
|
||||
checkHiveQl("SELECT id FROM parquet_t0 CLUSTER BY id")
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue