[SPARK-36715][SQL] InferFiltersFromGenerate should not infer filter for udf

### What changes were proposed in this pull request?

Fix InferFiltersFromGenerate bug, InferFiltersFromGenerate should not infer filter for generate when the children contain an expression which is instance of `org.apache.spark.sql.catalyst.expressions.UserDefinedExpression`.
Before this pr, the following case will throw an exception.

```scala
spark.udf.register("vec", (i: Int) => (0 until i).toArray)
sql("select explode(vec(8)) as c1").show
```

```
Once strategy's idempotence is broken for batch Infer Filters
 GlobalLimit 21                                                        GlobalLimit 21
 +- LocalLimit 21                                                      +- LocalLimit 21
    +- Project [cast(c1#3 as string) AS c1#12]                            +- Project [cast(c1#3 as string) AS c1#12]
       +- Generate explode(vec(8)), false, [c1#3]                            +- Generate explode(vec(8)), false, [c1#3]
          +- Filter ((size(vec(8), true) > 0) AND isnotnull(vec(8)))            +- Filter ((size(vec(8), true) > 0) AND isnotnull(vec(8)))
!            +- OneRowRelation                                                     +- Filter ((size(vec(8), true) > 0) AND isnotnull(vec(8)))
!                                                                                     +- OneRowRelation

java.lang.RuntimeException:
Once strategy's idempotence is broken for batch Infer Filters
 GlobalLimit 21                                                        GlobalLimit 21
 +- LocalLimit 21                                                      +- LocalLimit 21
    +- Project [cast(c1#3 as string) AS c1#12]                            +- Project [cast(c1#3 as string) AS c1#12]
       +- Generate explode(vec(8)), false, [c1#3]                            +- Generate explode(vec(8)), false, [c1#3]
          +- Filter ((size(vec(8), true) > 0) AND isnotnull(vec(8)))            +- Filter ((size(vec(8), true) > 0) AND isnotnull(vec(8)))
!            +- OneRowRelation                                                     +- Filter ((size(vec(8), true) > 0) AND isnotnull(vec(8)))
!                                                                                     +- OneRowRelation

	at org.apache.spark.sql.errors.QueryExecutionErrors$.onceStrategyIdempotenceIsBrokenForBatchError(QueryExecutionErrors.scala:1200)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.checkBatchIdempotence(RuleExecutor.scala:168)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1(RuleExecutor.scala:254)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1$adapted(RuleExecutor.scala:200)
	at scala.collection.immutable.List.foreach(List.scala:431)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:200)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$executeAndTrack$1(RuleExecutor.scala:179)
	at org.apache.spark.sql.catalyst.QueryPlanningTracker$.withTracker(QueryPlanningTracker.scala:88)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.executeAndTrack(RuleExecutor.scala:179)
	at org.apache.spark.sql.execution.QueryExecution.$anonfun$optimizedPlan$1(QueryExecution.scala:138)
	at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:111)
	at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:196)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
	at org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:196)
	at org.apache.spark.sql.execution.QueryExecution.optimizedPlan$lzycompute(QueryExecution.scala:134)
	at org.apache.spark.sql.execution.QueryExecution.optimizedPlan(QueryExecution.scala:130)
	at org.apache.spark.sql.execution.QueryExecution.assertOptimized(QueryExecution.scala:148)
	at org.apache.spark.sql.execution.QueryExecution.$anonfun$executedPlan$1(QueryExecution.scala:166)
	at org.apache.spark.sql.execution.QueryExecution.withCteMap(QueryExecution.scala:73)
	at org.apache.spark.sql.execution.QueryExecution.executedPlan$lzycompute(QueryExecution.scala:163)
	at org.apache.spark.sql.execution.QueryExecution.executedPlan(QueryExecution.scala:163)
	at org.apache.spark.sql.execution.QueryExecution.simpleString(QueryExecution.scala:214)
	at org.apache.spark.sql.execution.QueryExecution.org$apache$spark$sql$execution$QueryExecution$$explainString(QueryExecution.scala:259)
	at org.apache.spark.sql.execution.QueryExecution.explainString(QueryExecution.scala:228)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:98)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:163)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:90)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3731)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2755)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:2962)
	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:288)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:327)
	at org.apache.spark.sql.Dataset.show(Dataset.scala:807)
```

### Does this PR introduce _any_ user-facing change?

No, only bug fix.

### How was this patch tested?

Unit test.

Closes #33956 from cfmcgrady/SPARK-36715.

Authored-by: Fu Chen <cfmcgrady@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
Fu Chen 2021-09-14 09:26:11 +09:00 committed by Hyukjin Kwon
parent a440025f08
commit 52c5ff20ca
2 changed files with 25 additions and 2 deletions

View file

@ -1099,7 +1099,8 @@ object InferFiltersFromGenerate extends Rule[LogicalPlan] {
// like 'size([1, 2, 3]) > 0'. These do not show up in child's constraints and
// then the idempotence will break.
case generate @ Generate(e, _, _, _, _, _)
if !e.deterministic || e.children.forall(_.foldable) => generate
if !e.deterministic || e.children.forall(_.foldable) ||
e.children.exists(_.isInstanceOf[UserDefinedExpression]) => generate
case generate @ Generate(g, _, false, _, _, _) if canInferFilters(g) =>
// Exclude child's constraints to guarantee idempotency

View file

@ -17,14 +17,16 @@
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType}
class InferFiltersFromGenerateSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
@ -111,4 +113,24 @@ class InferFiltersFromGenerateSuite extends PlanTest {
comparePlans(optimized, originalQuery)
}
}
test("SPARK-36715: Don't infer filters from udf") {
Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f =>
val returnSchema = ArrayType(StructType(Seq(
StructField("x", IntegerType),
StructField("y", StringType)
)))
val fakeUDF = ScalaUDF(
(i: Int) => Array(Row.fromSeq(Seq(1, "a")), Row.fromSeq(Seq(2, "b"))),
returnSchema, Literal(8) :: Nil,
Option(ExpressionEncoder[Int]().resolveAndBind()) :: Nil)
val generator = f(fakeUDF)
val originalQuery = OneRowRelation().generate(generator).analyze
val optimized = OptimizeInferAndConstantFold.execute(originalQuery)
val correctAnswer = OneRowRelation()
.generate(generator)
.analyze
comparePlans(optimized, correctAnswer)
}
}
}