[SPARK-20758][SQL] Add Constant propagation optimization
## What changes were proposed in this pull request? See class doc of `ConstantPropagation` for the approach used. ## How was this patch tested? - Added unit tests Author: Tejas Patil <tejasp@fb.com> Closes #17993 from tejasapatil/SPARK-20758_const_propagation.
This commit is contained in:
parent
9d0db5a7f8
commit
f9b59abeae
|
@ -92,6 +92,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf)
|
|||
CombineUnions,
|
||||
// Constant folding and strength reduction
|
||||
NullPropagation(conf),
|
||||
ConstantPropagation,
|
||||
FoldablePropagation,
|
||||
OptimizeIn(conf),
|
||||
ConstantFolding,
|
||||
|
|
|
@ -54,6 +54,62 @@ object ConstantFolding extends Rule[LogicalPlan] {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Substitutes [[Attribute Attributes]] which can be statically evaluated with their corresponding
|
||||
* value in conjunctive [[Expression Expressions]]
|
||||
* eg.
|
||||
* {{{
|
||||
* SELECT * FROM table WHERE i = 5 AND j = i + 3
|
||||
* ==> SELECT * FROM table WHERE i = 5 AND j = 8
|
||||
* }}}
|
||||
*
|
||||
* Approach used:
|
||||
* - Start from AND operator as the root
|
||||
* - Get all the children conjunctive predicates which are EqualTo / EqualNullSafe such that they
|
||||
* don't have a `NOT` or `OR` operator in them
|
||||
* - Populate a mapping of attribute => constant value by looking at all the equals predicates
|
||||
* - Using this mapping, replace occurrence of the attributes with the corresponding constant values
|
||||
* in the AND node.
|
||||
*/
|
||||
object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
|
||||
private def containsNonConjunctionPredicates(expression: Expression): Boolean = expression.find {
|
||||
case _: Not | _: Or => true
|
||||
case _ => false
|
||||
}.isDefined
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case f: Filter => f transformExpressionsUp {
|
||||
case and: And =>
|
||||
val conjunctivePredicates =
|
||||
splitConjunctivePredicates(and)
|
||||
.filter(expr => expr.isInstanceOf[EqualTo] || expr.isInstanceOf[EqualNullSafe])
|
||||
.filterNot(expr => containsNonConjunctionPredicates(expr))
|
||||
|
||||
val equalityPredicates = conjunctivePredicates.collect {
|
||||
case e @ EqualTo(left: AttributeReference, right: Literal) => ((left, right), e)
|
||||
case e @ EqualTo(left: Literal, right: AttributeReference) => ((right, left), e)
|
||||
case e @ EqualNullSafe(left: AttributeReference, right: Literal) => ((left, right), e)
|
||||
case e @ EqualNullSafe(left: Literal, right: AttributeReference) => ((right, left), e)
|
||||
}
|
||||
|
||||
val constantsMap = AttributeMap(equalityPredicates.map(_._1))
|
||||
val predicates = equalityPredicates.map(_._2).toSet
|
||||
|
||||
def replaceConstants(expression: Expression) = expression transform {
|
||||
case a: AttributeReference =>
|
||||
constantsMap.get(a) match {
|
||||
case Some(literal) => literal
|
||||
case None => a
|
||||
}
|
||||
}
|
||||
|
||||
and transform {
|
||||
case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants(e)
|
||||
case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reorder associative integral-type operators and fold all constants into one.
|
||||
|
|
|
@ -0,0 +1,167 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.optimizer
|
||||
|
||||
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.PlanTest
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
|
||||
import org.apache.spark.sql.catalyst.rules.RuleExecutor
|
||||
|
||||
/**
|
||||
* Unit tests for constant propagation in expressions.
|
||||
*/
|
||||
class ConstantPropagationSuite extends PlanTest {
|
||||
|
||||
object Optimize extends RuleExecutor[LogicalPlan] {
|
||||
val batches =
|
||||
Batch("AnalysisNodes", Once,
|
||||
EliminateSubqueryAliases) ::
|
||||
Batch("ConstantPropagation", FixedPoint(10),
|
||||
ColumnPruning,
|
||||
ConstantPropagation,
|
||||
ConstantFolding,
|
||||
BooleanSimplification) :: Nil
|
||||
}
|
||||
|
||||
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
|
||||
|
||||
private val columnA = 'a.int
|
||||
private val columnB = 'b.int
|
||||
private val columnC = 'c.int
|
||||
|
||||
test("basic test") {
|
||||
val query = testRelation
|
||||
.select(columnA)
|
||||
.where(columnA === Add(columnB, Literal(1)) && columnB === Literal(10))
|
||||
|
||||
val correctAnswer =
|
||||
testRelation
|
||||
.select(columnA)
|
||||
.where(columnA === Literal(11) && columnB === Literal(10)).analyze
|
||||
|
||||
comparePlans(Optimize.execute(query.analyze), correctAnswer)
|
||||
}
|
||||
|
||||
test("with combination of AND and OR predicates") {
|
||||
val query = testRelation
|
||||
.select(columnA)
|
||||
.where(
|
||||
columnA === Add(columnB, Literal(1)) &&
|
||||
columnB === Literal(10) &&
|
||||
(columnA === Add(columnC, Literal(3)) || columnB === columnC))
|
||||
.analyze
|
||||
|
||||
val correctAnswer =
|
||||
testRelation
|
||||
.select(columnA)
|
||||
.where(
|
||||
columnA === Literal(11) &&
|
||||
columnB === Literal(10) &&
|
||||
(Literal(11) === Add(columnC, Literal(3)) || Literal(10) === columnC))
|
||||
.analyze
|
||||
|
||||
comparePlans(Optimize.execute(query), correctAnswer)
|
||||
}
|
||||
|
||||
test("equality predicates outside a `NOT` can be propagated within a `NOT`") {
|
||||
val query = testRelation
|
||||
.select(columnA)
|
||||
.where(Not(columnA === Add(columnB, Literal(1))) && columnB === Literal(10))
|
||||
.analyze
|
||||
|
||||
val correctAnswer =
|
||||
testRelation
|
||||
.select(columnA)
|
||||
.where(Not(columnA === Literal(11)) && columnB === Literal(10))
|
||||
.analyze
|
||||
|
||||
comparePlans(Optimize.execute(query), correctAnswer)
|
||||
}
|
||||
|
||||
test("equality predicates inside a `NOT` should not be picked for propagation") {
|
||||
val query = testRelation
|
||||
.select(columnA)
|
||||
.where(Not(columnB === Literal(10)) && columnA === Add(columnB, Literal(1)))
|
||||
.analyze
|
||||
|
||||
comparePlans(Optimize.execute(query), query)
|
||||
}
|
||||
|
||||
test("equality predicates outside a `OR` can be propagated within a `OR`") {
|
||||
val query = testRelation
|
||||
.select(columnA)
|
||||
.where(
|
||||
columnA === Literal(2) &&
|
||||
(columnA === Add(columnB, Literal(3)) || columnB === Literal(9)))
|
||||
.analyze
|
||||
|
||||
val correctAnswer = testRelation
|
||||
.select(columnA)
|
||||
.where(
|
||||
columnA === Literal(2) &&
|
||||
(Literal(2) === Add(columnB, Literal(3)) || columnB === Literal(9)))
|
||||
.analyze
|
||||
|
||||
comparePlans(Optimize.execute(query), correctAnswer)
|
||||
}
|
||||
|
||||
test("equality predicates inside a `OR` should not be picked for propagation") {
|
||||
val query = testRelation
|
||||
.select(columnA)
|
||||
.where(
|
||||
columnA === Add(columnB, Literal(2)) &&
|
||||
(columnA === Add(columnB, Literal(3)) || columnB === Literal(9)))
|
||||
.analyze
|
||||
|
||||
comparePlans(Optimize.execute(query), query)
|
||||
}
|
||||
|
||||
test("equality operator not immediate child of root `AND` should not be used for propagation") {
|
||||
val query = testRelation
|
||||
.select(columnA)
|
||||
.where(
|
||||
columnA === Literal(0) &&
|
||||
((columnB === columnA) === (columnB === Literal(0))))
|
||||
.analyze
|
||||
|
||||
val correctAnswer = testRelation
|
||||
.select(columnA)
|
||||
.where(
|
||||
columnA === Literal(0) &&
|
||||
((columnB === Literal(0)) === (columnB === Literal(0))))
|
||||
.analyze
|
||||
|
||||
comparePlans(Optimize.execute(query), correctAnswer)
|
||||
}
|
||||
|
||||
test("conflicting equality predicates") {
|
||||
val query = testRelation
|
||||
.select(columnA)
|
||||
.where(
|
||||
columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3)))
|
||||
|
||||
val correctAnswer = testRelation
|
||||
.select(columnA)
|
||||
.where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5))
|
||||
|
||||
comparePlans(Optimize.execute(query.analyze), correctAnswer)
|
||||
}
|
||||
}
|
|
@ -190,7 +190,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
|
|||
checkDataFilters(Set.empty)
|
||||
|
||||
// Only one file should be read.
|
||||
checkScan(table.where("p1 = 1 AND c1 = 1 AND (p1 + c1) = 1")) { partitions =>
|
||||
checkScan(table.where("p1 = 1 AND c1 = 1 AND (p1 + c1) = 2")) { partitions =>
|
||||
assert(partitions.size == 1, "when checking partitions")
|
||||
assert(partitions.head.files.size == 1, "when checking files in partition 1")
|
||||
assert(partitions.head.files.head.partitionValues.getInt(0) == 1,
|
||||
|
@ -217,7 +217,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
|
|||
checkDataFilters(Set.empty)
|
||||
|
||||
// Only one file should be read.
|
||||
checkScan(table.where("P1 = 1 AND C1 = 1 AND (P1 + C1) = 1")) { partitions =>
|
||||
checkScan(table.where("P1 = 1 AND C1 = 1 AND (P1 + C1) = 2")) { partitions =>
|
||||
assert(partitions.size == 1, "when checking partitions")
|
||||
assert(partitions.head.files.size == 1, "when checking files in partition 1")
|
||||
assert(partitions.head.files.head.partitionValues.getInt(0) == 1,
|
||||
|
@ -235,13 +235,17 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
|
|||
"p1=1/file1" -> 10,
|
||||
"p1=2/file2" -> 10))
|
||||
|
||||
val df = table.where("p1 = 1 AND (p1 + c1) = 2 AND c1 = 1")
|
||||
val df1 = table.where("p1 = 1 AND (p1 + c1) = 2 AND c1 = 1")
|
||||
// Filter on data only are advisory so we have to reevaluate.
|
||||
assert(getPhysicalFilters(df) contains resolve(df, "c1 = 1"))
|
||||
// Need to evalaute filters that are not pushed down.
|
||||
assert(getPhysicalFilters(df) contains resolve(df, "(p1 + c1) = 2"))
|
||||
assert(getPhysicalFilters(df1) contains resolve(df1, "c1 = 1"))
|
||||
// Don't reevaluate partition only filters.
|
||||
assert(!(getPhysicalFilters(df) contains resolve(df, "p1 = 1")))
|
||||
assert(!(getPhysicalFilters(df1) contains resolve(df1, "p1 = 1")))
|
||||
|
||||
val df2 = table.where("(p1 + c2) = 2 AND c1 = 1")
|
||||
// Filter on data only are advisory so we have to reevaluate.
|
||||
assert(getPhysicalFilters(df2) contains resolve(df2, "c1 = 1"))
|
||||
// Need to evalaute filters that are not pushed down.
|
||||
assert(getPhysicalFilters(df2) contains resolve(df2, "(p1 + c2) = 2"))
|
||||
}
|
||||
|
||||
test("bucketed table") {
|
||||
|
|
Loading…
Reference in a new issue