Revert "[SPARK-13376] [SQL] improve column pruning"
This reverts commit e9533b419e
.
This commit is contained in:
parent
382b27babf
commit
d563c8fa01
|
@ -313,85 +313,97 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
|
|||
*/
|
||||
object ColumnPruning extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
// Prunes the unused columns from project list of Project/Aggregate/Window/Expand
|
||||
case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty =>
|
||||
p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains)))
|
||||
case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty =>
|
||||
p.copy(
|
||||
child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains)))
|
||||
case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty =>
|
||||
p.copy(child = w.copy(
|
||||
projectList = w.projectList.filter(p.references.contains),
|
||||
windowExpressions = w.windowExpressions.filter(p.references.contains)))
|
||||
case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty =>
|
||||
val newOutput = e.output.filter(a.references.contains(_))
|
||||
val newProjects = e.projections.map { proj =>
|
||||
proj.zip(e.output).filter { case (e, a) =>
|
||||
case a @ Aggregate(_, _, e @ Expand(projects, output, child))
|
||||
if (e.outputSet -- a.references).nonEmpty =>
|
||||
val newOutput = output.filter(a.references.contains(_))
|
||||
val newProjects = projects.map { proj =>
|
||||
proj.zip(output).filter { case (e, a) =>
|
||||
newOutput.contains(a)
|
||||
}.unzip._1
|
||||
}
|
||||
a.copy(child = Expand(newProjects, newOutput, grandChild))
|
||||
// TODO: support some logical plan for Dataset
|
||||
a.copy(child = Expand(newProjects, newOutput, child))
|
||||
|
||||
// Prunes the unused columns from child of Aggregate/Window/Expand/Generate
|
||||
case a @ Aggregate(_, _, e @ Expand(_, _, child))
|
||||
if (child.outputSet -- e.references -- a.references).nonEmpty =>
|
||||
a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references)))
|
||||
|
||||
// Eliminate attributes that are not needed to calculate the specified aggregates.
|
||||
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
|
||||
a.copy(child = prunedChild(child, a.references))
|
||||
case w @ Window(_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty =>
|
||||
w.copy(child = prunedChild(child, w.references))
|
||||
case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty =>
|
||||
e.copy(child = prunedChild(child, e.references))
|
||||
case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty =>
|
||||
g.copy(child = prunedChild(g.child, g.references))
|
||||
a.copy(child = Project(a.references.toSeq, child))
|
||||
|
||||
// Eliminate attributes that are not needed to calculate the Generate.
|
||||
case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty =>
|
||||
g.copy(child = Project(g.references.toSeq, g.child))
|
||||
|
||||
// Turn off `join` for Generate if no column from it's child is used
|
||||
case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) =>
|
||||
p.copy(child = g.copy(join = false))
|
||||
|
||||
// Eliminate unneeded attributes from right side of a LeftSemiJoin.
|
||||
case j @ Join(left, right, LeftSemi, condition) =>
|
||||
j.copy(right = prunedChild(right, j.references))
|
||||
|
||||
// all the columns will be used to compare, so we can't prune them
|
||||
case p @ Project(_, _: SetOperation) => p
|
||||
case p @ Project(_, _: Distinct) => p
|
||||
// Eliminate unneeded attributes from children of Union.
|
||||
case p @ Project(_, u: Union) =>
|
||||
if ((u.outputSet -- p.references).nonEmpty) {
|
||||
val firstChild = u.children.head
|
||||
val newOutput = prunedChild(firstChild, p.references).output
|
||||
// pruning the columns of all children based on the pruned first child.
|
||||
val newChildren = u.children.map { p =>
|
||||
val selected = p.output.zipWithIndex.filter { case (a, i) =>
|
||||
newOutput.contains(firstChild.output(i))
|
||||
}.map(_._1)
|
||||
Project(selected, p)
|
||||
}
|
||||
p.copy(child = u.withNewChildren(newChildren))
|
||||
} else {
|
||||
case p @ Project(projectList, g: Generate) if g.join =>
|
||||
val neededChildOutput = p.references -- g.generatorOutput ++ g.references
|
||||
if (neededChildOutput == g.child.outputSet) {
|
||||
p
|
||||
} else {
|
||||
Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child)))
|
||||
}
|
||||
|
||||
// Can't prune the columns on LeafNode
|
||||
case p @ Project(_, l: LeafNode) => p
|
||||
case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child))
|
||||
if (a.outputSet -- p.references).nonEmpty =>
|
||||
Project(
|
||||
projectList,
|
||||
Aggregate(
|
||||
groupingExpressions,
|
||||
aggregateExpressions.filter(e => p.references.contains(e)),
|
||||
child))
|
||||
|
||||
// Eliminate unneeded attributes from either side of a Join.
|
||||
case Project(projectList, Join(left, right, joinType, condition)) =>
|
||||
// Collect the list of all references required either above or to evaluate the condition.
|
||||
val allReferences: AttributeSet =
|
||||
AttributeSet(
|
||||
projectList.flatMap(_.references.iterator)) ++
|
||||
condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
|
||||
|
||||
/** Applies a projection only when the child is producing unnecessary attributes */
|
||||
def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences)
|
||||
|
||||
Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition))
|
||||
|
||||
// Eliminate unneeded attributes from right side of a LeftSemiJoin.
|
||||
case Join(left, right, LeftSemi, condition) =>
|
||||
// Collect the list of all references required to evaluate the condition.
|
||||
val allReferences: AttributeSet =
|
||||
condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
|
||||
|
||||
Join(left, prunedChild(right, allReferences), LeftSemi, condition)
|
||||
|
||||
// Push down project through limit, so that we may have chance to push it further.
|
||||
case Project(projectList, Limit(exp, child)) =>
|
||||
Limit(exp, Project(projectList, child))
|
||||
|
||||
// Push down project if possible when the child is sort.
|
||||
case p @ Project(projectList, s @ Sort(_, _, grandChild)) =>
|
||||
if (s.references.subsetOf(p.outputSet)) {
|
||||
s.copy(child = Project(projectList, grandChild))
|
||||
} else {
|
||||
val neededReferences = s.references ++ p.references
|
||||
if (neededReferences == grandChild.outputSet) {
|
||||
// No column we can prune, return the original plan.
|
||||
p
|
||||
} else {
|
||||
// Do not use neededReferences.toSeq directly, should respect grandChild's output order.
|
||||
val newProjectList = grandChild.output.filter(neededReferences.contains)
|
||||
p.copy(child = s.copy(child = Project(newProjectList, grandChild)))
|
||||
}
|
||||
}
|
||||
|
||||
// Eliminate no-op Projects
|
||||
case p @ Project(projectList, child) if child.output == p.output => child
|
||||
|
||||
// for all other logical plans that inherits the output from it's children
|
||||
case p @ Project(_, child) =>
|
||||
val required = child.references ++ p.references
|
||||
if ((child.inputSet -- required).nonEmpty) {
|
||||
val newChildren = child.children.map(c => prunedChild(c, required))
|
||||
p.copy(child = child.withNewChildren(newChildren))
|
||||
} else {
|
||||
p
|
||||
}
|
||||
case Project(projectList, child) if child.output == projectList => child
|
||||
}
|
||||
|
||||
/** Applies a projection only when the child is producing unnecessary attributes */
|
||||
private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
|
||||
if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
|
||||
Project(c.output.filter(allReferences.contains), c)
|
||||
Project(allReferences.filter(c.outputSet.contains).toSeq, c)
|
||||
} else {
|
||||
c
|
||||
}
|
||||
|
|
|
@ -17,10 +17,9 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.optimizer
|
||||
|
||||
import org.apache.spark.sql.catalyst.analysis
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Explode, Literal}
|
||||
import org.apache.spark.sql.catalyst.plans.PlanTest
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.rules.RuleExecutor
|
||||
|
@ -120,134 +119,11 @@ class ColumnPruningSuite extends PlanTest {
|
|||
Seq('c, Literal.create(null, StringType), 1),
|
||||
Seq('c, 'a, 2)),
|
||||
Seq('c, 'aa.int, 'gid.int),
|
||||
Project(Seq('a, 'c),
|
||||
Project(Seq('c, 'a),
|
||||
input))).analyze
|
||||
|
||||
comparePlans(optimized, expected)
|
||||
}
|
||||
|
||||
test("Column pruning on Filter") {
|
||||
val input = LocalRelation('a.int, 'b.string, 'c.double)
|
||||
val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze
|
||||
val expected =
|
||||
Project('a :: Nil,
|
||||
Filter('c > Literal(0.0),
|
||||
Project(Seq('a, 'c), input))).analyze
|
||||
comparePlans(Optimize.execute(query), expected)
|
||||
}
|
||||
|
||||
test("Column pruning on except/intersect/distinct") {
|
||||
val input = LocalRelation('a.int, 'b.string, 'c.double)
|
||||
val query = Project('a :: Nil, Except(input, input)).analyze
|
||||
comparePlans(Optimize.execute(query), query)
|
||||
|
||||
val query2 = Project('a :: Nil, Intersect(input, input)).analyze
|
||||
comparePlans(Optimize.execute(query2), query2)
|
||||
val query3 = Project('a :: Nil, Distinct(input)).analyze
|
||||
comparePlans(Optimize.execute(query3), query3)
|
||||
}
|
||||
|
||||
test("Column pruning on Project") {
|
||||
val input = LocalRelation('a.int, 'b.string, 'c.double)
|
||||
val query = Project('a :: Nil, Project(Seq('a, 'b), input)).analyze
|
||||
val expected = Project(Seq('a), input).analyze
|
||||
comparePlans(Optimize.execute(query), expected)
|
||||
}
|
||||
|
||||
test("column pruning for group") {
|
||||
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
|
||||
val originalQuery =
|
||||
testRelation
|
||||
.groupBy('a)('a, count('b))
|
||||
.select('a)
|
||||
|
||||
val optimized = Optimize.execute(originalQuery.analyze)
|
||||
val correctAnswer =
|
||||
testRelation
|
||||
.select('a)
|
||||
.groupBy('a)('a).analyze
|
||||
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
||||
test("column pruning for group with alias") {
|
||||
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
|
||||
|
||||
val originalQuery =
|
||||
testRelation
|
||||
.groupBy('a)('a as 'c, count('b))
|
||||
.select('c)
|
||||
|
||||
val optimized = Optimize.execute(originalQuery.analyze)
|
||||
val correctAnswer =
|
||||
testRelation
|
||||
.select('a)
|
||||
.groupBy('a)('a as 'c).analyze
|
||||
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
||||
test("column pruning for Project(ne, Limit)") {
|
||||
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
|
||||
|
||||
val originalQuery =
|
||||
testRelation
|
||||
.select('a, 'b)
|
||||
.limit(2)
|
||||
.select('a)
|
||||
|
||||
val optimized = Optimize.execute(originalQuery.analyze)
|
||||
val correctAnswer =
|
||||
testRelation
|
||||
.select('a)
|
||||
.limit(2).analyze
|
||||
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
||||
test("push down project past sort") {
|
||||
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
|
||||
val x = testRelation.subquery('x)
|
||||
|
||||
// push down valid
|
||||
val originalQuery = {
|
||||
x.select('a, 'b)
|
||||
.sortBy(SortOrder('a, Ascending))
|
||||
.select('a)
|
||||
}
|
||||
|
||||
val optimized = Optimize.execute(originalQuery.analyze)
|
||||
val correctAnswer =
|
||||
x.select('a)
|
||||
.sortBy(SortOrder('a, Ascending)).analyze
|
||||
|
||||
comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
|
||||
|
||||
// push down invalid
|
||||
val originalQuery1 = {
|
||||
x.select('a, 'b)
|
||||
.sortBy(SortOrder('a, Ascending))
|
||||
.select('b)
|
||||
}
|
||||
|
||||
val optimized1 = Optimize.execute(originalQuery1.analyze)
|
||||
val correctAnswer1 =
|
||||
x.select('a, 'b)
|
||||
.sortBy(SortOrder('a, Ascending))
|
||||
.select('b).analyze
|
||||
|
||||
comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1))
|
||||
}
|
||||
|
||||
test("Column pruning on Union") {
|
||||
val input1 = LocalRelation('a.int, 'b.string, 'c.double)
|
||||
val input2 = LocalRelation('c.int, 'd.string, 'e.double)
|
||||
val query = Project('b :: Nil,
|
||||
Union(input1 :: input2 :: Nil)).analyze
|
||||
val expected = Project('b :: Nil,
|
||||
Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze
|
||||
comparePlans(Optimize.execute(query), expected)
|
||||
}
|
||||
|
||||
// todo: add more tests for column pruning
|
||||
}
|
||||
|
|
|
@ -41,6 +41,7 @@ class FilterPushdownSuite extends PlanTest {
|
|||
PushPredicateThroughJoin,
|
||||
PushPredicateThroughGenerate,
|
||||
PushPredicateThroughAggregate,
|
||||
ColumnPruning,
|
||||
CollapseProject) :: Nil
|
||||
}
|
||||
|
||||
|
@ -64,6 +65,52 @@ class FilterPushdownSuite extends PlanTest {
|
|||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
||||
test("column pruning for group") {
|
||||
val originalQuery =
|
||||
testRelation
|
||||
.groupBy('a)('a, count('b))
|
||||
.select('a)
|
||||
|
||||
val optimized = Optimize.execute(originalQuery.analyze)
|
||||
val correctAnswer =
|
||||
testRelation
|
||||
.select('a)
|
||||
.groupBy('a)('a).analyze
|
||||
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
||||
test("column pruning for group with alias") {
|
||||
val originalQuery =
|
||||
testRelation
|
||||
.groupBy('a)('a as 'c, count('b))
|
||||
.select('c)
|
||||
|
||||
val optimized = Optimize.execute(originalQuery.analyze)
|
||||
val correctAnswer =
|
||||
testRelation
|
||||
.select('a)
|
||||
.groupBy('a)('a as 'c).analyze
|
||||
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
||||
test("column pruning for Project(ne, Limit)") {
|
||||
val originalQuery =
|
||||
testRelation
|
||||
.select('a, 'b)
|
||||
.limit(2)
|
||||
.select('a)
|
||||
|
||||
val optimized = Optimize.execute(originalQuery.analyze)
|
||||
val correctAnswer =
|
||||
testRelation
|
||||
.select('a)
|
||||
.limit(2).analyze
|
||||
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
||||
// After this line is unimplemented.
|
||||
test("simple push down") {
|
||||
val originalQuery =
|
||||
|
@ -557,6 +604,39 @@ class FilterPushdownSuite extends PlanTest {
|
|||
comparePlans(optimized, originalQuery)
|
||||
}
|
||||
|
||||
test("push down project past sort") {
|
||||
val x = testRelation.subquery('x)
|
||||
|
||||
// push down valid
|
||||
val originalQuery = {
|
||||
x.select('a, 'b)
|
||||
.sortBy(SortOrder('a, Ascending))
|
||||
.select('a)
|
||||
}
|
||||
|
||||
val optimized = Optimize.execute(originalQuery.analyze)
|
||||
val correctAnswer =
|
||||
x.select('a)
|
||||
.sortBy(SortOrder('a, Ascending)).analyze
|
||||
|
||||
comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
|
||||
|
||||
// push down invalid
|
||||
val originalQuery1 = {
|
||||
x.select('a, 'b)
|
||||
.sortBy(SortOrder('a, Ascending))
|
||||
.select('b)
|
||||
}
|
||||
|
||||
val optimized1 = Optimize.execute(originalQuery1.analyze)
|
||||
val correctAnswer1 =
|
||||
x.select('a, 'b)
|
||||
.sortBy(SortOrder('a, Ascending))
|
||||
.select('b).analyze
|
||||
|
||||
comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1))
|
||||
}
|
||||
|
||||
test("push project and filter down into sample") {
|
||||
val x = testRelation.subquery('x)
|
||||
val originalQuery =
|
||||
|
|
|
@ -26,8 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.logical
|
||||
import org.apache.spark.sql.catalyst.plans.logical.Statistics
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
|
||||
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
|
||||
import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
|
@ -64,7 +63,7 @@ private[sql] case class InMemoryRelation(
|
|||
@transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null,
|
||||
@transient private[sql] var _statistics: Statistics = null,
|
||||
private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null)
|
||||
extends logical.LeafNode with MultiInstanceRelation {
|
||||
extends LogicalPlan with MultiInstanceRelation {
|
||||
|
||||
override def producedAttributes: AttributeSet = outputSet
|
||||
|
||||
|
@ -185,6 +184,8 @@ private[sql] case class InMemoryRelation(
|
|||
_cachedColumnBuffers, statisticsToBePropagated, batchStats)
|
||||
}
|
||||
|
||||
override def children: Seq[LogicalPlan] = Seq.empty
|
||||
|
||||
override def newInstance(): this.type = {
|
||||
new InMemoryRelation(
|
||||
output.map(_.newInstance()),
|
||||
|
|
Loading…
Reference in a new issue