[SPARK-38647][SQL] Add SupportsReportOrdering mix in interface for Scan (DataSourceV2)

### What changes were proposed in this pull request?
As `SupportsReportPartitioning` allows implementations of `Scan` provide Spark with information about the exiting partitioning of data read by a `DataSourceV2`, a similar mix in interface `SupportsReportOrdering` should provide order information.

### Why are the changes needed?
This prevents Spark from sorting data if they already exhibit a certain order provided by the source.

### Does this PR introduce _any_ user-facing change?
It adds `SupportsReportOrdering` mix in interface.

### How was this patch tested?
This adds tests to `DataSourceV2Suite`, similar to the test for `SupportsReportPartitioning`.

Closes #35965 from EnricoMi/branch-datasourcev2-output-ordering.

Authored-by: Enrico Minack <github@enrico.minack.dev>
Signed-off-by: Chao Sun <sunchao@apple.com>
master
Enrico Minack 2022-06-21 10:40:06 -07:00 committed by Chao Sun
parent db0e972c09
commit b588d070eb
22 changed files with 427 additions and 50 deletions

View File

@ -59,7 +59,7 @@ class AvroRowReaderSuite
val df = spark.read.format("avro").load(dir.getCanonicalPath)
val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan, _, _) => f
case BatchScanExec(_, f: AvroScan, _, _, _) => f
}
val filePath = fileScan.get.fileIndex.inputFiles(0)
val fileSize = new File(new URI(filePath)).length

View File

@ -2335,7 +2335,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
})
val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan, _, _) => f
case BatchScanExec(_, f: AvroScan, _, _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty)
@ -2368,7 +2368,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
assert(filterCondition.isDefined)
val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan, _, _) => f
case BatchScanExec(_, f: AvroScan, _, _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty)
@ -2449,7 +2449,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
.where("value = 'a'")
val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan, _, _) => f
case BatchScanExec(_, f: AvroScan, _, _, _) => f
}
assert(fileScan.nonEmpty)
if (filtersPushdown) {

View File

@ -372,7 +372,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
private def checkAggregatePushed(df: DataFrame, funcName: String): Unit = {
df.queryExecution.optimizedPlan.collect {
case DataSourceV2ScanRelation(_, scan, _, _) =>
case DataSourceV2ScanRelation(_, scan, _, _, _) =>
assert(scan.isInstanceOf[V1ScanWrapper])
val wrapper = scan.asInstanceOf[V1ScanWrapper]
assert(wrapper.pushedDownOperators.aggregation.isDefined)

View File

@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connector.read;
import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.SortOrder;
/**
* A mix in interface for {@link Scan}. Data sources can implement this interface to
* report the order of data in each partition to Spark.
* Global order is part of the partitioning, see {@link SupportsReportPartitioning}.
* <p>
* Spark uses ordering information to exploit existing order to avoid sorting required by
* subsequent operations.
*
* @since 3.4.0
*/
@Evolving
public interface SupportsReportOrdering extends Scan {
/**
* Returns the order in each partition of this data source scan.
*/
SortOrder[] outputOrdering();
}

View File

@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.datasources.v2
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{ExposesMetadataColumns, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils}
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, MetadataColumn, SupportsMetadataColumns, Table, TableCapability}
@ -115,12 +115,14 @@ case class DataSourceV2Relation(
* @param output the output attributes of this relation
* @param keyGroupedPartitioning if set, the partitioning expressions that are used to split the
* rows in the scan across different partitions
* @param ordering if set, the ordering provided by the scan
*/
case class DataSourceV2ScanRelation(
relation: DataSourceV2Relation,
scan: Scan,
output: Seq[AttributeReference],
keyGroupedPartitioning: Option[Seq[Expression]] = None) extends LeafNode with NamedRelation {
keyGroupedPartitioning: Option[Seq[Expression]] = None,
ordering: Option[Seq[SortOrder]] = None) extends LeafNode with NamedRelation {
override def name: String = relation.table.name()

View File

@ -3689,7 +3689,7 @@ class Dataset[T] private[sql](
case r: HiveTableRelation =>
r.tableMeta.storage.locationUri.map(_.toString).toArray
case DataSourceV2ScanRelation(DataSourceV2Relation(table: FileTable, _, _, _, _),
_, _, _) =>
_, _, _, _) =>
table.fileIndex.inputFiles
}.flatten
files.toSet.toArray

View File

@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
import org.apache.spark.sql.execution.datasources.SchemaPruning
import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioning, V2ScanRelationPushDown, V2Writes}
import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, V2ScanRelationPushDown, V2Writes}
import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning}
import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs}
@ -40,7 +40,7 @@ class SparkOptimizer(
Seq(SchemaPruning) :+
GroupBasedRowLevelOperationScanPlanning :+
V2ScanRelationPushDown :+
V2ScanPartitioning :+
V2ScanPartitioningAndOrdering :+
V2Writes :+
PruneFileSourcePartitions
@ -86,7 +86,7 @@ class SparkOptimizer(
ExtractPythonUDFs.ruleName :+
GroupBasedRowLevelOperationScanPlanning.ruleName :+
V2ScanRelationPushDown.ruleName :+
V2ScanPartitioning.ruleName :+
V2ScanPartitioningAndOrdering.ruleName :+
V2Writes.ruleName :+
ReplaceCTERefWithRepartition.ruleName

View File

@ -37,7 +37,8 @@ case class BatchScanExec(
output: Seq[AttributeReference],
@transient scan: Scan,
runtimeFilters: Seq[Expression],
keyGroupedPartitioning: Option[Seq[Expression]] = None) extends DataSourceV2ScanExecBase {
keyGroupedPartitioning: Option[Seq[Expression]] = None,
ordering: Option[Seq[SortOrder]] = None) extends DataSourceV2ScanExecBase {
@transient lazy val batch = scan.toBatch

View File

@ -32,7 +32,8 @@ case class ContinuousScanExec(
@transient scan: Scan,
@transient stream: ContinuousStream,
@transient start: Offset,
keyGroupedPartitioning: Option[Seq[Expression]] = None) extends DataSourceV2ScanExecBase {
keyGroupedPartitioning: Option[Seq[Expression]] = None,
ordering: Option[Seq[SortOrder]] = None) extends DataSourceV2ScanExecBase {
// TODO: unify the equal/hashCode implementation for all data source v2 query plans.
override def equals(other: Any): Boolean = other match {

View File

@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, RowOrdering}
import org.apache.spark.sql.catalyst.expressions.{Expression, RowOrdering, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.truncatedString
@ -50,6 +50,10 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
* `SupportsReportPartitioning` */
def keyGroupedPartitioning: Option[Seq[Expression]]
/** Optional ordering expressions provided by the V2 data sources, through
* `SupportsReportOrdering` */
def ordering: Option[Seq[SortOrder]]
protected def inputPartitions: Seq[InputPartition]
override def simpleString(maxFields: Int): String = {
@ -138,6 +142,12 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
}
}
override def outputOrdering: Seq[SortOrder] = {
// when multiple partitions are grouped together, ordering inside partitions is not preserved
val partitioningPreservesOrdering = groupedPartitions.forall(_.forall(_._2.length <= 1))
ordering.filter(_ => partitioningPreservesOrdering).getOrElse(super.outputOrdering)
}
override def supportsColumnar: Boolean = {
require(inputPartitions.forall(readerFactory.supportColumnarReads) ||
!inputPartitions.exists(readerFactory.supportColumnarReads),

View File

@ -106,7 +106,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(project, filters, DataSourceV2ScanRelation(
_, V1ScanWrapper(scan, pushed, pushedDownOperators), output, _)) =>
_, V1ScanWrapper(scan, pushed, pushedDownOperators), output, _, _)) =>
val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext)
if (v1Relation.schema != scan.readSchema()) {
throw QueryExecutionErrors.fallbackV1RelationReportsInconsistentSchemaError(
@ -127,7 +127,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
withProjectAndFilter(project, filters, dsScan, needsUnsafeConversion = false) :: Nil
case PhysicalOperation(project, filters,
DataSourceV2ScanRelation(_, scan: LocalScan, output, _)) =>
DataSourceV2ScanRelation(_, scan: LocalScan, output, _, _)) =>
val localScanExec = LocalTableScanExec(output, scan.rows().toSeq)
withProjectAndFilter(project, filters, localScanExec, needsUnsafeConversion = false) :: Nil
@ -140,7 +140,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
case _ => false
}
val batchExec = BatchScanExec(relation.output, relation.scan, runtimeFilters,
relation.keyGroupedPartitioning)
relation.keyGroupedPartitioning, relation.ordering)
withProjectAndFilter(project, postScanFilters, batchExec, !batchExec.supportsColumnar) :: Nil
case PhysicalOperation(p, f, r: StreamingDataSourceV2Relation)
@ -267,7 +267,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
case DeleteFromTable(relation, condition) =>
relation match {
case DataSourceV2ScanRelation(r, _, output, _) =>
case DataSourceV2ScanRelation(r, _, output, _, _) =>
val table = r.table
if (SubqueryExpression.hasSubquery(condition)) {
throw QueryCompilationErrors.unsupportedDeleteByConditionWithSubqueryError(condition)

View File

@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan}
import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset}
@ -32,7 +32,8 @@ case class MicroBatchScanExec(
@transient stream: MicroBatchStream,
@transient start: Offset,
@transient end: Offset,
keyGroupedPartitioning: Option[Seq[Expression]] = None) extends DataSourceV2ScanExecBase {
keyGroupedPartitioning: Option[Seq[Expression]] = None,
ordering: Option[Seq[SortOrder]] = None) extends DataSourceV2ScanExecBase {
// TODO: unify the equal/hashCode implementation for all data source v2 query plans.
override def equals(other: Any): Boolean = other match {

View File

@ -21,18 +21,26 @@ import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.FunctionCatalog
import org.apache.spark.sql.connector.read.SupportsReportPartitioning
import org.apache.spark.sql.connector.read.{SupportsReportOrdering, SupportsReportPartitioning}
import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, UnknownPartitioning}
import org.apache.spark.util.collection.Utils.sequenceToOption
/**
* Extracts [[DataSourceV2ScanRelation]] from the input logical plan, converts any V2 partitioning
* reported by data sources to their catalyst counterparts. Then, annotates the plan with the
* result.
* and ordering reported by data sources to their catalyst counterparts. Then, annotates the plan
* with the partitioning and ordering result.
*/
object V2ScanPartitioning extends Rule[LogicalPlan] with SQLConfHelper {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case d @ DataSourceV2ScanRelation(relation, scan: SupportsReportPartitioning, _, None) =>
object V2ScanPartitioningAndOrdering extends Rule[LogicalPlan] with SQLConfHelper {
override def apply(plan: LogicalPlan): LogicalPlan = {
val scanRules = Seq[LogicalPlan => LogicalPlan] (partitioning, ordering)
scanRules.foldLeft(plan) { (newPlan, scanRule) =>
scanRule(newPlan)
}
}
private def partitioning(plan: LogicalPlan) = plan.transformDown {
case d @ DataSourceV2ScanRelation(relation, scan: SupportsReportPartitioning, _, None, _) =>
val funCatalogOpt = relation.catalog.flatMap {
case c: FunctionCatalog => Some(c)
case _ => None
@ -48,4 +56,10 @@ object V2ScanPartitioning extends Rule[LogicalPlan] with SQLConfHelper {
d.copy(keyGroupedPartitioning = catalystPartitioning)
}
private def ordering(plan: LogicalPlan) = plan.transformDown {
case d @ DataSourceV2ScanRelation(relation, scan: SupportsReportOrdering, _, _, _) =>
val ordering = V2ExpressionUtils.toCatalystOrdering(scan.outputOrdering(), relation)
d.copy(ordering = Some(ordering))
}
}

View File

@ -78,7 +78,7 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join
} else {
None
}
case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeFiltering, _, _)) =>
case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeFiltering, _, _, _)) =>
val filterAttrs = V2ExpressionUtils.resolveRefs[Attribute](scan.filterAttributes, r)
if (resExp.references.subsetOf(AttributeSet(filterAttrs))) {
Some(r)

View File

@ -0,0 +1,160 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package test.org.apache.spark.sql.connector;
import java.util.Arrays;
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.expressions.*;
import org.apache.spark.sql.connector.read.*;
import org.apache.spark.sql.connector.read.partitioning.KeyGroupedPartitioning;
import org.apache.spark.sql.connector.read.partitioning.Partitioning;
import org.apache.spark.sql.connector.read.partitioning.UnknownPartitioning;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
public class JavaOrderAndPartitionAwareDataSource extends JavaPartitionAwareDataSource {
static class MyScanBuilder extends JavaPartitionAwareDataSource.MyScanBuilder
implements SupportsReportOrdering {
private final Partitioning partitioning;
private final SortOrder[] ordering;
MyScanBuilder(String partitionKeys, String orderKeys) {
if (partitionKeys != null) {
String[] keys = partitionKeys.split(",");
Expression[] clustering = new Transform[keys.length];
for (int i = 0; i < keys.length; i++) {
clustering[i] = Expressions.identity(keys[i]);
}
this.partitioning = new KeyGroupedPartitioning(clustering, 2);
} else {
this.partitioning = new UnknownPartitioning(2);
}
if (orderKeys != null) {
String[] keys = orderKeys.split(",");
this.ordering = new SortOrder[keys.length];
for (int i = 0; i < keys.length; i++) {
this.ordering[i] = new MySortOrder(keys[i]);
}
} else {
this.ordering = new SortOrder[0];
}
}
@Override
public InputPartition[] planInputPartitions() {
InputPartition[] partitions = new InputPartition[2];
partitions[0] = new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 5, 5});
partitions[1] = new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 1, 2});
return partitions;
}
@Override
public Partitioning outputPartitioning() {
return this.partitioning;
}
@Override
public SortOrder[] outputOrdering() {
return this.ordering;
}
}
@Override
public Table getTable(CaseInsensitiveStringMap options) {
return new JavaSimpleBatchTable() {
@Override
public Transform[] partitioning() {
String partitionKeys = options.get("partitionKeys");
if (partitionKeys == null) {
return new Transform[0];
} else {
return (Transform[]) Arrays.stream(partitionKeys.split(","))
.map(Expressions::identity).toArray();
}
}
@Override
public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
return new MyScanBuilder(options.get("partitionKeys"), options.get("orderKeys"));
}
};
}
static class MySortOrder implements SortOrder {
private final Expression expression;
MySortOrder(String columnName) {
this.expression = new MyIdentityTransform(new MyNamedReference(columnName));
}
@Override
public Expression expression() {
return expression;
}
@Override
public SortDirection direction() {
return SortDirection.ASCENDING;
}
@Override
public NullOrdering nullOrdering() {
return NullOrdering.NULLS_FIRST;
}
}
static class MyNamedReference implements NamedReference {
private final String[] parts;
MyNamedReference(String part) {
this.parts = new String[] { part };
}
@Override
public String[] fieldNames() {
return this.parts;
}
}
static class MyIdentityTransform implements Transform {
private final Expression[] args;
MyIdentityTransform(NamedReference namedReference) {
this.args = new Expression[] { namedReference };
}
@Override
public String name() {
return "identity";
}
@Override
public NamedReference[] references() {
return new NamedReference[0];
}
@Override
public Expression[] arguments() {
return this.args;
}
}
}

View File

@ -849,7 +849,7 @@ class FileBasedDataSourceSuite extends QueryTest
})
val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan, _, _) => f
case BatchScanExec(_, f: FileScan, _, _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty)
@ -889,7 +889,7 @@ class FileBasedDataSourceSuite extends QueryTest
assert(filterCondition.isDefined)
val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan, _, _) => f
case BatchScanExec(_, f: FileScan, _, _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty)

View File

@ -26,14 +26,16 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{FieldReference, Literal, Transform}
import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning}
import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution.SortExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation}
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{Filter, GreaterThan}
@ -251,27 +253,27 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
val df = spark.read.format(cls.getName).load()
checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2)))
val groupByColA = df.groupBy($"i").agg(sum($"j"))
checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4)))
assert(collectFirst(groupByColA.queryExecution.executedPlan) {
val groupByColI = df.groupBy($"i").agg(sum($"j"))
checkAnswer(groupByColI, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4)))
assert(collectFirst(groupByColI.queryExecution.executedPlan) {
case e: ShuffleExchangeExec => e
}.isEmpty)
val groupByColAB = df.groupBy($"i", $"j").agg(count("*"))
checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2)))
assert(collectFirst(groupByColAB.queryExecution.executedPlan) {
val groupByColIJ = df.groupBy($"i", $"j").agg(count("*"))
checkAnswer(groupByColIJ, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2)))
assert(collectFirst(groupByColIJ.queryExecution.executedPlan) {
case e: ShuffleExchangeExec => e
}.isEmpty)
val groupByColB = df.groupBy($"j").agg(sum($"i"))
checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5)))
assert(collectFirst(groupByColB.queryExecution.executedPlan) {
val groupByColJ = df.groupBy($"j").agg(sum($"i"))
checkAnswer(groupByColJ, Seq(Row(2, 8), Row(4, 2), Row(6, 5)))
assert(collectFirst(groupByColJ.queryExecution.executedPlan) {
case e: ShuffleExchangeExec => e
}.isDefined)
val groupByAPlusB = df.groupBy($"i" + $"j").agg(count("*"))
checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1)))
assert(collectFirst(groupByAPlusB.queryExecution.executedPlan) {
val groupByIPlusJ = df.groupBy($"i" + $"j").agg(count("*"))
checkAnswer(groupByIPlusJ, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1)))
assert(collectFirst(groupByIPlusJ.queryExecution.executedPlan) {
case e: ShuffleExchangeExec => e
}.isDefined)
}
@ -279,6 +281,90 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
}
}
test("ordering and partitioning reporting") {
withSQLConf(SQLConf.V2_BUCKETING_ENABLED.key -> "true") {
Seq(
classOf[OrderAndPartitionAwareDataSource],
classOf[JavaOrderAndPartitionAwareDataSource]
).foreach { cls =>
withClue(cls.getName) {
// we test report ordering (together with report partitioning) with these transformations:
// - groupBy("i").flatMapGroups:
// hash-partitions by "i" and sorts each partition by "i"
// requires partitioning and sort by "i"
// - aggregation function over window partitioned by "i" and ordered by "j":
// hash-partitions by "i" and sorts each partition by "j"
// requires partitioning by "i" and sort by "i" and "j"
Seq(
// with no partitioning and no order, we expect shuffling AND sorting
(None, None, (true, true), (true, true)),
// partitioned by i and no order, we expect NO shuffling BUT sorting
(Some("i"), None, (false, true), (false, true)),
// partitioned by i and in-partition sorted by i,
// we expect NO shuffling AND sorting for groupBy but sorting for window function
(Some("i"), Some("i"), (false, false), (false, true)),
// partitioned by i and in-partition sorted by j, we expect NO shuffling BUT sorting
(Some("i"), Some("j"), (false, true), (false, true)),
// partitioned by i and in-partition sorted by i,j, we expect NO shuffling NOR sorting
(Some("i"), Some("i,j"), (false, false), (false, false)),
// partitioned by j and in-partition sorted by i, we expect shuffling AND sorting
(Some("j"), Some("i"), (true, true), (true, true)),
// partitioned by j and in-partition sorted by i,j, we expect shuffling and sorting
(Some("j"), Some("i,j"), (true, true), (true, true))
).foreach { testParams =>
val (partitionKeys, orderKeys, groupByExpects, windowFuncExpects) = testParams
withClue(f"${partitionKeys.orNull} ${orderKeys.orNull}") {
val df = spark.read
.option("partitionKeys", partitionKeys.orNull)
.option("orderKeys", orderKeys.orNull)
.format(cls.getName)
.load()
checkAnswer(df, Seq(Row(1, 4), Row(1, 5), Row(3, 5), Row(2, 6), Row(4, 1), Row(4, 2)))
// groupBy(i).flatMapGroups
{
val groupBy = df.groupBy($"i").as[Int, (Int, Int)]
.flatMapGroups { (i: Int, it: Iterator[(Int, Int)]) =>
Iterator.single((i, it.length)) }
checkAnswer(
groupBy.toDF(),
Seq(Row(1, 2), Row(2, 1), Row(3, 1), Row(4, 2))
)
val (shuffleExpected, sortExpected) = groupByExpects
assert(collectFirst(groupBy.queryExecution.executedPlan) {
case e: ShuffleExchangeExec => e
}.isDefined === shuffleExpected)
assert(collectFirst(groupBy.queryExecution.executedPlan) {
case e: SortExec => e
}.isDefined === sortExpected)
}
// aggregation function over window partitioned by i and ordered by j
{
val windowPartByColIOrderByColJ = df.withColumn("no",
row_number() over Window.partitionBy(Symbol("i")).orderBy(Symbol("j"))
)
checkAnswer(windowPartByColIOrderByColJ, Seq(
Row(1, 4, 1), Row(1, 5, 2), Row(2, 6, 1), Row(3, 5, 1), Row(4, 1, 1), Row(4, 2, 2)
))
val (shuffleExpected, sortExpected) = windowFuncExpects
assert(collectFirst(windowPartByColIOrderByColJ.queryExecution.executedPlan) {
case e: ShuffleExchangeExec => e
}.isDefined === shuffleExpected)
assert(collectFirst(windowPartByColIOrderByColJ.queryExecution.executedPlan) {
case e: SortExec => e
}.isDefined === sortExpected)
}
}
}
}
}
}
}
test ("statistics report data source") {
Seq(classOf[ReportStatisticsDataSource], classOf[JavaReportStatisticsDataSource]).foreach {
cls =>
@ -862,10 +948,10 @@ object ColumnarReaderFactory extends PartitionReaderFactory {
class PartitionAwareDataSource extends TestingV2Source {
class MyScanBuilder extends SimpleScanBuilder
with SupportsReportPartitioning{
with SupportsReportPartitioning {
override def planInputPartitions(): Array[InputPartition] = {
// Note that we don't have same value of column `a` across partitions.
// Note that we don't have same value of column `i` across partitions.
Array(
SpecificInputPartition(Array(1, 1, 3), Array(4, 4, 6)),
SpecificInputPartition(Array(2, 4, 4), Array(6, 2, 2)))
@ -886,6 +972,68 @@ class PartitionAwareDataSource extends TestingV2Source {
}
}
class OrderAndPartitionAwareDataSource extends PartitionAwareDataSource {
class MyScanBuilder(
val partitionKeys: Option[Seq[String]],
val orderKeys: Seq[String])
extends SimpleScanBuilder
with SupportsReportPartitioning with SupportsReportOrdering {
override def planInputPartitions(): Array[InputPartition] = {
// data are partitioned by column `i` or `j`, so we can report any partitioning
// column `i` is not ordered globally, but within partitions, together with`j`
// this allows us to report ordering by [i] and [i, j]
Array(
SpecificInputPartition(Array(1, 1, 3), Array(4, 5, 5)),
SpecificInputPartition(Array(2, 4, 4), Array(6, 1, 2)))
}
override def createReaderFactory(): PartitionReaderFactory = {
SpecificReaderFactory
}
override def outputPartitioning(): Partitioning = {
partitionKeys.map(keys =>
new KeyGroupedPartitioning(keys.map(FieldReference(_)).toArray, 2)
).getOrElse(
new UnknownPartitioning(2)
)
}
override def outputOrdering(): Array[SortOrder] = orderKeys.map(
new MySortOrder(_)
).toArray
}
override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable {
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new MyScanBuilder(
Option(options.get("partitionKeys")).map(_.split(",")),
Option(options.get("orderKeys")).map(_.split(",").toSeq).getOrElse(Seq.empty)
)
}
}
class MySortOrder(columnName: String) extends SortOrder {
override def expression(): Expression = new MyIdentityTransform(
new MyNamedReference(columnName)
)
override def direction(): SortDirection = SortDirection.ASCENDING
override def nullOrdering(): NullOrdering = NullOrdering.NULLS_FIRST
}
class MyNamedReference(parts: String*) extends NamedReference {
override def fieldNames(): Array[String] = parts.toArray
}
class MyIdentityTransform(namedReference: NamedReference) extends Transform {
override def name(): String = "identity"
override def references(): Array[NamedReference] = Array.empty
override def arguments(): Array[Expression] = Seq(namedReference).toArray
}
}
case class SpecificInputPartition(
i: Array[Int],
j: Array[Int]) extends InputPartition with HasPartitionKey {

View File

@ -95,7 +95,7 @@ abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase {
assert(getScanExecPartitionSize(plan) == expectedPartitionCount)
val collectFn: PartialFunction[SparkPlan, Seq[Expression]] = collectPartitionFiltersFn orElse {
case BatchScanExec(_, scan: FileScan, _, _) => scan.partitionFilters
case BatchScanExec(_, scan: FileScan, _, _, _) => scan.partitionFilters
}
val pushedDownPartitionFilters = plan.collectFirst(collectFn)
.map(exps => exps.filterNot(e => e.isInstanceOf[IsNotNull]))

View File

@ -59,7 +59,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession {
.where(Column(predicate))
query.queryExecution.optimizedPlan match {
case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, o: OrcScan, _, _)) =>
case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, o: OrcScan, _, _, _)) =>
assert(filters.nonEmpty, "No filter is analyzed from the given query")
assert(o.pushedFilters.nonEmpty, "No filter is pushed down")
val maybeFilter = OrcFilters.createFilter(query.schema, o.pushedFilters)

View File

@ -120,7 +120,7 @@ trait OrcTest extends QueryTest with FileBasedDataSourceTest with BeforeAndAfter
.where(Column(predicate))
query.queryExecution.optimizedPlan match {
case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, o: OrcScan, _, _)) =>
case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, o: OrcScan, _, _, _)) =>
assert(filters.nonEmpty, "No filter is analyzed from the given query")
if (noneSupported) {
assert(o.pushedFilters.isEmpty, "Unsupported filters should not show in pushed filters")

View File

@ -40,7 +40,7 @@ class OrcV2SchemaPruningSuite extends SchemaPruningSuite with AdaptiveSparkPlanH
override def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = {
val fileSourceScanSchemata =
collect(df.queryExecution.executedPlan) {
case BatchScanExec(_, scan: OrcScan, _, _) => scan.readDataSchema
case BatchScanExec(_, scan: OrcScan, _, _, _) => scan.readDataSchema
}
assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size,
s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " +

View File

@ -2155,7 +2155,7 @@ class ParquetV2FilterSuite extends ParquetFilterSuite {
query.queryExecution.optimizedPlan.collectFirst {
case PhysicalOperation(_, filters,
DataSourceV2ScanRelation(_, scan: ParquetScan, _, _)) =>
DataSourceV2ScanRelation(_, scan: ParquetScan, _, _, _)) =>
assert(filters.nonEmpty, "No filter is analyzed from the given query")
val sourceFilters = filters.flatMap(DataSourceStrategy.translateFilter(_, true)).toArray
val pushedFilters = scan.pushedFilters