[SPARK-33736][SQL] Handle MERGE in ReplaceNullWithFalseInPredicate
### What changes were proposed in this pull request? This PR handles merge operations in `ReplaceNullWithFalseInPredicate`. ### Why are the changes needed? These changes are needed to match what we already do for delete and update operations. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This PR extends existing tests to cover merge operations. Closes #31579 from aokolnychyi/spark-33736. Authored-by: Anton Okolnychyi <aokolnychyi@apple.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
parent
44a9aed0d7
commit
1ad343238c
|
@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
|
||||||
|
|
||||||
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, EqualNullSafe, Expression, If, LambdaFunction, Literal, MapFilter, Or}
|
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, EqualNullSafe, Expression, If, LambdaFunction, Literal, MapFilter, Or}
|
||||||
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
|
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
|
||||||
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, Filter, Join, LogicalPlan, UpdateTable}
|
import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction, DeleteFromTable, Filter, InsertAction, Join, LogicalPlan, MergeAction, MergeIntoTable, UpdateAction, UpdateTable}
|
||||||
import org.apache.spark.sql.catalyst.rules.Rule
|
import org.apache.spark.sql.catalyst.rules.Rule
|
||||||
import org.apache.spark.sql.types.BooleanType
|
import org.apache.spark.sql.types.BooleanType
|
||||||
import org.apache.spark.util.Utils
|
import org.apache.spark.util.Utils
|
||||||
|
@ -54,6 +54,11 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
|
||||||
case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond)))
|
case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond)))
|
||||||
case d @ DeleteFromTable(_, Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond)))
|
case d @ DeleteFromTable(_, Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond)))
|
||||||
case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond)))
|
case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond)))
|
||||||
|
case m @ MergeIntoTable(_, _, mergeCond, matchedActions, notMatchedActions) =>
|
||||||
|
m.copy(
|
||||||
|
mergeCondition = replaceNullWithFalse(mergeCond),
|
||||||
|
matchedActions = replaceNullWithFalse(matchedActions),
|
||||||
|
notMatchedActions = replaceNullWithFalse(notMatchedActions))
|
||||||
case p: LogicalPlan => p transformExpressions {
|
case p: LogicalPlan => p transformExpressions {
|
||||||
// For `EqualNullSafe` with a `TrueLiteral`, whether the other side is null or false has no
|
// For `EqualNullSafe` with a `TrueLiteral`, whether the other side is null or false has no
|
||||||
// difference, as `null <=> true` and `false <=> true` both return false.
|
// difference, as `null <=> true` and `false <=> true` both return false.
|
||||||
|
@ -114,4 +119,13 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
|
||||||
e
|
e
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private def replaceNullWithFalse(mergeActions: Seq[MergeAction]): Seq[MergeAction] = {
|
||||||
|
mergeActions.map {
|
||||||
|
case u @ UpdateAction(Some(cond), _) => u.copy(condition = Some(replaceNullWithFalse(cond)))
|
||||||
|
case d @ DeleteAction(Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond)))
|
||||||
|
case i @ InsertAction(Some(cond), _) => i.copy(condition = Some(replaceNullWithFalse(cond)))
|
||||||
|
case other => other
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._
|
||||||
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
|
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
|
||||||
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
|
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
|
||||||
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
|
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
|
||||||
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable}
|
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, DeleteFromTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, UpdateAction, UpdateTable}
|
||||||
import org.apache.spark.sql.catalyst.rules.RuleExecutor
|
import org.apache.spark.sql.catalyst.rules.RuleExecutor
|
||||||
import org.apache.spark.sql.internal.SQLConf
|
import org.apache.spark.sql.internal.SQLConf
|
||||||
import org.apache.spark.sql.types.{BooleanType, IntegerType}
|
import org.apache.spark.sql.types.{BooleanType, IntegerType}
|
||||||
|
@ -50,6 +50,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
|
testJoin(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
|
||||||
testDelete(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
|
testDelete(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
|
||||||
testUpdate(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
|
testUpdate(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
|
||||||
|
testMerge(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Not expected type - replaceNullWithFalse") {
|
test("Not expected type - replaceNullWithFalse") {
|
||||||
|
@ -68,6 +69,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(originalCond, expectedCond = FalseLiteral)
|
testJoin(originalCond, expectedCond = FalseLiteral)
|
||||||
testDelete(originalCond, expectedCond = FalseLiteral)
|
testDelete(originalCond, expectedCond = FalseLiteral)
|
||||||
testUpdate(originalCond, expectedCond = FalseLiteral)
|
testUpdate(originalCond, expectedCond = FalseLiteral)
|
||||||
|
testMerge(originalCond, expectedCond = FalseLiteral)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("replace nulls in nested expressions in branches of If") {
|
test("replace nulls in nested expressions in branches of If") {
|
||||||
|
@ -79,6 +81,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(originalCond, expectedCond = FalseLiteral)
|
testJoin(originalCond, expectedCond = FalseLiteral)
|
||||||
testDelete(originalCond, expectedCond = FalseLiteral)
|
testDelete(originalCond, expectedCond = FalseLiteral)
|
||||||
testUpdate(originalCond, expectedCond = FalseLiteral)
|
testUpdate(originalCond, expectedCond = FalseLiteral)
|
||||||
|
testMerge(originalCond, expectedCond = FalseLiteral)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("replace null in elseValue of CaseWhen") {
|
test("replace null in elseValue of CaseWhen") {
|
||||||
|
@ -91,6 +94,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(originalCond, expectedCond)
|
testJoin(originalCond, expectedCond)
|
||||||
testDelete(originalCond, expectedCond)
|
testDelete(originalCond, expectedCond)
|
||||||
testUpdate(originalCond, expectedCond)
|
testUpdate(originalCond, expectedCond)
|
||||||
|
testMerge(originalCond, expectedCond)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("replace null in branch values of CaseWhen") {
|
test("replace null in branch values of CaseWhen") {
|
||||||
|
@ -102,6 +106,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(originalCond, expectedCond = FalseLiteral)
|
testJoin(originalCond, expectedCond = FalseLiteral)
|
||||||
testDelete(originalCond, expectedCond = FalseLiteral)
|
testDelete(originalCond, expectedCond = FalseLiteral)
|
||||||
testUpdate(originalCond, expectedCond = FalseLiteral)
|
testUpdate(originalCond, expectedCond = FalseLiteral)
|
||||||
|
testMerge(originalCond, expectedCond = FalseLiteral)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("replace null in branches of If inside CaseWhen") {
|
test("replace null in branches of If inside CaseWhen") {
|
||||||
|
@ -120,6 +125,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(originalCond, expectedCond)
|
testJoin(originalCond, expectedCond)
|
||||||
testDelete(originalCond, expectedCond)
|
testDelete(originalCond, expectedCond)
|
||||||
testUpdate(originalCond, expectedCond)
|
testUpdate(originalCond, expectedCond)
|
||||||
|
testMerge(originalCond, expectedCond)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("replace null in complex CaseWhen expressions") {
|
test("replace null in complex CaseWhen expressions") {
|
||||||
|
@ -141,6 +147,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(originalCond, expectedCond)
|
testJoin(originalCond, expectedCond)
|
||||||
testDelete(originalCond, expectedCond)
|
testDelete(originalCond, expectedCond)
|
||||||
testUpdate(originalCond, expectedCond)
|
testUpdate(originalCond, expectedCond)
|
||||||
|
testMerge(originalCond, expectedCond)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("replace null in Or") {
|
test("replace null in Or") {
|
||||||
|
@ -150,6 +157,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(originalCond, expectedCond)
|
testJoin(originalCond, expectedCond)
|
||||||
testDelete(originalCond, expectedCond)
|
testDelete(originalCond, expectedCond)
|
||||||
testUpdate(originalCond, expectedCond)
|
testUpdate(originalCond, expectedCond)
|
||||||
|
testMerge(originalCond, expectedCond)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("replace null in And") {
|
test("replace null in And") {
|
||||||
|
@ -158,6 +166,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(originalCond, expectedCond = FalseLiteral)
|
testJoin(originalCond, expectedCond = FalseLiteral)
|
||||||
testDelete(originalCond, expectedCond = FalseLiteral)
|
testDelete(originalCond, expectedCond = FalseLiteral)
|
||||||
testUpdate(originalCond, expectedCond = FalseLiteral)
|
testUpdate(originalCond, expectedCond = FalseLiteral)
|
||||||
|
testMerge(originalCond, expectedCond = FalseLiteral)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("replace nulls in nested And/Or expressions") {
|
test("replace nulls in nested And/Or expressions") {
|
||||||
|
@ -168,6 +177,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(originalCond, expectedCond = FalseLiteral)
|
testJoin(originalCond, expectedCond = FalseLiteral)
|
||||||
testDelete(originalCond, expectedCond = FalseLiteral)
|
testDelete(originalCond, expectedCond = FalseLiteral)
|
||||||
testUpdate(originalCond, expectedCond = FalseLiteral)
|
testUpdate(originalCond, expectedCond = FalseLiteral)
|
||||||
|
testMerge(originalCond, expectedCond = FalseLiteral)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("replace null in And inside branches of If") {
|
test("replace null in And inside branches of If") {
|
||||||
|
@ -179,6 +189,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(originalCond, expectedCond = FalseLiteral)
|
testJoin(originalCond, expectedCond = FalseLiteral)
|
||||||
testDelete(originalCond, expectedCond = FalseLiteral)
|
testDelete(originalCond, expectedCond = FalseLiteral)
|
||||||
testUpdate(originalCond, expectedCond = FalseLiteral)
|
testUpdate(originalCond, expectedCond = FalseLiteral)
|
||||||
|
testMerge(originalCond, expectedCond = FalseLiteral)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("replace null in branches of If inside And") {
|
test("replace null in branches of If inside And") {
|
||||||
|
@ -192,6 +203,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(originalCond, expectedCond = FalseLiteral)
|
testJoin(originalCond, expectedCond = FalseLiteral)
|
||||||
testDelete(originalCond, expectedCond = FalseLiteral)
|
testDelete(originalCond, expectedCond = FalseLiteral)
|
||||||
testUpdate(originalCond, expectedCond = FalseLiteral)
|
testUpdate(originalCond, expectedCond = FalseLiteral)
|
||||||
|
testMerge(originalCond, expectedCond = FalseLiteral)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("replace null in branches of If inside another If") {
|
test("replace null in branches of If inside another If") {
|
||||||
|
@ -203,6 +215,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(originalCond, expectedCond = FalseLiteral)
|
testJoin(originalCond, expectedCond = FalseLiteral)
|
||||||
testDelete(originalCond, expectedCond = FalseLiteral)
|
testDelete(originalCond, expectedCond = FalseLiteral)
|
||||||
testUpdate(originalCond, expectedCond = FalseLiteral)
|
testUpdate(originalCond, expectedCond = FalseLiteral)
|
||||||
|
testMerge(originalCond, expectedCond = FalseLiteral)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("replace null in CaseWhen inside another CaseWhen") {
|
test("replace null in CaseWhen inside another CaseWhen") {
|
||||||
|
@ -212,6 +225,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(originalCond, expectedCond = FalseLiteral)
|
testJoin(originalCond, expectedCond = FalseLiteral)
|
||||||
testDelete(originalCond, expectedCond = FalseLiteral)
|
testDelete(originalCond, expectedCond = FalseLiteral)
|
||||||
testUpdate(originalCond, expectedCond = FalseLiteral)
|
testUpdate(originalCond, expectedCond = FalseLiteral)
|
||||||
|
testMerge(originalCond, expectedCond = FalseLiteral)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("inability to replace null in non-boolean branches of If") {
|
test("inability to replace null in non-boolean branches of If") {
|
||||||
|
@ -226,6 +240,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(originalCond = condition, expectedCond = condition)
|
testJoin(originalCond = condition, expectedCond = condition)
|
||||||
testDelete(originalCond = condition, expectedCond = condition)
|
testDelete(originalCond = condition, expectedCond = condition)
|
||||||
testUpdate(originalCond = condition, expectedCond = condition)
|
testUpdate(originalCond = condition, expectedCond = condition)
|
||||||
|
testMerge(originalCond = condition, expectedCond = condition)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("inability to replace null in non-boolean values of CaseWhen") {
|
test("inability to replace null in non-boolean values of CaseWhen") {
|
||||||
|
@ -244,6 +259,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(originalCond = condition, expectedCond = expectedCond)
|
testJoin(originalCond = condition, expectedCond = expectedCond)
|
||||||
testDelete(originalCond = condition, expectedCond = expectedCond)
|
testDelete(originalCond = condition, expectedCond = expectedCond)
|
||||||
testUpdate(originalCond = condition, expectedCond = expectedCond)
|
testUpdate(originalCond = condition, expectedCond = expectedCond)
|
||||||
|
testMerge(originalCond = condition, expectedCond = expectedCond)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("inability to replace null in non-boolean branches of If inside another If") {
|
test("inability to replace null in non-boolean branches of If inside another If") {
|
||||||
|
@ -262,6 +278,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(originalCond = condition, expectedCond = expectedCond)
|
testJoin(originalCond = condition, expectedCond = expectedCond)
|
||||||
testDelete(originalCond = condition, expectedCond = expectedCond)
|
testDelete(originalCond = condition, expectedCond = expectedCond)
|
||||||
testUpdate(originalCond = condition, expectedCond = expectedCond)
|
testUpdate(originalCond = condition, expectedCond = expectedCond)
|
||||||
|
testMerge(originalCond = condition, expectedCond = expectedCond)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("replace null in If used as a join condition") {
|
test("replace null in If used as a join condition") {
|
||||||
|
@ -396,11 +413,13 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(allFalseCond, FalseLiteral)
|
testJoin(allFalseCond, FalseLiteral)
|
||||||
testDelete(allFalseCond, FalseLiteral)
|
testDelete(allFalseCond, FalseLiteral)
|
||||||
testUpdate(allFalseCond, FalseLiteral)
|
testUpdate(allFalseCond, FalseLiteral)
|
||||||
|
testMerge(allFalseCond, FalseLiteral)
|
||||||
|
|
||||||
testFilter(nonAllFalseCond, nonAllFalseCond)
|
testFilter(nonAllFalseCond, nonAllFalseCond)
|
||||||
testJoin(nonAllFalseCond, nonAllFalseCond)
|
testJoin(nonAllFalseCond, nonAllFalseCond)
|
||||||
testDelete(nonAllFalseCond, nonAllFalseCond)
|
testDelete(nonAllFalseCond, nonAllFalseCond)
|
||||||
testUpdate(nonAllFalseCond, nonAllFalseCond)
|
testUpdate(nonAllFalseCond, nonAllFalseCond)
|
||||||
|
testMerge(nonAllFalseCond, nonAllFalseCond)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("replace None of elseValue inside CaseWhen if all branches are null") {
|
test("replace None of elseValue inside CaseWhen if all branches are null") {
|
||||||
|
@ -412,6 +431,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
testJoin(allFalseCond, FalseLiteral)
|
testJoin(allFalseCond, FalseLiteral)
|
||||||
testDelete(allFalseCond, FalseLiteral)
|
testDelete(allFalseCond, FalseLiteral)
|
||||||
testUpdate(allFalseCond, FalseLiteral)
|
testUpdate(allFalseCond, FalseLiteral)
|
||||||
|
testMerge(allFalseCond, FalseLiteral)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = {
|
private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = {
|
||||||
|
@ -434,6 +454,21 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
|
||||||
test((rel, expr) => UpdateTable(rel, Seq.empty, Some(expr)), originalCond, expectedCond)
|
test((rel, expr) => UpdateTable(rel, Seq.empty, Some(expr)), originalCond, expectedCond)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private def testMerge(originalCond: Expression, expectedCond: Expression): Unit = {
|
||||||
|
val func = (rel: LogicalPlan, expr: Expression) => {
|
||||||
|
val assignments = Seq(
|
||||||
|
Assignment('i, 'i),
|
||||||
|
Assignment('b, 'b),
|
||||||
|
Assignment('a, 'a),
|
||||||
|
Assignment('m, 'm)
|
||||||
|
)
|
||||||
|
val matchedActions = UpdateAction(Some(expr), assignments) :: DeleteAction(Some(expr)) :: Nil
|
||||||
|
val notMatchedActions = InsertAction(Some(expr), assignments) :: Nil
|
||||||
|
MergeIntoTable(rel, rel, mergeCondition = expr, matchedActions, notMatchedActions)
|
||||||
|
}
|
||||||
|
test(func, originalCond, expectedCond)
|
||||||
|
}
|
||||||
|
|
||||||
private def testHigherOrderFunc(
|
private def testHigherOrderFunc(
|
||||||
argument: Expression,
|
argument: Expression,
|
||||||
createExpr: (Expression, Expression) => Expression,
|
createExpr: (Expression, Expression) => Expression,
|
||||||
|
|
Loading…
Reference in a new issue