[SPARK-35449][SQL] Only extract common expressions from CaseWhen values if elseValue is set

### What changes were proposed in this pull request?

This PR fixes a bug with subexpression elimination for CaseWhen statements. https://github.com/apache/spark/pull/30245 added support for creating subexpressions that are present in all branches of conditional statements. However, for a statement to be in "all branches" of a CaseWhen statement, it must also be in the elseValue.

### Why are the changes needed?

Fix a bug where a subexpression can be created and run for branches of a conditional that don't pass. This can cause issues especially with a UDF in a branch that gets executed assuming the condition is true.

### Does this PR introduce _any_ user-facing change?

Yes, fixes a potential bug where a UDF could be eagerly executed even though it might expect to have already passed some form of validation. For example:
```
val col = when($"id" < 0, myUdf($"id"))
spark.range(1).select(when(col > 0, col)).show()
```

`myUdf($"id")` is considered a subexpression and eagerly evaluated, because it is pulled out as a common expression from both executions of the when clause, but if `id >= 0` it should never actually be run.

### How was this patch tested?

Updated existing test with new case.

Closes #32595 from Kimahriman/bug-case-subexpr-elimination.

Authored-by: Adam Binford <adamq43@gmail.com>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
This commit is contained in:
Adam Binford 2021-05-24 00:27:41 -07:00 committed by Liang-Chi Hsieh
parent 1530876615
commit 6c0c617bd0
3 changed files with 29 additions and 5 deletions

View file

@ -143,7 +143,13 @@ class EquivalentExpressions {
// a subexpression among values doesn't need to be in conditions because no matter which
// condition is true, it will be evaluated.
val conditions = c.branches.tail.map(_._1)
val values = c.branches.map(_._2) ++ c.elseValue
// For an expression to be in all branch values of a CaseWhen statement, it must also be in
// the elseValue.
val values = if (c.elseValue.nonEmpty) {
c.branches.map(_._2) ++ c.elseValue
} else {
Nil
}
Seq(conditions, values)
case c: Coalesce => Seq(c.children.tail)
case _ => Nil

View file

@ -209,7 +209,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
(GreaterThan(add2, Literal(4)), add1) ::
(GreaterThan(add2, Literal(5)), add1) :: Nil
val caseWhenExpr2 = CaseWhen(conditions2, None)
val caseWhenExpr2 = CaseWhen(conditions2, add1)
val equivalence2 = new EquivalentExpressions
equivalence2.addExprTree(caseWhenExpr2)
@ -317,7 +317,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
val add3 = Add(add1, add2)
val condition = (GreaterThan(add3, Literal(3)), add3) :: Nil
val caseWhenExpr = CaseWhen(condition, None)
val caseWhenExpr = CaseWhen(condition, Add(add3, Literal(1)))
val equivalence = new EquivalentExpressions
equivalence.addExprTree(caseWhenExpr)
@ -354,6 +354,22 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
assert(equivalence2.getAllEquivalentExprs() ===
Seq(Seq(add, add), Seq(Add(Literal(3), add), Add(Literal(3), add))))
}
test("SPARK-35499: Subexpressions should only be extracted from CaseWhen values with an "
+ "elseValue") {
val add1 = Add(Literal(1), Literal(2))
val add2 = Add(Literal(2), Literal(3))
val conditions = (GreaterThan(add1, Literal(3)), add1) ::
(GreaterThan(add2, Literal(4)), add1) ::
(GreaterThan(add2, Literal(5)), add1) :: Nil
val caseWhenExpr = CaseWhen(conditions, None)
val equivalence = new EquivalentExpressions
equivalence.addExprTree(caseWhenExpr)
// `add1` is not in the elseValue, so we can't extract it from the branches
assert(equivalence.getAllEquivalentExprs().count(_.size == 2) == 0)
}
}
case class CodegenFallbackExpression(child: Expression)

View file

@ -2870,13 +2870,15 @@ class DataFrameSuite extends QueryTest
s
})
val df1 = spark.range(5).select(when(functions.length(simpleUDF($"id")) > 0,
functions.length(simpleUDF($"id"))))
functions.length(simpleUDF($"id"))).otherwise(
functions.length(simpleUDF($"id")) + 1))
df1.collect()
assert(accum.value == 5)
val nondeterministicUDF = simpleUDF.asNondeterministic()
val df2 = spark.range(5).select(when(functions.length(nondeterministicUDF($"id")) > 0,
functions.length(nondeterministicUDF($"id"))))
functions.length(nondeterministicUDF($"id"))).otherwise(
functions.length(nondeterministicUDF($"id")) + 1))
df2.collect()
assert(accum.value == 15)
}