[SPARK-23574][SQL] Report SinglePartition in DataSourceV2ScanExec when there's exactly 1 data reader factory.
## What changes were proposed in this pull request? Report SinglePartition in DataSourceV2ScanExec when there's exactly 1 data reader factory. Note that this means reader factories end up being constructed as partitioning is checked; let me know if you think that could be a problem. ## How was this patch tested? existing unit tests Author: Jose Torres <jose@databricks.com> Author: Jose Torres <torres.joseph.f+github@gmail.com> Closes #20726 from jose-torres/SPARK-23574.
This commit is contained in:
parent
7f5e8aa260
commit
2c4b9962fd
|
@ -23,6 +23,9 @@ import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning;
|
|||
/**
|
||||
* A mix in interface for {@link DataSourceReader}. Data source readers can implement this
|
||||
* interface to report data partitioning and try to avoid shuffle at Spark side.
|
||||
*
|
||||
* Note that, when the reader creates exactly one {@link DataReaderFactory}, Spark may avoid
|
||||
* adding a shuffle even if the reader does not implement this interface.
|
||||
*/
|
||||
@InterfaceStability.Evolving
|
||||
public interface SupportsReportPartitioning extends DataSourceReader {
|
||||
|
|
|
@ -29,11 +29,11 @@ class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: Da
|
|||
|
||||
class DataSourceRDD[T: ClassTag](
|
||||
sc: SparkContext,
|
||||
@transient private val readerFactories: java.util.List[DataReaderFactory[T]])
|
||||
@transient private val readerFactories: Seq[DataReaderFactory[T]])
|
||||
extends RDD[T](sc, Nil) {
|
||||
|
||||
override protected def getPartitions: Array[Partition] = {
|
||||
readerFactories.asScala.zipWithIndex.map {
|
||||
readerFactories.zipWithIndex.map {
|
||||
case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
|
||||
}.toArray
|
||||
}
|
||||
|
|
|
@ -25,12 +25,14 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.physical
|
||||
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
|
||||
import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
|
||||
import org.apache.spark.sql.execution.streaming.continuous._
|
||||
import org.apache.spark.sql.sources.v2.DataSourceV2
|
||||
import org.apache.spark.sql.sources.v2.reader._
|
||||
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.sql.vectorized.ColumnarBatch
|
||||
|
||||
/**
|
||||
* Physical plan node for scanning data from a data source.
|
||||
|
@ -56,6 +58,15 @@ case class DataSourceV2ScanExec(
|
|||
}
|
||||
|
||||
override def outputPartitioning: physical.Partitioning = reader match {
|
||||
case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchReaderFactories.size == 1 =>
|
||||
SinglePartition
|
||||
|
||||
case r: SupportsScanColumnarBatch if !r.enableBatchRead() && readerFactories.size == 1 =>
|
||||
SinglePartition
|
||||
|
||||
case r if !r.isInstanceOf[SupportsScanColumnarBatch] && readerFactories.size == 1 =>
|
||||
SinglePartition
|
||||
|
||||
case s: SupportsReportPartitioning =>
|
||||
new DataSourcePartitioning(
|
||||
s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name)))
|
||||
|
@ -63,29 +74,33 @@ case class DataSourceV2ScanExec(
|
|||
case _ => super.outputPartitioning
|
||||
}
|
||||
|
||||
private lazy val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]] = reader match {
|
||||
case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories()
|
||||
private lazy val readerFactories: Seq[DataReaderFactory[UnsafeRow]] = reader match {
|
||||
case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories().asScala
|
||||
case _ =>
|
||||
reader.createDataReaderFactories().asScala.map {
|
||||
new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow]
|
||||
}.asJava
|
||||
}
|
||||
}
|
||||
|
||||
private lazy val inputRDD: RDD[InternalRow] = reader match {
|
||||
private lazy val batchReaderFactories: Seq[DataReaderFactory[ColumnarBatch]] = reader match {
|
||||
case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
|
||||
assert(!reader.isInstanceOf[ContinuousReader],
|
||||
"continuous stream reader does not support columnar read yet.")
|
||||
new DataSourceRDD(sparkContext, r.createBatchDataReaderFactories())
|
||||
.asInstanceOf[RDD[InternalRow]]
|
||||
r.createBatchDataReaderFactories().asScala
|
||||
}
|
||||
|
||||
private lazy val inputRDD: RDD[InternalRow] = reader match {
|
||||
case _: ContinuousReader =>
|
||||
EpochCoordinatorRef.get(
|
||||
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
|
||||
sparkContext.env)
|
||||
.askSync[Unit](SetReaderPartitions(readerFactories.size()))
|
||||
.askSync[Unit](SetReaderPartitions(readerFactories.size))
|
||||
new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories)
|
||||
.asInstanceOf[RDD[InternalRow]]
|
||||
|
||||
case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
|
||||
new DataSourceRDD(sparkContext, batchReaderFactories).asInstanceOf[RDD[InternalRow]]
|
||||
|
||||
case _ =>
|
||||
new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]]
|
||||
}
|
||||
|
|
|
@ -35,14 +35,14 @@ import org.apache.spark.util.ThreadUtils
|
|||
class ContinuousDataSourceRDD(
|
||||
sc: SparkContext,
|
||||
sqlContext: SQLContext,
|
||||
@transient private val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]])
|
||||
@transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]])
|
||||
extends RDD[UnsafeRow](sc, Nil) {
|
||||
|
||||
private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize
|
||||
private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs
|
||||
|
||||
override protected def getPartitions: Array[Partition] = {
|
||||
readerFactories.asScala.zipWithIndex.map {
|
||||
readerFactories.zipWithIndex.map {
|
||||
case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
|
||||
}.toArray
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ import org.apache.spark.SparkException
|
|||
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
|
||||
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
|
||||
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec}
|
||||
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
|
||||
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
|
||||
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.sources.{Filter, GreaterThan}
|
||||
|
@ -191,6 +191,11 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
|
|||
}
|
||||
}
|
||||
|
||||
test("SPARK-23574: no shuffle exchange with single partition") {
|
||||
val df = spark.read.format(classOf[SimpleSinglePartitionSource].getName).load().agg(count("*"))
|
||||
assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.isEmpty)
|
||||
}
|
||||
|
||||
test("simple writable data source") {
|
||||
// TODO: java implementation.
|
||||
Seq(classOf[SimpleWritableDataSource]).foreach { cls =>
|
||||
|
@ -336,6 +341,19 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
|
|||
}
|
||||
}
|
||||
|
||||
class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport {
|
||||
|
||||
class Reader extends DataSourceReader {
|
||||
override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
|
||||
|
||||
override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
|
||||
java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5))
|
||||
}
|
||||
}
|
||||
|
||||
override def createReader(options: DataSourceOptions): DataSourceReader = new Reader
|
||||
}
|
||||
|
||||
class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {
|
||||
|
||||
class Reader extends DataSourceReader {
|
||||
|
|
|
@ -326,9 +326,9 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
|
|||
|
||||
assert(progress.durationMs.get("setOffsetRange") === 50)
|
||||
assert(progress.durationMs.get("getEndOffset") === 100)
|
||||
assert(progress.durationMs.get("queryPlanning") === 0)
|
||||
assert(progress.durationMs.get("queryPlanning") === 200)
|
||||
assert(progress.durationMs.get("walCommit") === 0)
|
||||
assert(progress.durationMs.get("addBatch") === 350)
|
||||
assert(progress.durationMs.get("addBatch") === 150)
|
||||
assert(progress.durationMs.get("triggerExecution") === 500)
|
||||
|
||||
assert(progress.sources.length === 1)
|
||||
|
|
Loading…
Reference in a new issue