[SPARK-36136][SQL][TESTS] Refactor PruneFileSourcePartitionsSuite etc to a different package

### What changes were proposed in this pull request?

Move both `PruneFileSourcePartitionsSuite` and `PrunePartitionSuiteBase` to the package `org.apache.spark.sql.execution.datasources`. Did a few refactoring to enable this.

### Why are the changes needed?

Currently both `PruneFileSourcePartitionsSuite` and `PrunePartitionSuiteBase` are in package `org.apache.spark.sql.hive.execution` which doesn't look correct as these tests are not specific to Hive. Therefore, it's better to move them into `org.apache.spark.sql.execution.datasources`, the same place where the rule `PruneFileSourcePartitions` is at.

### Does this PR introduce _any_ user-facing change?

No, it's just test refactoring.

### How was this patch tested?

Using existing tests:
```
build/sbt "sql/testOnly *PruneFileSourcePartitionsSuite"
```
and
```
build/sbt "hive/testOnly *PruneHiveTablePartitionsSuite"
```

Closes #33350 from sunchao/SPARK-36136-partitions-suite.

Authored-by: Chao Sun <sunchao@apple.com>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
(cherry picked from commit 634f96dde4)
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
This commit is contained in:
Chao Sun 2021-07-26 13:03:50 -07:00 committed by Liang-Chi Hsieh
parent a77c9d6d17
commit ae7b32a9e8
3 changed files with 43 additions and 38 deletions

View file

@ -15,7 +15,7 @@
* limitations under the License. * limitations under the License.
*/ */
package org.apache.spark.sql.hive.execution package org.apache.spark.sql.execution.datasources
import org.scalatest.matchers.should.Matchers._ import org.scalatest.matchers.should.Matchers._
@ -24,18 +24,19 @@ import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.functions.broadcast import org.apache.spark.sql.functions.broadcast
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StructType
class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase { class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase with SharedSparkSession {
override def format: String = "parquet" override def format: String = "parquet"
@ -45,35 +46,27 @@ class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase {
test("PruneFileSourcePartitions should not change the output of LogicalRelation") { test("PruneFileSourcePartitions should not change the output of LogicalRelation") {
withTable("test") { withTable("test") {
withTempDir { dir => spark.range(10).selectExpr("id", "id % 3 as p").write.partitionBy("p").saveAsTable("test")
sql( val tableMeta = spark.sharedState.externalCatalog.getTable("default", "test")
s""" val catalogFileIndex = new CatalogFileIndex(spark, tableMeta, 0)
|CREATE EXTERNAL TABLE test(i int)
|PARTITIONED BY (p int)
|STORED AS parquet
|LOCATION '${dir.toURI}'""".stripMargin)
val tableMeta = spark.sharedState.externalCatalog.getTable("default", "test") val dataSchema = StructType(tableMeta.schema.filterNot { f =>
val catalogFileIndex = new CatalogFileIndex(spark, tableMeta, 0) tableMeta.partitionColumnNames.contains(f.name)
})
val relation = HadoopFsRelation(
location = catalogFileIndex,
partitionSchema = tableMeta.partitionSchema,
dataSchema = dataSchema,
bucketSpec = None,
fileFormat = new ParquetFileFormat(),
options = Map.empty)(sparkSession = spark)
val dataSchema = StructType(tableMeta.schema.filterNot { f => val logicalRelation = LogicalRelation(relation, tableMeta)
tableMeta.partitionColumnNames.contains(f.name) val query = Project(Seq(Symbol("id"), Symbol("p")),
}) Filter(Symbol("p") === 1, logicalRelation)).analyze
val relation = HadoopFsRelation(
location = catalogFileIndex,
partitionSchema = tableMeta.partitionSchema,
dataSchema = dataSchema,
bucketSpec = None,
fileFormat = new ParquetFileFormat(),
options = Map.empty)(sparkSession = spark)
val logicalRelation = LogicalRelation(relation, tableMeta) val optimized = Optimize.execute(query)
val query = Project(Seq(Symbol("i"), Symbol("p")), assert(optimized.missingInput.isEmpty)
Filter(Symbol("p") === 1, logicalRelation)).analyze
val optimized = Optimize.execute(query)
assert(optimized.missingInput.isEmpty)
}
} }
} }
@ -142,6 +135,10 @@ class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase {
} }
} }
protected def collectPartitionFiltersFn(): PartialFunction[SparkPlan, Seq[Expression]] = {
case scan: FileSourceScanExec => scan.partitionFilters
}
override def getScanExecPartitionSize(plan: SparkPlan): Long = { override def getScanExecPartitionSize(plan: SparkPlan): Long = {
plan.collectFirst { plan.collectFirst {
case p: FileSourceScanExec => p.selectedPartitions.length case p: FileSourceScanExec => p.selectedPartitions.length

View file

@ -15,16 +15,15 @@
* limitations under the License. * limitations under the License.
*/ */
package org.apache.spark.sql.hive.execution package org.apache.spark.sql.execution.datasources
import org.apache.spark.sql.StatisticsCollectionTestBase import org.apache.spark.sql.StatisticsCollectionTestBase
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryOperator, Expression, IsNotNull, Literal} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryOperator, Expression, IsNotNull, Literal}
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan} import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf.ADAPTIVE_EXECUTION_ENABLED import org.apache.spark.sql.internal.SQLConf.ADAPTIVE_EXECUTION_ENABLED
abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase with TestHiveSingleton { abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase {
protected def format: String protected def format: String
@ -95,11 +94,11 @@ abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase with
val plan = qe.sparkPlan val plan = qe.sparkPlan
assert(getScanExecPartitionSize(plan) == expectedPartitionCount) assert(getScanExecPartitionSize(plan) == expectedPartitionCount)
val pushedDownPartitionFilters = plan.collectFirst { val collectFn: PartialFunction[SparkPlan, Seq[Expression]] = collectPartitionFiltersFn orElse {
case scan: FileSourceScanExec => scan.partitionFilters
case scan: HiveTableScanExec => scan.partitionPruningPred
case BatchScanExec(_, scan: FileScan, _) => scan.partitionFilters case BatchScanExec(_, scan: FileScan, _) => scan.partitionFilters
}.map(exps => exps.filterNot(e => e.isInstanceOf[IsNotNull])) }
val pushedDownPartitionFilters = plan.collectFirst(collectFn)
.map(exps => exps.filterNot(e => e.isInstanceOf[IsNotNull]))
val pushedFilters = pushedDownPartitionFilters.map(filters => { val pushedFilters = pushedDownPartitionFilters.map(filters => {
filters.foldLeft("")((currentStr, exp) => { filters.foldLeft("")((currentStr, exp) => {
if (currentStr == "") { if (currentStr == "") {
@ -113,5 +112,7 @@ abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase with
assert(pushedFilters == Some(expectedPushedDownFilters)) assert(pushedFilters == Some(expectedPushedDownFilters))
} }
protected def collectPartitionFiltersFn(): PartialFunction[SparkPlan, Seq[Expression]]
protected def getScanExecPartitionSize(plan: SparkPlan): Long protected def getScanExecPartitionSize(plan: SparkPlan): Long
} }

View file

@ -18,13 +18,16 @@
package org.apache.spark.sql.hive.execution package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.PrunePartitionSuiteBase
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.LongType import org.apache.spark.sql.types.LongType
class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase { class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase with TestHiveSingleton {
override def format(): String = "hive" override def format(): String = "hive"
@ -131,6 +134,10 @@ class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase {
} }
} }
protected def collectPartitionFiltersFn(): PartialFunction[SparkPlan, Seq[Expression]] = {
case scan: HiveTableScanExec => scan.partitionPruningPred
}
override def getScanExecPartitionSize(plan: SparkPlan): Long = { override def getScanExecPartitionSize(plan: SparkPlan): Long = {
plan.collectFirst { plan.collectFirst {
case p: HiveTableScanExec => p case p: HiveTableScanExec => p