diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala index fc384fe117..d4fe272f8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala @@ -193,14 +193,14 @@ object ExplainUtils { subqueries: ArrayBuffer[(SparkPlan, Expression, BaseSubqueryExec)]): Unit = { plan.foreach { case p: SparkPlan => - p.expressions.flatMap(_.collect { + p.expressions.foreach (_.collect { case e: PlanExpression[_] => e.plan match { case s: BaseSubqueryExec => subqueries += ((p, e, s)) getSubqueries(s, subqueries) + case _ => } - case other => }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index 125cff0e66..37183556d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType @@ -36,6 +37,19 @@ class ExplainSuite extends QueryTest with SharedSparkSession { f(normalizedOutput) } + /** + * Get the explain by running the sql. The explain mode should be part of the + * sql text itself. + */ + private def withNormalizedExplain(queryText: String)(f: String => Unit) = { + val output = new java.io.ByteArrayOutputStream() + Console.withOut(output) { + sql(queryText).show(false) + } + val normalizedOutput = output.toString.replaceAll("#\\d+", "#x") + f(normalizedOutput) + } + /** * Runs the plan and makes sure the plans contains all of the keywords. */ @@ -200,6 +214,41 @@ class ExplainSuite extends QueryTest with SharedSparkSession { } } } + + test("explain formatted - check presence of subquery in case of DPP") { + withTable("df1", "df2") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false") { + withTable("df1", "df2") { + spark.range(1000).select(col("id"), col("id").as("k")) + .write + .partitionBy("k") + .format("parquet") + .mode("overwrite") + .saveAsTable("df1") + + spark.range(100) + .select(col("id"), col("id").as("k")) + .write + .partitionBy("k") + .format("parquet") + .mode("overwrite") + .saveAsTable("df2") + + val sqlText = + """ + |EXPLAIN FORMATTED SELECT df1.id, df2.k + |FROM df1 JOIN df2 ON df1.k = df2.k AND df2.id < 2 + |""".stripMargin + + val expected_pattern = "Subquery:1 Hosting operator id = 1 Hosting Expression = k#x" + withNormalizedExplain(sqlText) { normalizedOutput => + assert(expected_pattern.r.findAllMatchIn(normalizedOutput).length == 1) + } + } + } + } + } } case class ExplainSingleData(id: Int)