Roll forward "[SPARK-23096][SS] Migrate rate source to V2"

## What changes were proposed in this pull request?

Roll forward c68ec4e (#20688).

There are two minor test changes required:

* An error which used to be TreeNodeException[ArithmeticException] is no longer wrapped and is now just ArithmeticException.
* The test framework simply does not set the active Spark session. (Or rather, it doesn't do so early enough - I think it only happens when a query is analyzed.) I've added the required logic to SQLTestUtils.

## How was this patch tested?

existing tests

Author: Jose Torres <torres.joseph.f+github@gmail.com>
Author: jerryshao <sshao@hortonworks.com>

Closes #20922 from jose-torres/ratefix.
This commit is contained in:
Jose Torres 2018-03-30 21:54:26 +08:00 committed by Wenchen Fan
parent b02e76cbff
commit 5b5a36ed6d
9 changed files with 524 additions and 663 deletions

View file

@ -5,6 +5,5 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
org.apache.spark.sql.execution.datasources.text.TextFileFormat
org.apache.spark.sql.execution.streaming.ConsoleSinkProvider
org.apache.spark.sql.execution.streaming.RateSourceProvider
org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2

View file

@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming.OutputMode
@ -566,6 +566,7 @@ object DataSource extends Logging {
val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat"
val nativeOrc = classOf[OrcFileFormat].getCanonicalName
val socket = classOf[TextSocketSourceProvider].getCanonicalName
val rate = classOf[RateStreamProvider].getCanonicalName
Map(
"org.apache.spark.sql.jdbc" -> jdbc,
@ -587,7 +588,8 @@ object DataSource extends Logging {
"org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm,
"org.apache.spark.ml.source.libsvm" -> libsvm,
"com.databricks.spark.csv" -> csv,
"org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket
"org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket,
"org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate
)
}

View file

@ -1,262 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming
import java.io._
import java.nio.charset.StandardCharsets
import java.util.Optional
import java.util.concurrent.TimeUnit
import org.apache.commons.io.IOUtils
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
import org.apache.spark.sql.types._
import org.apache.spark.util.{ManualClock, SystemClock}
/**
* A source that generates increment long values with timestamps. Each generated row has two
* columns: a timestamp column for the generated time and an auto increment long column starting
* with 0L.
*
* This source supports the following options:
* - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second.
* - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed
* becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer
* seconds.
* - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the
* generated rows. The source will try its best to reach `rowsPerSecond`, but the query may
* be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed.
*/
class RateSourceProvider extends StreamSourceProvider with DataSourceRegister
with DataSourceV2 with ContinuousReadSupport {
override def sourceSchema(
sqlContext: SQLContext,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): (String, StructType) = {
if (schema.nonEmpty) {
throw new AnalysisException("The rate source does not support a user-specified schema.")
}
(shortName(), RateSourceProvider.SCHEMA)
}
override def createSource(
sqlContext: SQLContext,
metadataPath: String,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source = {
val params = CaseInsensitiveMap(parameters)
val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L)
if (rowsPerSecond <= 0) {
throw new IllegalArgumentException(
s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " +
"must be positive")
}
val rampUpTimeSeconds =
params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L)
if (rampUpTimeSeconds < 0) {
throw new IllegalArgumentException(
s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " +
"must not be negative")
}
val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse(
sqlContext.sparkContext.defaultParallelism)
if (numPartitions <= 0) {
throw new IllegalArgumentException(
s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " +
"must be positive")
}
new RateStreamSource(
sqlContext,
metadataPath,
rowsPerSecond,
rampUpTimeSeconds,
numPartitions,
params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing
)
}
override def createContinuousReader(
schema: Optional[StructType],
checkpointLocation: String,
options: DataSourceOptions): ContinuousReader = {
new RateStreamContinuousReader(options)
}
override def shortName(): String = "rate"
}
object RateSourceProvider {
val SCHEMA =
StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil)
val VERSION = 1
}
class RateStreamSource(
sqlContext: SQLContext,
metadataPath: String,
rowsPerSecond: Long,
rampUpTimeSeconds: Long,
numPartitions: Int,
useManualClock: Boolean) extends Source with Logging {
import RateSourceProvider._
import RateStreamSource._
val clock = if (useManualClock) new ManualClock else new SystemClock
private val maxSeconds = Long.MaxValue / rowsPerSecond
if (rampUpTimeSeconds > maxSeconds) {
throw new ArithmeticException(
s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" +
s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.")
}
private val startTimeMs = {
val metadataLog =
new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) {
override def serialize(metadata: LongOffset, out: OutputStream): Unit = {
val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8))
writer.write("v" + VERSION + "\n")
writer.write(metadata.json)
writer.flush
}
override def deserialize(in: InputStream): LongOffset = {
val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8))
// HDFSMetadataLog guarantees that it never creates a partial file.
assert(content.length != 0)
if (content(0) == 'v') {
val indexOfNewLine = content.indexOf("\n")
if (indexOfNewLine > 0) {
val version = parseVersion(content.substring(0, indexOfNewLine), VERSION)
LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1)))
} else {
throw new IllegalStateException(
s"Log file was malformed: failed to detect the log file version line.")
}
} else {
throw new IllegalStateException(
s"Log file was malformed: failed to detect the log file version line.")
}
}
}
metadataLog.get(0).getOrElse {
val offset = LongOffset(clock.getTimeMillis())
metadataLog.add(0, offset)
logInfo(s"Start time: $offset")
offset
}.offset
}
/** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */
@volatile private var lastTimeMs = startTimeMs
override def schema: StructType = RateSourceProvider.SCHEMA
override def getOffset: Option[Offset] = {
val now = clock.getTimeMillis()
if (lastTimeMs < now) {
lastTimeMs = now
}
Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs)))
}
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L)
val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L)
assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)")
if (endSeconds > maxSeconds) {
throw new ArithmeticException("Integer overflow. Max offset with " +
s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.")
}
// Fix "lastTimeMs" for recovery
if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) {
lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs
}
val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds)
val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds)
logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " +
s"rangeStart: $rangeStart, rangeEnd: $rangeEnd")
if (rangeStart == rangeEnd) {
return sqlContext.internalCreateDataFrame(
sqlContext.sparkContext.emptyRDD, schema, isStreaming = true)
}
val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds)
val relativeMsPerValue =
TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart)
val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v =>
val relative = math.round((v - rangeStart) * relativeMsPerValue)
InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v)
}
sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true)
}
override def stop(): Unit = {}
override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " +
s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]"
}
object RateStreamSource {
/** Calculate the end value we will emit at the time `seconds`. */
def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = {
// E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10
// Then speedDeltaPerSecond = 2
//
// seconds = 0 1 2 3 4 5 6
// speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds)
// end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2
val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1)
if (seconds <= rampUpTimeSeconds) {
// Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to
// avoid overflow
if (seconds % 2 == 1) {
(seconds + 1) / 2 * speedDeltaPerSecond * seconds
} else {
seconds / 2 * speedDeltaPerSecond * (seconds + 1)
}
} else {
// rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds
val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds)
rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond
}
}
}

View file

@ -24,8 +24,8 @@ import org.json4s.jackson.Serialization
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair}
import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2
import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair}
import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
@ -40,8 +40,8 @@ class RateStreamContinuousReader(options: DataSourceOptions)
val creationTime = System.currentTimeMillis()
val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt
val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong
val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt
val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong
val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble
override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = {
@ -57,12 +57,12 @@ class RateStreamContinuousReader(options: DataSourceOptions)
RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json))
}
override def readSchema(): StructType = RateSourceProvider.SCHEMA
override def readSchema(): StructType = RateStreamProvider.SCHEMA
private var offset: Offset = _
override def setStartOffset(offset: java.util.Optional[Offset]): Unit = {
this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime))
this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime))
}
override def getStartOffset(): Offset = offset
@ -98,6 +98,19 @@ class RateStreamContinuousReader(options: DataSourceOptions)
override def commit(end: Offset): Unit = {}
override def stop(): Unit = {}
private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = {
RateStreamOffset(
Range(0, numPartitions).map { i =>
// Note that the starting offset is exclusive, so we have to decrement the starting value
// by the increment that will later be applied. The first row output in each
// partition will have a value equal to the partition index.
(i,
ValueRunTimeMsPair(
(i - numPartitions).toLong,
creationTimeMs))
}.toMap)
}
}
case class RateStreamContinuousDataReaderFactory(

View file

@ -0,0 +1,222 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming.sources
import java.io._
import java.nio.charset.StandardCharsets
import java.util.Optional
import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._
import org.apache.commons.io.IOUtils
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{ManualClock, SystemClock}
class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String)
extends MicroBatchReader with Logging {
import RateStreamProvider._
private[sources] val clock = {
// The option to use a manual clock is provided only for unit testing purposes.
if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock
}
private val rowsPerSecond =
options.get(ROWS_PER_SECOND).orElse("1").toLong
private val rampUpTimeSeconds =
Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String]))
.map(JavaUtils.timeStringAsSec(_))
.getOrElse(0L)
private val maxSeconds = Long.MaxValue / rowsPerSecond
if (rampUpTimeSeconds > maxSeconds) {
throw new ArithmeticException(
s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" +
s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.")
}
private[sources] val creationTimeMs = {
val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession)
require(session.isDefined)
val metadataLog =
new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) {
override def serialize(metadata: LongOffset, out: OutputStream): Unit = {
val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8))
writer.write("v" + VERSION + "\n")
writer.write(metadata.json)
writer.flush
}
override def deserialize(in: InputStream): LongOffset = {
val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8))
// HDFSMetadataLog guarantees that it never creates a partial file.
assert(content.length != 0)
if (content(0) == 'v') {
val indexOfNewLine = content.indexOf("\n")
if (indexOfNewLine > 0) {
parseVersion(content.substring(0, indexOfNewLine), VERSION)
LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1)))
} else {
throw new IllegalStateException(
s"Log file was malformed: failed to detect the log file version line.")
}
} else {
throw new IllegalStateException(
s"Log file was malformed: failed to detect the log file version line.")
}
}
}
metadataLog.get(0).getOrElse {
val offset = LongOffset(clock.getTimeMillis())
metadataLog.add(0, offset)
logInfo(s"Start time: $offset")
offset
}.offset
}
@volatile private var lastTimeMs: Long = creationTimeMs
private var start: LongOffset = _
private var end: LongOffset = _
override def readSchema(): StructType = SCHEMA
override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {
this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset]
this.end = end.orElse {
val now = clock.getTimeMillis()
if (lastTimeMs < now) {
lastTimeMs = now
}
LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs))
}.asInstanceOf[LongOffset]
}
override def getStartOffset(): Offset = {
if (start == null) throw new IllegalStateException("start offset not set")
start
}
override def getEndOffset(): Offset = {
if (end == null) throw new IllegalStateException("end offset not set")
end
}
override def deserializeOffset(json: String): Offset = {
LongOffset(json.toLong)
}
override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = {
val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L)
val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L)
assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)")
if (endSeconds > maxSeconds) {
throw new ArithmeticException("Integer overflow. Max offset with " +
s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.")
}
// Fix "lastTimeMs" for recovery
if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) {
lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs
}
val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds)
val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds)
logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " +
s"rangeStart: $rangeStart, rangeEnd: $rangeEnd")
if (rangeStart == rangeEnd) {
return List.empty.asJava
}
val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds)
val relativeMsPerValue =
TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart)
val numPartitions = {
val activeSession = SparkSession.getActiveSession
require(activeSession.isDefined)
Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String]))
.map(_.toInt)
.getOrElse(activeSession.get.sparkContext.defaultParallelism)
}
(0 until numPartitions).map { p =>
new RateStreamMicroBatchDataReaderFactory(
p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue)
: DataReaderFactory[Row]
}.toList.asJava
}
override def commit(end: Offset): Unit = {}
override def stop(): Unit = {}
override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " +
s"rampUpTimeSeconds=$rampUpTimeSeconds, " +
s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}"
}
class RateStreamMicroBatchDataReaderFactory(
partitionId: Int,
numPartitions: Int,
rangeStart: Long,
rangeEnd: Long,
localStartTimeMs: Long,
relativeMsPerValue: Double) extends DataReaderFactory[Row] {
override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader(
partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue)
}
class RateStreamMicroBatchDataReader(
partitionId: Int,
numPartitions: Int,
rangeStart: Long,
rangeEnd: Long,
localStartTimeMs: Long,
relativeMsPerValue: Double) extends DataReader[Row] {
private var count = 0
override def next(): Boolean = {
rangeStart + partitionId + numPartitions * count < rangeEnd
}
override def get(): Row = {
val currValue = rangeStart + partitionId + numPartitions * count
count += 1
val relative = math.round((currValue - rangeStart) * relativeMsPerValue)
Row(
DateTimeUtils.toJavaTimestamp(
DateTimeUtils.fromMillis(relative + localStartTimeMs)),
currValue
)
}
override def close(): Unit = {}
}

View file

@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming.sources
import java.util.Optional
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader}
import org.apache.spark.sql.types._
/**
* A source that generates increment long values with timestamps. Each generated row has two
* columns: a timestamp column for the generated time and an auto increment long column starting
* with 0L.
*
* This source supports the following options:
* - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second.
* - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed
* becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer
* seconds.
* - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the
* generated rows. The source will try its best to reach `rowsPerSecond`, but the query may
* be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed.
*/
class RateStreamProvider extends DataSourceV2
with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister {
import RateStreamProvider._
override def createMicroBatchReader(
schema: Optional[StructType],
checkpointLocation: String,
options: DataSourceOptions): MicroBatchReader = {
if (options.get(ROWS_PER_SECOND).isPresent) {
val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong
if (rowsPerSecond <= 0) {
throw new IllegalArgumentException(
s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive")
}
}
if (options.get(RAMP_UP_TIME).isPresent) {
val rampUpTimeSeconds =
JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get())
if (rampUpTimeSeconds < 0) {
throw new IllegalArgumentException(
s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative")
}
}
if (options.get(NUM_PARTITIONS).isPresent) {
val numPartitions = options.get(NUM_PARTITIONS).get().toInt
if (numPartitions <= 0) {
throw new IllegalArgumentException(
s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive")
}
}
if (schema.isPresent) {
throw new AnalysisException("The rate source does not support a user-specified schema.")
}
new RateStreamMicroBatchReader(options, checkpointLocation)
}
override def createContinuousReader(
schema: Optional[StructType],
checkpointLocation: String,
options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options)
override def shortName(): String = "rate"
}
object RateStreamProvider {
val SCHEMA =
StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil)
val VERSION = 1
val NUM_PARTITIONS = "numPartitions"
val ROWS_PER_SECOND = "rowsPerSecond"
val RAMP_UP_TIME = "rampUpTime"
/** Calculate the end value we will emit at the time `seconds`. */
def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = {
// E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10
// Then speedDeltaPerSecond = 2
//
// seconds = 0 1 2 3 4 5 6
// speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds)
// end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2
val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1)
if (seconds <= rampUpTimeSeconds) {
// Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to
// avoid overflow
if (seconds % 2 == 1) {
(seconds + 1) / 2 * speedDeltaPerSecond * seconds
} else {
seconds / 2 * speedDeltaPerSecond * (seconds + 1)
}
} else {
// rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds
val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds)
rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond
}
}
}

View file

@ -1,187 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming.sources
import java.util.Optional
import scala.collection.JavaConverters._
import scala.collection.mutable
import org.json4s.DefaultFormats
import org.json4s.jackson.Serialization
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType}
import org.apache.spark.util.{ManualClock, SystemClock}
/**
* This is a temporary register as we build out v2 migration. Microbatch read support should
* be implemented in the same register as v1.
*/
class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister {
override def createMicroBatchReader(
schema: Optional[StructType],
checkpointLocation: String,
options: DataSourceOptions): MicroBatchReader = {
new RateStreamMicroBatchReader(options)
}
override def shortName(): String = "ratev2"
}
class RateStreamMicroBatchReader(options: DataSourceOptions)
extends MicroBatchReader {
implicit val defaultFormats: DefaultFormats = DefaultFormats
val clock = {
// The option to use a manual clock is provided only for unit testing purposes.
if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock
else new SystemClock
}
private val numPartitions =
options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt
private val rowsPerSecond =
options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong
// The interval (in milliseconds) between rows in each partition.
// e.g. if there are 4 global rows per second, and 2 partitions, each partition
// should output rows every (1000 * 2 / 4) = 500 ms.
private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond
override def readSchema(): StructType = {
StructType(
StructField("timestamp", TimestampType, false) ::
StructField("value", LongType, false) :: Nil)
}
val creationTimeMs = clock.getTimeMillis()
private var start: RateStreamOffset = _
private var end: RateStreamOffset = _
override def setOffsetRange(
start: Optional[Offset],
end: Optional[Offset]): Unit = {
this.start = start.orElse(
RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs))
.asInstanceOf[RateStreamOffset]
this.end = end.orElse {
val currentTime = clock.getTimeMillis()
RateStreamOffset(
this.start.partitionToValueAndRunTimeMs.map {
case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) =>
// Calculate the number of rows we should advance in this partition (based on the
// current time), and output a corresponding offset.
val readInterval = currentTime - currentReadTime
val numNewRows = readInterval / msPerPartitionBetweenRows
if (numNewRows <= 0) {
startOffset
} else {
(part, ValueRunTimeMsPair(
currentVal + (numNewRows * numPartitions),
currentReadTime + (numNewRows * msPerPartitionBetweenRows)))
}
}
)
}.asInstanceOf[RateStreamOffset]
}
override def getStartOffset(): Offset = {
if (start == null) throw new IllegalStateException("start offset not set")
start
}
override def getEndOffset(): Offset = {
if (end == null) throw new IllegalStateException("end offset not set")
end
}
override def deserializeOffset(json: String): Offset = {
RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json))
}
override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = {
val startMap = start.partitionToValueAndRunTimeMs
val endMap = end.partitionToValueAndRunTimeMs
endMap.keys.toSeq.map { part =>
val ValueRunTimeMsPair(endVal, _) = endMap(part)
val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part)
val packedRows = mutable.ListBuffer[(Long, Long)]()
var outVal = startVal + numPartitions
var outTimeMs = startTimeMs
while (outVal <= endVal) {
packedRows.append((outTimeMs, outVal))
outVal += numPartitions
outTimeMs += msPerPartitionBetweenRows
}
RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]]
}.toList.asJava
}
override def commit(end: Offset): Unit = {}
override def stop(): Unit = {}
}
case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] {
override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals)
}
class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] {
private var currentIndex = -1
override def next(): Boolean = {
// Return true as long as the new index is in the seq.
currentIndex += 1
currentIndex < vals.size
}
override def get(): Row = {
Row(
DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)),
vals(currentIndex)._2)
}
override def close(): Unit = {}
}
object RateStreamSourceV2 {
val NUM_PARTITIONS = "numPartitions"
val ROWS_PER_SECOND = "rowsPerSecond"
private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = {
RateStreamOffset(
Range(0, numPartitions).map { i =>
// Note that the starting offset is exclusive, so we have to decrement the starting value
// by the increment that will later be applied. The first row output in each
// partition will have a value equal to the partition index.
(i,
ValueRunTimeMsPair(
(i - numPartitions).toLong,
creationTimeMs))
}.toMap)
}
}

View file

@ -1,191 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming
import java.util.Optional
import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2}
import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport}
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.util.ManualClock
class RateSourceV2Suite extends StreamTest {
import testImplicits._
case class AdvanceRateManualClock(seconds: Long) extends AddData {
override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
assert(query.nonEmpty)
val rateSource = query.get.logicalPlan.collect {
case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source
}.head
rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds))
rateSource.setOffsetRange(Optional.empty(), Optional.empty())
(rateSource, rateSource.getEndOffset())
}
}
test("microbatch in registry") {
DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match {
case ds: MicroBatchReadSupport =>
val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty())
assert(reader.isInstanceOf[RateStreamMicroBatchReader])
case _ =>
throw new IllegalStateException("Could not find v2 read support for rate")
}
}
test("basic microbatch execution") {
val input = spark.readStream
.format("rateV2")
.option("numPartitions", "1")
.option("rowsPerSecond", "10")
.option("useManualClock", "true")
.load()
testStream(input, useV2Sink = true)(
AdvanceRateManualClock(seconds = 1),
CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*),
StopStream,
StartStream(),
// Advance 2 seconds because creating a new RateSource will also create a new ManualClock
AdvanceRateManualClock(seconds = 2),
CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*)
)
}
test("microbatch - numPartitions propagated") {
val reader = new RateStreamMicroBatchReader(
new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava))
reader.setOffsetRange(Optional.empty(), Optional.empty())
val tasks = reader.createDataReaderFactories()
assert(tasks.size == 11)
}
test("microbatch - set offset") {
val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty())
val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000))))
val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000))))
reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
assert(reader.getStartOffset() == startOffset)
assert(reader.getEndOffset() == endOffset)
}
test("microbatch - infer offsets") {
val reader = new RateStreamMicroBatchReader(
new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava))
reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100)
reader.setOffsetRange(Optional.empty(), Optional.empty())
reader.getStartOffset() match {
case r: RateStreamOffset =>
assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs)
case _ => throw new IllegalStateException("unexpected offset type")
}
reader.getEndOffset() match {
case r: RateStreamOffset =>
// End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted
// longer than 100ms. It should never be early.
assert(r.partitionToValueAndRunTimeMs(0).value >= 9)
assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100)
case _ => throw new IllegalStateException("unexpected offset type")
}
}
test("microbatch - predetermined batch size") {
val reader = new RateStreamMicroBatchReader(
new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava))
val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000))))
val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000))))
reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
val tasks = reader.createDataReaderFactories()
assert(tasks.size == 1)
assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20)
}
test("microbatch - data read") {
val reader = new RateStreamMicroBatchReader(
new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava))
val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs)
val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map {
case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) =>
(part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000))
}.toMap)
reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
val tasks = reader.createDataReaderFactories()
assert(tasks.size == 11)
val readData = tasks.asScala
.map(_.createDataReader())
.flatMap { reader =>
val buf = scala.collection.mutable.ListBuffer[Row]()
while (reader.next()) buf.append(reader.get())
buf
}
assert(readData.map(_.getLong(1)).sorted == Range(0, 33))
}
test("continuous in registry") {
DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match {
case ds: ContinuousReadSupport =>
val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty())
assert(reader.isInstanceOf[RateStreamContinuousReader])
case _ =>
throw new IllegalStateException("Could not find v2 read support for rate")
}
}
test("continuous data") {
val reader = new RateStreamContinuousReader(
new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava))
reader.setStartOffset(Optional.empty())
val tasks = reader.createDataReaderFactories()
assert(tasks.size == 2)
val data = scala.collection.mutable.ListBuffer[Row]()
tasks.asScala.foreach {
case t: RateStreamContinuousDataReaderFactory =>
val startTimeMs = reader.getStartOffset()
.asInstanceOf[RateStreamOffset]
.partitionToValueAndRunTimeMs(t.partitionIndex)
.runTimeMs
val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader]
for (rowIndex <- 0 to 9) {
r.next()
data.append(r.get())
assert(r.getOffset() ==
RateStreamPartitionOffset(
t.partitionIndex,
t.partitionIndex + rowIndex * 2,
startTimeMs + (rowIndex + 1) * 100))
}
assert(System.currentTimeMillis() >= startTimeMs + 1000)
case _ => throw new IllegalStateException("Unexpected task type")
}
assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20))
}
}

View file

@ -15,13 +15,24 @@
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming
package org.apache.spark.sql.execution.streaming.sources
import java.nio.file.Files
import java.util.Optional
import java.util.concurrent.TimeUnit
import org.apache.spark.sql.AnalysisException
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest}
import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport}
import org.apache.spark.sql.sources.v2.reader.streaming.Offset
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.util.ManualClock
class RateSourceSuite extends StreamTest {
@ -29,18 +40,40 @@ class RateSourceSuite extends StreamTest {
import testImplicits._
case class AdvanceRateManualClock(seconds: Long) extends AddData {
override def addData(query: Option[StreamExecution]): (Source, Offset) = {
override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
assert(query.nonEmpty)
val rateSource = query.get.logicalPlan.collect {
case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] =>
source.asInstanceOf[RateStreamSource]
case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source
}.head
rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds))
(rateSource, rateSource.getOffset.get)
val offset = LongOffset(TimeUnit.MILLISECONDS.toSeconds(
rateSource.clock.getTimeMillis() - rateSource.creationTimeMs))
(rateSource, offset)
}
}
test("basic") {
test("microbatch in registry") {
DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match {
case ds: MicroBatchReadSupport =>
val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty())
assert(reader.isInstanceOf[RateStreamMicroBatchReader])
case _ =>
throw new IllegalStateException("Could not find read support for rate")
}
}
test("compatible with old path in registry") {
DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider",
spark.sqlContext.conf).newInstance() match {
case ds: MicroBatchReadSupport =>
assert(ds.isInstanceOf[RateStreamProvider])
case _ =>
throw new IllegalStateException("Could not find read support for rate")
}
}
test("microbatch - basic") {
val input = spark.readStream
.format("rate")
.option("rowsPerSecond", "10")
@ -57,7 +90,7 @@ class RateSourceSuite extends StreamTest {
)
}
test("uniform distribution of event timestamps") {
test("microbatch - uniform distribution of event timestamps") {
val input = spark.readStream
.format("rate")
.option("rowsPerSecond", "1500")
@ -74,8 +107,74 @@ class RateSourceSuite extends StreamTest {
)
}
test("microbatch - set offset") {
val temp = Files.createTempDirectory("dummy").toString
val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp)
val startOffset = LongOffset(0L)
val endOffset = LongOffset(1L)
reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
assert(reader.getStartOffset() == startOffset)
assert(reader.getEndOffset() == endOffset)
}
test("microbatch - infer offsets") {
val tempFolder = Files.createTempDirectory("dummy").toString
val reader = new RateStreamMicroBatchReader(
new DataSourceOptions(
Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava),
tempFolder)
reader.clock.asInstanceOf[ManualClock].advance(100000)
reader.setOffsetRange(Optional.empty(), Optional.empty())
reader.getStartOffset() match {
case r: LongOffset => assert(r.offset === 0L)
case _ => throw new IllegalStateException("unexpected offset type")
}
reader.getEndOffset() match {
case r: LongOffset => assert(r.offset >= 100)
case _ => throw new IllegalStateException("unexpected offset type")
}
}
test("microbatch - predetermined batch size") {
val temp = Files.createTempDirectory("dummy").toString
val reader = new RateStreamMicroBatchReader(
new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp)
val startOffset = LongOffset(0L)
val endOffset = LongOffset(1L)
reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
val tasks = reader.createDataReaderFactories()
assert(tasks.size == 1)
val dataReader = tasks.get(0).createDataReader()
val data = ArrayBuffer[Row]()
while (dataReader.next()) {
data.append(dataReader.get())
}
assert(data.size === 20)
}
test("microbatch - data read") {
val temp = Files.createTempDirectory("dummy").toString
val reader = new RateStreamMicroBatchReader(
new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp)
val startOffset = LongOffset(0L)
val endOffset = LongOffset(1L)
reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
val tasks = reader.createDataReaderFactories()
assert(tasks.size == 11)
val readData = tasks.asScala
.map(_.createDataReader())
.flatMap { reader =>
val buf = scala.collection.mutable.ListBuffer[Row]()
while (reader.next()) buf.append(reader.get())
buf
}
assert(readData.map(_.getLong(1)).sorted == Range(0, 33))
}
test("valueAtSecond") {
import RateStreamSource._
import RateStreamProvider._
assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0)
assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5)
@ -161,7 +260,7 @@ class RateSourceSuite extends StreamTest {
option: String,
value: String,
expectedMessages: Seq[String]): Unit = {
val e = intercept[StreamingQueryException] {
val e = intercept[IllegalArgumentException] {
spark.readStream
.format("rate")
.option(option, value)
@ -171,9 +270,8 @@ class RateSourceSuite extends StreamTest {
.start()
.awaitTermination()
}
assert(e.getCause.isInstanceOf[IllegalArgumentException])
for (msg <- expectedMessages) {
assert(e.getCause.getMessage.contains(msg))
assert(e.getMessage.contains(msg))
}
}
@ -191,4 +289,46 @@ class RateSourceSuite extends StreamTest {
assert(exception.getMessage.contains(
"rate source does not support a user-specified schema"))
}
test("continuous in registry") {
DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match {
case ds: ContinuousReadSupport =>
val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty())
assert(reader.isInstanceOf[RateStreamContinuousReader])
case _ =>
throw new IllegalStateException("Could not find read support for continuous rate")
}
}
test("continuous data") {
val reader = new RateStreamContinuousReader(
new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava))
reader.setStartOffset(Optional.empty())
val tasks = reader.createDataReaderFactories()
assert(tasks.size == 2)
val data = scala.collection.mutable.ListBuffer[Row]()
tasks.asScala.foreach {
case t: RateStreamContinuousDataReaderFactory =>
val startTimeMs = reader.getStartOffset()
.asInstanceOf[RateStreamOffset]
.partitionToValueAndRunTimeMs(t.partitionIndex)
.runTimeMs
val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader]
for (rowIndex <- 0 to 9) {
r.next()
data.append(r.get())
assert(r.getOffset() ==
RateStreamPartitionOffset(
t.partitionIndex,
t.partitionIndex + rowIndex * 2,
startTimeMs + (rowIndex + 1) * 100))
}
assert(System.currentTimeMillis() >= startTimeMs + 1000)
case _ => throw new IllegalStateException("Unexpected task type")
}
assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20))
}
}