[SPARK-25379][SQL] Improve AttributeSet and ColumnPruning performance
## What changes were proposed in this pull request? This PR contains 3 optimizations: 1) it improves significantly the operation `--` on `AttributeSet`. As a benchmark for the `--` operation, the following code has been run ``` test("AttributeSet -- benchmark") { val attrSetA = AttributeSet((1 to 100).map { i => AttributeReference(s"c$i", IntegerType)() }) val attrSetB = AttributeSet(attrSetA.take(80).toSeq) val attrSetC = AttributeSet((1 to 100).map { i => AttributeReference(s"c2_$i", IntegerType)() }) val attrSetD = AttributeSet((attrSetA.take(50) ++ attrSetC.take(50)).toSeq) val attrSetE = AttributeSet((attrSetC.take(50) ++ attrSetA.take(50)).toSeq) val n_iter = 1000000 val t0 = System.nanoTime() (1 to n_iter) foreach { _ => val r1 = attrSetA -- attrSetB val r2 = attrSetA -- attrSetC val r3 = attrSetA -- attrSetD val r4 = attrSetA -- attrSetE } val t1 = System.nanoTime() val totalTime = t1 - t0 println(s"Average time: ${totalTime / n_iter} us") } ``` The results are: ``` Before PR - Average time: 67674 us (100 %) After PR - Average time: 28827 us (42.6 %) ``` 2) In `ColumnPruning`, it replaces the occurrences of `(attributeSet1 -- attributeSet2).nonEmpty` with `attributeSet1.subsetOf(attributeSet2)` which is order of magnitudes more efficient (especially where there are many attributes). Running the previous benchmark replacing `--` with `subsetOf` returns: ``` Average time: 67 us (0.1 %) ``` 3) Provides a more efficient way of building `AttributeSet`s, which can greatly improve the performance of the methods `references` and `outputSet` of `Expression` and `QueryPlan`. This basically avoids unneeded operations (eg. creating many `AttributeEqual` wrapper classes which could be avoided) The overall effect of those optimizations has been tested on `ColumnPruning` with the following benchmark: ``` test("ColumnPruning benchmark") { val attrSetA = (1 to 100).map { i => AttributeReference(s"c$i", IntegerType)() } val attrSetB = attrSetA.take(80) val attrSetC = attrSetA.take(20).map(a => Alias(Add(a, Literal(1)), s"${a.name}_1")()) val input = LocalRelation(attrSetA) val query1 = Project(attrSetB, Project(attrSetA, input)).analyze val query2 = Project(attrSetC, Project(attrSetA, input)).analyze val query3 = Project(attrSetA, Project(attrSetA, input)).analyze val nIter = 100000 val t0 = System.nanoTime() (1 to nIter).foreach { _ => ColumnPruning(query1) ColumnPruning(query2) ColumnPruning(query3) } val t1 = System.nanoTime() val totalTime = t1 - t0 println(s"Average time: ${totalTime / nIter} us") } ``` The output of the test is: ``` Before PR - Average time: 733471 us (100 %) After PR - Average time: 362455 us (49.4 %) ``` The performance improvement has been evaluated also on the `SQLQueryTestSuite`'s queries: ``` (before) org.apache.spark.sql.catalyst.optimizer.ColumnPruning 518413198 / 1377707172 2756 / 15717 (after) org.apache.spark.sql.catalyst.optimizer.ColumnPruning 415432579 / 1121147950 2756 / 15717 % Running time 80.1% / 81.3% ``` Also other rules benefit especially from (3), despite the impact is lower, eg: ``` (before) org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences 307341442 / 623436806 2154 / 16480 (after) org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences 290511312 / 560962495 2154 / 16480 % Running time 94.5% / 90.0% ``` The reason why the impact on the `SQLQueryTestSuite`'s queries is lower compared to the other benchmark is that the optimizations are more significant when the number of attributes involved is higher. Since in the tests we often have very few attributes, the effect there is lower. ## How was this patch tested? run benchmarks + existing UTs Closes #22364 from mgaido91/SPARK-25379. Authored-by: Marco Gaido <marcogaido91@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
b39e228ce8
commit
44a71741d5
|
@ -17,6 +17,8 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
|
||||
protected class AttributeEquals(val a: Attribute) {
|
||||
override def hashCode(): Int = a match {
|
||||
|
@ -39,10 +41,13 @@ object AttributeSet {
|
|||
|
||||
/** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */
|
||||
def apply(baseSet: Iterable[Expression]): AttributeSet = {
|
||||
new AttributeSet(
|
||||
baseSet
|
||||
.flatMap(_.references)
|
||||
.map(new AttributeEquals(_)).toSet)
|
||||
fromAttributeSets(baseSet.map(_.references))
|
||||
}
|
||||
|
||||
/** Constructs a new [[AttributeSet]] given a sequence of [[AttributeSet]]s. */
|
||||
def fromAttributeSets(sets: Iterable[AttributeSet]): AttributeSet = {
|
||||
val baseSet = sets.foldLeft(new mutable.LinkedHashSet[AttributeEquals]())( _ ++= _.baseSet)
|
||||
new AttributeSet(baseSet.toSet)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -94,8 +99,14 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
|
|||
* Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found
|
||||
* in `other`.
|
||||
*/
|
||||
def --(other: Traversable[NamedExpression]): AttributeSet =
|
||||
def --(other: Traversable[NamedExpression]): AttributeSet = {
|
||||
other match {
|
||||
case otherSet: AttributeSet =>
|
||||
new AttributeSet(baseSet -- otherSet.baseSet)
|
||||
case _ =>
|
||||
new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute)))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found
|
||||
|
|
|
@ -85,7 +85,7 @@ abstract class Expression extends TreeNode[Expression] {
|
|||
|
||||
def nullable: Boolean
|
||||
|
||||
def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator))
|
||||
def references: AttributeSet = AttributeSet.fromAttributeSets(children.map(_.references))
|
||||
|
||||
/** Returns the result of evaluating this expression on a given input Row */
|
||||
def eval(input: InternalRow = null): Any
|
||||
|
|
|
@ -532,12 +532,12 @@ object ColumnPruning extends Rule[LogicalPlan] {
|
|||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan transform {
|
||||
// Prunes the unused columns from project list of Project/Aggregate/Expand
|
||||
case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty =>
|
||||
case p @ Project(_, p2: Project) if !p2.outputSet.subsetOf(p.references) =>
|
||||
p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains)))
|
||||
case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty =>
|
||||
case p @ Project(_, a: Aggregate) if !a.outputSet.subsetOf(p.references) =>
|
||||
p.copy(
|
||||
child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains)))
|
||||
case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty =>
|
||||
case a @ Project(_, e @ Expand(_, _, grandChild)) if !e.outputSet.subsetOf(a.references) =>
|
||||
val newOutput = e.output.filter(a.references.contains(_))
|
||||
val newProjects = e.projections.map { proj =>
|
||||
proj.zip(e.output).filter { case (_, a) =>
|
||||
|
@ -547,18 +547,18 @@ object ColumnPruning extends Rule[LogicalPlan] {
|
|||
a.copy(child = Expand(newProjects, newOutput, grandChild))
|
||||
|
||||
// Prunes the unused columns from child of `DeserializeToObject`
|
||||
case d @ DeserializeToObject(_, _, child) if (child.outputSet -- d.references).nonEmpty =>
|
||||
case d @ DeserializeToObject(_, _, child) if !child.outputSet.subsetOf(d.references) =>
|
||||
d.copy(child = prunedChild(child, d.references))
|
||||
|
||||
// Prunes the unused columns from child of Aggregate/Expand/Generate/ScriptTransformation
|
||||
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
|
||||
case a @ Aggregate(_, _, child) if !child.outputSet.subsetOf(a.references) =>
|
||||
a.copy(child = prunedChild(child, a.references))
|
||||
case f @ FlatMapGroupsInPandas(_, _, _, child) if (child.outputSet -- f.references).nonEmpty =>
|
||||
case f @ FlatMapGroupsInPandas(_, _, _, child) if !child.outputSet.subsetOf(f.references) =>
|
||||
f.copy(child = prunedChild(child, f.references))
|
||||
case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty =>
|
||||
case e @ Expand(_, _, child) if !child.outputSet.subsetOf(e.references) =>
|
||||
e.copy(child = prunedChild(child, e.references))
|
||||
case s @ ScriptTransformation(_, _, _, child, _)
|
||||
if (child.outputSet -- s.references).nonEmpty =>
|
||||
if !child.outputSet.subsetOf(s.references) =>
|
||||
s.copy(child = prunedChild(child, s.references))
|
||||
|
||||
// prune unrequired references
|
||||
|
@ -579,7 +579,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
|
|||
case p @ Project(_, _: Distinct) => p
|
||||
// Eliminate unneeded attributes from children of Union.
|
||||
case p @ Project(_, u: Union) =>
|
||||
if ((u.outputSet -- p.references).nonEmpty) {
|
||||
if (!u.outputSet.subsetOf(p.references)) {
|
||||
val firstChild = u.children.head
|
||||
val newOutput = prunedChild(firstChild, p.references).output
|
||||
// pruning the columns of all children based on the pruned first child.
|
||||
|
@ -595,7 +595,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
|
|||
}
|
||||
|
||||
// Prune unnecessary window expressions
|
||||
case p @ Project(_, w: Window) if (w.windowOutputSet -- p.references).nonEmpty =>
|
||||
case p @ Project(_, w: Window) if !w.windowOutputSet.subsetOf(p.references) =>
|
||||
p.copy(child = w.copy(
|
||||
windowExpressions = w.windowExpressions.filter(p.references.contains)))
|
||||
|
||||
|
@ -611,7 +611,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
|
|||
// for all other logical plans that inherits the output from it's children
|
||||
case p @ Project(_, child) =>
|
||||
val required = child.references ++ p.references
|
||||
if ((child.inputSet -- required).nonEmpty) {
|
||||
if (!child.inputSet.subsetOf(required)) {
|
||||
val newChildren = child.children.map(c => prunedChild(c, required))
|
||||
p.copy(child = child.withNewChildren(newChildren))
|
||||
} else {
|
||||
|
@ -621,7 +621,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
|
|||
|
||||
/** Applies a projection only when the child is producing unnecessary attributes */
|
||||
private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
|
||||
if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
|
||||
if (!c.outputSet.subsetOf(allReferences)) {
|
||||
Project(c.output.filter(allReferences.contains), c)
|
||||
} else {
|
||||
c
|
||||
|
|
|
@ -42,7 +42,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
|
|||
* All Attributes that appear in expressions from this operator. Note that this set does not
|
||||
* include attributes that are implicitly referenced by being passed through to the output tuple.
|
||||
*/
|
||||
def references: AttributeSet = AttributeSet(expressions.flatMap(_.references))
|
||||
def references: AttributeSet = AttributeSet.fromAttributeSets(expressions.map(_.references))
|
||||
|
||||
/**
|
||||
* The set of all attributes that are input to this operator by its children.
|
||||
|
|
Loading…
Reference in a new issue