[SPARK-17653][SQL] Remove unnecessary distincts in multiple unions

## What changes were proposed in this pull request?

Currently for `Union [Distinct]`, a `Distinct` operator is necessary to be on the top of `Union`. Once there are adjacent `Union [Distinct]`,  there will be multiple `Distinct` in the query plan.

E.g.,

For a query like: select 1 a union select 2 b union select 3 c

Before this patch, its physical plan looks like:

    *HashAggregate(keys=[a#13], functions=[])
    +- Exchange hashpartitioning(a#13, 200)
       +- *HashAggregate(keys=[a#13], functions=[])
          +- Union
             :- *HashAggregate(keys=[a#13], functions=[])
             :  +- Exchange hashpartitioning(a#13, 200)
             :     +- *HashAggregate(keys=[a#13], functions=[])
             :        +- Union
             :           :- *Project [1 AS a#13]
             :           :  +- Scan OneRowRelation[]
             :           +- *Project [2 AS b#14]
             :              +- Scan OneRowRelation[]
             +- *Project [3 AS c#15]
                +- Scan OneRowRelation[]

Only the top distinct should be necessary.

After this patch, the physical plan looks like:

    *HashAggregate(keys=[a#221], functions=[], output=[a#221])
    +- Exchange hashpartitioning(a#221, 5)
       +- *HashAggregate(keys=[a#221], functions=[], output=[a#221])
          +- Union
             :- *Project [1 AS a#221]
             :  +- Scan OneRowRelation[]
             :- *Project [2 AS b#222]
             :  +- Scan OneRowRelation[]
             +- *Project [3 AS c#223]
                +- Scan OneRowRelation[]

## How was this patch tested?

Jenkins tests.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #15238 from viirya/remove-extra-distinct-union.
This commit is contained in:
Liang-Chi Hsieh 2016-09-29 14:30:23 -07:00 committed by Herman van Hovell
parent fe33121a53
commit 566d7f2827
3 changed files with 89 additions and 30 deletions

View file

@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.annotation.tailrec
import scala.collection.immutable.HashSet
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.api.java.function.FilterFunction
@ -29,7 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions}
import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
@ -579,8 +580,25 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
* Combines all adjacent [[Union]] operators into a single [[Union]].
*/
object CombineUnions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Unions(children) => Union(children)
def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case u: Union => flattenUnion(u, false)
case Distinct(u: Union) => Distinct(flattenUnion(u, true))
}
private def flattenUnion(union: Union, flattenDistinct: Boolean): Union = {
val stack = mutable.Stack[LogicalPlan](union)
val flattened = mutable.ArrayBuffer.empty[LogicalPlan]
while (stack.nonEmpty) {
stack.pop() match {
case Distinct(Union(children)) if flattenDistinct =>
stack.pushAll(children.reverse)
case Union(children) =>
stack.pushAll(children.reverse)
case child =>
flattened += child
}
}
Union(flattened)
}
}

View file

@ -188,33 +188,6 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper {
}
}
/**
* A pattern that collects all adjacent unions and returns their children as a Seq.
*/
object Unions {
def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match {
case u: Union => Some(collectUnionChildren(mutable.Stack(u), Seq.empty[LogicalPlan]))
case _ => None
}
// Doing a depth-first tree traversal to combine all the union children.
@tailrec
private def collectUnionChildren(
plans: mutable.Stack[LogicalPlan],
children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
if (plans.isEmpty) children
else {
plans.pop match {
case Union(grandchildren) =>
grandchildren.reverseMap(plans.push(_))
collectUnionChildren(plans, children)
case other => collectUnionChildren(plans, children :+ other)
}
}
}
}
/**
* An extractor used when planning the physical execution of an aggregation. Compared with a logical
* aggregation, the following transformations are performed:

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
@ -76,4 +77,71 @@ class SetOperationSuite extends PlanTest {
testRelation3.select('g) :: Nil).analyze
comparePlans(unionOptimized, unionCorrectAnswer)
}
test("Remove unnecessary distincts in multiple unions") {
val query1 = OneRowRelation
.select(Literal(1).as('a))
val query2 = OneRowRelation
.select(Literal(2).as('b))
val query3 = OneRowRelation
.select(Literal(3).as('c))
// D - U - D - U - query1
// | |
// query3 query2
val unionQuery1 = Distinct(Union(Distinct(Union(query1, query2)), query3)).analyze
val optimized1 = Optimize.execute(unionQuery1)
val distinctUnionCorrectAnswer1 =
Distinct(Union(query1 :: query2 :: query3 :: Nil)).analyze
comparePlans(distinctUnionCorrectAnswer1, optimized1)
// query1
// |
// D - U - U - query2
// |
// D - U - query2
// |
// query3
val unionQuery2 = Distinct(Union(Union(query1, query2),
Distinct(Union(query2, query3)))).analyze
val optimized2 = Optimize.execute(unionQuery2)
val distinctUnionCorrectAnswer2 =
Distinct(Union(query1 :: query2 :: query2 :: query3 :: Nil)).analyze
comparePlans(distinctUnionCorrectAnswer2, optimized2)
}
test("Keep necessary distincts in multiple unions") {
val query1 = OneRowRelation
.select(Literal(1).as('a))
val query2 = OneRowRelation
.select(Literal(2).as('b))
val query3 = OneRowRelation
.select(Literal(3).as('c))
val query4 = OneRowRelation
.select(Literal(4).as('d))
// U - D - U - query1
// | |
// query3 query2
val unionQuery1 = Union(Distinct(Union(query1, query2)), query3).analyze
val optimized1 = Optimize.execute(unionQuery1)
val distinctUnionCorrectAnswer1 =
Union(Distinct(Union(query1 :: query2 :: Nil)) :: query3 :: Nil).analyze
comparePlans(distinctUnionCorrectAnswer1, optimized1)
// query1
// |
// U - D - U - query2
// |
// D - U - query3
// |
// query4
val unionQuery2 =
Union(Distinct(Union(query1, query2)), Distinct(Union(query3, query4))).analyze
val optimized2 = Optimize.execute(unionQuery2)
val distinctUnionCorrectAnswer2 =
Union(Distinct(Union(query1 :: query2 :: Nil)),
Distinct(Union(query3 :: query4 :: Nil))).analyze
comparePlans(distinctUnionCorrectAnswer2, optimized2)
}
}