diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2b804976f3..1f05f2065c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -313,85 +313,97 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Prunes the unused columns from project list of Project/Aggregate/Window/Expand - case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => - p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) - case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => - p.copy( - child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) - case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty => - p.copy(child = w.copy( - projectList = w.projectList.filter(p.references.contains), - windowExpressions = w.windowExpressions.filter(p.references.contains))) - case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => - val newOutput = e.output.filter(a.references.contains(_)) - val newProjects = e.projections.map { proj => - proj.zip(e.output).filter { case (e, a) => + case a @ Aggregate(_, _, e @ Expand(projects, output, child)) + if (e.outputSet -- a.references).nonEmpty => + val newOutput = output.filter(a.references.contains(_)) + val newProjects = projects.map { proj => + proj.zip(output).filter { case (e, a) => newOutput.contains(a) }.unzip._1 } - a.copy(child = Expand(newProjects, newOutput, grandChild)) - // TODO: support some logical plan for Dataset + a.copy(child = Expand(newProjects, newOutput, child)) - // Prunes the unused columns from child of Aggregate/Window/Expand/Generate + case a @ Aggregate(_, _, e @ Expand(_, _, child)) + if (child.outputSet -- e.references -- a.references).nonEmpty => + a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references))) + + // Eliminate attributes that are not needed to calculate the specified aggregates. case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => - a.copy(child = prunedChild(child, a.references)) - case w @ Window(_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty => - w.copy(child = prunedChild(child, w.references)) - case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => - e.copy(child = prunedChild(child, e.references)) - case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => - g.copy(child = prunedChild(g.child, g.references)) + a.copy(child = Project(a.references.toSeq, child)) + + // Eliminate attributes that are not needed to calculate the Generate. + case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => + g.copy(child = Project(g.references.toSeq, g.child)) - // Turn off `join` for Generate if no column from it's child is used case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => p.copy(child = g.copy(join = false)) - // Eliminate unneeded attributes from right side of a LeftSemiJoin. - case j @ Join(left, right, LeftSemi, condition) => - j.copy(right = prunedChild(right, j.references)) - - // all the columns will be used to compare, so we can't prune them - case p @ Project(_, _: SetOperation) => p - case p @ Project(_, _: Distinct) => p - // Eliminate unneeded attributes from children of Union. - case p @ Project(_, u: Union) => - if ((u.outputSet -- p.references).nonEmpty) { - val firstChild = u.children.head - val newOutput = prunedChild(firstChild, p.references).output - // pruning the columns of all children based on the pruned first child. - val newChildren = u.children.map { p => - val selected = p.output.zipWithIndex.filter { case (a, i) => - newOutput.contains(firstChild.output(i)) - }.map(_._1) - Project(selected, p) - } - p.copy(child = u.withNewChildren(newChildren)) - } else { + case p @ Project(projectList, g: Generate) if g.join => + val neededChildOutput = p.references -- g.generatorOutput ++ g.references + if (neededChildOutput == g.child.outputSet) { p + } else { + Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child))) } - // Can't prune the columns on LeafNode - case p @ Project(_, l: LeafNode) => p + case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child)) + if (a.outputSet -- p.references).nonEmpty => + Project( + projectList, + Aggregate( + groupingExpressions, + aggregateExpressions.filter(e => p.references.contains(e)), + child)) + + // Eliminate unneeded attributes from either side of a Join. + case Project(projectList, Join(left, right, joinType, condition)) => + // Collect the list of all references required either above or to evaluate the condition. + val allReferences: AttributeSet = + AttributeSet( + projectList.flatMap(_.references.iterator)) ++ + condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) + + /** Applies a projection only when the child is producing unnecessary attributes */ + def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences) + + Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition)) + + // Eliminate unneeded attributes from right side of a LeftSemiJoin. + case Join(left, right, LeftSemi, condition) => + // Collect the list of all references required to evaluate the condition. + val allReferences: AttributeSet = + condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) + + Join(left, prunedChild(right, allReferences), LeftSemi, condition) + + // Push down project through limit, so that we may have chance to push it further. + case Project(projectList, Limit(exp, child)) => + Limit(exp, Project(projectList, child)) + + // Push down project if possible when the child is sort. + case p @ Project(projectList, s @ Sort(_, _, grandChild)) => + if (s.references.subsetOf(p.outputSet)) { + s.copy(child = Project(projectList, grandChild)) + } else { + val neededReferences = s.references ++ p.references + if (neededReferences == grandChild.outputSet) { + // No column we can prune, return the original plan. + p + } else { + // Do not use neededReferences.toSeq directly, should respect grandChild's output order. + val newProjectList = grandChild.output.filter(neededReferences.contains) + p.copy(child = s.copy(child = Project(newProjectList, grandChild))) + } + } // Eliminate no-op Projects - case p @ Project(projectList, child) if child.output == p.output => child - - // for all other logical plans that inherits the output from it's children - case p @ Project(_, child) => - val required = child.references ++ p.references - if ((child.inputSet -- required).nonEmpty) { - val newChildren = child.children.map(c => prunedChild(c, required)) - p.copy(child = child.withNewChildren(newChildren)) - } else { - p - } + case Project(projectList, child) if child.output == projectList => child } /** Applies a projection only when the child is producing unnecessary attributes */ private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { - Project(c.output.filter(allReferences.contains), c) + Project(allReferences.filter(c.outputSet.contains).toSeq, c) } else { c } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 715d01a3cd..c890fffc40 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Explode, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -120,134 +119,11 @@ class ColumnPruningSuite extends PlanTest { Seq('c, Literal.create(null, StringType), 1), Seq('c, 'a, 2)), Seq('c, 'aa.int, 'gid.int), - Project(Seq('a, 'c), + Project(Seq('c, 'a), input))).analyze comparePlans(optimized, expected) } - test("Column pruning on Filter") { - val input = LocalRelation('a.int, 'b.string, 'c.double) - val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze - val expected = - Project('a :: Nil, - Filter('c > Literal(0.0), - Project(Seq('a, 'c), input))).analyze - comparePlans(Optimize.execute(query), expected) - } - - test("Column pruning on except/intersect/distinct") { - val input = LocalRelation('a.int, 'b.string, 'c.double) - val query = Project('a :: Nil, Except(input, input)).analyze - comparePlans(Optimize.execute(query), query) - - val query2 = Project('a :: Nil, Intersect(input, input)).analyze - comparePlans(Optimize.execute(query2), query2) - val query3 = Project('a :: Nil, Distinct(input)).analyze - comparePlans(Optimize.execute(query3), query3) - } - - test("Column pruning on Project") { - val input = LocalRelation('a.int, 'b.string, 'c.double) - val query = Project('a :: Nil, Project(Seq('a, 'b), input)).analyze - val expected = Project(Seq('a), input).analyze - comparePlans(Optimize.execute(query), expected) - } - - test("column pruning for group") { - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val originalQuery = - testRelation - .groupBy('a)('a, count('b)) - .select('a) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .groupBy('a)('a).analyze - - comparePlans(optimized, correctAnswer) - } - - test("column pruning for group with alias") { - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - - val originalQuery = - testRelation - .groupBy('a)('a as 'c, count('b)) - .select('c) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .groupBy('a)('a as 'c).analyze - - comparePlans(optimized, correctAnswer) - } - - test("column pruning for Project(ne, Limit)") { - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - - val originalQuery = - testRelation - .select('a, 'b) - .limit(2) - .select('a) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .limit(2).analyze - - comparePlans(optimized, correctAnswer) - } - - test("push down project past sort") { - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val x = testRelation.subquery('x) - - // push down valid - val originalQuery = { - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('a) - } - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - x.select('a) - .sortBy(SortOrder('a, Ascending)).analyze - - comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) - - // push down invalid - val originalQuery1 = { - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b) - } - - val optimized1 = Optimize.execute(originalQuery1.analyze) - val correctAnswer1 = - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b).analyze - - comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1)) - } - - test("Column pruning on Union") { - val input1 = LocalRelation('a.int, 'b.string, 'c.double) - val input2 = LocalRelation('c.int, 'd.string, 'e.double) - val query = Project('b :: Nil, - Union(input1 :: input2 :: Nil)).analyze - val expected = Project('b :: Nil, - Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze - comparePlans(Optimize.execute(query), expected) - } - // todo: add more tests for column pruning } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 7d60862f5a..70b34cbb24 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -41,6 +41,7 @@ class FilterPushdownSuite extends PlanTest { PushPredicateThroughJoin, PushPredicateThroughGenerate, PushPredicateThroughAggregate, + ColumnPruning, CollapseProject) :: Nil } @@ -64,6 +65,52 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("column pruning for group") { + val originalQuery = + testRelation + .groupBy('a)('a, count('b)) + .select('a) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a).analyze + + comparePlans(optimized, correctAnswer) + } + + test("column pruning for group with alias") { + val originalQuery = + testRelation + .groupBy('a)('a as 'c, count('b)) + .select('c) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a as 'c).analyze + + comparePlans(optimized, correctAnswer) + } + + test("column pruning for Project(ne, Limit)") { + val originalQuery = + testRelation + .select('a, 'b) + .limit(2) + .select('a) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(2).analyze + + comparePlans(optimized, correctAnswer) + } + // After this line is unimplemented. test("simple push down") { val originalQuery = @@ -557,6 +604,39 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } + test("push down project past sort") { + val x = testRelation.subquery('x) + + // push down valid + val originalQuery = { + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('a) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + x.select('a) + .sortBy(SortOrder('a, Ascending)).analyze + + comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) + + // push down invalid + val originalQuery1 = { + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('b) + } + + val optimized1 = Optimize.execute(originalQuery1.analyze) + val correctAnswer1 = + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('b).analyze + + comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1)) + } + test("push project and filter down into sample") { val x = testRelation.subquery('x) val originalQuery = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index 22d4278085..4858140229 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -26,8 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{LeafNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -64,7 +63,7 @@ private[sql] case class InMemoryRelation( @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null, @transient private[sql] var _statistics: Statistics = null, private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) - extends logical.LeafNode with MultiInstanceRelation { + extends LogicalPlan with MultiInstanceRelation { override def producedAttributes: AttributeSet = outputSet @@ -185,6 +184,8 @@ private[sql] case class InMemoryRelation( _cachedColumnBuffers, statisticsToBePropagated, batchStats) } + override def children: Seq[LogicalPlan] = Seq.empty + override def newInstance(): this.type = { new InMemoryRelation( output.map(_.newInstance()),