[SPARK-26952][SQL] Row count statics should respect the data reported by data source
## What changes were proposed in this pull request? In data source v2, if the data source scan implemented `SupportsReportStatistics`. `DataSourceV2Relation` should respect the row count reported by the data source. ## How was this patch tested? New UT test. Closes #23853 from ConeyLiu/report-row-count. Authored-by: Xianyang Liu <xianyang.liu@intel.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
52a180f25f
commit
bc03c8b3fa
|
@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
|
|||
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
|
||||
import org.apache.spark.sql.catalyst.util.truncatedString
|
||||
import org.apache.spark.sql.sources.v2._
|
||||
import org.apache.spark.sql.sources.v2.reader._
|
||||
import org.apache.spark.sql.sources.v2.reader.{Statistics => V2Statistics, _}
|
||||
import org.apache.spark.sql.sources.v2.reader.streaming.{Offset, SparkDataStream}
|
||||
import org.apache.spark.sql.sources.v2.writer._
|
||||
|
||||
|
@ -56,7 +56,7 @@ case class DataSourceV2Relation(
|
|||
scan match {
|
||||
case r: SupportsReportStatistics =>
|
||||
val statistics = r.estimateStatistics()
|
||||
Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
|
||||
DataSourceV2Relation.transformV2Stats(statistics, None, conf.defaultSizeInBytes)
|
||||
case _ =>
|
||||
Statistics(sizeInBytes = conf.defaultSizeInBytes)
|
||||
}
|
||||
|
@ -89,7 +89,7 @@ case class StreamingDataSourceV2Relation(
|
|||
override def computeStats(): Statistics = scan match {
|
||||
case r: SupportsReportStatistics =>
|
||||
val statistics = r.estimateStatistics()
|
||||
Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
|
||||
DataSourceV2Relation.transformV2Stats(statistics, None, conf.defaultSizeInBytes)
|
||||
case _ =>
|
||||
Statistics(sizeInBytes = conf.defaultSizeInBytes)
|
||||
}
|
||||
|
@ -100,4 +100,21 @@ object DataSourceV2Relation {
|
|||
val output = table.schema().toAttributes
|
||||
DataSourceV2Relation(table, output, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* This is used to transform data source v2 statistics to logical.Statistics.
|
||||
*/
|
||||
def transformV2Stats(
|
||||
v2Statistics: V2Statistics,
|
||||
defaultRowCount: Option[BigInt],
|
||||
defaultSizeInBytes: Long): Statistics = {
|
||||
val numRows: Option[BigInt] = if (v2Statistics.numRows().isPresent) {
|
||||
Some(v2Statistics.numRows().getAsLong)
|
||||
} else {
|
||||
defaultRowCount
|
||||
}
|
||||
Statistics(
|
||||
sizeInBytes = v2Statistics.sizeInBytes().orElse(defaultSizeInBytes),
|
||||
rowCount = numRows)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package test.org.apache.spark.sql.sources.v2;
|
||||
|
||||
import java.util.OptionalLong;
|
||||
|
||||
import org.apache.spark.sql.sources.v2.DataSourceOptions;
|
||||
import org.apache.spark.sql.sources.v2.Table;
|
||||
import org.apache.spark.sql.sources.v2.TableProvider;
|
||||
import org.apache.spark.sql.sources.v2.reader.InputPartition;
|
||||
import org.apache.spark.sql.sources.v2.reader.ScanBuilder;
|
||||
import org.apache.spark.sql.sources.v2.reader.Statistics;
|
||||
import org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics;
|
||||
|
||||
public class JavaReportStatisticsDataSource implements TableProvider {
|
||||
class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportStatistics {
|
||||
@Override
|
||||
public Statistics estimateStatistics() {
|
||||
return new Statistics() {
|
||||
@Override
|
||||
public OptionalLong sizeInBytes() {
|
||||
return OptionalLong.of(80);
|
||||
}
|
||||
|
||||
@Override
|
||||
public OptionalLong numRows() {
|
||||
return OptionalLong.of(10);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputPartition[] planInputPartitions() {
|
||||
InputPartition[] partitions = new InputPartition[2];
|
||||
partitions[0] = new JavaRangeInputPartition(0, 5);
|
||||
partitions[1] = new JavaRangeInputPartition(5, 10);
|
||||
return partitions;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Table getTable(DataSourceOptions options) {
|
||||
return new JavaSimpleBatchTable() {
|
||||
@Override
|
||||
public ScanBuilder newScanBuilder(DataSourceOptions options) {
|
||||
return new MyScanBuilder();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql.sources.v2
|
||||
|
||||
import java.io.File
|
||||
import java.util.OptionalLong
|
||||
|
||||
import test.org.apache.spark.sql.sources.v2._
|
||||
|
||||
|
@ -182,6 +183,24 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
|
|||
}
|
||||
}
|
||||
|
||||
test ("statistics report data source") {
|
||||
Seq(classOf[ReportStatisticsDataSource], classOf[JavaReportStatisticsDataSource]).foreach {
|
||||
cls =>
|
||||
withClue(cls.getName) {
|
||||
val df = spark.read.format(cls.getName).load()
|
||||
val logical = df.queryExecution.optimizedPlan.collect {
|
||||
case d: DataSourceV2Relation => d
|
||||
}.head
|
||||
|
||||
val statics = logical.computeStats()
|
||||
assert(statics.rowCount.isDefined && statics.rowCount.get === 10,
|
||||
"Row count statics should be reported by data source")
|
||||
assert(statics.sizeInBytes === 80,
|
||||
"Size in bytes statics should be reported by data source")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-23574: no shuffle exchange with single partition") {
|
||||
val df = spark.read.format(classOf[SimpleSinglePartitionSource].getName).load().agg(count("*"))
|
||||
assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.isEmpty)
|
||||
|
@ -621,7 +640,6 @@ object ColumnarReaderFactory extends PartitionReaderFactory {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
class PartitionAwareDataSource extends TableProvider {
|
||||
|
||||
class MyScanBuilder extends SimpleScanBuilder
|
||||
|
@ -689,3 +707,29 @@ class SimpleWriteOnlyDataSource extends SimpleWritableDataSource {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ReportStatisticsDataSource extends TableProvider {
|
||||
|
||||
class MyScanBuilder extends SimpleScanBuilder
|
||||
with SupportsReportStatistics {
|
||||
override def estimateStatistics(): Statistics = {
|
||||
new Statistics {
|
||||
override def sizeInBytes(): OptionalLong = OptionalLong.of(80)
|
||||
|
||||
override def numRows(): OptionalLong = OptionalLong.of(10)
|
||||
}
|
||||
}
|
||||
|
||||
override def planInputPartitions(): Array[InputPartition] = {
|
||||
Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10))
|
||||
}
|
||||
}
|
||||
|
||||
override def getTable(options: DataSourceOptions): Table = {
|
||||
new SimpleBatchTable {
|
||||
override def newScanBuilder(options: DataSourceOptions): ScanBuilder = {
|
||||
new MyScanBuilder
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue