[SPARK-36686][SQL] Fix SimplifyConditionalsInPredicate to be null-safe

### What changes were proposed in this pull request?

fix SimplifyConditionalsInPredicate to be null-safe

Reproducible:

```
import org.apache.spark.sql.types.{StructField, BooleanType, StructType}
import org.apache.spark.sql.Row

val schema = List(
  StructField("b", BooleanType, true)
)
val data = Seq(
  Row(true),
  Row(false),
  Row(null)
)
val df = spark.createDataFrame(
  spark.sparkContext.parallelize(data),
  StructType(schema)
)

// cartesian product of true / false / null
val df2 = df.select(col("b") as "cond").crossJoin(df.select(col("b") as "falseVal"))
df2.createOrReplaceTempView("df2")

spark.sql("SELECT * FROM df2 WHERE IF(cond, FALSE, falseVal)").show()
// actual:
// +-----+--------+
// | cond|falseVal|
// +-----+--------+
// |false|    true|
// +-----+--------+
spark.sql("SET spark.sql.optimizer.excludedRules=org.apache.spark.sql.catalyst.optimizer.SimplifyConditionalsInPredicate")
spark.sql("SELECT * FROM df2 WHERE IF(cond, FALSE, falseVal)").show()
// expected:
// +-----+--------+
// | cond|falseVal|
// +-----+--------+
// |false|    true|
// | null|    true|
// +-----+--------+
```

### Why are the changes needed?

is a regression that leads to incorrect results

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

existing tests

Closes #33928 from hypercubestart/fix-SimplifyConditionalsInPredicate.

Authored-by: Andrew Liu <andrewlliu@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 9b633f2075)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Andrew Liu 2021-09-09 11:32:40 +08:00 committed by Wenchen Fan
parent 88bba0c94b
commit 6cb23c163c
10 changed files with 30 additions and 29 deletions

View file

@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression, If, Literal, Not, Or}
import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Coalesce, Expression, If, Literal, Not, Or}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
@ -28,13 +28,13 @@ import org.apache.spark.sql.types.BooleanType
* A rule that converts conditional expressions to predicate expressions, if possible, in the
* search condition of the WHERE/HAVING/ON(JOIN) clauses, which contain an implicit Boolean operator
* "(search condition) = TRUE". After this converting, we can potentially push the filter down to
* the data source.
* the data source. This rule is null-safe.
*
* Supported cases are:
* - IF(cond, trueVal, false) => AND(cond, trueVal)
* - IF(cond, trueVal, true) => OR(NOT(cond), trueVal)
* - IF(cond, false, falseVal) => AND(NOT(cond), elseVal)
* - IF(cond, true, falseVal) => OR(cond, elseVal)
* - IF(cond, false, falseVal) => AND(NOT(cond), falseVal)
* - IF(cond, true, falseVal) => OR(cond, falseVal)
* - CASE WHEN cond THEN trueVal ELSE false END => AND(cond, trueVal)
* - CASE WHEN cond THEN trueVal END => AND(cond, trueVal)
* - CASE WHEN cond THEN trueVal ELSE null END => AND(cond, trueVal)
@ -56,16 +56,17 @@ object SimplifyConditionalsInPredicate extends Rule[LogicalPlan] {
case And(left, right) => And(simplifyConditional(left), simplifyConditional(right))
case Or(left, right) => Or(simplifyConditional(left), simplifyConditional(right))
case If(cond, trueValue, FalseLiteral) => And(cond, trueValue)
case If(cond, trueValue, TrueLiteral) => Or(Not(cond), trueValue)
case If(cond, FalseLiteral, falseValue) => And(Not(cond), falseValue)
case If(cond, trueValue, TrueLiteral) => Or(Not(Coalesce(Seq(cond, FalseLiteral))), trueValue)
case If(cond, FalseLiteral, falseValue) =>
And(Not(Coalesce(Seq(cond, FalseLiteral))), falseValue)
case If(cond, TrueLiteral, falseValue) => Or(cond, falseValue)
case CaseWhen(Seq((cond, trueValue)),
Some(FalseLiteral) | Some(Literal(null, BooleanType)) | None) =>
And(cond, trueValue)
case CaseWhen(Seq((cond, trueValue)), Some(TrueLiteral)) =>
Or(Not(cond), trueValue)
Or(Not(Coalesce(Seq(cond, FalseLiteral))), trueValue)
case CaseWhen(Seq((cond, FalseLiteral)), Some(elseValue)) =>
And(Not(cond), elseValue)
And(Not(Coalesce(Seq(cond, FalseLiteral))), elseValue)
case CaseWhen(Seq((cond, TrueLiteral)), Some(elseValue)) =>
Or(cond, elseValue)
case e if e.dataType == BooleanType => e

View file

@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression, If, IsNotNull, Literal, Or, Rand}
import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Coalesce, Expression, If, IsNotNull, Literal, Not, Or, Rand}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable}
@ -65,7 +65,7 @@ class SimplifyConditionalsInPredicateSuite extends PlanTest {
UnresolvedAttribute("b"),
TrueLiteral)
val expectedCond = Or(
UnresolvedAttribute("i") <= Literal(10),
Not(Coalesce(Seq(UnresolvedAttribute("i") > Literal(10), FalseLiteral))),
UnresolvedAttribute("b"))
testFilter(originalCond, expectedCond = expectedCond)
testJoin(originalCond, expectedCond = expectedCond)
@ -80,7 +80,7 @@ class SimplifyConditionalsInPredicateSuite extends PlanTest {
FalseLiteral,
UnresolvedAttribute("b"))
val expectedCond = And(
UnresolvedAttribute("i") <= Literal(10),
Not(Coalesce(Seq(UnresolvedAttribute("i") > Literal(10), FalseLiteral))),
UnresolvedAttribute("b"))
testFilter(originalCond, expectedCond = expectedCond)
testJoin(originalCond, expectedCond = expectedCond)
@ -125,7 +125,7 @@ class SimplifyConditionalsInPredicateSuite extends PlanTest {
Seq((UnresolvedAttribute("i") > Literal(10), UnresolvedAttribute("b"))),
TrueLiteral)
val expectedCond = Or(
UnresolvedAttribute("i") <= Literal(10),
Not(Coalesce(Seq(UnresolvedAttribute("i") > Literal(10), FalseLiteral))),
UnresolvedAttribute("b"))
testFilter(originalCond, expectedCond = expectedCond)
testJoin(originalCond, expectedCond = expectedCond)
@ -139,7 +139,7 @@ class SimplifyConditionalsInPredicateSuite extends PlanTest {
Seq((UnresolvedAttribute("i") > Literal(10), FalseLiteral)),
UnresolvedAttribute("b"))
val expectedCond = And(
UnresolvedAttribute("i") <= Literal(10),
Not(Coalesce(Seq(UnresolvedAttribute("i") > Literal(10), FalseLiteral))),
UnresolvedAttribute("b"))
testFilter(originalCond, expectedCond = expectedCond)
testJoin(originalCond, expectedCond = expectedCond)

View file

@ -149,7 +149,7 @@ Results [5]: [w_warehouse_sk#10, i_item_sk#8, d_moy#7, stddev_samp(cast(inv_quan
(22) Filter [codegen id : 5]
Input [5]: [w_warehouse_sk#10, i_item_sk#8, d_moy#7, stdev#26, mean#27]
Condition : ((isnotnull(mean#27) AND isnotnull(stdev#26)) AND (NOT (mean#27 = 0.0) AND ((stdev#26 / mean#27) > 1.0)))
Condition : ((isnotnull(stdev#26) AND isnotnull(mean#27)) AND (NOT coalesce((mean#27 = 0.0), false) AND ((stdev#26 / mean#27) > 1.0)))
(23) Project [codegen id : 5]
Output [5]: [w_warehouse_sk#10, i_item_sk#8, d_moy#7, mean#27, CASE WHEN (mean#27 = 0.0) THEN null ELSE (stdev#26 / mean#27) END AS cov#28]
@ -234,7 +234,7 @@ Results [5]: [w_warehouse_sk#38, i_item_sk#37, d_moy#36, stddev_samp(cast(inv_qu
(41) Filter [codegen id : 11]
Input [5]: [w_warehouse_sk#38, i_item_sk#37, d_moy#36, stdev#26, mean#27]
Condition : ((isnotnull(mean#27) AND isnotnull(stdev#26)) AND (NOT (mean#27 = 0.0) AND ((stdev#26 / mean#27) > 1.0)))
Condition : ((isnotnull(stdev#26) AND isnotnull(mean#27)) AND (NOT coalesce((mean#27 = 0.0), false) AND ((stdev#26 / mean#27) > 1.0)))
(42) Project [codegen id : 11]
Output [5]: [w_warehouse_sk#38, i_item_sk#37, d_moy#36, mean#27 AS mean#51, CASE WHEN (mean#27 = 0.0) THEN null ELSE (stdev#26 / mean#27) END AS cov#52]

View file

@ -11,7 +11,7 @@ WholeStageCodegen (14)
Exchange [i_item_sk,w_warehouse_sk] #2
WholeStageCodegen (5)
Project [w_warehouse_sk,i_item_sk,d_moy,mean,stdev]
Filter [mean,stdev]
Filter [stdev,mean]
HashAggregate [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy,n,avg,m2,sum,count] [stddev_samp(cast(inv_quantity_on_hand as double)),avg(inv_quantity_on_hand),stdev,mean,n,avg,m2,sum,count]
InputAdapter
Exchange [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy] #3
@ -58,7 +58,7 @@ WholeStageCodegen (14)
Exchange [i_item_sk,w_warehouse_sk] #7
WholeStageCodegen (11)
Project [w_warehouse_sk,i_item_sk,d_moy,mean,stdev]
Filter [mean,stdev]
Filter [stdev,mean]
HashAggregate [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy,n,avg,m2,sum,count] [stddev_samp(cast(inv_quantity_on_hand as double)),avg(inv_quantity_on_hand),stdev,mean,n,avg,m2,sum,count]
InputAdapter
Exchange [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy] #8

View file

@ -146,7 +146,7 @@ Results [5]: [w_warehouse_sk#8, i_item_sk#6, d_moy#12, stddev_samp(cast(inv_quan
(22) Filter [codegen id : 10]
Input [5]: [w_warehouse_sk#8, i_item_sk#6, d_moy#12, stdev#26, mean#27]
Condition : ((isnotnull(mean#27) AND isnotnull(stdev#26)) AND (NOT (mean#27 = 0.0) AND ((stdev#26 / mean#27) > 1.0)))
Condition : ((isnotnull(stdev#26) AND isnotnull(mean#27)) AND (NOT coalesce((mean#27 = 0.0), false) AND ((stdev#26 / mean#27) > 1.0)))
(23) Project [codegen id : 10]
Output [5]: [w_warehouse_sk#8, i_item_sk#6, d_moy#12, mean#27, CASE WHEN (mean#27 = 0.0) THEN null ELSE (stdev#26 / mean#27) END AS cov#28]
@ -223,7 +223,7 @@ Results [5]: [w_warehouse_sk#35, i_item_sk#34, d_moy#38, stddev_samp(cast(inv_qu
(39) Filter [codegen id : 9]
Input [5]: [w_warehouse_sk#35, i_item_sk#34, d_moy#38, stdev#26, mean#27]
Condition : ((isnotnull(mean#27) AND isnotnull(stdev#26)) AND (NOT (mean#27 = 0.0) AND ((stdev#26 / mean#27) > 1.0)))
Condition : ((isnotnull(stdev#26) AND isnotnull(mean#27)) AND (NOT coalesce((mean#27 = 0.0), false) AND ((stdev#26 / mean#27) > 1.0)))
(40) Project [codegen id : 9]
Output [5]: [w_warehouse_sk#35, i_item_sk#34, d_moy#38, mean#27 AS mean#50, CASE WHEN (mean#27 = 0.0) THEN null ELSE (stdev#26 / mean#27) END AS cov#51]

View file

@ -5,7 +5,7 @@ WholeStageCodegen (11)
WholeStageCodegen (10)
BroadcastHashJoin [i_item_sk,w_warehouse_sk,i_item_sk,w_warehouse_sk]
Project [w_warehouse_sk,i_item_sk,d_moy,mean,stdev]
Filter [mean,stdev]
Filter [stdev,mean]
HashAggregate [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy,n,avg,m2,sum,count] [stddev_samp(cast(inv_quantity_on_hand as double)),avg(inv_quantity_on_hand),stdev,mean,n,avg,m2,sum,count]
InputAdapter
Exchange [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy] #2
@ -49,7 +49,7 @@ WholeStageCodegen (11)
BroadcastExchange #6
WholeStageCodegen (9)
Project [w_warehouse_sk,i_item_sk,d_moy,mean,stdev]
Filter [mean,stdev]
Filter [stdev,mean]
HashAggregate [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy,n,avg,m2,sum,count] [stddev_samp(cast(inv_quantity_on_hand as double)),avg(inv_quantity_on_hand),stdev,mean,n,avg,m2,sum,count]
InputAdapter
Exchange [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy] #7

View file

@ -149,7 +149,7 @@ Results [5]: [w_warehouse_sk#10, i_item_sk#8, d_moy#7, stddev_samp(cast(inv_quan
(22) Filter [codegen id : 5]
Input [5]: [w_warehouse_sk#10, i_item_sk#8, d_moy#7, stdev#26, mean#27]
Condition : ((isnotnull(mean#27) AND isnotnull(stdev#26)) AND ((NOT (mean#27 = 0.0) AND ((stdev#26 / mean#27) > 1.0)) AND ((stdev#26 / mean#27) > 1.5)))
Condition : ((isnotnull(stdev#26) AND isnotnull(mean#27)) AND ((NOT coalesce((mean#27 = 0.0), false) AND ((stdev#26 / mean#27) > 1.0)) AND ((stdev#26 / mean#27) > 1.5)))
(23) Project [codegen id : 5]
Output [5]: [w_warehouse_sk#10, i_item_sk#8, d_moy#7, mean#27, CASE WHEN (mean#27 = 0.0) THEN null ELSE (stdev#26 / mean#27) END AS cov#28]
@ -234,7 +234,7 @@ Results [5]: [w_warehouse_sk#38, i_item_sk#37, d_moy#36, stddev_samp(cast(inv_qu
(41) Filter [codegen id : 11]
Input [5]: [w_warehouse_sk#38, i_item_sk#37, d_moy#36, stdev#26, mean#27]
Condition : ((isnotnull(mean#27) AND isnotnull(stdev#26)) AND (NOT (mean#27 = 0.0) AND ((stdev#26 / mean#27) > 1.0)))
Condition : ((isnotnull(stdev#26) AND isnotnull(mean#27)) AND (NOT coalesce((mean#27 = 0.0), false) AND ((stdev#26 / mean#27) > 1.0)))
(42) Project [codegen id : 11]
Output [5]: [w_warehouse_sk#38, i_item_sk#37, d_moy#36, mean#27 AS mean#51, CASE WHEN (mean#27 = 0.0) THEN null ELSE (stdev#26 / mean#27) END AS cov#52]

View file

@ -11,7 +11,7 @@ WholeStageCodegen (14)
Exchange [i_item_sk,w_warehouse_sk] #2
WholeStageCodegen (5)
Project [w_warehouse_sk,i_item_sk,d_moy,mean,stdev]
Filter [mean,stdev]
Filter [stdev,mean]
HashAggregate [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy,n,avg,m2,sum,count] [stddev_samp(cast(inv_quantity_on_hand as double)),avg(inv_quantity_on_hand),stdev,mean,n,avg,m2,sum,count]
InputAdapter
Exchange [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy] #3
@ -58,7 +58,7 @@ WholeStageCodegen (14)
Exchange [i_item_sk,w_warehouse_sk] #7
WholeStageCodegen (11)
Project [w_warehouse_sk,i_item_sk,d_moy,mean,stdev]
Filter [mean,stdev]
Filter [stdev,mean]
HashAggregate [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy,n,avg,m2,sum,count] [stddev_samp(cast(inv_quantity_on_hand as double)),avg(inv_quantity_on_hand),stdev,mean,n,avg,m2,sum,count]
InputAdapter
Exchange [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy] #8

View file

@ -146,7 +146,7 @@ Results [5]: [w_warehouse_sk#8, i_item_sk#6, d_moy#12, stddev_samp(cast(inv_quan
(22) Filter [codegen id : 10]
Input [5]: [w_warehouse_sk#8, i_item_sk#6, d_moy#12, stdev#26, mean#27]
Condition : ((isnotnull(mean#27) AND isnotnull(stdev#26)) AND ((NOT (mean#27 = 0.0) AND ((stdev#26 / mean#27) > 1.0)) AND ((stdev#26 / mean#27) > 1.5)))
Condition : ((isnotnull(stdev#26) AND isnotnull(mean#27)) AND ((NOT coalesce((mean#27 = 0.0), false) AND ((stdev#26 / mean#27) > 1.0)) AND ((stdev#26 / mean#27) > 1.5)))
(23) Project [codegen id : 10]
Output [5]: [w_warehouse_sk#8, i_item_sk#6, d_moy#12, mean#27, CASE WHEN (mean#27 = 0.0) THEN null ELSE (stdev#26 / mean#27) END AS cov#28]
@ -223,7 +223,7 @@ Results [5]: [w_warehouse_sk#35, i_item_sk#34, d_moy#38, stddev_samp(cast(inv_qu
(39) Filter [codegen id : 9]
Input [5]: [w_warehouse_sk#35, i_item_sk#34, d_moy#38, stdev#26, mean#27]
Condition : ((isnotnull(mean#27) AND isnotnull(stdev#26)) AND (NOT (mean#27 = 0.0) AND ((stdev#26 / mean#27) > 1.0)))
Condition : ((isnotnull(stdev#26) AND isnotnull(mean#27)) AND (NOT coalesce((mean#27 = 0.0), false) AND ((stdev#26 / mean#27) > 1.0)))
(40) Project [codegen id : 9]
Output [5]: [w_warehouse_sk#35, i_item_sk#34, d_moy#38, mean#27 AS mean#50, CASE WHEN (mean#27 = 0.0) THEN null ELSE (stdev#26 / mean#27) END AS cov#51]

View file

@ -5,7 +5,7 @@ WholeStageCodegen (11)
WholeStageCodegen (10)
BroadcastHashJoin [i_item_sk,w_warehouse_sk,i_item_sk,w_warehouse_sk]
Project [w_warehouse_sk,i_item_sk,d_moy,mean,stdev]
Filter [mean,stdev]
Filter [stdev,mean]
HashAggregate [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy,n,avg,m2,sum,count] [stddev_samp(cast(inv_quantity_on_hand as double)),avg(inv_quantity_on_hand),stdev,mean,n,avg,m2,sum,count]
InputAdapter
Exchange [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy] #2
@ -49,7 +49,7 @@ WholeStageCodegen (11)
BroadcastExchange #6
WholeStageCodegen (9)
Project [w_warehouse_sk,i_item_sk,d_moy,mean,stdev]
Filter [mean,stdev]
Filter [stdev,mean]
HashAggregate [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy,n,avg,m2,sum,count] [stddev_samp(cast(inv_quantity_on_hand as double)),avg(inv_quantity_on_hand),stdev,mean,n,avg,m2,sum,count]
InputAdapter
Exchange [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy] #7