[SPARK-13221] [SQL] Fixing GroupingSets when Aggregate Functions Containing GroupBy Columns
Using GroupingSets will generate a wrong result when Aggregate Functions containing GroupBy columns. This PR is to fix it. Since the code changes are very small. Maybe we also can merge it to 1.6 For example, the following query returns a wrong result: ```scala sql("select course, sum(earnings) as sum from courseSales group by course, earnings" + " grouping sets((), (course), (course, earnings))" + " order by course, sum").show() ``` Before the fix, the results are like ``` [null,null] [Java,null] [Java,20000.0] [Java,30000.0] [dotNET,null] [dotNET,5000.0] [dotNET,10000.0] [dotNET,48000.0] ``` After the fix, the results become correct: ``` [null,113000.0] [Java,20000.0] [Java,30000.0] [Java,50000.0] [dotNET,5000.0] [dotNET,10000.0] [dotNET,48000.0] [dotNET,63000.0] ``` UPDATE: This PR also deprecated the external column: GROUPING__ID. Author: gatorsmile <gatorsmile@gmail.com> Closes #11100 from gatorsmile/groupingSets.
This commit is contained in:
parent
e4675c2402
commit
fee739f07b
|
@ -209,13 +209,23 @@ class Analyzer(
|
|||
Seq.tabulate(1 << c.groupByExprs.length)(i => i)
|
||||
}
|
||||
|
||||
private def hasGroupingId(expr: Seq[Expression]): Boolean = {
|
||||
expr.exists(_.collectFirst {
|
||||
case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.groupingIdName) => u
|
||||
}.isDefined)
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
case a if !a.childrenResolved => a // be sure all of the children are resolved.
|
||||
case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
|
||||
GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions)
|
||||
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
|
||||
GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions)
|
||||
case x: GroupingSets =>
|
||||
case g: GroupingSets if g.expressions.exists(!_.resolved) && hasGroupingId(g.expressions) =>
|
||||
failAnalysis(
|
||||
s"${VirtualColumn.groupingIdName} is deprecated; use grouping_id() instead")
|
||||
// Ensure all the expressions have been resolved.
|
||||
case x: GroupingSets if x.expressions.forall(_.resolved) =>
|
||||
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
|
||||
|
||||
// Expand works by setting grouping expressions to null as determined by the bitmasks. To
|
||||
|
|
|
@ -2040,6 +2040,36 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
|||
)
|
||||
}
|
||||
|
||||
test("grouping sets when aggregate functions containing groupBy columns") {
|
||||
checkAnswer(
|
||||
sql("select course, sum(earnings) as sum from courseSales group by course, earnings " +
|
||||
"grouping sets((), (course), (course, earnings)) " +
|
||||
"order by course, sum"),
|
||||
Row(null, 113000.0) ::
|
||||
Row("Java", 20000.0) ::
|
||||
Row("Java", 30000.0) ::
|
||||
Row("Java", 50000.0) ::
|
||||
Row("dotNET", 5000.0) ::
|
||||
Row("dotNET", 10000.0) ::
|
||||
Row("dotNET", 48000.0) ::
|
||||
Row("dotNET", 63000.0) :: Nil
|
||||
)
|
||||
|
||||
checkAnswer(
|
||||
sql("select course, sum(earnings) as sum, grouping_id(course, earnings) from courseSales " +
|
||||
"group by course, earnings grouping sets((), (course), (course, earnings)) " +
|
||||
"order by course, sum"),
|
||||
Row(null, 113000.0, 3) ::
|
||||
Row("Java", 20000.0, 0) ::
|
||||
Row("Java", 30000.0, 0) ::
|
||||
Row("Java", 50000.0, 1) ::
|
||||
Row("dotNET", 5000.0, 0) ::
|
||||
Row("dotNET", 10000.0, 0) ::
|
||||
Row("dotNET", 48000.0, 0) ::
|
||||
Row("dotNET", 63000.0, 1) :: Nil
|
||||
)
|
||||
}
|
||||
|
||||
test("cube") {
|
||||
checkAnswer(
|
||||
sql("select course, year, sum(earnings) from courseSales group by cube(course, year)"),
|
||||
|
@ -2103,6 +2133,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
|||
sql("select course, year, grouping_id(course, year) from courseSales group by course, year")
|
||||
}
|
||||
assert(error.getMessage contains "grouping_id() can only be used with GroupingSets/Cube/Rollup")
|
||||
error = intercept[AnalysisException] {
|
||||
sql("select course, year, grouping__id from courseSales group by cube(course, year)")
|
||||
}
|
||||
assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead")
|
||||
}
|
||||
|
||||
test("SPARK-13056: Null in map value causes NPE") {
|
||||
|
|
|
@ -1,6 +0,0 @@
|
|||
500 NULL 1
|
||||
91 0 0
|
||||
84 1 0
|
||||
105 2 0
|
||||
113 3 0
|
||||
107 4 0
|
|
@ -1,10 +0,0 @@
|
|||
1 NULL -3 2
|
||||
1 NULL -1 2
|
||||
1 NULL 3 2
|
||||
1 NULL 4 2
|
||||
1 NULL 5 2
|
||||
1 NULL 6 2
|
||||
1 NULL 12 2
|
||||
1 NULL 14 2
|
||||
1 NULL 15 2
|
||||
1 NULL 22 2
|
|
@ -1,10 +0,0 @@
|
|||
1 NULL -3 2
|
||||
1 NULL -1 2
|
||||
1 NULL 3 2
|
||||
1 NULL 4 2
|
||||
1 NULL 5 2
|
||||
1 NULL 6 2
|
||||
1 NULL 12 2
|
||||
1 NULL 14 2
|
||||
1 NULL 15 2
|
||||
1 NULL 22 2
|
|
@ -1,6 +0,0 @@
|
|||
500 NULL 1
|
||||
91 0 0
|
||||
84 1 0
|
||||
105 2 0
|
||||
113 3 0
|
||||
107 4 0
|
|
@ -1,10 +0,0 @@
|
|||
1 0 5 0
|
||||
1 0 15 0
|
||||
1 0 25 0
|
||||
1 0 60 0
|
||||
1 0 75 0
|
||||
1 0 80 0
|
||||
1 0 100 0
|
||||
1 0 140 0
|
||||
1 0 145 0
|
||||
1 0 150 0
|
|
@ -1,10 +0,0 @@
|
|||
1 0 5 0
|
||||
1 0 15 0
|
||||
1 0 25 0
|
||||
1 0 60 0
|
||||
1 0 75 0
|
||||
1 0 80 0
|
||||
1 0 100 0
|
||||
1 0 140 0
|
||||
1 0 145 0
|
||||
1 0 150 0
|
|
@ -123,60 +123,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
|
|||
assertBroadcastNestedLoopJoin(spark_10484_4)
|
||||
}
|
||||
|
||||
createQueryTest("SPARK-8976 Wrong Result for Rollup #1",
|
||||
"""
|
||||
SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH ROLLUP
|
||||
""".stripMargin)
|
||||
|
||||
createQueryTest("SPARK-8976 Wrong Result for Rollup #2",
|
||||
"""
|
||||
SELECT
|
||||
count(*) AS cnt,
|
||||
key % 5 as k1,
|
||||
key-5 as k2,
|
||||
GROUPING__ID as k3
|
||||
FROM src group by key%5, key-5
|
||||
WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10
|
||||
""".stripMargin)
|
||||
|
||||
createQueryTest("SPARK-8976 Wrong Result for Rollup #3",
|
||||
"""
|
||||
SELECT
|
||||
count(*) AS cnt,
|
||||
key % 5 as k1,
|
||||
key-5 as k2,
|
||||
GROUPING__ID as k3
|
||||
FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5
|
||||
WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10
|
||||
""".stripMargin)
|
||||
|
||||
createQueryTest("SPARK-8976 Wrong Result for CUBE #1",
|
||||
"""
|
||||
SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH CUBE
|
||||
""".stripMargin)
|
||||
|
||||
createQueryTest("SPARK-8976 Wrong Result for CUBE #2",
|
||||
"""
|
||||
SELECT
|
||||
count(*) AS cnt,
|
||||
key % 5 as k1,
|
||||
key-5 as k2,
|
||||
GROUPING__ID as k3
|
||||
FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5
|
||||
WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10
|
||||
""".stripMargin)
|
||||
|
||||
createQueryTest("SPARK-8976 Wrong Result for GroupingSet",
|
||||
"""
|
||||
SELECT
|
||||
count(*) AS cnt,
|
||||
key % 5 as k1,
|
||||
key-5 as k2,
|
||||
GROUPING__ID as k3
|
||||
FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5
|
||||
GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10
|
||||
""".stripMargin)
|
||||
|
||||
createQueryTest("insert table with generator with column name",
|
||||
"""
|
||||
| CREATE TABLE gen_tmp (key Int);
|
||||
|
|
|
@ -1551,6 +1551,116 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
|
|||
}
|
||||
}
|
||||
|
||||
test("SPARK-8976 Wrong Result for Rollup #1") {
|
||||
checkAnswer(sql(
|
||||
"SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH ROLLUP"),
|
||||
Seq(
|
||||
(113, 3, 0),
|
||||
(91, 0, 0),
|
||||
(500, null, 1),
|
||||
(84, 1, 0),
|
||||
(105, 2, 0),
|
||||
(107, 4, 0)
|
||||
).map(i => Row(i._1, i._2, i._3)))
|
||||
}
|
||||
|
||||
test("SPARK-8976 Wrong Result for Rollup #2") {
|
||||
checkAnswer(sql(
|
||||
"""
|
||||
|SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3
|
||||
|FROM src GROUP BY key%5, key-5
|
||||
|WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10
|
||||
""".stripMargin),
|
||||
Seq(
|
||||
(1, 0, 5, 0),
|
||||
(1, 0, 15, 0),
|
||||
(1, 0, 25, 0),
|
||||
(1, 0, 60, 0),
|
||||
(1, 0, 75, 0),
|
||||
(1, 0, 80, 0),
|
||||
(1, 0, 100, 0),
|
||||
(1, 0, 140, 0),
|
||||
(1, 0, 145, 0),
|
||||
(1, 0, 150, 0)
|
||||
).map(i => Row(i._1, i._2, i._3, i._4)))
|
||||
}
|
||||
|
||||
test("SPARK-8976 Wrong Result for Rollup #3") {
|
||||
checkAnswer(sql(
|
||||
"""
|
||||
|SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3
|
||||
|FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5
|
||||
|WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10
|
||||
""".stripMargin),
|
||||
Seq(
|
||||
(1, 0, 5, 0),
|
||||
(1, 0, 15, 0),
|
||||
(1, 0, 25, 0),
|
||||
(1, 0, 60, 0),
|
||||
(1, 0, 75, 0),
|
||||
(1, 0, 80, 0),
|
||||
(1, 0, 100, 0),
|
||||
(1, 0, 140, 0),
|
||||
(1, 0, 145, 0),
|
||||
(1, 0, 150, 0)
|
||||
).map(i => Row(i._1, i._2, i._3, i._4)))
|
||||
}
|
||||
|
||||
test("SPARK-8976 Wrong Result for CUBE #1") {
|
||||
checkAnswer(sql(
|
||||
"SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH CUBE"),
|
||||
Seq(
|
||||
(113, 3, 0),
|
||||
(91, 0, 0),
|
||||
(500, null, 1),
|
||||
(84, 1, 0),
|
||||
(105, 2, 0),
|
||||
(107, 4, 0)
|
||||
).map(i => Row(i._1, i._2, i._3)))
|
||||
}
|
||||
|
||||
test("SPARK-8976 Wrong Result for CUBE #2") {
|
||||
checkAnswer(sql(
|
||||
"""
|
||||
|SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3
|
||||
|FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5
|
||||
|WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10
|
||||
""".stripMargin),
|
||||
Seq(
|
||||
(1, null, -3, 2),
|
||||
(1, null, -1, 2),
|
||||
(1, null, 3, 2),
|
||||
(1, null, 4, 2),
|
||||
(1, null, 5, 2),
|
||||
(1, null, 6, 2),
|
||||
(1, null, 12, 2),
|
||||
(1, null, 14, 2),
|
||||
(1, null, 15, 2),
|
||||
(1, null, 22, 2)
|
||||
).map(i => Row(i._1, i._2, i._3, i._4)))
|
||||
}
|
||||
|
||||
test("SPARK-8976 Wrong Result for GroupingSet") {
|
||||
checkAnswer(sql(
|
||||
"""
|
||||
|SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3
|
||||
|FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5
|
||||
|GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10
|
||||
""".stripMargin),
|
||||
Seq(
|
||||
(1, null, -3, 2),
|
||||
(1, null, -1, 2),
|
||||
(1, null, 3, 2),
|
||||
(1, null, 4, 2),
|
||||
(1, null, 5, 2),
|
||||
(1, null, 6, 2),
|
||||
(1, null, 12, 2),
|
||||
(1, null, 14, 2),
|
||||
(1, null, 15, 2),
|
||||
(1, null, 22, 2)
|
||||
).map(i => Row(i._1, i._2, i._3, i._4)))
|
||||
}
|
||||
|
||||
test("SPARK-10562: partition by column with mixed case name") {
|
||||
withTable("tbl10562") {
|
||||
val df = Seq(2012 -> "a").toDF("Year", "val")
|
||||
|
|
Loading…
Reference in a new issue