From 52c5ff20ca132653f505040a4dff522b136d2626 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Tue, 14 Sep 2021 09:26:11 +0900 Subject: [PATCH] [SPARK-36715][SQL] InferFiltersFromGenerate should not infer filter for udf ### What changes were proposed in this pull request? Fix InferFiltersFromGenerate bug, InferFiltersFromGenerate should not infer filter for generate when the children contain an expression which is instance of `org.apache.spark.sql.catalyst.expressions.UserDefinedExpression`. Before this pr, the following case will throw an exception. ```scala spark.udf.register("vec", (i: Int) => (0 until i).toArray) sql("select explode(vec(8)) as c1").show ``` ``` Once strategy's idempotence is broken for batch Infer Filters GlobalLimit 21 GlobalLimit 21 +- LocalLimit 21 +- LocalLimit 21 +- Project [cast(c1#3 as string) AS c1#12] +- Project [cast(c1#3 as string) AS c1#12] +- Generate explode(vec(8)), false, [c1#3] +- Generate explode(vec(8)), false, [c1#3] +- Filter ((size(vec(8), true) > 0) AND isnotnull(vec(8))) +- Filter ((size(vec(8), true) > 0) AND isnotnull(vec(8))) ! +- OneRowRelation +- Filter ((size(vec(8), true) > 0) AND isnotnull(vec(8))) ! +- OneRowRelation java.lang.RuntimeException: Once strategy's idempotence is broken for batch Infer Filters GlobalLimit 21 GlobalLimit 21 +- LocalLimit 21 +- LocalLimit 21 +- Project [cast(c1#3 as string) AS c1#12] +- Project [cast(c1#3 as string) AS c1#12] +- Generate explode(vec(8)), false, [c1#3] +- Generate explode(vec(8)), false, [c1#3] +- Filter ((size(vec(8), true) > 0) AND isnotnull(vec(8))) +- Filter ((size(vec(8), true) > 0) AND isnotnull(vec(8))) ! +- OneRowRelation +- Filter ((size(vec(8), true) > 0) AND isnotnull(vec(8))) ! +- OneRowRelation at org.apache.spark.sql.errors.QueryExecutionErrors$.onceStrategyIdempotenceIsBrokenForBatchError(QueryExecutionErrors.scala:1200) at org.apache.spark.sql.catalyst.rules.RuleExecutor.checkBatchIdempotence(RuleExecutor.scala:168) at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1(RuleExecutor.scala:254) at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1$adapted(RuleExecutor.scala:200) at scala.collection.immutable.List.foreach(List.scala:431) at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:200) at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$executeAndTrack$1(RuleExecutor.scala:179) at org.apache.spark.sql.catalyst.QueryPlanningTracker$.withTracker(QueryPlanningTracker.scala:88) at org.apache.spark.sql.catalyst.rules.RuleExecutor.executeAndTrack(RuleExecutor.scala:179) at org.apache.spark.sql.execution.QueryExecution.$anonfun$optimizedPlan$1(QueryExecution.scala:138) at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:111) at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:196) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775) at org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:196) at org.apache.spark.sql.execution.QueryExecution.optimizedPlan$lzycompute(QueryExecution.scala:134) at org.apache.spark.sql.execution.QueryExecution.optimizedPlan(QueryExecution.scala:130) at org.apache.spark.sql.execution.QueryExecution.assertOptimized(QueryExecution.scala:148) at org.apache.spark.sql.execution.QueryExecution.$anonfun$executedPlan$1(QueryExecution.scala:166) at org.apache.spark.sql.execution.QueryExecution.withCteMap(QueryExecution.scala:73) at org.apache.spark.sql.execution.QueryExecution.executedPlan$lzycompute(QueryExecution.scala:163) at org.apache.spark.sql.execution.QueryExecution.executedPlan(QueryExecution.scala:163) at org.apache.spark.sql.execution.QueryExecution.simpleString(QueryExecution.scala:214) at org.apache.spark.sql.execution.QueryExecution.org$apache$spark$sql$execution$QueryExecution$$explainString(QueryExecution.scala:259) at org.apache.spark.sql.execution.QueryExecution.explainString(QueryExecution.scala:228) at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:98) at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:163) at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:90) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775) at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64) at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3731) at org.apache.spark.sql.Dataset.head(Dataset.scala:2755) at org.apache.spark.sql.Dataset.take(Dataset.scala:2962) at org.apache.spark.sql.Dataset.getRows(Dataset.scala:288) at org.apache.spark.sql.Dataset.showString(Dataset.scala:327) at org.apache.spark.sql.Dataset.show(Dataset.scala:807) ``` ### Does this PR introduce _any_ user-facing change? No, only bug fix. ### How was this patch tested? Unit test. Closes #33956 from cfmcgrady/SPARK-36715. Authored-by: Fu Chen Signed-off-by: Hyukjin Kwon --- .../sql/catalyst/optimizer/Optimizer.scala | 3 ++- .../InferFiltersFromGenerateSuite.scala | 24 ++++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ed4ab9cb16..bf974c73d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1099,7 +1099,8 @@ object InferFiltersFromGenerate extends Rule[LogicalPlan] { // like 'size([1, 2, 3]) > 0'. These do not show up in child's constraints and // then the idempotence will break. case generate @ Generate(e, _, _, _, _, _) - if !e.deterministic || e.children.forall(_.foldable) => generate + if !e.deterministic || e.children.forall(_.foldable) || + e.children.exists(_.isInstanceOf[UserDefinedExpression]) => generate case generate @ Generate(g, _, false, _, _, _) if canInferFilters(g) => // Exclude child's constraints to guarantee idempotency diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala index 93a1d414ed..800d37eaa0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType} class InferFiltersFromGenerateSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -111,4 +113,24 @@ class InferFiltersFromGenerateSuite extends PlanTest { comparePlans(optimized, originalQuery) } } + + test("SPARK-36715: Don't infer filters from udf") { + Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f => + val returnSchema = ArrayType(StructType(Seq( + StructField("x", IntegerType), + StructField("y", StringType) + ))) + val fakeUDF = ScalaUDF( + (i: Int) => Array(Row.fromSeq(Seq(1, "a")), Row.fromSeq(Seq(2, "b"))), + returnSchema, Literal(8) :: Nil, + Option(ExpressionEncoder[Int]().resolveAndBind()) :: Nil) + val generator = f(fakeUDF) + val originalQuery = OneRowRelation().generate(generator).analyze + val optimized = OptimizeInferAndConstantFold.execute(originalQuery) + val correctAnswer = OneRowRelation() + .generate(generator) + .analyze + comparePlans(optimized, correctAnswer) + } + } }