diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 72771e0aad..1cec7ed3a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, Predicate} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf object ExternalCatalogUtils { // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since catalyst doesn't @@ -132,6 +133,19 @@ object ExternalCatalogUtils { escapePathName(col) + "=" + partitionString } + def listPartitionsByFilter( + conf: SQLConf, + catalog: SessionCatalog, + table: CatalogTable, + partitionFilters: Seq[Expression]): Seq[CatalogTablePartition] = { + if (conf.metastorePartitionPruning) { + catalog.listPartitionsByFilter(table.identifier, partitionFilters) + } else { + ExternalCatalogUtils.prunePartitionsByFilter(table, catalog.listPartitions(table.identifier), + partitionFilters, conf.sessionLocalTimeZone) + } + } + def prunePartitionsByFilter( catalogTable: CatalogTable, inputPartitions: Seq[CatalogTablePartition], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a1c344a25d..b9663bb380 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -979,9 +979,7 @@ object SQLConf { val HIVE_METASTORE_PARTITION_PRUNING = buildConf("spark.sql.hive.metastorePartitionPruning") .doc("When true, some predicates will be pushed down into the Hive metastore so that " + - "unmatching partitions can be eliminated earlier. This only affects Hive tables " + - "not converted to filesource relations (see HiveUtils.CONVERT_METASTORE_PARQUET and " + - "HiveUtils.CONVERT_METASTORE_ORC for more information).") + "unmatching partitions can be eliminated earlier.") .version("1.5.0") .booleanConf .createWithDefault(true) @@ -1005,7 +1003,8 @@ object SQLConf { .doc("When true, enable metastore partition management for file source tables as well. " + "This includes both datasource and converted Hive tables. When partition management " + "is enabled, datasource tables store partition in the Hive metastore, and use the " + - "metastore to prune partitions during query planning.") + s"metastore to prune partitions during query planning when " + + s"$HIVE_METASTORE_PARTITION_PRUNING is set to true.") .version("2.1.1") .booleanConf .createWithDefault(true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 91313f33a7..727b33018f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.StructType @@ -70,8 +70,8 @@ class CatalogFileIndex( def filterPartitions(filters: Seq[Expression]): InMemoryFileIndex = { if (table.partitionColumnNames.nonEmpty) { val startTime = System.nanoTime() - val selectedPartitions = sparkSession.sessionState.catalog.listPartitionsByFilter( - table.identifier, filters) + val selectedPartitions = ExternalCatalogUtils.listPartitionsByFilter( + sparkSession.sessionState.conf, sparkSession.sessionState.catalog, table, filters) val partitions = selectedPartitions.map { p => val path = new Path(p.location) val fs = path.getFileSystem(hadoopConf) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala index a831e8b898..1bd47d7d7a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala @@ -56,22 +56,6 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession) normalizedFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionColumnSet))) } - /** - * Prune the hive table using filters on the partitions of the table. - */ - private def prunePartitions( - relation: HiveTableRelation, - partitionFilters: ExpressionSet): Seq[CatalogTablePartition] = { - if (conf.metastorePartitionPruning) { - session.sessionState.catalog.listPartitionsByFilter( - relation.tableMeta.identifier, partitionFilters.toSeq) - } else { - ExternalCatalogUtils.prunePartitionsByFilter(relation.tableMeta, - session.sessionState.catalog.listPartitions(relation.tableMeta.identifier), - partitionFilters.toSeq, conf.sessionLocalTimeZone) - } - } - /** * Update the statistics of the table. */ @@ -111,7 +95,8 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession) if filters.nonEmpty && relation.isPartitioned && relation.prunedPartitions.isEmpty => val partitionKeyFilters = getPartitionKeyFilters(filters, relation) if (partitionKeyFilters.nonEmpty) { - val newPartitions = prunePartitions(relation, partitionKeyFilters) + val newPartitions = ExternalCatalogUtils.listPartitionsByFilter(conf, + session.sessionState.catalog, relation.tableMeta, partitionKeyFilters.toSeq) val newTableMeta = updateTableMeta(relation, newPartitions, partitionKeyFilters) val newRelation = relation.copy( tableMeta = newTableMeta, prunedPartitions = Some(newPartitions)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index a16545a742..a669b803f0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.execution import org.scalatest.matchers.should.Matchers._ +import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -115,7 +117,7 @@ class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase { withSQLConf((SQLConf.USE_V1_SOURCE_LIST.key, "")) { withTempPath { dir => spark.range(10).selectExpr("id", "id % 3 as p") - .write.partitionBy("p").parquet(dir.getCanonicalPath) + .write.partitionBy("p").parquet(dir.getCanonicalPath) withTempView("tmp") { spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); assertPrunedPartitions("SELECT COUNT(*) FROM tmp WHERE p = 0", 1, "(tmp.p = 0)") @@ -125,6 +127,21 @@ class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase { } } + test("SPARK-36128: spark.sql.hive.metastorePartitionPruning should work for file data sources") { + Seq(true, false).foreach { enablePruning => + withTable("tbl") { + withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING.key -> enablePruning.toString) { + spark.range(10).selectExpr("id", "id % 3 as p").write.partitionBy("p").saveAsTable("tbl") + HiveCatalogMetrics.reset() + QueryTest.checkAnswer(sql("SELECT id FROM tbl WHERE p = 1"), + Seq(1, 4, 7).map(Row.apply(_)), checkToRDD = false) // avoid analyzing the query twice + val expectedCount = if (enablePruning) 1 else 3 + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount == expectedCount) + } + } + } + } + override def getScanExecPartitionSize(plan: SparkPlan): Long = { plan.collectFirst { case p: FileSourceScanExec => p.selectedPartitions.length