[SPARK-17273][SQL] Move expression optimizer rules into a separate file
## What changes were proposed in this pull request? As part of breaking Optimizer.scala apart, this patch moves various expression optimization rules into a single file. ## How was this patch tested? This should be covered by existing tests. Author: Reynold Xin <rxin@databricks.com> Closes #14845 from rxin/SPARK-17273.
This commit is contained in:
parent
0243b32873
commit
5aad4509c1
|
@ -533,176 +533,6 @@ object CollapseRepartition extends Rule[LogicalPlan] {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition.
|
||||
* For example, when the expression is just checking to see if a string starts with a given
|
||||
* pattern.
|
||||
*/
|
||||
object LikeSimplification extends Rule[LogicalPlan] {
|
||||
// if guards below protect from escapes on trailing %.
|
||||
// Cases like "something\%" are not optimized, but this does not affect correctness.
|
||||
private val startsWith = "([^_%]+)%".r
|
||||
private val endsWith = "%([^_%]+)".r
|
||||
private val startsAndEndsWith = "([^_%]+)%([^_%]+)".r
|
||||
private val contains = "%([^_%]+)%".r
|
||||
private val equalTo = "([^_%]*)".r
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
|
||||
case Like(input, Literal(pattern, StringType)) =>
|
||||
pattern.toString match {
|
||||
case startsWith(prefix) if !prefix.endsWith("\\") =>
|
||||
StartsWith(input, Literal(prefix))
|
||||
case endsWith(postfix) =>
|
||||
EndsWith(input, Literal(postfix))
|
||||
// 'a%a' pattern is basically same with 'a%' && '%a'.
|
||||
// However, the additional `Length` condition is required to prevent 'a' match 'a%a'.
|
||||
case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") =>
|
||||
And(GreaterThanOrEqual(Length(input), Literal(prefix.size + postfix.size)),
|
||||
And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix))))
|
||||
case contains(infix) if !infix.endsWith("\\") =>
|
||||
Contains(input, Literal(infix))
|
||||
case equalTo(str) =>
|
||||
EqualTo(input, Literal(str))
|
||||
case _ =>
|
||||
Like(input, Literal.create(pattern, StringType))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Replaces [[Expression Expressions]] that can be statically evaluated with
|
||||
* equivalent [[Literal]] values. This rule is more specific with
|
||||
* Null value propagation from bottom to top of the expression tree.
|
||||
*/
|
||||
object NullPropagation extends Rule[LogicalPlan] {
|
||||
private def nonNullLiteral(e: Expression): Boolean = e match {
|
||||
case Literal(null, _) => false
|
||||
case _ => true
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsUp {
|
||||
case e @ WindowExpression(Cast(Literal(0L, _), _), _) =>
|
||||
Cast(Literal(0L), e.dataType)
|
||||
case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) =>
|
||||
Cast(Literal(0L), e.dataType)
|
||||
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
|
||||
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
|
||||
case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType)
|
||||
case e @ GetArrayItem(_, Literal(null, _)) => Literal.create(null, e.dataType)
|
||||
case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType)
|
||||
case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType)
|
||||
case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType)
|
||||
case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) =>
|
||||
Literal.create(null, e.dataType)
|
||||
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
|
||||
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
|
||||
case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) =>
|
||||
// This rule should be only triggered when isDistinct field is false.
|
||||
ae.copy(aggregateFunction = Count(Literal(1)))
|
||||
|
||||
// For Coalesce, remove null literals.
|
||||
case e @ Coalesce(children) =>
|
||||
val newChildren = children.filter(nonNullLiteral)
|
||||
if (newChildren.isEmpty) {
|
||||
Literal.create(null, e.dataType)
|
||||
} else if (newChildren.length == 1) {
|
||||
newChildren.head
|
||||
} else {
|
||||
Coalesce(newChildren)
|
||||
}
|
||||
|
||||
case e @ Substring(Literal(null, _), _, _) => Literal.create(null, e.dataType)
|
||||
case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType)
|
||||
case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType)
|
||||
|
||||
// Put exceptional cases above if any
|
||||
case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, e.dataType)
|
||||
case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, e.dataType)
|
||||
|
||||
case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, e.dataType)
|
||||
case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, e.dataType)
|
||||
|
||||
case e: StringRegexExpression => e.children match {
|
||||
case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
|
||||
case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
|
||||
case _ => e
|
||||
}
|
||||
|
||||
case e: StringPredicate => e.children match {
|
||||
case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
|
||||
case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
|
||||
case _ => e
|
||||
}
|
||||
|
||||
// If the value expression is NULL then transform the In expression to
|
||||
// Literal(null)
|
||||
case In(Literal(null, _), list) => Literal.create(null, BooleanType)
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Propagate foldable expressions:
|
||||
* Replace attributes with aliases of the original foldable expressions if possible.
|
||||
* Other optimizations will take advantage of the propagated foldable expressions.
|
||||
*
|
||||
* {{{
|
||||
* SELECT 1.0 x, 'abc' y, Now() z ORDER BY x, y, 3
|
||||
* ==> SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now()
|
||||
* }}}
|
||||
*/
|
||||
object FoldablePropagation extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = {
|
||||
val foldableMap = AttributeMap(plan.flatMap {
|
||||
case Project(projectList, _) => projectList.collect {
|
||||
case a: Alias if a.child.foldable => (a.toAttribute, a)
|
||||
}
|
||||
case _ => Nil
|
||||
})
|
||||
|
||||
if (foldableMap.isEmpty) {
|
||||
plan
|
||||
} else {
|
||||
var stop = false
|
||||
CleanupAliases(plan.transformUp {
|
||||
case u: Union =>
|
||||
stop = true
|
||||
u
|
||||
case c: Command =>
|
||||
stop = true
|
||||
c
|
||||
// For outer join, although its output attributes are derived from its children, they are
|
||||
// actually different attributes: the output of outer join is not always picked from its
|
||||
// children, but can also be null.
|
||||
// TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
|
||||
// of outer join.
|
||||
case j @ Join(_, _, LeftOuter | RightOuter | FullOuter, _) =>
|
||||
stop = true
|
||||
j
|
||||
|
||||
// These 3 operators take attributes as constructor parameters, and these attributes
|
||||
// can't be replaced by alias.
|
||||
case m: MapGroups =>
|
||||
stop = true
|
||||
m
|
||||
case f: FlatMapGroupsInR =>
|
||||
stop = true
|
||||
f
|
||||
case c: CoGroup =>
|
||||
stop = true
|
||||
c
|
||||
|
||||
case p: LogicalPlan if !stop => p.transformExpressions {
|
||||
case a: AttributeReference if foldableMap.contains(a) =>
|
||||
foldableMap(a)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a list of additional filters from an operator's existing constraint but remove those
|
||||
* that are either already part of the operator's condition or are part of the operator's child
|
||||
|
@ -742,261 +572,6 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reorder associative integral-type operators and fold all constants into one.
|
||||
*/
|
||||
object ReorderAssociativeOperator extends Rule[LogicalPlan] {
|
||||
private def flattenAdd(e: Expression): Seq[Expression] = e match {
|
||||
case Add(l, r) => flattenAdd(l) ++ flattenAdd(r)
|
||||
case other => other :: Nil
|
||||
}
|
||||
|
||||
private def flattenMultiply(e: Expression): Seq[Expression] = e match {
|
||||
case Multiply(l, r) => flattenMultiply(l) ++ flattenMultiply(r)
|
||||
case other => other :: Nil
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsDown {
|
||||
case a: Add if a.deterministic && a.dataType.isInstanceOf[IntegralType] =>
|
||||
val (foldables, others) = flattenAdd(a).partition(_.foldable)
|
||||
if (foldables.size > 1) {
|
||||
val foldableExpr = foldables.reduce((x, y) => Add(x, y))
|
||||
val c = Literal.create(foldableExpr.eval(EmptyRow), a.dataType)
|
||||
if (others.isEmpty) c else Add(others.reduce((x, y) => Add(x, y)), c)
|
||||
} else {
|
||||
a
|
||||
}
|
||||
case m: Multiply if m.deterministic && m.dataType.isInstanceOf[IntegralType] =>
|
||||
val (foldables, others) = flattenMultiply(m).partition(_.foldable)
|
||||
if (foldables.size > 1) {
|
||||
val foldableExpr = foldables.reduce((x, y) => Multiply(x, y))
|
||||
val c = Literal.create(foldableExpr.eval(EmptyRow), m.dataType)
|
||||
if (others.isEmpty) c else Multiply(others.reduce((x, y) => Multiply(x, y)), c)
|
||||
} else {
|
||||
m
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Replaces [[Expression Expressions]] that can be statically evaluated with
|
||||
* equivalent [[Literal]] values.
|
||||
*/
|
||||
object ConstantFolding extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsDown {
|
||||
// Skip redundant folding of literals. This rule is technically not necessary. Placing this
|
||||
// here avoids running the next rule for Literal values, which would create a new Literal
|
||||
// object and running eval unnecessarily.
|
||||
case l: Literal => l
|
||||
|
||||
// Fold expressions that are foldable.
|
||||
case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimize IN predicates:
|
||||
* 1. Removes literal repetitions.
|
||||
* 2. Replaces [[In (value, seq[Literal])]] with optimized version
|
||||
* [[InSet (value, HashSet[Literal])]] which is much faster.
|
||||
*/
|
||||
case class OptimizeIn(conf: CatalystConf) extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsDown {
|
||||
case expr @ In(v, list) if expr.inSetConvertible =>
|
||||
val newList = ExpressionSet(list).toSeq
|
||||
if (newList.size > conf.optimizerInSetConversionThreshold) {
|
||||
val hSet = newList.map(e => e.eval(EmptyRow))
|
||||
InSet(v, HashSet() ++ hSet)
|
||||
} else if (newList.size < list.size) {
|
||||
expr.copy(list = newList)
|
||||
} else { // newList.length == list.length
|
||||
expr
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Simplifies boolean expressions:
|
||||
* 1. Simplifies expressions whose answer can be determined without evaluating both sides.
|
||||
* 2. Eliminates / extracts common factors.
|
||||
* 3. Merge same expressions
|
||||
* 4. Removes `Not` operator.
|
||||
*/
|
||||
object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsUp {
|
||||
case TrueLiteral And e => e
|
||||
case e And TrueLiteral => e
|
||||
case FalseLiteral Or e => e
|
||||
case e Or FalseLiteral => e
|
||||
|
||||
case FalseLiteral And _ => FalseLiteral
|
||||
case _ And FalseLiteral => FalseLiteral
|
||||
case TrueLiteral Or _ => TrueLiteral
|
||||
case _ Or TrueLiteral => TrueLiteral
|
||||
|
||||
case a And b if a.semanticEquals(b) => a
|
||||
case a Or b if a.semanticEquals(b) => a
|
||||
|
||||
case a And (b Or c) if Not(a).semanticEquals(b) => And(a, c)
|
||||
case a And (b Or c) if Not(a).semanticEquals(c) => And(a, b)
|
||||
case (a Or b) And c if a.semanticEquals(Not(c)) => And(b, c)
|
||||
case (a Or b) And c if b.semanticEquals(Not(c)) => And(a, c)
|
||||
|
||||
case a Or (b And c) if Not(a).semanticEquals(b) => Or(a, c)
|
||||
case a Or (b And c) if Not(a).semanticEquals(c) => Or(a, b)
|
||||
case (a And b) Or c if a.semanticEquals(Not(c)) => Or(b, c)
|
||||
case (a And b) Or c if b.semanticEquals(Not(c)) => Or(a, c)
|
||||
|
||||
// Common factor elimination for conjunction
|
||||
case and @ (left And right) =>
|
||||
// 1. Split left and right to get the disjunctive predicates,
|
||||
// i.e. lhs = (a, b), rhs = (a, c)
|
||||
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
|
||||
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
|
||||
// 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff)
|
||||
val lhs = splitDisjunctivePredicates(left)
|
||||
val rhs = splitDisjunctivePredicates(right)
|
||||
val common = lhs.filter(e => rhs.exists(e.semanticEquals))
|
||||
if (common.isEmpty) {
|
||||
// No common factors, return the original predicate
|
||||
and
|
||||
} else {
|
||||
val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
|
||||
val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
|
||||
if (ldiff.isEmpty || rdiff.isEmpty) {
|
||||
// (a || b || c || ...) && (a || b) => (a || b)
|
||||
common.reduce(Or)
|
||||
} else {
|
||||
// (a || b || c || ...) && (a || b || d || ...) =>
|
||||
// ((c || ...) && (d || ...)) || a || b
|
||||
(common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
|
||||
}
|
||||
}
|
||||
|
||||
// Common factor elimination for disjunction
|
||||
case or @ (left Or right) =>
|
||||
// 1. Split left and right to get the conjunctive predicates,
|
||||
// i.e. lhs = (a, b), rhs = (a, c)
|
||||
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
|
||||
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
|
||||
// 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff)
|
||||
val lhs = splitConjunctivePredicates(left)
|
||||
val rhs = splitConjunctivePredicates(right)
|
||||
val common = lhs.filter(e => rhs.exists(e.semanticEquals))
|
||||
if (common.isEmpty) {
|
||||
// No common factors, return the original predicate
|
||||
or
|
||||
} else {
|
||||
val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
|
||||
val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
|
||||
if (ldiff.isEmpty || rdiff.isEmpty) {
|
||||
// (a && b) || (a && b && c && ...) => a && b
|
||||
common.reduce(And)
|
||||
} else {
|
||||
// (a && b && c && ...) || (a && b && d && ...) =>
|
||||
// ((c && ...) || (d && ...)) && a && b
|
||||
(common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
|
||||
}
|
||||
}
|
||||
|
||||
case Not(TrueLiteral) => FalseLiteral
|
||||
case Not(FalseLiteral) => TrueLiteral
|
||||
|
||||
case Not(a GreaterThan b) => LessThanOrEqual(a, b)
|
||||
case Not(a GreaterThanOrEqual b) => LessThan(a, b)
|
||||
|
||||
case Not(a LessThan b) => GreaterThanOrEqual(a, b)
|
||||
case Not(a LessThanOrEqual b) => GreaterThan(a, b)
|
||||
|
||||
case Not(a Or b) => And(Not(a), Not(b))
|
||||
case Not(a And b) => Or(Not(a), Not(b))
|
||||
|
||||
case Not(Not(e)) => e
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Simplifies binary comparisons with semantically-equal expressions:
|
||||
* 1) Replace '<=>' with 'true' literal.
|
||||
* 2) Replace '=', '<=', and '>=' with 'true' literal if both operands are non-nullable.
|
||||
* 3) Replace '<' and '>' with 'false' literal if both operands are non-nullable.
|
||||
*/
|
||||
object SimplifyBinaryComparison extends Rule[LogicalPlan] with PredicateHelper {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsUp {
|
||||
// True with equality
|
||||
case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral
|
||||
case a EqualTo b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral
|
||||
case a GreaterThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) =>
|
||||
TrueLiteral
|
||||
case a LessThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral
|
||||
|
||||
// False with inequality
|
||||
case a GreaterThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral
|
||||
case a LessThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Simplifies conditional expressions (if / case).
|
||||
*/
|
||||
object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
|
||||
private def falseOrNullLiteral(e: Expression): Boolean = e match {
|
||||
case FalseLiteral => true
|
||||
case Literal(null, _) => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsUp {
|
||||
case If(TrueLiteral, trueValue, _) => trueValue
|
||||
case If(FalseLiteral, _, falseValue) => falseValue
|
||||
case If(Literal(null, _), _, falseValue) => falseValue
|
||||
|
||||
case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
|
||||
// If there are branches that are always false, remove them.
|
||||
// If there are no more branches left, just use the else value.
|
||||
// Note that these two are handled together here in a single case statement because
|
||||
// otherwise we cannot determine the data type for the elseValue if it is None (i.e. null).
|
||||
val newBranches = branches.filter(x => !falseOrNullLiteral(x._1))
|
||||
if (newBranches.isEmpty) {
|
||||
elseValue.getOrElse(Literal.create(null, e.dataType))
|
||||
} else {
|
||||
e.copy(branches = newBranches)
|
||||
}
|
||||
|
||||
case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral) =>
|
||||
// If the first branch is a true literal, remove the entire CaseWhen and use the value
|
||||
// from that. Note that CaseWhen.branches should never be empty, and as a result the
|
||||
// headOption (rather than head) added above is just an extra (and unnecessary) safeguard.
|
||||
branches.head._2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimizes expressions by replacing according to CodeGen configuration.
|
||||
*/
|
||||
case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
|
||||
case e: CaseWhen if canCodegen(e) => e.toCodegen()
|
||||
}
|
||||
|
||||
private def canCodegen(e: CaseWhen): Boolean = {
|
||||
val numBranches = e.branches.size + e.elseValue.size
|
||||
numBranches <= conf.maxCaseBranchesForCodegen
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Combines all adjacent [[Union]] operators into a single [[Union]].
|
||||
*/
|
||||
|
@ -1026,7 +601,7 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
|
|||
/**
|
||||
* Removes no-op SortOrder from Sort
|
||||
*/
|
||||
object EliminateSorts extends Rule[LogicalPlan] {
|
||||
object EliminateSorts extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) =>
|
||||
val newOrders = orders.filterNot(_.child.foldable)
|
||||
|
@ -1448,25 +1023,6 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes [[Cast Casts]] that are unnecessary because the input is already the correct type.
|
||||
*/
|
||||
object SimplifyCasts extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
|
||||
case Cast(e, dataType) if e.dataType == dataType => e
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes nodes that are not necessary.
|
||||
*/
|
||||
object RemoveDispensableExpressions extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
|
||||
case UnaryPositive(child) => child
|
||||
case PromotePrecision(child) => child
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Combines two adjacent [[Limit]] operators into one, merging the
|
||||
* expressions into one single expression.
|
||||
|
@ -1482,21 +1038,6 @@ object CombineLimits extends Rule[LogicalPlan] {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes the inner case conversion expressions that are unnecessary because
|
||||
* the inner conversion is overwritten by the outer one.
|
||||
*/
|
||||
object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsUp {
|
||||
case Upper(Upper(child)) => Upper(child)
|
||||
case Upper(Lower(child)) => Upper(child)
|
||||
case Lower(Upper(child)) => Lower(child)
|
||||
case Lower(Lower(child)) => Lower(child)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Speeds up aggregates on fixed-precision decimals by executing them on unscaled Long values.
|
||||
*
|
||||
|
|
|
@ -0,0 +1,506 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.optimizer
|
||||
|
||||
import scala.collection.immutable.HashSet
|
||||
|
||||
import org.apache.spark.sql.catalyst.CatalystConf
|
||||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate._
|
||||
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.rules._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/*
|
||||
* Optimization rules defined in this file should not affect the structure of the logical plan.
|
||||
*/
|
||||
|
||||
|
||||
/**
|
||||
* Replaces [[Expression Expressions]] that can be statically evaluated with
|
||||
* equivalent [[Literal]] values.
|
||||
*/
|
||||
object ConstantFolding extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsDown {
|
||||
// Skip redundant folding of literals. This rule is technically not necessary. Placing this
|
||||
// here avoids running the next rule for Literal values, which would create a new Literal
|
||||
// object and running eval unnecessarily.
|
||||
case l: Literal => l
|
||||
|
||||
// Fold expressions that are foldable.
|
||||
case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Reorder associative integral-type operators and fold all constants into one.
|
||||
*/
|
||||
object ReorderAssociativeOperator extends Rule[LogicalPlan] {
|
||||
private def flattenAdd(e: Expression): Seq[Expression] = e match {
|
||||
case Add(l, r) => flattenAdd(l) ++ flattenAdd(r)
|
||||
case other => other :: Nil
|
||||
}
|
||||
|
||||
private def flattenMultiply(e: Expression): Seq[Expression] = e match {
|
||||
case Multiply(l, r) => flattenMultiply(l) ++ flattenMultiply(r)
|
||||
case other => other :: Nil
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsDown {
|
||||
case a: Add if a.deterministic && a.dataType.isInstanceOf[IntegralType] =>
|
||||
val (foldables, others) = flattenAdd(a).partition(_.foldable)
|
||||
if (foldables.size > 1) {
|
||||
val foldableExpr = foldables.reduce((x, y) => Add(x, y))
|
||||
val c = Literal.create(foldableExpr.eval(EmptyRow), a.dataType)
|
||||
if (others.isEmpty) c else Add(others.reduce((x, y) => Add(x, y)), c)
|
||||
} else {
|
||||
a
|
||||
}
|
||||
case m: Multiply if m.deterministic && m.dataType.isInstanceOf[IntegralType] =>
|
||||
val (foldables, others) = flattenMultiply(m).partition(_.foldable)
|
||||
if (foldables.size > 1) {
|
||||
val foldableExpr = foldables.reduce((x, y) => Multiply(x, y))
|
||||
val c = Literal.create(foldableExpr.eval(EmptyRow), m.dataType)
|
||||
if (others.isEmpty) c else Multiply(others.reduce((x, y) => Multiply(x, y)), c)
|
||||
} else {
|
||||
m
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Optimize IN predicates:
|
||||
* 1. Removes literal repetitions.
|
||||
* 2. Replaces [[In (value, seq[Literal])]] with optimized version
|
||||
* [[InSet (value, HashSet[Literal])]] which is much faster.
|
||||
*/
|
||||
case class OptimizeIn(conf: CatalystConf) extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsDown {
|
||||
case expr @ In(v, list) if expr.inSetConvertible =>
|
||||
val newList = ExpressionSet(list).toSeq
|
||||
if (newList.size > conf.optimizerInSetConversionThreshold) {
|
||||
val hSet = newList.map(e => e.eval(EmptyRow))
|
||||
InSet(v, HashSet() ++ hSet)
|
||||
} else if (newList.size < list.size) {
|
||||
expr.copy(list = newList)
|
||||
} else { // newList.length == list.length
|
||||
expr
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Simplifies boolean expressions:
|
||||
* 1. Simplifies expressions whose answer can be determined without evaluating both sides.
|
||||
* 2. Eliminates / extracts common factors.
|
||||
* 3. Merge same expressions
|
||||
* 4. Removes `Not` operator.
|
||||
*/
|
||||
object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsUp {
|
||||
case TrueLiteral And e => e
|
||||
case e And TrueLiteral => e
|
||||
case FalseLiteral Or e => e
|
||||
case e Or FalseLiteral => e
|
||||
|
||||
case FalseLiteral And _ => FalseLiteral
|
||||
case _ And FalseLiteral => FalseLiteral
|
||||
case TrueLiteral Or _ => TrueLiteral
|
||||
case _ Or TrueLiteral => TrueLiteral
|
||||
|
||||
case a And b if a.semanticEquals(b) => a
|
||||
case a Or b if a.semanticEquals(b) => a
|
||||
|
||||
case a And (b Or c) if Not(a).semanticEquals(b) => And(a, c)
|
||||
case a And (b Or c) if Not(a).semanticEquals(c) => And(a, b)
|
||||
case (a Or b) And c if a.semanticEquals(Not(c)) => And(b, c)
|
||||
case (a Or b) And c if b.semanticEquals(Not(c)) => And(a, c)
|
||||
|
||||
case a Or (b And c) if Not(a).semanticEquals(b) => Or(a, c)
|
||||
case a Or (b And c) if Not(a).semanticEquals(c) => Or(a, b)
|
||||
case (a And b) Or c if a.semanticEquals(Not(c)) => Or(b, c)
|
||||
case (a And b) Or c if b.semanticEquals(Not(c)) => Or(a, c)
|
||||
|
||||
// Common factor elimination for conjunction
|
||||
case and @ (left And right) =>
|
||||
// 1. Split left and right to get the disjunctive predicates,
|
||||
// i.e. lhs = (a, b), rhs = (a, c)
|
||||
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
|
||||
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
|
||||
// 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff)
|
||||
val lhs = splitDisjunctivePredicates(left)
|
||||
val rhs = splitDisjunctivePredicates(right)
|
||||
val common = lhs.filter(e => rhs.exists(e.semanticEquals))
|
||||
if (common.isEmpty) {
|
||||
// No common factors, return the original predicate
|
||||
and
|
||||
} else {
|
||||
val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
|
||||
val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
|
||||
if (ldiff.isEmpty || rdiff.isEmpty) {
|
||||
// (a || b || c || ...) && (a || b) => (a || b)
|
||||
common.reduce(Or)
|
||||
} else {
|
||||
// (a || b || c || ...) && (a || b || d || ...) =>
|
||||
// ((c || ...) && (d || ...)) || a || b
|
||||
(common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
|
||||
}
|
||||
}
|
||||
|
||||
// Common factor elimination for disjunction
|
||||
case or @ (left Or right) =>
|
||||
// 1. Split left and right to get the conjunctive predicates,
|
||||
// i.e. lhs = (a, b), rhs = (a, c)
|
||||
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
|
||||
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
|
||||
// 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff)
|
||||
val lhs = splitConjunctivePredicates(left)
|
||||
val rhs = splitConjunctivePredicates(right)
|
||||
val common = lhs.filter(e => rhs.exists(e.semanticEquals))
|
||||
if (common.isEmpty) {
|
||||
// No common factors, return the original predicate
|
||||
or
|
||||
} else {
|
||||
val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
|
||||
val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
|
||||
if (ldiff.isEmpty || rdiff.isEmpty) {
|
||||
// (a && b) || (a && b && c && ...) => a && b
|
||||
common.reduce(And)
|
||||
} else {
|
||||
// (a && b && c && ...) || (a && b && d && ...) =>
|
||||
// ((c && ...) || (d && ...)) && a && b
|
||||
(common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
|
||||
}
|
||||
}
|
||||
|
||||
case Not(TrueLiteral) => FalseLiteral
|
||||
case Not(FalseLiteral) => TrueLiteral
|
||||
|
||||
case Not(a GreaterThan b) => LessThanOrEqual(a, b)
|
||||
case Not(a GreaterThanOrEqual b) => LessThan(a, b)
|
||||
|
||||
case Not(a LessThan b) => GreaterThanOrEqual(a, b)
|
||||
case Not(a LessThanOrEqual b) => GreaterThan(a, b)
|
||||
|
||||
case Not(a Or b) => And(Not(a), Not(b))
|
||||
case Not(a And b) => Or(Not(a), Not(b))
|
||||
|
||||
case Not(Not(e)) => e
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Simplifies binary comparisons with semantically-equal expressions:
|
||||
* 1) Replace '<=>' with 'true' literal.
|
||||
* 2) Replace '=', '<=', and '>=' with 'true' literal if both operands are non-nullable.
|
||||
* 3) Replace '<' and '>' with 'false' literal if both operands are non-nullable.
|
||||
*/
|
||||
object SimplifyBinaryComparison extends Rule[LogicalPlan] with PredicateHelper {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsUp {
|
||||
// True with equality
|
||||
case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral
|
||||
case a EqualTo b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral
|
||||
case a GreaterThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) =>
|
||||
TrueLiteral
|
||||
case a LessThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral
|
||||
|
||||
// False with inequality
|
||||
case a GreaterThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral
|
||||
case a LessThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Simplifies conditional expressions (if / case).
|
||||
*/
|
||||
object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
|
||||
private def falseOrNullLiteral(e: Expression): Boolean = e match {
|
||||
case FalseLiteral => true
|
||||
case Literal(null, _) => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsUp {
|
||||
case If(TrueLiteral, trueValue, _) => trueValue
|
||||
case If(FalseLiteral, _, falseValue) => falseValue
|
||||
case If(Literal(null, _), _, falseValue) => falseValue
|
||||
|
||||
case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
|
||||
// If there are branches that are always false, remove them.
|
||||
// If there are no more branches left, just use the else value.
|
||||
// Note that these two are handled together here in a single case statement because
|
||||
// otherwise we cannot determine the data type for the elseValue if it is None (i.e. null).
|
||||
val newBranches = branches.filter(x => !falseOrNullLiteral(x._1))
|
||||
if (newBranches.isEmpty) {
|
||||
elseValue.getOrElse(Literal.create(null, e.dataType))
|
||||
} else {
|
||||
e.copy(branches = newBranches)
|
||||
}
|
||||
|
||||
case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral) =>
|
||||
// If the first branch is a true literal, remove the entire CaseWhen and use the value
|
||||
// from that. Note that CaseWhen.branches should never be empty, and as a result the
|
||||
// headOption (rather than head) added above is just an extra (and unnecessary) safeguard.
|
||||
branches.head._2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition.
|
||||
* For example, when the expression is just checking to see if a string starts with a given
|
||||
* pattern.
|
||||
*/
|
||||
object LikeSimplification extends Rule[LogicalPlan] {
|
||||
// if guards below protect from escapes on trailing %.
|
||||
// Cases like "something\%" are not optimized, but this does not affect correctness.
|
||||
private val startsWith = "([^_%]+)%".r
|
||||
private val endsWith = "%([^_%]+)".r
|
||||
private val startsAndEndsWith = "([^_%]+)%([^_%]+)".r
|
||||
private val contains = "%([^_%]+)%".r
|
||||
private val equalTo = "([^_%]*)".r
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
|
||||
case Like(input, Literal(pattern, StringType)) =>
|
||||
pattern.toString match {
|
||||
case startsWith(prefix) if !prefix.endsWith("\\") =>
|
||||
StartsWith(input, Literal(prefix))
|
||||
case endsWith(postfix) =>
|
||||
EndsWith(input, Literal(postfix))
|
||||
// 'a%a' pattern is basically same with 'a%' && '%a'.
|
||||
// However, the additional `Length` condition is required to prevent 'a' match 'a%a'.
|
||||
case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") =>
|
||||
And(GreaterThanOrEqual(Length(input), Literal(prefix.size + postfix.size)),
|
||||
And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix))))
|
||||
case contains(infix) if !infix.endsWith("\\") =>
|
||||
Contains(input, Literal(infix))
|
||||
case equalTo(str) =>
|
||||
EqualTo(input, Literal(str))
|
||||
case _ =>
|
||||
Like(input, Literal.create(pattern, StringType))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Replaces [[Expression Expressions]] that can be statically evaluated with
|
||||
* equivalent [[Literal]] values. This rule is more specific with
|
||||
* Null value propagation from bottom to top of the expression tree.
|
||||
*/
|
||||
object NullPropagation extends Rule[LogicalPlan] {
|
||||
private def nonNullLiteral(e: Expression): Boolean = e match {
|
||||
case Literal(null, _) => false
|
||||
case _ => true
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsUp {
|
||||
case e @ WindowExpression(Cast(Literal(0L, _), _), _) =>
|
||||
Cast(Literal(0L), e.dataType)
|
||||
case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) =>
|
||||
Cast(Literal(0L), e.dataType)
|
||||
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
|
||||
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
|
||||
case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType)
|
||||
case e @ GetArrayItem(_, Literal(null, _)) => Literal.create(null, e.dataType)
|
||||
case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType)
|
||||
case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType)
|
||||
case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType)
|
||||
case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) =>
|
||||
Literal.create(null, e.dataType)
|
||||
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
|
||||
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
|
||||
case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) =>
|
||||
// This rule should be only triggered when isDistinct field is false.
|
||||
ae.copy(aggregateFunction = Count(Literal(1)))
|
||||
|
||||
// For Coalesce, remove null literals.
|
||||
case e @ Coalesce(children) =>
|
||||
val newChildren = children.filter(nonNullLiteral)
|
||||
if (newChildren.isEmpty) {
|
||||
Literal.create(null, e.dataType)
|
||||
} else if (newChildren.length == 1) {
|
||||
newChildren.head
|
||||
} else {
|
||||
Coalesce(newChildren)
|
||||
}
|
||||
|
||||
case e @ Substring(Literal(null, _), _, _) => Literal.create(null, e.dataType)
|
||||
case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType)
|
||||
case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType)
|
||||
|
||||
// Put exceptional cases above if any
|
||||
case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, e.dataType)
|
||||
case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, e.dataType)
|
||||
|
||||
case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, e.dataType)
|
||||
case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, e.dataType)
|
||||
|
||||
case e: StringRegexExpression => e.children match {
|
||||
case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
|
||||
case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
|
||||
case _ => e
|
||||
}
|
||||
|
||||
case e: StringPredicate => e.children match {
|
||||
case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
|
||||
case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
|
||||
case _ => e
|
||||
}
|
||||
|
||||
// If the value expression is NULL then transform the In expression to
|
||||
// Literal(null)
|
||||
case In(Literal(null, _), list) => Literal.create(null, BooleanType)
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Propagate foldable expressions:
|
||||
* Replace attributes with aliases of the original foldable expressions if possible.
|
||||
* Other optimizations will take advantage of the propagated foldable expressions.
|
||||
*
|
||||
* {{{
|
||||
* SELECT 1.0 x, 'abc' y, Now() z ORDER BY x, y, 3
|
||||
* ==> SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now()
|
||||
* }}}
|
||||
*/
|
||||
object FoldablePropagation extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = {
|
||||
val foldableMap = AttributeMap(plan.flatMap {
|
||||
case Project(projectList, _) => projectList.collect {
|
||||
case a: Alias if a.child.foldable => (a.toAttribute, a)
|
||||
}
|
||||
case _ => Nil
|
||||
})
|
||||
|
||||
if (foldableMap.isEmpty) {
|
||||
plan
|
||||
} else {
|
||||
var stop = false
|
||||
CleanupAliases(plan.transformUp {
|
||||
case u: Union =>
|
||||
stop = true
|
||||
u
|
||||
case c: Command =>
|
||||
stop = true
|
||||
c
|
||||
// For outer join, although its output attributes are derived from its children, they are
|
||||
// actually different attributes: the output of outer join is not always picked from its
|
||||
// children, but can also be null.
|
||||
// TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
|
||||
// of outer join.
|
||||
case j @ Join(_, _, LeftOuter | RightOuter | FullOuter, _) =>
|
||||
stop = true
|
||||
j
|
||||
|
||||
// These 3 operators take attributes as constructor parameters, and these attributes
|
||||
// can't be replaced by alias.
|
||||
case m: MapGroups =>
|
||||
stop = true
|
||||
m
|
||||
case f: FlatMapGroupsInR =>
|
||||
stop = true
|
||||
f
|
||||
case c: CoGroup =>
|
||||
stop = true
|
||||
c
|
||||
|
||||
case p: LogicalPlan if !stop => p.transformExpressions {
|
||||
case a: AttributeReference if foldableMap.contains(a) =>
|
||||
foldableMap(a)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Optimizes expressions by replacing according to CodeGen configuration.
|
||||
*/
|
||||
case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
|
||||
case e: CaseWhen if canCodegen(e) => e.toCodegen()
|
||||
}
|
||||
|
||||
private def canCodegen(e: CaseWhen): Boolean = {
|
||||
val numBranches = e.branches.size + e.elseValue.size
|
||||
numBranches <= conf.maxCaseBranchesForCodegen
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Removes [[Cast Casts]] that are unnecessary because the input is already the correct type.
|
||||
*/
|
||||
object SimplifyCasts extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
|
||||
case Cast(e, dataType) if e.dataType == dataType => e
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Removes nodes that are not necessary.
|
||||
*/
|
||||
object RemoveDispensableExpressions extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
|
||||
case UnaryPositive(child) => child
|
||||
case PromotePrecision(child) => child
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Removes the inner case conversion expressions that are unnecessary because
|
||||
* the inner conversion is overwritten by the outer one.
|
||||
*/
|
||||
object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsUp {
|
||||
case Upper(Upper(child)) => Upper(child)
|
||||
case Upper(Lower(child)) => Upper(child)
|
||||
case Lower(Upper(child)) => Lower(child)
|
||||
case Lower(Lower(child)) => Lower(child)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue