Roll forward "[SPARK-23096][SS] Migrate rate source to V2"
## What changes were proposed in this pull request?
Roll forward c68ec4e
(#20688).
There are two minor test changes required:
* An error which used to be TreeNodeException[ArithmeticException] is no longer wrapped and is now just ArithmeticException.
* The test framework simply does not set the active Spark session. (Or rather, it doesn't do so early enough - I think it only happens when a query is analyzed.) I've added the required logic to SQLTestUtils.
## How was this patch tested?
existing tests
Author: Jose Torres <torres.joseph.f+github@gmail.com>
Author: jerryshao <sshao@hortonworks.com>
Closes #20922 from jose-torres/ratefix.
This commit is contained in:
parent
b02e76cbff
commit
5b5a36ed6d
|
@ -5,6 +5,5 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
|
|||
org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
|
||||
org.apache.spark.sql.execution.datasources.text.TextFileFormat
|
||||
org.apache.spark.sql.execution.streaming.ConsoleSinkProvider
|
||||
org.apache.spark.sql.execution.streaming.RateSourceProvider
|
||||
org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
|
||||
org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
|
||||
org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2
|
||||
|
|
|
@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
|
|||
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
|
||||
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
|
||||
import org.apache.spark.sql.execution.streaming._
|
||||
import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
|
||||
import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.sources._
|
||||
import org.apache.spark.sql.streaming.OutputMode
|
||||
|
@ -566,6 +566,7 @@ object DataSource extends Logging {
|
|||
val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat"
|
||||
val nativeOrc = classOf[OrcFileFormat].getCanonicalName
|
||||
val socket = classOf[TextSocketSourceProvider].getCanonicalName
|
||||
val rate = classOf[RateStreamProvider].getCanonicalName
|
||||
|
||||
Map(
|
||||
"org.apache.spark.sql.jdbc" -> jdbc,
|
||||
|
@ -587,7 +588,8 @@ object DataSource extends Logging {
|
|||
"org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm,
|
||||
"org.apache.spark.ml.source.libsvm" -> libsvm,
|
||||
"com.databricks.spark.csv" -> csv,
|
||||
"org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket
|
||||
"org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket,
|
||||
"org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,262 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.streaming
|
||||
|
||||
import java.io._
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.util.Optional
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
import org.apache.commons.io.IOUtils
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.network.util.JavaUtils
|
||||
import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
|
||||
import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader
|
||||
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
|
||||
import org.apache.spark.sql.sources.v2._
|
||||
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.{ManualClock, SystemClock}
|
||||
|
||||
/**
|
||||
* A source that generates increment long values with timestamps. Each generated row has two
|
||||
* columns: a timestamp column for the generated time and an auto increment long column starting
|
||||
* with 0L.
|
||||
*
|
||||
* This source supports the following options:
|
||||
* - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second.
|
||||
* - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed
|
||||
* becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer
|
||||
* seconds.
|
||||
* - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the
|
||||
* generated rows. The source will try its best to reach `rowsPerSecond`, but the query may
|
||||
* be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed.
|
||||
*/
|
||||
class RateSourceProvider extends StreamSourceProvider with DataSourceRegister
|
||||
with DataSourceV2 with ContinuousReadSupport {
|
||||
|
||||
override def sourceSchema(
|
||||
sqlContext: SQLContext,
|
||||
schema: Option[StructType],
|
||||
providerName: String,
|
||||
parameters: Map[String, String]): (String, StructType) = {
|
||||
if (schema.nonEmpty) {
|
||||
throw new AnalysisException("The rate source does not support a user-specified schema.")
|
||||
}
|
||||
|
||||
(shortName(), RateSourceProvider.SCHEMA)
|
||||
}
|
||||
|
||||
override def createSource(
|
||||
sqlContext: SQLContext,
|
||||
metadataPath: String,
|
||||
schema: Option[StructType],
|
||||
providerName: String,
|
||||
parameters: Map[String, String]): Source = {
|
||||
val params = CaseInsensitiveMap(parameters)
|
||||
|
||||
val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L)
|
||||
if (rowsPerSecond <= 0) {
|
||||
throw new IllegalArgumentException(
|
||||
s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " +
|
||||
"must be positive")
|
||||
}
|
||||
|
||||
val rampUpTimeSeconds =
|
||||
params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L)
|
||||
if (rampUpTimeSeconds < 0) {
|
||||
throw new IllegalArgumentException(
|
||||
s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " +
|
||||
"must not be negative")
|
||||
}
|
||||
|
||||
val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse(
|
||||
sqlContext.sparkContext.defaultParallelism)
|
||||
if (numPartitions <= 0) {
|
||||
throw new IllegalArgumentException(
|
||||
s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " +
|
||||
"must be positive")
|
||||
}
|
||||
|
||||
new RateStreamSource(
|
||||
sqlContext,
|
||||
metadataPath,
|
||||
rowsPerSecond,
|
||||
rampUpTimeSeconds,
|
||||
numPartitions,
|
||||
params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing
|
||||
)
|
||||
}
|
||||
|
||||
override def createContinuousReader(
|
||||
schema: Optional[StructType],
|
||||
checkpointLocation: String,
|
||||
options: DataSourceOptions): ContinuousReader = {
|
||||
new RateStreamContinuousReader(options)
|
||||
}
|
||||
|
||||
override def shortName(): String = "rate"
|
||||
}
|
||||
|
||||
object RateSourceProvider {
|
||||
val SCHEMA =
|
||||
StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil)
|
||||
|
||||
val VERSION = 1
|
||||
}
|
||||
|
||||
class RateStreamSource(
|
||||
sqlContext: SQLContext,
|
||||
metadataPath: String,
|
||||
rowsPerSecond: Long,
|
||||
rampUpTimeSeconds: Long,
|
||||
numPartitions: Int,
|
||||
useManualClock: Boolean) extends Source with Logging {
|
||||
|
||||
import RateSourceProvider._
|
||||
import RateStreamSource._
|
||||
|
||||
val clock = if (useManualClock) new ManualClock else new SystemClock
|
||||
|
||||
private val maxSeconds = Long.MaxValue / rowsPerSecond
|
||||
|
||||
if (rampUpTimeSeconds > maxSeconds) {
|
||||
throw new ArithmeticException(
|
||||
s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" +
|
||||
s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.")
|
||||
}
|
||||
|
||||
private val startTimeMs = {
|
||||
val metadataLog =
|
||||
new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) {
|
||||
override def serialize(metadata: LongOffset, out: OutputStream): Unit = {
|
||||
val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8))
|
||||
writer.write("v" + VERSION + "\n")
|
||||
writer.write(metadata.json)
|
||||
writer.flush
|
||||
}
|
||||
|
||||
override def deserialize(in: InputStream): LongOffset = {
|
||||
val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8))
|
||||
// HDFSMetadataLog guarantees that it never creates a partial file.
|
||||
assert(content.length != 0)
|
||||
if (content(0) == 'v') {
|
||||
val indexOfNewLine = content.indexOf("\n")
|
||||
if (indexOfNewLine > 0) {
|
||||
val version = parseVersion(content.substring(0, indexOfNewLine), VERSION)
|
||||
LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1)))
|
||||
} else {
|
||||
throw new IllegalStateException(
|
||||
s"Log file was malformed: failed to detect the log file version line.")
|
||||
}
|
||||
} else {
|
||||
throw new IllegalStateException(
|
||||
s"Log file was malformed: failed to detect the log file version line.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
metadataLog.get(0).getOrElse {
|
||||
val offset = LongOffset(clock.getTimeMillis())
|
||||
metadataLog.add(0, offset)
|
||||
logInfo(s"Start time: $offset")
|
||||
offset
|
||||
}.offset
|
||||
}
|
||||
|
||||
/** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */
|
||||
@volatile private var lastTimeMs = startTimeMs
|
||||
|
||||
override def schema: StructType = RateSourceProvider.SCHEMA
|
||||
|
||||
override def getOffset: Option[Offset] = {
|
||||
val now = clock.getTimeMillis()
|
||||
if (lastTimeMs < now) {
|
||||
lastTimeMs = now
|
||||
}
|
||||
Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs)))
|
||||
}
|
||||
|
||||
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
|
||||
val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L)
|
||||
val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L)
|
||||
assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)")
|
||||
if (endSeconds > maxSeconds) {
|
||||
throw new ArithmeticException("Integer overflow. Max offset with " +
|
||||
s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.")
|
||||
}
|
||||
// Fix "lastTimeMs" for recovery
|
||||
if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) {
|
||||
lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs
|
||||
}
|
||||
val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds)
|
||||
val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds)
|
||||
logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " +
|
||||
s"rangeStart: $rangeStart, rangeEnd: $rangeEnd")
|
||||
|
||||
if (rangeStart == rangeEnd) {
|
||||
return sqlContext.internalCreateDataFrame(
|
||||
sqlContext.sparkContext.emptyRDD, schema, isStreaming = true)
|
||||
}
|
||||
|
||||
val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds)
|
||||
val relativeMsPerValue =
|
||||
TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart)
|
||||
|
||||
val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v =>
|
||||
val relative = math.round((v - rangeStart) * relativeMsPerValue)
|
||||
InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v)
|
||||
}
|
||||
sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true)
|
||||
}
|
||||
|
||||
override def stop(): Unit = {}
|
||||
|
||||
override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " +
|
||||
s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]"
|
||||
}
|
||||
|
||||
object RateStreamSource {
|
||||
|
||||
/** Calculate the end value we will emit at the time `seconds`. */
|
||||
def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = {
|
||||
// E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10
|
||||
// Then speedDeltaPerSecond = 2
|
||||
//
|
||||
// seconds = 0 1 2 3 4 5 6
|
||||
// speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds)
|
||||
// end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2
|
||||
val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1)
|
||||
if (seconds <= rampUpTimeSeconds) {
|
||||
// Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to
|
||||
// avoid overflow
|
||||
if (seconds % 2 == 1) {
|
||||
(seconds + 1) / 2 * speedDeltaPerSecond * seconds
|
||||
} else {
|
||||
seconds / 2 * speedDeltaPerSecond * (seconds + 1)
|
||||
}
|
||||
} else {
|
||||
// rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds
|
||||
val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds)
|
||||
rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond
|
||||
}
|
||||
}
|
||||
}
|
|
@ -24,8 +24,8 @@ import org.json4s.jackson.Serialization
|
|||
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair}
|
||||
import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2
|
||||
import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair}
|
||||
import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
|
||||
import org.apache.spark.sql.sources.v2.DataSourceOptions
|
||||
import org.apache.spark.sql.sources.v2.reader._
|
||||
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
|
||||
|
@ -40,8 +40,8 @@ class RateStreamContinuousReader(options: DataSourceOptions)
|
|||
|
||||
val creationTime = System.currentTimeMillis()
|
||||
|
||||
val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt
|
||||
val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong
|
||||
val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt
|
||||
val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong
|
||||
val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble
|
||||
|
||||
override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = {
|
||||
|
@ -57,12 +57,12 @@ class RateStreamContinuousReader(options: DataSourceOptions)
|
|||
RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json))
|
||||
}
|
||||
|
||||
override def readSchema(): StructType = RateSourceProvider.SCHEMA
|
||||
override def readSchema(): StructType = RateStreamProvider.SCHEMA
|
||||
|
||||
private var offset: Offset = _
|
||||
|
||||
override def setStartOffset(offset: java.util.Optional[Offset]): Unit = {
|
||||
this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime))
|
||||
this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime))
|
||||
}
|
||||
|
||||
override def getStartOffset(): Offset = offset
|
||||
|
@ -98,6 +98,19 @@ class RateStreamContinuousReader(options: DataSourceOptions)
|
|||
override def commit(end: Offset): Unit = {}
|
||||
override def stop(): Unit = {}
|
||||
|
||||
private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = {
|
||||
RateStreamOffset(
|
||||
Range(0, numPartitions).map { i =>
|
||||
// Note that the starting offset is exclusive, so we have to decrement the starting value
|
||||
// by the increment that will later be applied. The first row output in each
|
||||
// partition will have a value equal to the partition index.
|
||||
(i,
|
||||
ValueRunTimeMsPair(
|
||||
(i - numPartitions).toLong,
|
||||
creationTimeMs))
|
||||
}.toMap)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
case class RateStreamContinuousDataReaderFactory(
|
||||
|
|
|
@ -0,0 +1,222 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.streaming.sources
|
||||
|
||||
import java.io._
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.util.Optional
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.commons.io.IOUtils
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.network.util.JavaUtils
|
||||
import org.apache.spark.sql.{Row, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.sql.execution.streaming._
|
||||
import org.apache.spark.sql.sources.v2.DataSourceOptions
|
||||
import org.apache.spark.sql.sources.v2.reader._
|
||||
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.util.{ManualClock, SystemClock}
|
||||
|
||||
class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String)
|
||||
extends MicroBatchReader with Logging {
|
||||
import RateStreamProvider._
|
||||
|
||||
private[sources] val clock = {
|
||||
// The option to use a manual clock is provided only for unit testing purposes.
|
||||
if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock
|
||||
}
|
||||
|
||||
private val rowsPerSecond =
|
||||
options.get(ROWS_PER_SECOND).orElse("1").toLong
|
||||
|
||||
private val rampUpTimeSeconds =
|
||||
Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String]))
|
||||
.map(JavaUtils.timeStringAsSec(_))
|
||||
.getOrElse(0L)
|
||||
|
||||
private val maxSeconds = Long.MaxValue / rowsPerSecond
|
||||
|
||||
if (rampUpTimeSeconds > maxSeconds) {
|
||||
throw new ArithmeticException(
|
||||
s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" +
|
||||
s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.")
|
||||
}
|
||||
|
||||
private[sources] val creationTimeMs = {
|
||||
val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession)
|
||||
require(session.isDefined)
|
||||
|
||||
val metadataLog =
|
||||
new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) {
|
||||
override def serialize(metadata: LongOffset, out: OutputStream): Unit = {
|
||||
val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8))
|
||||
writer.write("v" + VERSION + "\n")
|
||||
writer.write(metadata.json)
|
||||
writer.flush
|
||||
}
|
||||
|
||||
override def deserialize(in: InputStream): LongOffset = {
|
||||
val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8))
|
||||
// HDFSMetadataLog guarantees that it never creates a partial file.
|
||||
assert(content.length != 0)
|
||||
if (content(0) == 'v') {
|
||||
val indexOfNewLine = content.indexOf("\n")
|
||||
if (indexOfNewLine > 0) {
|
||||
parseVersion(content.substring(0, indexOfNewLine), VERSION)
|
||||
LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1)))
|
||||
} else {
|
||||
throw new IllegalStateException(
|
||||
s"Log file was malformed: failed to detect the log file version line.")
|
||||
}
|
||||
} else {
|
||||
throw new IllegalStateException(
|
||||
s"Log file was malformed: failed to detect the log file version line.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
metadataLog.get(0).getOrElse {
|
||||
val offset = LongOffset(clock.getTimeMillis())
|
||||
metadataLog.add(0, offset)
|
||||
logInfo(s"Start time: $offset")
|
||||
offset
|
||||
}.offset
|
||||
}
|
||||
|
||||
@volatile private var lastTimeMs: Long = creationTimeMs
|
||||
|
||||
private var start: LongOffset = _
|
||||
private var end: LongOffset = _
|
||||
|
||||
override def readSchema(): StructType = SCHEMA
|
||||
|
||||
override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {
|
||||
this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset]
|
||||
this.end = end.orElse {
|
||||
val now = clock.getTimeMillis()
|
||||
if (lastTimeMs < now) {
|
||||
lastTimeMs = now
|
||||
}
|
||||
LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs))
|
||||
}.asInstanceOf[LongOffset]
|
||||
}
|
||||
|
||||
override def getStartOffset(): Offset = {
|
||||
if (start == null) throw new IllegalStateException("start offset not set")
|
||||
start
|
||||
}
|
||||
override def getEndOffset(): Offset = {
|
||||
if (end == null) throw new IllegalStateException("end offset not set")
|
||||
end
|
||||
}
|
||||
|
||||
override def deserializeOffset(json: String): Offset = {
|
||||
LongOffset(json.toLong)
|
||||
}
|
||||
|
||||
override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = {
|
||||
val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L)
|
||||
val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L)
|
||||
assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)")
|
||||
if (endSeconds > maxSeconds) {
|
||||
throw new ArithmeticException("Integer overflow. Max offset with " +
|
||||
s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.")
|
||||
}
|
||||
// Fix "lastTimeMs" for recovery
|
||||
if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) {
|
||||
lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs
|
||||
}
|
||||
val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds)
|
||||
val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds)
|
||||
logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " +
|
||||
s"rangeStart: $rangeStart, rangeEnd: $rangeEnd")
|
||||
|
||||
if (rangeStart == rangeEnd) {
|
||||
return List.empty.asJava
|
||||
}
|
||||
|
||||
val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds)
|
||||
val relativeMsPerValue =
|
||||
TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart)
|
||||
val numPartitions = {
|
||||
val activeSession = SparkSession.getActiveSession
|
||||
require(activeSession.isDefined)
|
||||
Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String]))
|
||||
.map(_.toInt)
|
||||
.getOrElse(activeSession.get.sparkContext.defaultParallelism)
|
||||
}
|
||||
|
||||
(0 until numPartitions).map { p =>
|
||||
new RateStreamMicroBatchDataReaderFactory(
|
||||
p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue)
|
||||
: DataReaderFactory[Row]
|
||||
}.toList.asJava
|
||||
}
|
||||
|
||||
override def commit(end: Offset): Unit = {}
|
||||
|
||||
override def stop(): Unit = {}
|
||||
|
||||
override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " +
|
||||
s"rampUpTimeSeconds=$rampUpTimeSeconds, " +
|
||||
s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}"
|
||||
}
|
||||
|
||||
class RateStreamMicroBatchDataReaderFactory(
|
||||
partitionId: Int,
|
||||
numPartitions: Int,
|
||||
rangeStart: Long,
|
||||
rangeEnd: Long,
|
||||
localStartTimeMs: Long,
|
||||
relativeMsPerValue: Double) extends DataReaderFactory[Row] {
|
||||
|
||||
override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader(
|
||||
partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue)
|
||||
}
|
||||
|
||||
class RateStreamMicroBatchDataReader(
|
||||
partitionId: Int,
|
||||
numPartitions: Int,
|
||||
rangeStart: Long,
|
||||
rangeEnd: Long,
|
||||
localStartTimeMs: Long,
|
||||
relativeMsPerValue: Double) extends DataReader[Row] {
|
||||
private var count = 0
|
||||
|
||||
override def next(): Boolean = {
|
||||
rangeStart + partitionId + numPartitions * count < rangeEnd
|
||||
}
|
||||
|
||||
override def get(): Row = {
|
||||
val currValue = rangeStart + partitionId + numPartitions * count
|
||||
count += 1
|
||||
val relative = math.round((currValue - rangeStart) * relativeMsPerValue)
|
||||
Row(
|
||||
DateTimeUtils.toJavaTimestamp(
|
||||
DateTimeUtils.fromMillis(relative + localStartTimeMs)),
|
||||
currValue
|
||||
)
|
||||
}
|
||||
|
||||
override def close(): Unit = {}
|
||||
}
|
|
@ -0,0 +1,125 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.streaming.sources
|
||||
|
||||
import java.util.Optional
|
||||
|
||||
import org.apache.spark.network.util.JavaUtils
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader
|
||||
import org.apache.spark.sql.sources.DataSourceRegister
|
||||
import org.apache.spark.sql.sources.v2._
|
||||
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
* A source that generates increment long values with timestamps. Each generated row has two
|
||||
* columns: a timestamp column for the generated time and an auto increment long column starting
|
||||
* with 0L.
|
||||
*
|
||||
* This source supports the following options:
|
||||
* - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second.
|
||||
* - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed
|
||||
* becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer
|
||||
* seconds.
|
||||
* - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the
|
||||
* generated rows. The source will try its best to reach `rowsPerSecond`, but the query may
|
||||
* be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed.
|
||||
*/
|
||||
class RateStreamProvider extends DataSourceV2
|
||||
with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister {
|
||||
import RateStreamProvider._
|
||||
|
||||
override def createMicroBatchReader(
|
||||
schema: Optional[StructType],
|
||||
checkpointLocation: String,
|
||||
options: DataSourceOptions): MicroBatchReader = {
|
||||
if (options.get(ROWS_PER_SECOND).isPresent) {
|
||||
val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong
|
||||
if (rowsPerSecond <= 0) {
|
||||
throw new IllegalArgumentException(
|
||||
s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive")
|
||||
}
|
||||
}
|
||||
|
||||
if (options.get(RAMP_UP_TIME).isPresent) {
|
||||
val rampUpTimeSeconds =
|
||||
JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get())
|
||||
if (rampUpTimeSeconds < 0) {
|
||||
throw new IllegalArgumentException(
|
||||
s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative")
|
||||
}
|
||||
}
|
||||
|
||||
if (options.get(NUM_PARTITIONS).isPresent) {
|
||||
val numPartitions = options.get(NUM_PARTITIONS).get().toInt
|
||||
if (numPartitions <= 0) {
|
||||
throw new IllegalArgumentException(
|
||||
s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive")
|
||||
}
|
||||
}
|
||||
|
||||
if (schema.isPresent) {
|
||||
throw new AnalysisException("The rate source does not support a user-specified schema.")
|
||||
}
|
||||
|
||||
new RateStreamMicroBatchReader(options, checkpointLocation)
|
||||
}
|
||||
|
||||
override def createContinuousReader(
|
||||
schema: Optional[StructType],
|
||||
checkpointLocation: String,
|
||||
options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options)
|
||||
|
||||
override def shortName(): String = "rate"
|
||||
}
|
||||
|
||||
object RateStreamProvider {
|
||||
val SCHEMA =
|
||||
StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil)
|
||||
|
||||
val VERSION = 1
|
||||
|
||||
val NUM_PARTITIONS = "numPartitions"
|
||||
val ROWS_PER_SECOND = "rowsPerSecond"
|
||||
val RAMP_UP_TIME = "rampUpTime"
|
||||
|
||||
/** Calculate the end value we will emit at the time `seconds`. */
|
||||
def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = {
|
||||
// E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10
|
||||
// Then speedDeltaPerSecond = 2
|
||||
//
|
||||
// seconds = 0 1 2 3 4 5 6
|
||||
// speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds)
|
||||
// end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2
|
||||
val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1)
|
||||
if (seconds <= rampUpTimeSeconds) {
|
||||
// Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to
|
||||
// avoid overflow
|
||||
if (seconds % 2 == 1) {
|
||||
(seconds + 1) / 2 * speedDeltaPerSecond * seconds
|
||||
} else {
|
||||
seconds / 2 * speedDeltaPerSecond * (seconds + 1)
|
||||
}
|
||||
} else {
|
||||
// rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds
|
||||
val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds)
|
||||
rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,187 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.streaming.sources
|
||||
|
||||
import java.util.Optional
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.json4s.DefaultFormats
|
||||
import org.json4s.jackson.Serialization
|
||||
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair}
|
||||
import org.apache.spark.sql.sources.DataSourceRegister
|
||||
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport}
|
||||
import org.apache.spark.sql.sources.v2.reader._
|
||||
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
|
||||
import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType}
|
||||
import org.apache.spark.util.{ManualClock, SystemClock}
|
||||
|
||||
/**
|
||||
* This is a temporary register as we build out v2 migration. Microbatch read support should
|
||||
* be implemented in the same register as v1.
|
||||
*/
|
||||
class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister {
|
||||
override def createMicroBatchReader(
|
||||
schema: Optional[StructType],
|
||||
checkpointLocation: String,
|
||||
options: DataSourceOptions): MicroBatchReader = {
|
||||
new RateStreamMicroBatchReader(options)
|
||||
}
|
||||
|
||||
override def shortName(): String = "ratev2"
|
||||
}
|
||||
|
||||
class RateStreamMicroBatchReader(options: DataSourceOptions)
|
||||
extends MicroBatchReader {
|
||||
implicit val defaultFormats: DefaultFormats = DefaultFormats
|
||||
|
||||
val clock = {
|
||||
// The option to use a manual clock is provided only for unit testing purposes.
|
||||
if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock
|
||||
else new SystemClock
|
||||
}
|
||||
|
||||
private val numPartitions =
|
||||
options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt
|
||||
private val rowsPerSecond =
|
||||
options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong
|
||||
|
||||
// The interval (in milliseconds) between rows in each partition.
|
||||
// e.g. if there are 4 global rows per second, and 2 partitions, each partition
|
||||
// should output rows every (1000 * 2 / 4) = 500 ms.
|
||||
private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond
|
||||
|
||||
override def readSchema(): StructType = {
|
||||
StructType(
|
||||
StructField("timestamp", TimestampType, false) ::
|
||||
StructField("value", LongType, false) :: Nil)
|
||||
}
|
||||
|
||||
val creationTimeMs = clock.getTimeMillis()
|
||||
|
||||
private var start: RateStreamOffset = _
|
||||
private var end: RateStreamOffset = _
|
||||
|
||||
override def setOffsetRange(
|
||||
start: Optional[Offset],
|
||||
end: Optional[Offset]): Unit = {
|
||||
this.start = start.orElse(
|
||||
RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs))
|
||||
.asInstanceOf[RateStreamOffset]
|
||||
|
||||
this.end = end.orElse {
|
||||
val currentTime = clock.getTimeMillis()
|
||||
RateStreamOffset(
|
||||
this.start.partitionToValueAndRunTimeMs.map {
|
||||
case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) =>
|
||||
// Calculate the number of rows we should advance in this partition (based on the
|
||||
// current time), and output a corresponding offset.
|
||||
val readInterval = currentTime - currentReadTime
|
||||
val numNewRows = readInterval / msPerPartitionBetweenRows
|
||||
if (numNewRows <= 0) {
|
||||
startOffset
|
||||
} else {
|
||||
(part, ValueRunTimeMsPair(
|
||||
currentVal + (numNewRows * numPartitions),
|
||||
currentReadTime + (numNewRows * msPerPartitionBetweenRows)))
|
||||
}
|
||||
}
|
||||
)
|
||||
}.asInstanceOf[RateStreamOffset]
|
||||
}
|
||||
|
||||
override def getStartOffset(): Offset = {
|
||||
if (start == null) throw new IllegalStateException("start offset not set")
|
||||
start
|
||||
}
|
||||
override def getEndOffset(): Offset = {
|
||||
if (end == null) throw new IllegalStateException("end offset not set")
|
||||
end
|
||||
}
|
||||
|
||||
override def deserializeOffset(json: String): Offset = {
|
||||
RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json))
|
||||
}
|
||||
|
||||
override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = {
|
||||
val startMap = start.partitionToValueAndRunTimeMs
|
||||
val endMap = end.partitionToValueAndRunTimeMs
|
||||
endMap.keys.toSeq.map { part =>
|
||||
val ValueRunTimeMsPair(endVal, _) = endMap(part)
|
||||
val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part)
|
||||
|
||||
val packedRows = mutable.ListBuffer[(Long, Long)]()
|
||||
var outVal = startVal + numPartitions
|
||||
var outTimeMs = startTimeMs
|
||||
while (outVal <= endVal) {
|
||||
packedRows.append((outTimeMs, outVal))
|
||||
outVal += numPartitions
|
||||
outTimeMs += msPerPartitionBetweenRows
|
||||
}
|
||||
|
||||
RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]]
|
||||
}.toList.asJava
|
||||
}
|
||||
|
||||
override def commit(end: Offset): Unit = {}
|
||||
override def stop(): Unit = {}
|
||||
}
|
||||
|
||||
case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] {
|
||||
override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals)
|
||||
}
|
||||
|
||||
class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] {
|
||||
private var currentIndex = -1
|
||||
|
||||
override def next(): Boolean = {
|
||||
// Return true as long as the new index is in the seq.
|
||||
currentIndex += 1
|
||||
currentIndex < vals.size
|
||||
}
|
||||
|
||||
override def get(): Row = {
|
||||
Row(
|
||||
DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)),
|
||||
vals(currentIndex)._2)
|
||||
}
|
||||
|
||||
override def close(): Unit = {}
|
||||
}
|
||||
|
||||
object RateStreamSourceV2 {
|
||||
val NUM_PARTITIONS = "numPartitions"
|
||||
val ROWS_PER_SECOND = "rowsPerSecond"
|
||||
|
||||
private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = {
|
||||
RateStreamOffset(
|
||||
Range(0, numPartitions).map { i =>
|
||||
// Note that the starting offset is exclusive, so we have to decrement the starting value
|
||||
// by the increment that will later be applied. The first row output in each
|
||||
// partition will have a value equal to the partition index.
|
||||
(i,
|
||||
ValueRunTimeMsPair(
|
||||
(i - numPartitions).toLong,
|
||||
creationTimeMs))
|
||||
}.toMap)
|
||||
}
|
||||
}
|
|
@ -1,191 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.streaming
|
||||
|
||||
import java.util.Optional
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.execution.datasources.DataSource
|
||||
import org.apache.spark.sql.execution.streaming.continuous._
|
||||
import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2}
|
||||
import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport}
|
||||
import org.apache.spark.sql.sources.v2.DataSourceOptions
|
||||
import org.apache.spark.sql.streaming.StreamTest
|
||||
import org.apache.spark.util.ManualClock
|
||||
|
||||
class RateSourceV2Suite extends StreamTest {
|
||||
import testImplicits._
|
||||
|
||||
case class AdvanceRateManualClock(seconds: Long) extends AddData {
|
||||
override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
|
||||
assert(query.nonEmpty)
|
||||
val rateSource = query.get.logicalPlan.collect {
|
||||
case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source
|
||||
}.head
|
||||
rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds))
|
||||
rateSource.setOffsetRange(Optional.empty(), Optional.empty())
|
||||
(rateSource, rateSource.getEndOffset())
|
||||
}
|
||||
}
|
||||
|
||||
test("microbatch in registry") {
|
||||
DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match {
|
||||
case ds: MicroBatchReadSupport =>
|
||||
val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty())
|
||||
assert(reader.isInstanceOf[RateStreamMicroBatchReader])
|
||||
case _ =>
|
||||
throw new IllegalStateException("Could not find v2 read support for rate")
|
||||
}
|
||||
}
|
||||
|
||||
test("basic microbatch execution") {
|
||||
val input = spark.readStream
|
||||
.format("rateV2")
|
||||
.option("numPartitions", "1")
|
||||
.option("rowsPerSecond", "10")
|
||||
.option("useManualClock", "true")
|
||||
.load()
|
||||
testStream(input, useV2Sink = true)(
|
||||
AdvanceRateManualClock(seconds = 1),
|
||||
CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*),
|
||||
StopStream,
|
||||
StartStream(),
|
||||
// Advance 2 seconds because creating a new RateSource will also create a new ManualClock
|
||||
AdvanceRateManualClock(seconds = 2),
|
||||
CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*)
|
||||
)
|
||||
}
|
||||
|
||||
test("microbatch - numPartitions propagated") {
|
||||
val reader = new RateStreamMicroBatchReader(
|
||||
new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava))
|
||||
reader.setOffsetRange(Optional.empty(), Optional.empty())
|
||||
val tasks = reader.createDataReaderFactories()
|
||||
assert(tasks.size == 11)
|
||||
}
|
||||
|
||||
test("microbatch - set offset") {
|
||||
val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty())
|
||||
val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000))))
|
||||
val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000))))
|
||||
reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
|
||||
assert(reader.getStartOffset() == startOffset)
|
||||
assert(reader.getEndOffset() == endOffset)
|
||||
}
|
||||
|
||||
test("microbatch - infer offsets") {
|
||||
val reader = new RateStreamMicroBatchReader(
|
||||
new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava))
|
||||
reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100)
|
||||
reader.setOffsetRange(Optional.empty(), Optional.empty())
|
||||
reader.getStartOffset() match {
|
||||
case r: RateStreamOffset =>
|
||||
assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs)
|
||||
case _ => throw new IllegalStateException("unexpected offset type")
|
||||
}
|
||||
reader.getEndOffset() match {
|
||||
case r: RateStreamOffset =>
|
||||
// End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted
|
||||
// longer than 100ms. It should never be early.
|
||||
assert(r.partitionToValueAndRunTimeMs(0).value >= 9)
|
||||
assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100)
|
||||
|
||||
case _ => throw new IllegalStateException("unexpected offset type")
|
||||
}
|
||||
}
|
||||
|
||||
test("microbatch - predetermined batch size") {
|
||||
val reader = new RateStreamMicroBatchReader(
|
||||
new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava))
|
||||
val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000))))
|
||||
val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000))))
|
||||
reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
|
||||
val tasks = reader.createDataReaderFactories()
|
||||
assert(tasks.size == 1)
|
||||
assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20)
|
||||
}
|
||||
|
||||
test("microbatch - data read") {
|
||||
val reader = new RateStreamMicroBatchReader(
|
||||
new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava))
|
||||
val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs)
|
||||
val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map {
|
||||
case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) =>
|
||||
(part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000))
|
||||
}.toMap)
|
||||
|
||||
reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
|
||||
val tasks = reader.createDataReaderFactories()
|
||||
assert(tasks.size == 11)
|
||||
|
||||
val readData = tasks.asScala
|
||||
.map(_.createDataReader())
|
||||
.flatMap { reader =>
|
||||
val buf = scala.collection.mutable.ListBuffer[Row]()
|
||||
while (reader.next()) buf.append(reader.get())
|
||||
buf
|
||||
}
|
||||
|
||||
assert(readData.map(_.getLong(1)).sorted == Range(0, 33))
|
||||
}
|
||||
|
||||
test("continuous in registry") {
|
||||
DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match {
|
||||
case ds: ContinuousReadSupport =>
|
||||
val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty())
|
||||
assert(reader.isInstanceOf[RateStreamContinuousReader])
|
||||
case _ =>
|
||||
throw new IllegalStateException("Could not find v2 read support for rate")
|
||||
}
|
||||
}
|
||||
|
||||
test("continuous data") {
|
||||
val reader = new RateStreamContinuousReader(
|
||||
new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava))
|
||||
reader.setStartOffset(Optional.empty())
|
||||
val tasks = reader.createDataReaderFactories()
|
||||
assert(tasks.size == 2)
|
||||
|
||||
val data = scala.collection.mutable.ListBuffer[Row]()
|
||||
tasks.asScala.foreach {
|
||||
case t: RateStreamContinuousDataReaderFactory =>
|
||||
val startTimeMs = reader.getStartOffset()
|
||||
.asInstanceOf[RateStreamOffset]
|
||||
.partitionToValueAndRunTimeMs(t.partitionIndex)
|
||||
.runTimeMs
|
||||
val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader]
|
||||
for (rowIndex <- 0 to 9) {
|
||||
r.next()
|
||||
data.append(r.get())
|
||||
assert(r.getOffset() ==
|
||||
RateStreamPartitionOffset(
|
||||
t.partitionIndex,
|
||||
t.partitionIndex + rowIndex * 2,
|
||||
startTimeMs + (rowIndex + 1) * 100))
|
||||
}
|
||||
assert(System.currentTimeMillis() >= startTimeMs + 1000)
|
||||
|
||||
case _ => throw new IllegalStateException("Unexpected task type")
|
||||
}
|
||||
|
||||
assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20))
|
||||
}
|
||||
}
|
|
@ -15,13 +15,24 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.streaming
|
||||
package org.apache.spark.sql.execution.streaming.sources
|
||||
|
||||
import java.nio.file.Files
|
||||
import java.util.Optional
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.errors.TreeNodeException
|
||||
import org.apache.spark.sql.execution.datasources.DataSource
|
||||
import org.apache.spark.sql.execution.streaming._
|
||||
import org.apache.spark.sql.execution.streaming.continuous._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest}
|
||||
import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport}
|
||||
import org.apache.spark.sql.sources.v2.reader.streaming.Offset
|
||||
import org.apache.spark.sql.streaming.StreamTest
|
||||
import org.apache.spark.util.ManualClock
|
||||
|
||||
class RateSourceSuite extends StreamTest {
|
||||
|
@ -29,18 +40,40 @@ class RateSourceSuite extends StreamTest {
|
|||
import testImplicits._
|
||||
|
||||
case class AdvanceRateManualClock(seconds: Long) extends AddData {
|
||||
override def addData(query: Option[StreamExecution]): (Source, Offset) = {
|
||||
override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
|
||||
assert(query.nonEmpty)
|
||||
val rateSource = query.get.logicalPlan.collect {
|
||||
case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] =>
|
||||
source.asInstanceOf[RateStreamSource]
|
||||
case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source
|
||||
}.head
|
||||
|
||||
rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds))
|
||||
(rateSource, rateSource.getOffset.get)
|
||||
val offset = LongOffset(TimeUnit.MILLISECONDS.toSeconds(
|
||||
rateSource.clock.getTimeMillis() - rateSource.creationTimeMs))
|
||||
(rateSource, offset)
|
||||
}
|
||||
}
|
||||
|
||||
test("basic") {
|
||||
test("microbatch in registry") {
|
||||
DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match {
|
||||
case ds: MicroBatchReadSupport =>
|
||||
val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty())
|
||||
assert(reader.isInstanceOf[RateStreamMicroBatchReader])
|
||||
case _ =>
|
||||
throw new IllegalStateException("Could not find read support for rate")
|
||||
}
|
||||
}
|
||||
|
||||
test("compatible with old path in registry") {
|
||||
DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider",
|
||||
spark.sqlContext.conf).newInstance() match {
|
||||
case ds: MicroBatchReadSupport =>
|
||||
assert(ds.isInstanceOf[RateStreamProvider])
|
||||
case _ =>
|
||||
throw new IllegalStateException("Could not find read support for rate")
|
||||
}
|
||||
}
|
||||
|
||||
test("microbatch - basic") {
|
||||
val input = spark.readStream
|
||||
.format("rate")
|
||||
.option("rowsPerSecond", "10")
|
||||
|
@ -57,7 +90,7 @@ class RateSourceSuite extends StreamTest {
|
|||
)
|
||||
}
|
||||
|
||||
test("uniform distribution of event timestamps") {
|
||||
test("microbatch - uniform distribution of event timestamps") {
|
||||
val input = spark.readStream
|
||||
.format("rate")
|
||||
.option("rowsPerSecond", "1500")
|
||||
|
@ -74,8 +107,74 @@ class RateSourceSuite extends StreamTest {
|
|||
)
|
||||
}
|
||||
|
||||
test("microbatch - set offset") {
|
||||
val temp = Files.createTempDirectory("dummy").toString
|
||||
val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp)
|
||||
val startOffset = LongOffset(0L)
|
||||
val endOffset = LongOffset(1L)
|
||||
reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
|
||||
assert(reader.getStartOffset() == startOffset)
|
||||
assert(reader.getEndOffset() == endOffset)
|
||||
}
|
||||
|
||||
test("microbatch - infer offsets") {
|
||||
val tempFolder = Files.createTempDirectory("dummy").toString
|
||||
val reader = new RateStreamMicroBatchReader(
|
||||
new DataSourceOptions(
|
||||
Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava),
|
||||
tempFolder)
|
||||
reader.clock.asInstanceOf[ManualClock].advance(100000)
|
||||
reader.setOffsetRange(Optional.empty(), Optional.empty())
|
||||
reader.getStartOffset() match {
|
||||
case r: LongOffset => assert(r.offset === 0L)
|
||||
case _ => throw new IllegalStateException("unexpected offset type")
|
||||
}
|
||||
reader.getEndOffset() match {
|
||||
case r: LongOffset => assert(r.offset >= 100)
|
||||
case _ => throw new IllegalStateException("unexpected offset type")
|
||||
}
|
||||
}
|
||||
|
||||
test("microbatch - predetermined batch size") {
|
||||
val temp = Files.createTempDirectory("dummy").toString
|
||||
val reader = new RateStreamMicroBatchReader(
|
||||
new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp)
|
||||
val startOffset = LongOffset(0L)
|
||||
val endOffset = LongOffset(1L)
|
||||
reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
|
||||
val tasks = reader.createDataReaderFactories()
|
||||
assert(tasks.size == 1)
|
||||
val dataReader = tasks.get(0).createDataReader()
|
||||
val data = ArrayBuffer[Row]()
|
||||
while (dataReader.next()) {
|
||||
data.append(dataReader.get())
|
||||
}
|
||||
assert(data.size === 20)
|
||||
}
|
||||
|
||||
test("microbatch - data read") {
|
||||
val temp = Files.createTempDirectory("dummy").toString
|
||||
val reader = new RateStreamMicroBatchReader(
|
||||
new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp)
|
||||
val startOffset = LongOffset(0L)
|
||||
val endOffset = LongOffset(1L)
|
||||
reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
|
||||
val tasks = reader.createDataReaderFactories()
|
||||
assert(tasks.size == 11)
|
||||
|
||||
val readData = tasks.asScala
|
||||
.map(_.createDataReader())
|
||||
.flatMap { reader =>
|
||||
val buf = scala.collection.mutable.ListBuffer[Row]()
|
||||
while (reader.next()) buf.append(reader.get())
|
||||
buf
|
||||
}
|
||||
|
||||
assert(readData.map(_.getLong(1)).sorted == Range(0, 33))
|
||||
}
|
||||
|
||||
test("valueAtSecond") {
|
||||
import RateStreamSource._
|
||||
import RateStreamProvider._
|
||||
|
||||
assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0)
|
||||
assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5)
|
||||
|
@ -161,7 +260,7 @@ class RateSourceSuite extends StreamTest {
|
|||
option: String,
|
||||
value: String,
|
||||
expectedMessages: Seq[String]): Unit = {
|
||||
val e = intercept[StreamingQueryException] {
|
||||
val e = intercept[IllegalArgumentException] {
|
||||
spark.readStream
|
||||
.format("rate")
|
||||
.option(option, value)
|
||||
|
@ -171,9 +270,8 @@ class RateSourceSuite extends StreamTest {
|
|||
.start()
|
||||
.awaitTermination()
|
||||
}
|
||||
assert(e.getCause.isInstanceOf[IllegalArgumentException])
|
||||
for (msg <- expectedMessages) {
|
||||
assert(e.getCause.getMessage.contains(msg))
|
||||
assert(e.getMessage.contains(msg))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -191,4 +289,46 @@ class RateSourceSuite extends StreamTest {
|
|||
assert(exception.getMessage.contains(
|
||||
"rate source does not support a user-specified schema"))
|
||||
}
|
||||
|
||||
test("continuous in registry") {
|
||||
DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match {
|
||||
case ds: ContinuousReadSupport =>
|
||||
val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty())
|
||||
assert(reader.isInstanceOf[RateStreamContinuousReader])
|
||||
case _ =>
|
||||
throw new IllegalStateException("Could not find read support for continuous rate")
|
||||
}
|
||||
}
|
||||
|
||||
test("continuous data") {
|
||||
val reader = new RateStreamContinuousReader(
|
||||
new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava))
|
||||
reader.setStartOffset(Optional.empty())
|
||||
val tasks = reader.createDataReaderFactories()
|
||||
assert(tasks.size == 2)
|
||||
|
||||
val data = scala.collection.mutable.ListBuffer[Row]()
|
||||
tasks.asScala.foreach {
|
||||
case t: RateStreamContinuousDataReaderFactory =>
|
||||
val startTimeMs = reader.getStartOffset()
|
||||
.asInstanceOf[RateStreamOffset]
|
||||
.partitionToValueAndRunTimeMs(t.partitionIndex)
|
||||
.runTimeMs
|
||||
val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader]
|
||||
for (rowIndex <- 0 to 9) {
|
||||
r.next()
|
||||
data.append(r.get())
|
||||
assert(r.getOffset() ==
|
||||
RateStreamPartitionOffset(
|
||||
t.partitionIndex,
|
||||
t.partitionIndex + rowIndex * 2,
|
||||
startTimeMs + (rowIndex + 1) * 100))
|
||||
}
|
||||
assert(System.currentTimeMillis() >= startTimeMs + 1000)
|
||||
|
||||
case _ => throw new IllegalStateException("Unexpected task type")
|
||||
}
|
||||
|
||||
assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20))
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue