[SPARK-23219][SQL] Rename ReadTask to DataReaderFactory in data source v2

## What changes were proposed in this pull request?

Currently we have `ReadTask` in data source v2 reader, while in writer we have `DataWriterFactory`.
To make the naming consistent and better, renaming `ReadTask` to `DataReaderFactory`.

## How was this patch tested?

Unit test

Author: Wang Gengliang <ltnwgl@gmail.com>

Closes #20397 from gengliangwang/rename.
This commit is contained in:
Wang Gengliang 2018-01-30 00:50:49 +08:00 committed by Wenchen Fan
parent 39d2c6b034
commit badf0d0e0d
28 changed files with 172 additions and 156 deletions

View file

@ -63,7 +63,7 @@ class KafkaContinuousReader(
private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong
// Initialized when creating read tasks. If this diverges from the partitions at the latest // Initialized when creating reader factories. If this diverges from the partitions at the latest
// offsets, we need to reconfigure. // offsets, we need to reconfigure.
// Exposed outside this object only for unit tests. // Exposed outside this object only for unit tests.
private[sql] var knownPartitions: Set[TopicPartition] = _ private[sql] var knownPartitions: Set[TopicPartition] = _
@ -89,7 +89,7 @@ class KafkaContinuousReader(
KafkaSourceOffset(JsonUtils.partitionOffsets(json)) KafkaSourceOffset(JsonUtils.partitionOffsets(json))
} }
override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = { override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset)
@ -109,9 +109,9 @@ class KafkaContinuousReader(
startOffsets.toSeq.map { startOffsets.toSeq.map {
case (topicPartition, start) => case (topicPartition, start) =>
KafkaContinuousReadTask( KafkaContinuousDataReaderFactory(
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
.asInstanceOf[ReadTask[UnsafeRow]] .asInstanceOf[DataReaderFactory[UnsafeRow]]
}.asJava }.asJava
} }
@ -149,8 +149,8 @@ class KafkaContinuousReader(
} }
/** /**
* A read task for continuous Kafka processing. This will be serialized and transformed into a * A data reader factory for continuous Kafka processing. This will be serialized and transformed
* full reader on executors. * into a full reader on executors.
* *
* @param topicPartition The (topic, partition) pair this task is responsible for. * @param topicPartition The (topic, partition) pair this task is responsible for.
* @param startOffset The offset to start reading from within the partition. * @param startOffset The offset to start reading from within the partition.
@ -159,12 +159,12 @@ class KafkaContinuousReader(
* @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
* are skipped. * are skipped.
*/ */
case class KafkaContinuousReadTask( case class KafkaContinuousDataReaderFactory(
topicPartition: TopicPartition, topicPartition: TopicPartition,
startOffset: Long, startOffset: Long,
kafkaParams: ju.Map[String, Object], kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long, pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] { failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] {
override def createDataReader(): KafkaContinuousDataReader = { override def createDataReader(): KafkaContinuousDataReader = {
new KafkaContinuousDataReader( new KafkaContinuousDataReader(
topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss)

View file

@ -20,7 +20,6 @@ package org.apache.spark.sql.execution;
import java.io.IOException; import java.io.IOException;
import java.util.function.Supplier; import java.util.function.Supplier;
import org.apache.spark.sql.catalyst.util.TypeUtils;
import scala.collection.AbstractIterator; import scala.collection.AbstractIterator;
import scala.collection.Iterator; import scala.collection.Iterator;
import scala.math.Ordering; import scala.math.Ordering;

View file

@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability;
/** /**
* A concrete implementation of {@link Distribution}. Represents a distribution where records that * A concrete implementation of {@link Distribution}. Represents a distribution where records that
* share the same values for the {@link #clusteredColumns} will be produced by the same * share the same values for the {@link #clusteredColumns} will be produced by the same
* {@link ReadTask}. * {@link DataReader}.
*/ */
@InterfaceStability.Evolving @InterfaceStability.Evolving
public class ClusteredDistribution implements Distribution { public class ClusteredDistribution implements Distribution {

View file

@ -23,7 +23,7 @@ import java.io.IOException;
import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.annotation.InterfaceStability;
/** /**
* A data reader returned by {@link ReadTask#createDataReader()} and is responsible for * A data reader returned by {@link DataReaderFactory#createDataReader()} and is responsible for
* outputting data for a RDD partition. * outputting data for a RDD partition.
* *
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data

View file

@ -22,21 +22,23 @@ import java.io.Serializable;
import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.annotation.InterfaceStability;
/** /**
* A read task returned by {@link DataSourceV2Reader#createReadTasks()} and is responsible for * A reader factory returned by {@link DataSourceV2Reader#createDataReaderFactories()} and is
* creating the actual data reader. The relationship between {@link ReadTask} and {@link DataReader} * responsible for creating the actual data reader. The relationship between
* {@link DataReaderFactory} and {@link DataReader}
* is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}.
* *
* Note that, the read task will be serialized and sent to executors, then the data reader will be * Note that, the reader factory will be serialized and sent to executors, then the data reader
* created on executors and do the actual reading. So {@link ReadTask} must be serializable and * will be created on executors and do the actual reading. So {@link DataReaderFactory} must be
* {@link DataReader} doesn't need to be. * serializable and {@link DataReader} doesn't need to be.
*/ */
@InterfaceStability.Evolving @InterfaceStability.Evolving
public interface ReadTask<T> extends Serializable { public interface DataReaderFactory<T> extends Serializable {
/** /**
* The preferred locations where this read task can run faster, but Spark does not guarantee that * The preferred locations where the data reader returned by this reader factory can run faster,
* this task will always run on these locations. The implementations should make sure that it can * but Spark does not guarantee to run the data reader on these locations.
* be run on any location. The location is a string representing the host name. * The implementations should make sure that it can be run on any location.
* The location is a string representing the host name.
* *
* Note that if a host name cannot be recognized by Spark, it will be ignored as it was not in * Note that if a host name cannot be recognized by Spark, it will be ignored as it was not in
* the returned locations. By default this method returns empty string array, which means this * the returned locations. By default this method returns empty string array, which means this
@ -50,7 +52,7 @@ public interface ReadTask<T> extends Serializable {
} }
/** /**
* Returns a data reader to do the actual reading work for this read task. * Returns a data reader to do the actual reading work.
* *
* If this method fails (by throwing an exception), the corresponding Spark task would fail and * If this method fails (by throwing an exception), the corresponding Spark task would fail and
* get retried until hitting the maximum retry times. * get retried until hitting the maximum retry times.

View file

@ -30,7 +30,8 @@ import org.apache.spark.sql.types.StructType;
* {@link org.apache.spark.sql.sources.v2.ReadSupportWithSchema#createReader( * {@link org.apache.spark.sql.sources.v2.ReadSupportWithSchema#createReader(
* StructType, org.apache.spark.sql.sources.v2.DataSourceV2Options)}. * StructType, org.apache.spark.sql.sources.v2.DataSourceV2Options)}.
* It can mix in various query optimization interfaces to speed up the data scan. The actual scan * It can mix in various query optimization interfaces to speed up the data scan. The actual scan
* logic is delegated to {@link ReadTask}s that are returned by {@link #createReadTasks()}. * logic is delegated to {@link DataReaderFactory}s that are returned by
* {@link #createDataReaderFactories()}.
* *
* There are mainly 3 kinds of query optimizations: * There are mainly 3 kinds of query optimizations:
* 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column
@ -63,9 +64,9 @@ public interface DataSourceV2Reader {
StructType readSchema(); StructType readSchema();
/** /**
* Returns a list of read tasks. Each task is responsible for outputting data for one RDD * Returns a list of reader factories. Each factory is responsible for creating a data reader to
* partition. That means the number of tasks returned here is same as the number of RDD * output data for one RDD partition. That means the number of factories returned here is same as
* partitions this scan outputs. * the number of RDD partitions this scan outputs.
* *
* Note that, this may not be a full scan if the data source reader mixes in other optimization * Note that, this may not be a full scan if the data source reader mixes in other optimization
* interfaces like column pruning, filter push-down, etc. These optimizations are applied before * interfaces like column pruning, filter push-down, etc. These optimizations are applied before
@ -74,5 +75,5 @@ public interface DataSourceV2Reader {
* If this method fails (by throwing an exception), the action would fail and no Spark job was * If this method fails (by throwing an exception), the action would fail and no Spark job was
* submitted. * submitted.
*/ */
List<ReadTask<Row>> createReadTasks(); List<DataReaderFactory<Row>> createDataReaderFactories();
} }

View file

@ -21,9 +21,9 @@ import org.apache.spark.annotation.InterfaceStability;
/** /**
* An interface to represent data distribution requirement, which specifies how the records should * An interface to represent data distribution requirement, which specifies how the records should
* be distributed among the {@link ReadTask}s that are returned by * be distributed among the data partitions(one {@link DataReader} outputs data for one partition).
* {@link DataSourceV2Reader#createReadTasks()}. Note that this interface has nothing to do with * Note that this interface has nothing to do with the data ordering inside one
* the data ordering inside one partition(the output records of a single {@link ReadTask}). * partition(the output records of a single {@link DataReader}).
* *
* The instance of this interface is created and provided by Spark, then consumed by * The instance of this interface is created and provided by Spark, then consumed by
* {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to

View file

@ -29,7 +29,7 @@ import org.apache.spark.annotation.InterfaceStability;
public interface Partitioning { public interface Partitioning {
/** /**
* Returns the number of partitions(i.e., {@link ReadTask}s) the data source outputs. * Returns the number of partitions(i.e., {@link DataReaderFactory}s) the data source outputs.
*/ */
int numPartitions(); int numPartitions();

View file

@ -30,21 +30,22 @@ import org.apache.spark.sql.vectorized.ColumnarBatch;
@InterfaceStability.Evolving @InterfaceStability.Evolving
public interface SupportsScanColumnarBatch extends DataSourceV2Reader { public interface SupportsScanColumnarBatch extends DataSourceV2Reader {
@Override @Override
default List<ReadTask<Row>> createReadTasks() { default List<DataReaderFactory<Row>> createDataReaderFactories() {
throw new IllegalStateException( throw new IllegalStateException(
"createReadTasks not supported by default within SupportsScanColumnarBatch."); "createDataReaderFactories not supported by default within SupportsScanColumnarBatch.");
} }
/** /**
* Similar to {@link DataSourceV2Reader#createReadTasks()}, but returns columnar data in batches. * Similar to {@link DataSourceV2Reader#createDataReaderFactories()}, but returns columnar data
* in batches.
*/ */
List<ReadTask<ColumnarBatch>> createBatchReadTasks(); List<DataReaderFactory<ColumnarBatch>> createBatchDataReaderFactories();
/** /**
* Returns true if the concrete data source reader can read data in batch according to the scan * Returns true if the concrete data source reader can read data in batch according to the scan
* properties like required columns, pushes filters, etc. It's possible that the implementation * properties like required columns, pushes filters, etc. It's possible that the implementation
* can only support some certain columns with certain types. Users can overwrite this method and * can only support some certain columns with certain types. Users can overwrite this method and
* {@link #createReadTasks()} to fallback to normal read path under some conditions. * {@link #createDataReaderFactories()} to fallback to normal read path under some conditions.
*/ */
default boolean enableBatchRead() { default boolean enableBatchRead() {
return true; return true;

View file

@ -33,13 +33,14 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
public interface SupportsScanUnsafeRow extends DataSourceV2Reader { public interface SupportsScanUnsafeRow extends DataSourceV2Reader {
@Override @Override
default List<ReadTask<Row>> createReadTasks() { default List<DataReaderFactory<Row>> createDataReaderFactories() {
throw new IllegalStateException( throw new IllegalStateException(
"createReadTasks not supported by default within SupportsScanUnsafeRow"); "createDataReaderFactories not supported by default within SupportsScanUnsafeRow");
} }
/** /**
* Similar to {@link DataSourceV2Reader#createReadTasks()}, but returns data in unsafe row format. * Similar to {@link DataSourceV2Reader#createDataReaderFactories()},
* but returns data in unsafe row format.
*/ */
List<ReadTask<UnsafeRow>> createUnsafeRowReadTasks(); List<DataReaderFactory<UnsafeRow>> createUnsafeRowReaderFactories();
} }

View file

@ -36,8 +36,8 @@ public interface MicroBatchReadSupport extends DataSourceV2 {
* streaming query. * streaming query.
* *
* The execution engine will create a micro-batch reader at the start of a streaming query, * The execution engine will create a micro-batch reader at the start of a streaming query,
* alternate calls to setOffsetRange and createReadTasks for each batch to process, and then * alternate calls to setOffsetRange and createDataReaderFactories for each batch to process, and
* call stop() when the execution is complete. Note that a single query may have multiple * then call stop() when the execution is complete. Note that a single query may have multiple
* executions due to restart or failure recovery. * executions due to restart or failure recovery.
* *
* @param schema the user provided schema, or empty() if none was provided * @param schema the user provided schema, or empty() if none was provided

View file

@ -27,7 +27,7 @@ import java.util.Optional;
* A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this
* interface to allow reading in a continuous processing mode stream. * interface to allow reading in a continuous processing mode stream.
* *
* Implementations must ensure each read task output is a {@link ContinuousDataReader}. * Implementations must ensure each reader factory output is a {@link ContinuousDataReader}.
* *
* Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with
* DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely.
@ -47,9 +47,9 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reade
Offset deserializeOffset(String json); Offset deserializeOffset(String json);
/** /**
* Set the desired start offset for read tasks created from this reader. The scan will start * Set the desired start offset for reader factories created from this reader. The scan will
* from the first record after the provided offset, or from an implementation-defined inferred * start from the first record after the provided offset, or from an implementation-defined
* starting point if no offset is provided. * inferred starting point if no offset is provided.
*/ */
void setOffset(Optional<Offset> start); void setOffset(Optional<Offset> start);
@ -61,9 +61,9 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reade
Offset getStartOffset(); Offset getStartOffset();
/** /**
* The execution engine will call this method in every epoch to determine if new read tasks need * The execution engine will call this method in every epoch to determine if new reader
* to be generated, which may be required if for example the underlying source system has had * factories need to be generated, which may be required if for example the underlying
* partitions added or removed. * source system has had partitions added or removed.
* *
* If true, the query will be shut down and restarted with a new reader. * If true, the query will be shut down and restarted with a new reader.
*/ */

View file

@ -33,9 +33,9 @@ import java.util.Optional;
@InterfaceStability.Evolving @InterfaceStability.Evolving
public interface MicroBatchReader extends DataSourceV2Reader, BaseStreamingSource { public interface MicroBatchReader extends DataSourceV2Reader, BaseStreamingSource {
/** /**
* Set the desired offset range for read tasks created from this reader. Read tasks will * Set the desired offset range for reader factories created from this reader. Reader factories
* generate only data within (`start`, `end`]; that is, from the first record after `start` to * will generate only data within (`start`, `end`]; that is, from the first record after `start`
* the record with offset `end`. * to the record with offset `end`.
* *
* @param start The initial offset to scan from. If not specified, scan from an * @param start The initial offset to scan from. If not specified, scan from an
* implementation-specified start point, such as the earliest available record. * implementation-specified start point, such as the earliest available record.

View file

@ -22,24 +22,24 @@ import scala.reflect.ClassTag
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.v2.reader.ReadTask import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
class DataSourceRDDPartition[T : ClassTag](val index: Int, val readTask: ReadTask[T]) class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T])
extends Partition with Serializable extends Partition with Serializable
class DataSourceRDD[T: ClassTag]( class DataSourceRDD[T: ClassTag](
sc: SparkContext, sc: SparkContext,
@transient private val readTasks: java.util.List[ReadTask[T]]) @transient private val readerFactories: java.util.List[DataReaderFactory[T]])
extends RDD[T](sc, Nil) { extends RDD[T](sc, Nil) {
override protected def getPartitions: Array[Partition] = { override protected def getPartitions: Array[Partition] = {
readTasks.asScala.zipWithIndex.map { readerFactories.asScala.zipWithIndex.map {
case (readTask, index) => new DataSourceRDDPartition(index, readTask) case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
}.toArray }.toArray
} }
override def compute(split: Partition, context: TaskContext): Iterator[T] = { override def compute(split: Partition, context: TaskContext): Iterator[T] = {
val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readTask.createDataReader() val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader()
context.addTaskCompletionListener(_ => reader.close()) context.addTaskCompletionListener(_ => reader.close())
val iter = new Iterator[T] { val iter = new Iterator[T] {
private[this] var valuePrepared = false private[this] var valuePrepared = false
@ -63,6 +63,6 @@ class DataSourceRDD[T: ClassTag](
} }
override def getPreferredLocations(split: Partition): Seq[String] = { override def getPreferredLocations(split: Partition): Seq[String] = {
split.asInstanceOf[DataSourceRDDPartition[T]].readTask.preferredLocations() split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.preferredLocations()
} }
} }

View file

@ -51,11 +51,11 @@ case class DataSourceV2ScanExec(
case _ => super.outputPartitioning case _ => super.outputPartitioning
} }
private lazy val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { private lazy val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]] = reader match {
case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories()
case _ => case _ =>
reader.createReadTasks().asScala.map { reader.createDataReaderFactories().asScala.map {
new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow] new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow]
}.asJava }.asJava
} }
@ -63,18 +63,19 @@ case class DataSourceV2ScanExec(
case r: SupportsScanColumnarBatch if r.enableBatchRead() => case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
assert(!reader.isInstanceOf[ContinuousReader], assert(!reader.isInstanceOf[ContinuousReader],
"continuous stream reader does not support columnar read yet.") "continuous stream reader does not support columnar read yet.")
new DataSourceRDD(sparkContext, r.createBatchReadTasks()).asInstanceOf[RDD[InternalRow]] new DataSourceRDD(sparkContext, r.createBatchDataReaderFactories())
.asInstanceOf[RDD[InternalRow]]
case _: ContinuousReader => case _: ContinuousReader =>
EpochCoordinatorRef.get( EpochCoordinatorRef.get(
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
sparkContext.env) sparkContext.env)
.askSync[Unit](SetReaderPartitions(readTasks.size())) .askSync[Unit](SetReaderPartitions(readerFactories.size()))
new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories)
.asInstanceOf[RDD[InternalRow]] .asInstanceOf[RDD[InternalRow]]
case _ => case _ =>
new DataSourceRDD(sparkContext, readTasks).asInstanceOf[RDD[InternalRow]] new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]]
} }
override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD)
@ -99,14 +100,14 @@ case class DataSourceV2ScanExec(
} }
} }
class RowToUnsafeRowReadTask(rowReadTask: ReadTask[Row], schema: StructType) class RowToUnsafeRowDataReaderFactory(rowReaderFactory: DataReaderFactory[Row], schema: StructType)
extends ReadTask[UnsafeRow] { extends DataReaderFactory[UnsafeRow] {
override def preferredLocations: Array[String] = rowReadTask.preferredLocations override def preferredLocations: Array[String] = rowReaderFactory.preferredLocations
override def createDataReader: DataReader[UnsafeRow] = { override def createDataReader: DataReader[UnsafeRow] = {
new RowToUnsafeDataReader( new RowToUnsafeDataReader(
rowReadTask.createDataReader, RowEncoder.apply(schema).resolveAndBind()) rowReaderFactory.createDataReader, RowEncoder.apply(schema).resolveAndBind())
} }
} }

View file

@ -39,15 +39,15 @@ import org.apache.spark.util.{SystemClock, ThreadUtils}
class ContinuousDataSourceRDD( class ContinuousDataSourceRDD(
sc: SparkContext, sc: SparkContext,
sqlContext: SQLContext, sqlContext: SQLContext,
@transient private val readTasks: java.util.List[ReadTask[UnsafeRow]]) @transient private val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]])
extends RDD[UnsafeRow](sc, Nil) { extends RDD[UnsafeRow](sc, Nil) {
private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize
private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs
override protected def getPartitions: Array[Partition] = { override protected def getPartitions: Array[Partition] = {
readTasks.asScala.zipWithIndex.map { readerFactories.asScala.zipWithIndex.map {
case (readTask, index) => new DataSourceRDDPartition(index, readTask) case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
}.toArray }.toArray
} }
@ -57,7 +57,8 @@ class ContinuousDataSourceRDD(
throw new ContinuousTaskRetryException() throw new ContinuousTaskRetryException()
} }
val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader() val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]]
.readerFactory.createDataReader()
val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)
@ -136,7 +137,7 @@ class ContinuousDataSourceRDD(
} }
override def getPreferredLocations(split: Partition): Seq[String] = { override def getPreferredLocations(split: Partition): Seq[String] = {
split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.preferredLocations() split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readerFactory.preferredLocations()
} }
} }

View file

@ -68,7 +68,7 @@ class RateStreamContinuousReader(options: DataSourceV2Options)
override def getStartOffset(): Offset = offset override def getStartOffset(): Offset = offset
override def createReadTasks(): java.util.List[ReadTask[Row]] = { override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = {
val partitionStartMap = offset match { val partitionStartMap = offset match {
case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off: RateStreamOffset => off.partitionToValueAndRunTimeMs
case off => case off =>
@ -86,13 +86,13 @@ class RateStreamContinuousReader(options: DataSourceV2Options)
val start = partitionStartMap(i) val start = partitionStartMap(i)
// Have each partition advance by numPartitions each row, with starting points staggered // Have each partition advance by numPartitions each row, with starting points staggered
// by their partition index. // by their partition index.
RateStreamContinuousReadTask( RateStreamContinuousDataReaderFactory(
start.value, start.value,
start.runTimeMs, start.runTimeMs,
i, i,
numPartitions, numPartitions,
perPartitionRate) perPartitionRate)
.asInstanceOf[ReadTask[Row]] .asInstanceOf[DataReaderFactory[Row]]
}.asJava }.asJava
} }
@ -101,13 +101,13 @@ class RateStreamContinuousReader(options: DataSourceV2Options)
} }
case class RateStreamContinuousReadTask( case class RateStreamContinuousDataReaderFactory(
startValue: Long, startValue: Long,
startTimeMs: Long, startTimeMs: Long,
partitionIndex: Int, partitionIndex: Int,
increment: Long, increment: Long,
rowsPerSecond: Double) rowsPerSecond: Double)
extends ReadTask[Row] { extends DataReaderFactory[Row] {
override def createDataReader(): DataReader[Row] = override def createDataReader(): DataReader[Row] =
new RateStreamContinuousDataReader( new RateStreamContinuousDataReader(
startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) startValue, startTimeMs, partitionIndex, increment, rowsPerSecond)

View file

@ -123,7 +123,7 @@ class RateStreamMicroBatchReader(options: DataSourceV2Options)
RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json))
} }
override def createReadTasks(): java.util.List[ReadTask[Row]] = { override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = {
val startMap = start.partitionToValueAndRunTimeMs val startMap = start.partitionToValueAndRunTimeMs
val endMap = end.partitionToValueAndRunTimeMs val endMap = end.partitionToValueAndRunTimeMs
endMap.keys.toSeq.map { part => endMap.keys.toSeq.map { part =>
@ -139,7 +139,7 @@ class RateStreamMicroBatchReader(options: DataSourceV2Options)
outTimeMs += msPerPartitionBetweenRows outTimeMs += msPerPartitionBetweenRows
} }
RateStreamBatchTask(packedRows).asInstanceOf[ReadTask[Row]] RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]]
}.toList.asJava }.toList.asJava
} }
@ -147,7 +147,7 @@ class RateStreamMicroBatchReader(options: DataSourceV2Options)
override def stop(): Unit = {} override def stop(): Unit = {}
} }
case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends ReadTask[Row] { case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] {
override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals) override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals)
} }

View file

@ -60,8 +60,8 @@ public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport {
} }
@Override @Override
public List<ReadTask<Row>> createReadTasks() { public List<DataReaderFactory<Row>> createDataReaderFactories() {
List<ReadTask<Row>> res = new ArrayList<>(); List<DataReaderFactory<Row>> res = new ArrayList<>();
Integer lowerBound = null; Integer lowerBound = null;
for (Filter filter : filters) { for (Filter filter : filters) {
@ -75,25 +75,25 @@ public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport {
} }
if (lowerBound == null) { if (lowerBound == null) {
res.add(new JavaAdvancedReadTask(0, 5, requiredSchema)); res.add(new JavaAdvancedDataReaderFactory(0, 5, requiredSchema));
res.add(new JavaAdvancedReadTask(5, 10, requiredSchema)); res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema));
} else if (lowerBound < 4) { } else if (lowerBound < 4) {
res.add(new JavaAdvancedReadTask(lowerBound + 1, 5, requiredSchema)); res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 5, requiredSchema));
res.add(new JavaAdvancedReadTask(5, 10, requiredSchema)); res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema));
} else if (lowerBound < 9) { } else if (lowerBound < 9) {
res.add(new JavaAdvancedReadTask(lowerBound + 1, 10, requiredSchema)); res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 10, requiredSchema));
} }
return res; return res;
} }
} }
static class JavaAdvancedReadTask implements ReadTask<Row>, DataReader<Row> { static class JavaAdvancedDataReaderFactory implements DataReaderFactory<Row>, DataReader<Row> {
private int start; private int start;
private int end; private int end;
private StructType requiredSchema; private StructType requiredSchema;
JavaAdvancedReadTask(int start, int end, StructType requiredSchema) { JavaAdvancedDataReaderFactory(int start, int end, StructType requiredSchema) {
this.start = start; this.start = start;
this.end = end; this.end = end;
this.requiredSchema = requiredSchema; this.requiredSchema = requiredSchema;
@ -101,7 +101,7 @@ public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport {
@Override @Override
public DataReader<Row> createDataReader() { public DataReader<Row> createDataReader() {
return new JavaAdvancedReadTask(start - 1, end, requiredSchema); return new JavaAdvancedDataReaderFactory(start - 1, end, requiredSchema);
} }
@Override @Override

View file

@ -42,12 +42,14 @@ public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport {
} }
@Override @Override
public List<ReadTask<ColumnarBatch>> createBatchReadTasks() { public List<DataReaderFactory<ColumnarBatch>> createBatchDataReaderFactories() {
return java.util.Arrays.asList(new JavaBatchReadTask(0, 50), new JavaBatchReadTask(50, 90)); return java.util.Arrays.asList(
new JavaBatchDataReaderFactory(0, 50), new JavaBatchDataReaderFactory(50, 90));
} }
} }
static class JavaBatchReadTask implements ReadTask<ColumnarBatch>, DataReader<ColumnarBatch> { static class JavaBatchDataReaderFactory
implements DataReaderFactory<ColumnarBatch>, DataReader<ColumnarBatch> {
private int start; private int start;
private int end; private int end;
@ -57,7 +59,7 @@ public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport {
private OnHeapColumnVector j; private OnHeapColumnVector j;
private ColumnarBatch batch; private ColumnarBatch batch;
JavaBatchReadTask(int start, int end) { JavaBatchDataReaderFactory(int start, int end) {
this.start = start; this.start = start;
this.end = end; this.end = end;
} }

View file

@ -40,10 +40,10 @@ public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport {
} }
@Override @Override
public List<ReadTask<Row>> createReadTasks() { public List<DataReaderFactory<Row>> createDataReaderFactories() {
return java.util.Arrays.asList( return java.util.Arrays.asList(
new SpecificReadTask(new int[]{1, 1, 3}, new int[]{4, 4, 6}), new SpecificDataReaderFactory(new int[]{1, 1, 3}, new int[]{4, 4, 6}),
new SpecificReadTask(new int[]{2, 4, 4}, new int[]{6, 2, 2})); new SpecificDataReaderFactory(new int[]{2, 4, 4}, new int[]{6, 2, 2}));
} }
@Override @Override
@ -70,12 +70,12 @@ public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport {
} }
} }
static class SpecificReadTask implements ReadTask<Row>, DataReader<Row> { static class SpecificDataReaderFactory implements DataReaderFactory<Row>, DataReader<Row> {
private int[] i; private int[] i;
private int[] j; private int[] j;
private int current = -1; private int current = -1;
SpecificReadTask(int[] i, int[] j) { SpecificDataReaderFactory(int[] i, int[] j) {
assert i.length == j.length; assert i.length == j.length;
this.i = i; this.i = i;
this.j = j; this.j = j;

View file

@ -24,7 +24,7 @@ import org.apache.spark.sql.sources.v2.DataSourceV2;
import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.DataSourceV2Options;
import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; import org.apache.spark.sql.sources.v2.ReadSupportWithSchema;
import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader;
import org.apache.spark.sql.sources.v2.reader.ReadTask; import org.apache.spark.sql.sources.v2.reader.DataReaderFactory;
import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.StructType;
public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema {
@ -42,7 +42,7 @@ public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWi
} }
@Override @Override
public List<ReadTask<Row>> createReadTasks() { public List<DataReaderFactory<Row>> createDataReaderFactories() {
return java.util.Collections.emptyList(); return java.util.Collections.emptyList();
} }
} }

View file

@ -26,7 +26,7 @@ import org.apache.spark.sql.sources.v2.DataSourceV2;
import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.DataSourceV2Options;
import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.ReadSupport;
import org.apache.spark.sql.sources.v2.reader.DataReader; import org.apache.spark.sql.sources.v2.reader.DataReader;
import org.apache.spark.sql.sources.v2.reader.ReadTask; import org.apache.spark.sql.sources.v2.reader.DataReaderFactory;
import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader;
import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.StructType;
@ -41,25 +41,25 @@ public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport {
} }
@Override @Override
public List<ReadTask<Row>> createReadTasks() { public List<DataReaderFactory<Row>> createDataReaderFactories() {
return java.util.Arrays.asList( return java.util.Arrays.asList(
new JavaSimpleReadTask(0, 5), new JavaSimpleDataReaderFactory(0, 5),
new JavaSimpleReadTask(5, 10)); new JavaSimpleDataReaderFactory(5, 10));
} }
} }
static class JavaSimpleReadTask implements ReadTask<Row>, DataReader<Row> { static class JavaSimpleDataReaderFactory implements DataReaderFactory<Row>, DataReader<Row> {
private int start; private int start;
private int end; private int end;
JavaSimpleReadTask(int start, int end) { JavaSimpleDataReaderFactory(int start, int end) {
this.start = start; this.start = start;
this.end = end; this.end = end;
} }
@Override @Override
public DataReader<Row> createDataReader() { public DataReader<Row> createDataReader() {
return new JavaSimpleReadTask(start - 1, end); return new JavaSimpleDataReaderFactory(start - 1, end);
} }
@Override @Override

View file

@ -38,19 +38,20 @@ public class JavaUnsafeRowDataSourceV2 implements DataSourceV2, ReadSupport {
} }
@Override @Override
public List<ReadTask<UnsafeRow>> createUnsafeRowReadTasks() { public List<DataReaderFactory<UnsafeRow>> createUnsafeRowReaderFactories() {
return java.util.Arrays.asList( return java.util.Arrays.asList(
new JavaUnsafeRowReadTask(0, 5), new JavaUnsafeRowDataReaderFactory(0, 5),
new JavaUnsafeRowReadTask(5, 10)); new JavaUnsafeRowDataReaderFactory(5, 10));
} }
} }
static class JavaUnsafeRowReadTask implements ReadTask<UnsafeRow>, DataReader<UnsafeRow> { static class JavaUnsafeRowDataReaderFactory
implements DataReaderFactory<UnsafeRow>, DataReader<UnsafeRow> {
private int start; private int start;
private int end; private int end;
private UnsafeRow row; private UnsafeRow row;
JavaUnsafeRowReadTask(int start, int end) { JavaUnsafeRowDataReaderFactory(int start, int end) {
this.start = start; this.start = start;
this.end = end; this.end = end;
this.row = new UnsafeRow(2); this.row = new UnsafeRow(2);
@ -59,7 +60,7 @@ public class JavaUnsafeRowDataSourceV2 implements DataSourceV2, ReadSupport {
@Override @Override
public DataReader<UnsafeRow> createDataReader() { public DataReader<UnsafeRow> createDataReader() {
return new JavaUnsafeRowReadTask(start - 1, end); return new JavaUnsafeRowDataReaderFactory(start - 1, end);
} }
@Override @Override

View file

@ -78,7 +78,7 @@ class RateSourceV2Suite extends StreamTest {
val reader = new RateStreamMicroBatchReader( val reader = new RateStreamMicroBatchReader(
new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava))
reader.setOffsetRange(Optional.empty(), Optional.empty()) reader.setOffsetRange(Optional.empty(), Optional.empty())
val tasks = reader.createReadTasks() val tasks = reader.createDataReaderFactories()
assert(tasks.size == 11) assert(tasks.size == 11)
} }
@ -118,7 +118,7 @@ class RateSourceV2Suite extends StreamTest {
val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000))))
val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000))))
reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
val tasks = reader.createReadTasks() val tasks = reader.createDataReaderFactories()
assert(tasks.size == 1) assert(tasks.size == 1)
assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20) assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20)
} }
@ -133,7 +133,7 @@ class RateSourceV2Suite extends StreamTest {
}.toMap) }.toMap)
reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
val tasks = reader.createReadTasks() val tasks = reader.createDataReaderFactories()
assert(tasks.size == 11) assert(tasks.size == 11)
val readData = tasks.asScala val readData = tasks.asScala
@ -161,12 +161,12 @@ class RateSourceV2Suite extends StreamTest {
val reader = new RateStreamContinuousReader( val reader = new RateStreamContinuousReader(
new DataSourceV2Options(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) new DataSourceV2Options(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava))
reader.setOffset(Optional.empty()) reader.setOffset(Optional.empty())
val tasks = reader.createReadTasks() val tasks = reader.createDataReaderFactories()
assert(tasks.size == 2) assert(tasks.size == 2)
val data = scala.collection.mutable.ListBuffer[Row]() val data = scala.collection.mutable.ListBuffer[Row]()
tasks.asScala.foreach { tasks.asScala.foreach {
case t: RateStreamContinuousReadTask => case t: RateStreamContinuousDataReaderFactory =>
val startTimeMs = reader.getStartOffset() val startTimeMs = reader.getStartOffset()
.asInstanceOf[RateStreamOffset] .asInstanceOf[RateStreamOffset]
.partitionToValueAndRunTimeMs(t.partitionIndex) .partitionToValueAndRunTimeMs(t.partitionIndex)

View file

@ -204,18 +204,20 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {
class Reader extends DataSourceV2Reader { class Reader extends DataSourceV2Reader {
override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
override def createReadTasks(): JList[ReadTask[Row]] = { override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
java.util.Arrays.asList(new SimpleReadTask(0, 5), new SimpleReadTask(5, 10)) java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5), new SimpleDataReaderFactory(5, 10))
} }
} }
override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader
} }
class SimpleReadTask(start: Int, end: Int) extends ReadTask[Row] with DataReader[Row] { class SimpleDataReaderFactory(start: Int, end: Int)
extends DataReaderFactory[Row]
with DataReader[Row] {
private var current = start - 1 private var current = start - 1
override def createDataReader(): DataReader[Row] = new SimpleReadTask(start, end) override def createDataReader(): DataReader[Row] = new SimpleDataReaderFactory(start, end)
override def next(): Boolean = { override def next(): Boolean = {
current += 1 current += 1
@ -252,21 +254,21 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport {
requiredSchema requiredSchema
} }
override def createReadTasks(): JList[ReadTask[Row]] = { override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
val lowerBound = filters.collect { val lowerBound = filters.collect {
case GreaterThan("i", v: Int) => v case GreaterThan("i", v: Int) => v
}.headOption }.headOption
val res = new ArrayList[ReadTask[Row]] val res = new ArrayList[DataReaderFactory[Row]]
if (lowerBound.isEmpty) { if (lowerBound.isEmpty) {
res.add(new AdvancedReadTask(0, 5, requiredSchema)) res.add(new AdvancedDataReaderFactory(0, 5, requiredSchema))
res.add(new AdvancedReadTask(5, 10, requiredSchema)) res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema))
} else if (lowerBound.get < 4) { } else if (lowerBound.get < 4) {
res.add(new AdvancedReadTask(lowerBound.get + 1, 5, requiredSchema)) res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 5, requiredSchema))
res.add(new AdvancedReadTask(5, 10, requiredSchema)) res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema))
} else if (lowerBound.get < 9) { } else if (lowerBound.get < 9) {
res.add(new AdvancedReadTask(lowerBound.get + 1, 10, requiredSchema)) res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 10, requiredSchema))
} }
res res
@ -276,13 +278,13 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport {
override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader
} }
class AdvancedReadTask(start: Int, end: Int, requiredSchema: StructType) class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType)
extends ReadTask[Row] with DataReader[Row] { extends DataReaderFactory[Row] with DataReader[Row] {
private var current = start - 1 private var current = start - 1
override def createDataReader(): DataReader[Row] = { override def createDataReader(): DataReader[Row] = {
new AdvancedReadTask(start, end, requiredSchema) new AdvancedDataReaderFactory(start, end, requiredSchema)
} }
override def close(): Unit = {} override def close(): Unit = {}
@ -307,16 +309,17 @@ class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport {
class Reader extends DataSourceV2Reader with SupportsScanUnsafeRow { class Reader extends DataSourceV2Reader with SupportsScanUnsafeRow {
override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
override def createUnsafeRowReadTasks(): JList[ReadTask[UnsafeRow]] = { override def createUnsafeRowReaderFactories(): JList[DataReaderFactory[UnsafeRow]] = {
java.util.Arrays.asList(new UnsafeRowReadTask(0, 5), new UnsafeRowReadTask(5, 10)) java.util.Arrays.asList(new UnsafeRowDataReaderFactory(0, 5),
new UnsafeRowDataReaderFactory(5, 10))
} }
} }
override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader
} }
class UnsafeRowReadTask(start: Int, end: Int) class UnsafeRowDataReaderFactory(start: Int, end: Int)
extends ReadTask[UnsafeRow] with DataReader[UnsafeRow] { extends DataReaderFactory[UnsafeRow] with DataReader[UnsafeRow] {
private val row = new UnsafeRow(2) private val row = new UnsafeRow(2)
row.pointTo(new Array[Byte](8 * 3), 8 * 3) row.pointTo(new Array[Byte](8 * 3), 8 * 3)
@ -341,7 +344,7 @@ class UnsafeRowReadTask(start: Int, end: Int)
class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema {
class Reader(val readSchema: StructType) extends DataSourceV2Reader { class Reader(val readSchema: StructType) extends DataSourceV2Reader {
override def createReadTasks(): JList[ReadTask[Row]] = override def createDataReaderFactories(): JList[DataReaderFactory[Row]] =
java.util.Collections.emptyList() java.util.Collections.emptyList()
} }
@ -354,16 +357,16 @@ class BatchDataSourceV2 extends DataSourceV2 with ReadSupport {
class Reader extends DataSourceV2Reader with SupportsScanColumnarBatch { class Reader extends DataSourceV2Reader with SupportsScanColumnarBatch {
override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
override def createBatchReadTasks(): JList[ReadTask[ColumnarBatch]] = { override def createBatchDataReaderFactories(): JList[DataReaderFactory[ColumnarBatch]] = {
java.util.Arrays.asList(new BatchReadTask(0, 50), new BatchReadTask(50, 90)) java.util.Arrays.asList(new BatchDataReaderFactory(0, 50), new BatchDataReaderFactory(50, 90))
} }
} }
override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader
} }
class BatchReadTask(start: Int, end: Int) class BatchDataReaderFactory(start: Int, end: Int)
extends ReadTask[ColumnarBatch] with DataReader[ColumnarBatch] { extends DataReaderFactory[ColumnarBatch] with DataReader[ColumnarBatch] {
private final val BATCH_SIZE = 20 private final val BATCH_SIZE = 20
private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType)
@ -406,11 +409,11 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport {
class Reader extends DataSourceV2Reader with SupportsReportPartitioning { class Reader extends DataSourceV2Reader with SupportsReportPartitioning {
override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int")
override def createReadTasks(): JList[ReadTask[Row]] = { override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
// Note that we don't have same value of column `a` across partitions. // Note that we don't have same value of column `a` across partitions.
java.util.Arrays.asList( java.util.Arrays.asList(
new SpecificReadTask(Array(1, 1, 3), Array(4, 4, 6)), new SpecificDataReaderFactory(Array(1, 1, 3), Array(4, 4, 6)),
new SpecificReadTask(Array(2, 4, 4), Array(6, 2, 2))) new SpecificDataReaderFactory(Array(2, 4, 4), Array(6, 2, 2)))
} }
override def outputPartitioning(): Partitioning = new MyPartitioning override def outputPartitioning(): Partitioning = new MyPartitioning
@ -428,7 +431,9 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport {
override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader
} }
class SpecificReadTask(i: Array[Int], j: Array[Int]) extends ReadTask[Row] with DataReader[Row] { class SpecificDataReaderFactory(i: Array[Int], j: Array[Int])
extends DataReaderFactory[Row]
with DataReader[Row] {
assert(i.length == j.length) assert(i.length == j.length)
private var current = -1 private var current = -1

View file

@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path}
import org.apache.spark.SparkContext import org.apache.spark.SparkContext
import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataSourceV2Reader, ReadTask} import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceV2Reader}
import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.SerializableConfiguration
@ -45,7 +45,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS
class Reader(path: String, conf: Configuration) extends DataSourceV2Reader { class Reader(path: String, conf: Configuration) extends DataSourceV2Reader {
override def readSchema(): StructType = schema override def readSchema(): StructType = schema
override def createReadTasks(): JList[ReadTask[Row]] = { override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
val dataPath = new Path(path) val dataPath = new Path(path)
val fs = dataPath.getFileSystem(conf) val fs = dataPath.getFileSystem(conf)
if (fs.exists(dataPath)) { if (fs.exists(dataPath)) {
@ -54,7 +54,9 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS
name.startsWith("_") || name.startsWith(".") name.startsWith("_") || name.startsWith(".")
}.map { f => }.map { f =>
val serializableConf = new SerializableConfiguration(conf) val serializableConf = new SerializableConfiguration(conf)
new SimpleCSVReadTask(f.getPath.toUri.toString, serializableConf): ReadTask[Row] new SimpleCSVDataReaderFactory(
f.getPath.toUri.toString,
serializableConf): DataReaderFactory[Row]
}.toList.asJava }.toList.asJava
} else { } else {
Collections.emptyList() Collections.emptyList()
@ -149,8 +151,8 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS
} }
} }
class SimpleCSVReadTask(path: String, conf: SerializableConfiguration) class SimpleCSVDataReaderFactory(path: String, conf: SerializableConfiguration)
extends ReadTask[Row] with DataReader[Row] { extends DataReaderFactory[Row] with DataReader[Row] {
@transient private var lines: Iterator[String] = _ @transient private var lines: Iterator[String] = _
@transient private var currentLine: String = _ @transient private var currentLine: String = _

View file

@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.sources.v2.DataSourceV2Options import org.apache.spark.sql.sources.v2.DataSourceV2Options
import org.apache.spark.sql.sources.v2.reader.ReadTask import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
import org.apache.spark.sql.sources.v2.streaming._ import org.apache.spark.sql.sources.v2.streaming._
import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset}
import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter
@ -45,7 +45,7 @@ case class FakeReader() extends MicroBatchReader with ContinuousReader {
def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map())
def setOffset(start: Optional[Offset]): Unit = {} def setOffset(start: Optional[Offset]): Unit = {}
def createReadTasks(): java.util.ArrayList[ReadTask[Row]] = { def createDataReaderFactories(): java.util.ArrayList[DataReaderFactory[Row]] = {
throw new IllegalStateException("fake source - cannot actually read") throw new IllegalStateException("fake source - cannot actually read")
} }
} }