diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index df3da3e8a9..2f95f242c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -98,12 +98,8 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { val newBranches = cw.branches.map { case (cond, value) => replaceNullWithFalse(cond) -> replaceNullWithFalse(value) } - if (newBranches.forall(_._2 == FalseLiteral) && cw.elseValue.isEmpty) { - FalseLiteral - } else { - val newElseValue = cw.elseValue.map(replaceNullWithFalse) - CaseWhen(newBranches, newElseValue) - } + val newElseValue = cw.elseValue.map(replaceNullWithFalse).getOrElse(FalseLiteral) + CaseWhen(newBranches, newElseValue) case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType => If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal)) case e if e.dataType == BooleanType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1b93d51496..819bffeafb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -515,8 +515,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { val (h, t) = branches.span(_._1 != TrueLiteral) CaseWhen( h :+ t.head, None) - case e @ CaseWhen(branches, Some(elseValue)) - if branches.forall(_._2.semanticEquals(elseValue)) => + case e @ CaseWhen(branches, elseOpt) + if branches.forall(_._2.semanticEquals(elseOpt.getOrElse(Literal(null, e.dataType)))) => + val elseValue = elseOpt.getOrElse(Literal(null, e.dataType)) // For non-deterministic conditions with side effect, we can not remove it, or change // the ordering. As a result, we try to remove the deterministic conditions from the tail. var hitNonDeterministicCond = false @@ -532,10 +533,6 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { } else { e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) } - - case e @ CaseWhen(branches, None) - if branches.forall(_._2.semanticEquals(Literal(null, e.dataType))) => - Literal(null, e.dataType) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index 0d5218ac62..cb90a39860 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -260,14 +260,13 @@ class PushFoldableIntoBranchesSuite } test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") { - Seq(a, LessThan(Rand(1), Literal(0.5))).foreach { condition => - assertEquivalent( - EqualTo(CaseWhen(Seq((condition, Literal.create(null, IntegerType)))), Literal(2)), - Literal.create(null, BooleanType)) - assertEquivalent( - EqualTo(CaseWhen(Seq((condition, Literal("str")))).cast(IntegerType), Literal(2)), - Literal.create(null, BooleanType)) - } + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal.create(null, IntegerType)))), Literal(2)), + Literal.create(null, BooleanType)) + assertEquivalent( + EqualTo(CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal("str")))).cast(IntegerType), + Literal(2)), + CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal.create(null, BooleanType))))) } test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index ae97d53256..ffab358721 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -114,7 +114,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val expectedBranches = Seq( (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral, (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral) - val expectedCond = CaseWhen(expectedBranches) + val expectedCond = CaseWhen(expectedBranches, FalseLiteral) testFilter(originalCond, expectedCond) testJoin(originalCond, expectedCond) @@ -135,7 +135,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral, (UnresolvedAttribute("i") > Literal(10)) -> FalseLiteral, TrueLiteral -> TrueLiteral) - val expectedCond = CaseWhen(expectedBranches) + val expectedCond = CaseWhen(expectedBranches, FalseLiteral) testFilter(originalCond, expectedCond) testJoin(originalCond, expectedCond) @@ -238,7 +238,8 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { FalseLiteral) val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue)) val expectedCond = CaseWhen(Seq( - (UnresolvedAttribute("i") > Literal(10), (Literal(2) === nestedCaseWhen) <=> TrueLiteral))) + (UnresolvedAttribute("i") > Literal(10), (Literal(2) === nestedCaseWhen) <=> TrueLiteral)), + FalseLiteral) testFilter(originalCond = condition, expectedCond = expectedCond) testJoin(originalCond = condition, expectedCond = expectedCond) testDelete(originalCond = condition, expectedCond = expectedCond) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index f3edd70bcf..2a685bfeef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -237,11 +237,13 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P } test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") { - Seq(GreaterThan('a, 1), GreaterThan(Rand(0), 1)).foreach { condition => - assertEquivalent( - CaseWhen((condition, Literal.create(null, IntegerType)) :: Nil, None), - Literal.create(null, IntegerType)) - } + assertEquivalent( + CaseWhen((GreaterThan('a, 1), Literal.create(null, IntegerType)) :: Nil, None), + Literal.create(null, IntegerType)) + + assertEquivalent( + CaseWhen((GreaterThan(Rand(0), 0.5), Literal.create(null, IntegerType)) :: Nil, None), + CaseWhen((GreaterThan(Rand(0), 0.5), Literal.create(null, IntegerType)) :: Nil, None)) } test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") {