[SPARK-33847][SQL][FOLLOWUP] Remove the CaseWhen should consider deterministic

### What changes were proposed in this pull request?

This pr fix remove the `CaseWhen` if elseValue is empty and other outputs are null because of we should consider deterministic.

### Why are the changes needed?

Fix bug.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit test.

Closes #30960 from wangyum/SPARK-33847-2.

Authored-by: Yuming Wang <yumwang@ebay.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Yuming Wang 2020-12-29 14:35:01 +00:00 committed by Wenchen Fan
parent 16c594de79
commit c42502493a
5 changed files with 23 additions and 28 deletions

View file

@ -98,12 +98,8 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
val newBranches = cw.branches.map { case (cond, value) =>
replaceNullWithFalse(cond) -> replaceNullWithFalse(value)
}
if (newBranches.forall(_._2 == FalseLiteral) && cw.elseValue.isEmpty) {
FalseLiteral
} else {
val newElseValue = cw.elseValue.map(replaceNullWithFalse)
CaseWhen(newBranches, newElseValue)
}
val newElseValue = cw.elseValue.map(replaceNullWithFalse).getOrElse(FalseLiteral)
CaseWhen(newBranches, newElseValue)
case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType =>
If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal))
case e if e.dataType == BooleanType =>

View file

@ -515,8 +515,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
val (h, t) = branches.span(_._1 != TrueLiteral)
CaseWhen( h :+ t.head, None)
case e @ CaseWhen(branches, Some(elseValue))
if branches.forall(_._2.semanticEquals(elseValue)) =>
case e @ CaseWhen(branches, elseOpt)
if branches.forall(_._2.semanticEquals(elseOpt.getOrElse(Literal(null, e.dataType)))) =>
val elseValue = elseOpt.getOrElse(Literal(null, e.dataType))
// For non-deterministic conditions with side effect, we can not remove it, or change
// the ordering. As a result, we try to remove the deterministic conditions from the tail.
var hitNonDeterministicCond = false
@ -532,10 +533,6 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
} else {
e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue)))
}
case e @ CaseWhen(branches, None)
if branches.forall(_._2.semanticEquals(Literal(null, e.dataType))) =>
Literal(null, e.dataType)
}
}
}

View file

@ -260,14 +260,13 @@ class PushFoldableIntoBranchesSuite
}
test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") {
Seq(a, LessThan(Rand(1), Literal(0.5))).foreach { condition =>
assertEquivalent(
EqualTo(CaseWhen(Seq((condition, Literal.create(null, IntegerType)))), Literal(2)),
Literal.create(null, BooleanType))
assertEquivalent(
EqualTo(CaseWhen(Seq((condition, Literal("str")))).cast(IntegerType), Literal(2)),
Literal.create(null, BooleanType))
}
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal.create(null, IntegerType)))), Literal(2)),
Literal.create(null, BooleanType))
assertEquivalent(
EqualTo(CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal("str")))).cast(IntegerType),
Literal(2)),
CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal.create(null, BooleanType)))))
}
test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") {

View file

@ -114,7 +114,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
val expectedBranches = Seq(
(UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral,
(UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
val expectedCond = CaseWhen(expectedBranches)
val expectedCond = CaseWhen(expectedBranches, FalseLiteral)
testFilter(originalCond, expectedCond)
testJoin(originalCond, expectedCond)
@ -135,7 +135,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
(UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
(UnresolvedAttribute("i") > Literal(10)) -> FalseLiteral,
TrueLiteral -> TrueLiteral)
val expectedCond = CaseWhen(expectedBranches)
val expectedCond = CaseWhen(expectedBranches, FalseLiteral)
testFilter(originalCond, expectedCond)
testJoin(originalCond, expectedCond)
@ -238,7 +238,8 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
FalseLiteral)
val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue))
val expectedCond = CaseWhen(Seq(
(UnresolvedAttribute("i") > Literal(10), (Literal(2) === nestedCaseWhen) <=> TrueLiteral)))
(UnresolvedAttribute("i") > Literal(10), (Literal(2) === nestedCaseWhen) <=> TrueLiteral)),
FalseLiteral)
testFilter(originalCond = condition, expectedCond = expectedCond)
testJoin(originalCond = condition, expectedCond = expectedCond)
testDelete(originalCond = condition, expectedCond = expectedCond)

View file

@ -237,11 +237,13 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
}
test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") {
Seq(GreaterThan('a, 1), GreaterThan(Rand(0), 1)).foreach { condition =>
assertEquivalent(
CaseWhen((condition, Literal.create(null, IntegerType)) :: Nil, None),
Literal.create(null, IntegerType))
}
assertEquivalent(
CaseWhen((GreaterThan('a, 1), Literal.create(null, IntegerType)) :: Nil, None),
Literal.create(null, IntegerType))
assertEquivalent(
CaseWhen((GreaterThan(Rand(0), 0.5), Literal.create(null, IntegerType)) :: Nil, None),
CaseWhen((GreaterThan(Rand(0), 0.5), Literal.create(null, IntegerType)) :: Nil, None))
}
test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") {