[SPARK-23092][SQL] Migrate MemoryStream to DataSourceV2 APIs

## What changes were proposed in this pull request?

This PR migrates the MemoryStream to DataSourceV2 APIs.

One additional change is in the reported keys in StreamingQueryProgress.durationMs. "getOffset" and "getBatch" replaced with "setOffsetRange" and "getEndOffset" as tracking these make more sense. Unit tests changed accordingly.

## How was this patch tested?
Existing unit tests, few updated unit tests.

Author: Tathagata Das <tathagata.das1565@gmail.com>
Author: Burak Yavuz <brkyvz@gmail.com>

Closes #20445 from tdas/SPARK-23092.
This commit is contained in:
Tathagata Das 2018-02-07 15:22:53 -08:00
parent 9841ae0313
commit 30295bf5a6
9 changed files with 171 additions and 134 deletions

View file

@ -17,10 +17,12 @@
package org.apache.spark.sql.execution.streaming
import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}
/**
* A simple offset for sources that produce a single linear stream of data.
*/
case class LongOffset(offset: Long) extends Offset {
case class LongOffset(offset: Long) extends OffsetV2 {
override val json = offset.toString

View file

@ -270,16 +270,17 @@ class MicroBatchExecution(
}
case s: MicroBatchReader =>
updateStatusMessage(s"Getting offsets from $s")
reportTimeTaken("getOffset") {
// Once v1 streaming source execution is gone, we can refactor this away.
// For now, we set the range here to get the source to infer the available end offset,
// get that offset, and then set the range again when we later execute.
s.setOffsetRange(
toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))),
Optional.empty())
(s, Some(s.getEndOffset))
reportTimeTaken("setOffsetRange") {
// Once v1 streaming source execution is gone, we can refactor this away.
// For now, we set the range here to get the source to infer the available end offset,
// get that offset, and then set the range again when we later execute.
s.setOffsetRange(
toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))),
Optional.empty())
}
val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() }
(s, Option(currentOffset))
}.toMap
availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get)
@ -401,10 +402,14 @@ class MicroBatchExecution(
case (reader: MicroBatchReader, available)
if committedOffsets.get(reader).map(_ != available).getOrElse(true) =>
val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json))
val availableV2: OffsetV2 = available match {
case v1: SerializedOffset => reader.deserializeOffset(v1.json)
case v2: OffsetV2 => v2
}
reader.setOffsetRange(
toJava(current),
Optional.of(available.asInstanceOf[OffsetV2]))
logDebug(s"Retrieving data from $reader: $current -> $available")
Optional.of(availableV2))
logDebug(s"Retrieving data from $reader: $current -> $availableV2")
Some(reader ->
new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader))
case _ => None

View file

@ -17,21 +17,23 @@
package org.apache.spark.sql.execution.streaming
import java.{util => ju}
import java.util.Optional
import java.util.concurrent.atomic.AtomicInteger
import javax.annotation.concurrent.GuardedBy
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.util.control.NonFatal
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics}
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
@ -51,9 +53,10 @@ object MemoryStream {
* available.
*/
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
extends Source with Logging {
extends MicroBatchReader with SupportsScanUnsafeRow with Logging {
protected val encoder = encoderFor[A]
protected val logicalPlan = StreamingExecutionRelation(this, sqlContext.sparkSession)
private val attributes = encoder.schema.toAttributes
protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession)
protected val output = logicalPlan.output
/**
@ -61,11 +64,17 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
* Stored in a ListBuffer to facilitate removing committed batches.
*/
@GuardedBy("this")
protected val batches = new ListBuffer[Dataset[A]]
protected val batches = new ListBuffer[Array[UnsafeRow]]
@GuardedBy("this")
protected var currentOffset: LongOffset = new LongOffset(-1)
@GuardedBy("this")
private var startOffset = new LongOffset(-1)
@GuardedBy("this")
private var endOffset = new LongOffset(-1)
/**
* Last offset that was discarded, or -1 if no commits have occurred. Note that the value
* -1 is used in calculations below and isn't just an arbitrary constant.
@ -73,8 +82,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
@GuardedBy("this")
protected var lastOffsetCommitted : LongOffset = new LongOffset(-1)
def schema: StructType = encoder.schema
def toDS(): Dataset[A] = {
Dataset(sqlContext.sparkSession, logicalPlan)
}
@ -88,72 +95,69 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
}
def addData(data: TraversableOnce[A]): Offset = {
val encoded = data.toVector.map(d => encoder.toRow(d).copy())
val plan = new LocalRelation(schema.toAttributes, encoded, isStreaming = true)
val ds = Dataset[A](sqlContext.sparkSession, plan)
logDebug(s"Adding ds: $ds")
val objects = data.toSeq
val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray
logDebug(s"Adding: $objects")
this.synchronized {
currentOffset = currentOffset + 1
batches += ds
batches += rows
currentOffset
}
}
override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]"
override def getOffset: Option[Offset] = synchronized {
if (currentOffset.offset == -1) {
None
} else {
Some(currentOffset)
override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = {
synchronized {
startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset]
endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset]
}
}
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
// Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
val startOrdinal =
start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1
val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1
override def readSchema(): StructType = encoder.schema
// Internal buffer only holds the batches after lastCommittedOffset.
val newBlocks = synchronized {
val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd")
batches.slice(sliceStart, sliceEnd)
}
override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong)
if (newBlocks.isEmpty) {
return sqlContext.internalCreateDataFrame(
sqlContext.sparkContext.emptyRDD, schema, isStreaming = true)
}
override def getStartOffset: OffsetV2 = synchronized {
if (startOffset.offset == -1) null else startOffset
}
logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal))
override def getEndOffset: OffsetV2 = synchronized {
if (endOffset.offset == -1) null else endOffset
}
newBlocks
.map(_.toDF())
.reduceOption(_ union _)
.getOrElse {
sys.error("No data selected!")
override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
synchronized {
// Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
val startOrdinal = startOffset.offset.toInt + 1
val endOrdinal = endOffset.offset.toInt + 1
// Internal buffer only holds the batches after lastCommittedOffset.
val newBlocks = synchronized {
val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd")
batches.slice(sliceStart, sliceEnd)
}
logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal))
newBlocks.map { block =>
new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory[UnsafeRow]]
}.asJava
}
}
private def generateDebugString(
blocks: TraversableOnce[Dataset[A]],
rows: Seq[UnsafeRow],
startOrdinal: Int,
endOrdinal: Int): String = {
val originalUnsupportedCheck =
sqlContext.getConf("spark.sql.streaming.unsupportedOperationCheck")
try {
sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", "false")
s"MemoryBatch [$startOrdinal, $endOrdinal]: " +
s"${blocks.flatMap(_.collect()).mkString(", ")}"
} finally {
sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", originalUnsupportedCheck)
}
val fromRow = encoder.resolveAndBind().fromRow _
s"MemoryBatch [$startOrdinal, $endOrdinal]: " +
s"${rows.map(row => fromRow(row)).mkString(", ")}"
}
override def commit(end: Offset): Unit = synchronized {
override def commit(end: OffsetV2): Unit = synchronized {
def check(newOffset: LongOffset): Unit = {
val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
@ -176,11 +180,33 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
def reset(): Unit = synchronized {
batches.clear()
startOffset = LongOffset(-1)
endOffset = LongOffset(-1)
currentOffset = new LongOffset(-1)
lastOffsetCommitted = new LongOffset(-1)
}
}
class MemoryStreamDataReaderFactory(records: Array[UnsafeRow])
extends DataReaderFactory[UnsafeRow] {
override def createDataReader(): DataReader[UnsafeRow] = {
new DataReader[UnsafeRow] {
private var currentIndex = -1
override def next(): Boolean = {
// Return true as long as the new index is in the array.
currentIndex += 1
currentIndex < records.length
}
override def get(): UnsafeRow = records(currentIndex)
override def close(): Unit = {}
}
}
}
/**
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
* tests and does not provide durability.

View file

@ -151,7 +151,7 @@ case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactor
}
class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] {
var currentIndex = -1
private var currentIndex = -1
override def next(): Boolean = {
// Return true as long as the new index is in the seq.

View file

@ -46,49 +46,34 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
.foreach(new TestForeachWriter())
.start()
def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = {
import ForeachSinkSuite._
val events = ForeachSinkSuite.allEvents()
assert(events.size === 2) // one seq of events for each of the 2 partitions
// Verify both seq of events have an Open event as the first event
assert(events.map(_.head).toSet === Set(0, 1).map(p => Open(p, expectedVersion)))
// Verify all the Process event correspond to the expected data
val allProcessEvents = events.flatMap(_.filter(_.isInstanceOf[Process[_]]))
assert(allProcessEvents.toSet === expectedData.map { data => Process(data) }.toSet)
// Verify both seq of events have a Close event as the last event
assert(events.map(_.last).toSet === Set(Close(None), Close(None)))
}
// -- batch 0 ---------------------------------------
ForeachSinkSuite.clear()
input.addData(1, 2, 3, 4)
query.processAllAvailable()
var expectedEventsForPartition0 = Seq(
ForeachSinkSuite.Open(partition = 0, version = 0),
ForeachSinkSuite.Process(value = 2),
ForeachSinkSuite.Process(value = 3),
ForeachSinkSuite.Close(None)
)
var expectedEventsForPartition1 = Seq(
ForeachSinkSuite.Open(partition = 1, version = 0),
ForeachSinkSuite.Process(value = 1),
ForeachSinkSuite.Process(value = 4),
ForeachSinkSuite.Close(None)
)
var allEvents = ForeachSinkSuite.allEvents()
assert(allEvents.size === 2)
assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1))
ForeachSinkSuite.clear()
verifyOutput(expectedVersion = 0, expectedData = 1 to 4)
// -- batch 1 ---------------------------------------
ForeachSinkSuite.clear()
input.addData(5, 6, 7, 8)
query.processAllAvailable()
expectedEventsForPartition0 = Seq(
ForeachSinkSuite.Open(partition = 0, version = 1),
ForeachSinkSuite.Process(value = 5),
ForeachSinkSuite.Process(value = 7),
ForeachSinkSuite.Close(None)
)
expectedEventsForPartition1 = Seq(
ForeachSinkSuite.Open(partition = 1, version = 1),
ForeachSinkSuite.Process(value = 6),
ForeachSinkSuite.Process(value = 8),
ForeachSinkSuite.Close(None)
)
allEvents = ForeachSinkSuite.allEvents()
assert(allEvents.size === 2)
assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1))
verifyOutput(expectedVersion = 1, expectedData = 5 to 8)
query.stop()
}

View file

@ -492,16 +492,16 @@ class StreamSuite extends StreamTest {
val explainWithoutExtended = q.explainInternal(false)
// `extended = false` only displays the physical plan.
assert("LocalRelation".r.findAllMatchIn(explainWithoutExtended).size === 0)
assert("LocalTableScan".r.findAllMatchIn(explainWithoutExtended).size === 1)
assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0)
assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1)
// Use "StateStoreRestore" to verify that it does output a streaming physical plan
assert(explainWithoutExtended.contains("StateStoreRestore"))
val explainWithExtended = q.explainInternal(true)
// `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical
// plan.
assert("LocalRelation".r.findAllMatchIn(explainWithExtended).size === 3)
assert("LocalTableScan".r.findAllMatchIn(explainWithExtended).size === 1)
assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3)
assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1)
// Use "StateStoreRestore" to verify that it does output a streaming physical plan
assert(explainWithExtended.contains("StateStoreRestore"))
} finally {

View file

@ -120,7 +120,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData {
override def toString: String = s"AddData to $source: ${data.mkString(",")}"
override def addData(query: Option[StreamExecution]): (Source, Offset) = {
override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
(source, source.addData(data))
}
}

View file

@ -33,6 +33,7 @@ import org.apache.spark.scheduler._
import org.apache.spark.sql.{Encoder, SparkSession}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}
import org.apache.spark.sql.streaming.StreamingQueryListener._
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.util.JsonProtocol
@ -298,9 +299,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
try {
val input = new MemoryStream[Int](0, sqlContext) {
@volatile var numTriggers = 0
override def getOffset: Option[Offset] = {
override def getEndOffset: OffsetV2 = {
numTriggers += 1
super.getOffset
super.getEndOffset
}
}
val clock = new StreamManualClock()

View file

@ -17,25 +17,27 @@
package org.apache.spark.sql.streaming
import java.{util => ju}
import java.util.Optional
import java.util.concurrent.CountDownLatch
import org.apache.commons.lang3.RandomStringUtils
import org.mockito.Mockito._
import org.scalactic.TolerantNumerics
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.mockito.MockitoSugar
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}
import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ManualClock
class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging with MockitoSugar {
@ -206,19 +208,29 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
/** Custom MemoryStream that waits for manual clock to reach a time */
val inputData = new MemoryStream[Int](0, sqlContext) {
// getOffset should take 50 ms the first time it is called
override def getOffset: Option[Offset] = {
val offset = super.getOffset
if (offset.nonEmpty) {
clock.waitTillTime(1050)
private def dataAdded: Boolean = currentOffset.offset != -1
// setOffsetRange should take 50 ms the first time it is called after data is added
override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = {
synchronized {
if (dataAdded) clock.waitTillTime(1050)
super.setOffsetRange(start, end)
}
offset
}
// getEndOffset should take 100 ms the first time it is called after data is added
override def getEndOffset(): OffsetV2 = synchronized {
if (dataAdded) clock.waitTillTime(1150)
super.getEndOffset()
}
// getBatch should take 100 ms the first time it is called
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
if (start.isEmpty) clock.waitTillTime(1150)
super.getBatch(start, end)
override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
synchronized {
clock.waitTillTime(1350)
super.createUnsafeRowReaderFactories()
}
}
}
@ -258,39 +270,44 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
AssertOnQuery(_.status.message === "Waiting for next trigger"),
AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0),
// Test status and progress while offset is being fetched
// Test status and progress when setOffsetRange is being called
AddData(inputData, 1, 2),
AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on getOffset
AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on setOffsetRange
AssertStreamExecThreadIsWaitingForTime(1050),
AssertOnQuery(_.status.isDataAvailable === false),
AssertOnQuery(_.status.isTriggerActive === true),
AssertOnQuery(_.status.message.startsWith("Getting offsets from")),
AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0),
// Test status and progress while batch is being fetched
AdvanceManualClock(50), // time = 1050 to unblock getOffset
AdvanceManualClock(50), // time = 1050 to unblock setOffsetRange
AssertClockTime(1050),
AssertStreamExecThreadIsWaitingForTime(1150), // will block on getBatch that needs 1150
AssertStreamExecThreadIsWaitingForTime(1150), // will block on getEndOffset that needs 1150
AssertOnQuery(_.status.isDataAvailable === false),
AssertOnQuery(_.status.isTriggerActive === true),
AssertOnQuery(_.status.message.startsWith("Getting offsets from")),
AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0),
AdvanceManualClock(100), // time = 1150 to unblock getEndOffset
AssertClockTime(1150),
AssertStreamExecThreadIsWaitingForTime(1350), // will block on createReadTasks that needs 1350
AssertOnQuery(_.status.isDataAvailable === true),
AssertOnQuery(_.status.isTriggerActive === true),
AssertOnQuery(_.status.message === "Processing new data"),
AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0),
// Test status and progress while batch is being processed
AdvanceManualClock(100), // time = 1150 to unblock getBatch
AssertClockTime(1150),
AssertStreamExecThreadIsWaitingForTime(1500), // will block in Spark job that needs 1500
AdvanceManualClock(200), // time = 1350 to unblock createReadTasks
AssertClockTime(1350),
AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500
AssertOnQuery(_.status.isDataAvailable === true),
AssertOnQuery(_.status.isTriggerActive === true),
AssertOnQuery(_.status.message === "Processing new data"),
AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0),
// Test status and progress while batch processing has completed
AssertOnQuery { _ => clock.getTimeMillis() === 1150 },
AdvanceManualClock(350), // time = 1500 to unblock job
AdvanceManualClock(150), // time = 1500 to unblock map task
AssertClockTime(1500),
CheckAnswer(2),
AssertStreamExecThreadIsWaitingForTime(2000),
AssertStreamExecThreadIsWaitingForTime(2000), // will block until the next trigger
AssertOnQuery(_.status.isDataAvailable === true),
AssertOnQuery(_.status.isTriggerActive === false),
AssertOnQuery(_.status.message === "Waiting for next trigger"),
@ -307,10 +324,11 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
assert(progress.numInputRows === 2)
assert(progress.processedRowsPerSecond === 4.0)
assert(progress.durationMs.get("getOffset") === 50)
assert(progress.durationMs.get("getBatch") === 100)
assert(progress.durationMs.get("setOffsetRange") === 50)
assert(progress.durationMs.get("getEndOffset") === 100)
assert(progress.durationMs.get("queryPlanning") === 0)
assert(progress.durationMs.get("walCommit") === 0)
assert(progress.durationMs.get("addBatch") === 350)
assert(progress.durationMs.get("triggerExecution") === 500)
assert(progress.sources.length === 1)