From 1ad343238cb82a79e51ae9d46ae704bc482ff437 Mon Sep 17 00:00:00 2001 From: Anton Okolnychyi Date: Wed, 17 Feb 2021 17:27:21 -0800 Subject: [PATCH] [SPARK-33736][SQL] Handle MERGE in ReplaceNullWithFalseInPredicate ### What changes were proposed in this pull request? This PR handles merge operations in `ReplaceNullWithFalseInPredicate`. ### Why are the changes needed? These changes are needed to match what we already do for delete and update operations. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This PR extends existing tests to cover merge operations. Closes #31579 from aokolnychyi/spark-33736. Authored-by: Anton Okolnychyi Signed-off-by: Dongjoon Hyun --- .../ReplaceNullWithFalseInPredicate.scala | 16 +++++++- ...ReplaceNullWithFalseInPredicateSuite.scala | 37 ++++++++++++++++++- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index 2f95f242c8..327856956c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, EqualNullSafe, Expression, If, LambdaFunction, Literal, MapFilter, Or} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} -import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, Filter, Join, LogicalPlan, UpdateTable} +import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction, DeleteFromTable, Filter, InsertAction, Join, LogicalPlan, MergeAction, MergeIntoTable, UpdateAction, UpdateTable} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.BooleanType import org.apache.spark.util.Utils @@ -54,6 +54,11 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond))) case d @ DeleteFromTable(_, Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond))) case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond))) + case m @ MergeIntoTable(_, _, mergeCond, matchedActions, notMatchedActions) => + m.copy( + mergeCondition = replaceNullWithFalse(mergeCond), + matchedActions = replaceNullWithFalse(matchedActions), + notMatchedActions = replaceNullWithFalse(notMatchedActions)) case p: LogicalPlan => p transformExpressions { // For `EqualNullSafe` with a `TrueLiteral`, whether the other side is null or false has no // difference, as `null <=> true` and `false <=> true` both return false. @@ -114,4 +119,13 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { e } } + + private def replaceNullWithFalse(mergeActions: Seq[MergeAction]): Seq[MergeAction] = { + mergeActions.map { + case u @ UpdateAction(Some(cond), _) => u.copy(condition = Some(replaceNullWithFalse(cond))) + case d @ DeleteAction(Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond))) + case i @ InsertAction(Some(cond), _) => i.copy(condition = Some(replaceNullWithFalse(cond))) + case other => other + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index ffab358721..5183cca1eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable} +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, DeleteFromTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, UpdateAction, UpdateTable} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, IntegerType} @@ -50,6 +50,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) testDelete(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) testUpdate(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) + testMerge(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) } test("Not expected type - replaceNullWithFalse") { @@ -68,6 +69,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral) testUpdate(originalCond, expectedCond = FalseLiteral) + testMerge(originalCond, expectedCond = FalseLiteral) } test("replace nulls in nested expressions in branches of If") { @@ -79,6 +81,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral) testUpdate(originalCond, expectedCond = FalseLiteral) + testMerge(originalCond, expectedCond = FalseLiteral) } test("replace null in elseValue of CaseWhen") { @@ -91,6 +94,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(originalCond, expectedCond) testDelete(originalCond, expectedCond) testUpdate(originalCond, expectedCond) + testMerge(originalCond, expectedCond) } test("replace null in branch values of CaseWhen") { @@ -102,6 +106,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral) testUpdate(originalCond, expectedCond = FalseLiteral) + testMerge(originalCond, expectedCond = FalseLiteral) } test("replace null in branches of If inside CaseWhen") { @@ -120,6 +125,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(originalCond, expectedCond) testDelete(originalCond, expectedCond) testUpdate(originalCond, expectedCond) + testMerge(originalCond, expectedCond) } test("replace null in complex CaseWhen expressions") { @@ -141,6 +147,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(originalCond, expectedCond) testDelete(originalCond, expectedCond) testUpdate(originalCond, expectedCond) + testMerge(originalCond, expectedCond) } test("replace null in Or") { @@ -150,6 +157,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(originalCond, expectedCond) testDelete(originalCond, expectedCond) testUpdate(originalCond, expectedCond) + testMerge(originalCond, expectedCond) } test("replace null in And") { @@ -158,6 +166,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral) testUpdate(originalCond, expectedCond = FalseLiteral) + testMerge(originalCond, expectedCond = FalseLiteral) } test("replace nulls in nested And/Or expressions") { @@ -168,6 +177,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral) testUpdate(originalCond, expectedCond = FalseLiteral) + testMerge(originalCond, expectedCond = FalseLiteral) } test("replace null in And inside branches of If") { @@ -179,6 +189,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral) testUpdate(originalCond, expectedCond = FalseLiteral) + testMerge(originalCond, expectedCond = FalseLiteral) } test("replace null in branches of If inside And") { @@ -192,6 +203,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral) testUpdate(originalCond, expectedCond = FalseLiteral) + testMerge(originalCond, expectedCond = FalseLiteral) } test("replace null in branches of If inside another If") { @@ -203,6 +215,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral) testUpdate(originalCond, expectedCond = FalseLiteral) + testMerge(originalCond, expectedCond = FalseLiteral) } test("replace null in CaseWhen inside another CaseWhen") { @@ -212,6 +225,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral) testUpdate(originalCond, expectedCond = FalseLiteral) + testMerge(originalCond, expectedCond = FalseLiteral) } test("inability to replace null in non-boolean branches of If") { @@ -226,6 +240,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(originalCond = condition, expectedCond = condition) testDelete(originalCond = condition, expectedCond = condition) testUpdate(originalCond = condition, expectedCond = condition) + testMerge(originalCond = condition, expectedCond = condition) } test("inability to replace null in non-boolean values of CaseWhen") { @@ -244,6 +259,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(originalCond = condition, expectedCond = expectedCond) testDelete(originalCond = condition, expectedCond = expectedCond) testUpdate(originalCond = condition, expectedCond = expectedCond) + testMerge(originalCond = condition, expectedCond = expectedCond) } test("inability to replace null in non-boolean branches of If inside another If") { @@ -262,6 +278,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(originalCond = condition, expectedCond = expectedCond) testDelete(originalCond = condition, expectedCond = expectedCond) testUpdate(originalCond = condition, expectedCond = expectedCond) + testMerge(originalCond = condition, expectedCond = expectedCond) } test("replace null in If used as a join condition") { @@ -396,11 +413,13 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(allFalseCond, FalseLiteral) testDelete(allFalseCond, FalseLiteral) testUpdate(allFalseCond, FalseLiteral) + testMerge(allFalseCond, FalseLiteral) testFilter(nonAllFalseCond, nonAllFalseCond) testJoin(nonAllFalseCond, nonAllFalseCond) testDelete(nonAllFalseCond, nonAllFalseCond) testUpdate(nonAllFalseCond, nonAllFalseCond) + testMerge(nonAllFalseCond, nonAllFalseCond) } test("replace None of elseValue inside CaseWhen if all branches are null") { @@ -412,6 +431,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testJoin(allFalseCond, FalseLiteral) testDelete(allFalseCond, FalseLiteral) testUpdate(allFalseCond, FalseLiteral) + testMerge(allFalseCond, FalseLiteral) } private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = { @@ -434,6 +454,21 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { test((rel, expr) => UpdateTable(rel, Seq.empty, Some(expr)), originalCond, expectedCond) } + private def testMerge(originalCond: Expression, expectedCond: Expression): Unit = { + val func = (rel: LogicalPlan, expr: Expression) => { + val assignments = Seq( + Assignment('i, 'i), + Assignment('b, 'b), + Assignment('a, 'a), + Assignment('m, 'm) + ) + val matchedActions = UpdateAction(Some(expr), assignments) :: DeleteAction(Some(expr)) :: Nil + val notMatchedActions = InsertAction(Some(expr), assignments) :: Nil + MergeIntoTable(rel, rel, mergeCondition = expr, matchedActions, notMatchedActions) + } + test(func, originalCond, expectedCond) + } + private def testHigherOrderFunc( argument: Expression, createExpr: (Expression, Expression) => Expression,