From 8c203272de047eb3382a2b78d68308de91e6d02d Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 29 Jul 2021 17:18:23 -0700 Subject: [PATCH] [SPARK-36136][SQL][TESTS] Refactor PruneFileSourcePartitionsSuite etc to a different package ### What changes were proposed in this pull request? Move both `PruneFileSourcePartitionsSuite` and `PrunePartitionSuiteBase` to the package `org.apache.spark.sql.execution.datasources`. Did a few refactoring to enable this. ### Why are the changes needed? Currently both `PruneFileSourcePartitionsSuite` and `PrunePartitionSuiteBase` are in package `org.apache.spark.sql.hive.execution` which doesn't look correct as these tests are not specific to Hive. Therefore, it's better to move them into `org.apache.spark.sql.execution.datasources`, the same place where the rule `PruneFileSourcePartitions` is at. ### Does this PR introduce _any_ user-facing change? No, it's just test refactoring. ### How was this patch tested? Using existing tests: ``` build/sbt "sql/testOnly *PruneFileSourcePartitionsSuite" ``` and ``` build/sbt "hive/testOnly *PruneHiveTablePartitionsSuite" ``` Closes #33564 from sunchao/SPARK-36136-partitions-suite. Authored-by: Chao Sun Signed-off-by: Liang-Chi Hsieh (cherry picked from commit 0ece865ea4b78f8144defcadd143fccf3dc99743) Signed-off-by: Liang-Chi Hsieh --- .../PruneFileSourcePartitionsSuite.scala | 70 +++++++------------ .../PrunePartitionSuiteBase.scala | 17 ++--- .../PruneHiveTablePartitionsSuite.scala | 26 ++++++- 3 files changed, 59 insertions(+), 54 deletions(-) rename sql/{hive/src/test/scala/org/apache/spark/sql/hive/execution => core/src/test/scala/org/apache/spark/sql/execution/datasources}/PruneFileSourcePartitionsSuite.scala (66%) rename sql/{hive/src/test/scala/org/apache/spark/sql/hive/execution => core/src/test/scala/org/apache/spark/sql/execution/datasources}/PrunePartitionSuiteBase.scala (90%) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala similarity index 66% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala index a669b803f0..98d3d65bef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala @@ -15,27 +15,26 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.execution +package org.apache.spark.sql.execution.datasources import org.scalatest.matchers.should.Matchers._ -import org.apache.spark.metrics.source.HiveCatalogMetrics -import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} -import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.functions.broadcast import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType -class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase { +class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase with SharedSparkSession { override def format: String = "parquet" @@ -45,35 +44,27 @@ class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase { test("PruneFileSourcePartitions should not change the output of LogicalRelation") { withTable("test") { - withTempDir { dir => - sql( - s""" - |CREATE EXTERNAL TABLE test(i int) - |PARTITIONED BY (p int) - |STORED AS parquet - |LOCATION '${dir.toURI}'""".stripMargin) + spark.range(10).selectExpr("id", "id % 3 as p").write.partitionBy("p").saveAsTable("test") + val tableMeta = spark.sharedState.externalCatalog.getTable("default", "test") + val catalogFileIndex = new CatalogFileIndex(spark, tableMeta, 0) - val tableMeta = spark.sharedState.externalCatalog.getTable("default", "test") - val catalogFileIndex = new CatalogFileIndex(spark, tableMeta, 0) + val dataSchema = StructType(tableMeta.schema.filterNot { f => + tableMeta.partitionColumnNames.contains(f.name) + }) + val relation = HadoopFsRelation( + location = catalogFileIndex, + partitionSchema = tableMeta.partitionSchema, + dataSchema = dataSchema, + bucketSpec = None, + fileFormat = new ParquetFileFormat(), + options = Map.empty)(sparkSession = spark) - val dataSchema = StructType(tableMeta.schema.filterNot { f => - tableMeta.partitionColumnNames.contains(f.name) - }) - val relation = HadoopFsRelation( - location = catalogFileIndex, - partitionSchema = tableMeta.partitionSchema, - dataSchema = dataSchema, - bucketSpec = None, - fileFormat = new ParquetFileFormat(), - options = Map.empty)(sparkSession = spark) + val logicalRelation = LogicalRelation(relation, tableMeta) + val query = Project(Seq(Symbol("id"), Symbol("p")), + Filter(Symbol("p") === 1, logicalRelation)).analyze - val logicalRelation = LogicalRelation(relation, tableMeta) - val query = Project(Seq(Symbol("i"), Symbol("p")), - Filter(Symbol("p") === 1, logicalRelation)).analyze - - val optimized = Optimize.execute(query) - assert(optimized.missingInput.isEmpty) - } + val optimized = Optimize.execute(query) + assert(optimized.missingInput.isEmpty) } } @@ -116,7 +107,7 @@ class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase { // Force datasource v2 for parquet withSQLConf((SQLConf.USE_V1_SOURCE_LIST.key, "")) { withTempPath { dir => - spark.range(10).selectExpr("id", "id % 3 as p") + spark.range(10).coalesce(1).selectExpr("id", "id % 3 as p") .write.partitionBy("p").parquet(dir.getCanonicalPath) withTempView("tmp") { spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); @@ -127,19 +118,8 @@ class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase { } } - test("SPARK-36128: spark.sql.hive.metastorePartitionPruning should work for file data sources") { - Seq(true, false).foreach { enablePruning => - withTable("tbl") { - withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING.key -> enablePruning.toString) { - spark.range(10).selectExpr("id", "id % 3 as p").write.partitionBy("p").saveAsTable("tbl") - HiveCatalogMetrics.reset() - QueryTest.checkAnswer(sql("SELECT id FROM tbl WHERE p = 1"), - Seq(1, 4, 7).map(Row.apply(_)), checkToRDD = false) // avoid analyzing the query twice - val expectedCount = if (enablePruning) 1 else 3 - assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount == expectedCount) - } - } - } + protected def collectPartitionFiltersFn(): PartialFunction[SparkPlan, Seq[Expression]] = { + case scan: FileSourceScanExec => scan.partitionFilters } override def getScanExecPartitionSize(plan: SparkPlan): Long = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala similarity index 90% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala index 2a690a8105..9909996059 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala @@ -15,16 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.execution +package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.StatisticsCollectionTestBase import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryOperator, Expression, IsNotNull, Literal} -import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan} -import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf.ADAPTIVE_EXECUTION_ENABLED -abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase with TestHiveSingleton { +abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase { protected def format: String @@ -95,11 +94,11 @@ abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase with val plan = qe.sparkPlan assert(getScanExecPartitionSize(plan) == expectedPartitionCount) - val pushedDownPartitionFilters = plan.collectFirst { - case scan: FileSourceScanExec => scan.partitionFilters - case scan: HiveTableScanExec => scan.partitionPruningPred + val collectFn: PartialFunction[SparkPlan, Seq[Expression]] = collectPartitionFiltersFn orElse { case BatchScanExec(_, scan: FileScan, _) => scan.partitionFilters - }.map(exps => exps.filterNot(e => e.isInstanceOf[IsNotNull])) + } + val pushedDownPartitionFilters = plan.collectFirst(collectFn) + .map(exps => exps.filterNot(e => e.isInstanceOf[IsNotNull])) val pushedFilters = pushedDownPartitionFilters.map(filters => { filters.foldLeft("")((currentStr, exp) => { if (currentStr == "") { @@ -113,5 +112,7 @@ abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase with assert(pushedFilters == Some(expectedPushedDownFilters)) } + protected def collectPartitionFiltersFn(): PartialFunction[SparkPlan, Seq[Expression]] + protected def getScanExecPartitionSize(plan: SparkPlan): Long } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala index 677b250960..95a02d5517 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala @@ -17,14 +17,19 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.PrunePartitionSuiteBase +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.LongType -class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase { +class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase with TestHiveSingleton { override def format(): String = "hive" @@ -131,6 +136,25 @@ class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase { } } + test("SPARK-36128: spark.sql.hive.metastorePartitionPruning should work for file data sources") { + Seq(true, false).foreach { enablePruning => + withTable("tbl") { + withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING.key -> enablePruning.toString) { + spark.range(10).selectExpr("id", "id % 3 as p").write.partitionBy("p").saveAsTable("tbl") + HiveCatalogMetrics.reset() + QueryTest.checkAnswer(sql("SELECT id FROM tbl WHERE p = 1"), + Seq(1, 4, 7).map(Row.apply(_)), checkToRDD = false) // avoid analyzing the query twice + val expectedCount = if (enablePruning) 1 else 3 + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount == expectedCount) + } + } + } + } + + protected def collectPartitionFiltersFn(): PartialFunction[SparkPlan, Seq[Expression]] = { + case scan: HiveTableScanExec => scan.partitionPruningPred + } + override def getScanExecPartitionSize(plan: SparkPlan): Long = { plan.collectFirst { case p: HiveTableScanExec => p