[SPARK-24890][SQL] Short circuiting the if
condition when trueValue
and falseValue
are the same
## What changes were proposed in this pull request? When `trueValue` and `falseValue` are semantic equivalence, the condition expression in `if` can be removed to avoid extra computation in runtime. ## How was this patch tested? Test added. Author: DB Tsai <d_tsai@apple.com> Closes #21848 from dbtsai/short-circuit-if.
This commit is contained in:
parent
c26b092169
commit
d4c3415894
|
@ -390,6 +390,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
|
|||
case If(TrueLiteral, trueValue, _) => trueValue
|
||||
case If(FalseLiteral, _, falseValue) => falseValue
|
||||
case If(Literal(null, _), _, falseValue) => falseValue
|
||||
case If(cond, trueValue, falseValue)
|
||||
if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue
|
||||
|
||||
case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
|
||||
// If there are branches that are always false, remove them.
|
||||
|
@ -403,14 +405,14 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
|
|||
e.copy(branches = newBranches)
|
||||
}
|
||||
|
||||
case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral) =>
|
||||
case CaseWhen(branches, _) if branches.headOption.map(_._1).contains(TrueLiteral) =>
|
||||
// If the first branch is a true literal, remove the entire CaseWhen and use the value
|
||||
// from that. Note that CaseWhen.branches should never be empty, and as a result the
|
||||
// headOption (rather than head) added above is just an extra (and unnecessary) safeguard.
|
||||
branches.head._2
|
||||
|
||||
case CaseWhen(branches, _) if branches.exists(_._1 == TrueLiteral) =>
|
||||
// a branc with a TRue condition eliminates all following branches,
|
||||
// a branch with a true condition eliminates all following branches,
|
||||
// these branches can be pruned away
|
||||
val (h, t) = branches.span(_._1 != TrueLiteral)
|
||||
CaseWhen( h :+ t.head, None)
|
||||
|
@ -651,6 +653,7 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Combine nested [[Concat]] expressions.
|
||||
*/
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.optimizer
|
||||
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
|
||||
|
@ -29,7 +31,8 @@ import org.apache.spark.sql.types.{IntegerType, NullType}
|
|||
class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
|
||||
|
||||
object Optimize extends RuleExecutor[LogicalPlan] {
|
||||
val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil
|
||||
val batches = Batch("SimplifyConditionals", FixedPoint(50),
|
||||
BooleanSimplification, ConstantFolding, SimplifyConditionals) :: Nil
|
||||
}
|
||||
|
||||
protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
|
||||
|
@ -43,6 +46,8 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
|
|||
private val unreachableBranch = (FalseLiteral, Literal(20))
|
||||
private val nullBranch = (Literal.create(null, NullType), Literal(30))
|
||||
|
||||
private val testRelation = LocalRelation('a.int)
|
||||
|
||||
test("simplify if") {
|
||||
assertEquivalent(
|
||||
If(TrueLiteral, Literal(10), Literal(20)),
|
||||
|
@ -57,6 +62,23 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
|
|||
Literal(20))
|
||||
}
|
||||
|
||||
test("remove unnecessary if when the outputs are semantic equivalence") {
|
||||
assertEquivalent(
|
||||
If(IsNotNull(UnresolvedAttribute("a")),
|
||||
Subtract(Literal(10), Literal(1)),
|
||||
Add(Literal(6), Literal(3))),
|
||||
Literal(9))
|
||||
|
||||
// For non-deterministic condition, we don't remove the `If` statement.
|
||||
assertEquivalent(
|
||||
If(GreaterThan(Rand(0), Literal(0.5)),
|
||||
Subtract(Literal(10), Literal(1)),
|
||||
Add(Literal(6), Literal(3))),
|
||||
If(GreaterThan(Rand(0), Literal(0.5)),
|
||||
Literal(9),
|
||||
Literal(9)))
|
||||
}
|
||||
|
||||
test("remove unreachable branches") {
|
||||
// i.e. removing branches whose conditions are always false
|
||||
assertEquivalent(
|
||||
|
|
|
@ -393,7 +393,7 @@ private[sql] trait SQLTestUtilsBase
|
|||
}
|
||||
|
||||
/**
|
||||
* Returns full path to the given file in the resouce folder
|
||||
* Returns full path to the given file in the resource folder
|
||||
*/
|
||||
protected def testFile(fileName: String): String = {
|
||||
Thread.currentThread().getContextClassLoader.getResource(fileName).toString
|
||||
|
|
Loading…
Reference in a new issue