[SPARK-26952][SQL] Row count statics should respect the data reported by data source
## What changes were proposed in this pull request? In data source v2, if the data source scan implemented `SupportsReportStatistics`. `DataSourceV2Relation` should respect the row count reported by the data source. ## How was this patch tested? New UT test. Closes #23853 from ConeyLiu/report-row-count. Authored-by: Xianyang Liu <xianyang.liu@intel.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
52a180f25f
commit
bc03c8b3fa
|
@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
|
||||||
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
|
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
|
||||||
import org.apache.spark.sql.catalyst.util.truncatedString
|
import org.apache.spark.sql.catalyst.util.truncatedString
|
||||||
import org.apache.spark.sql.sources.v2._
|
import org.apache.spark.sql.sources.v2._
|
||||||
import org.apache.spark.sql.sources.v2.reader._
|
import org.apache.spark.sql.sources.v2.reader.{Statistics => V2Statistics, _}
|
||||||
import org.apache.spark.sql.sources.v2.reader.streaming.{Offset, SparkDataStream}
|
import org.apache.spark.sql.sources.v2.reader.streaming.{Offset, SparkDataStream}
|
||||||
import org.apache.spark.sql.sources.v2.writer._
|
import org.apache.spark.sql.sources.v2.writer._
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ case class DataSourceV2Relation(
|
||||||
scan match {
|
scan match {
|
||||||
case r: SupportsReportStatistics =>
|
case r: SupportsReportStatistics =>
|
||||||
val statistics = r.estimateStatistics()
|
val statistics = r.estimateStatistics()
|
||||||
Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
|
DataSourceV2Relation.transformV2Stats(statistics, None, conf.defaultSizeInBytes)
|
||||||
case _ =>
|
case _ =>
|
||||||
Statistics(sizeInBytes = conf.defaultSizeInBytes)
|
Statistics(sizeInBytes = conf.defaultSizeInBytes)
|
||||||
}
|
}
|
||||||
|
@ -89,7 +89,7 @@ case class StreamingDataSourceV2Relation(
|
||||||
override def computeStats(): Statistics = scan match {
|
override def computeStats(): Statistics = scan match {
|
||||||
case r: SupportsReportStatistics =>
|
case r: SupportsReportStatistics =>
|
||||||
val statistics = r.estimateStatistics()
|
val statistics = r.estimateStatistics()
|
||||||
Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
|
DataSourceV2Relation.transformV2Stats(statistics, None, conf.defaultSizeInBytes)
|
||||||
case _ =>
|
case _ =>
|
||||||
Statistics(sizeInBytes = conf.defaultSizeInBytes)
|
Statistics(sizeInBytes = conf.defaultSizeInBytes)
|
||||||
}
|
}
|
||||||
|
@ -100,4 +100,21 @@ object DataSourceV2Relation {
|
||||||
val output = table.schema().toAttributes
|
val output = table.schema().toAttributes
|
||||||
DataSourceV2Relation(table, output, options)
|
DataSourceV2Relation(table, output, options)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This is used to transform data source v2 statistics to logical.Statistics.
|
||||||
|
*/
|
||||||
|
def transformV2Stats(
|
||||||
|
v2Statistics: V2Statistics,
|
||||||
|
defaultRowCount: Option[BigInt],
|
||||||
|
defaultSizeInBytes: Long): Statistics = {
|
||||||
|
val numRows: Option[BigInt] = if (v2Statistics.numRows().isPresent) {
|
||||||
|
Some(v2Statistics.numRows().getAsLong)
|
||||||
|
} else {
|
||||||
|
defaultRowCount
|
||||||
|
}
|
||||||
|
Statistics(
|
||||||
|
sizeInBytes = v2Statistics.sizeInBytes().orElse(defaultSizeInBytes),
|
||||||
|
rowCount = numRows)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package test.org.apache.spark.sql.sources.v2;
|
||||||
|
|
||||||
|
import java.util.OptionalLong;
|
||||||
|
|
||||||
|
import org.apache.spark.sql.sources.v2.DataSourceOptions;
|
||||||
|
import org.apache.spark.sql.sources.v2.Table;
|
||||||
|
import org.apache.spark.sql.sources.v2.TableProvider;
|
||||||
|
import org.apache.spark.sql.sources.v2.reader.InputPartition;
|
||||||
|
import org.apache.spark.sql.sources.v2.reader.ScanBuilder;
|
||||||
|
import org.apache.spark.sql.sources.v2.reader.Statistics;
|
||||||
|
import org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics;
|
||||||
|
|
||||||
|
public class JavaReportStatisticsDataSource implements TableProvider {
|
||||||
|
class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportStatistics {
|
||||||
|
@Override
|
||||||
|
public Statistics estimateStatistics() {
|
||||||
|
return new Statistics() {
|
||||||
|
@Override
|
||||||
|
public OptionalLong sizeInBytes() {
|
||||||
|
return OptionalLong.of(80);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public OptionalLong numRows() {
|
||||||
|
return OptionalLong.of(10);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public InputPartition[] planInputPartitions() {
|
||||||
|
InputPartition[] partitions = new InputPartition[2];
|
||||||
|
partitions[0] = new JavaRangeInputPartition(0, 5);
|
||||||
|
partitions[1] = new JavaRangeInputPartition(5, 10);
|
||||||
|
return partitions;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Table getTable(DataSourceOptions options) {
|
||||||
|
return new JavaSimpleBatchTable() {
|
||||||
|
@Override
|
||||||
|
public ScanBuilder newScanBuilder(DataSourceOptions options) {
|
||||||
|
return new MyScanBuilder();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
|
@ -18,6 +18,7 @@
|
||||||
package org.apache.spark.sql.sources.v2
|
package org.apache.spark.sql.sources.v2
|
||||||
|
|
||||||
import java.io.File
|
import java.io.File
|
||||||
|
import java.util.OptionalLong
|
||||||
|
|
||||||
import test.org.apache.spark.sql.sources.v2._
|
import test.org.apache.spark.sql.sources.v2._
|
||||||
|
|
||||||
|
@ -182,6 +183,24 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test ("statistics report data source") {
|
||||||
|
Seq(classOf[ReportStatisticsDataSource], classOf[JavaReportStatisticsDataSource]).foreach {
|
||||||
|
cls =>
|
||||||
|
withClue(cls.getName) {
|
||||||
|
val df = spark.read.format(cls.getName).load()
|
||||||
|
val logical = df.queryExecution.optimizedPlan.collect {
|
||||||
|
case d: DataSourceV2Relation => d
|
||||||
|
}.head
|
||||||
|
|
||||||
|
val statics = logical.computeStats()
|
||||||
|
assert(statics.rowCount.isDefined && statics.rowCount.get === 10,
|
||||||
|
"Row count statics should be reported by data source")
|
||||||
|
assert(statics.sizeInBytes === 80,
|
||||||
|
"Size in bytes statics should be reported by data source")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
test("SPARK-23574: no shuffle exchange with single partition") {
|
test("SPARK-23574: no shuffle exchange with single partition") {
|
||||||
val df = spark.read.format(classOf[SimpleSinglePartitionSource].getName).load().agg(count("*"))
|
val df = spark.read.format(classOf[SimpleSinglePartitionSource].getName).load().agg(count("*"))
|
||||||
assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.isEmpty)
|
assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.isEmpty)
|
||||||
|
@ -621,7 +640,6 @@ object ColumnarReaderFactory extends PartitionReaderFactory {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class PartitionAwareDataSource extends TableProvider {
|
class PartitionAwareDataSource extends TableProvider {
|
||||||
|
|
||||||
class MyScanBuilder extends SimpleScanBuilder
|
class MyScanBuilder extends SimpleScanBuilder
|
||||||
|
@ -689,3 +707,29 @@ class SimpleWriteOnlyDataSource extends SimpleWritableDataSource {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ReportStatisticsDataSource extends TableProvider {
|
||||||
|
|
||||||
|
class MyScanBuilder extends SimpleScanBuilder
|
||||||
|
with SupportsReportStatistics {
|
||||||
|
override def estimateStatistics(): Statistics = {
|
||||||
|
new Statistics {
|
||||||
|
override def sizeInBytes(): OptionalLong = OptionalLong.of(80)
|
||||||
|
|
||||||
|
override def numRows(): OptionalLong = OptionalLong.of(10)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def planInputPartitions(): Array[InputPartition] = {
|
||||||
|
Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def getTable(options: DataSourceOptions): Table = {
|
||||||
|
new SimpleBatchTable {
|
||||||
|
override def newScanBuilder(options: DataSourceOptions): ScanBuilder = {
|
||||||
|
new MyScanBuilder
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue