[SPARK-12705] [SQL] push missing attributes for Sort

The current implementation of ResolveSortReferences can only push one missing attributes into it's child, it failed to analyze TPCDS Q98, because of there are two missing attributes in that (one from Window, another from Aggregate).

Author: Davies Liu <davies@databricks.com>

Closes #11153 from davies/resolve_sort.
This commit is contained in:
Davies Liu 2016-02-12 09:34:18 -08:00 committed by Davies Liu
parent 64515e5fbf
commit 5b805df279
3 changed files with 67 additions and 83 deletions

View file

@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.AnalysisException
@ -598,98 +597,69 @@ class Analyzer(
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
case sa @ Sort(_, _, child: Aggregate) => sa
case s @ Sort(_, _, child) if !s.resolved && child.resolved =>
val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child)
if (missingResolvableAttrs.isEmpty) {
val unresolvableAttrs = s.order.filterNot(_.resolved)
logDebug(s"Failed to find $unresolvableAttrs in ${child.output.mkString(", ")}")
s // Nothing we can do here. Return original plan.
} else {
// Add the missing attributes into projectList of Project/Window or
// aggregateExpressions of Aggregate, if they are in the inputSet
// but not in the outputSet of the plan.
val newChild = child transformUp {
case p: Project =>
p.copy(projectList = p.projectList ++
missingResolvableAttrs.filter((p.inputSet -- p.outputSet).contains))
case w: Window =>
w.copy(projectList = w.projectList ++
missingResolvableAttrs.filter((w.inputSet -- w.outputSet).contains))
case a: Aggregate =>
val resolvableAttrs = missingResolvableAttrs.filter(a.groupingExpressions.contains)
val notResolvedAttrs = resolvableAttrs.filterNot(a.aggregateExpressions.contains)
val newAggregateExpressions = a.aggregateExpressions ++ notResolvedAttrs
a.copy(aggregateExpressions = newAggregateExpressions)
case o => o
}
case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
val missingAttrs = requiredAttrs -- child.outputSet
if (missingAttrs.nonEmpty) {
// Add missing attributes and then project them away after the sort.
Project(child.output,
Sort(newOrdering, s.global, newChild))
Sort(newOrder, s.global, addMissingAttr(child, missingAttrs)))
} else if (newOrder != order) {
s.copy(order = newOrder)
} else {
s
}
}
/**
* Traverse the tree until resolving the sorting attributes
* Return all the resolvable missing sorting attributes
*/
@tailrec
private def collectResolvableMissingAttrs(
ordering: Seq[SortOrder],
plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
* Add the missing attributes into projectList of Project/Window or aggregateExpressions of
* Aggregate.
*/
private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = {
if (missingAttrs.isEmpty) {
return plan
}
plan match {
// Only Windows and Project have projectList-like attribute.
case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] =>
val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child)
// If missingAttrs is non empty, that means we got it and return it;
// Otherwise, continue to traverse the tree.
if (missingAttrs.nonEmpty) {
(newOrdering, missingAttrs)
} else {
collectResolvableMissingAttrs(ordering, un.child)
}
case p: Project =>
val missing = missingAttrs -- p.child.outputSet
Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing))
case w: Window =>
val missing = missingAttrs -- w.child.outputSet
w.copy(projectList = w.projectList ++ missingAttrs,
child = addMissingAttr(w.child, missing))
case a: Aggregate =>
val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child)
// For Aggregate, all the order by columns must be specified in group by clauses
if (missingAttrs.nonEmpty &&
missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) {
(newOrdering, missingAttrs)
} else {
// If missingAttrs is empty, we are unable to resolve any unresolved missing attributes
(Seq.empty[SortOrder], Seq.empty[Attribute])
// all the missing attributes should be grouping expressions
// TODO: push down AggregateExpression
missingAttrs.foreach { attr =>
if (!a.groupingExpressions.exists(_.semanticEquals(attr))) {
throw new AnalysisException(s"Can't add $attr to ${a.simpleString}")
}
}
// Jump over the following UnaryNode types
// The output of these types is the same as their child's output
case _: Distinct |
_: Filter |
_: RepartitionByExpression =>
collectResolvableMissingAttrs(ordering, plan.asInstanceOf[UnaryNode].child)
// If hitting the other unsupported operators, we are unable to resolve it.
case other => (Seq.empty[SortOrder], Seq.empty[Attribute])
val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs
a.copy(aggregateExpressions = newAggregateExpressions)
case u: UnaryNode =>
u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil)
case other =>
throw new AnalysisException(s"Can't add $missingAttrs to $other")
}
}
/**
* Try to resolve the sort ordering and returns it with a list of attributes that are missing
* from the plan but are present in the child.
*/
private def resolveAndFindMissing(
ordering: Seq[SortOrder],
plan: LogicalPlan,
child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
val newOrdering =
ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder])
// Construct a set that contains all of the attributes that we need to evaluate the
// ordering.
val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved)
// Figure out which ones are missing from the projection, so that we can add them and
// remove them after the sort.
val missingInProject = requiredAttributes -- plan.outputSet
// It is important to return the new SortOrders here, instead of waiting for the standard
// resolving process as adding attributes to the project below can actually introduce
// ambiguity that was not present before.
(newOrdering, missingInProject.toSeq)
* Resolve the expression on a specified logical plan and it's child (recursively), until
* the expression is resolved or meet a non-unary node or Subquery.
*/
private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = {
val resolved = resolveExpression(expr, plan)
if (resolved.resolved) {
resolved
} else {
plan match {
case u: UnaryNode if !u.isInstanceOf[Subquery] =>
resolveExpressionRecursively(resolved, u.child)
case other => resolved
}
}
}
}
@ -782,8 +752,7 @@ class Analyzer(
filter
}
case sort @ Sort(sortOrder, global, aggregate: Aggregate)
if aggregate.resolved =>
case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>
// Try resolving the ordering as though it is in the aggregate clause.
try {

View file

@ -90,7 +90,7 @@ class AnalysisSuite extends AnalysisTest {
.where(a > "str").select(a, b, c)
.where(b > "str").select(a, b, c)
.sortBy(b.asc, c.desc)
.select(a, b).select(a)
.select(a)
checkAnalysis(plan1, expected1)
// Case 2: all the missing attributes are in the leaf node

View file

@ -978,6 +978,21 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
("d", 1),
("c", 2)
).map(i => Row(i._1, i._2)))
checkAnswer(
sql(
"""
|select area, sum(product) / sum(sum(product)) over (partition by area) as c1
|from windowData group by area, month order by month, c1
""".stripMargin),
Seq(
("d", 1.0),
("a", 1.0),
("b", 0.4666666666666667),
("b", 0.5333333333333333),
("c", 0.45),
("c", 0.55)
).map(i => Row(i._1, i._2)))
}
// todo: fix this test case by reimplementing the function ResolveAggregateFunctions