[SPARK-22748][SQL] Analyze __grouping__id as a literal function
### What changes were proposed in this pull request? This PR intends to refactor the logic to resolve `__grouping_id` in the `Analyzer`; it moves the logic from `ResolveFunctions` to `ResolveReferences` (`resolveLiteralFunction`). The original author of this PR is sqlwindspeaker (#30781). Closes #30781. ### Why are the changes needed? Code refactoring. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added tests in `AnalysisSuite`. Closes #31751 from maropu/SPARK-22748. Authored-by: suqilong <suqilong@qiyi.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
358697b386
commit
ca326c4bb3
|
@ -1838,6 +1838,13 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined)
|
||||
}
|
||||
|
||||
// support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id
|
||||
private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq(
|
||||
(CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)),
|
||||
(CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)),
|
||||
(VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName)
|
||||
)
|
||||
|
||||
/**
|
||||
* Literal functions do not require the user to specify braces when calling them
|
||||
* When an attributes is not resolvable, we try to resolve it as a literal function.
|
||||
|
@ -1849,17 +1856,19 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
if (nameParts.length != 1) return None
|
||||
val isNamedExpression = plan match {
|
||||
case Aggregate(_, aggregateExpressions, _) => aggregateExpressions.contains(attribute)
|
||||
case GroupingSets(_, _, _, aggregateExpressions) => aggregateExpressions.contains(attribute)
|
||||
case Project(projectList, _) => projectList.contains(attribute)
|
||||
case Window(windowExpressions, _, _, _) => windowExpressions.contains(attribute)
|
||||
case _ => false
|
||||
}
|
||||
val wrapper: Expression => Expression =
|
||||
if (isNamedExpression) f => Alias(f, toPrettySQL(f))() else identity
|
||||
// support CURRENT_DATE and CURRENT_TIMESTAMP
|
||||
val literalFunctions = Seq(CurrentDate(), CurrentTimestamp())
|
||||
val wrapper: (Expression, String) => Expression =
|
||||
if (isNamedExpression) (f, n) => Alias(f, n)() else (f, _) => f
|
||||
val name = nameParts.head
|
||||
val func = literalFunctions.find(e => caseInsensitiveResolution(e.prettyName, name))
|
||||
func.map(wrapper)
|
||||
val func = literalFunctions.find { case (fn, _, _) => caseInsensitiveResolution(fn, name) }
|
||||
func.map { case (_, f, fn) =>
|
||||
val funcExpr = f()
|
||||
wrapper(funcExpr, fn(funcExpr))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -2218,10 +2227,6 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
case q: LogicalPlan =>
|
||||
q transformExpressions {
|
||||
case u if !u.childrenResolved => u // Skip until children are resolved.
|
||||
case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) =>
|
||||
withPosition(u) {
|
||||
Alias(GroupingID(Nil), VirtualColumn.hiveGroupingIdName)()
|
||||
}
|
||||
case u @ UnresolvedGenerator(name, children) =>
|
||||
withPosition(u) {
|
||||
v1SessionCatalog.lookupFunction(name, children) match {
|
||||
|
|
|
@ -1036,4 +1036,72 @@ class AnalysisSuite extends AnalysisTest with Matchers {
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-22748: Analyze __grouping__id as a literal function") {
|
||||
assertAnalysisSuccess(parsePlan(
|
||||
"""
|
||||
|SELECT grouping__id FROM (
|
||||
| SELECT grouping__id FROM (
|
||||
| SELECT a, b, count(1), grouping__id FROM TaBlE2
|
||||
| GROUP BY a, b WITH ROLLUP
|
||||
| )
|
||||
|)
|
||||
""".stripMargin), false)
|
||||
|
||||
|
||||
assertAnalysisSuccess(parsePlan(
|
||||
"""
|
||||
|SELECT grouping__id FROM (
|
||||
| SELECT a, b, count(1), grouping__id FROM TaBlE2
|
||||
| GROUP BY a, b WITH CUBE
|
||||
|)
|
||||
""".stripMargin), false)
|
||||
|
||||
assertAnalysisSuccess(parsePlan(
|
||||
"""
|
||||
|SELECT grouping__id FROM (
|
||||
| SELECT a, b, count(1), grouping__id FROM TaBlE2
|
||||
| GROUP BY a, b GROUPING SETS ((a, b), ())
|
||||
|)
|
||||
""".stripMargin), false)
|
||||
|
||||
assertAnalysisSuccess(parsePlan(
|
||||
"""
|
||||
|SELECT a, b, count(1) FROM TaBlE2
|
||||
| GROUP BY CUBE(a, b) HAVING grouping__id > 0
|
||||
""".stripMargin), false)
|
||||
|
||||
assertAnalysisSuccess(parsePlan(
|
||||
"""
|
||||
|SELECT * FROM (
|
||||
| SELECT a, b, count(1) FROM TaBlE2
|
||||
| GROUP BY a, b GROUPING SETS ((a, b), ())
|
||||
|) WHERE grouping__id > 0
|
||||
""".stripMargin), false)
|
||||
|
||||
assertAnalysisSuccess(parsePlan(
|
||||
"""
|
||||
|SELECT * FROM (
|
||||
| SELECT a, b, count(1) FROM TaBlE2
|
||||
| GROUP BY a, b GROUPING SETS ((a, b), ())
|
||||
|) ORDER BY grouping__id > 0
|
||||
""".stripMargin), false)
|
||||
|
||||
assertAnalysisSuccess(parsePlan(
|
||||
"""
|
||||
|SELECT a, b, count(1) FROM TaBlE2
|
||||
| GROUP BY a, b GROUPING SETS ((a, b), ())
|
||||
| ORDER BY grouping__id > 0
|
||||
""".stripMargin), false)
|
||||
|
||||
assertAnalysisError(parsePlan(
|
||||
"""
|
||||
|SELECT grouping__id FROM (
|
||||
| SELECT a, b, count(1), grouping__id FROM TaBlE2
|
||||
| GROUP BY a, b
|
||||
|)
|
||||
""".stripMargin),
|
||||
Seq("grouping_id() can only be used with GroupingSets/Cube/Rollup"),
|
||||
false)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue