[SPARK-33736][SQL] Handle MERGE in ReplaceNullWithFalseInPredicate

### What changes were proposed in this pull request?

This PR handles merge operations in `ReplaceNullWithFalseInPredicate`.

### Why are the changes needed?

These changes are needed to match what we already do for delete and update operations.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

This PR extends existing tests to cover merge operations.

Closes #31579 from aokolnychyi/spark-33736.

Authored-by: Anton Okolnychyi <aokolnychyi@apple.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
Anton Okolnychyi 2021-02-17 17:27:21 -08:00 committed by Dongjoon Hyun
parent 44a9aed0d7
commit 1ad343238c
2 changed files with 51 additions and 2 deletions

View file

@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, EqualNullSafe, Expression, If, LambdaFunction, Literal, MapFilter, Or} import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, EqualNullSafe, Expression, If, LambdaFunction, Literal, MapFilter, Or}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, Filter, Join, LogicalPlan, UpdateTable} import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction, DeleteFromTable, Filter, InsertAction, Join, LogicalPlan, MergeAction, MergeIntoTable, UpdateAction, UpdateTable}
import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.BooleanType import org.apache.spark.sql.types.BooleanType
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
@ -54,6 +54,11 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond))) case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond)))
case d @ DeleteFromTable(_, Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond))) case d @ DeleteFromTable(_, Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond)))
case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond))) case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond)))
case m @ MergeIntoTable(_, _, mergeCond, matchedActions, notMatchedActions) =>
m.copy(
mergeCondition = replaceNullWithFalse(mergeCond),
matchedActions = replaceNullWithFalse(matchedActions),
notMatchedActions = replaceNullWithFalse(notMatchedActions))
case p: LogicalPlan => p transformExpressions { case p: LogicalPlan => p transformExpressions {
// For `EqualNullSafe` with a `TrueLiteral`, whether the other side is null or false has no // For `EqualNullSafe` with a `TrueLiteral`, whether the other side is null or false has no
// difference, as `null <=> true` and `false <=> true` both return false. // difference, as `null <=> true` and `false <=> true` both return false.
@ -114,4 +119,13 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
e e
} }
} }
private def replaceNullWithFalse(mergeActions: Seq[MergeAction]): Seq[MergeAction] = {
mergeActions.map {
case u @ UpdateAction(Some(cond), _) => u.copy(condition = Some(replaceNullWithFalse(cond)))
case d @ DeleteAction(Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond)))
case i @ InsertAction(Some(cond), _) => i.copy(condition = Some(replaceNullWithFalse(cond)))
case other => other
}
}
} }

View file

@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable} import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, DeleteFromTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, UpdateAction, UpdateTable}
import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, IntegerType} import org.apache.spark.sql.types.{BooleanType, IntegerType}
@ -50,6 +50,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) testJoin(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
testDelete(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) testDelete(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
testUpdate(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) testUpdate(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
testMerge(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
} }
test("Not expected type - replaceNullWithFalse") { test("Not expected type - replaceNullWithFalse") {
@ -68,6 +69,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral)
testDelete(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral)
testUpdate(originalCond, expectedCond = FalseLiteral) testUpdate(originalCond, expectedCond = FalseLiteral)
testMerge(originalCond, expectedCond = FalseLiteral)
} }
test("replace nulls in nested expressions in branches of If") { test("replace nulls in nested expressions in branches of If") {
@ -79,6 +81,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral)
testDelete(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral)
testUpdate(originalCond, expectedCond = FalseLiteral) testUpdate(originalCond, expectedCond = FalseLiteral)
testMerge(originalCond, expectedCond = FalseLiteral)
} }
test("replace null in elseValue of CaseWhen") { test("replace null in elseValue of CaseWhen") {
@ -91,6 +94,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond) testJoin(originalCond, expectedCond)
testDelete(originalCond, expectedCond) testDelete(originalCond, expectedCond)
testUpdate(originalCond, expectedCond) testUpdate(originalCond, expectedCond)
testMerge(originalCond, expectedCond)
} }
test("replace null in branch values of CaseWhen") { test("replace null in branch values of CaseWhen") {
@ -102,6 +106,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral)
testDelete(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral)
testUpdate(originalCond, expectedCond = FalseLiteral) testUpdate(originalCond, expectedCond = FalseLiteral)
testMerge(originalCond, expectedCond = FalseLiteral)
} }
test("replace null in branches of If inside CaseWhen") { test("replace null in branches of If inside CaseWhen") {
@ -120,6 +125,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond) testJoin(originalCond, expectedCond)
testDelete(originalCond, expectedCond) testDelete(originalCond, expectedCond)
testUpdate(originalCond, expectedCond) testUpdate(originalCond, expectedCond)
testMerge(originalCond, expectedCond)
} }
test("replace null in complex CaseWhen expressions") { test("replace null in complex CaseWhen expressions") {
@ -141,6 +147,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond) testJoin(originalCond, expectedCond)
testDelete(originalCond, expectedCond) testDelete(originalCond, expectedCond)
testUpdate(originalCond, expectedCond) testUpdate(originalCond, expectedCond)
testMerge(originalCond, expectedCond)
} }
test("replace null in Or") { test("replace null in Or") {
@ -150,6 +157,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond) testJoin(originalCond, expectedCond)
testDelete(originalCond, expectedCond) testDelete(originalCond, expectedCond)
testUpdate(originalCond, expectedCond) testUpdate(originalCond, expectedCond)
testMerge(originalCond, expectedCond)
} }
test("replace null in And") { test("replace null in And") {
@ -158,6 +166,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral)
testDelete(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral)
testUpdate(originalCond, expectedCond = FalseLiteral) testUpdate(originalCond, expectedCond = FalseLiteral)
testMerge(originalCond, expectedCond = FalseLiteral)
} }
test("replace nulls in nested And/Or expressions") { test("replace nulls in nested And/Or expressions") {
@ -168,6 +177,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral)
testDelete(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral)
testUpdate(originalCond, expectedCond = FalseLiteral) testUpdate(originalCond, expectedCond = FalseLiteral)
testMerge(originalCond, expectedCond = FalseLiteral)
} }
test("replace null in And inside branches of If") { test("replace null in And inside branches of If") {
@ -179,6 +189,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral)
testDelete(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral)
testUpdate(originalCond, expectedCond = FalseLiteral) testUpdate(originalCond, expectedCond = FalseLiteral)
testMerge(originalCond, expectedCond = FalseLiteral)
} }
test("replace null in branches of If inside And") { test("replace null in branches of If inside And") {
@ -192,6 +203,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral)
testDelete(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral)
testUpdate(originalCond, expectedCond = FalseLiteral) testUpdate(originalCond, expectedCond = FalseLiteral)
testMerge(originalCond, expectedCond = FalseLiteral)
} }
test("replace null in branches of If inside another If") { test("replace null in branches of If inside another If") {
@ -203,6 +215,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral)
testDelete(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral)
testUpdate(originalCond, expectedCond = FalseLiteral) testUpdate(originalCond, expectedCond = FalseLiteral)
testMerge(originalCond, expectedCond = FalseLiteral)
} }
test("replace null in CaseWhen inside another CaseWhen") { test("replace null in CaseWhen inside another CaseWhen") {
@ -212,6 +225,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral)
testDelete(originalCond, expectedCond = FalseLiteral) testDelete(originalCond, expectedCond = FalseLiteral)
testUpdate(originalCond, expectedCond = FalseLiteral) testUpdate(originalCond, expectedCond = FalseLiteral)
testMerge(originalCond, expectedCond = FalseLiteral)
} }
test("inability to replace null in non-boolean branches of If") { test("inability to replace null in non-boolean branches of If") {
@ -226,6 +240,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond = condition, expectedCond = condition) testJoin(originalCond = condition, expectedCond = condition)
testDelete(originalCond = condition, expectedCond = condition) testDelete(originalCond = condition, expectedCond = condition)
testUpdate(originalCond = condition, expectedCond = condition) testUpdate(originalCond = condition, expectedCond = condition)
testMerge(originalCond = condition, expectedCond = condition)
} }
test("inability to replace null in non-boolean values of CaseWhen") { test("inability to replace null in non-boolean values of CaseWhen") {
@ -244,6 +259,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond = condition, expectedCond = expectedCond) testJoin(originalCond = condition, expectedCond = expectedCond)
testDelete(originalCond = condition, expectedCond = expectedCond) testDelete(originalCond = condition, expectedCond = expectedCond)
testUpdate(originalCond = condition, expectedCond = expectedCond) testUpdate(originalCond = condition, expectedCond = expectedCond)
testMerge(originalCond = condition, expectedCond = expectedCond)
} }
test("inability to replace null in non-boolean branches of If inside another If") { test("inability to replace null in non-boolean branches of If inside another If") {
@ -262,6 +278,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond = condition, expectedCond = expectedCond) testJoin(originalCond = condition, expectedCond = expectedCond)
testDelete(originalCond = condition, expectedCond = expectedCond) testDelete(originalCond = condition, expectedCond = expectedCond)
testUpdate(originalCond = condition, expectedCond = expectedCond) testUpdate(originalCond = condition, expectedCond = expectedCond)
testMerge(originalCond = condition, expectedCond = expectedCond)
} }
test("replace null in If used as a join condition") { test("replace null in If used as a join condition") {
@ -396,11 +413,13 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(allFalseCond, FalseLiteral) testJoin(allFalseCond, FalseLiteral)
testDelete(allFalseCond, FalseLiteral) testDelete(allFalseCond, FalseLiteral)
testUpdate(allFalseCond, FalseLiteral) testUpdate(allFalseCond, FalseLiteral)
testMerge(allFalseCond, FalseLiteral)
testFilter(nonAllFalseCond, nonAllFalseCond) testFilter(nonAllFalseCond, nonAllFalseCond)
testJoin(nonAllFalseCond, nonAllFalseCond) testJoin(nonAllFalseCond, nonAllFalseCond)
testDelete(nonAllFalseCond, nonAllFalseCond) testDelete(nonAllFalseCond, nonAllFalseCond)
testUpdate(nonAllFalseCond, nonAllFalseCond) testUpdate(nonAllFalseCond, nonAllFalseCond)
testMerge(nonAllFalseCond, nonAllFalseCond)
} }
test("replace None of elseValue inside CaseWhen if all branches are null") { test("replace None of elseValue inside CaseWhen if all branches are null") {
@ -412,6 +431,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(allFalseCond, FalseLiteral) testJoin(allFalseCond, FalseLiteral)
testDelete(allFalseCond, FalseLiteral) testDelete(allFalseCond, FalseLiteral)
testUpdate(allFalseCond, FalseLiteral) testUpdate(allFalseCond, FalseLiteral)
testMerge(allFalseCond, FalseLiteral)
} }
private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = { private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = {
@ -434,6 +454,21 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
test((rel, expr) => UpdateTable(rel, Seq.empty, Some(expr)), originalCond, expectedCond) test((rel, expr) => UpdateTable(rel, Seq.empty, Some(expr)), originalCond, expectedCond)
} }
private def testMerge(originalCond: Expression, expectedCond: Expression): Unit = {
val func = (rel: LogicalPlan, expr: Expression) => {
val assignments = Seq(
Assignment('i, 'i),
Assignment('b, 'b),
Assignment('a, 'a),
Assignment('m, 'm)
)
val matchedActions = UpdateAction(Some(expr), assignments) :: DeleteAction(Some(expr)) :: Nil
val notMatchedActions = InsertAction(Some(expr), assignments) :: Nil
MergeIntoTable(rel, rel, mergeCondition = expr, matchedActions, notMatchedActions)
}
test(func, originalCond, expectedCond)
}
private def testHigherOrderFunc( private def testHigherOrderFunc(
argument: Expression, argument: Expression,
createExpr: (Expression, Expression) => Expression, createExpr: (Expression, Expression) => Expression,