[SPARK-35449][SQL] Only extract common expressions from CaseWhen values if elseValue is set
### What changes were proposed in this pull request? This PR fixes a bug with subexpression elimination for CaseWhen statements. https://github.com/apache/spark/pull/30245 added support for creating subexpressions that are present in all branches of conditional statements. However, for a statement to be in "all branches" of a CaseWhen statement, it must also be in the elseValue. ### Why are the changes needed? Fix a bug where a subexpression can be created and run for branches of a conditional that don't pass. This can cause issues especially with a UDF in a branch that gets executed assuming the condition is true. ### Does this PR introduce _any_ user-facing change? Yes, fixes a potential bug where a UDF could be eagerly executed even though it might expect to have already passed some form of validation. For example: ``` val col = when($"id" < 0, myUdf($"id")) spark.range(1).select(when(col > 0, col)).show() ``` `myUdf($"id")` is considered a subexpression and eagerly evaluated, because it is pulled out as a common expression from both executions of the when clause, but if `id >= 0` it should never actually be run. ### How was this patch tested? Updated existing test with new case. Closes #32595 from Kimahriman/bug-case-subexpr-elimination. Authored-by: Adam Binford <adamq43@gmail.com> Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
This commit is contained in:
parent
1530876615
commit
6c0c617bd0
|
@ -143,7 +143,13 @@ class EquivalentExpressions {
|
|||
// a subexpression among values doesn't need to be in conditions because no matter which
|
||||
// condition is true, it will be evaluated.
|
||||
val conditions = c.branches.tail.map(_._1)
|
||||
val values = c.branches.map(_._2) ++ c.elseValue
|
||||
// For an expression to be in all branch values of a CaseWhen statement, it must also be in
|
||||
// the elseValue.
|
||||
val values = if (c.elseValue.nonEmpty) {
|
||||
c.branches.map(_._2) ++ c.elseValue
|
||||
} else {
|
||||
Nil
|
||||
}
|
||||
Seq(conditions, values)
|
||||
case c: Coalesce => Seq(c.children.tail)
|
||||
case _ => Nil
|
||||
|
|
|
@ -209,7 +209,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
|||
(GreaterThan(add2, Literal(4)), add1) ::
|
||||
(GreaterThan(add2, Literal(5)), add1) :: Nil
|
||||
|
||||
val caseWhenExpr2 = CaseWhen(conditions2, None)
|
||||
val caseWhenExpr2 = CaseWhen(conditions2, add1)
|
||||
val equivalence2 = new EquivalentExpressions
|
||||
equivalence2.addExprTree(caseWhenExpr2)
|
||||
|
||||
|
@ -317,7 +317,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
|||
val add3 = Add(add1, add2)
|
||||
val condition = (GreaterThan(add3, Literal(3)), add3) :: Nil
|
||||
|
||||
val caseWhenExpr = CaseWhen(condition, None)
|
||||
val caseWhenExpr = CaseWhen(condition, Add(add3, Literal(1)))
|
||||
val equivalence = new EquivalentExpressions
|
||||
equivalence.addExprTree(caseWhenExpr)
|
||||
|
||||
|
@ -354,6 +354,22 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
|||
assert(equivalence2.getAllEquivalentExprs() ===
|
||||
Seq(Seq(add, add), Seq(Add(Literal(3), add), Add(Literal(3), add))))
|
||||
}
|
||||
|
||||
test("SPARK-35499: Subexpressions should only be extracted from CaseWhen values with an "
|
||||
+ "elseValue") {
|
||||
val add1 = Add(Literal(1), Literal(2))
|
||||
val add2 = Add(Literal(2), Literal(3))
|
||||
val conditions = (GreaterThan(add1, Literal(3)), add1) ::
|
||||
(GreaterThan(add2, Literal(4)), add1) ::
|
||||
(GreaterThan(add2, Literal(5)), add1) :: Nil
|
||||
|
||||
val caseWhenExpr = CaseWhen(conditions, None)
|
||||
val equivalence = new EquivalentExpressions
|
||||
equivalence.addExprTree(caseWhenExpr)
|
||||
|
||||
// `add1` is not in the elseValue, so we can't extract it from the branches
|
||||
assert(equivalence.getAllEquivalentExprs().count(_.size == 2) == 0)
|
||||
}
|
||||
}
|
||||
|
||||
case class CodegenFallbackExpression(child: Expression)
|
||||
|
|
|
@ -2870,13 +2870,15 @@ class DataFrameSuite extends QueryTest
|
|||
s
|
||||
})
|
||||
val df1 = spark.range(5).select(when(functions.length(simpleUDF($"id")) > 0,
|
||||
functions.length(simpleUDF($"id"))))
|
||||
functions.length(simpleUDF($"id"))).otherwise(
|
||||
functions.length(simpleUDF($"id")) + 1))
|
||||
df1.collect()
|
||||
assert(accum.value == 5)
|
||||
|
||||
val nondeterministicUDF = simpleUDF.asNondeterministic()
|
||||
val df2 = spark.range(5).select(when(functions.length(nondeterministicUDF($"id")) > 0,
|
||||
functions.length(nondeterministicUDF($"id"))))
|
||||
functions.length(nondeterministicUDF($"id"))).otherwise(
|
||||
functions.length(nondeterministicUDF($"id")) + 1))
|
||||
df2.collect()
|
||||
assert(accum.value == 15)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue