[SPARK-23574][SQL] Report SinglePartition in DataSourceV2ScanExec when there's exactly 1 data reader factory.

## What changes were proposed in this pull request?

Report SinglePartition in DataSourceV2ScanExec when there's exactly 1 data reader factory.

Note that this means reader factories end up being constructed as partitioning is checked; let me know if you think that could be a problem.

## How was this patch tested?

existing unit tests

Author: Jose Torres <jose@databricks.com>
Author: Jose Torres <torres.joseph.f+github@gmail.com>

Closes #20726 from jose-torres/SPARK-23574.
This commit is contained in:
Jose Torres 2018-03-20 11:46:51 -07:00 committed by Wenchen Fan
parent 7f5e8aa260
commit 2c4b9962fd
6 changed files with 50 additions and 14 deletions

View file

@ -23,6 +23,9 @@ import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning;
/**
* A mix in interface for {@link DataSourceReader}. Data source readers can implement this
* interface to report data partitioning and try to avoid shuffle at Spark side.
*
* Note that, when the reader creates exactly one {@link DataReaderFactory}, Spark may avoid
* adding a shuffle even if the reader does not implement this interface.
*/
@InterfaceStability.Evolving
public interface SupportsReportPartitioning extends DataSourceReader {

View file

@ -29,11 +29,11 @@ class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: Da
class DataSourceRDD[T: ClassTag](
sc: SparkContext,
@transient private val readerFactories: java.util.List[DataReaderFactory[T]])
@transient private val readerFactories: Seq[DataReaderFactory[T]])
extends RDD[T](sc, Nil) {
override protected def getPartitions: Array[Partition] = {
readerFactories.asScala.zipWithIndex.map {
readerFactories.zipWithIndex.map {
case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
}.toArray
}

View file

@ -25,12 +25,14 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.sources.v2.DataSourceV2
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
/**
* Physical plan node for scanning data from a data source.
@ -56,6 +58,15 @@ case class DataSourceV2ScanExec(
}
override def outputPartitioning: physical.Partitioning = reader match {
case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchReaderFactories.size == 1 =>
SinglePartition
case r: SupportsScanColumnarBatch if !r.enableBatchRead() && readerFactories.size == 1 =>
SinglePartition
case r if !r.isInstanceOf[SupportsScanColumnarBatch] && readerFactories.size == 1 =>
SinglePartition
case s: SupportsReportPartitioning =>
new DataSourcePartitioning(
s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name)))
@ -63,29 +74,33 @@ case class DataSourceV2ScanExec(
case _ => super.outputPartitioning
}
private lazy val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]] = reader match {
case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories()
private lazy val readerFactories: Seq[DataReaderFactory[UnsafeRow]] = reader match {
case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories().asScala
case _ =>
reader.createDataReaderFactories().asScala.map {
new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow]
}.asJava
}
}
private lazy val inputRDD: RDD[InternalRow] = reader match {
private lazy val batchReaderFactories: Seq[DataReaderFactory[ColumnarBatch]] = reader match {
case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
assert(!reader.isInstanceOf[ContinuousReader],
"continuous stream reader does not support columnar read yet.")
new DataSourceRDD(sparkContext, r.createBatchDataReaderFactories())
.asInstanceOf[RDD[InternalRow]]
r.createBatchDataReaderFactories().asScala
}
private lazy val inputRDD: RDD[InternalRow] = reader match {
case _: ContinuousReader =>
EpochCoordinatorRef.get(
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
sparkContext.env)
.askSync[Unit](SetReaderPartitions(readerFactories.size()))
.askSync[Unit](SetReaderPartitions(readerFactories.size))
new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories)
.asInstanceOf[RDD[InternalRow]]
case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
new DataSourceRDD(sparkContext, batchReaderFactories).asInstanceOf[RDD[InternalRow]]
case _ =>
new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]]
}

View file

@ -35,14 +35,14 @@ import org.apache.spark.util.ThreadUtils
class ContinuousDataSourceRDD(
sc: SparkContext,
sqlContext: SQLContext,
@transient private val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]])
@transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]])
extends RDD[UnsafeRow](sc, Nil) {
private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize
private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs
override protected def getPartitions: Array[Partition] = {
readerFactories.asScala.zipWithIndex.map {
readerFactories.zipWithIndex.map {
case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
}.toArray
}

View file

@ -25,7 +25,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.functions._
import org.apache.spark.sql.sources.{Filter, GreaterThan}
@ -191,6 +191,11 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
}
}
test("SPARK-23574: no shuffle exchange with single partition") {
val df = spark.read.format(classOf[SimpleSinglePartitionSource].getName).load().agg(count("*"))
assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.isEmpty)
}
test("simple writable data source") {
// TODO: java implementation.
Seq(classOf[SimpleWritableDataSource]).foreach { cls =>
@ -336,6 +341,19 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
}
}
class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport {
class Reader extends DataSourceReader {
override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5))
}
}
override def createReader(options: DataSourceOptions): DataSourceReader = new Reader
}
class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {
class Reader extends DataSourceReader {

View file

@ -326,9 +326,9 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
assert(progress.durationMs.get("setOffsetRange") === 50)
assert(progress.durationMs.get("getEndOffset") === 100)
assert(progress.durationMs.get("queryPlanning") === 0)
assert(progress.durationMs.get("queryPlanning") === 200)
assert(progress.durationMs.get("walCommit") === 0)
assert(progress.durationMs.get("addBatch") === 350)
assert(progress.durationMs.get("addBatch") === 150)
assert(progress.durationMs.get("triggerExecution") === 500)
assert(progress.sources.length === 1)