[SPARK-25860][SPARK-26107][FOLLOW-UP] Rule ReplaceNullWithFalseInPredicate

## What changes were proposed in this pull request?

Based on https://github.com/apache/spark/pull/22857 and https://github.com/apache/spark/pull/23079, this PR did a few updates

- Limit the data types of NULL to Boolean.
- Limit the input data type of replaceNullWithFalse to Boolean; throw an exception in the testing mode.
- Create a new file for the rule ReplaceNullWithFalseInPredicate
- Update the description of this rule.

## How was this patch tested?
Added a test case

Closes #23139 from gatorsmile/followupSpark-25860.

Authored-by: gatorsmile <gatorsmile@gmail.com>
Signed-off-by: DB Tsai <d_tsai@apple.com>
This commit is contained in:
gatorsmile 2018-11-27 04:51:32 +00:00 committed by DB Tsai
parent 1c487f7d14
commit 85383d29ed
No known key found for this signature in database
GPG key ID: E6FD79DA81FE14FD
3 changed files with 119 additions and 68 deletions

View file

@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, If}
import org.apache.spark.sql.catalyst.expressions.{LambdaFunction, Literal, MapFilter, Or}
import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.util.Utils
/**
* A rule that replaces `Literal(null, BooleanType)` with `FalseLiteral`, if possible, in the search
* condition of the WHERE/HAVING/ON(JOIN) clauses, which contain an implicit Boolean operator
* "(search condition) = TRUE". The replacement is only valid when `Literal(null, BooleanType)` is
* semantically equivalent to `FalseLiteral` when evaluating the whole search condition.
*
* Please note that FALSE and NULL are not exchangeable in most cases, when the search condition
* contains NOT and NULL-tolerant expressions. Thus, the rule is very conservative and applicable
* in very limited cases.
*
* For example, `Filter(Literal(null, BooleanType))` is equal to `Filter(FalseLiteral)`.
*
* Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`;
* this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually
* `Filter(FalseLiteral)`.
*
* Moreover, this rule also transforms predicates in all [[If]] expressions as well as branch
* conditions in all [[CaseWhen]] expressions, even if they are not part of the search conditions.
*
* For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))` can be simplified
* into `Project(Literal(2))`.
*/
object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond)))
case p: LogicalPlan => p transformExpressions {
case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred))
case cw @ CaseWhen(branches, _) =>
val newBranches = branches.map { case (cond, value) =>
replaceNullWithFalse(cond) -> value
}
cw.copy(branches = newBranches)
case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) =>
val newLambda = lf.copy(function = replaceNullWithFalse(func))
af.copy(function = newLambda)
case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) =>
val newLambda = lf.copy(function = replaceNullWithFalse(func))
ae.copy(function = newLambda)
case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) =>
val newLambda = lf.copy(function = replaceNullWithFalse(func))
mf.copy(function = newLambda)
}
}
/**
* Recursively traverse the Boolean-type expression to replace
* `Literal(null, BooleanType)` with `FalseLiteral`, if possible.
*
* Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit
* an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or
* `Literal(null, BooleanType)`.
*/
private def replaceNullWithFalse(e: Expression): Expression = e match {
case Literal(null, BooleanType) =>
FalseLiteral
case And(left, right) =>
And(replaceNullWithFalse(left), replaceNullWithFalse(right))
case Or(left, right) =>
Or(replaceNullWithFalse(left), replaceNullWithFalse(right))
case cw: CaseWhen if cw.dataType == BooleanType =>
val newBranches = cw.branches.map { case (cond, value) =>
replaceNullWithFalse(cond) -> replaceNullWithFalse(value)
}
val newElseValue = cw.elseValue.map(replaceNullWithFalse)
CaseWhen(newBranches, newElseValue)
case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType =>
If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal))
case e if e.dataType == BooleanType =>
e
case e =>
val message = "Expected a Boolean type expression in replaceNullWithFalse, " +
s"but got the type `${e.dataType.catalogString}` in `${e.sql}`."
if (Utils.isTesting) {
throw new IllegalArgumentException(message)
} else {
logWarning(message)
e
}
}
}

View file

@ -736,69 +736,3 @@ object CombineConcats extends Rule[LogicalPlan] {
flattenConcats(concat)
}
}
/**
* A rule that replaces `Literal(null, _)` with `FalseLiteral` for further optimizations.
*
* This rule applies to conditions in [[Filter]] and [[Join]]. Moreover, it transforms predicates
* in all [[If]] expressions as well as branch conditions in all [[CaseWhen]] expressions.
*
* For example, `Filter(Literal(null, _))` is equal to `Filter(FalseLiteral)`.
*
* Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`;
* this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually
* `Filter(FalseLiteral)`.
*
* As this rule is not limited to conditions in [[Filter]] and [[Join]], arbitrary plans can
* benefit from it. For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))`
* can be simplified into `Project(Literal(2))`.
*
* As a result, many unnecessary computations can be removed in the query optimization phase.
*/
object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond)))
case p: LogicalPlan => p transformExpressions {
case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred))
case cw @ CaseWhen(branches, _) =>
val newBranches = branches.map { case (cond, value) =>
replaceNullWithFalse(cond) -> value
}
cw.copy(branches = newBranches)
case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) =>
val newLambda = lf.copy(function = replaceNullWithFalse(func))
af.copy(function = newLambda)
case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) =>
val newLambda = lf.copy(function = replaceNullWithFalse(func))
ae.copy(function = newLambda)
case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) =>
val newLambda = lf.copy(function = replaceNullWithFalse(func))
mf.copy(function = newLambda)
}
}
/**
* Recursively replaces `Literal(null, _)` with `FalseLiteral`.
*
* Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit
* an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or `Literal(null, _)`.
*/
private def replaceNullWithFalse(e: Expression): Expression = e match {
case cw: CaseWhen if cw.dataType == BooleanType =>
val newBranches = cw.branches.map { case (cond, value) =>
replaceNullWithFalse(cond) -> replaceNullWithFalse(value)
}
val newElseValue = cw.elseValue.map(replaceNullWithFalse)
CaseWhen(newBranches, newElseValue)
case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType =>
If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal))
case And(left, right) =>
And(replaceNullWithFalse(left), replaceNullWithFalse(right))
case Or(left, right) =>
Or(replaceNullWithFalse(left), replaceNullWithFalse(right))
case Literal(null, _) => FalseLiteral
case _ => e
}
}

View file

@ -44,8 +44,15 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
private val anotherTestRelation = LocalRelation('d.int)
test("replace null inside filter and join conditions") {
testFilter(originalCond = Literal(null), expectedCond = FalseLiteral)
testJoin(originalCond = Literal(null), expectedCond = FalseLiteral)
testFilter(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
testJoin(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
}
test("Not expected type - replaceNullWithFalse") {
val e = intercept[IllegalArgumentException] {
testFilter(originalCond = Literal(null, IntegerType), expectedCond = FalseLiteral)
}.getMessage
assert(e.contains("but got the type `int` in `CAST(NULL AS INT)"))
}
test("replace null in branches of If") {