From 38b6fbd9b8c621dc2de447d4a3ef65ea28510e5e Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 2 Sep 2021 19:11:43 -0700 Subject: [PATCH] [SPARK-36351][SQL] Refactor filter push down in file source v2 ### What changes were proposed in this pull request? Currently in `V2ScanRelationPushDown`, we push the filters (partition filters + data filters) to file source, and then pass all the filters (partition filters + data filters) as post scan filters to v2 Scan, and later in `PruneFileSourcePartitions`, we separate partition filters and data filters, set them in the format of `Expression` to file source. Changes in this PR: When we push filters to file sources in `V2ScanRelationPushDown`, since we already have the information about partition column , we want to separate partition filter and data filter there. The benefit of doing this: - we can handle all the filter related work for v2 file source at one place instead of two (`V2ScanRelationPushDown` and `PruneFileSourcePartitions`), so the code will be cleaner and easier to maintain. - we actually have to separate partition filters and data filters at `V2ScanRelationPushDown`, otherwise, there is no way to find out which filters are partition filters, and we can't push down aggregate for parquet even if we only have partition filter. - By separating the filters early at `V2ScanRelationPushDown`, we only needs to check data filters to find out which one needs to be converted to data source filters (e.g. Parquet predicates, ORC predicates) and pushed down to file source, right now we are checking all the filters (both partition filters and data filters) - Similarly, we can only pass data filters as post scan filters to v2 Scan, because partition filters are used for partition pruning only, no need to pass them as post scan filters. In order to do this, we will have the following changes - add `pushFilters` in file source v2. In this method: - push both Expression partition filter and Expression data filter to file source. Have to use Expression filters because we need these for partition pruning. - data filters are used for filter push down. If file source needs to push down data filters, it translates the data filters from `Expression` to `Sources.Filer`, and then decides which filters to push down. - partition filters are used for partition pruning. - file source v2 no need to implement `SupportsPushdownFilters` any more, because when we separating the two types of filters, we have already set them on file data sources. It's redundant to use `SupportsPushdownFilters` to set the filters again on file data sources. ### Why are the changes needed? see section one ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests Closes #33650 from huaxingao/partition_filter. Authored-by: Huaxin Gao Signed-off-by: Liang-Chi Hsieh --- .../apache/spark/sql/v2/avro/AvroScan.scala | 4 -- .../spark/sql/v2/avro/AvroScanBuilder.scala | 19 +++---- .../SupportsPushDownCatalystFilters.scala | 41 ++++++++++++++ .../datasources/DataSourceUtils.scala | 21 ++++++- .../PruneFileSourcePartitions.scala | 56 ++----------------- .../execution/datasources/v2/FileScan.scala | 6 -- .../datasources/v2/FileScanBuilder.scala | 44 +++++++++++++-- .../datasources/v2/PushDownUtils.scala | 7 ++- .../datasources/v2/csv/CSVScan.scala | 6 +- .../datasources/v2/csv/CSVScanBuilder.scala | 19 +++---- .../datasources/v2/json/JsonScan.scala | 6 +- .../datasources/v2/json/JsonScanBuilder.scala | 19 +++---- .../datasources/v2/orc/OrcScan.scala | 4 -- .../datasources/v2/orc/OrcScanBuilder.scala | 19 +++---- .../datasources/v2/parquet/ParquetScan.scala | 4 -- .../v2/parquet/ParquetScanBuilder.scala | 15 ++--- .../datasources/v2/text/TextScan.scala | 6 +- .../datasources/v2/text/TextScanBuilder.scala | 3 +- .../datasources/json/JsonSuite.scala | 6 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 20 ++++++- 20 files changed, 177 insertions(+), 148 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala index 144e9ad129..d0f38c1242 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala @@ -62,10 +62,6 @@ case class AvroScan( pushedFilters) } - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) - override def equals(obj: Any): Boolean = obj match { case a: AvroScan => super.equals(a) && dataSchema == a.dataSchema && options == a.options && equivalentFilters(pushedFilters, a.pushedFilters) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala index 9420608bb2..8fae89a945 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.v2.avro import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.StructFilters -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter @@ -31,7 +31,7 @@ class AvroScanBuilder ( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { AvroScan( @@ -41,17 +41,16 @@ class AvroScanBuilder ( readDataSchema(), readPartitionSchema(), options, - pushedFilters()) + pushedDataFilters, + partitionFilters, + dataFilters) } - private var _pushedFilters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.avroFilterPushDown) { - _pushedFilters = StructFilters.pushedFilters(filters, dataSchema) + StructFilters.pushedFilters(dataFilters, dataSchema) + } else { + Array.empty[Filter] } - filters } - - override def pushedFilters(): Array[Filter] = _pushedFilters } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala new file mode 100644 index 0000000000..9c2a4ac78a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal.connector + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.sources.Filter + +/** + * A mix-in interface for {@link FileScanBuilder}. File sources can implement this interface to + * push down filters to the file source. The pushed down filters will be separated into partition + * filters and data filters. Partition filters are used for partition pruning and data filters are + * used to reduce the size of the data to be read. + */ +trait SupportsPushDownCatalystFilters { + + /** + * Pushes down catalyst Expression filters (which will be separated into partition filters and + * data filters), and returns data filters that need to be evaluated after scanning. + */ + def pushFilters(filters: Seq[Expression]): Seq[Expression] + + /** + * Returns the data filters that are pushed to the data source via + * {@link #pushFilters(Expression[])}. + */ + def pushedFilters: Array[Filter] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index fcd95a27bf..67d03998a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -28,6 +28,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.SparkUpgradeException import org.apache.spark.sql.{SPARK_LEGACY_DATETIME, SPARK_LEGACY_INT96, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper} import org.apache.spark.sql.catalyst.util.RebaseDateTime import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions @@ -39,7 +40,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.Utils -object DataSourceUtils { +object DataSourceUtils extends PredicateHelper { /** * The key to use for storing partitionBy columns as options. */ @@ -242,4 +243,22 @@ object DataSourceUtils { options } } + + def getPartitionFiltersAndDataFilters( + partitionSchema: StructType, + normalizedFilters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + val partitionColumns = normalizedFilters.flatMap { expr => + expr.collect { + case attr: AttributeReference if partitionSchema.names.contains(attr.name) => + attr + } + } + val partitionSet = AttributeSet(partitionColumns) + val (partitionFilters, dataFilters) = normalizedFilters.partition(f => + f.references.subsetOf(partitionSet) + ) + val extraPartitionFilter = + dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet)) + (ExpressionSet(partitionFilters ++ extraPartitionFilter).toSeq, dataFilters) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 0927027bee..2e8e5426d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -17,52 +17,24 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.CatalogStatistics import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.FilterEstimation import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, FileScan} -import org.apache.spark.sql.types.StructType /** * Prune the partitions of file source based table using partition filters. Currently, this rule - * is applied to [[HadoopFsRelation]] with [[CatalogFileIndex]] and [[DataSourceV2ScanRelation]] - * with [[FileScan]]. + * is applied to [[HadoopFsRelation]] with [[CatalogFileIndex]]. * * For [[HadoopFsRelation]], the location will be replaced by pruned file index, and corresponding * statistics will be updated. And the partition filters will be kept in the filters of returned * logical plan. - * - * For [[DataSourceV2ScanRelation]], both partition filters and data filters will be added to - * its underlying [[FileScan]]. And the partition filters will be removed in the filters of - * returned logical plan. */ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] with PredicateHelper { - private def getPartitionKeyFiltersAndDataFilters( - sparkSession: SparkSession, - relation: LeafNode, - partitionSchema: StructType, - filters: Seq[Expression], - output: Seq[AttributeReference]): (ExpressionSet, Seq[Expression]) = { - val normalizedFilters = DataSourceStrategy.normalizeExprs( - filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), output) - val partitionColumns = - relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver) - val partitionSet = AttributeSet(partitionColumns) - val (partitionFilters, dataFilters) = normalizedFilters.partition(f => - f.references.subsetOf(partitionSet) - ) - val extraPartitionFilter = - dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet)) - - (ExpressionSet(partitionFilters ++ extraPartitionFilter), dataFilters) - } - private def rebuildPhysicalOperation( projects: Seq[NamedExpression], filters: Seq[Expression], @@ -91,12 +63,14 @@ private[sql] object PruneFileSourcePartitions _, _)) if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => - val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters( - fsRelation.sparkSession, logicalRelation, partitionSchema, filters, + val normalizedFilters = DataSourceStrategy.normalizeExprs( + filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), logicalRelation.output) + val (partitionKeyFilters, _) = DataSourceUtils + .getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters) if (partitionKeyFilters.nonEmpty) { - val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) + val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters) val prunedFsRelation = fsRelation.copy(location = prunedFileIndex)(fsRelation.sparkSession) // Change table stats based on the sizeInBytes of pruned files @@ -117,23 +91,5 @@ private[sql] object PruneFileSourcePartitions } else { op } - - case op @ PhysicalOperation(projects, filters, - v2Relation @ DataSourceV2ScanRelation(_, scan: FileScan, output)) - if filters.nonEmpty => - val (partitionKeyFilters, dataFilters) = - getPartitionKeyFiltersAndDataFilters(scan.sparkSession, v2Relation, - scan.readPartitionSchema, filters, output) - // The dataFilters are pushed down only once - if (partitionKeyFilters.nonEmpty || (dataFilters.nonEmpty && scan.dataFilters.isEmpty)) { - val prunedV2Relation = - v2Relation.copy(scan = scan.withFilters(partitionKeyFilters.toSeq, dataFilters)) - // The pushed down partition filters don't need to be reevaluated. - val afterScanFilters = - ExpressionSet(filters) -- partitionKeyFilters.filter(_.references.nonEmpty) - rebuildPhysicalOperation(projects, afterScanFilters.toSeq, prunedV2Relation) - } else { - op - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index b20270275d..8b0328cabc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -71,12 +71,6 @@ trait FileScan extends Scan */ def dataFilters: Seq[Expression] - /** - * Create a new `FileScan` instance from the current one - * with different `partitionFilters` and `dataFilters` - */ - def withFilters(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan - /** * If a file with `path` is unsplittable, return the unsplittable reason, * otherwise return `None`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index 97874e8f49..309f045201 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -16,19 +16,30 @@ */ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.SparkSession +import scala.collection.mutable + +import org.apache.spark.sql.{sources, SparkSession} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.execution.datasources.{PartitioningAwareFileIndex, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, DataSourceUtils, PartitioningAwareFileIndex, PartitioningUtils} +import org.apache.spark.sql.internal.connector.SupportsPushDownCatalystFilters +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType abstract class FileScanBuilder( sparkSession: SparkSession, fileIndex: PartitioningAwareFileIndex, - dataSchema: StructType) extends ScanBuilder with SupportsPushDownRequiredColumns { + dataSchema: StructType) + extends ScanBuilder + with SupportsPushDownRequiredColumns + with SupportsPushDownCatalystFilters { private val partitionSchema = fileIndex.partitionSchema private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis protected val supportsNestedSchemaPruning = false protected var requiredSchema = StructType(dataSchema.fields ++ partitionSchema.fields) + protected var partitionFilters = Seq.empty[Expression] + protected var dataFilters = Seq.empty[Expression] + protected var pushedDataFilters = Array.empty[Filter] override def pruneColumns(requiredSchema: StructType): Unit = { // [SPARK-30107] While `requiredSchema` might have pruned nested columns, @@ -48,7 +59,7 @@ abstract class FileScanBuilder( StructType(fields) } - protected def readPartitionSchema(): StructType = { + def readPartitionSchema(): StructType = { val requiredNameSet = createRequiredNameSet() val fields = partitionSchema.fields.filter { field => val colName = PartitioningUtils.getColName(field, isCaseSensitive) @@ -57,6 +68,31 @@ abstract class FileScanBuilder( StructType(fields) } + override def pushFilters(filters: Seq[Expression]): Seq[Expression] = { + val (partitionFilters, dataFilters) = + DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, filters) + this.partitionFilters = partitionFilters + this.dataFilters = dataFilters + val translatedFilters = mutable.ArrayBuffer.empty[sources.Filter] + for (filterExpr <- dataFilters) { + val translated = DataSourceStrategy.translateFilter(filterExpr, true) + if (translated.nonEmpty) { + translatedFilters += translated.get + } + } + pushedDataFilters = pushDataFilters(translatedFilters.toArray) + dataFilters + } + + override def pushedFilters: Array[Filter] = pushedDataFilters + + /* + * Push down data filters to the file source, so the data filters can be evaluated there to + * reduce the size of the data to be read. By default, data filters are not pushed down. + * File source needs to implement this method to push down data filters. + */ + protected def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = Array.empty[Filter] + private def createRequiredNameSet(): Set[String] = requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index acc6457418..7229488026 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -25,9 +25,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.execution.datasources.PushableColumnWithoutNestedColumn +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType @@ -71,6 +69,9 @@ object PushDownUtils extends PredicateHelper { } (r.pushedFilters(), (untranslatableExprs ++ postScanFilters).toSeq) + case f: FileScanBuilder => + val postScanFilters = f.pushFilters(filters) + (f.pushedFilters, postScanFilters) case _ => (Nil, filters) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index 3f77b2147f..cc3c146106 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.csv.CSVDataSource -import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} +import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -84,10 +84,6 @@ case class CSVScan( dataSchema, readDataSchema, readPartitionSchema, parsedOptions, pushedFilters) } - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) - override def equals(obj: Any): Boolean = obj match { case c: CSVScan => super.equals(c) && dataSchema == c.dataSchema && options == c.options && equivalentFilters(pushedFilters, c.pushedFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala index f7a79bf319..2b6edd4f35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2.csv import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.StructFilters -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter @@ -32,7 +32,7 @@ case class CSVScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { CSVScan( @@ -42,17 +42,16 @@ case class CSVScanBuilder( readDataSchema(), readPartitionSchema(), options, - pushedFilters()) + pushedDataFilters, + partitionFilters, + dataFilters) } - private var _pushedFilters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.csvFilterPushDown) { - _pushedFilters = StructFilters.pushedFilters(filters, dataSchema) + StructFilters.pushedFilters(dataFilters, dataSchema) + } else { + Array.empty[Filter] } - filters } - - override def pushedFilters(): Array[Filter] = _pushedFilters } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala index 29eb8bec9a..9ab367136f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.json.JsonDataSource -import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} +import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -83,10 +83,6 @@ case class JsonScan( dataSchema, readDataSchema, readPartitionSchema, parsedOptions, pushedFilters) } - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) - override def equals(obj: Any): Boolean = obj match { case j: JsonScan => super.equals(j) && dataSchema == j.dataSchema && options == j.options && equivalentFilters(pushedFilters, j.pushedFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala index cf1204566d..c581617a4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.json import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.StructFilters -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter @@ -31,7 +31,7 @@ class JsonScanBuilder ( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { JsonScan( sparkSession, @@ -40,17 +40,16 @@ class JsonScanBuilder ( readDataSchema(), readPartitionSchema(), options, - pushedFilters()) + pushedDataFilters, + partitionFilters, + dataFilters) } - private var _pushedFilters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.jsonFilterPushDown) { - _pushedFilters = StructFilters.pushedFilters(filters, dataSchema) + StructFilters.pushedFilters(dataFilters, dataSchema) + } else { + Array.empty[Filter] } - filters } - - override def pushedFilters(): Array[Filter] = _pushedFilters } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 8fa7f8dc41..7619e3c503 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -68,8 +68,4 @@ case class OrcScan( override def getMetaData(): Map[String, String] = { super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) } - - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index dc59526bb3..cfa396f548 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2.orc import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.orc.OrcFilters import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder @@ -35,7 +35,7 @@ case class OrcScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. @@ -45,20 +45,17 @@ case class OrcScanBuilder( override protected val supportsNestedSchemaPruning: Boolean = true override def build(): Scan = { - OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, - readDataSchema(), readPartitionSchema(), options, pushedFilters()) + OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), + readPartitionSchema(), options, pushedDataFilters, partitionFilters, dataFilters) } - private var _pushedFilters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.orcFilterPushDown) { val dataTypeMap = OrcFilters.getSearchableTypeMap( readDataSchema(), SQLConf.get.caseSensitiveAnalysis) - _pushedFilters = OrcFilters.convertibleFilters(dataTypeMap, filters).toArray + OrcFilters.convertibleFilters(dataTypeMap, dataFilters).toArray + } else { + Array.empty[Filter] } - filters } - - override def pushedFilters(): Array[Filter] = _pushedFilters } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 60573ba10c..e277e33484 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -105,8 +105,4 @@ case class ParquetScan( override def getMetaData(): Map[String, String] = { super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) } - - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 4b3f4e7edc..ff5137e928 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder @@ -35,7 +35,7 @@ case class ParquetScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. @@ -63,17 +63,12 @@ case class ParquetScanBuilder( // The rebase mode doesn't matter here because the filters are used to determine // whether they is convertible. LegacyBehaviorPolicy.CORRECTED) - parquetFilters.convertibleFilters(this.filters).toArray + parquetFilters.convertibleFilters(pushedDataFilters).toArray } override protected val supportsNestedSchemaPruning: Boolean = true - private var filters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { - this.filters = filters - this.filters - } + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = dataFilters // Note: for Parquet, the actual filter push down happens in [[ParquetPartitionReaderFactory]]. // It requires the Parquet physical schema to determine whether a filter is convertible. @@ -82,6 +77,6 @@ case class ParquetScanBuilder( override def build(): Scan = { ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), - readPartitionSchema(), pushedParquetFilters, options) + readPartitionSchema(), pushedParquetFilters, options, partitionFilters, dataFilters) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala index a401d296d3..c7b0fec34b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.text.TextOptions -import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} +import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -72,10 +72,6 @@ case class TextScan( readPartitionSchema, textOptions) } - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) - override def equals(obj: Any): Boolean = obj match { case t: TextScan => super.equals(t) && options == t.options diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala index d929468b1b..0ebb098bfc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala @@ -33,6 +33,7 @@ case class TextScanBuilder( extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { - TextScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options) + TextScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options, + partitionFilters, dataFilters) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index e5c82603d8..f7f1d0b847 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -3023,16 +3023,14 @@ class JsonV2Suite extends JsonSuite { withSQLConf(SQLConf.JSON_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { file => val scanBuilder = getBuilder(file.getCanonicalPath) - assert(scanBuilder.pushFilters(filters) === filters) - assert(scanBuilder.pushedFilters() === filters) + assert(scanBuilder.pushDataFilters(filters) === filters) } } withSQLConf(SQLConf.JSON_FILTER_PUSHDOWN_ENABLED.key -> "false") { withTempPath { file => val scanBuilder = getBuilder(file.getCanonicalPath) - assert(scanBuilder.pushFilters(filters) === filters) - assert(scanBuilder.pushedFilters() === Array.empty[sources.Filter]) + assert(scanBuilder.pushDataFilters(filters) === Array.empty[sources.Filter]) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 526dad91e5..02f10aa0af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -447,7 +447,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [SUM(SALARY), SUM(BONUS)" + "PushedAggregates: [SUM(SALARY), SUM(BONUS)]" checkKeywordsExistsInExplain(query, expected_plan_fragment) } checkAnswer(query, Seq(Row(47100.0))) @@ -465,4 +465,22 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } checkAnswer(df2, Seq(Row(53000.00))) } + + test("scan with aggregate push-down: aggregate with partially pushed down filters" + + "will NOT push down") { + val df = spark.table("h2.test.employee") + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val sub = udf { (x: String) => x.substring(0, 3) } + val query = df.select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) + .filter("SALARY > 100") + .filter(name($"shortName")) + .agg(sum($"SALARY").as("sum_salary")) + query.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: []" + checkKeywordsExistsInExplain(query, expected_plan_fragment) + } + checkAnswer(query, Seq(Row(29000.0))) + } }