[SPARK-22748][SQL] Analyze __grouping__id as a literal function

### What changes were proposed in this pull request?

This PR intends to refactor the logic to resolve `__grouping_id` in the `Analyzer`; it moves the logic from `ResolveFunctions` to `ResolveReferences` (`resolveLiteralFunction`).

The original author of this PR is sqlwindspeaker (#30781).

Closes #30781.

### Why are the changes needed?

Code refactoring.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Added tests in `AnalysisSuite`.

Closes #31751 from maropu/SPARK-22748.

Authored-by: suqilong <suqilong@qiyi.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
suqilong 2021-03-05 07:40:58 +00:00 committed by Wenchen Fan
parent 358697b386
commit ca326c4bb3
2 changed files with 83 additions and 10 deletions

View file

@ -1838,6 +1838,13 @@ class Analyzer(override val catalogManager: CatalogManager)
exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined)
}
// support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id
private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq(
(CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)),
(CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)),
(VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName)
)
/**
* Literal functions do not require the user to specify braces when calling them
* When an attributes is not resolvable, we try to resolve it as a literal function.
@ -1849,17 +1856,19 @@ class Analyzer(override val catalogManager: CatalogManager)
if (nameParts.length != 1) return None
val isNamedExpression = plan match {
case Aggregate(_, aggregateExpressions, _) => aggregateExpressions.contains(attribute)
case GroupingSets(_, _, _, aggregateExpressions) => aggregateExpressions.contains(attribute)
case Project(projectList, _) => projectList.contains(attribute)
case Window(windowExpressions, _, _, _) => windowExpressions.contains(attribute)
case _ => false
}
val wrapper: Expression => Expression =
if (isNamedExpression) f => Alias(f, toPrettySQL(f))() else identity
// support CURRENT_DATE and CURRENT_TIMESTAMP
val literalFunctions = Seq(CurrentDate(), CurrentTimestamp())
val wrapper: (Expression, String) => Expression =
if (isNamedExpression) (f, n) => Alias(f, n)() else (f, _) => f
val name = nameParts.head
val func = literalFunctions.find(e => caseInsensitiveResolution(e.prettyName, name))
func.map(wrapper)
val func = literalFunctions.find { case (fn, _, _) => caseInsensitiveResolution(fn, name) }
func.map { case (_, f, fn) =>
val funcExpr = f()
wrapper(funcExpr, fn(funcExpr))
}
}
/**
@ -2218,10 +2227,6 @@ class Analyzer(override val catalogManager: CatalogManager)
case q: LogicalPlan =>
q transformExpressions {
case u if !u.childrenResolved => u // Skip until children are resolved.
case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) =>
withPosition(u) {
Alias(GroupingID(Nil), VirtualColumn.hiveGroupingIdName)()
}
case u @ UnresolvedGenerator(name, children) =>
withPosition(u) {
v1SessionCatalog.lookupFunction(name, children) match {

View file

@ -1036,4 +1036,72 @@ class AnalysisSuite extends AnalysisTest with Matchers {
)
}
}
test("SPARK-22748: Analyze __grouping__id as a literal function") {
assertAnalysisSuccess(parsePlan(
"""
|SELECT grouping__id FROM (
| SELECT grouping__id FROM (
| SELECT a, b, count(1), grouping__id FROM TaBlE2
| GROUP BY a, b WITH ROLLUP
| )
|)
""".stripMargin), false)
assertAnalysisSuccess(parsePlan(
"""
|SELECT grouping__id FROM (
| SELECT a, b, count(1), grouping__id FROM TaBlE2
| GROUP BY a, b WITH CUBE
|)
""".stripMargin), false)
assertAnalysisSuccess(parsePlan(
"""
|SELECT grouping__id FROM (
| SELECT a, b, count(1), grouping__id FROM TaBlE2
| GROUP BY a, b GROUPING SETS ((a, b), ())
|)
""".stripMargin), false)
assertAnalysisSuccess(parsePlan(
"""
|SELECT a, b, count(1) FROM TaBlE2
| GROUP BY CUBE(a, b) HAVING grouping__id > 0
""".stripMargin), false)
assertAnalysisSuccess(parsePlan(
"""
|SELECT * FROM (
| SELECT a, b, count(1) FROM TaBlE2
| GROUP BY a, b GROUPING SETS ((a, b), ())
|) WHERE grouping__id > 0
""".stripMargin), false)
assertAnalysisSuccess(parsePlan(
"""
|SELECT * FROM (
| SELECT a, b, count(1) FROM TaBlE2
| GROUP BY a, b GROUPING SETS ((a, b), ())
|) ORDER BY grouping__id > 0
""".stripMargin), false)
assertAnalysisSuccess(parsePlan(
"""
|SELECT a, b, count(1) FROM TaBlE2
| GROUP BY a, b GROUPING SETS ((a, b), ())
| ORDER BY grouping__id > 0
""".stripMargin), false)
assertAnalysisError(parsePlan(
"""
|SELECT grouping__id FROM (
| SELECT a, b, count(1), grouping__id FROM TaBlE2
| GROUP BY a, b
|)
""".stripMargin),
Seq("grouping_id() can only be used with GroupingSets/Cube/Rollup"),
false)
}
}