[SPARK-26952][SQL] Row count statics should respect the data reported by data source

## What changes were proposed in this pull request?

In data source v2, if the data source scan implemented `SupportsReportStatistics`. `DataSourceV2Relation` should respect the row count reported by the data source.

## How was this patch tested?

New UT test.

Closes #23853 from ConeyLiu/report-row-count.

Authored-by: Xianyang Liu <xianyang.liu@intel.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Xianyang Liu 2019-02-26 14:10:54 +08:00 committed by Wenchen Fan
parent 52a180f25f
commit bc03c8b3fa
3 changed files with 130 additions and 4 deletions

View file

@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.{Statistics => V2Statistics, _}
import org.apache.spark.sql.sources.v2.reader.streaming.{Offset, SparkDataStream}
import org.apache.spark.sql.sources.v2.writer._
@ -56,7 +56,7 @@ case class DataSourceV2Relation(
scan match {
case r: SupportsReportStatistics =>
val statistics = r.estimateStatistics()
Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
DataSourceV2Relation.transformV2Stats(statistics, None, conf.defaultSizeInBytes)
case _ =>
Statistics(sizeInBytes = conf.defaultSizeInBytes)
}
@ -89,7 +89,7 @@ case class StreamingDataSourceV2Relation(
override def computeStats(): Statistics = scan match {
case r: SupportsReportStatistics =>
val statistics = r.estimateStatistics()
Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
DataSourceV2Relation.transformV2Stats(statistics, None, conf.defaultSizeInBytes)
case _ =>
Statistics(sizeInBytes = conf.defaultSizeInBytes)
}
@ -100,4 +100,21 @@ object DataSourceV2Relation {
val output = table.schema().toAttributes
DataSourceV2Relation(table, output, options)
}
/**
* This is used to transform data source v2 statistics to logical.Statistics.
*/
def transformV2Stats(
v2Statistics: V2Statistics,
defaultRowCount: Option[BigInt],
defaultSizeInBytes: Long): Statistics = {
val numRows: Option[BigInt] = if (v2Statistics.numRows().isPresent) {
Some(v2Statistics.numRows().getAsLong)
} else {
defaultRowCount
}
Statistics(
sizeInBytes = v2Statistics.sizeInBytes().orElse(defaultSizeInBytes),
rowCount = numRows)
}
}

View file

@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package test.org.apache.spark.sql.sources.v2;
import java.util.OptionalLong;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.Table;
import org.apache.spark.sql.sources.v2.TableProvider;
import org.apache.spark.sql.sources.v2.reader.InputPartition;
import org.apache.spark.sql.sources.v2.reader.ScanBuilder;
import org.apache.spark.sql.sources.v2.reader.Statistics;
import org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics;
public class JavaReportStatisticsDataSource implements TableProvider {
class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportStatistics {
@Override
public Statistics estimateStatistics() {
return new Statistics() {
@Override
public OptionalLong sizeInBytes() {
return OptionalLong.of(80);
}
@Override
public OptionalLong numRows() {
return OptionalLong.of(10);
}
};
}
@Override
public InputPartition[] planInputPartitions() {
InputPartition[] partitions = new InputPartition[2];
partitions[0] = new JavaRangeInputPartition(0, 5);
partitions[1] = new JavaRangeInputPartition(5, 10);
return partitions;
}
}
@Override
public Table getTable(DataSourceOptions options) {
return new JavaSimpleBatchTable() {
@Override
public ScanBuilder newScanBuilder(DataSourceOptions options) {
return new MyScanBuilder();
}
};
}
}

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.sources.v2
import java.io.File
import java.util.OptionalLong
import test.org.apache.spark.sql.sources.v2._
@ -182,6 +183,24 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
}
}
test ("statistics report data source") {
Seq(classOf[ReportStatisticsDataSource], classOf[JavaReportStatisticsDataSource]).foreach {
cls =>
withClue(cls.getName) {
val df = spark.read.format(cls.getName).load()
val logical = df.queryExecution.optimizedPlan.collect {
case d: DataSourceV2Relation => d
}.head
val statics = logical.computeStats()
assert(statics.rowCount.isDefined && statics.rowCount.get === 10,
"Row count statics should be reported by data source")
assert(statics.sizeInBytes === 80,
"Size in bytes statics should be reported by data source")
}
}
}
test("SPARK-23574: no shuffle exchange with single partition") {
val df = spark.read.format(classOf[SimpleSinglePartitionSource].getName).load().agg(count("*"))
assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.isEmpty)
@ -621,7 +640,6 @@ object ColumnarReaderFactory extends PartitionReaderFactory {
}
}
class PartitionAwareDataSource extends TableProvider {
class MyScanBuilder extends SimpleScanBuilder
@ -689,3 +707,29 @@ class SimpleWriteOnlyDataSource extends SimpleWritableDataSource {
}
}
}
class ReportStatisticsDataSource extends TableProvider {
class MyScanBuilder extends SimpleScanBuilder
with SupportsReportStatistics {
override def estimateStatistics(): Statistics = {
new Statistics {
override def sizeInBytes(): OptionalLong = OptionalLong.of(80)
override def numRows(): OptionalLong = OptionalLong.of(10)
}
}
override def planInputPartitions(): Array[InputPartition] = {
Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10))
}
}
override def getTable(options: DataSourceOptions): Table = {
new SimpleBatchTable {
override def newScanBuilder(options: DataSourceOptions): ScanBuilder = {
new MyScanBuilder
}
}
}
}