From 2d59ca464eb12d4de03bfa383a3efcfe0bc0441d Mon Sep 17 00:00:00 2001 From: Guy Khazma Date: Mon, 20 Jan 2020 20:20:37 -0800 Subject: [PATCH] [SPARK-30475][SQL] File source V2: Push data filters for file listing ### What changes were proposed in this pull request? Follow up on [SPARK-30428](https://github.com/apache/spark/pull/27112) which added support for partition pruning in File source V2. This PR implements the necessary changes in order to pass the `dataFilters` to the `listFiles`. This enables having `FileIndex` implementations which use the `dataFilters` for further pruning the file listing (see the discussion [here](https://github.com/apache/spark/pull/27112#discussion_r364757217)). ### Why are the changes needed? Datasources such as `csv` and `json` do not implement the `SupportsPushDownFilters` trait. In order to support data skipping uniformly for all file based data sources, one can override the `listFiles` method in a `FileIndex` implementation, which consults external metadata and prunes the list of files. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Modifying the unit tests for v2 file sources to verify the `dataFilters` are passed Closes #27157 from guykhazma/PushdataFiltersInFileListing. Authored-by: Guy Khazma Signed-off-by: Gengliang Wang --- .../apache/spark/sql/v2/avro/AvroScan.scala | 8 +++-- .../org/apache/spark/sql/avro/AvroSuite.scala | 29 +++++++++++++++ .../PruneFileSourcePartitions.scala | 22 +++++++----- .../execution/datasources/v2/FileScan.scala | 16 ++++++--- .../datasources/v2/csv/CSVScan.scala | 8 +++-- .../datasources/v2/json/JsonScan.scala | 8 +++-- .../datasources/v2/orc/OrcScan.scala | 8 +++-- .../datasources/v2/parquet/ParquetScan.scala | 8 +++-- .../datasources/v2/text/TextScan.scala | 8 +++-- .../spark/sql/FileBasedDataSourceSuite.scala | 36 +++++++++++++++++++ 10 files changed, 120 insertions(+), 31 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala index d5a29124a2..fe7315c739 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala @@ -37,7 +37,8 @@ case class AvroScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression] = Seq.empty) extends FileScan { + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true override def createReaderFactory(): PartitionReaderFactory = { @@ -53,8 +54,9 @@ case class AvroScan( dataSchema, readDataSchema, readPartitionSchema, parsedOptions) } - override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters) + override def withFilters( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) override def equals(obj: Any): Boolean = obj match { case a: AvroScan => super.equals(a) && dataSchema == a.dataSchema && options == a.options diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 5d2f78deb4..360160c9c9 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -1566,6 +1566,7 @@ class AvroV2Suite extends AvroSuite { } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.nonEmpty) + assert(fileScan.get.dataFilters.nonEmpty) assert(fileScan.get.planInputPartitions().forall { partition => partition.asInstanceOf[FilePartition].files.forall { file => file.filePath.contains("p1=1") && file.filePath.contains("p2=2") @@ -1575,6 +1576,34 @@ class AvroV2Suite extends AvroSuite { } } + test("Avro source v2: support passing data filters to FileScan without partitionFilters") { + withTempPath { dir => + Seq(("a", 1, 2), ("b", 1, 2), ("c", 2, 1)) + .toDF("value", "p1", "p2") + .write + .format("avro") + .save(dir.getCanonicalPath) + val df = spark + .read + .format("avro") + .load(dir.getCanonicalPath) + .where("value = 'a'") + + val filterCondition = df.queryExecution.optimizedPlan.collectFirst { + case f: Filter => f.condition + } + assert(filterCondition.isDefined) + + val fileScan = df.queryExecution.executedPlan collectFirst { + case BatchScanExec(_, f: AvroScan) => f + } + assert(fileScan.nonEmpty) + assert(fileScan.get.partitionFilters.isEmpty) + assert(fileScan.get.dataFilters.nonEmpty) + checkAnswer(df, Row("a", 1, 2)) + } + } + private def getBatchScanExec(plan: SparkPlan): BatchScanExec = { plan.find(_.isInstanceOf[BatchScanExec]).get.asInstanceOf[BatchScanExec] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 7fd154ccac..59c55c161b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -28,20 +28,22 @@ import org.apache.spark.sql.types.StructType private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { - private def getPartitionKeyFilters( + private def getPartitionKeyFiltersAndDataFilters( sparkSession: SparkSession, relation: LeafNode, partitionSchema: StructType, filters: Seq[Expression], - output: Seq[AttributeReference]): ExpressionSet = { + output: Seq[AttributeReference]): (ExpressionSet, Seq[Expression]) = { val normalizedFilters = DataSourceStrategy.normalizeExprs( filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), output) val partitionColumns = relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) - ExpressionSet(normalizedFilters.filter { f => + val (partitionFilters, dataFilters) = normalizedFilters.partition(f => f.references.subsetOf(partitionSet) - }) + ) + + (ExpressionSet(partitionFilters), dataFilters) } private def rebuildPhysicalOperation( @@ -72,7 +74,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { _, _)) if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => - val partitionKeyFilters = getPartitionKeyFilters( + val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters( fsRelation.sparkSession, logicalRelation, partitionSchema, filters, logicalRelation.output) if (partitionKeyFilters.nonEmpty) { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) @@ -92,11 +94,13 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { case op @ PhysicalOperation(projects, filters, v2Relation @ DataSourceV2ScanRelation(_, scan: FileScan, output)) if filters.nonEmpty && scan.readDataSchema.nonEmpty => - val partitionKeyFilters = getPartitionKeyFilters(scan.sparkSession, - v2Relation, scan.readPartitionSchema, filters, output) - if (partitionKeyFilters.nonEmpty) { + val (partitionKeyFilters, dataFilters) = + getPartitionKeyFiltersAndDataFilters(scan.sparkSession, v2Relation, + scan.readPartitionSchema, filters, output) + // The dataFilters are pushed down only once + if (partitionKeyFilters.nonEmpty || (dataFilters.nonEmpty && scan.dataFilters.isEmpty)) { val prunedV2Relation = - v2Relation.copy(scan = scan.withPartitionFilters(partitionKeyFilters.toSeq)) + v2Relation.copy(scan = scan.withFilters(partitionKeyFilters.toSeq, dataFilters)) // The pushed down partition filters don't need to be reevaluated. val afterScanFilters = ExpressionSet(filters) -- partitionKeyFilters.filter(_.references.nonEmpty) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index a22e1ccfe4..6e05aa56f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -61,9 +61,15 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin def partitionFilters: Seq[Expression] /** - * Create a new `FileScan` instance from the current one with different `partitionFilters`. + * Returns the data filters that can be use for file listing */ - def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan + def dataFilters: Seq[Expression] + + /** + * Create a new `FileScan` instance from the current one + * with different `partitionFilters` and `dataFilters` + */ + def withFilters(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan /** * If a file with `path` is unsplittable, return the unsplittable reason, @@ -79,7 +85,8 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin override def equals(obj: Any): Boolean = obj match { case f: FileScan => fileIndex == f.fileIndex && readSchema == f.readSchema - ExpressionSet(partitionFilters) == ExpressionSet(f.partitionFilters) + ExpressionSet(partitionFilters) == ExpressionSet(f.partitionFilters) && + ExpressionSet(dataFilters) == ExpressionSet(f.dataFilters) case _ => false } @@ -92,6 +99,7 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin val metadata: Map[String, String] = Map( "ReadSchema" -> readDataSchema.catalogString, "PartitionFilters" -> seqToString(partitionFilters), + "DataFilters" -> seqToString(dataFilters), "Location" -> locationDesc) val metadataStr = metadata.toSeq.sorted.map { case (key, value) => @@ -103,7 +111,7 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin } protected def partitions: Seq[FilePartition] = { - val selectedPartitions = fileIndex.listFiles(partitionFilters, Seq.empty) + val selectedPartitions = fileIndex.listFiles(partitionFilters, dataFilters) val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions) val partitionAttributes = fileIndex.partitionSchema.toAttributes val attributeMap = partitionAttributes.map(a => normalizeName(a.name) -> a).toMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index 690d66908e..4f51032281 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -40,7 +40,8 @@ case class CSVScan( readPartitionSchema: StructType, options: CaseInsensitiveStringMap, pushedFilters: Array[Filter], - partitionFilters: Seq[Expression] = Seq.empty) + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) extends TextBasedFileScan(sparkSession, options) { private lazy val parsedOptions: CSVOptions = new CSVOptions( @@ -91,8 +92,9 @@ case class CSVScan( dataSchema, readDataSchema, readPartitionSchema, parsedOptions, pushedFilters) } - override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters) + override def withFilters( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) override def equals(obj: Any): Boolean = obj match { case c: CSVScan => super.equals(c) && dataSchema == c.dataSchema && options == c.options && diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala index 153b402476..7523162567 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala @@ -39,7 +39,8 @@ case class JsonScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression] = Seq.empty) + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) extends TextBasedFileScan(sparkSession, options) { private val parsedOptions = new JSONOptionsInRead( @@ -88,8 +89,9 @@ case class JsonScan( dataSchema, readDataSchema, readPartitionSchema, parsedOptions) } - override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters) + override def withFilters( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) override def equals(obj: Any): Boolean = obj match { case j: JsonScan => super.equals(j) && dataSchema == j.dataSchema && options == j.options diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index f0595cb6d0..62894fa7a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -38,7 +38,8 @@ case class OrcScan( readPartitionSchema: StructType, options: CaseInsensitiveStringMap, pushedFilters: Array[Filter], - partitionFilters: Seq[Expression] = Seq.empty) extends FileScan { + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true override def createReaderFactory(): PartitionReaderFactory = { @@ -64,6 +65,7 @@ case class OrcScan( super.description() + ", PushedFilters: " + seqToString(pushedFilters) } - override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters) + override def withFilters( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 44179e2e42..bb315262a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -41,7 +41,8 @@ case class ParquetScan( readPartitionSchema: StructType, pushedFilters: Array[Filter], options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression] = Seq.empty) extends FileScan { + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true override def createReaderFactory(): PartitionReaderFactory = { @@ -92,6 +93,7 @@ case class ParquetScan( super.description() + ", PushedFilters: " + seqToString(pushedFilters) } - override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters) + override def withFilters( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala index cf6595e5c1..e75de2c4a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala @@ -36,7 +36,8 @@ case class TextScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression] = Seq.empty) + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) extends TextBasedFileScan(sparkSession, options) { private val optionsAsScala = options.asScala.toMap @@ -70,8 +71,9 @@ case class TextScan( readPartitionSchema, textOptions) } - override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters) + override def withFilters( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) override def equals(obj: Any): Boolean = obj match { case t: TextScan => super.equals(t) && options == t.options diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 972683512d..c870958128 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -775,6 +775,7 @@ class FileBasedDataSourceSuite extends QueryTest } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.nonEmpty) + assert(fileScan.get.dataFilters.nonEmpty) assert(fileScan.get.planInputPartitions().forall { partition => partition.asInstanceOf[FilePartition].files.forall { file => file.filePath.contains("p1=1") && file.filePath.contains("p2=2") @@ -786,6 +787,41 @@ class FileBasedDataSourceSuite extends QueryTest } } + test("File source v2: support passing data filters to FileScan without partitionFilters") { + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + allFileBasedDataSources.foreach { format => + withTempPath { dir => + Seq(("a", 1, 2), ("b", 1, 2), ("c", 2, 1)) + .toDF("value", "p1", "p2") + .write + .format(format) + .partitionBy("p1", "p2") + .option("header", true) + .save(dir.getCanonicalPath) + val df = spark + .read + .format(format) + .option("header", true) + .load(dir.getCanonicalPath) + .where("value = 'a'") + + val filterCondition = df.queryExecution.optimizedPlan.collectFirst { + case f: Filter => f.condition + } + assert(filterCondition.isDefined) + + val fileScan = df.queryExecution.executedPlan collectFirst { + case BatchScanExec(_, f: FileScan) => f + } + assert(fileScan.nonEmpty) + assert(fileScan.get.partitionFilters.isEmpty) + assert(fileScan.get.dataFilters.nonEmpty) + checkAnswer(df, Row("a", 1, 2)) + } + } + } + } + test("File table location should include both values of option `path` and `paths`") { withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { withTempPaths(3) { paths =>