From 44a71741d510484b787855986cec970ac0cb5da8 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 26 Sep 2018 21:34:18 +0800 Subject: [PATCH] [SPARK-25379][SQL] Improve AttributeSet and ColumnPruning performance ## What changes were proposed in this pull request? This PR contains 3 optimizations: 1) it improves significantly the operation `--` on `AttributeSet`. As a benchmark for the `--` operation, the following code has been run ``` test("AttributeSet -- benchmark") { val attrSetA = AttributeSet((1 to 100).map { i => AttributeReference(s"c$i", IntegerType)() }) val attrSetB = AttributeSet(attrSetA.take(80).toSeq) val attrSetC = AttributeSet((1 to 100).map { i => AttributeReference(s"c2_$i", IntegerType)() }) val attrSetD = AttributeSet((attrSetA.take(50) ++ attrSetC.take(50)).toSeq) val attrSetE = AttributeSet((attrSetC.take(50) ++ attrSetA.take(50)).toSeq) val n_iter = 1000000 val t0 = System.nanoTime() (1 to n_iter) foreach { _ => val r1 = attrSetA -- attrSetB val r2 = attrSetA -- attrSetC val r3 = attrSetA -- attrSetD val r4 = attrSetA -- attrSetE } val t1 = System.nanoTime() val totalTime = t1 - t0 println(s"Average time: ${totalTime / n_iter} us") } ``` The results are: ``` Before PR - Average time: 67674 us (100 %) After PR - Average time: 28827 us (42.6 %) ``` 2) In `ColumnPruning`, it replaces the occurrences of `(attributeSet1 -- attributeSet2).nonEmpty` with `attributeSet1.subsetOf(attributeSet2)` which is order of magnitudes more efficient (especially where there are many attributes). Running the previous benchmark replacing `--` with `subsetOf` returns: ``` Average time: 67 us (0.1 %) ``` 3) Provides a more efficient way of building `AttributeSet`s, which can greatly improve the performance of the methods `references` and `outputSet` of `Expression` and `QueryPlan`. This basically avoids unneeded operations (eg. creating many `AttributeEqual` wrapper classes which could be avoided) The overall effect of those optimizations has been tested on `ColumnPruning` with the following benchmark: ``` test("ColumnPruning benchmark") { val attrSetA = (1 to 100).map { i => AttributeReference(s"c$i", IntegerType)() } val attrSetB = attrSetA.take(80) val attrSetC = attrSetA.take(20).map(a => Alias(Add(a, Literal(1)), s"${a.name}_1")()) val input = LocalRelation(attrSetA) val query1 = Project(attrSetB, Project(attrSetA, input)).analyze val query2 = Project(attrSetC, Project(attrSetA, input)).analyze val query3 = Project(attrSetA, Project(attrSetA, input)).analyze val nIter = 100000 val t0 = System.nanoTime() (1 to nIter).foreach { _ => ColumnPruning(query1) ColumnPruning(query2) ColumnPruning(query3) } val t1 = System.nanoTime() val totalTime = t1 - t0 println(s"Average time: ${totalTime / nIter} us") } ``` The output of the test is: ``` Before PR - Average time: 733471 us (100 %) After PR - Average time: 362455 us (49.4 %) ``` The performance improvement has been evaluated also on the `SQLQueryTestSuite`'s queries: ``` (before) org.apache.spark.sql.catalyst.optimizer.ColumnPruning 518413198 / 1377707172 2756 / 15717 (after) org.apache.spark.sql.catalyst.optimizer.ColumnPruning 415432579 / 1121147950 2756 / 15717 % Running time 80.1% / 81.3% ``` Also other rules benefit especially from (3), despite the impact is lower, eg: ``` (before) org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences 307341442 / 623436806 2154 / 16480 (after) org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences 290511312 / 560962495 2154 / 16480 % Running time 94.5% / 90.0% ``` The reason why the impact on the `SQLQueryTestSuite`'s queries is lower compared to the other benchmark is that the optimizations are more significant when the number of attributes involved is higher. Since in the tests we often have very few attributes, the effect there is lower. ## How was this patch tested? run benchmarks + existing UTs Closes #22364 from mgaido91/SPARK-25379. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../catalyst/expressions/AttributeSet.scala | 23 +++++++++++++----- .../sql/catalyst/expressions/Expression.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 24 +++++++++---------- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 +- 4 files changed, 31 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 7420b6b57d..a7e09eee61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable + protected class AttributeEquals(val a: Attribute) { override def hashCode(): Int = a match { @@ -39,10 +41,13 @@ object AttributeSet { /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ def apply(baseSet: Iterable[Expression]): AttributeSet = { - new AttributeSet( - baseSet - .flatMap(_.references) - .map(new AttributeEquals(_)).toSet) + fromAttributeSets(baseSet.map(_.references)) + } + + /** Constructs a new [[AttributeSet]] given a sequence of [[AttributeSet]]s. */ + def fromAttributeSets(sets: Iterable[AttributeSet]): AttributeSet = { + val baseSet = sets.foldLeft(new mutable.LinkedHashSet[AttributeEquals]())( _ ++= _.baseSet) + new AttributeSet(baseSet.toSet) } } @@ -94,8 +99,14 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) * Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found * in `other`. */ - def --(other: Traversable[NamedExpression]): AttributeSet = - new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute))) + def --(other: Traversable[NamedExpression]): AttributeSet = { + other match { + case otherSet: AttributeSet => + new AttributeSet(baseSet -- otherSet.baseSet) + case _ => + new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute))) + } + } /** * Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 773aefc0ac..c215735ab1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -85,7 +85,7 @@ abstract class Expression extends TreeNode[Expression] { def nullable: Boolean - def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator)) + def references: AttributeSet = AttributeSet.fromAttributeSets(children.map(_.references)) /** Returns the result of evaluating this expression on a given input Row */ def eval(input: InternalRow = null): Any diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 7c461895c5..07a653f3b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -532,12 +532,12 @@ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan transform { // Prunes the unused columns from project list of Project/Aggregate/Expand - case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => + case p @ Project(_, p2: Project) if !p2.outputSet.subsetOf(p.references) => p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) - case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => + case p @ Project(_, a: Aggregate) if !a.outputSet.subsetOf(p.references) => p.copy( child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) - case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => + case a @ Project(_, e @ Expand(_, _, grandChild)) if !e.outputSet.subsetOf(a.references) => val newOutput = e.output.filter(a.references.contains(_)) val newProjects = e.projections.map { proj => proj.zip(e.output).filter { case (_, a) => @@ -547,18 +547,18 @@ object ColumnPruning extends Rule[LogicalPlan] { a.copy(child = Expand(newProjects, newOutput, grandChild)) // Prunes the unused columns from child of `DeserializeToObject` - case d @ DeserializeToObject(_, _, child) if (child.outputSet -- d.references).nonEmpty => + case d @ DeserializeToObject(_, _, child) if !child.outputSet.subsetOf(d.references) => d.copy(child = prunedChild(child, d.references)) // Prunes the unused columns from child of Aggregate/Expand/Generate/ScriptTransformation - case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => + case a @ Aggregate(_, _, child) if !child.outputSet.subsetOf(a.references) => a.copy(child = prunedChild(child, a.references)) - case f @ FlatMapGroupsInPandas(_, _, _, child) if (child.outputSet -- f.references).nonEmpty => + case f @ FlatMapGroupsInPandas(_, _, _, child) if !child.outputSet.subsetOf(f.references) => f.copy(child = prunedChild(child, f.references)) - case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => + case e @ Expand(_, _, child) if !child.outputSet.subsetOf(e.references) => e.copy(child = prunedChild(child, e.references)) case s @ ScriptTransformation(_, _, _, child, _) - if (child.outputSet -- s.references).nonEmpty => + if !child.outputSet.subsetOf(s.references) => s.copy(child = prunedChild(child, s.references)) // prune unrequired references @@ -579,7 +579,7 @@ object ColumnPruning extends Rule[LogicalPlan] { case p @ Project(_, _: Distinct) => p // Eliminate unneeded attributes from children of Union. case p @ Project(_, u: Union) => - if ((u.outputSet -- p.references).nonEmpty) { + if (!u.outputSet.subsetOf(p.references)) { val firstChild = u.children.head val newOutput = prunedChild(firstChild, p.references).output // pruning the columns of all children based on the pruned first child. @@ -595,7 +595,7 @@ object ColumnPruning extends Rule[LogicalPlan] { } // Prune unnecessary window expressions - case p @ Project(_, w: Window) if (w.windowOutputSet -- p.references).nonEmpty => + case p @ Project(_, w: Window) if !w.windowOutputSet.subsetOf(p.references) => p.copy(child = w.copy( windowExpressions = w.windowExpressions.filter(p.references.contains))) @@ -611,7 +611,7 @@ object ColumnPruning extends Rule[LogicalPlan] { // for all other logical plans that inherits the output from it's children case p @ Project(_, child) => val required = child.references ++ p.references - if ((child.inputSet -- required).nonEmpty) { + if (!child.inputSet.subsetOf(required)) { val newChildren = child.children.map(c => prunedChild(c, required)) p.copy(child = child.withNewChildren(newChildren)) } else { @@ -621,7 +621,7 @@ object ColumnPruning extends Rule[LogicalPlan] { /** Applies a projection only when the child is producing unnecessary attributes */ private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = - if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { + if (!c.outputSet.subsetOf(allReferences)) { Project(c.output.filter(allReferences.contains), c) } else { c diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index b1ffdca091..ca0cea6ba7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -42,7 +42,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * All Attributes that appear in expressions from this operator. Note that this set does not * include attributes that are implicitly referenced by being passed through to the output tuple. */ - def references: AttributeSet = AttributeSet(expressions.flatMap(_.references)) + def references: AttributeSet = AttributeSet.fromAttributeSets(expressions.map(_.references)) /** * The set of all attributes that are input to this operator by its children.