[SPARK-17653][SQL] Remove unnecessary distincts in multiple unions
## What changes were proposed in this pull request? Currently for `Union [Distinct]`, a `Distinct` operator is necessary to be on the top of `Union`. Once there are adjacent `Union [Distinct]`, there will be multiple `Distinct` in the query plan. E.g., For a query like: select 1 a union select 2 b union select 3 c Before this patch, its physical plan looks like: *HashAggregate(keys=[a#13], functions=[]) +- Exchange hashpartitioning(a#13, 200) +- *HashAggregate(keys=[a#13], functions=[]) +- Union :- *HashAggregate(keys=[a#13], functions=[]) : +- Exchange hashpartitioning(a#13, 200) : +- *HashAggregate(keys=[a#13], functions=[]) : +- Union : :- *Project [1 AS a#13] : : +- Scan OneRowRelation[] : +- *Project [2 AS b#14] : +- Scan OneRowRelation[] +- *Project [3 AS c#15] +- Scan OneRowRelation[] Only the top distinct should be necessary. After this patch, the physical plan looks like: *HashAggregate(keys=[a#221], functions=[], output=[a#221]) +- Exchange hashpartitioning(a#221, 5) +- *HashAggregate(keys=[a#221], functions=[], output=[a#221]) +- Union :- *Project [1 AS a#221] : +- Scan OneRowRelation[] :- *Project [2 AS b#222] : +- Scan OneRowRelation[] +- *Project [3 AS c#223] +- Scan OneRowRelation[] ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #15238 from viirya/remove-extra-distinct-union.
This commit is contained in:
parent
fe33121a53
commit
566d7f2827
|
@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
|
|||
|
||||
import scala.annotation.tailrec
|
||||
import scala.collection.immutable.HashSet
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark.api.java.function.FilterFunction
|
||||
|
@ -29,7 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
|
|||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate._
|
||||
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
|
||||
import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions}
|
||||
import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.rules._
|
||||
|
@ -579,8 +580,25 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
|
|||
* Combines all adjacent [[Union]] operators into a single [[Union]].
|
||||
*/
|
||||
object CombineUnions extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case Unions(children) => Union(children)
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
|
||||
case u: Union => flattenUnion(u, false)
|
||||
case Distinct(u: Union) => Distinct(flattenUnion(u, true))
|
||||
}
|
||||
|
||||
private def flattenUnion(union: Union, flattenDistinct: Boolean): Union = {
|
||||
val stack = mutable.Stack[LogicalPlan](union)
|
||||
val flattened = mutable.ArrayBuffer.empty[LogicalPlan]
|
||||
while (stack.nonEmpty) {
|
||||
stack.pop() match {
|
||||
case Distinct(Union(children)) if flattenDistinct =>
|
||||
stack.pushAll(children.reverse)
|
||||
case Union(children) =>
|
||||
stack.pushAll(children.reverse)
|
||||
case child =>
|
||||
flattened += child
|
||||
}
|
||||
}
|
||||
Union(flattened)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -188,33 +188,6 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A pattern that collects all adjacent unions and returns their children as a Seq.
|
||||
*/
|
||||
object Unions {
|
||||
def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match {
|
||||
case u: Union => Some(collectUnionChildren(mutable.Stack(u), Seq.empty[LogicalPlan]))
|
||||
case _ => None
|
||||
}
|
||||
|
||||
// Doing a depth-first tree traversal to combine all the union children.
|
||||
@tailrec
|
||||
private def collectUnionChildren(
|
||||
plans: mutable.Stack[LogicalPlan],
|
||||
children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
|
||||
if (plans.isEmpty) children
|
||||
else {
|
||||
plans.pop match {
|
||||
case Union(grandchildren) =>
|
||||
grandchildren.reverseMap(plans.push(_))
|
||||
collectUnionChildren(plans, children)
|
||||
case other => collectUnionChildren(plans, children :+ other)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* An extractor used when planning the physical execution of an aggregation. Compared with a logical
|
||||
* aggregation, the following transformations are performed:
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
|
|||
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
import org.apache.spark.sql.catalyst.expressions.Literal
|
||||
import org.apache.spark.sql.catalyst.plans.PlanTest
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.rules._
|
||||
|
@ -76,4 +77,71 @@ class SetOperationSuite extends PlanTest {
|
|||
testRelation3.select('g) :: Nil).analyze
|
||||
comparePlans(unionOptimized, unionCorrectAnswer)
|
||||
}
|
||||
|
||||
test("Remove unnecessary distincts in multiple unions") {
|
||||
val query1 = OneRowRelation
|
||||
.select(Literal(1).as('a))
|
||||
val query2 = OneRowRelation
|
||||
.select(Literal(2).as('b))
|
||||
val query3 = OneRowRelation
|
||||
.select(Literal(3).as('c))
|
||||
|
||||
// D - U - D - U - query1
|
||||
// | |
|
||||
// query3 query2
|
||||
val unionQuery1 = Distinct(Union(Distinct(Union(query1, query2)), query3)).analyze
|
||||
val optimized1 = Optimize.execute(unionQuery1)
|
||||
val distinctUnionCorrectAnswer1 =
|
||||
Distinct(Union(query1 :: query2 :: query3 :: Nil)).analyze
|
||||
comparePlans(distinctUnionCorrectAnswer1, optimized1)
|
||||
|
||||
// query1
|
||||
// |
|
||||
// D - U - U - query2
|
||||
// |
|
||||
// D - U - query2
|
||||
// |
|
||||
// query3
|
||||
val unionQuery2 = Distinct(Union(Union(query1, query2),
|
||||
Distinct(Union(query2, query3)))).analyze
|
||||
val optimized2 = Optimize.execute(unionQuery2)
|
||||
val distinctUnionCorrectAnswer2 =
|
||||
Distinct(Union(query1 :: query2 :: query2 :: query3 :: Nil)).analyze
|
||||
comparePlans(distinctUnionCorrectAnswer2, optimized2)
|
||||
}
|
||||
|
||||
test("Keep necessary distincts in multiple unions") {
|
||||
val query1 = OneRowRelation
|
||||
.select(Literal(1).as('a))
|
||||
val query2 = OneRowRelation
|
||||
.select(Literal(2).as('b))
|
||||
val query3 = OneRowRelation
|
||||
.select(Literal(3).as('c))
|
||||
val query4 = OneRowRelation
|
||||
.select(Literal(4).as('d))
|
||||
|
||||
// U - D - U - query1
|
||||
// | |
|
||||
// query3 query2
|
||||
val unionQuery1 = Union(Distinct(Union(query1, query2)), query3).analyze
|
||||
val optimized1 = Optimize.execute(unionQuery1)
|
||||
val distinctUnionCorrectAnswer1 =
|
||||
Union(Distinct(Union(query1 :: query2 :: Nil)) :: query3 :: Nil).analyze
|
||||
comparePlans(distinctUnionCorrectAnswer1, optimized1)
|
||||
|
||||
// query1
|
||||
// |
|
||||
// U - D - U - query2
|
||||
// |
|
||||
// D - U - query3
|
||||
// |
|
||||
// query4
|
||||
val unionQuery2 =
|
||||
Union(Distinct(Union(query1, query2)), Distinct(Union(query3, query4))).analyze
|
||||
val optimized2 = Optimize.execute(unionQuery2)
|
||||
val distinctUnionCorrectAnswer2 =
|
||||
Union(Distinct(Union(query1 :: query2 :: Nil)),
|
||||
Distinct(Union(query3 :: query4 :: Nil))).analyze
|
||||
comparePlans(distinctUnionCorrectAnswer2, optimized2)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue