[SPARK-36686][SQL] Fix SimplifyConditionalsInPredicate to be null-safe
### What changes were proposed in this pull request? fix SimplifyConditionalsInPredicate to be null-safe Reproducible: ``` import org.apache.spark.sql.types.{StructField, BooleanType, StructType} import org.apache.spark.sql.Row val schema = List( StructField("b", BooleanType, true) ) val data = Seq( Row(true), Row(false), Row(null) ) val df = spark.createDataFrame( spark.sparkContext.parallelize(data), StructType(schema) ) // cartesian product of true / false / null val df2 = df.select(col("b") as "cond").crossJoin(df.select(col("b") as "falseVal")) df2.createOrReplaceTempView("df2") spark.sql("SELECT * FROM df2 WHERE IF(cond, FALSE, falseVal)").show() // actual: // +-----+--------+ // | cond|falseVal| // +-----+--------+ // |false| true| // +-----+--------+ spark.sql("SET spark.sql.optimizer.excludedRules=org.apache.spark.sql.catalyst.optimizer.SimplifyConditionalsInPredicate") spark.sql("SELECT * FROM df2 WHERE IF(cond, FALSE, falseVal)").show() // expected: // +-----+--------+ // | cond|falseVal| // +-----+--------+ // |false| true| // | null| true| // +-----+--------+ ``` ### Why are the changes needed? is a regression that leads to incorrect results ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests Closes #33928 from hypercubestart/fix-SimplifyConditionalsInPredicate. Authored-by: Andrew Liu <andrewlliu@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
67421d80b8
commit
9b633f2075
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.optimizer
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression, If, Literal, Not, Or}
|
||||
import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Coalesce, Expression, If, Literal, Not, Or}
|
||||
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
|
@ -28,13 +28,13 @@ import org.apache.spark.sql.types.BooleanType
|
|||
* A rule that converts conditional expressions to predicate expressions, if possible, in the
|
||||
* search condition of the WHERE/HAVING/ON(JOIN) clauses, which contain an implicit Boolean operator
|
||||
* "(search condition) = TRUE". After this converting, we can potentially push the filter down to
|
||||
* the data source.
|
||||
* the data source. This rule is null-safe.
|
||||
*
|
||||
* Supported cases are:
|
||||
* - IF(cond, trueVal, false) => AND(cond, trueVal)
|
||||
* - IF(cond, trueVal, true) => OR(NOT(cond), trueVal)
|
||||
* - IF(cond, false, falseVal) => AND(NOT(cond), elseVal)
|
||||
* - IF(cond, true, falseVal) => OR(cond, elseVal)
|
||||
* - IF(cond, false, falseVal) => AND(NOT(cond), falseVal)
|
||||
* - IF(cond, true, falseVal) => OR(cond, falseVal)
|
||||
* - CASE WHEN cond THEN trueVal ELSE false END => AND(cond, trueVal)
|
||||
* - CASE WHEN cond THEN trueVal END => AND(cond, trueVal)
|
||||
* - CASE WHEN cond THEN trueVal ELSE null END => AND(cond, trueVal)
|
||||
|
@ -56,16 +56,17 @@ object SimplifyConditionalsInPredicate extends Rule[LogicalPlan] {
|
|||
case And(left, right) => And(simplifyConditional(left), simplifyConditional(right))
|
||||
case Or(left, right) => Or(simplifyConditional(left), simplifyConditional(right))
|
||||
case If(cond, trueValue, FalseLiteral) => And(cond, trueValue)
|
||||
case If(cond, trueValue, TrueLiteral) => Or(Not(cond), trueValue)
|
||||
case If(cond, FalseLiteral, falseValue) => And(Not(cond), falseValue)
|
||||
case If(cond, trueValue, TrueLiteral) => Or(Not(Coalesce(Seq(cond, FalseLiteral))), trueValue)
|
||||
case If(cond, FalseLiteral, falseValue) =>
|
||||
And(Not(Coalesce(Seq(cond, FalseLiteral))), falseValue)
|
||||
case If(cond, TrueLiteral, falseValue) => Or(cond, falseValue)
|
||||
case CaseWhen(Seq((cond, trueValue)),
|
||||
Some(FalseLiteral) | Some(Literal(null, BooleanType)) | None) =>
|
||||
And(cond, trueValue)
|
||||
case CaseWhen(Seq((cond, trueValue)), Some(TrueLiteral)) =>
|
||||
Or(Not(cond), trueValue)
|
||||
Or(Not(Coalesce(Seq(cond, FalseLiteral))), trueValue)
|
||||
case CaseWhen(Seq((cond, FalseLiteral)), Some(elseValue)) =>
|
||||
And(Not(cond), elseValue)
|
||||
And(Not(Coalesce(Seq(cond, FalseLiteral))), elseValue)
|
||||
case CaseWhen(Seq((cond, TrueLiteral)), Some(elseValue)) =>
|
||||
Or(cond, elseValue)
|
||||
case e if e.dataType == BooleanType => e
|
||||
|
|
|
@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
|
|||
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression, If, IsNotNull, Literal, Or, Rand}
|
||||
import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Coalesce, Expression, If, IsNotNull, Literal, Not, Or, Rand}
|
||||
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
|
||||
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable}
|
||||
|
@ -65,7 +65,7 @@ class SimplifyConditionalsInPredicateSuite extends PlanTest {
|
|||
UnresolvedAttribute("b"),
|
||||
TrueLiteral)
|
||||
val expectedCond = Or(
|
||||
UnresolvedAttribute("i") <= Literal(10),
|
||||
Not(Coalesce(Seq(UnresolvedAttribute("i") > Literal(10), FalseLiteral))),
|
||||
UnresolvedAttribute("b"))
|
||||
testFilter(originalCond, expectedCond = expectedCond)
|
||||
testJoin(originalCond, expectedCond = expectedCond)
|
||||
|
@ -80,7 +80,7 @@ class SimplifyConditionalsInPredicateSuite extends PlanTest {
|
|||
FalseLiteral,
|
||||
UnresolvedAttribute("b"))
|
||||
val expectedCond = And(
|
||||
UnresolvedAttribute("i") <= Literal(10),
|
||||
Not(Coalesce(Seq(UnresolvedAttribute("i") > Literal(10), FalseLiteral))),
|
||||
UnresolvedAttribute("b"))
|
||||
testFilter(originalCond, expectedCond = expectedCond)
|
||||
testJoin(originalCond, expectedCond = expectedCond)
|
||||
|
@ -125,7 +125,7 @@ class SimplifyConditionalsInPredicateSuite extends PlanTest {
|
|||
Seq((UnresolvedAttribute("i") > Literal(10), UnresolvedAttribute("b"))),
|
||||
TrueLiteral)
|
||||
val expectedCond = Or(
|
||||
UnresolvedAttribute("i") <= Literal(10),
|
||||
Not(Coalesce(Seq(UnresolvedAttribute("i") > Literal(10), FalseLiteral))),
|
||||
UnresolvedAttribute("b"))
|
||||
testFilter(originalCond, expectedCond = expectedCond)
|
||||
testJoin(originalCond, expectedCond = expectedCond)
|
||||
|
@ -139,7 +139,7 @@ class SimplifyConditionalsInPredicateSuite extends PlanTest {
|
|||
Seq((UnresolvedAttribute("i") > Literal(10), FalseLiteral)),
|
||||
UnresolvedAttribute("b"))
|
||||
val expectedCond = And(
|
||||
UnresolvedAttribute("i") <= Literal(10),
|
||||
Not(Coalesce(Seq(UnresolvedAttribute("i") > Literal(10), FalseLiteral))),
|
||||
UnresolvedAttribute("b"))
|
||||
testFilter(originalCond, expectedCond = expectedCond)
|
||||
testJoin(originalCond, expectedCond = expectedCond)
|
||||
|
|
|
@ -149,7 +149,7 @@ Results [5]: [w_warehouse_sk#10, i_item_sk#8, d_moy#7, stddev_samp(cast(inv_quan
|
|||
|
||||
(22) Filter [codegen id : 5]
|
||||
Input [5]: [w_warehouse_sk#10, i_item_sk#8, d_moy#7, stdev#26, mean#27]
|
||||
Condition : ((isnotnull(mean#27) AND isnotnull(stdev#26)) AND (NOT (mean#27 = 0.0) AND ((stdev#26 / mean#27) > 1.0)))
|
||||
Condition : ((isnotnull(stdev#26) AND isnotnull(mean#27)) AND (NOT coalesce((mean#27 = 0.0), false) AND ((stdev#26 / mean#27) > 1.0)))
|
||||
|
||||
(23) Project [codegen id : 5]
|
||||
Output [5]: [w_warehouse_sk#10, i_item_sk#8, d_moy#7, mean#27, CASE WHEN (mean#27 = 0.0) THEN null ELSE (stdev#26 / mean#27) END AS cov#28]
|
||||
|
@ -234,7 +234,7 @@ Results [5]: [w_warehouse_sk#38, i_item_sk#37, d_moy#36, stddev_samp(cast(inv_qu
|
|||
|
||||
(41) Filter [codegen id : 11]
|
||||
Input [5]: [w_warehouse_sk#38, i_item_sk#37, d_moy#36, stdev#26, mean#27]
|
||||
Condition : ((isnotnull(mean#27) AND isnotnull(stdev#26)) AND (NOT (mean#27 = 0.0) AND ((stdev#26 / mean#27) > 1.0)))
|
||||
Condition : ((isnotnull(stdev#26) AND isnotnull(mean#27)) AND (NOT coalesce((mean#27 = 0.0), false) AND ((stdev#26 / mean#27) > 1.0)))
|
||||
|
||||
(42) Project [codegen id : 11]
|
||||
Output [5]: [w_warehouse_sk#38, i_item_sk#37, d_moy#36, mean#27 AS mean#51, CASE WHEN (mean#27 = 0.0) THEN null ELSE (stdev#26 / mean#27) END AS cov#52]
|
||||
|
|
|
@ -11,7 +11,7 @@ WholeStageCodegen (14)
|
|||
Exchange [i_item_sk,w_warehouse_sk] #2
|
||||
WholeStageCodegen (5)
|
||||
Project [w_warehouse_sk,i_item_sk,d_moy,mean,stdev]
|
||||
Filter [mean,stdev]
|
||||
Filter [stdev,mean]
|
||||
HashAggregate [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy,n,avg,m2,sum,count] [stddev_samp(cast(inv_quantity_on_hand as double)),avg(inv_quantity_on_hand),stdev,mean,n,avg,m2,sum,count]
|
||||
InputAdapter
|
||||
Exchange [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy] #3
|
||||
|
@ -58,7 +58,7 @@ WholeStageCodegen (14)
|
|||
Exchange [i_item_sk,w_warehouse_sk] #7
|
||||
WholeStageCodegen (11)
|
||||
Project [w_warehouse_sk,i_item_sk,d_moy,mean,stdev]
|
||||
Filter [mean,stdev]
|
||||
Filter [stdev,mean]
|
||||
HashAggregate [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy,n,avg,m2,sum,count] [stddev_samp(cast(inv_quantity_on_hand as double)),avg(inv_quantity_on_hand),stdev,mean,n,avg,m2,sum,count]
|
||||
InputAdapter
|
||||
Exchange [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy] #8
|
||||
|
|
|
@ -146,7 +146,7 @@ Results [5]: [w_warehouse_sk#8, i_item_sk#6, d_moy#12, stddev_samp(cast(inv_quan
|
|||
|
||||
(22) Filter [codegen id : 10]
|
||||
Input [5]: [w_warehouse_sk#8, i_item_sk#6, d_moy#12, stdev#26, mean#27]
|
||||
Condition : ((isnotnull(mean#27) AND isnotnull(stdev#26)) AND (NOT (mean#27 = 0.0) AND ((stdev#26 / mean#27) > 1.0)))
|
||||
Condition : ((isnotnull(stdev#26) AND isnotnull(mean#27)) AND (NOT coalesce((mean#27 = 0.0), false) AND ((stdev#26 / mean#27) > 1.0)))
|
||||
|
||||
(23) Project [codegen id : 10]
|
||||
Output [5]: [w_warehouse_sk#8, i_item_sk#6, d_moy#12, mean#27, CASE WHEN (mean#27 = 0.0) THEN null ELSE (stdev#26 / mean#27) END AS cov#28]
|
||||
|
@ -223,7 +223,7 @@ Results [5]: [w_warehouse_sk#35, i_item_sk#34, d_moy#38, stddev_samp(cast(inv_qu
|
|||
|
||||
(39) Filter [codegen id : 9]
|
||||
Input [5]: [w_warehouse_sk#35, i_item_sk#34, d_moy#38, stdev#26, mean#27]
|
||||
Condition : ((isnotnull(mean#27) AND isnotnull(stdev#26)) AND (NOT (mean#27 = 0.0) AND ((stdev#26 / mean#27) > 1.0)))
|
||||
Condition : ((isnotnull(stdev#26) AND isnotnull(mean#27)) AND (NOT coalesce((mean#27 = 0.0), false) AND ((stdev#26 / mean#27) > 1.0)))
|
||||
|
||||
(40) Project [codegen id : 9]
|
||||
Output [5]: [w_warehouse_sk#35, i_item_sk#34, d_moy#38, mean#27 AS mean#50, CASE WHEN (mean#27 = 0.0) THEN null ELSE (stdev#26 / mean#27) END AS cov#51]
|
||||
|
|
|
@ -5,7 +5,7 @@ WholeStageCodegen (11)
|
|||
WholeStageCodegen (10)
|
||||
BroadcastHashJoin [i_item_sk,w_warehouse_sk,i_item_sk,w_warehouse_sk]
|
||||
Project [w_warehouse_sk,i_item_sk,d_moy,mean,stdev]
|
||||
Filter [mean,stdev]
|
||||
Filter [stdev,mean]
|
||||
HashAggregate [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy,n,avg,m2,sum,count] [stddev_samp(cast(inv_quantity_on_hand as double)),avg(inv_quantity_on_hand),stdev,mean,n,avg,m2,sum,count]
|
||||
InputAdapter
|
||||
Exchange [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy] #2
|
||||
|
@ -49,7 +49,7 @@ WholeStageCodegen (11)
|
|||
BroadcastExchange #6
|
||||
WholeStageCodegen (9)
|
||||
Project [w_warehouse_sk,i_item_sk,d_moy,mean,stdev]
|
||||
Filter [mean,stdev]
|
||||
Filter [stdev,mean]
|
||||
HashAggregate [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy,n,avg,m2,sum,count] [stddev_samp(cast(inv_quantity_on_hand as double)),avg(inv_quantity_on_hand),stdev,mean,n,avg,m2,sum,count]
|
||||
InputAdapter
|
||||
Exchange [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy] #7
|
||||
|
|
|
@ -149,7 +149,7 @@ Results [5]: [w_warehouse_sk#10, i_item_sk#8, d_moy#7, stddev_samp(cast(inv_quan
|
|||
|
||||
(22) Filter [codegen id : 5]
|
||||
Input [5]: [w_warehouse_sk#10, i_item_sk#8, d_moy#7, stdev#26, mean#27]
|
||||
Condition : ((isnotnull(mean#27) AND isnotnull(stdev#26)) AND ((NOT (mean#27 = 0.0) AND ((stdev#26 / mean#27) > 1.0)) AND ((stdev#26 / mean#27) > 1.5)))
|
||||
Condition : ((isnotnull(stdev#26) AND isnotnull(mean#27)) AND ((NOT coalesce((mean#27 = 0.0), false) AND ((stdev#26 / mean#27) > 1.0)) AND ((stdev#26 / mean#27) > 1.5)))
|
||||
|
||||
(23) Project [codegen id : 5]
|
||||
Output [5]: [w_warehouse_sk#10, i_item_sk#8, d_moy#7, mean#27, CASE WHEN (mean#27 = 0.0) THEN null ELSE (stdev#26 / mean#27) END AS cov#28]
|
||||
|
@ -234,7 +234,7 @@ Results [5]: [w_warehouse_sk#38, i_item_sk#37, d_moy#36, stddev_samp(cast(inv_qu
|
|||
|
||||
(41) Filter [codegen id : 11]
|
||||
Input [5]: [w_warehouse_sk#38, i_item_sk#37, d_moy#36, stdev#26, mean#27]
|
||||
Condition : ((isnotnull(mean#27) AND isnotnull(stdev#26)) AND (NOT (mean#27 = 0.0) AND ((stdev#26 / mean#27) > 1.0)))
|
||||
Condition : ((isnotnull(stdev#26) AND isnotnull(mean#27)) AND (NOT coalesce((mean#27 = 0.0), false) AND ((stdev#26 / mean#27) > 1.0)))
|
||||
|
||||
(42) Project [codegen id : 11]
|
||||
Output [5]: [w_warehouse_sk#38, i_item_sk#37, d_moy#36, mean#27 AS mean#51, CASE WHEN (mean#27 = 0.0) THEN null ELSE (stdev#26 / mean#27) END AS cov#52]
|
||||
|
|
|
@ -11,7 +11,7 @@ WholeStageCodegen (14)
|
|||
Exchange [i_item_sk,w_warehouse_sk] #2
|
||||
WholeStageCodegen (5)
|
||||
Project [w_warehouse_sk,i_item_sk,d_moy,mean,stdev]
|
||||
Filter [mean,stdev]
|
||||
Filter [stdev,mean]
|
||||
HashAggregate [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy,n,avg,m2,sum,count] [stddev_samp(cast(inv_quantity_on_hand as double)),avg(inv_quantity_on_hand),stdev,mean,n,avg,m2,sum,count]
|
||||
InputAdapter
|
||||
Exchange [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy] #3
|
||||
|
@ -58,7 +58,7 @@ WholeStageCodegen (14)
|
|||
Exchange [i_item_sk,w_warehouse_sk] #7
|
||||
WholeStageCodegen (11)
|
||||
Project [w_warehouse_sk,i_item_sk,d_moy,mean,stdev]
|
||||
Filter [mean,stdev]
|
||||
Filter [stdev,mean]
|
||||
HashAggregate [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy,n,avg,m2,sum,count] [stddev_samp(cast(inv_quantity_on_hand as double)),avg(inv_quantity_on_hand),stdev,mean,n,avg,m2,sum,count]
|
||||
InputAdapter
|
||||
Exchange [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy] #8
|
||||
|
|
|
@ -146,7 +146,7 @@ Results [5]: [w_warehouse_sk#8, i_item_sk#6, d_moy#12, stddev_samp(cast(inv_quan
|
|||
|
||||
(22) Filter [codegen id : 10]
|
||||
Input [5]: [w_warehouse_sk#8, i_item_sk#6, d_moy#12, stdev#26, mean#27]
|
||||
Condition : ((isnotnull(mean#27) AND isnotnull(stdev#26)) AND ((NOT (mean#27 = 0.0) AND ((stdev#26 / mean#27) > 1.0)) AND ((stdev#26 / mean#27) > 1.5)))
|
||||
Condition : ((isnotnull(stdev#26) AND isnotnull(mean#27)) AND ((NOT coalesce((mean#27 = 0.0), false) AND ((stdev#26 / mean#27) > 1.0)) AND ((stdev#26 / mean#27) > 1.5)))
|
||||
|
||||
(23) Project [codegen id : 10]
|
||||
Output [5]: [w_warehouse_sk#8, i_item_sk#6, d_moy#12, mean#27, CASE WHEN (mean#27 = 0.0) THEN null ELSE (stdev#26 / mean#27) END AS cov#28]
|
||||
|
@ -223,7 +223,7 @@ Results [5]: [w_warehouse_sk#35, i_item_sk#34, d_moy#38, stddev_samp(cast(inv_qu
|
|||
|
||||
(39) Filter [codegen id : 9]
|
||||
Input [5]: [w_warehouse_sk#35, i_item_sk#34, d_moy#38, stdev#26, mean#27]
|
||||
Condition : ((isnotnull(mean#27) AND isnotnull(stdev#26)) AND (NOT (mean#27 = 0.0) AND ((stdev#26 / mean#27) > 1.0)))
|
||||
Condition : ((isnotnull(stdev#26) AND isnotnull(mean#27)) AND (NOT coalesce((mean#27 = 0.0), false) AND ((stdev#26 / mean#27) > 1.0)))
|
||||
|
||||
(40) Project [codegen id : 9]
|
||||
Output [5]: [w_warehouse_sk#35, i_item_sk#34, d_moy#38, mean#27 AS mean#50, CASE WHEN (mean#27 = 0.0) THEN null ELSE (stdev#26 / mean#27) END AS cov#51]
|
||||
|
|
|
@ -5,7 +5,7 @@ WholeStageCodegen (11)
|
|||
WholeStageCodegen (10)
|
||||
BroadcastHashJoin [i_item_sk,w_warehouse_sk,i_item_sk,w_warehouse_sk]
|
||||
Project [w_warehouse_sk,i_item_sk,d_moy,mean,stdev]
|
||||
Filter [mean,stdev]
|
||||
Filter [stdev,mean]
|
||||
HashAggregate [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy,n,avg,m2,sum,count] [stddev_samp(cast(inv_quantity_on_hand as double)),avg(inv_quantity_on_hand),stdev,mean,n,avg,m2,sum,count]
|
||||
InputAdapter
|
||||
Exchange [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy] #2
|
||||
|
@ -49,7 +49,7 @@ WholeStageCodegen (11)
|
|||
BroadcastExchange #6
|
||||
WholeStageCodegen (9)
|
||||
Project [w_warehouse_sk,i_item_sk,d_moy,mean,stdev]
|
||||
Filter [mean,stdev]
|
||||
Filter [stdev,mean]
|
||||
HashAggregate [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy,n,avg,m2,sum,count] [stddev_samp(cast(inv_quantity_on_hand as double)),avg(inv_quantity_on_hand),stdev,mean,n,avg,m2,sum,count]
|
||||
InputAdapter
|
||||
Exchange [w_warehouse_name,w_warehouse_sk,i_item_sk,d_moy] #7
|
||||
|
|
Loading…
Reference in a new issue