From 5bc06fd7d9e756a7d40011e3cda57494859f692a Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Wed, 14 Jul 2021 15:57:10 +0800 Subject: [PATCH] [SPARK-36130][SQL] UnwrapCastInBinaryComparison should skip In expression when in.list contains an expression that is not literal ### What changes were proposed in this pull request? Fix [comment](https://github.com/apache/spark/pull/32488#issuecomment-879315179) This PR fix rule `UnwrapCastInBinaryComparison` bug. Rule UnwrapCastInBinaryComparison should skip In expression when in.list contains an expression that is not literal. - In Before this pr, the following example will throw an exception. ```scala withTable("tbl") { sql("CREATE TABLE tbl (d decimal(33, 27)) USING PARQUET") sql("SELECT d FROM tbl WHERE d NOT IN (d + 1)") } ``` - InSet As the analyzer guarantee that all the elements in the `inSet.hset` are literal, so this is not an issue for `InSet`. https://github.com/apache/spark/blob/fbf53dee37129a493a4e5d5a007625b35f44fbda/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala#L264-L279 ### Does this PR introduce _any_ user-facing change? No, only bug fix. ### How was this patch tested? New test. Closes #33335 from cfmcgrady/SPARK-36130. Authored-by: Fu Chen Signed-off-by: Wenchen Fan (cherry picked from commit 103d16e868e3caaa08401e0398c20b4a4574c6b7) Signed-off-by: Wenchen Fan --- .../UnwrapCastInBinaryComparison.scala | 3 ++- .../UnwrapCastInBinaryComparisonSuite.scala | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index d5ff0fc349..08c4cbfe77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -141,8 +141,9 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { // values. // 2. this rule only handles the case when both `fromExp` and value in `in.list` are of numeric // type. + // 3. this rule doesn't optimize In when `in.list` contains an expression that is not literal. case in @ In(Cast(fromExp, toType: NumericType, _, _), list @ Seq(firstLit, _*)) - if canImplicitlyCast(fromExp, toType, firstLit.dataType) => + if canImplicitlyCast(fromExp, toType, firstLit.dataType) && in.inSetConvertible => // There are 3 kinds of literals in the list: // 1. null literals diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala index e5df1abf00..31f62cf28e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala @@ -283,6 +283,25 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp ) } + test("SPARK-36130: unwrap In should skip when in.list contains an expression that " + + "is not literal") { + val add = Cast(f2, DoubleType) + 1.0d + val doubleLit = Literal.create(null, DoubleType) + assertEquivalent(In(Cast(f2, DoubleType), Seq(add)), In(Cast(f2, DoubleType), Seq(add))) + assertEquivalent( + In(Cast(f2, DoubleType), Seq(doubleLit, add)), + In(Cast(f2, DoubleType), Seq(doubleLit, add))) + assertEquivalent( + In(Cast(f2, DoubleType), Seq(doubleLit, 1.0d, add)), + In(Cast(f2, DoubleType), Seq(doubleLit, 1.0d, add))) + assertEquivalent( + In(Cast(f2, DoubleType), Seq(1.0d, add)), + In(Cast(f2, DoubleType), Seq(1.0d, add))) + assertEquivalent( + In(Cast(f2, DoubleType), Seq(0.0d, 1.0d, add)), + In(Cast(f2, DoubleType), Seq(0.0d, 1.0d, add))) + } + private def castInt(e: Expression): Expression = Cast(e, IntegerType) private def castDouble(e: Expression): Expression = Cast(e, DoubleType) private def castDecimal2(e: Expression): Expression = Cast(e, DecimalType(10, 4))