[SPARK-36136][SQL][TESTS] Refactor PruneFileSourcePartitionsSuite etc to a different package
### What changes were proposed in this pull request? Move both `PruneFileSourcePartitionsSuite` and `PrunePartitionSuiteBase` to the package `org.apache.spark.sql.execution.datasources`. Did a few refactoring to enable this. ### Why are the changes needed? Currently both `PruneFileSourcePartitionsSuite` and `PrunePartitionSuiteBase` are in package `org.apache.spark.sql.hive.execution` which doesn't look correct as these tests are not specific to Hive. Therefore, it's better to move them into `org.apache.spark.sql.execution.datasources`, the same place where the rule `PruneFileSourcePartitions` is at. ### Does this PR introduce _any_ user-facing change? No, it's just test refactoring. ### How was this patch tested? Using existing tests: ``` build/sbt "sql/testOnly *PruneFileSourcePartitionsSuite" ``` and ``` build/sbt "hive/testOnly *PruneHiveTablePartitionsSuite" ``` Closes #33564 from sunchao/SPARK-36136-partitions-suite. Authored-by: Chao Sun <sunchao@apple.com> Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
This commit is contained in:
parent
900b38d5fa
commit
0ece865ea4
|
@ -15,27 +15,26 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hive.execution
|
||||
package org.apache.spark.sql.execution.datasources
|
||||
|
||||
import org.scalatest.matchers.should.Matchers._
|
||||
|
||||
import org.apache.spark.metrics.source.HiveCatalogMetrics
|
||||
import org.apache.spark.sql.{QueryTest, Row}
|
||||
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
|
||||
import org.apache.spark.sql.catalyst.rules.RuleExecutor
|
||||
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
|
||||
import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions}
|
||||
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
|
||||
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
|
||||
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
|
||||
import org.apache.spark.sql.functions.broadcast
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.test.SharedSparkSession
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase {
|
||||
class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase with SharedSparkSession {
|
||||
|
||||
override def format: String = "parquet"
|
||||
|
||||
|
@ -45,35 +44,27 @@ class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase {
|
|||
|
||||
test("PruneFileSourcePartitions should not change the output of LogicalRelation") {
|
||||
withTable("test") {
|
||||
withTempDir { dir =>
|
||||
sql(
|
||||
s"""
|
||||
|CREATE EXTERNAL TABLE test(i int)
|
||||
|PARTITIONED BY (p int)
|
||||
|STORED AS parquet
|
||||
|LOCATION '${dir.toURI}'""".stripMargin)
|
||||
spark.range(10).selectExpr("id", "id % 3 as p").write.partitionBy("p").saveAsTable("test")
|
||||
val tableMeta = spark.sharedState.externalCatalog.getTable("default", "test")
|
||||
val catalogFileIndex = new CatalogFileIndex(spark, tableMeta, 0)
|
||||
|
||||
val tableMeta = spark.sharedState.externalCatalog.getTable("default", "test")
|
||||
val catalogFileIndex = new CatalogFileIndex(spark, tableMeta, 0)
|
||||
val dataSchema = StructType(tableMeta.schema.filterNot { f =>
|
||||
tableMeta.partitionColumnNames.contains(f.name)
|
||||
})
|
||||
val relation = HadoopFsRelation(
|
||||
location = catalogFileIndex,
|
||||
partitionSchema = tableMeta.partitionSchema,
|
||||
dataSchema = dataSchema,
|
||||
bucketSpec = None,
|
||||
fileFormat = new ParquetFileFormat(),
|
||||
options = Map.empty)(sparkSession = spark)
|
||||
|
||||
val dataSchema = StructType(tableMeta.schema.filterNot { f =>
|
||||
tableMeta.partitionColumnNames.contains(f.name)
|
||||
})
|
||||
val relation = HadoopFsRelation(
|
||||
location = catalogFileIndex,
|
||||
partitionSchema = tableMeta.partitionSchema,
|
||||
dataSchema = dataSchema,
|
||||
bucketSpec = None,
|
||||
fileFormat = new ParquetFileFormat(),
|
||||
options = Map.empty)(sparkSession = spark)
|
||||
val logicalRelation = LogicalRelation(relation, tableMeta)
|
||||
val query = Project(Seq(Symbol("id"), Symbol("p")),
|
||||
Filter(Symbol("p") === 1, logicalRelation)).analyze
|
||||
|
||||
val logicalRelation = LogicalRelation(relation, tableMeta)
|
||||
val query = Project(Seq(Symbol("i"), Symbol("p")),
|
||||
Filter(Symbol("p") === 1, logicalRelation)).analyze
|
||||
|
||||
val optimized = Optimize.execute(query)
|
||||
assert(optimized.missingInput.isEmpty)
|
||||
}
|
||||
val optimized = Optimize.execute(query)
|
||||
assert(optimized.missingInput.isEmpty)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -116,7 +107,7 @@ class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase {
|
|||
// Force datasource v2 for parquet
|
||||
withSQLConf((SQLConf.USE_V1_SOURCE_LIST.key, "")) {
|
||||
withTempPath { dir =>
|
||||
spark.range(10).selectExpr("id", "id % 3 as p")
|
||||
spark.range(10).coalesce(1).selectExpr("id", "id % 3 as p")
|
||||
.write.partitionBy("p").parquet(dir.getCanonicalPath)
|
||||
withTempView("tmp") {
|
||||
spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp");
|
||||
|
@ -127,19 +118,8 @@ class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase {
|
|||
}
|
||||
}
|
||||
|
||||
test("SPARK-36128: spark.sql.hive.metastorePartitionPruning should work for file data sources") {
|
||||
Seq(true, false).foreach { enablePruning =>
|
||||
withTable("tbl") {
|
||||
withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING.key -> enablePruning.toString) {
|
||||
spark.range(10).selectExpr("id", "id % 3 as p").write.partitionBy("p").saveAsTable("tbl")
|
||||
HiveCatalogMetrics.reset()
|
||||
QueryTest.checkAnswer(sql("SELECT id FROM tbl WHERE p = 1"),
|
||||
Seq(1, 4, 7).map(Row.apply(_)), checkToRDD = false) // avoid analyzing the query twice
|
||||
val expectedCount = if (enablePruning) 1 else 3
|
||||
assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount == expectedCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
protected def collectPartitionFiltersFn(): PartialFunction[SparkPlan, Seq[Expression]] = {
|
||||
case scan: FileSourceScanExec => scan.partitionFilters
|
||||
}
|
||||
|
||||
override def getScanExecPartitionSize(plan: SparkPlan): Long = {
|
|
@ -15,16 +15,15 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hive.execution
|
||||
package org.apache.spark.sql.execution.datasources
|
||||
|
||||
import org.apache.spark.sql.StatisticsCollectionTestBase
|
||||
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryOperator, Expression, IsNotNull, Literal}
|
||||
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
|
||||
import org.apache.spark.sql.execution.SparkPlan
|
||||
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan}
|
||||
import org.apache.spark.sql.hive.test.TestHiveSingleton
|
||||
import org.apache.spark.sql.internal.SQLConf.ADAPTIVE_EXECUTION_ENABLED
|
||||
|
||||
abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase with TestHiveSingleton {
|
||||
abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase {
|
||||
|
||||
protected def format: String
|
||||
|
||||
|
@ -95,11 +94,11 @@ abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase with
|
|||
val plan = qe.sparkPlan
|
||||
assert(getScanExecPartitionSize(plan) == expectedPartitionCount)
|
||||
|
||||
val pushedDownPartitionFilters = plan.collectFirst {
|
||||
case scan: FileSourceScanExec => scan.partitionFilters
|
||||
case scan: HiveTableScanExec => scan.partitionPruningPred
|
||||
val collectFn: PartialFunction[SparkPlan, Seq[Expression]] = collectPartitionFiltersFn orElse {
|
||||
case BatchScanExec(_, scan: FileScan, _) => scan.partitionFilters
|
||||
}.map(exps => exps.filterNot(e => e.isInstanceOf[IsNotNull]))
|
||||
}
|
||||
val pushedDownPartitionFilters = plan.collectFirst(collectFn)
|
||||
.map(exps => exps.filterNot(e => e.isInstanceOf[IsNotNull]))
|
||||
val pushedFilters = pushedDownPartitionFilters.map(filters => {
|
||||
filters.foldLeft("")((currentStr, exp) => {
|
||||
if (currentStr == "") {
|
||||
|
@ -113,5 +112,7 @@ abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase with
|
|||
assert(pushedFilters == Some(expectedPushedDownFilters))
|
||||
}
|
||||
|
||||
protected def collectPartitionFiltersFn(): PartialFunction[SparkPlan, Seq[Expression]]
|
||||
|
||||
protected def getScanExecPartitionSize(plan: SparkPlan): Long
|
||||
}
|
|
@ -17,14 +17,19 @@
|
|||
|
||||
package org.apache.spark.sql.hive.execution
|
||||
|
||||
import org.apache.spark.metrics.source.HiveCatalogMetrics
|
||||
import org.apache.spark.sql.{QueryTest, Row}
|
||||
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan}
|
||||
import org.apache.spark.sql.catalyst.rules.RuleExecutor
|
||||
import org.apache.spark.sql.execution.SparkPlan
|
||||
import org.apache.spark.sql.execution.datasources.PrunePartitionSuiteBase
|
||||
import org.apache.spark.sql.hive.test.TestHiveSingleton
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types.LongType
|
||||
|
||||
class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase {
|
||||
class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase with TestHiveSingleton {
|
||||
|
||||
override def format(): String = "hive"
|
||||
|
||||
|
@ -131,6 +136,25 @@ class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase {
|
|||
}
|
||||
}
|
||||
|
||||
test("SPARK-36128: spark.sql.hive.metastorePartitionPruning should work for file data sources") {
|
||||
Seq(true, false).foreach { enablePruning =>
|
||||
withTable("tbl") {
|
||||
withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING.key -> enablePruning.toString) {
|
||||
spark.range(10).selectExpr("id", "id % 3 as p").write.partitionBy("p").saveAsTable("tbl")
|
||||
HiveCatalogMetrics.reset()
|
||||
QueryTest.checkAnswer(sql("SELECT id FROM tbl WHERE p = 1"),
|
||||
Seq(1, 4, 7).map(Row.apply(_)), checkToRDD = false) // avoid analyzing the query twice
|
||||
val expectedCount = if (enablePruning) 1 else 3
|
||||
assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount == expectedCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected def collectPartitionFiltersFn(): PartialFunction[SparkPlan, Seq[Expression]] = {
|
||||
case scan: HiveTableScanExec => scan.partitionPruningPred
|
||||
}
|
||||
|
||||
override def getScanExecPartitionSize(plan: SparkPlan): Long = {
|
||||
plan.collectFirst {
|
||||
case p: HiveTableScanExec => p
|
||||
|
|
Loading…
Reference in a new issue