[SPARK-13146][SQL] Management API for continuous queries

### Management API for Continuous Queries

**API for getting status of each query**
- Whether active or not
- Unique name of each query
- Status of the sources and sinks
- Exceptions

**API for managing each query**
- Immediately stop an active query
- Waiting for a query to be terminated, correctly or with error

**API for managing multiple queries**
- Listing all active queries
- Getting an active query by name
- Waiting for any one of the active queries to be terminated

**API for listening to query life cycle events**
- ContinuousQueryListener API for query start, progress and termination events.

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #11030 from tdas/streaming-df-management-api.
This commit is contained in:
Tathagata Das 2016-02-10 16:45:06 -08:00 committed by Shixiong Zhu
parent 29c547303f
commit 0902e20288
17 changed files with 1686 additions and 115 deletions

View file

@ -17,14 +17,84 @@
package org.apache.spark.sql
import org.apache.spark.annotation.Experimental
/**
* :: Experimental ::
* A handle to a query that is executing continuously in the background as new data arrives.
* All these methods are thread-safe.
* @since 2.0.0
*/
@Experimental
trait ContinuousQuery {
/**
* Stops the execution of this query if it is running. This method blocks until the threads
* Returns the name of the query.
* @since 2.0.0
*/
def name: String
/**
* Returns the SQLContext associated with `this` query
* @since 2.0.0
*/
def sqlContext: SQLContext
/**
* Whether the query is currently active or not
* @since 2.0.0
*/
def isActive: Boolean
/**
* Returns the [[ContinuousQueryException]] if the query was terminated by an exception.
* @since 2.0.0
*/
def exception: Option[ContinuousQueryException]
/**
* Returns current status of all the sources.
* @since 2.0.0
*/
def sourceStatuses: Array[SourceStatus]
/** Returns current status of the sink. */
def sinkStatus: SinkStatus
/**
* Waits for the termination of `this` query, either by `query.stop()` or by an exception.
* If the query has terminated with an exception, then the exception will be thrown.
*
* If the query has terminated, then all subsequent calls to this method will either return
* immediately (if the query was terminated by `stop()`), or throw the exception
* immediately (if the query has terminated with exception).
*
* @throws ContinuousQueryException, if `this` query has terminated with an exception.
*
* @since 2.0.0
*/
def awaitTermination(): Unit
/**
* Waits for the termination of `this` query, either by `query.stop()` or by an exception.
* If the query has terminated with an exception, then the exception will be throw.
* Otherwise, it returns whether the query has terminated or not within the `timeoutMs`
* milliseconds.
*
* If the query has terminated, then all subsequent calls to this method will either return
* `true` immediately (if the query was terminated by `stop()`), or throw the exception
* immediately (if the query has terminated with exception).
*
* @throws ContinuousQueryException, if `this` query has terminated with an exception
*
* @since 2.0.0
*/
def awaitTermination(timeoutMs: Long): Boolean
/**
* Stops the execution of this query if it is running. This method blocks until the threads
* performing execution has stopped.
* @since 2.0.0
*/
def stop(): Unit
}

View file

@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.streaming.{Offset, StreamExecution}
/**
* :: Experimental ::
* Exception that stopped a [[ContinuousQuery]].
* @param query Query that caused the exception
* @param message Message of this exception
* @param cause Internal cause of this exception
* @param startOffset Starting offset (if known) of the range of data in which exception occurred
* @param endOffset Ending offset (if known) of the range of data in exception occurred
* @since 2.0.0
*/
@Experimental
class ContinuousQueryException private[sql](
val query: ContinuousQuery,
val message: String,
val cause: Throwable,
val startOffset: Option[Offset] = None,
val endOffset: Option[Offset] = None
) extends Exception(message, cause) {
/** Time when the exception occurred */
val time: Long = System.currentTimeMillis
override def toString(): String = {
val causeStr =
s"${cause.getMessage} ${cause.getStackTrace.take(10).mkString("", "\n|\t", "\n")}"
s"""
|$causeStr
|
|${query.asInstanceOf[StreamExecution].toDebugString}
""".stripMargin
}
}

View file

@ -0,0 +1,193 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import scala.collection.mutable
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution}
import org.apache.spark.sql.util.ContinuousQueryListener
/**
* :: Experimental ::
* A class to manage all the [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active
* on a [[SQLContext]].
*
* @since 2.0.0
*/
@Experimental
class ContinuousQueryManager(sqlContext: SQLContext) {
private val listenerBus = new ContinuousQueryListenerBus(sqlContext.sparkContext.listenerBus)
private val activeQueries = new mutable.HashMap[String, ContinuousQuery]
private val activeQueriesLock = new Object
private val awaitTerminationLock = new Object
private var lastTerminatedQuery: ContinuousQuery = null
/**
* Returns a list of active queries associated with this SQLContext
*
* @since 2.0.0
*/
def active: Array[ContinuousQuery] = activeQueriesLock.synchronized {
activeQueries.values.toArray
}
/**
* Returns an active query from this SQLContext or throws exception if bad name
*
* @since 2.0.0
*/
def get(name: String): ContinuousQuery = activeQueriesLock.synchronized {
activeQueries.get(name).getOrElse {
throw new IllegalArgumentException(s"There is no active query with name $name")
}
}
/**
* Wait until any of the queries on the associated SQLContext has terminated since the
* creation of the context, or since `resetTerminated()` was called. If any query was terminated
* with an exception, then the exception will be thrown.
*
* If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either
* return immediately (if the query was terminated by `query.stop()`),
* or throw the exception immediately (if the query was terminated with exception). Use
* `resetTerminated()` to clear past terminations and wait for new terminations.
*
* In the case where multiple queries have terminated since `resetTermination()` was called,
* if any query has terminated with exception, then `awaitAnyTermination()` will
* throw any of the exception. For correctly documenting exceptions across multiple queries,
* users need to stop all of them after any of them terminates with exception, and then check the
* `query.exception()` for each query.
*
* @throws ContinuousQueryException, if any query has terminated with an exception
*
* @since 2.0.0
*/
def awaitAnyTermination(): Unit = {
awaitTerminationLock.synchronized {
while (lastTerminatedQuery == null) {
awaitTerminationLock.wait(10)
}
if (lastTerminatedQuery != null && lastTerminatedQuery.exception.nonEmpty) {
throw lastTerminatedQuery.exception.get
}
}
}
/**
* Wait until any of the queries on the associated SQLContext has terminated since the
* creation of the context, or since `resetTerminated()` was called. Returns whether any query
* has terminated or not (multiple may have terminated). If any query has terminated with an
* exception, then the exception will be thrown.
*
* If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either
* return `true` immediately (if the query was terminated by `query.stop()`),
* or throw the exception immediately (if the query was terminated with exception). Use
* `resetTerminated()` to clear past terminations and wait for new terminations.
*
* In the case where multiple queries have terminated since `resetTermination()` was called,
* if any query has terminated with exception, then `awaitAnyTermination()` will
* throw any of the exception. For correctly documenting exceptions across multiple queries,
* users need to stop all of them after any of them terminates with exception, and then check the
* `query.exception()` for each query.
*
* @throws ContinuousQueryException, if any query has terminated with an exception
*
* @since 2.0.0
*/
def awaitAnyTermination(timeoutMs: Long): Boolean = {
val startTime = System.currentTimeMillis
def isTimedout = System.currentTimeMillis - startTime >= timeoutMs
awaitTerminationLock.synchronized {
while (!isTimedout && lastTerminatedQuery == null) {
awaitTerminationLock.wait(10)
}
if (lastTerminatedQuery != null && lastTerminatedQuery.exception.nonEmpty) {
throw lastTerminatedQuery.exception.get
}
lastTerminatedQuery != null
}
}
/**
* Forget about past terminated queries so that `awaitAnyTermination()` can be used again to
* wait for new terminations.
*
* @since 2.0.0
*/
def resetTerminated(): Unit = {
awaitTerminationLock.synchronized {
lastTerminatedQuery = null
}
}
/**
* Register a [[ContinuousQueryListener]] to receive up-calls for life cycle events of
* [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]].
*
* @since 2.0.0
*/
def addListener(listener: ContinuousQueryListener): Unit = {
listenerBus.addListener(listener)
}
/**
* Deregister a [[ContinuousQueryListener]].
*
* @since 2.0.0
*/
def removeListener(listener: ContinuousQueryListener): Unit = {
listenerBus.removeListener(listener)
}
/** Post a listener event */
private[sql] def postListenerEvent(event: ContinuousQueryListener.Event): Unit = {
listenerBus.post(event)
}
/** Start a query */
private[sql] def startQuery(name: String, df: DataFrame, sink: Sink): ContinuousQuery = {
activeQueriesLock.synchronized {
if (activeQueries.contains(name)) {
throw new IllegalArgumentException(
s"Cannot start query with name $name as a query with that name is already active")
}
val query = new StreamExecution(sqlContext, name, df.logicalPlan, sink)
query.start()
activeQueries.put(name, query)
query
}
}
/** Notify (by the ContinuousQuery) that the query has been terminated */
private[sql] def notifyQueryTermination(terminatedQuery: ContinuousQuery): Unit = {
activeQueriesLock.synchronized {
activeQueries -= terminatedQuery.name
}
awaitTerminationLock.synchronized {
if (lastTerminatedQuery == null || terminatedQuery.exception.nonEmpty) {
lastTerminatedQuery = terminatedQuery
}
awaitTerminationLock.notifyAll()
}
}
}

View file

@ -205,6 +205,17 @@ final class DataFrameWriter private[sql](df: DataFrame) {
df)
}
/**
* Specifies the name of the [[ContinuousQuery]] that can be started with `stream()`.
* This name must be unique among all the currently active queries in the associated SQLContext.
*
* @since 2.0.0
*/
def queryName(queryName: String): DataFrameWriter = {
this.extraOptions += ("queryName" -> queryName)
this
}
/**
* Starts the execution of the streaming query, which will continually output results to the given
* path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with
@ -230,7 +241,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
extraOptions.toMap,
normalizedParCols.getOrElse(Nil))
new StreamExecution(df.sqlContext, df.logicalPlan, sink)
df.sqlContext.continuousQueryManager.startQuery(
extraOptions.getOrElse("queryName", StreamExecution.nextName), df, sink)
}
/**

View file

@ -181,6 +181,8 @@ class SQLContext private[sql](
@transient
lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager
protected[sql] lazy val continuousQueryManager = new ContinuousQueryManager(this)
@transient
protected[sql] lazy val catalog: Catalog = new SimpleCatalog(conf)
@ -835,6 +837,16 @@ class SQLContext private[sql](
DataFrame(this, ShowTablesCommand(Some(databaseName)))
}
/**
* Returns a [[ContinuousQueryManager]] that allows managing all the
* [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active on `this` context.
*
* @since 2.0.0
*/
def streams: ContinuousQueryManager = {
continuousQueryManager
}
/**
* Returns the names of tables in the current database as an array.
*

View file

@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.streaming.{Offset, Sink}
/**
* :: Experimental ::
* Status and metrics of a streaming [[Sink]].
*
* @param description Description of the source corresponding to this status
* @param offset Current offset up to which data has been written by the sink
* @since 2.0.0
*/
@Experimental
class SinkStatus private[sql](
val description: String,
val offset: Option[Offset])

View file

@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.streaming.{Offset, Source}
/**
* :: Experimental ::
* Status and metrics of a streaming [[Source]].
*
* @param description Description of the source corresponding to this status
* @param offset Current offset of the source, if known
* @since 2.0.0
*/
@Experimental
class SourceStatus private[sql] (
val description: String,
val offset: Option[Offset])

View file

@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming
import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerEvent}
import org.apache.spark.sql.util.ContinuousQueryListener
import org.apache.spark.sql.util.ContinuousQueryListener._
import org.apache.spark.util.ListenerBus
/**
* A bus to forward events to [[ContinuousQueryListener]]s. This one will wrap received
* [[ContinuousQueryListener.Event]]s as WrappedContinuousQueryListenerEvents and send them to the
* Spark listener bus. It also registers itself with Spark listener bus, so that it can receive
* WrappedContinuousQueryListenerEvents, unwrap them as ContinuousQueryListener.Events and
* dispatch them to ContinuousQueryListener.
*/
class ContinuousQueryListenerBus(sparkListenerBus: LiveListenerBus)
extends SparkListener with ListenerBus[ContinuousQueryListener, ContinuousQueryListener.Event] {
sparkListenerBus.addListener(this)
/**
* Post a ContinuousQueryListener event to the Spark listener bus asynchronously. This event will
* be dispatched to all ContinuousQueryListener in the thread of the Spark listener bus.
*/
def post(event: ContinuousQueryListener.Event) {
event match {
case s: QueryStarted =>
postToAll(s)
case _ =>
sparkListenerBus.post(new WrappedContinuousQueryListenerEvent(event))
}
}
override def onOtherEvent(event: SparkListenerEvent): Unit = {
event match {
case WrappedContinuousQueryListenerEvent(e) =>
postToAll(e)
case _ =>
}
}
override protected def doPostEvent(
listener: ContinuousQueryListener,
event: ContinuousQueryListener.Event): Unit = {
event match {
case queryStarted: QueryStarted =>
listener.onQueryStarted(queryStarted)
case queryProgress: QueryProgress =>
listener.onQueryProgress(queryProgress)
case queryTerminated: QueryTerminated =>
listener.onQueryTerminated(queryTerminated)
case _ =>
}
}
/**
* Wrapper for StreamingListenerEvent as SparkListenerEvent so that it can be posted to Spark
* listener bus.
*/
private case class WrappedContinuousQueryListenerEvent(
streamingListenerEvent: ContinuousQueryListener.Event) extends SparkListenerEvent {
// Do not log streaming events in event log as history server does not support these events.
protected[spark] override def logEvent: Boolean = false
}
}

View file

@ -17,16 +17,20 @@
package org.apache.spark.sql.execution.streaming
import java.lang.Thread.UncaughtExceptionHandler
import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
import org.apache.spark.Logging
import org.apache.spark.sql.{ContinuousQuery, DataFrame, SQLContext}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.util.ContinuousQueryListener
import org.apache.spark.sql.util.ContinuousQueryListener._
/**
* Manages the execution of a streaming Spark SQL query that is occurring in a separate thread.
@ -35,15 +39,15 @@ import org.apache.spark.sql.execution.QueryExecution
* and the results are committed transactionally to the given [[Sink]].
*/
class StreamExecution(
sqlContext: SQLContext,
val sqlContext: SQLContext,
override val name: String,
private[sql] val logicalPlan: LogicalPlan,
val sink: Sink) extends ContinuousQuery with Logging {
/** An monitor used to wait/notify when batches complete. */
private val awaitBatchLock = new Object
@volatile
private var batchRun = false
private val startLatch = new CountDownLatch(1)
private val terminationLatch = new CountDownLatch(1)
/** Minimum amount of time in between the start of each batch. */
private val minBatchTime = 10
@ -55,9 +59,92 @@ class StreamExecution(
private val sources =
logicalPlan.collect { case s: StreamingRelation => s.source }
// Start the execution at the current offsets stored in the sink. (i.e. avoid reprocessing data
// that we have already processed).
{
/** Defines the internal state of execution */
@volatile
private var state: State = INITIALIZED
@volatile
private[sql] var lastExecution: QueryExecution = null
@volatile
private[sql] var streamDeathCause: ContinuousQueryException = null
/** The thread that runs the micro-batches of this stream. */
private[sql] val microBatchThread = new Thread(s"stream execution thread for $name") {
override def run(): Unit = { runBatches() }
}
/** Whether the query is currently active or not */
override def isActive: Boolean = state == ACTIVE
/** Returns current status of all the sources. */
override def sourceStatuses: Array[SourceStatus] = {
sources.map(s => new SourceStatus(s.toString, streamProgress.get(s))).toArray
}
/** Returns current status of the sink. */
override def sinkStatus: SinkStatus = new SinkStatus(sink.toString, sink.currentOffset)
/** Returns the [[ContinuousQueryException]] if the query was terminated by an exception. */
override def exception: Option[ContinuousQueryException] = Option(streamDeathCause)
/**
* Starts the execution. This returns only after the thread has started and [[QueryStarted]] event
* has been posted to all the listeners.
*/
private[sql] def start(): Unit = {
microBatchThread.setDaemon(true)
microBatchThread.start()
startLatch.await() // Wait until thread started and QueryStart event has been posted
}
/**
* Repeatedly attempts to run batches as data arrives.
*
* Note that this method ensures that [[QueryStarted]] and [[QueryTerminated]] events are posted
* so that listeners are guaranteed to get former event before the latter. Furthermore, this
* method also ensures that [[QueryStarted]] event is posted before the `start()` method returns.
*/
private def runBatches(): Unit = {
try {
// Mark ACTIVE and then post the event. QueryStarted event is synchronously sent to listeners,
// so must mark this as ACTIVE first.
state = ACTIVE
postEvent(new QueryStarted(this)) // Assumption: Does not throw exception.
// Unblock starting thread
startLatch.countDown()
// While active, repeatedly attempt to run batches.
SQLContext.setActive(sqlContext)
populateStartOffsets()
logInfo(s"Stream running at $streamProgress")
while (isActive) {
attemptBatch()
Thread.sleep(minBatchTime) // TODO: Could be tighter
}
} catch {
case _: InterruptedException if state == TERMINATED => // interrupted by stop()
case NonFatal(e) =>
streamDeathCause = new ContinuousQueryException(
this,
s"Query $name terminated with exception: ${e.getMessage}",
e,
Some(streamProgress.toCompositeOffset(sources)))
logError(s"Query $name terminated with error", e)
} finally {
state = TERMINATED
sqlContext.streams.notifyQueryTermination(StreamExecution.this)
postEvent(new QueryTerminated(this))
terminationLatch.countDown()
}
}
/**
* Populate the start offsets to start the execution at the current offsets stored in the sink
* (i.e. avoid reprocessing data that we have already processed).
*/
private def populateStartOffsets(): Unit = {
sink.currentOffset match {
case Some(c: CompositeOffset) =>
val storedProgress = c.offsets
@ -74,37 +161,8 @@ class StreamExecution(
}
}
logInfo(s"Stream running at $streamProgress")
/** When false, signals to the microBatchThread that it should stop running. */
@volatile private var shouldRun = true
/** The thread that runs the micro-batches of this stream. */
private[sql] val microBatchThread = new Thread("stream execution thread") {
override def run(): Unit = {
SQLContext.setActive(sqlContext)
while (shouldRun) {
attemptBatch()
Thread.sleep(minBatchTime) // TODO: Could be tighter
}
}
}
microBatchThread.setDaemon(true)
microBatchThread.setUncaughtExceptionHandler(
new UncaughtExceptionHandler {
override def uncaughtException(t: Thread, e: Throwable): Unit = {
streamDeathCause = e
}
})
microBatchThread.start()
@volatile
private[sql] var lastExecution: QueryExecution = null
@volatile
private[sql] var streamDeathCause: Throwable = null
/**
* Checks to see if any new data is present in any of the sources. When new data is available,
* Checks to see if any new data is present in any of the sources. When new data is available,
* a batch is executed and passed to the sink, updating the currentOffsets.
*/
private def attemptBatch(): Unit = {
@ -150,36 +208,43 @@ class StreamExecution(
streamProgress.synchronized {
// Update the offsets and calculate a new composite offset
newOffsets.foreach(streamProgress.update)
val newStreamProgress = logicalPlan.collect {
case StreamingRelation(source, _) => streamProgress.get(source)
}
val batchOffset = CompositeOffset(newStreamProgress)
// Construct the batch and send it to the sink.
val batchOffset = streamProgress.toCompositeOffset(sources)
val nextBatch = new Batch(batchOffset, new DataFrame(sqlContext, newPlan))
sink.addBatch(nextBatch)
}
batchRun = true
awaitBatchLock.synchronized {
// Wake up any threads that are waiting for the stream to progress.
awaitBatchLock.notifyAll()
}
val batchTime = (System.nanoTime() - startTime).toDouble / 1000000
logInfo(s"Compete up to $newOffsets in ${batchTime}ms")
logInfo(s"Completed up to $newOffsets in ${batchTime}ms")
postEvent(new QueryProgress(this))
}
logDebug(s"Waiting for data, current: $streamProgress")
}
private def postEvent(event: ContinuousQueryListener.Event) {
sqlContext.streams.postListenerEvent(event)
}
/**
* Signals to the thread executing micro-batches that it should stop running after the next
* batch. This method blocks until the thread stops running.
*/
def stop(): Unit = {
shouldRun = false
if (microBatchThread.isAlive) { microBatchThread.join() }
override def stop(): Unit = {
// Set the state to TERMINATED so that the batching thread knows that it was interrupted
// intentionally
state = TERMINATED
if (microBatchThread.isAlive) {
microBatchThread.interrupt()
microBatchThread.join()
}
logInfo(s"Query $name was stopped")
}
/**
@ -198,14 +263,60 @@ class StreamExecution(
logDebug(s"Unblocked at $newOffset for $source")
}
override def toString: String =
override def awaitTermination(): Unit = {
if (state == INITIALIZED) {
throw new IllegalStateException("Cannot wait for termination on a query that has not started")
}
terminationLatch.await()
if (streamDeathCause != null) {
throw streamDeathCause
}
}
override def awaitTermination(timeoutMs: Long): Boolean = {
if (state == INITIALIZED) {
throw new IllegalStateException("Cannot wait for termination on a query that has not started")
}
require(timeoutMs > 0, "Timeout has to be positive")
terminationLatch.await(timeoutMs, TimeUnit.MILLISECONDS)
if (streamDeathCause != null) {
throw streamDeathCause
} else {
!isActive
}
}
override def toString: String = {
s"Continuous Query - $name [state = $state]"
}
def toDebugString: String = {
val deathCauseStr = if (streamDeathCause != null) {
"Error:\n" + stackTraceToString(streamDeathCause.cause)
} else ""
s"""
|=== Streaming Query ===
|CurrentOffsets: $streamProgress
|Thread State: ${microBatchThread.getState}
|${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""}
|=== Continuous Query ===
|Name: $name
|Current Offsets: $streamProgress
|
|Current State: $state
|Thread State: ${microBatchThread.getState}
|
|Logical Plan:
|$logicalPlan
|
|$deathCauseStr
""".stripMargin
}
trait State
case object INITIALIZED extends State
case object ACTIVE extends State
case object TERMINATED extends State
}
private[sql] object StreamExecution {
private val nextId = new AtomicInteger()
def nextName: String = s"query-${nextId.getAndIncrement}"
}

View file

@ -55,6 +55,10 @@ class StreamProgress {
copied
}
private[sql] def toCompositeOffset(source: Seq[Source]): CompositeOffset = {
CompositeOffset(source.map(get))
}
override def toString: String =
currentOffsets.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}")

View file

@ -20,11 +20,12 @@ package org.apache.spark.sql.execution.streaming
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{encoderFor, RowEncoder}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.StructType
object MemoryStream {
@ -46,14 +47,13 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
protected val logicalPlan = StreamingRelation(this)
protected val output = logicalPlan.output
protected val batches = new ArrayBuffer[Dataset[A]]
protected var currentOffset: LongOffset = new LongOffset(-1)
protected def blockManager = SparkEnv.get.blockManager
def schema: StructType = encoder.schema
def getCurrentOffset: Offset = currentOffset
def toDS()(implicit sqlContext: SQLContext): Dataset[A] = {
new Dataset(sqlContext, logicalPlan)
}
@ -62,6 +62,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
new DataFrame(sqlContext, logicalPlan)
}
def addData(data: A*): Offset = {
addData(data.toTraversable)
}
def addData(data: TraversableOnce[A]): Offset = {
import sqlContext.implicits._
this.synchronized {
@ -110,6 +114,7 @@ class MemorySink(schema: StructType) extends Sink with Logging {
}
override def addBatch(nextBatch: Batch): Unit = synchronized {
nextBatch.data.collect() // 'compute' the batch's data and record the batch
batches.append(nextBatch)
}
@ -131,8 +136,13 @@ class MemorySink(schema: StructType) extends Sink with Logging {
batches.dropRight(num)
}
override def toString: String = synchronized {
batches.map(b => s"${b.end}: ${b.data.collect().mkString(" ")}").mkString("\n")
def toDebugString: String = synchronized {
batches.map { b =>
val dataStr = try b.data.collect().mkString(" ") catch {
case NonFatal(e) => "[Error converting to string]"
}
s"${b.end}: $dataStr"
}.mkString("\n")
}
}

View file

@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.util
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.ContinuousQuery
import org.apache.spark.sql.util.ContinuousQueryListener._
/**
* :: Experimental ::
* Interface for listening to events related to [[ContinuousQuery ContinuousQueries]].
* @note The methods are not thread-safe as they may be called from different threads.
*/
@Experimental
abstract class ContinuousQueryListener {
/**
* Called when a query is started.
* @note This is called synchronously with
* [[org.apache.spark.sql.DataFrameWriter `DataFrameWriter.stream()`]],
* that is, `onQueryStart` will be called on all listeners before `DataFrameWriter.stream()`
* returns the corresponding [[ContinuousQuery]].
*/
def onQueryStarted(queryStarted: QueryStarted)
/** Called when there is some status update (ingestion rate updated, etc. */
def onQueryProgress(queryProgress: QueryProgress)
/** Called when a query is stopped, with or without error */
def onQueryTerminated(queryTerminated: QueryTerminated)
}
/**
* :: Experimental ::
* Companion object of [[ContinuousQueryListener]] that defines the listener events.
*/
@Experimental
object ContinuousQueryListener {
/** Base type of [[ContinuousQueryListener]] events */
trait Event
/** Event representing the start of a query */
class QueryStarted private[sql](val query: ContinuousQuery) extends Event
/** Event representing any progress updates in a query */
class QueryProgress private[sql](val query: ContinuousQuery) extends Event
/** Event representing that termination of a query */
class QueryTerminated private[sql](val query: ContinuousQuery) extends Event
}

View file

@ -21,9 +21,16 @@ import java.lang.Thread.UncaughtExceptionHandler
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.language.experimental.macros
import scala.reflect.ClassTag
import scala.util.Random
import scala.util.control.NonFatal
import org.scalatest.concurrent.Timeouts
import org.scalatest.Assertions
import org.scalatest.concurrent.{Eventually, Timeouts}
import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.exceptions.TestFailedDueToTimeoutException
import org.scalatest.time.Span
import org.scalatest.time.SpanSugar._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder}
@ -64,7 +71,7 @@ trait StreamTest extends QueryTest with Timeouts {
}
/** How long to wait for an active stream to catch up when checking a result. */
val streamingTimout = 10.seconds
val streamingTimeout = 10.seconds
/** A trait for actions that can be performed while testing a streaming DataFrame. */
trait StreamAction
@ -128,7 +135,38 @@ trait StreamTest extends QueryTest with Timeouts {
case object StartStream extends StreamAction
/** Signals that a failure is expected and should not kill the test. */
case object ExpectFailure extends StreamAction
case class ExpectFailure[T <: Throwable : ClassTag]() extends StreamAction {
val causeClass: Class[T] = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]]
override def toString(): String = s"ExpectFailure[${causeClass.getCanonicalName}]"
}
/** Assert that a body is true */
class Assert(condition: => Boolean, val message: String = "") extends StreamAction {
def run(): Unit = { Assertions.assert(condition) }
override def toString: String = s"Assert(<condition>, $message)"
}
object Assert {
def apply(condition: => Boolean, message: String = ""): Assert = new Assert(condition, message)
def apply(message: String)(body: => Unit): Assert = new Assert( { body; true }, message)
def apply(body: => Unit): Assert = new Assert( { body; true }, "")
}
/** Assert that a condition on the active query is true */
class AssertOnQuery(val condition: StreamExecution => Boolean, val message: String)
extends StreamAction {
override def toString: String = s"AssertOnQuery(<condition>, $message)"
}
object AssertOnQuery {
def apply(condition: StreamExecution => Boolean, message: String = ""): AssertOnQuery = {
new AssertOnQuery(condition, message)
}
def apply(message: String)(condition: StreamExecution => Boolean): AssertOnQuery = {
new AssertOnQuery(condition, message)
}
}
/** A helper for running actions on a Streaming Dataset. See `checkAnswer(DataFrame)`. */
def testStream(stream: Dataset[_])(actions: StreamAction*): Unit =
@ -145,6 +183,7 @@ trait StreamTest extends QueryTest with Timeouts {
var pos = 0
var currentPlan: LogicalPlan = stream.logicalPlan
var currentStream: StreamExecution = null
var lastStream: StreamExecution = null
val awaiting = new mutable.HashMap[Source, Offset]()
val sink = new MemorySink(stream.schema)
@ -170,6 +209,7 @@ trait StreamTest extends QueryTest with Timeouts {
def threadState =
if (currentStream != null && currentStream.microBatchThread.isAlive) "alive" else "dead"
def testState =
s"""
|== Progress ==
@ -181,16 +221,49 @@ trait StreamTest extends QueryTest with Timeouts {
|${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""}
|
|== Sink ==
|$sink
|${sink.toDebugString}
|
|== Plan ==
|${if (currentStream != null) currentStream.lastExecution else ""}
"""
""".stripMargin
def checkState(check: Boolean, error: String) = if (!check) {
def verify(condition: => Boolean, message: String): Unit = {
try {
Assertions.assert(condition)
} catch {
case NonFatal(e) =>
failTest(message, e)
}
}
def eventually[T](message: String)(func: => T): T = {
try {
Eventually.eventually(Timeout(streamingTimeout)) {
func
}
} catch {
case NonFatal(e) =>
failTest(message, e)
}
}
def failTest(message: String, cause: Throwable = null) = {
// Recursively pretty print a exception with truncated stacktrace and internal cause
def exceptionToString(e: Throwable, prefix: String = ""): String = {
val base = s"$prefix${e.getMessage}" +
e.getStackTrace.take(10).mkString(s"\n$prefix", s"\n$prefix\t", "\n")
if (e.getCause != null) {
base + s"\n$prefix\tCaused by: " + exceptionToString(e.getCause, s"$prefix\t")
} else {
base
}
}
val c = Option(cause).map(exceptionToString(_))
val m = if (message != null && message.size > 0) Some(message) else None
fail(
s"""
|Invalid State: $error
|${(m ++ c).mkString(": ")}
|$testState
""".stripMargin)
}
@ -201,9 +274,13 @@ trait StreamTest extends QueryTest with Timeouts {
startedTest.foreach { action =>
action match {
case StartStream =>
checkState(currentStream == null, "stream already running")
currentStream = new StreamExecution(sqlContext, stream.logicalPlan, sink)
verify(currentStream == null, "stream already running")
lastStream = currentStream
currentStream =
sqlContext
.streams
.startQuery(StreamExecution.nextName, stream, sink)
.asInstanceOf[StreamExecution]
currentStream.microBatchThread.setUncaughtExceptionHandler(
new UncaughtExceptionHandler {
override def uncaughtException(t: Thread, e: Throwable): Unit = {
@ -213,77 +290,100 @@ trait StreamTest extends QueryTest with Timeouts {
})
case StopStream =>
checkState(currentStream != null, "can not stop a stream that is not running")
currentStream.stop()
currentStream = null
case DropBatches(num) =>
checkState(currentStream == null, "dropping batches while running leads to corruption")
sink.dropBatches(num)
case ExpectFailure =>
try failAfter(streamingTimout) {
while (streamDeathCause == null) {
Thread.sleep(100)
}
verify(currentStream != null, "can not stop a stream that is not running")
try failAfter(streamingTimeout) {
currentStream.stop()
verify(!currentStream.microBatchThread.isAlive,
s"microbatch thread not stopped")
verify(!currentStream.isActive,
"query.isActive() is false even after stopping")
verify(currentStream.exception.isEmpty,
s"query.exception() is not empty after clean stop: " +
currentStream.exception.map(_.toString()).getOrElse(""))
} catch {
case _: InterruptedException =>
case _: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
fail(
s"""
|Timed out while waiting for failure.
|$testState
""".stripMargin)
failTest("Timed out while stopping and waiting for microbatchthread to terminate.")
case t: Throwable =>
failTest("Error while stopping stream", t)
} finally {
lastStream = currentStream
currentStream = null
}
currentStream = null
streamDeathCause = null
case DropBatches(num) =>
verify(currentStream == null, "dropping batches while running leads to corruption")
sink.dropBatches(num)
case ef: ExpectFailure[_] =>
verify(currentStream != null, "can not expect failure when stream is not running")
try failAfter(streamingTimeout) {
val thrownException = intercept[ContinuousQueryException] {
currentStream.awaitTermination()
}
eventually("microbatch thread not stopped after termination with failure") {
assert(!currentStream.microBatchThread.isAlive)
}
verify(thrownException.query.eq(currentStream),
s"incorrect query reference in exception")
verify(currentStream.exception === Some(thrownException),
s"incorrect exception returned by query.exception()")
val exception = currentStream.exception.get
verify(exception.cause.getClass === ef.causeClass,
"incorrect cause in exception returned by query.exception()\n" +
s"\tExpected: ${ef.causeClass}\n\tReturned: ${exception.cause.getClass}")
} catch {
case _: InterruptedException =>
case _: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
failTest("Timed out while waiting for failure")
case t: Throwable =>
failTest("Error while checking stream failure", t)
} finally {
lastStream = currentStream
currentStream = null
streamDeathCause = null
}
case a: AssertOnQuery =>
verify(currentStream != null || lastStream != null,
"cannot assert when not stream has been started")
val streamToAssert = Option(currentStream).getOrElse(lastStream)
verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}")
case a: Assert =>
val streamToAssert = Option(currentStream).getOrElse(lastStream)
verify({ a.run(); true }, s"Assert failed: ${a.message}")
case a: AddData =>
awaiting.put(a.source, a.addData())
case CheckAnswerRows(expectedAnswer) =>
checkState(currentStream != null, "stream not running")
verify(currentStream != null, "stream not running")
// Block until all data added has been processed
awaiting.foreach { case (source, offset) =>
failAfter(streamingTimout) {
failAfter(streamingTimeout) {
currentStream.awaitOffset(source, offset)
}
}
val allData = try sink.allData catch {
case e: Exception =>
fail(
s"""
|Exception while getting data from sink $e
|$testState
""".stripMargin)
failTest("Exception while getting data from sink", e)
}
QueryTest.sameRows(expectedAnswer, allData).foreach {
error => fail(
s"""
|$error
|$testState
""".stripMargin)
error => failTest(error)
}
}
pos += 1
}
} catch {
case _: InterruptedException if streamDeathCause != null =>
fail(
s"""
|Stream Thread Died
|$testState
""".stripMargin)
failTest("Stream Thread Died")
case _: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
fail(
s"""
|Timed out waiting for stream
|$testState
""".stripMargin)
failTest("Timed out waiting for stream")
} finally {
if (currentStream != null && currentStream.microBatchThread.isAlive) {
currentStream.stop()
@ -335,7 +435,8 @@ trait StreamTest extends QueryTest with Timeouts {
case r if r < 0.7 => // AddData
addRandomData()
case _ => // StartStream
case _ => // StopStream
addCheck()
actions += StopStream
running = false
}
@ -345,4 +446,59 @@ trait StreamTest extends QueryTest with Timeouts {
addCheck()
testStream(ds)(actions: _*)
}
object AwaitTerminationTester {
trait ExpectedBehavior
/** Expect awaitTermination to not be blocked */
case object ExpectNotBlocked extends ExpectedBehavior
/** Expect awaitTermination to get blocked */
case object ExpectBlocked extends ExpectedBehavior
/** Expect awaitTermination to throw an exception */
case class ExpectException[E <: Exception]()(implicit val t: ClassTag[E])
extends ExpectedBehavior
private val DEFAULT_TEST_TIMEOUT = 1 second
def test(
expectedBehavior: ExpectedBehavior,
awaitTermFunc: () => Unit,
testTimeout: Span = DEFAULT_TEST_TIMEOUT
): Unit = {
expectedBehavior match {
case ExpectNotBlocked =>
withClue("Got blocked when expected non-blocking.") {
failAfter(testTimeout) {
awaitTermFunc()
}
}
case ExpectBlocked =>
withClue("Was not blocked when expected.") {
intercept[TestFailedDueToTimeoutException] {
failAfter(testTimeout) {
awaitTermFunc()
}
}
}
case e: ExpectException[_] =>
val thrownException =
withClue(s"Did not throw ${e.t.runtimeClass.getSimpleName} when expected.") {
intercept[ContinuousQueryException] {
failAfter(testTimeout) {
awaitTermFunc()
}
}
}
assert(thrownException.cause.getClass === e.t.runtimeClass,
"exception of incorrect type was throw")
}
}
}
}

View file

@ -0,0 +1,306 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.streaming
import scala.concurrent.Future
import scala.util.Random
import scala.util.control.NonFatal
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.time.Span
import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkException
import org.apache.spark.sql.{ContinuousQuery, Dataset, StreamTest}
import org.apache.spark.sql.execution.streaming.{MemorySink, MemoryStream, StreamExecution, StreamingRelation}
import org.apache.spark.sql.test.SharedSQLContext
class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with BeforeAndAfter {
import AwaitTerminationTester._
import testImplicits._
override val streamingTimeout = 20.seconds
before {
assert(sqlContext.streams.active.isEmpty)
sqlContext.streams.resetTerminated()
}
after {
assert(sqlContext.streams.active.isEmpty)
sqlContext.streams.resetTerminated()
}
test("listing") {
val (m1, ds1) = makeDataset
val (m2, ds2) = makeDataset
val (m3, ds3) = makeDataset
withQueriesOn(ds1, ds2, ds3) { queries =>
require(queries.size === 3)
assert(sqlContext.streams.active.toSet === queries.toSet)
val (q1, q2, q3) = (queries(0), queries(1), queries(2))
assert(sqlContext.streams.get(q1.name).eq(q1))
assert(sqlContext.streams.get(q2.name).eq(q2))
assert(sqlContext.streams.get(q3.name).eq(q3))
intercept[IllegalArgumentException] {
sqlContext.streams.get("non-existent-name")
}
q1.stop()
assert(sqlContext.streams.active.toSet === Set(q2, q3))
val ex1 = withClue("no error while getting non-active query") {
intercept[IllegalArgumentException] {
sqlContext.streams.get(q1.name)
}
}
assert(ex1.getMessage.contains(q1.name), "error does not contain name of query to be fetched")
assert(sqlContext.streams.get(q2.name).eq(q2))
m2.addData(0) // q2 should terminate with error
eventually(Timeout(streamingTimeout)) {
require(!q2.isActive)
require(q2.exception.isDefined)
}
val ex2 = withClue("no error while getting non-active query") {
intercept[IllegalArgumentException] {
sqlContext.streams.get(q2.name).eq(q2)
}
}
assert(sqlContext.streams.active.toSet === Set(q3))
}
}
test("awaitAnyTermination without timeout and resetTerminated") {
val datasets = Seq.fill(5)(makeDataset._2)
withQueriesOn(datasets: _*) { queries =>
require(queries.size === datasets.size)
assert(sqlContext.streams.active.toSet === queries.toSet)
// awaitAnyTermination should be blocking
testAwaitAnyTermination(ExpectBlocked)
// Stop a query asynchronously and see if it is reported through awaitAnyTermination
val q1 = stopRandomQueryAsync(stopAfter = 100 milliseconds, withError = false)
testAwaitAnyTermination(ExpectNotBlocked)
require(!q1.isActive) // should be inactive by the time the prev awaitAnyTerm returned
// All subsequent calls to awaitAnyTermination should be non-blocking
testAwaitAnyTermination(ExpectNotBlocked)
// Resetting termination should make awaitAnyTermination() blocking again
sqlContext.streams.resetTerminated()
testAwaitAnyTermination(ExpectBlocked)
// Terminate a query asynchronously with exception and see awaitAnyTermination throws
// the exception
val q2 = stopRandomQueryAsync(100 milliseconds, withError = true)
testAwaitAnyTermination(ExpectException[SparkException])
require(!q2.isActive) // should be inactive by the time the prev awaitAnyTerm returned
// All subsequent calls to awaitAnyTermination should throw the exception
testAwaitAnyTermination(ExpectException[SparkException])
// Resetting termination should make awaitAnyTermination() blocking again
sqlContext.streams.resetTerminated()
testAwaitAnyTermination(ExpectBlocked)
// Terminate multiple queries, one with failure and see whether awaitAnyTermination throws
// the exception
val q3 = stopRandomQueryAsync(10 milliseconds, withError = false)
testAwaitAnyTermination(ExpectNotBlocked)
require(!q3.isActive)
val q4 = stopRandomQueryAsync(10 milliseconds, withError = true)
eventually(Timeout(streamingTimeout)) { require(!q4.isActive) }
// After q4 terminates with exception, awaitAnyTerm should start throwing exception
testAwaitAnyTermination(ExpectException[SparkException])
}
}
test("awaitAnyTermination with timeout and resetTerminated") {
val datasets = Seq.fill(6)(makeDataset._2)
withQueriesOn(datasets: _*) { queries =>
require(queries.size === datasets.size)
assert(sqlContext.streams.active.toSet === queries.toSet)
// awaitAnyTermination should be blocking or non-blocking depending on timeout values
testAwaitAnyTermination(
ExpectBlocked,
awaitTimeout = 2 seconds,
expectedReturnedValue = false,
testBehaviorFor = 1 second)
testAwaitAnyTermination(
ExpectNotBlocked,
awaitTimeout = 50 milliseconds,
expectedReturnedValue = false,
testBehaviorFor = 1 second)
// Stop a query asynchronously within timeout and awaitAnyTerm should be unblocked
val q1 = stopRandomQueryAsync(stopAfter = 100 milliseconds, withError = false)
testAwaitAnyTermination(
ExpectNotBlocked,
awaitTimeout = 1 second,
expectedReturnedValue = true,
testBehaviorFor = 2 seconds)
require(!q1.isActive) // should be inactive by the time the prev awaitAnyTerm returned
// All subsequent calls to awaitAnyTermination should be non-blocking even if timeout is high
testAwaitAnyTermination(
ExpectNotBlocked, awaitTimeout = 2 seconds, expectedReturnedValue = true)
// Resetting termination should make awaitAnyTermination() blocking again
sqlContext.streams.resetTerminated()
testAwaitAnyTermination(
ExpectBlocked,
awaitTimeout = 2 seconds,
expectedReturnedValue = false,
testBehaviorFor = 1 second)
// Terminate a query asynchronously with exception within timeout, awaitAnyTermination should
// throws the exception
val q2 = stopRandomQueryAsync(100 milliseconds, withError = true)
testAwaitAnyTermination(
ExpectException[SparkException],
awaitTimeout = 1 second,
testBehaviorFor = 2 seconds)
require(!q2.isActive) // should be inactive by the time the prev awaitAnyTerm returned
// All subsequent calls to awaitAnyTermination should throw the exception
testAwaitAnyTermination(
ExpectException[SparkException],
awaitTimeout = 1 second,
testBehaviorFor = 2 seconds)
// Terminate a query asynchronously outside the timeout, awaitAnyTerm should be blocked
sqlContext.streams.resetTerminated()
val q3 = stopRandomQueryAsync(1 second, withError = true)
testAwaitAnyTermination(
ExpectNotBlocked,
awaitTimeout = 100 milliseconds,
expectedReturnedValue = false,
testBehaviorFor = 2 seconds)
// After that query is stopped, awaitAnyTerm should throw exception
eventually(Timeout(streamingTimeout)) { require(!q3.isActive) } // wait for query to stop
testAwaitAnyTermination(
ExpectException[SparkException],
awaitTimeout = 100 milliseconds,
testBehaviorFor = 2 seconds)
// Terminate multiple queries, one with failure and see whether awaitAnyTermination throws
// the exception
sqlContext.streams.resetTerminated()
val q4 = stopRandomQueryAsync(10 milliseconds, withError = false)
testAwaitAnyTermination(
ExpectNotBlocked, awaitTimeout = 1 second, expectedReturnedValue = true)
require(!q4.isActive)
val q5 = stopRandomQueryAsync(10 milliseconds, withError = true)
eventually(Timeout(streamingTimeout)) { require(!q5.isActive) }
// After q5 terminates with exception, awaitAnyTerm should start throwing exception
testAwaitAnyTermination(ExpectException[SparkException], awaitTimeout = 100 milliseconds)
}
}
/** Run a body of code by defining a query each on multiple datasets */
private def withQueriesOn(datasets: Dataset[_]*)(body: Seq[ContinuousQuery] => Unit): Unit = {
failAfter(streamingTimeout) {
val queries = withClue("Error starting queries") {
datasets.map { ds =>
@volatile var query: StreamExecution = null
try {
val df = ds.toDF
query = sqlContext
.streams
.startQuery(StreamExecution.nextName, df, new MemorySink(df.schema))
.asInstanceOf[StreamExecution]
} catch {
case NonFatal(e) =>
if (query != null) query.stop()
throw e
}
query
}
}
try {
body(queries)
} finally {
queries.foreach(_.stop())
}
}
}
/** Test the behavior of awaitAnyTermination */
private def testAwaitAnyTermination(
expectedBehavior: ExpectedBehavior,
expectedReturnedValue: Boolean = false,
awaitTimeout: Span = null,
testBehaviorFor: Span = 2 seconds
): Unit = {
def awaitTermFunc(): Unit = {
if (awaitTimeout != null && awaitTimeout.toMillis > 0) {
val returnedValue = sqlContext.streams.awaitAnyTermination(awaitTimeout.toMillis)
assert(returnedValue === expectedReturnedValue, "Returned value does not match expected")
} else {
sqlContext.streams.awaitAnyTermination()
}
}
AwaitTerminationTester.test(expectedBehavior, awaitTermFunc, testBehaviorFor)
}
/** Stop a random active query either with `stop()` or with an error */
private def stopRandomQueryAsync(stopAfter: Span, withError: Boolean): ContinuousQuery = {
import scala.concurrent.ExecutionContext.Implicits.global
val activeQueries = sqlContext.streams.active
val queryToStop = activeQueries(Random.nextInt(activeQueries.length))
Future {
Thread.sleep(stopAfter.toMillis)
if (withError) {
logDebug(s"Terminating query ${queryToStop.name} with error")
queryToStop.asInstanceOf[StreamExecution].logicalPlan.collect {
case StreamingRelation(memoryStream, _) =>
memoryStream.asInstanceOf[MemoryStream[Int]].addData(0)
}
} else {
logDebug(s"Stopping query ${queryToStop.name}")
queryToStop.stop()
}
}
queryToStop
}
private def makeDataset: (MemoryStream[Int], Dataset[Int]) = {
val inputData = MemoryStream[Int]
val mapped = inputData.toDS.map(6 / _)
(inputData, mapped)
}
}

View file

@ -0,0 +1,139 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.streaming
import org.apache.spark.SparkException
import org.apache.spark.sql.StreamTest
import org.apache.spark.sql.execution.streaming.{CompositeOffset, LongOffset, MemoryStream, StreamExecution}
import org.apache.spark.sql.test.SharedSQLContext
class ContinuousQuerySuite extends StreamTest with SharedSQLContext {
import AwaitTerminationTester._
import testImplicits._
test("lifecycle states and awaitTermination") {
val inputData = MemoryStream[Int]
val mapped = inputData.toDS().map { 6 / _}
testStream(mapped)(
AssertOnQuery(_.isActive === true),
AssertOnQuery(_.exception.isEmpty),
AddData(inputData, 1, 2),
CheckAnswer(6, 3),
TestAwaitTermination(ExpectBlocked),
TestAwaitTermination(ExpectBlocked, timeoutMs = 2000),
TestAwaitTermination(ExpectNotBlocked, timeoutMs = 10, expectedReturnValue = false),
StopStream,
AssertOnQuery(_.isActive === false),
AssertOnQuery(_.exception.isEmpty),
TestAwaitTermination(ExpectNotBlocked),
TestAwaitTermination(ExpectNotBlocked, timeoutMs = 2000, expectedReturnValue = true),
TestAwaitTermination(ExpectNotBlocked, timeoutMs = 10, expectedReturnValue = true),
StartStream,
AssertOnQuery(_.isActive === true),
AddData(inputData, 0),
ExpectFailure[SparkException],
AssertOnQuery(_.isActive === false),
TestAwaitTermination(ExpectException[SparkException]),
TestAwaitTermination(ExpectException[SparkException], timeoutMs = 2000),
TestAwaitTermination(ExpectException[SparkException], timeoutMs = 10),
AssertOnQuery(
q => q.exception.get.startOffset.get === q.streamProgress.toCompositeOffset(Seq(inputData)),
"incorrect start offset on exception")
)
}
test("source and sink statuses") {
val inputData = MemoryStream[Int]
val mapped = inputData.toDS().map(6 / _)
testStream(mapped)(
AssertOnQuery(_.sourceStatuses.length === 1),
AssertOnQuery(_.sourceStatuses(0).description.contains("Memory")),
AssertOnQuery(_.sourceStatuses(0).offset === None),
AssertOnQuery(_.sinkStatus.description.contains("Memory")),
AssertOnQuery(_.sinkStatus.offset === None),
AddData(inputData, 1, 2),
CheckAnswer(6, 3),
AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(0))),
AssertOnQuery(_.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(0)))),
AddData(inputData, 1, 2),
CheckAnswer(6, 3, 6, 3),
AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(1))),
AssertOnQuery(_.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(1)))),
AddData(inputData, 0),
ExpectFailure[SparkException],
AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(2))),
AssertOnQuery(_.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(1))))
)
}
/**
* A [[StreamAction]] to test the behavior of `ContinuousQuery.awaitTermination()`.
*
* @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown)
* @param timeoutMs Timeout in milliseconds
* When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout)
* When timeoutMs > 0, awaitTermination(timeoutMs) is tested
* @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used
*/
case class TestAwaitTermination(
expectedBehavior: ExpectedBehavior,
timeoutMs: Int = -1,
expectedReturnValue: Boolean = false
) extends AssertOnQuery(
TestAwaitTermination.assertOnQueryCondition(expectedBehavior, timeoutMs, expectedReturnValue),
"Error testing awaitTermination behavior"
) {
override def toString(): String = {
s"TestAwaitTermination($expectedBehavior, timeoutMs = $timeoutMs, " +
s"expectedReturnValue = $expectedReturnValue)"
}
}
object TestAwaitTermination {
/**
* Tests the behavior of `ContinuousQuery.awaitTermination`.
*
* @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown)
* @param timeoutMs Timeout in milliseconds
* When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout)
* When timeoutMs > 0, awaitTermination(timeoutMs) is tested
* @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used
*/
def assertOnQueryCondition(
expectedBehavior: ExpectedBehavior,
timeoutMs: Int,
expectedReturnValue: Boolean
)(q: StreamExecution): Boolean = {
def awaitTermFunc(): Unit = {
if (timeoutMs <= 0) {
q.awaitTermination()
} else {
val returnedValue = q.awaitTermination(timeoutMs)
assert(returnedValue === expectedReturnValue, "Returned value does not match expected")
}
}
AwaitTerminationTester.test(expectedBehavior, awaitTermFunc)
true // If the control reached here, then everything worked as expected
}
}
}

View file

@ -17,7 +17,9 @@
package org.apache.spark.sql.streaming.test
import org.apache.spark.sql.{AnalysisException, SQLContext, StreamTest}
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.{AnalysisException, ContinuousQuery, SQLContext, StreamTest}
import org.apache.spark.sql.execution.streaming.{Batch, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.test.SharedSQLContext
@ -57,9 +59,13 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
}
}
class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext {
class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter {
import testImplicits._
after {
sqlContext.streams.active.foreach(_.stop())
}
test("resolve default source") {
sqlContext.read
.format("org.apache.spark.sql.streaming.test")
@ -188,4 +194,63 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext {
assert(LastOptions.parameters("boolOpt") == "false")
assert(LastOptions.parameters("doubleOpt") == "6.7")
}
test("unique query names") {
/** Start a query with a specific name */
def startQueryWithName(name: String = ""): ContinuousQuery = {
sqlContext.read
.format("org.apache.spark.sql.streaming.test")
.stream("/test")
.write
.format("org.apache.spark.sql.streaming.test")
.queryName(name)
.stream()
}
/** Start a query without specifying a name */
def startQueryWithoutName(): ContinuousQuery = {
sqlContext.read
.format("org.apache.spark.sql.streaming.test")
.stream("/test")
.write
.format("org.apache.spark.sql.streaming.test")
.stream()
}
/** Get the names of active streams */
def activeStreamNames: Set[String] = {
val streams = sqlContext.streams.active
val names = streams.map(_.name).toSet
assert(streams.length === names.size, s"names of active queries are not unique: $names")
names
}
val q1 = startQueryWithName("name")
// Should not be able to start another query with the same name
intercept[IllegalArgumentException] {
startQueryWithName("name")
}
assert(activeStreamNames === Set("name"))
// Should be able to start queries with other names
val q3 = startQueryWithName("another-name")
assert(activeStreamNames === Set("name", "another-name"))
// Should be able to start queries with auto-generated names
val q4 = startQueryWithoutName()
assert(activeStreamNames.contains(q4.name))
// Should not be able to start a query with same auto-generated name
intercept[IllegalArgumentException] {
startQueryWithName(q4.name)
}
// Should be able to start query with that name after stopping the previous query
q1.stop()
val q5 = startQueryWithName("name")
assert(activeStreamNames.contains("name"))
sqlContext.streams.active.foreach(_.stop())
}
}

View file

@ -0,0 +1,222 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.util
import java.util.concurrent.ConcurrentLinkedQueue
import scala.util.control.NonFatal
import org.scalatest.BeforeAndAfter
import org.scalatest.PrivateMethodTester._
import org.scalatest.concurrent.AsyncAssertions.Waiter
import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.time.SpanSugar._
import org.apache.spark.sql._
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.util.ContinuousQueryListener.{QueryProgress, QueryStarted, QueryTerminated}
class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with BeforeAndAfter {
import testImplicits._
after {
sqlContext.streams.active.foreach(_.stop())
assert(sqlContext.streams.active.isEmpty)
assert(addedListeners.isEmpty)
}
test("single listener") {
val listener = new QueryStatusCollector
val input = MemoryStream[Int]
withListenerAdded(listener) {
testStream(input.toDS)(
StartStream,
Assert("Incorrect query status in onQueryStarted") {
val status = listener.startStatus
assert(status != null)
assert(status.active == true)
assert(status.sourceStatuses.size === 1)
assert(status.sourceStatuses(0).description.contains("Memory"))
// The source and sink offsets must be None as this must be called before the
// batches have started
assert(status.sourceStatuses(0).offset === None)
assert(status.sinkStatus.offset === None)
// No progress events or termination events
assert(listener.progressStatuses.isEmpty)
assert(listener.terminationStatus === null)
},
AddDataMemory(input, Seq(1, 2, 3)),
CheckAnswer(1, 2, 3),
Assert("Incorrect query status in onQueryProgress") {
eventually(Timeout(streamingTimeout)) {
// There should be only on progress event as batch has been processed
assert(listener.progressStatuses.size === 1)
val status = listener.progressStatuses.peek()
assert(status != null)
assert(status.active == true)
assert(status.sourceStatuses(0).offset === Some(LongOffset(0)))
assert(status.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(0))))
// No termination events
assert(listener.terminationStatus === null)
}
},
StopStream,
Assert("Incorrect query status in onQueryTerminated") {
eventually(Timeout(streamingTimeout)) {
val status = listener.terminationStatus
assert(status != null)
assert(status.active === false) // must be inactive by the time onQueryTerm is called
assert(status.sourceStatuses(0).offset === Some(LongOffset(0)))
assert(status.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(0))))
}
listener.checkAsyncErrors()
}
)
}
}
test("adding and removing listener") {
def isListenerActive(listener: QueryStatusCollector): Boolean = {
listener.reset()
testStream(MemoryStream[Int].toDS)(
StartStream,
StopStream
)
listener.startStatus != null
}
try {
val listener1 = new QueryStatusCollector
val listener2 = new QueryStatusCollector
sqlContext.streams.addListener(listener1)
assert(isListenerActive(listener1) === true)
assert(isListenerActive(listener2) === false)
sqlContext.streams.addListener(listener2)
assert(isListenerActive(listener1) === true)
assert(isListenerActive(listener2) === true)
sqlContext.streams.removeListener(listener1)
assert(isListenerActive(listener1) === false)
assert(isListenerActive(listener2) === true)
} finally {
addedListeners.foreach(sqlContext.streams.removeListener)
}
}
test("event ordering") {
val listener = new QueryStatusCollector
withListenerAdded(listener) {
for (i <- 1 to 100) {
listener.reset()
require(listener.startStatus === null)
testStream(MemoryStream[Int].toDS)(
StartStream,
Assert(listener.startStatus !== null, "onQueryStarted not called before query returned"),
StopStream,
Assert { listener.checkAsyncErrors() }
)
}
}
}
private def withListenerAdded(listener: ContinuousQueryListener)(body: => Unit): Unit = {
@volatile var query: StreamExecution = null
try {
failAfter(1 minute) {
sqlContext.streams.addListener(listener)
body
}
} finally {
sqlContext.streams.removeListener(listener)
}
}
private def addedListeners(): Array[ContinuousQueryListener] = {
val listenerBusMethod =
PrivateMethod[ContinuousQueryListenerBus]('listenerBus)
val listenerBus = sqlContext.streams invokePrivate listenerBusMethod()
listenerBus.listeners.toArray.map(_.asInstanceOf[ContinuousQueryListener])
}
class QueryStatusCollector extends ContinuousQueryListener {
private val asyncTestWaiter = new Waiter // to catch errors in the async listener events
@volatile var startStatus: QueryStatus = null
@volatile var terminationStatus: QueryStatus = null
val progressStatuses = new ConcurrentLinkedQueue[QueryStatus]
def reset(): Unit = {
startStatus = null
terminationStatus = null
progressStatuses.clear()
// To reset the waiter
try asyncTestWaiter.await(timeout(1 milliseconds)) catch {
case NonFatal(e) =>
}
}
def checkAsyncErrors(): Unit = {
asyncTestWaiter.await(timeout(streamingTimeout))
}
override def onQueryStarted(queryStarted: QueryStarted): Unit = {
asyncTestWaiter {
startStatus = QueryStatus(queryStarted.query)
}
}
override def onQueryProgress(queryProgress: QueryProgress): Unit = {
asyncTestWaiter {
assert(startStatus != null, "onQueryProgress called before onQueryStarted")
progressStatuses.add(QueryStatus(queryProgress.query))
}
}
override def onQueryTerminated(queryTerminated: QueryTerminated): Unit = {
asyncTestWaiter {
assert(startStatus != null, "onQueryTerminated called before onQueryStarted")
terminationStatus = QueryStatus(queryTerminated.query)
}
asyncTestWaiter.dismiss()
}
}
case class QueryStatus(
active: Boolean,
expection: Option[Exception],
sourceStatuses: Array[SourceStatus],
sinkStatus: SinkStatus)
object QueryStatus {
def apply(query: ContinuousQuery): QueryStatus = {
QueryStatus(query.isActive, query.exception, query.sourceStatuses, query.sinkStatus)
}
}
}