diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index aeb236ea5a..da23b96afa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1093,7 +1093,8 @@ object InferFiltersFromGenerate extends Rule[LogicalPlan] { // like 'size([1, 2, 3]) > 0'. These do not show up in child's constraints and // then the idempotence will break. case generate @ Generate(e, _, _, _, _, _) - if !e.deterministic || e.children.forall(_.foldable) => generate + if !e.deterministic || e.children.forall(_.foldable) || + e.children.exists(_.isInstanceOf[UserDefinedExpression]) => generate case generate @ Generate(g, _, false, _, _, _) if canInferFilters(g) => // Exclude child's constraints to guarantee idempotency diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala index 93a1d414ed..800d37eaa0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType} class InferFiltersFromGenerateSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -111,4 +113,24 @@ class InferFiltersFromGenerateSuite extends PlanTest { comparePlans(optimized, originalQuery) } } + + test("SPARK-36715: Don't infer filters from udf") { + Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f => + val returnSchema = ArrayType(StructType(Seq( + StructField("x", IntegerType), + StructField("y", StringType) + ))) + val fakeUDF = ScalaUDF( + (i: Int) => Array(Row.fromSeq(Seq(1, "a")), Row.fromSeq(Seq(2, "b"))), + returnSchema, Literal(8) :: Nil, + Option(ExpressionEncoder[Int]().resolveAndBind()) :: Nil) + val generator = f(fakeUDF) + val originalQuery = OneRowRelation().generate(generator).analyze + val optimized = OptimizeInferAndConstantFold.execute(originalQuery) + val correctAnswer = OneRowRelation() + .generate(generator) + .analyze + comparePlans(optimized, correctAnswer) + } + } }