diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 7455e68ee8..586bf3d4dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -381,12 +381,12 @@ object ColumnPruning extends Rule[LogicalPlan] { p } - // Can't prune the columns on LeafNode - case p @ Project(_, l: LeafNode) => p - // Eliminate no-op Projects case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child + // Can't prune the columns on LeafNode + case p @ Project(_, l: LeafNode) => p + // for all other logical plans that inherits the output from it's children case p @ Project(_, child) => val required = child.references ++ p.references diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index d09601e034..409e92238e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -157,6 +157,22 @@ class ColumnPruningSuite extends PlanTest { comparePlans(Optimize.execute(query), expected) } + test("Eliminate the Project with an empty projectList") { + val input = OneRowRelation + val expected = Project(Literal(1).as("1") :: Nil, input).analyze + + val query1 = + Project(Literal(1).as("1") :: Nil, Project(Literal(1).as("1") :: Nil, input)).analyze + comparePlans(Optimize.execute(query1), expected) + + val query2 = + Project(Literal(1).as("1") :: Nil, Project(Nil, input)).analyze + comparePlans(Optimize.execute(query2), expected) + + // to make sure the top Project will not be removed. + comparePlans(Optimize.execute(expected), expected) + } + test("column pruning for group") { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val originalQuery =