[SPARK-35779][SQL] Dynamic filtering for Data Source V2

### What changes were proposed in this pull request?

This PR implemented the proposal per [design doc](https://docs.google.com/document/d/1RfFn2e9o_1uHJ8jFGsSakp-BZMizX1uRrJSybMe2a6M) for SPARK-35779.

### Why are the changes needed?

Spark supports dynamic partition filtering that enables reusing parts of the query to skip unnecessary partitions in the larger table during joins. This optimization has proven to be beneficial for star-schema queries which are common in the industry. Unfortunately, dynamic pruning is currently limited to partition pruning during joins and is only supported for built-in v1 sources. As more and more Spark users migrate to Data Source V2, it is important to generalize dynamic filtering and expose it to all v2 connectors.

Please, see the design doc for more information on this effort.

### Does this PR introduce _any_ user-facing change?

Yes, this PR adds a new optional mix-in interface for `Scan` in Data Source V2.

### How was this patch tested?

This PR comes with tests.

Closes #32921 from aokolnychyi/dynamic-filtering-wip.

Authored-by: Anton Okolnychyi <aokolnychyi@apple.com>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
This commit is contained in:
Anton Okolnychyi 2021-07-01 17:00:12 -07:00 committed by Liang-Chi Hsieh
parent a643076d4e
commit fceabe2372
14 changed files with 352 additions and 120 deletions

View file

@ -58,7 +58,7 @@ class AvroRowReaderSuite
val df = spark.read.format("avro").load(dir.getCanonicalPath) val df = spark.read.format("avro").load(dir.getCanonicalPath)
val fileScan = df.queryExecution.executedPlan collectFirst { val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan) => f case BatchScanExec(_, f: AvroScan, _) => f
} }
val filePath = fileScan.get.fileIndex.inputFiles(0) val filePath = fileScan.get.fileIndex.inputFiles(0)
val fileSize = new File(new URI(filePath)).length val fileSize = new File(new URI(filePath)).length

View file

@ -2199,7 +2199,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
}.isEmpty) }.isEmpty)
val fileScan = df.queryExecution.executedPlan collectFirst { val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan) => f case BatchScanExec(_, f: AvroScan, _) => f
} }
assert(fileScan.nonEmpty) assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty) assert(fileScan.get.partitionFilters.nonEmpty)
@ -2232,7 +2232,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
assert(filterCondition.isDefined) assert(filterCondition.isDefined)
val fileScan = df.queryExecution.executedPlan collectFirst { val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan) => f case BatchScanExec(_, f: AvroScan, _) => f
} }
assert(fileScan.nonEmpty) assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty) assert(fileScan.get.partitionFilters.isEmpty)
@ -2313,7 +2313,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
.where("value = 'a'") .where("value = 'a'")
val fileScan = df.queryExecution.executedPlan collectFirst { val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan) => f case BatchScanExec(_, f: AvroScan, _) => f
} }
assert(fileScan.nonEmpty) assert(fileScan.nonEmpty)
if (filtersPushdown) { if (filtersPushdown) {

View file

@ -65,6 +65,10 @@ public interface Scan {
* exception, data sources must overwrite this method to provide an implementation, if the * exception, data sources must overwrite this method to provide an implementation, if the
* {@link Table} that creates this scan returns {@link TableCapability#BATCH_READ} support in its * {@link Table} that creates this scan returns {@link TableCapability#BATCH_READ} support in its
* {@link Table#capabilities()}. * {@link Table#capabilities()}.
* <p>
* If the scan supports runtime filtering and implements {@link SupportsRuntimeFiltering},
* this method may be called multiple times. Therefore, implementations can cache some state
* to avoid planning the job twice.
* *
* @throws UnsupportedOperationException * @throws UnsupportedOperationException
*/ */

View file

@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connector.read;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.sources.Filter;
/**
* A mix-in interface for {@link Scan}. Data sources can implement this interface if they can
* filter initially planned {@link InputPartition}s using predicates Spark infers at runtime.
* <p>
* Note that Spark will push runtime filters only if they are beneficial.
*
* @since 3.2.0
*/
@Experimental
public interface SupportsRuntimeFiltering extends Scan {
/**
* Returns attributes this scan can be filtered by at runtime.
* <p>
* Spark will call {@link #filter(Filter[])} if it can derive a runtime
* predicate for any of the filter attributes.
*/
NamedReference[] filterAttributes();
/**
* Filters this scan using runtime filters.
* <p>
* The provided expressions must be interpreted as a set of filters that are ANDed together.
* Implementations may use the filters to prune initially planned {@link InputPartition}s.
* <p>
* If the scan also implements {@link SupportsReportPartitioning}, it must preserve
* the originally reported partitioning during runtime filtering. While applying runtime filters,
* the scan may detect that some {@link InputPartition}s have no matching data. It can omit
* such partitions entirely only if it does not report a specific partitioning. Otherwise,
* the scan can replace the initially planned {@link InputPartition}s that have no matching
* data with empty {@link InputPartition}s but must preserve the overall number of partitions.
* <p>
* Note that Spark will call {@link Scan#toBatch()} again after filtering the scan at runtime.
*
* @param filters data source filters used to filter the scan at runtime
*/
void filter(Filter[] filters);
}

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.connector.catalog
import java.time.{Instant, ZoneId} import java.time.{Instant, ZoneId}
import java.time.temporal.ChronoUnit import java.time.temporal.ChronoUnit
import java.util import java.util
import java.util.OptionalLong
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import scala.collection.mutable import scala.collection.mutable
@ -245,21 +246,58 @@ class InMemoryTable(
} }
} }
class InMemoryBatchScan( case class InMemoryStats(sizeInBytes: OptionalLong, numRows: OptionalLong) extends Statistics
data: Array[InputPartition],
case class InMemoryBatchScan(
var data: Seq[InputPartition],
readSchema: StructType, readSchema: StructType,
tableSchema: StructType) extends Scan with Batch { tableSchema: StructType)
override def readSchema(): StructType = readSchema extends Scan with Batch with SupportsRuntimeFiltering with SupportsReportStatistics {
override def toBatch: Batch = this override def toBatch: Batch = this
override def planInputPartitions(): Array[InputPartition] = data override def estimateStatistics(): Statistics = {
if (data.isEmpty) {
return InMemoryStats(OptionalLong.of(0L), OptionalLong.of(0L))
}
val inputPartitions = data.map(_.asInstanceOf[BufferedRows])
val numRows = inputPartitions.map(_.rows.size).sum
// we assume an average object header is 12 bytes
val objectHeaderSizeInBytes = 12L
val rowSizeInBytes = objectHeaderSizeInBytes + schema.defaultSize
val sizeInBytes = numRows * rowSizeInBytes
InMemoryStats(OptionalLong.of(sizeInBytes), OptionalLong.of(numRows))
}
override def planInputPartitions(): Array[InputPartition] = data.toArray
override def createReaderFactory(): PartitionReaderFactory = { override def createReaderFactory(): PartitionReaderFactory = {
val metadataColumns = readSchema.map(_.name).filter(metadataColumnNames.contains) val metadataColumns = readSchema.map(_.name).filter(metadataColumnNames.contains)
val nonMetadataColumns = readSchema.filterNot(f => metadataColumns.contains(f.name)) val nonMetadataColumns = readSchema.filterNot(f => metadataColumns.contains(f.name))
new BufferedRowsReaderFactory(metadataColumns, nonMetadataColumns, tableSchema) new BufferedRowsReaderFactory(metadataColumns, nonMetadataColumns, tableSchema)
} }
override def filterAttributes(): Array[NamedReference] = {
val scanFields = readSchema.fields.map(_.name).toSet
partitioning.flatMap(_.references)
.filter(ref => scanFields.contains(ref.fieldNames.mkString(".")))
}
override def filter(filters: Array[Filter]): Unit = {
if (partitioning.length == 1) {
filters.foreach {
case In(attrName, values) if attrName == partitioning.head.name =>
val matchingKeys = values.map(_.toString).toSet
data = data.filter(partition => {
val key = partition.asInstanceOf[BufferedRows].key
matchingKeys.contains(key)
})
case _ => // skip
}
}
}
} }
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {

View file

@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.execution.streaming.StreamingRelation
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
@ -631,6 +631,25 @@ object DataSourceStrategy
} }
} }
/**
* Translates a runtime filter into a data source filter.
*
* Runtime filters usually contain a subquery that must be evaluated before the translation.
* If the underlying subquery hasn't completed yet, this method will throw an exception.
*/
protected[sql] def translateRuntimeFilter(expr: Expression): Option[Filter] = expr match {
case in @ InSubqueryExec(e @ PushableColumnAndNestedColumn(name), _, _, _) =>
val values = in.values().getOrElse {
throw new IllegalStateException(s"Can't translate $in to source filter, no subquery result")
}
val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
Some(sources.In(name, values.map(toScala)))
case other =>
logWarning(s"Can't translate $other to source filter, unsupported expression")
None
}
/** /**
* Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s
* and can be handled by `relation`. * and can be handled by `relation`.

View file

@ -17,38 +17,96 @@
package org.apache.spark.sql.execution.datasources.v2 package org.apache.spark.sql.execution.datasources.v2
import com.google.common.base.Objects
import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan} import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeFiltering}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
/** /**
* Physical plan node for scanning a batch of data from a data source v2. * Physical plan node for scanning a batch of data from a data source v2.
*/ */
case class BatchScanExec( case class BatchScanExec(
output: Seq[AttributeReference], output: Seq[AttributeReference],
@transient scan: Scan) extends DataSourceV2ScanExecBase { @transient scan: Scan,
runtimeFilters: Seq[Expression]) extends DataSourceV2ScanExecBase {
@transient lazy val batch = scan.toBatch @transient lazy val batch = scan.toBatch
// TODO: unify the equal/hashCode implementation for all data source v2 query plans. // TODO: unify the equal/hashCode implementation for all data source v2 query plans.
override def equals(other: Any): Boolean = other match { override def equals(other: Any): Boolean = other match {
case other: BatchScanExec => this.batch == other.batch case other: BatchScanExec =>
case _ => false this.batch == other.batch && this.runtimeFilters == other.runtimeFilters
case _ =>
false
} }
override def hashCode(): Int = batch.hashCode() override def hashCode(): Int = Objects.hashCode(batch, runtimeFilters)
@transient override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions() @transient override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions()
@transient private lazy val filteredPartitions: Seq[InputPartition] = {
val dataSourceFilters = runtimeFilters.flatMap {
case DynamicPruningExpression(e) => DataSourceStrategy.translateRuntimeFilter(e)
case _ => None
}
if (dataSourceFilters.nonEmpty) {
val originalPartitioning = outputPartitioning
// the cast is safe as runtime filters are only assigned if the scan can be filtered
val filterableScan = scan.asInstanceOf[SupportsRuntimeFiltering]
filterableScan.filter(dataSourceFilters.toArray)
// call toBatch again to get filtered partitions
val newPartitions = scan.toBatch.planInputPartitions()
originalPartitioning match {
case p: DataSourcePartitioning if p.numPartitions != newPartitions.size =>
throw new SparkException(
"Data source must have preserved the original partitioning during runtime filtering; " +
s"reported num partitions: ${p.numPartitions}, " +
s"num partitions after runtime filtering: ${newPartitions.size}")
case _ =>
// no validation is needed as the data source did not report any specific partitioning
}
newPartitions
} else {
partitions
}
}
override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory() override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory()
override lazy val inputRDD: RDD[InternalRow] = { override lazy val inputRDD: RDD[InternalRow] = {
new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar, customMetrics) if (filteredPartitions.isEmpty && outputPartitioning == SinglePartition) {
// return an empty RDD with 1 partition if dynamic filtering removed the only split
sparkContext.parallelize(Array.empty[InternalRow], 1)
} else {
new DataSourceRDD(
sparkContext, filteredPartitions, readerFactory, supportsColumnar, customMetrics)
}
} }
override def doCanonicalize(): BatchScanExec = { override def doCanonicalize(): BatchScanExec = {
this.copy(output = output.map(QueryPlan.normalizeExpressions(_, output))) this.copy(
output = output.map(QueryPlan.normalizeExpressions(_, output)),
runtimeFilters = QueryPlan.normalizePredicates(
runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)),
output))
}
override def simpleString(maxFields: Int): String = {
val truncatedOutputString = truncatedString(output, "[", ", ", "]", maxFields)
val runtimeFiltersString = s"RuntimeFilters: ${runtimeFilters.mkString("[", ",", "]")}"
val result = s"$nodeName$truncatedOutputString ${scan.description()} $runtimeFiltersString"
redact(result)
} }
} }

View file

@ -21,7 +21,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.sql.{SparkSession, Strategy} import org.apache.spark.sql.{SparkSession, Strategy}
import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable} import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable}
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Expression, NamedExpression, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, NamedExpression, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.catalyst.util.toPrettySQL
@ -114,8 +114,12 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
// projection and filters were already pushed down in the optimizer. // projection and filters were already pushed down in the optimizer.
// this uses PhysicalOperation to get the projection and ensure that if the batch scan does // this uses PhysicalOperation to get the projection and ensure that if the batch scan does
// not support columnar, a projection is added to convert the rows to UnsafeRow. // not support columnar, a projection is added to convert the rows to UnsafeRow.
val batchExec = BatchScanExec(relation.output, relation.scan) val (runtimeFilters, postScanFilters) = filters.partition {
withProjectAndFilter(project, filters, batchExec, !batchExec.supportsColumnar) :: Nil case _: DynamicPruning => true
case _ => false
}
val batchExec = BatchScanExec(relation.output, relation.scan, runtimeFilters)
withProjectAndFilter(project, postScanFilters, batchExec, !batchExec.supportsColumnar) :: Nil
case PhysicalOperation(p, f, r: StreamingDataSourceV2Relation) case PhysicalOperation(p, f, r: StreamingDataSourceV2Relation)
if r.startOffset.isDefined && r.endOffset.isDefined => if r.startOffset.isDefined && r.endOffset.isDefined =>

View file

@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
/** /**
* Removes the filter nodes with dynamic pruning that were not pushed down to the scan. * Removes the filter nodes with dynamic pruning that were not pushed down to the scan.
@ -42,6 +43,7 @@ object CleanupDynamicPruningFilters extends Rule[LogicalPlan] with PredicateHelp
_.containsAnyPattern(DYNAMIC_PRUNING_EXPRESSION, DYNAMIC_PRUNING_SUBQUERY)) { _.containsAnyPattern(DYNAMIC_PRUNING_EXPRESSION, DYNAMIC_PRUNING_SUBQUERY)) {
// pass through anything that is pushed down into PhysicalOperation // pass through anything that is pushed down into PhysicalOperation
case p @ PhysicalOperation(_, _, LogicalRelation(_: HadoopFsRelation, _, _, _)) => p case p @ PhysicalOperation(_, _, LogicalRelation(_: HadoopFsRelation, _, _, _)) => p
case p @ PhysicalOperation(_, _, _: DataSourceV2ScanRelation) => p
// remove any Filters with DynamicPruning that didn't get pushed down to PhysicalOperation. // remove any Filters with DynamicPruning that didn't get pushed down to PhysicalOperation.
case f @ Filter(condition, _) => case f @ Filter(condition, _) =>
val newCondition = condition.transformWithPruning( val newCondition = condition.transformWithPruning(

View file

@ -23,17 +23,19 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
/** /**
* Dynamic partition pruning optimization is performed based on the type and * Dynamic partition pruning optimization is performed based on the type and
* selectivity of the join operation. During query optimization, we insert a * selectivity of the join operation. During query optimization, we insert a
* predicate on the partitioned table using the filter from the other side of * predicate on the filterable table using the filter from the other side of
* the join and a custom wrapper called DynamicPruning. * the join and a custom wrapper called DynamicPruning.
* *
* The basic mechanism for DPP inserts a duplicated subquery with the filter from the other side, * The basic mechanism for DPP inserts a duplicated subquery with the filter from the other side,
* when the following conditions are met: * when the following conditions are met:
* (1) the table to prune is partitioned by the JOIN key * (1) the table to prune is filterable by the JOIN key
* (2) the join operation is one of the following types: INNER, LEFT SEMI, * (2) the join operation is one of the following types: INNER, LEFT SEMI,
* LEFT OUTER (partitioned on right), or RIGHT OUTER (partitioned on left) * LEFT OUTER (partitioned on right), or RIGHT OUTER (partitioned on left)
* *
@ -49,9 +51,12 @@ import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRela
object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with JoinSelectionHelper { object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with JoinSelectionHelper {
/** /**
* Search the partitioned table scan for a given partition column in a logical plan * Searches for a table scan that can be filtered for a given column in a logical plan.
*
* This methods tries to find either a v1 partitioned scan for a given partition column or
* a v2 scan that support runtime filtering on a given attribute.
*/ */
def getPartitionTableScan(a: Expression, plan: LogicalPlan): Option[LogicalRelation] = { def getFilterableTableScan(a: Expression, plan: LogicalPlan): Option[LogicalPlan] = {
val srcInfo: Option[(Expression, LogicalPlan)] = findExpressionAndTrackLineageDown(a, plan) val srcInfo: Option[(Expression, LogicalPlan)] = findExpressionAndTrackLineageDown(a, plan)
srcInfo.flatMap { srcInfo.flatMap {
case (resExp, l: LogicalRelation) => case (resExp, l: LogicalRelation) =>
@ -66,6 +71,13 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join
} }
case _ => None case _ => None
} }
case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeFiltering, _)) =>
val filterAttrs = V2ExpressionUtils.resolveRefs[Attribute](scan.filterAttributes, r)
if (resExp.references.subsetOf(AttributeSet(filterAttrs))) {
Some(r)
} else {
None
}
case _ => None case _ => None
} }
} }
@ -85,7 +97,7 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join
filteringKey: Expression, filteringKey: Expression,
filteringPlan: LogicalPlan, filteringPlan: LogicalPlan,
joinKeys: Seq[Expression], joinKeys: Seq[Expression],
partScan: LogicalRelation, partScan: LogicalPlan,
canBuildBroadcast: Boolean): LogicalPlan = { canBuildBroadcast: Boolean): LogicalPlan = {
val reuseEnabled = conf.exchangeReuseEnabled val reuseEnabled = conf.exchangeReuseEnabled
val index = joinKeys.indexOf(filteringKey) val index = joinKeys.indexOf(filteringKey)
@ -245,16 +257,16 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join
// there should be a partitioned table and a filter on the dimension table, // there should be a partitioned table and a filter on the dimension table,
// otherwise the pruning will not trigger // otherwise the pruning will not trigger
var partScan = getPartitionTableScan(l, left) var filterableScan = getFilterableTableScan(l, left)
if (partScan.isDefined && canPruneLeft(joinType) && if (filterableScan.isDefined && canPruneLeft(joinType) &&
hasPartitionPruningFilter(right)) { hasPartitionPruningFilter(right)) {
newLeft = insertPredicate(l, newLeft, r, right, rightKeys, partScan.get, newLeft = insertPredicate(l, newLeft, r, right, rightKeys, filterableScan.get,
canBuildBroadcastRight(joinType)) canBuildBroadcastRight(joinType))
} else { } else {
partScan = getPartitionTableScan(r, right) filterableScan = getFilterableTableScan(r, right)
if (partScan.isDefined && canPruneRight(joinType) && if (filterableScan.isDefined && canPruneRight(joinType) &&
hasPartitionPruningFilter(left) ) { hasPartitionPruningFilter(left) ) {
newRight = insertPredicate(r, newRight, l, left, leftKeys, partScan.get, newRight = insertPredicate(r, newRight, l, left, leftKeys, filterableScan.get,
canBuildBroadcastLeft(joinType)) canBuildBroadcastLeft(joinType))
} }
} }

View file

@ -22,8 +22,10 @@ import org.scalatest.GivenWhenThen
import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression} import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression}
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
import org.apache.spark.sql.catalyst.plans.ExistenceJoin import org.apache.spark.sql.catalyst.plans.ExistenceJoin
import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog
import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive._ import org.apache.spark.sql.execution.adaptive._
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec}
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper}
@ -44,9 +46,14 @@ abstract class DynamicPartitionPruningSuiteBase
import testImplicits._ import testImplicits._
protected def initState(): Unit = {}
protected def runAnalyzeColumnCommands: Boolean = true
override def beforeAll(): Unit = { override def beforeAll(): Unit = {
super.beforeAll() super.beforeAll()
initState()
val factData = Seq[(Int, Int, Int, Int)]( val factData = Seq[(Int, Int, Int, Int)](
(1000, 1, 1, 10), (1000, 1, 1, 10),
(1010, 2, 1, 10), (1010, 2, 1, 10),
@ -140,9 +147,11 @@ abstract class DynamicPartitionPruningSuiteBase
.format(tableFormat) .format(tableFormat)
.saveAsTable("code_stats") .saveAsTable("code_stats")
sql("ANALYZE TABLE fact_stats COMPUTE STATISTICS FOR COLUMNS store_id") if (runAnalyzeColumnCommands) {
sql("ANALYZE TABLE dim_stats COMPUTE STATISTICS FOR COLUMNS store_id") sql("ANALYZE TABLE fact_stats COMPUTE STATISTICS FOR COLUMNS store_id")
sql("ANALYZE TABLE code_stats COMPUTE STATISTICS FOR COLUMNS store_id") sql("ANALYZE TABLE dim_stats COMPUTE STATISTICS FOR COLUMNS store_id")
sql("ANALYZE TABLE code_stats COMPUTE STATISTICS FOR COLUMNS store_id")
}
} }
override def afterAll(): Unit = { override def afterAll(): Unit = {
@ -244,6 +253,9 @@ abstract class DynamicPartitionPruningSuiteBase
case s: FileSourceScanExec => s.partitionFilters.collect { case s: FileSourceScanExec => s.partitionFilters.collect {
case d: DynamicPruningExpression => d.child case d: DynamicPruningExpression => d.child
} }
case s: BatchScanExec => s.runtimeFilters.collect {
case d: DynamicPruningExpression => d.child
}
case _ => Nil case _ => Nil
} }
} }
@ -314,86 +326,6 @@ abstract class DynamicPartitionPruningSuiteBase
} }
} }
/**
* Check the static scan metrics with and without DPP
*/
test("static scan metrics",
DisableAdaptiveExecution("DPP in AQE must reuse broadcast")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true",
SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false",
SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") {
withTable("fact", "dim") {
val numPartitions = 10
spark.range(10)
.map { x => Tuple3(x, x + 1, 0) }
.toDF("did", "d1", "d2")
.write
.format(tableFormat)
.mode("overwrite")
.saveAsTable("dim")
spark.range(100)
.map { x => Tuple2(x, x % numPartitions) }
.toDF("f1", "fid")
.write.partitionBy("fid")
.format(tableFormat)
.mode("overwrite")
.saveAsTable("fact")
def getFactScan(plan: SparkPlan): SparkPlan = {
val scanOption =
find(plan) {
case s: FileSourceScanExec =>
s.output.exists(_.find(_.argString(maxFields = 100).contains("fid")).isDefined)
case _ => false
}
assert(scanOption.isDefined)
scanOption.get
}
// No dynamic partition pruning, so no static metrics
// All files in fact table are scanned
val df1 = sql("SELECT sum(f1) FROM fact")
df1.collect()
val scan1 = getFactScan(df1.queryExecution.executedPlan)
assert(!scan1.metrics.contains("staticFilesNum"))
assert(!scan1.metrics.contains("staticFilesSize"))
val allFilesNum = scan1.metrics("numFiles").value
val allFilesSize = scan1.metrics("filesSize").value
assert(scan1.metrics("numPartitions").value === numPartitions)
assert(scan1.metrics("pruningTime").value === -1)
// No dynamic partition pruning, so no static metrics
// Only files from fid = 5 partition are scanned
val df2 = sql("SELECT sum(f1) FROM fact WHERE fid = 5")
df2.collect()
val scan2 = getFactScan(df2.queryExecution.executedPlan)
assert(!scan2.metrics.contains("staticFilesNum"))
assert(!scan2.metrics.contains("staticFilesSize"))
val partFilesNum = scan2.metrics("numFiles").value
val partFilesSize = scan2.metrics("filesSize").value
assert(0 < partFilesNum && partFilesNum < allFilesNum)
assert(0 < partFilesSize && partFilesSize < allFilesSize)
assert(scan2.metrics("numPartitions").value === 1)
assert(scan2.metrics("pruningTime").value === -1)
// Dynamic partition pruning is used
// Static metrics are as-if reading the whole fact table
// "Regular" metrics are as-if reading only the "fid = 5" partition
val df3 = sql("SELECT sum(f1) FROM fact, dim WHERE fid = did AND d1 = 6")
df3.collect()
val scan3 = getFactScan(df3.queryExecution.executedPlan)
assert(scan3.metrics("staticFilesNum").value == allFilesNum)
assert(scan3.metrics("staticFilesSize").value == allFilesSize)
assert(scan3.metrics("numFiles").value == partFilesNum)
assert(scan3.metrics("filesSize").value == partFilesSize)
assert(scan3.metrics("numPartitions").value === 1)
assert(scan3.metrics("pruningTime").value !== -1)
}
}
}
test("DPP should not be rewritten as an existential join") { test("DPP should not be rewritten as an existential join") {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true",
SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "1.5", SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "1.5",
@ -1567,8 +1499,111 @@ abstract class DynamicPartitionPruningSuiteBase
} }
} }
class DynamicPartitionPruningSuiteAEOff extends DynamicPartitionPruningSuiteBase abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningSuiteBase {
import testImplicits._
/**
* Check the static scan metrics with and without DPP
*/
test("static scan metrics",
DisableAdaptiveExecution("DPP in AQE must reuse broadcast")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true",
SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false",
SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") {
withTable("fact", "dim") {
val numPartitions = 10
spark.range(10)
.map { x => Tuple3(x, x + 1, 0) }
.toDF("did", "d1", "d2")
.write
.format(tableFormat)
.mode("overwrite")
.saveAsTable("dim")
spark.range(100)
.map { x => Tuple2(x, x % numPartitions) }
.toDF("f1", "fid")
.write.partitionBy("fid")
.format(tableFormat)
.mode("overwrite")
.saveAsTable("fact")
def getFactScan(plan: SparkPlan): SparkPlan = {
val scanOption =
find(plan) {
case s: FileSourceScanExec =>
s.output.exists(_.find(_.argString(maxFields = 100).contains("fid")).isDefined)
case s: BatchScanExec =>
// we use f1 col for v2 tables due to schema pruning
s.output.exists(_.find(_.argString(maxFields = 100).contains("f1")).isDefined)
case _ => false
}
assert(scanOption.isDefined)
scanOption.get
}
// No dynamic partition pruning, so no static metrics
// All files in fact table are scanned
val df1 = sql("SELECT sum(f1) FROM fact")
df1.collect()
val scan1 = getFactScan(df1.queryExecution.executedPlan)
assert(!scan1.metrics.contains("staticFilesNum"))
assert(!scan1.metrics.contains("staticFilesSize"))
val allFilesNum = scan1.metrics("numFiles").value
val allFilesSize = scan1.metrics("filesSize").value
assert(scan1.metrics("numPartitions").value === numPartitions)
assert(scan1.metrics("pruningTime").value === -1)
// No dynamic partition pruning, so no static metrics
// Only files from fid = 5 partition are scanned
val df2 = sql("SELECT sum(f1) FROM fact WHERE fid = 5")
df2.collect()
val scan2 = getFactScan(df2.queryExecution.executedPlan)
assert(!scan2.metrics.contains("staticFilesNum"))
assert(!scan2.metrics.contains("staticFilesSize"))
val partFilesNum = scan2.metrics("numFiles").value
val partFilesSize = scan2.metrics("filesSize").value
assert(0 < partFilesNum && partFilesNum < allFilesNum)
assert(0 < partFilesSize && partFilesSize < allFilesSize)
assert(scan2.metrics("numPartitions").value === 1)
assert(scan2.metrics("pruningTime").value === -1)
// Dynamic partition pruning is used
// Static metrics are as-if reading the whole fact table
// "Regular" metrics are as-if reading only the "fid = 5" partition
val df3 = sql("SELECT sum(f1) FROM fact, dim WHERE fid = did AND d1 = 6")
df3.collect()
val scan3 = getFactScan(df3.queryExecution.executedPlan)
assert(scan3.metrics("staticFilesNum").value == allFilesNum)
assert(scan3.metrics("staticFilesSize").value == allFilesSize)
assert(scan3.metrics("numFiles").value == partFilesNum)
assert(scan3.metrics("filesSize").value == partFilesSize)
assert(scan3.metrics("numPartitions").value === 1)
assert(scan3.metrics("pruningTime").value !== -1)
}
}
}
}
class DynamicPartitionPruningV1SuiteAEOff extends DynamicPartitionPruningV1Suite
with DisableAdaptiveExecutionSuite with DisableAdaptiveExecutionSuite
class DynamicPartitionPruningSuiteAEOn extends DynamicPartitionPruningSuiteBase class DynamicPartitionPruningV1SuiteAEOn extends DynamicPartitionPruningV1Suite
with EnableAdaptiveExecutionSuite
abstract class DynamicPartitionPruningV2Suite extends DynamicPartitionPruningSuiteBase {
override protected def runAnalyzeColumnCommands: Boolean = false
override protected def initState(): Unit = {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
spark.conf.set("spark.sql.defaultCatalog", "testcat")
}
}
class DynamicPartitionPruningV2SuiteAEOff extends DynamicPartitionPruningV2Suite
with DisableAdaptiveExecutionSuite
class DynamicPartitionPruningV2SuiteAEOn extends DynamicPartitionPruningV2Suite
with EnableAdaptiveExecutionSuite with EnableAdaptiveExecutionSuite

View file

@ -759,7 +759,7 @@ class FileBasedDataSourceSuite extends QueryTest
}.isEmpty) }.isEmpty)
val fileScan = df.queryExecution.executedPlan collectFirst { val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan) => f case BatchScanExec(_, f: FileScan, _) => f
} }
assert(fileScan.nonEmpty) assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty) assert(fileScan.get.partitionFilters.nonEmpty)
@ -799,7 +799,7 @@ class FileBasedDataSourceSuite extends QueryTest
assert(filterCondition.isDefined) assert(filterCondition.isDefined)
val fileScan = df.queryExecution.executedPlan collectFirst { val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan) => f case BatchScanExec(_, f: FileScan, _) => f
} }
assert(fileScan.nonEmpty) assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty) assert(fileScan.get.partitionFilters.isEmpty)

View file

@ -38,7 +38,7 @@ class OrcV2SchemaPruningSuite extends SchemaPruningSuite with AdaptiveSparkPlanH
override def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { override def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = {
val fileSourceScanSchemata = val fileSourceScanSchemata =
collect(df.queryExecution.executedPlan) { collect(df.queryExecution.executedPlan) {
case BatchScanExec(_, scan: OrcScan) => scan.readDataSchema case BatchScanExec(_, scan: OrcScan, _) => scan.readDataSchema
} }
assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size,
s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " +

View file

@ -823,7 +823,7 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils
val oldCount = statusStore.executionsList().size val oldCount = statusStore.executionsList().size
val schema = new StructType().add("i", "int").add("j", "int") val schema = new StructType().add("i", "int").add("j", "int")
val physicalPlan = BatchScanExec(schema.toAttributes, new CustomMetricScanBuilder()) val physicalPlan = BatchScanExec(schema.toAttributes, new CustomMetricScanBuilder(), Seq.empty)
val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) { val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) {
override lazy val sparkPlan = physicalPlan override lazy val sparkPlan = physicalPlan
override lazy val executedPlan = physicalPlan override lazy val executedPlan = physicalPlan