[SPARK-18137][SQL] Fix RewriteDistinctAggregates UnresolvedException when a UDAF has a foldable TypeCheck
## What changes were proposed in this pull request? In RewriteDistinctAggregates rewrite funtion,after the UDAF's childs are mapped to AttributeRefference, If the UDAF(such as ApproximatePercentile) has a foldable TypeCheck for the input, It will failed because the AttributeRefference is not foldable,then the UDAF is not resolved, and then nullify on the unresolved object will throw a Exception. In this PR, only map Unfoldable child to AttributeRefference, this can avoid the UDAF's foldable TypeCheck. and then only Expand Unfoldable child, there is no need to Expand a static value(foldable value). **Before sql result** > select percentile_approxy(key,0.99999),count(distinct key),sume(distinc key) from src limit 1 > org.apache.spark.sql.catalyst.analysis.UnresolvedException: Invalid call to dataType on unresolved object, tree: 'percentile_approx(CAST(src.`key` AS DOUBLE), CAST(0.99999BD AS DOUBLE), 10000) > at org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute.dataType(unresolved.scala:92) > at org.apache.spark.sql.catalyst.optimizer.RewriteDistinctAggregates$.org$apache$spark$sql$catalyst$optimizer$RewriteDistinctAggregates$$nullify(RewriteDistinctAggregates.scala:261) **After sql result** > select percentile_approxy(key,0.99999),count(distinct key),sume(distinc key) from src limit 1 > [498.0,309,79136] ## How was this patch tested? Add a test case in HiveUDFSuit. Author: root <root@iZbp1gsnrlfzjxh82cz80vZ.(none)> Closes #15668 from windpiger/RewriteDistinctUDAFUnresolveExcep.
This commit is contained in:
parent
47731e1865
commit
c291bd2745
|
@ -115,9 +115,21 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
|
|||
}
|
||||
|
||||
// Extract distinct aggregate expressions.
|
||||
val distinctAggGroups = aggExpressions
|
||||
.filter(_.isDistinct)
|
||||
.groupBy(_.aggregateFunction.children.toSet)
|
||||
val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e =>
|
||||
val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet
|
||||
if (unfoldableChildren.nonEmpty) {
|
||||
// Only expand the unfoldable children
|
||||
unfoldableChildren
|
||||
} else {
|
||||
// If aggregateFunction's children are all foldable
|
||||
// we must expand at least one of the children (here we take the first child),
|
||||
// or If we don't, we will get the wrong result, for example:
|
||||
// count(distinct 1) will be explained to count(1) after the rewrite function.
|
||||
// Generally, the distinct aggregateFunction should not run
|
||||
// foldable TypeCheck for the first child.
|
||||
e.aggregateFunction.children.take(1).toSet
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the aggregates contains functions that do not support partial aggregation.
|
||||
val existsNonPartial = aggExpressions.exists(!_.aggregateFunction.supportsPartial)
|
||||
|
@ -136,8 +148,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
|
|||
def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
|
||||
def patchAggregateFunctionChildren(
|
||||
af: AggregateFunction)(
|
||||
attrs: Expression => Expression): AggregateFunction = {
|
||||
af.withNewChildren(af.children.map(attrs)).asInstanceOf[AggregateFunction]
|
||||
attrs: Expression => Option[Expression]): AggregateFunction = {
|
||||
val newChildren = af.children.map(c => attrs(c).getOrElse(c))
|
||||
af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
|
||||
}
|
||||
|
||||
// Setup unique distinct aggregate children.
|
||||
|
@ -161,7 +174,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
|
|||
val operators = expressions.map { e =>
|
||||
val af = e.aggregateFunction
|
||||
val naf = patchAggregateFunctionChildren(af) { x =>
|
||||
evalWithinGroup(id, distinctAggChildAttrLookup(x))
|
||||
distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _))
|
||||
}
|
||||
(e, e.copy(aggregateFunction = naf, isDistinct = false))
|
||||
}
|
||||
|
@ -170,8 +183,12 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
|
|||
}
|
||||
|
||||
// Setup expand for the 'regular' aggregate expressions.
|
||||
val regularAggExprs = aggExpressions.filter(!_.isDistinct)
|
||||
val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
|
||||
// only expand unfoldable children
|
||||
val regularAggExprs = aggExpressions
|
||||
.filter(e => !e.isDistinct && e.children.exists(!_.foldable))
|
||||
val regularAggChildren = regularAggExprs
|
||||
.flatMap(_.aggregateFunction.children.filter(!_.foldable))
|
||||
.distinct
|
||||
val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair)
|
||||
|
||||
// Setup aggregates for 'regular' aggregate expressions.
|
||||
|
@ -179,7 +196,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
|
|||
val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
|
||||
val regularAggOperatorMap = regularAggExprs.map { e =>
|
||||
// Perform the actual aggregation in the initial aggregate.
|
||||
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
|
||||
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get)
|
||||
val operator = Alias(e.copy(aggregateFunction = af), e.sql)()
|
||||
|
||||
// Select the result of the first aggregate in the last aggregate.
|
||||
|
|
|
@ -150,6 +150,41 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
|
|||
}
|
||||
|
||||
test("Generic UDAF aggregates") {
|
||||
|
||||
checkAnswer(sql(
|
||||
"""
|
||||
|SELECT percentile_approx(2, 0.99999),
|
||||
| sum(distinct 1),
|
||||
| count(distinct 1,2,3,4) FROM src LIMIT 1
|
||||
""".stripMargin), sql("SELECT 2, 1, 1 FROM src LIMIT 1").collect().toSeq)
|
||||
|
||||
checkAnswer(sql(
|
||||
"""
|
||||
|SELECT ceiling(percentile_approx(distinct key, 0.99999)),
|
||||
| count(distinct key),
|
||||
| sum(distinct key),
|
||||
| count(distinct 1),
|
||||
| sum(distinct 1),
|
||||
| sum(1) FROM src LIMIT 1
|
||||
""".stripMargin),
|
||||
sql(
|
||||
"""
|
||||
|SELECT max(key),
|
||||
| count(distinct key),
|
||||
| sum(distinct key),
|
||||
| 1, 1, sum(1) FROM src LIMIT 1
|
||||
""".stripMargin).collect().toSeq)
|
||||
|
||||
checkAnswer(sql(
|
||||
"""
|
||||
|SELECT ceiling(percentile_approx(distinct key, 0.9 + 0.09999)),
|
||||
| count(distinct key), sum(distinct key),
|
||||
| count(distinct 1), sum(distinct 1),
|
||||
| sum(1) FROM src LIMIT 1
|
||||
""".stripMargin),
|
||||
sql("SELECT max(key), count(distinct key), sum(distinct key), 1, 1, sum(1) FROM src LIMIT 1")
|
||||
.collect().toSeq)
|
||||
|
||||
checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999D)) FROM src LIMIT 1"),
|
||||
sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq)
|
||||
|
||||
|
|
Loading…
Reference in a new issue