diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala index a2521f8b98..9bf2e569f7 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala @@ -58,7 +58,7 @@ class AvroRowReaderSuite val df = spark.read.format("avro").load(dir.getCanonicalPath) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: AvroScan) => f + case BatchScanExec(_, f: AvroScan, _) => f } val filePath = fileScan.get.fileIndex.inputFiles(0) val fileSize = new File(new URI(filePath)).length diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index d68460b57b..ffad851132 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -2199,7 +2199,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper { }.isEmpty) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: AvroScan) => f + case BatchScanExec(_, f: AvroScan, _) => f } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.nonEmpty) @@ -2232,7 +2232,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper { assert(filterCondition.isDefined) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: AvroScan) => f + case BatchScanExec(_, f: AvroScan, _) => f } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.isEmpty) @@ -2313,7 +2313,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper { .where("value = 'a'") val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: AvroScan) => f + case BatchScanExec(_, f: AvroScan, _) => f } assert(fileScan.nonEmpty) if (filtersPushdown) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java index 0c009f5c56..78684b3705 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java @@ -65,6 +65,10 @@ public interface Scan { * exception, data sources must overwrite this method to provide an implementation, if the * {@link Table} that creates this scan returns {@link TableCapability#BATCH_READ} support in its * {@link Table#capabilities()}. + *

+ * If the scan supports runtime filtering and implements {@link SupportsRuntimeFiltering}, + * this method may be called multiple times. Therefore, implementations can cache some state + * to avoid planning the job twice. * * @throws UnsupportedOperationException */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java new file mode 100644 index 0000000000..65d029dc30 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.sources.Filter; + +/** + * A mix-in interface for {@link Scan}. Data sources can implement this interface if they can + * filter initially planned {@link InputPartition}s using predicates Spark infers at runtime. + *

+ * Note that Spark will push runtime filters only if they are beneficial. + * + * @since 3.2.0 + */ +@Experimental +public interface SupportsRuntimeFiltering extends Scan { + /** + * Returns attributes this scan can be filtered by at runtime. + *

+ * Spark will call {@link #filter(Filter[])} if it can derive a runtime + * predicate for any of the filter attributes. + */ + NamedReference[] filterAttributes(); + + /** + * Filters this scan using runtime filters. + *

+ * The provided expressions must be interpreted as a set of filters that are ANDed together. + * Implementations may use the filters to prune initially planned {@link InputPartition}s. + *

+ * If the scan also implements {@link SupportsReportPartitioning}, it must preserve + * the originally reported partitioning during runtime filtering. While applying runtime filters, + * the scan may detect that some {@link InputPartition}s have no matching data. It can omit + * such partitions entirely only if it does not report a specific partitioning. Otherwise, + * the scan can replace the initially planned {@link InputPartition}s that have no matching + * data with empty {@link InputPartition}s but must preserve the overall number of partitions. + *

+ * Note that Spark will call {@link Scan#toBatch()} again after filtering the scan at runtime. + * + * @param filters data source filters used to filter the scan at runtime + */ + void filter(Filter[] filters); +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index 9293dbcdce..6f7f761146 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connector.catalog import java.time.{Instant, ZoneId} import java.time.temporal.ChronoUnit import java.util +import java.util.OptionalLong import scala.collection.JavaConverters._ import scala.collection.mutable @@ -245,21 +246,58 @@ class InMemoryTable( } } - class InMemoryBatchScan( - data: Array[InputPartition], + case class InMemoryStats(sizeInBytes: OptionalLong, numRows: OptionalLong) extends Statistics + + case class InMemoryBatchScan( + var data: Seq[InputPartition], readSchema: StructType, - tableSchema: StructType) extends Scan with Batch { - override def readSchema(): StructType = readSchema + tableSchema: StructType) + extends Scan with Batch with SupportsRuntimeFiltering with SupportsReportStatistics { override def toBatch: Batch = this - override def planInputPartitions(): Array[InputPartition] = data + override def estimateStatistics(): Statistics = { + if (data.isEmpty) { + return InMemoryStats(OptionalLong.of(0L), OptionalLong.of(0L)) + } + + val inputPartitions = data.map(_.asInstanceOf[BufferedRows]) + val numRows = inputPartitions.map(_.rows.size).sum + // we assume an average object header is 12 bytes + val objectHeaderSizeInBytes = 12L + val rowSizeInBytes = objectHeaderSizeInBytes + schema.defaultSize + val sizeInBytes = numRows * rowSizeInBytes + InMemoryStats(OptionalLong.of(sizeInBytes), OptionalLong.of(numRows)) + } + + override def planInputPartitions(): Array[InputPartition] = data.toArray override def createReaderFactory(): PartitionReaderFactory = { val metadataColumns = readSchema.map(_.name).filter(metadataColumnNames.contains) val nonMetadataColumns = readSchema.filterNot(f => metadataColumns.contains(f.name)) new BufferedRowsReaderFactory(metadataColumns, nonMetadataColumns, tableSchema) } + + override def filterAttributes(): Array[NamedReference] = { + val scanFields = readSchema.fields.map(_.name).toSet + partitioning.flatMap(_.references) + .filter(ref => scanFields.contains(ref.fieldNames.mkString("."))) + } + + override def filter(filters: Array[Filter]): Unit = { + if (partitioning.length == 1) { + filters.foreach { + case In(attrName, values) if attrName == partitioning.head.name => + val matchingKeys = values.map(_.toString).toSet + data = data.filter(partition => { + val key = partition.asInstanceOf[BufferedRows].key + matchingKeys.contains(key) + }) + + case _ => // skip + } + } + } } override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 9ecdf97e55..9e33723e5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy @@ -631,6 +631,25 @@ object DataSourceStrategy } } + /** + * Translates a runtime filter into a data source filter. + * + * Runtime filters usually contain a subquery that must be evaluated before the translation. + * If the underlying subquery hasn't completed yet, this method will throw an exception. + */ + protected[sql] def translateRuntimeFilter(expr: Expression): Option[Filter] = expr match { + case in @ InSubqueryExec(e @ PushableColumnAndNestedColumn(name), _, _, _) => + val values = in.values().getOrElse { + throw new IllegalStateException(s"Can't translate $in to source filter, no subquery result") + } + val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) + Some(sources.In(name, values.map(toScala))) + + case other => + logWarning(s"Can't translate $other to source filter, unsupported expression") + None + } + /** * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s * and can be handled by `relation`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 1987c9e63a..937d18d9eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -17,38 +17,96 @@ package org.apache.spark.sql.execution.datasources.v2 +import com.google.common.base.Objects + +import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan} +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeFiltering} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy /** * Physical plan node for scanning a batch of data from a data source v2. */ case class BatchScanExec( output: Seq[AttributeReference], - @transient scan: Scan) extends DataSourceV2ScanExecBase { + @transient scan: Scan, + runtimeFilters: Seq[Expression]) extends DataSourceV2ScanExecBase { @transient lazy val batch = scan.toBatch // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { - case other: BatchScanExec => this.batch == other.batch - case _ => false + case other: BatchScanExec => + this.batch == other.batch && this.runtimeFilters == other.runtimeFilters + case _ => + false } - override def hashCode(): Int = batch.hashCode() + override def hashCode(): Int = Objects.hashCode(batch, runtimeFilters) @transient override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions() + @transient private lazy val filteredPartitions: Seq[InputPartition] = { + val dataSourceFilters = runtimeFilters.flatMap { + case DynamicPruningExpression(e) => DataSourceStrategy.translateRuntimeFilter(e) + case _ => None + } + + if (dataSourceFilters.nonEmpty) { + val originalPartitioning = outputPartitioning + + // the cast is safe as runtime filters are only assigned if the scan can be filtered + val filterableScan = scan.asInstanceOf[SupportsRuntimeFiltering] + filterableScan.filter(dataSourceFilters.toArray) + + // call toBatch again to get filtered partitions + val newPartitions = scan.toBatch.planInputPartitions() + + originalPartitioning match { + case p: DataSourcePartitioning if p.numPartitions != newPartitions.size => + throw new SparkException( + "Data source must have preserved the original partitioning during runtime filtering; " + + s"reported num partitions: ${p.numPartitions}, " + + s"num partitions after runtime filtering: ${newPartitions.size}") + case _ => + // no validation is needed as the data source did not report any specific partitioning + } + + newPartitions + } else { + partitions + } + } + override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory() override lazy val inputRDD: RDD[InternalRow] = { - new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar, customMetrics) + if (filteredPartitions.isEmpty && outputPartitioning == SinglePartition) { + // return an empty RDD with 1 partition if dynamic filtering removed the only split + sparkContext.parallelize(Array.empty[InternalRow], 1) + } else { + new DataSourceRDD( + sparkContext, filteredPartitions, readerFactory, supportsColumnar, customMetrics) + } } override def doCanonicalize(): BatchScanExec = { - this.copy(output = output.map(QueryPlan.normalizeExpressions(_, output))) + this.copy( + output = output.map(QueryPlan.normalizeExpressions(_, output)), + runtimeFilters = QueryPlan.normalizePredicates( + runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)), + output)) + } + + override def simpleString(maxFields: Int): String = { + val truncatedOutputString = truncatedString(output, "[", ", ", "]", maxFields) + val runtimeFiltersString = s"RuntimeFilters: ${runtimeFilters.mkString("[", ",", "]")}" + val result = s"$nodeName$truncatedOutputString ${scan.description()} $runtimeFiltersString" + redact(result) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 7a01488f25..7be13791ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.{SparkSession, Strategy} import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable} -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Expression, NamedExpression, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, NamedExpression, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL @@ -114,8 +114,12 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat // projection and filters were already pushed down in the optimizer. // this uses PhysicalOperation to get the projection and ensure that if the batch scan does // not support columnar, a projection is added to convert the rows to UnsafeRow. - val batchExec = BatchScanExec(relation.output, relation.scan) - withProjectAndFilter(project, filters, batchExec, !batchExec.supportsColumnar) :: Nil + val (runtimeFilters, postScanFilters) = filters.partition { + case _: DynamicPruning => true + case _ => false + } + val batchExec = BatchScanExec(relation.output, relation.scan, runtimeFilters) + withProjectAndFilter(project, postScanFilters, batchExec, !batchExec.supportsColumnar) :: Nil case PhysicalOperation(p, f, r: StreamingDataSourceV2Relation) if r.startOffset.isDefined && r.endOffset.isDefined => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/CleanupDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/CleanupDynamicPruningFilters.scala index 29bed4d18e..bcaed524a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/CleanupDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/CleanupDynamicPruningFilters.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation /** * Removes the filter nodes with dynamic pruning that were not pushed down to the scan. @@ -42,6 +43,7 @@ object CleanupDynamicPruningFilters extends Rule[LogicalPlan] with PredicateHelp _.containsAnyPattern(DYNAMIC_PRUNING_EXPRESSION, DYNAMIC_PRUNING_SUBQUERY)) { // pass through anything that is pushed down into PhysicalOperation case p @ PhysicalOperation(_, _, LogicalRelation(_: HadoopFsRelation, _, _, _)) => p + case p @ PhysicalOperation(_, _, _: DataSourceV2ScanRelation) => p // remove any Filters with DynamicPruning that didn't get pushed down to PhysicalOperation. case f @ Filter(condition, _) => val newCondition = condition.transformWithPruning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala index 1d828b1396..3014e7a4d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala @@ -23,17 +23,19 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation /** * Dynamic partition pruning optimization is performed based on the type and * selectivity of the join operation. During query optimization, we insert a - * predicate on the partitioned table using the filter from the other side of + * predicate on the filterable table using the filter from the other side of * the join and a custom wrapper called DynamicPruning. * * The basic mechanism for DPP inserts a duplicated subquery with the filter from the other side, * when the following conditions are met: - * (1) the table to prune is partitioned by the JOIN key + * (1) the table to prune is filterable by the JOIN key * (2) the join operation is one of the following types: INNER, LEFT SEMI, * LEFT OUTER (partitioned on right), or RIGHT OUTER (partitioned on left) * @@ -49,9 +51,12 @@ import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRela object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with JoinSelectionHelper { /** - * Search the partitioned table scan for a given partition column in a logical plan + * Searches for a table scan that can be filtered for a given column in a logical plan. + * + * This methods tries to find either a v1 partitioned scan for a given partition column or + * a v2 scan that support runtime filtering on a given attribute. */ - def getPartitionTableScan(a: Expression, plan: LogicalPlan): Option[LogicalRelation] = { + def getFilterableTableScan(a: Expression, plan: LogicalPlan): Option[LogicalPlan] = { val srcInfo: Option[(Expression, LogicalPlan)] = findExpressionAndTrackLineageDown(a, plan) srcInfo.flatMap { case (resExp, l: LogicalRelation) => @@ -66,6 +71,13 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join } case _ => None } + case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeFiltering, _)) => + val filterAttrs = V2ExpressionUtils.resolveRefs[Attribute](scan.filterAttributes, r) + if (resExp.references.subsetOf(AttributeSet(filterAttrs))) { + Some(r) + } else { + None + } case _ => None } } @@ -85,7 +97,7 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join filteringKey: Expression, filteringPlan: LogicalPlan, joinKeys: Seq[Expression], - partScan: LogicalRelation, + partScan: LogicalPlan, canBuildBroadcast: Boolean): LogicalPlan = { val reuseEnabled = conf.exchangeReuseEnabled val index = joinKeys.indexOf(filteringKey) @@ -245,16 +257,16 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join // there should be a partitioned table and a filter on the dimension table, // otherwise the pruning will not trigger - var partScan = getPartitionTableScan(l, left) - if (partScan.isDefined && canPruneLeft(joinType) && + var filterableScan = getFilterableTableScan(l, left) + if (filterableScan.isDefined && canPruneLeft(joinType) && hasPartitionPruningFilter(right)) { - newLeft = insertPredicate(l, newLeft, r, right, rightKeys, partScan.get, + newLeft = insertPredicate(l, newLeft, r, right, rightKeys, filterableScan.get, canBuildBroadcastRight(joinType)) } else { - partScan = getPartitionTableScan(r, right) - if (partScan.isDefined && canPruneRight(joinType) && + filterableScan = getFilterableTableScan(r, right) + if (filterableScan.isDefined && canPruneRight(joinType) && hasPartitionPruningFilter(left) ) { - newRight = insertPredicate(r, newRight, l, left, leftKeys, partScan.get, + newRight = insertPredicate(r, newRight, l, left, leftKeys, filterableScan.get, canBuildBroadcastLeft(joinType)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala index 1f823a0734..b175701ac8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -22,8 +22,10 @@ import org.scalatest.GivenWhenThen import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression} import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.plans.ExistenceJoin +import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive._ +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec} import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} @@ -44,9 +46,14 @@ abstract class DynamicPartitionPruningSuiteBase import testImplicits._ + protected def initState(): Unit = {} + protected def runAnalyzeColumnCommands: Boolean = true + override def beforeAll(): Unit = { super.beforeAll() + initState() + val factData = Seq[(Int, Int, Int, Int)]( (1000, 1, 1, 10), (1010, 2, 1, 10), @@ -140,9 +147,11 @@ abstract class DynamicPartitionPruningSuiteBase .format(tableFormat) .saveAsTable("code_stats") - sql("ANALYZE TABLE fact_stats COMPUTE STATISTICS FOR COLUMNS store_id") - sql("ANALYZE TABLE dim_stats COMPUTE STATISTICS FOR COLUMNS store_id") - sql("ANALYZE TABLE code_stats COMPUTE STATISTICS FOR COLUMNS store_id") + if (runAnalyzeColumnCommands) { + sql("ANALYZE TABLE fact_stats COMPUTE STATISTICS FOR COLUMNS store_id") + sql("ANALYZE TABLE dim_stats COMPUTE STATISTICS FOR COLUMNS store_id") + sql("ANALYZE TABLE code_stats COMPUTE STATISTICS FOR COLUMNS store_id") + } } override def afterAll(): Unit = { @@ -244,6 +253,9 @@ abstract class DynamicPartitionPruningSuiteBase case s: FileSourceScanExec => s.partitionFilters.collect { case d: DynamicPruningExpression => d.child } + case s: BatchScanExec => s.runtimeFilters.collect { + case d: DynamicPruningExpression => d.child + } case _ => Nil } } @@ -314,86 +326,6 @@ abstract class DynamicPartitionPruningSuiteBase } } - /** - * Check the static scan metrics with and without DPP - */ - test("static scan metrics", - DisableAdaptiveExecution("DPP in AQE must reuse broadcast")) { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", - SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { - withTable("fact", "dim") { - val numPartitions = 10 - - spark.range(10) - .map { x => Tuple3(x, x + 1, 0) } - .toDF("did", "d1", "d2") - .write - .format(tableFormat) - .mode("overwrite") - .saveAsTable("dim") - - spark.range(100) - .map { x => Tuple2(x, x % numPartitions) } - .toDF("f1", "fid") - .write.partitionBy("fid") - .format(tableFormat) - .mode("overwrite") - .saveAsTable("fact") - - def getFactScan(plan: SparkPlan): SparkPlan = { - val scanOption = - find(plan) { - case s: FileSourceScanExec => - s.output.exists(_.find(_.argString(maxFields = 100).contains("fid")).isDefined) - case _ => false - } - assert(scanOption.isDefined) - scanOption.get - } - - // No dynamic partition pruning, so no static metrics - // All files in fact table are scanned - val df1 = sql("SELECT sum(f1) FROM fact") - df1.collect() - val scan1 = getFactScan(df1.queryExecution.executedPlan) - assert(!scan1.metrics.contains("staticFilesNum")) - assert(!scan1.metrics.contains("staticFilesSize")) - val allFilesNum = scan1.metrics("numFiles").value - val allFilesSize = scan1.metrics("filesSize").value - assert(scan1.metrics("numPartitions").value === numPartitions) - assert(scan1.metrics("pruningTime").value === -1) - - // No dynamic partition pruning, so no static metrics - // Only files from fid = 5 partition are scanned - val df2 = sql("SELECT sum(f1) FROM fact WHERE fid = 5") - df2.collect() - val scan2 = getFactScan(df2.queryExecution.executedPlan) - assert(!scan2.metrics.contains("staticFilesNum")) - assert(!scan2.metrics.contains("staticFilesSize")) - val partFilesNum = scan2.metrics("numFiles").value - val partFilesSize = scan2.metrics("filesSize").value - assert(0 < partFilesNum && partFilesNum < allFilesNum) - assert(0 < partFilesSize && partFilesSize < allFilesSize) - assert(scan2.metrics("numPartitions").value === 1) - assert(scan2.metrics("pruningTime").value === -1) - - // Dynamic partition pruning is used - // Static metrics are as-if reading the whole fact table - // "Regular" metrics are as-if reading only the "fid = 5" partition - val df3 = sql("SELECT sum(f1) FROM fact, dim WHERE fid = did AND d1 = 6") - df3.collect() - val scan3 = getFactScan(df3.queryExecution.executedPlan) - assert(scan3.metrics("staticFilesNum").value == allFilesNum) - assert(scan3.metrics("staticFilesSize").value == allFilesSize) - assert(scan3.metrics("numFiles").value == partFilesNum) - assert(scan3.metrics("filesSize").value == partFilesSize) - assert(scan3.metrics("numPartitions").value === 1) - assert(scan3.metrics("pruningTime").value !== -1) - } - } - } - test("DPP should not be rewritten as an existential join") { withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "1.5", @@ -1567,8 +1499,111 @@ abstract class DynamicPartitionPruningSuiteBase } } -class DynamicPartitionPruningSuiteAEOff extends DynamicPartitionPruningSuiteBase +abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningSuiteBase { + + import testImplicits._ + + /** + * Check the static scan metrics with and without DPP + */ + test("static scan metrics", + DisableAdaptiveExecution("DPP in AQE must reuse broadcast")) { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { + withTable("fact", "dim") { + val numPartitions = 10 + + spark.range(10) + .map { x => Tuple3(x, x + 1, 0) } + .toDF("did", "d1", "d2") + .write + .format(tableFormat) + .mode("overwrite") + .saveAsTable("dim") + + spark.range(100) + .map { x => Tuple2(x, x % numPartitions) } + .toDF("f1", "fid") + .write.partitionBy("fid") + .format(tableFormat) + .mode("overwrite") + .saveAsTable("fact") + + def getFactScan(plan: SparkPlan): SparkPlan = { + val scanOption = + find(plan) { + case s: FileSourceScanExec => + s.output.exists(_.find(_.argString(maxFields = 100).contains("fid")).isDefined) + case s: BatchScanExec => + // we use f1 col for v2 tables due to schema pruning + s.output.exists(_.find(_.argString(maxFields = 100).contains("f1")).isDefined) + case _ => false + } + assert(scanOption.isDefined) + scanOption.get + } + + // No dynamic partition pruning, so no static metrics + // All files in fact table are scanned + val df1 = sql("SELECT sum(f1) FROM fact") + df1.collect() + val scan1 = getFactScan(df1.queryExecution.executedPlan) + assert(!scan1.metrics.contains("staticFilesNum")) + assert(!scan1.metrics.contains("staticFilesSize")) + val allFilesNum = scan1.metrics("numFiles").value + val allFilesSize = scan1.metrics("filesSize").value + assert(scan1.metrics("numPartitions").value === numPartitions) + assert(scan1.metrics("pruningTime").value === -1) + + // No dynamic partition pruning, so no static metrics + // Only files from fid = 5 partition are scanned + val df2 = sql("SELECT sum(f1) FROM fact WHERE fid = 5") + df2.collect() + val scan2 = getFactScan(df2.queryExecution.executedPlan) + assert(!scan2.metrics.contains("staticFilesNum")) + assert(!scan2.metrics.contains("staticFilesSize")) + val partFilesNum = scan2.metrics("numFiles").value + val partFilesSize = scan2.metrics("filesSize").value + assert(0 < partFilesNum && partFilesNum < allFilesNum) + assert(0 < partFilesSize && partFilesSize < allFilesSize) + assert(scan2.metrics("numPartitions").value === 1) + assert(scan2.metrics("pruningTime").value === -1) + + // Dynamic partition pruning is used + // Static metrics are as-if reading the whole fact table + // "Regular" metrics are as-if reading only the "fid = 5" partition + val df3 = sql("SELECT sum(f1) FROM fact, dim WHERE fid = did AND d1 = 6") + df3.collect() + val scan3 = getFactScan(df3.queryExecution.executedPlan) + assert(scan3.metrics("staticFilesNum").value == allFilesNum) + assert(scan3.metrics("staticFilesSize").value == allFilesSize) + assert(scan3.metrics("numFiles").value == partFilesNum) + assert(scan3.metrics("filesSize").value == partFilesSize) + assert(scan3.metrics("numPartitions").value === 1) + assert(scan3.metrics("pruningTime").value !== -1) + } + } + } +} + +class DynamicPartitionPruningV1SuiteAEOff extends DynamicPartitionPruningV1Suite with DisableAdaptiveExecutionSuite -class DynamicPartitionPruningSuiteAEOn extends DynamicPartitionPruningSuiteBase +class DynamicPartitionPruningV1SuiteAEOn extends DynamicPartitionPruningV1Suite + with EnableAdaptiveExecutionSuite + +abstract class DynamicPartitionPruningV2Suite extends DynamicPartitionPruningSuiteBase { + override protected def runAnalyzeColumnCommands: Boolean = false + + override protected def initState(): Unit = { + spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) + spark.conf.set("spark.sql.defaultCatalog", "testcat") + } +} + +class DynamicPartitionPruningV2SuiteAEOff extends DynamicPartitionPruningV2Suite + with DisableAdaptiveExecutionSuite + +class DynamicPartitionPruningV2SuiteAEOn extends DynamicPartitionPruningV2Suite with EnableAdaptiveExecutionSuite diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 220800b4cb..6452e6778e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -759,7 +759,7 @@ class FileBasedDataSourceSuite extends QueryTest }.isEmpty) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: FileScan) => f + case BatchScanExec(_, f: FileScan, _) => f } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.nonEmpty) @@ -799,7 +799,7 @@ class FileBasedDataSourceSuite extends QueryTest assert(filterCondition.isDefined) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: FileScan) => f + case BatchScanExec(_, f: FileScan, _) => f } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.isEmpty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala index 378b52f9c6..47254f4231 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala @@ -38,7 +38,7 @@ class OrcV2SchemaPruningSuite extends SchemaPruningSuite with AdaptiveSparkPlanH override def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { val fileSourceScanSchemata = collect(df.queryExecution.executedPlan) { - case BatchScanExec(_, scan: OrcScan) => scan.readDataSchema + case BatchScanExec(_, scan: OrcScan, _) => scan.readDataSchema } assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 612b74a661..c9609e8402 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -823,7 +823,7 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils val oldCount = statusStore.executionsList().size val schema = new StructType().add("i", "int").add("j", "int") - val physicalPlan = BatchScanExec(schema.toAttributes, new CustomMetricScanBuilder()) + val physicalPlan = BatchScanExec(schema.toAttributes, new CustomMetricScanBuilder(), Seq.empty) val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) { override lazy val sparkPlan = physicalPlan override lazy val executedPlan = physicalPlan