[SPARK-24890][SQL] Short circuiting the if condition when trueValue and falseValue are the same

## What changes were proposed in this pull request?

When `trueValue` and `falseValue` are semantic equivalence, the condition expression in `if` can be removed to avoid extra computation in runtime.

## How was this patch tested?

Test added.

Author: DB Tsai <d_tsai@apple.com>

Closes #21848 from dbtsai/short-circuit-if.
This commit is contained in:
DB Tsai 2018-07-24 20:21:11 -07:00 committed by Xiao Li
parent c26b092169
commit d4c3415894
3 changed files with 29 additions and 4 deletions

View file

@ -390,6 +390,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue
case If(Literal(null, _), _, falseValue) => falseValue
case If(cond, trueValue, falseValue)
if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue
case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
// If there are branches that are always false, remove them.
@ -403,14 +405,14 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
e.copy(branches = newBranches)
}
case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral) =>
case CaseWhen(branches, _) if branches.headOption.map(_._1).contains(TrueLiteral) =>
// If the first branch is a true literal, remove the entire CaseWhen and use the value
// from that. Note that CaseWhen.branches should never be empty, and as a result the
// headOption (rather than head) added above is just an extra (and unnecessary) safeguard.
branches.head._2
case CaseWhen(branches, _) if branches.exists(_._1 == TrueLiteral) =>
// a branc with a TRue condition eliminates all following branches,
// a branch with a true condition eliminates all following branches,
// these branches can be pruned away
val (h, t) = branches.span(_._1 != TrueLiteral)
CaseWhen( h :+ t.head, None)
@ -651,6 +653,7 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
}
}
/**
* Combine nested [[Concat]] expressions.
*/

View file

@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
@ -29,7 +31,8 @@ import org.apache.spark.sql.types.{IntegerType, NullType}
class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil
val batches = Batch("SimplifyConditionals", FixedPoint(50),
BooleanSimplification, ConstantFolding, SimplifyConditionals) :: Nil
}
protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
@ -43,6 +46,8 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
private val unreachableBranch = (FalseLiteral, Literal(20))
private val nullBranch = (Literal.create(null, NullType), Literal(30))
private val testRelation = LocalRelation('a.int)
test("simplify if") {
assertEquivalent(
If(TrueLiteral, Literal(10), Literal(20)),
@ -57,6 +62,23 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
Literal(20))
}
test("remove unnecessary if when the outputs are semantic equivalence") {
assertEquivalent(
If(IsNotNull(UnresolvedAttribute("a")),
Subtract(Literal(10), Literal(1)),
Add(Literal(6), Literal(3))),
Literal(9))
// For non-deterministic condition, we don't remove the `If` statement.
assertEquivalent(
If(GreaterThan(Rand(0), Literal(0.5)),
Subtract(Literal(10), Literal(1)),
Add(Literal(6), Literal(3))),
If(GreaterThan(Rand(0), Literal(0.5)),
Literal(9),
Literal(9)))
}
test("remove unreachable branches") {
// i.e. removing branches whose conditions are always false
assertEquivalent(

View file

@ -393,7 +393,7 @@ private[sql] trait SQLTestUtilsBase
}
/**
* Returns full path to the given file in the resouce folder
* Returns full path to the given file in the resource folder
*/
protected def testFile(fileName: String): String = {
Thread.currentThread().getContextClassLoader.getResource(fileName).toString