From 0902e20288366db6270f3a444e66114b1b63a3e2 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 10 Feb 2016 16:45:06 -0800 Subject: [PATCH] [SPARK-13146][SQL] Management API for continuous queries ### Management API for Continuous Queries **API for getting status of each query** - Whether active or not - Unique name of each query - Status of the sources and sinks - Exceptions **API for managing each query** - Immediately stop an active query - Waiting for a query to be terminated, correctly or with error **API for managing multiple queries** - Listing all active queries - Getting an active query by name - Waiting for any one of the active queries to be terminated **API for listening to query life cycle events** - ContinuousQueryListener API for query start, progress and termination events. Author: Tathagata Das Closes #11030 from tdas/streaming-df-management-api. --- .../apache/spark/sql/ContinuousQuery.scala | 72 ++++- .../spark/sql/ContinuousQueryException.scala | 54 ++++ .../spark/sql/ContinuousQueryManager.scala | 193 +++++++++++ .../apache/spark/sql/DataFrameWriter.scala | 14 +- .../org/apache/spark/sql/SQLContext.scala | 12 + .../org/apache/spark/sql/SinkStatus.scala | 34 ++ .../org/apache/spark/sql/SourceStatus.scala | 34 ++ .../ContinuousQueryListenerBus.scala | 82 +++++ .../execution/streaming/StreamExecution.scala | 217 ++++++++++--- .../execution/streaming/StreamProgress.scala | 4 + .../sql/execution/streaming/memory.scala | 20 +- .../sql/util/ContinuousQueryListener.scala | 67 ++++ .../org/apache/spark/sql/StreamTest.scala | 262 ++++++++++++--- .../ContinuousQueryManagerSuite.scala | 306 ++++++++++++++++++ .../sql/streaming/ContinuousQuerySuite.scala | 139 ++++++++ .../DataFrameReaderWriterSuite.scala | 69 +++- .../util/ContinuousQueryListenerSuite.scala | 222 +++++++++++++ 17 files changed, 1686 insertions(+), 115 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala index 1c2c0290fc..eb69804c39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala @@ -17,14 +17,84 @@ package org.apache.spark.sql +import org.apache.spark.annotation.Experimental + /** + * :: Experimental :: * A handle to a query that is executing continuously in the background as new data arrives. + * All these methods are thread-safe. + * @since 2.0.0 */ +@Experimental trait ContinuousQuery { /** - * Stops the execution of this query if it is running. This method blocks until the threads + * Returns the name of the query. + * @since 2.0.0 + */ + def name: String + + /** + * Returns the SQLContext associated with `this` query + * @since 2.0.0 + */ + def sqlContext: SQLContext + + /** + * Whether the query is currently active or not + * @since 2.0.0 + */ + def isActive: Boolean + + /** + * Returns the [[ContinuousQueryException]] if the query was terminated by an exception. + * @since 2.0.0 + */ + def exception: Option[ContinuousQueryException] + + /** + * Returns current status of all the sources. + * @since 2.0.0 + */ + def sourceStatuses: Array[SourceStatus] + + /** Returns current status of the sink. */ + def sinkStatus: SinkStatus + + /** + * Waits for the termination of `this` query, either by `query.stop()` or by an exception. + * If the query has terminated with an exception, then the exception will be thrown. + * + * If the query has terminated, then all subsequent calls to this method will either return + * immediately (if the query was terminated by `stop()`), or throw the exception + * immediately (if the query has terminated with exception). + * + * @throws ContinuousQueryException, if `this` query has terminated with an exception. + * + * @since 2.0.0 + */ + def awaitTermination(): Unit + + /** + * Waits for the termination of `this` query, either by `query.stop()` or by an exception. + * If the query has terminated with an exception, then the exception will be throw. + * Otherwise, it returns whether the query has terminated or not within the `timeoutMs` + * milliseconds. + * + * If the query has terminated, then all subsequent calls to this method will either return + * `true` immediately (if the query was terminated by `stop()`), or throw the exception + * immediately (if the query has terminated with exception). + * + * @throws ContinuousQueryException, if `this` query has terminated with an exception + * + * @since 2.0.0 + */ + def awaitTermination(timeoutMs: Long): Boolean + + /** + * Stops the execution of this query if it is running. This method blocks until the threads * performing execution has stopped. + * @since 2.0.0 */ def stop(): Unit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala new file mode 100644 index 0000000000..67dd9dbe23 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.streaming.{Offset, StreamExecution} + +/** + * :: Experimental :: + * Exception that stopped a [[ContinuousQuery]]. + * @param query Query that caused the exception + * @param message Message of this exception + * @param cause Internal cause of this exception + * @param startOffset Starting offset (if known) of the range of data in which exception occurred + * @param endOffset Ending offset (if known) of the range of data in exception occurred + * @since 2.0.0 + */ +@Experimental +class ContinuousQueryException private[sql]( + val query: ContinuousQuery, + val message: String, + val cause: Throwable, + val startOffset: Option[Offset] = None, + val endOffset: Option[Offset] = None + ) extends Exception(message, cause) { + + /** Time when the exception occurred */ + val time: Long = System.currentTimeMillis + + override def toString(): String = { + val causeStr = + s"${cause.getMessage} ${cause.getStackTrace.take(10).mkString("", "\n|\t", "\n")}" + s""" + |$causeStr + | + |${query.asInstanceOf[StreamExecution].toDebugString} + """.stripMargin + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala new file mode 100644 index 0000000000..13142d0e61 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution} +import org.apache.spark.sql.util.ContinuousQueryListener + +/** + * :: Experimental :: + * A class to manage all the [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active + * on a [[SQLContext]]. + * + * @since 2.0.0 + */ +@Experimental +class ContinuousQueryManager(sqlContext: SQLContext) { + + private val listenerBus = new ContinuousQueryListenerBus(sqlContext.sparkContext.listenerBus) + private val activeQueries = new mutable.HashMap[String, ContinuousQuery] + private val activeQueriesLock = new Object + private val awaitTerminationLock = new Object + + private var lastTerminatedQuery: ContinuousQuery = null + + /** + * Returns a list of active queries associated with this SQLContext + * + * @since 2.0.0 + */ + def active: Array[ContinuousQuery] = activeQueriesLock.synchronized { + activeQueries.values.toArray + } + + /** + * Returns an active query from this SQLContext or throws exception if bad name + * + * @since 2.0.0 + */ + def get(name: String): ContinuousQuery = activeQueriesLock.synchronized { + activeQueries.get(name).getOrElse { + throw new IllegalArgumentException(s"There is no active query with name $name") + } + } + + /** + * Wait until any of the queries on the associated SQLContext has terminated since the + * creation of the context, or since `resetTerminated()` was called. If any query was terminated + * with an exception, then the exception will be thrown. + * + * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either + * return immediately (if the query was terminated by `query.stop()`), + * or throw the exception immediately (if the query was terminated with exception). Use + * `resetTerminated()` to clear past terminations and wait for new terminations. + * + * In the case where multiple queries have terminated since `resetTermination()` was called, + * if any query has terminated with exception, then `awaitAnyTermination()` will + * throw any of the exception. For correctly documenting exceptions across multiple queries, + * users need to stop all of them after any of them terminates with exception, and then check the + * `query.exception()` for each query. + * + * @throws ContinuousQueryException, if any query has terminated with an exception + * + * @since 2.0.0 + */ + def awaitAnyTermination(): Unit = { + awaitTerminationLock.synchronized { + while (lastTerminatedQuery == null) { + awaitTerminationLock.wait(10) + } + if (lastTerminatedQuery != null && lastTerminatedQuery.exception.nonEmpty) { + throw lastTerminatedQuery.exception.get + } + } + } + + /** + * Wait until any of the queries on the associated SQLContext has terminated since the + * creation of the context, or since `resetTerminated()` was called. Returns whether any query + * has terminated or not (multiple may have terminated). If any query has terminated with an + * exception, then the exception will be thrown. + * + * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either + * return `true` immediately (if the query was terminated by `query.stop()`), + * or throw the exception immediately (if the query was terminated with exception). Use + * `resetTerminated()` to clear past terminations and wait for new terminations. + * + * In the case where multiple queries have terminated since `resetTermination()` was called, + * if any query has terminated with exception, then `awaitAnyTermination()` will + * throw any of the exception. For correctly documenting exceptions across multiple queries, + * users need to stop all of them after any of them terminates with exception, and then check the + * `query.exception()` for each query. + * + * @throws ContinuousQueryException, if any query has terminated with an exception + * + * @since 2.0.0 + */ + def awaitAnyTermination(timeoutMs: Long): Boolean = { + + val startTime = System.currentTimeMillis + def isTimedout = System.currentTimeMillis - startTime >= timeoutMs + + awaitTerminationLock.synchronized { + while (!isTimedout && lastTerminatedQuery == null) { + awaitTerminationLock.wait(10) + } + if (lastTerminatedQuery != null && lastTerminatedQuery.exception.nonEmpty) { + throw lastTerminatedQuery.exception.get + } + lastTerminatedQuery != null + } + } + + /** + * Forget about past terminated queries so that `awaitAnyTermination()` can be used again to + * wait for new terminations. + * + * @since 2.0.0 + */ + def resetTerminated(): Unit = { + awaitTerminationLock.synchronized { + lastTerminatedQuery = null + } + } + + /** + * Register a [[ContinuousQueryListener]] to receive up-calls for life cycle events of + * [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]]. + * + * @since 2.0.0 + */ + def addListener(listener: ContinuousQueryListener): Unit = { + listenerBus.addListener(listener) + } + + /** + * Deregister a [[ContinuousQueryListener]]. + * + * @since 2.0.0 + */ + def removeListener(listener: ContinuousQueryListener): Unit = { + listenerBus.removeListener(listener) + } + + /** Post a listener event */ + private[sql] def postListenerEvent(event: ContinuousQueryListener.Event): Unit = { + listenerBus.post(event) + } + + /** Start a query */ + private[sql] def startQuery(name: String, df: DataFrame, sink: Sink): ContinuousQuery = { + activeQueriesLock.synchronized { + if (activeQueries.contains(name)) { + throw new IllegalArgumentException( + s"Cannot start query with name $name as a query with that name is already active") + } + val query = new StreamExecution(sqlContext, name, df.logicalPlan, sink) + query.start() + activeQueries.put(name, query) + query + } + } + + /** Notify (by the ContinuousQuery) that the query has been terminated */ + private[sql] def notifyQueryTermination(terminatedQuery: ContinuousQuery): Unit = { + activeQueriesLock.synchronized { + activeQueries -= terminatedQuery.name + } + awaitTerminationLock.synchronized { + if (lastTerminatedQuery == null || terminatedQuery.exception.nonEmpty) { + lastTerminatedQuery = terminatedQuery + } + awaitTerminationLock.notifyAll() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 8060198968..d6bdd3d825 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -205,6 +205,17 @@ final class DataFrameWriter private[sql](df: DataFrame) { df) } + /** + * Specifies the name of the [[ContinuousQuery]] that can be started with `stream()`. + * This name must be unique among all the currently active queries in the associated SQLContext. + * + * @since 2.0.0 + */ + def queryName(queryName: String): DataFrameWriter = { + this.extraOptions += ("queryName" -> queryName) + this + } + /** * Starts the execution of the streaming query, which will continually output results to the given * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with @@ -230,7 +241,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { extraOptions.toMap, normalizedParCols.getOrElse(Nil)) - new StreamExecution(df.sqlContext, df.logicalPlan, sink) + df.sqlContext.continuousQueryManager.startQuery( + extraOptions.getOrElse("queryName", StreamExecution.nextName), df, sink) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1661fdbec5..050a1031c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -181,6 +181,8 @@ class SQLContext private[sql]( @transient lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager + protected[sql] lazy val continuousQueryManager = new ContinuousQueryManager(this) + @transient protected[sql] lazy val catalog: Catalog = new SimpleCatalog(conf) @@ -835,6 +837,16 @@ class SQLContext private[sql]( DataFrame(this, ShowTablesCommand(Some(databaseName))) } + /** + * Returns a [[ContinuousQueryManager]] that allows managing all the + * [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active on `this` context. + * + * @since 2.0.0 + */ + def streams: ContinuousQueryManager = { + continuousQueryManager + } + /** * Returns the names of tables in the current database as an array. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala new file mode 100644 index 0000000000..ce21451b2c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.streaming.{Offset, Sink} + +/** + * :: Experimental :: + * Status and metrics of a streaming [[Sink]]. + * + * @param description Description of the source corresponding to this status + * @param offset Current offset up to which data has been written by the sink + * @since 2.0.0 + */ +@Experimental +class SinkStatus private[sql]( + val description: String, + val offset: Option[Offset]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala new file mode 100644 index 0000000000..2479e67e36 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.streaming.{Offset, Source} + +/** + * :: Experimental :: + * Status and metrics of a streaming [[Source]]. + * + * @param description Description of the source corresponding to this status + * @param offset Current offset of the source, if known + * @since 2.0.0 + */ +@Experimental +class SourceStatus private[sql] ( + val description: String, + val offset: Option[Offset]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala new file mode 100644 index 0000000000..b1d24b6cfc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerEvent} +import org.apache.spark.sql.util.ContinuousQueryListener +import org.apache.spark.sql.util.ContinuousQueryListener._ +import org.apache.spark.util.ListenerBus + +/** + * A bus to forward events to [[ContinuousQueryListener]]s. This one will wrap received + * [[ContinuousQueryListener.Event]]s as WrappedContinuousQueryListenerEvents and send them to the + * Spark listener bus. It also registers itself with Spark listener bus, so that it can receive + * WrappedContinuousQueryListenerEvents, unwrap them as ContinuousQueryListener.Events and + * dispatch them to ContinuousQueryListener. + */ +class ContinuousQueryListenerBus(sparkListenerBus: LiveListenerBus) + extends SparkListener with ListenerBus[ContinuousQueryListener, ContinuousQueryListener.Event] { + + sparkListenerBus.addListener(this) + + /** + * Post a ContinuousQueryListener event to the Spark listener bus asynchronously. This event will + * be dispatched to all ContinuousQueryListener in the thread of the Spark listener bus. + */ + def post(event: ContinuousQueryListener.Event) { + event match { + case s: QueryStarted => + postToAll(s) + case _ => + sparkListenerBus.post(new WrappedContinuousQueryListenerEvent(event)) + } + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case WrappedContinuousQueryListenerEvent(e) => + postToAll(e) + case _ => + } + } + + override protected def doPostEvent( + listener: ContinuousQueryListener, + event: ContinuousQueryListener.Event): Unit = { + event match { + case queryStarted: QueryStarted => + listener.onQueryStarted(queryStarted) + case queryProgress: QueryProgress => + listener.onQueryProgress(queryProgress) + case queryTerminated: QueryTerminated => + listener.onQueryTerminated(queryTerminated) + case _ => + } + } + + /** + * Wrapper for StreamingListenerEvent as SparkListenerEvent so that it can be posted to Spark + * listener bus. + */ + private case class WrappedContinuousQueryListenerEvent( + streamingListenerEvent: ContinuousQueryListener.Event) extends SparkListenerEvent { + + // Do not log streaming events in event log as history server does not support these events. + protected[spark] override def logEvent: Boolean = false + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index ebebb82971..bc7c520930 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -17,16 +17,20 @@ package org.apache.spark.sql.execution.streaming -import java.lang.Thread.UncaughtExceptionHandler +import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal import org.apache.spark.Logging -import org.apache.spark.sql.{ContinuousQuery, DataFrame, SQLContext} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.ContinuousQueryListener +import org.apache.spark.sql.util.ContinuousQueryListener._ /** * Manages the execution of a streaming Spark SQL query that is occurring in a separate thread. @@ -35,15 +39,15 @@ import org.apache.spark.sql.execution.QueryExecution * and the results are committed transactionally to the given [[Sink]]. */ class StreamExecution( - sqlContext: SQLContext, + val sqlContext: SQLContext, + override val name: String, private[sql] val logicalPlan: LogicalPlan, val sink: Sink) extends ContinuousQuery with Logging { /** An monitor used to wait/notify when batches complete. */ private val awaitBatchLock = new Object - - @volatile - private var batchRun = false + private val startLatch = new CountDownLatch(1) + private val terminationLatch = new CountDownLatch(1) /** Minimum amount of time in between the start of each batch. */ private val minBatchTime = 10 @@ -55,9 +59,92 @@ class StreamExecution( private val sources = logicalPlan.collect { case s: StreamingRelation => s.source } - // Start the execution at the current offsets stored in the sink. (i.e. avoid reprocessing data - // that we have already processed). - { + /** Defines the internal state of execution */ + @volatile + private var state: State = INITIALIZED + + @volatile + private[sql] var lastExecution: QueryExecution = null + + @volatile + private[sql] var streamDeathCause: ContinuousQueryException = null + + /** The thread that runs the micro-batches of this stream. */ + private[sql] val microBatchThread = new Thread(s"stream execution thread for $name") { + override def run(): Unit = { runBatches() } + } + + /** Whether the query is currently active or not */ + override def isActive: Boolean = state == ACTIVE + + /** Returns current status of all the sources. */ + override def sourceStatuses: Array[SourceStatus] = { + sources.map(s => new SourceStatus(s.toString, streamProgress.get(s))).toArray + } + + /** Returns current status of the sink. */ + override def sinkStatus: SinkStatus = new SinkStatus(sink.toString, sink.currentOffset) + + /** Returns the [[ContinuousQueryException]] if the query was terminated by an exception. */ + override def exception: Option[ContinuousQueryException] = Option(streamDeathCause) + + /** + * Starts the execution. This returns only after the thread has started and [[QueryStarted]] event + * has been posted to all the listeners. + */ + private[sql] def start(): Unit = { + microBatchThread.setDaemon(true) + microBatchThread.start() + startLatch.await() // Wait until thread started and QueryStart event has been posted + } + + /** + * Repeatedly attempts to run batches as data arrives. + * + * Note that this method ensures that [[QueryStarted]] and [[QueryTerminated]] events are posted + * so that listeners are guaranteed to get former event before the latter. Furthermore, this + * method also ensures that [[QueryStarted]] event is posted before the `start()` method returns. + */ + private def runBatches(): Unit = { + try { + // Mark ACTIVE and then post the event. QueryStarted event is synchronously sent to listeners, + // so must mark this as ACTIVE first. + state = ACTIVE + postEvent(new QueryStarted(this)) // Assumption: Does not throw exception. + + // Unblock starting thread + startLatch.countDown() + + // While active, repeatedly attempt to run batches. + SQLContext.setActive(sqlContext) + populateStartOffsets() + logInfo(s"Stream running at $streamProgress") + while (isActive) { + attemptBatch() + Thread.sleep(minBatchTime) // TODO: Could be tighter + } + } catch { + case _: InterruptedException if state == TERMINATED => // interrupted by stop() + case NonFatal(e) => + streamDeathCause = new ContinuousQueryException( + this, + s"Query $name terminated with exception: ${e.getMessage}", + e, + Some(streamProgress.toCompositeOffset(sources))) + logError(s"Query $name terminated with error", e) + } finally { + state = TERMINATED + sqlContext.streams.notifyQueryTermination(StreamExecution.this) + postEvent(new QueryTerminated(this)) + terminationLatch.countDown() + } + } + + /** + * Populate the start offsets to start the execution at the current offsets stored in the sink + * (i.e. avoid reprocessing data that we have already processed). + */ + private def populateStartOffsets(): Unit = { sink.currentOffset match { case Some(c: CompositeOffset) => val storedProgress = c.offsets @@ -74,37 +161,8 @@ class StreamExecution( } } - logInfo(s"Stream running at $streamProgress") - - /** When false, signals to the microBatchThread that it should stop running. */ - @volatile private var shouldRun = true - - /** The thread that runs the micro-batches of this stream. */ - private[sql] val microBatchThread = new Thread("stream execution thread") { - override def run(): Unit = { - SQLContext.setActive(sqlContext) - while (shouldRun) { - attemptBatch() - Thread.sleep(minBatchTime) // TODO: Could be tighter - } - } - } - microBatchThread.setDaemon(true) - microBatchThread.setUncaughtExceptionHandler( - new UncaughtExceptionHandler { - override def uncaughtException(t: Thread, e: Throwable): Unit = { - streamDeathCause = e - } - }) - microBatchThread.start() - - @volatile - private[sql] var lastExecution: QueryExecution = null - @volatile - private[sql] var streamDeathCause: Throwable = null - /** - * Checks to see if any new data is present in any of the sources. When new data is available, + * Checks to see if any new data is present in any of the sources. When new data is available, * a batch is executed and passed to the sink, updating the currentOffsets. */ private def attemptBatch(): Unit = { @@ -150,36 +208,43 @@ class StreamExecution( streamProgress.synchronized { // Update the offsets and calculate a new composite offset newOffsets.foreach(streamProgress.update) - val newStreamProgress = logicalPlan.collect { - case StreamingRelation(source, _) => streamProgress.get(source) - } - val batchOffset = CompositeOffset(newStreamProgress) // Construct the batch and send it to the sink. + val batchOffset = streamProgress.toCompositeOffset(sources) val nextBatch = new Batch(batchOffset, new DataFrame(sqlContext, newPlan)) sink.addBatch(nextBatch) } - batchRun = true awaitBatchLock.synchronized { // Wake up any threads that are waiting for the stream to progress. awaitBatchLock.notifyAll() } val batchTime = (System.nanoTime() - startTime).toDouble / 1000000 - logInfo(s"Compete up to $newOffsets in ${batchTime}ms") + logInfo(s"Completed up to $newOffsets in ${batchTime}ms") + postEvent(new QueryProgress(this)) } logDebug(s"Waiting for data, current: $streamProgress") } + private def postEvent(event: ContinuousQueryListener.Event) { + sqlContext.streams.postListenerEvent(event) + } + /** * Signals to the thread executing micro-batches that it should stop running after the next * batch. This method blocks until the thread stops running. */ - def stop(): Unit = { - shouldRun = false - if (microBatchThread.isAlive) { microBatchThread.join() } + override def stop(): Unit = { + // Set the state to TERMINATED so that the batching thread knows that it was interrupted + // intentionally + state = TERMINATED + if (microBatchThread.isAlive) { + microBatchThread.interrupt() + microBatchThread.join() + } + logInfo(s"Query $name was stopped") } /** @@ -198,14 +263,60 @@ class StreamExecution( logDebug(s"Unblocked at $newOffset for $source") } - override def toString: String = + override def awaitTermination(): Unit = { + if (state == INITIALIZED) { + throw new IllegalStateException("Cannot wait for termination on a query that has not started") + } + terminationLatch.await() + if (streamDeathCause != null) { + throw streamDeathCause + } + } + + override def awaitTermination(timeoutMs: Long): Boolean = { + if (state == INITIALIZED) { + throw new IllegalStateException("Cannot wait for termination on a query that has not started") + } + require(timeoutMs > 0, "Timeout has to be positive") + terminationLatch.await(timeoutMs, TimeUnit.MILLISECONDS) + if (streamDeathCause != null) { + throw streamDeathCause + } else { + !isActive + } + } + + override def toString: String = { + s"Continuous Query - $name [state = $state]" + } + + def toDebugString: String = { + val deathCauseStr = if (streamDeathCause != null) { + "Error:\n" + stackTraceToString(streamDeathCause.cause) + } else "" s""" - |=== Streaming Query === - |CurrentOffsets: $streamProgress - |Thread State: ${microBatchThread.getState} - |${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""} + |=== Continuous Query === + |Name: $name + |Current Offsets: $streamProgress | + |Current State: $state + |Thread State: ${microBatchThread.getState} + | + |Logical Plan: |$logicalPlan + | + |$deathCauseStr """.stripMargin + } + + trait State + case object INITIALIZED extends State + case object ACTIVE extends State + case object TERMINATED extends State } +private[sql] object StreamExecution { + private val nextId = new AtomicInteger() + + def nextName: String = s"query-${nextId.getAndIncrement}" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala index 0ded1d7152..d45b9bd983 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -55,6 +55,10 @@ class StreamProgress { copied } + private[sql] def toCompositeOffset(source: Seq[Source]): CompositeOffset = { + CompositeOffset(source.map(get)) + } + override def toString: String = currentOffsets.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index e6a0842936..8124df15af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext} -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{encoderFor, RowEncoder} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.StructType object MemoryStream { @@ -46,14 +47,13 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) protected val logicalPlan = StreamingRelation(this) protected val output = logicalPlan.output protected val batches = new ArrayBuffer[Dataset[A]] + protected var currentOffset: LongOffset = new LongOffset(-1) protected def blockManager = SparkEnv.get.blockManager def schema: StructType = encoder.schema - def getCurrentOffset: Offset = currentOffset - def toDS()(implicit sqlContext: SQLContext): Dataset[A] = { new Dataset(sqlContext, logicalPlan) } @@ -62,6 +62,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) new DataFrame(sqlContext, logicalPlan) } + def addData(data: A*): Offset = { + addData(data.toTraversable) + } + def addData(data: TraversableOnce[A]): Offset = { import sqlContext.implicits._ this.synchronized { @@ -110,6 +114,7 @@ class MemorySink(schema: StructType) extends Sink with Logging { } override def addBatch(nextBatch: Batch): Unit = synchronized { + nextBatch.data.collect() // 'compute' the batch's data and record the batch batches.append(nextBatch) } @@ -131,8 +136,13 @@ class MemorySink(schema: StructType) extends Sink with Logging { batches.dropRight(num) } - override def toString: String = synchronized { - batches.map(b => s"${b.end}: ${b.data.collect().mkString(" ")}").mkString("\n") + def toDebugString: String = synchronized { + batches.map { b => + val dataStr = try b.data.collect().mkString(" ") catch { + case NonFatal(e) => "[Error converting to string]" + } + s"${b.end}: $dataStr" + }.mkString("\n") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala new file mode 100644 index 0000000000..73c78d1b62 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.ContinuousQuery +import org.apache.spark.sql.util.ContinuousQueryListener._ + +/** + * :: Experimental :: + * Interface for listening to events related to [[ContinuousQuery ContinuousQueries]]. + * @note The methods are not thread-safe as they may be called from different threads. + */ +@Experimental +abstract class ContinuousQueryListener { + + /** + * Called when a query is started. + * @note This is called synchronously with + * [[org.apache.spark.sql.DataFrameWriter `DataFrameWriter.stream()`]], + * that is, `onQueryStart` will be called on all listeners before `DataFrameWriter.stream()` + * returns the corresponding [[ContinuousQuery]]. + */ + def onQueryStarted(queryStarted: QueryStarted) + + /** Called when there is some status update (ingestion rate updated, etc. */ + def onQueryProgress(queryProgress: QueryProgress) + + /** Called when a query is stopped, with or without error */ + def onQueryTerminated(queryTerminated: QueryTerminated) +} + + +/** + * :: Experimental :: + * Companion object of [[ContinuousQueryListener]] that defines the listener events. + */ +@Experimental +object ContinuousQueryListener { + + /** Base type of [[ContinuousQueryListener]] events */ + trait Event + + /** Event representing the start of a query */ + class QueryStarted private[sql](val query: ContinuousQuery) extends Event + + /** Event representing any progress updates in a query */ + class QueryProgress private[sql](val query: ContinuousQuery) extends Event + + /** Event representing that termination of a query */ + class QueryTerminated private[sql](val query: ContinuousQuery) extends Event +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 7e388ea602..62710e72fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -21,9 +21,16 @@ import java.lang.Thread.UncaughtExceptionHandler import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.language.experimental.macros +import scala.reflect.ClassTag import scala.util.Random +import scala.util.control.NonFatal -import org.scalatest.concurrent.Timeouts +import org.scalatest.Assertions +import org.scalatest.concurrent.{Eventually, Timeouts} +import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.time.Span import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} @@ -64,7 +71,7 @@ trait StreamTest extends QueryTest with Timeouts { } /** How long to wait for an active stream to catch up when checking a result. */ - val streamingTimout = 10.seconds + val streamingTimeout = 10.seconds /** A trait for actions that can be performed while testing a streaming DataFrame. */ trait StreamAction @@ -128,7 +135,38 @@ trait StreamTest extends QueryTest with Timeouts { case object StartStream extends StreamAction /** Signals that a failure is expected and should not kill the test. */ - case object ExpectFailure extends StreamAction + case class ExpectFailure[T <: Throwable : ClassTag]() extends StreamAction { + val causeClass: Class[T] = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] + override def toString(): String = s"ExpectFailure[${causeClass.getCanonicalName}]" + } + + /** Assert that a body is true */ + class Assert(condition: => Boolean, val message: String = "") extends StreamAction { + def run(): Unit = { Assertions.assert(condition) } + override def toString: String = s"Assert(, $message)" + } + + object Assert { + def apply(condition: => Boolean, message: String = ""): Assert = new Assert(condition, message) + def apply(message: String)(body: => Unit): Assert = new Assert( { body; true }, message) + def apply(body: => Unit): Assert = new Assert( { body; true }, "") + } + + /** Assert that a condition on the active query is true */ + class AssertOnQuery(val condition: StreamExecution => Boolean, val message: String) + extends StreamAction { + override def toString: String = s"AssertOnQuery(, $message)" + } + + object AssertOnQuery { + def apply(condition: StreamExecution => Boolean, message: String = ""): AssertOnQuery = { + new AssertOnQuery(condition, message) + } + + def apply(message: String)(condition: StreamExecution => Boolean): AssertOnQuery = { + new AssertOnQuery(condition, message) + } + } /** A helper for running actions on a Streaming Dataset. See `checkAnswer(DataFrame)`. */ def testStream(stream: Dataset[_])(actions: StreamAction*): Unit = @@ -145,6 +183,7 @@ trait StreamTest extends QueryTest with Timeouts { var pos = 0 var currentPlan: LogicalPlan = stream.logicalPlan var currentStream: StreamExecution = null + var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Source, Offset]() val sink = new MemorySink(stream.schema) @@ -170,6 +209,7 @@ trait StreamTest extends QueryTest with Timeouts { def threadState = if (currentStream != null && currentStream.microBatchThread.isAlive) "alive" else "dead" + def testState = s""" |== Progress == @@ -181,16 +221,49 @@ trait StreamTest extends QueryTest with Timeouts { |${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""} | |== Sink == - |$sink + |${sink.toDebugString} | |== Plan == |${if (currentStream != null) currentStream.lastExecution else ""} - """ + """.stripMargin - def checkState(check: Boolean, error: String) = if (!check) { + def verify(condition: => Boolean, message: String): Unit = { + try { + Assertions.assert(condition) + } catch { + case NonFatal(e) => + failTest(message, e) + } + } + + def eventually[T](message: String)(func: => T): T = { + try { + Eventually.eventually(Timeout(streamingTimeout)) { + func + } + } catch { + case NonFatal(e) => + failTest(message, e) + } + } + + def failTest(message: String, cause: Throwable = null) = { + + // Recursively pretty print a exception with truncated stacktrace and internal cause + def exceptionToString(e: Throwable, prefix: String = ""): String = { + val base = s"$prefix${e.getMessage}" + + e.getStackTrace.take(10).mkString(s"\n$prefix", s"\n$prefix\t", "\n") + if (e.getCause != null) { + base + s"\n$prefix\tCaused by: " + exceptionToString(e.getCause, s"$prefix\t") + } else { + base + } + } + val c = Option(cause).map(exceptionToString(_)) + val m = if (message != null && message.size > 0) Some(message) else None fail( s""" - |Invalid State: $error + |${(m ++ c).mkString(": ")} |$testState """.stripMargin) } @@ -201,9 +274,13 @@ trait StreamTest extends QueryTest with Timeouts { startedTest.foreach { action => action match { case StartStream => - checkState(currentStream == null, "stream already running") - - currentStream = new StreamExecution(sqlContext, stream.logicalPlan, sink) + verify(currentStream == null, "stream already running") + lastStream = currentStream + currentStream = + sqlContext + .streams + .startQuery(StreamExecution.nextName, stream, sink) + .asInstanceOf[StreamExecution] currentStream.microBatchThread.setUncaughtExceptionHandler( new UncaughtExceptionHandler { override def uncaughtException(t: Thread, e: Throwable): Unit = { @@ -213,77 +290,100 @@ trait StreamTest extends QueryTest with Timeouts { }) case StopStream => - checkState(currentStream != null, "can not stop a stream that is not running") - currentStream.stop() - currentStream = null - - case DropBatches(num) => - checkState(currentStream == null, "dropping batches while running leads to corruption") - sink.dropBatches(num) - - case ExpectFailure => - try failAfter(streamingTimout) { - while (streamDeathCause == null) { - Thread.sleep(100) - } + verify(currentStream != null, "can not stop a stream that is not running") + try failAfter(streamingTimeout) { + currentStream.stop() + verify(!currentStream.microBatchThread.isAlive, + s"microbatch thread not stopped") + verify(!currentStream.isActive, + "query.isActive() is false even after stopping") + verify(currentStream.exception.isEmpty, + s"query.exception() is not empty after clean stop: " + + currentStream.exception.map(_.toString()).getOrElse("")) } catch { case _: InterruptedException => case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => - fail( - s""" - |Timed out while waiting for failure. - |$testState - """.stripMargin) + failTest("Timed out while stopping and waiting for microbatchthread to terminate.") + case t: Throwable => + failTest("Error while stopping stream", t) + } finally { + lastStream = currentStream + currentStream = null } - currentStream = null - streamDeathCause = null + case DropBatches(num) => + verify(currentStream == null, "dropping batches while running leads to corruption") + sink.dropBatches(num) + + case ef: ExpectFailure[_] => + verify(currentStream != null, "can not expect failure when stream is not running") + try failAfter(streamingTimeout) { + val thrownException = intercept[ContinuousQueryException] { + currentStream.awaitTermination() + } + eventually("microbatch thread not stopped after termination with failure") { + assert(!currentStream.microBatchThread.isAlive) + } + verify(thrownException.query.eq(currentStream), + s"incorrect query reference in exception") + verify(currentStream.exception === Some(thrownException), + s"incorrect exception returned by query.exception()") + + val exception = currentStream.exception.get + verify(exception.cause.getClass === ef.causeClass, + "incorrect cause in exception returned by query.exception()\n" + + s"\tExpected: ${ef.causeClass}\n\tReturned: ${exception.cause.getClass}") + } catch { + case _: InterruptedException => + case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest("Timed out while waiting for failure") + case t: Throwable => + failTest("Error while checking stream failure", t) + } finally { + lastStream = currentStream + currentStream = null + streamDeathCause = null + } + + case a: AssertOnQuery => + verify(currentStream != null || lastStream != null, + "cannot assert when not stream has been started") + val streamToAssert = Option(currentStream).getOrElse(lastStream) + verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") + + case a: Assert => + val streamToAssert = Option(currentStream).getOrElse(lastStream) + verify({ a.run(); true }, s"Assert failed: ${a.message}") case a: AddData => awaiting.put(a.source, a.addData()) case CheckAnswerRows(expectedAnswer) => - checkState(currentStream != null, "stream not running") + verify(currentStream != null, "stream not running") // Block until all data added has been processed awaiting.foreach { case (source, offset) => - failAfter(streamingTimout) { + failAfter(streamingTimeout) { currentStream.awaitOffset(source, offset) } } val allData = try sink.allData catch { case e: Exception => - fail( - s""" - |Exception while getting data from sink $e - |$testState - """.stripMargin) + failTest("Exception while getting data from sink", e) } QueryTest.sameRows(expectedAnswer, allData).foreach { - error => fail( - s""" - |$error - |$testState - """.stripMargin) + error => failTest(error) } } pos += 1 } } catch { case _: InterruptedException if streamDeathCause != null => - fail( - s""" - |Stream Thread Died - |$testState - """.stripMargin) + failTest("Stream Thread Died") case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => - fail( - s""" - |Timed out waiting for stream - |$testState - """.stripMargin) + failTest("Timed out waiting for stream") } finally { if (currentStream != null && currentStream.microBatchThread.isAlive) { currentStream.stop() @@ -335,7 +435,8 @@ trait StreamTest extends QueryTest with Timeouts { case r if r < 0.7 => // AddData addRandomData() - case _ => // StartStream + case _ => // StopStream + addCheck() actions += StopStream running = false } @@ -345,4 +446,59 @@ trait StreamTest extends QueryTest with Timeouts { addCheck() testStream(ds)(actions: _*) } + + + object AwaitTerminationTester { + + trait ExpectedBehavior + + /** Expect awaitTermination to not be blocked */ + case object ExpectNotBlocked extends ExpectedBehavior + + /** Expect awaitTermination to get blocked */ + case object ExpectBlocked extends ExpectedBehavior + + /** Expect awaitTermination to throw an exception */ + case class ExpectException[E <: Exception]()(implicit val t: ClassTag[E]) + extends ExpectedBehavior + + private val DEFAULT_TEST_TIMEOUT = 1 second + + def test( + expectedBehavior: ExpectedBehavior, + awaitTermFunc: () => Unit, + testTimeout: Span = DEFAULT_TEST_TIMEOUT + ): Unit = { + + expectedBehavior match { + case ExpectNotBlocked => + withClue("Got blocked when expected non-blocking.") { + failAfter(testTimeout) { + awaitTermFunc() + } + } + + case ExpectBlocked => + withClue("Was not blocked when expected.") { + intercept[TestFailedDueToTimeoutException] { + failAfter(testTimeout) { + awaitTermFunc() + } + } + } + + case e: ExpectException[_] => + val thrownException = + withClue(s"Did not throw ${e.t.runtimeClass.getSimpleName} when expected.") { + intercept[ContinuousQueryException] { + failAfter(testTimeout) { + awaitTermFunc() + } + } + } + assert(thrownException.cause.getClass === e.t.runtimeClass, + "exception of incorrect type was throw") + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala new file mode 100644 index 0000000000..daf08efca4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import scala.concurrent.Future +import scala.util.Random +import scala.util.control.NonFatal + +import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.time.Span +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkException +import org.apache.spark.sql.{ContinuousQuery, Dataset, StreamTest} +import org.apache.spark.sql.execution.streaming.{MemorySink, MemoryStream, StreamExecution, StreamingRelation} +import org.apache.spark.sql.test.SharedSQLContext + +class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { + + import AwaitTerminationTester._ + import testImplicits._ + + override val streamingTimeout = 20.seconds + + before { + assert(sqlContext.streams.active.isEmpty) + sqlContext.streams.resetTerminated() + } + + after { + assert(sqlContext.streams.active.isEmpty) + sqlContext.streams.resetTerminated() + } + + test("listing") { + val (m1, ds1) = makeDataset + val (m2, ds2) = makeDataset + val (m3, ds3) = makeDataset + + withQueriesOn(ds1, ds2, ds3) { queries => + require(queries.size === 3) + assert(sqlContext.streams.active.toSet === queries.toSet) + val (q1, q2, q3) = (queries(0), queries(1), queries(2)) + + assert(sqlContext.streams.get(q1.name).eq(q1)) + assert(sqlContext.streams.get(q2.name).eq(q2)) + assert(sqlContext.streams.get(q3.name).eq(q3)) + intercept[IllegalArgumentException] { + sqlContext.streams.get("non-existent-name") + } + + q1.stop() + + assert(sqlContext.streams.active.toSet === Set(q2, q3)) + val ex1 = withClue("no error while getting non-active query") { + intercept[IllegalArgumentException] { + sqlContext.streams.get(q1.name) + } + } + assert(ex1.getMessage.contains(q1.name), "error does not contain name of query to be fetched") + assert(sqlContext.streams.get(q2.name).eq(q2)) + + m2.addData(0) // q2 should terminate with error + + eventually(Timeout(streamingTimeout)) { + require(!q2.isActive) + require(q2.exception.isDefined) + } + val ex2 = withClue("no error while getting non-active query") { + intercept[IllegalArgumentException] { + sqlContext.streams.get(q2.name).eq(q2) + } + } + + assert(sqlContext.streams.active.toSet === Set(q3)) + } + } + + test("awaitAnyTermination without timeout and resetTerminated") { + val datasets = Seq.fill(5)(makeDataset._2) + withQueriesOn(datasets: _*) { queries => + require(queries.size === datasets.size) + assert(sqlContext.streams.active.toSet === queries.toSet) + + // awaitAnyTermination should be blocking + testAwaitAnyTermination(ExpectBlocked) + + // Stop a query asynchronously and see if it is reported through awaitAnyTermination + val q1 = stopRandomQueryAsync(stopAfter = 100 milliseconds, withError = false) + testAwaitAnyTermination(ExpectNotBlocked) + require(!q1.isActive) // should be inactive by the time the prev awaitAnyTerm returned + + // All subsequent calls to awaitAnyTermination should be non-blocking + testAwaitAnyTermination(ExpectNotBlocked) + + // Resetting termination should make awaitAnyTermination() blocking again + sqlContext.streams.resetTerminated() + testAwaitAnyTermination(ExpectBlocked) + + // Terminate a query asynchronously with exception and see awaitAnyTermination throws + // the exception + val q2 = stopRandomQueryAsync(100 milliseconds, withError = true) + testAwaitAnyTermination(ExpectException[SparkException]) + require(!q2.isActive) // should be inactive by the time the prev awaitAnyTerm returned + + // All subsequent calls to awaitAnyTermination should throw the exception + testAwaitAnyTermination(ExpectException[SparkException]) + + // Resetting termination should make awaitAnyTermination() blocking again + sqlContext.streams.resetTerminated() + testAwaitAnyTermination(ExpectBlocked) + + // Terminate multiple queries, one with failure and see whether awaitAnyTermination throws + // the exception + val q3 = stopRandomQueryAsync(10 milliseconds, withError = false) + testAwaitAnyTermination(ExpectNotBlocked) + require(!q3.isActive) + val q4 = stopRandomQueryAsync(10 milliseconds, withError = true) + eventually(Timeout(streamingTimeout)) { require(!q4.isActive) } + // After q4 terminates with exception, awaitAnyTerm should start throwing exception + testAwaitAnyTermination(ExpectException[SparkException]) + } + } + + test("awaitAnyTermination with timeout and resetTerminated") { + val datasets = Seq.fill(6)(makeDataset._2) + withQueriesOn(datasets: _*) { queries => + require(queries.size === datasets.size) + assert(sqlContext.streams.active.toSet === queries.toSet) + + // awaitAnyTermination should be blocking or non-blocking depending on timeout values + testAwaitAnyTermination( + ExpectBlocked, + awaitTimeout = 2 seconds, + expectedReturnedValue = false, + testBehaviorFor = 1 second) + + testAwaitAnyTermination( + ExpectNotBlocked, + awaitTimeout = 50 milliseconds, + expectedReturnedValue = false, + testBehaviorFor = 1 second) + + // Stop a query asynchronously within timeout and awaitAnyTerm should be unblocked + val q1 = stopRandomQueryAsync(stopAfter = 100 milliseconds, withError = false) + testAwaitAnyTermination( + ExpectNotBlocked, + awaitTimeout = 1 second, + expectedReturnedValue = true, + testBehaviorFor = 2 seconds) + require(!q1.isActive) // should be inactive by the time the prev awaitAnyTerm returned + + // All subsequent calls to awaitAnyTermination should be non-blocking even if timeout is high + testAwaitAnyTermination( + ExpectNotBlocked, awaitTimeout = 2 seconds, expectedReturnedValue = true) + + // Resetting termination should make awaitAnyTermination() blocking again + sqlContext.streams.resetTerminated() + testAwaitAnyTermination( + ExpectBlocked, + awaitTimeout = 2 seconds, + expectedReturnedValue = false, + testBehaviorFor = 1 second) + + // Terminate a query asynchronously with exception within timeout, awaitAnyTermination should + // throws the exception + val q2 = stopRandomQueryAsync(100 milliseconds, withError = true) + testAwaitAnyTermination( + ExpectException[SparkException], + awaitTimeout = 1 second, + testBehaviorFor = 2 seconds) + require(!q2.isActive) // should be inactive by the time the prev awaitAnyTerm returned + + // All subsequent calls to awaitAnyTermination should throw the exception + testAwaitAnyTermination( + ExpectException[SparkException], + awaitTimeout = 1 second, + testBehaviorFor = 2 seconds) + + // Terminate a query asynchronously outside the timeout, awaitAnyTerm should be blocked + sqlContext.streams.resetTerminated() + val q3 = stopRandomQueryAsync(1 second, withError = true) + testAwaitAnyTermination( + ExpectNotBlocked, + awaitTimeout = 100 milliseconds, + expectedReturnedValue = false, + testBehaviorFor = 2 seconds) + + // After that query is stopped, awaitAnyTerm should throw exception + eventually(Timeout(streamingTimeout)) { require(!q3.isActive) } // wait for query to stop + testAwaitAnyTermination( + ExpectException[SparkException], + awaitTimeout = 100 milliseconds, + testBehaviorFor = 2 seconds) + + + // Terminate multiple queries, one with failure and see whether awaitAnyTermination throws + // the exception + sqlContext.streams.resetTerminated() + + val q4 = stopRandomQueryAsync(10 milliseconds, withError = false) + testAwaitAnyTermination( + ExpectNotBlocked, awaitTimeout = 1 second, expectedReturnedValue = true) + require(!q4.isActive) + val q5 = stopRandomQueryAsync(10 milliseconds, withError = true) + eventually(Timeout(streamingTimeout)) { require(!q5.isActive) } + // After q5 terminates with exception, awaitAnyTerm should start throwing exception + testAwaitAnyTermination(ExpectException[SparkException], awaitTimeout = 100 milliseconds) + } + } + + + /** Run a body of code by defining a query each on multiple datasets */ + private def withQueriesOn(datasets: Dataset[_]*)(body: Seq[ContinuousQuery] => Unit): Unit = { + failAfter(streamingTimeout) { + val queries = withClue("Error starting queries") { + datasets.map { ds => + @volatile var query: StreamExecution = null + try { + val df = ds.toDF + query = sqlContext + .streams + .startQuery(StreamExecution.nextName, df, new MemorySink(df.schema)) + .asInstanceOf[StreamExecution] + } catch { + case NonFatal(e) => + if (query != null) query.stop() + throw e + } + query + } + } + try { + body(queries) + } finally { + queries.foreach(_.stop()) + } + } + } + + /** Test the behavior of awaitAnyTermination */ + private def testAwaitAnyTermination( + expectedBehavior: ExpectedBehavior, + expectedReturnedValue: Boolean = false, + awaitTimeout: Span = null, + testBehaviorFor: Span = 2 seconds + ): Unit = { + + def awaitTermFunc(): Unit = { + if (awaitTimeout != null && awaitTimeout.toMillis > 0) { + val returnedValue = sqlContext.streams.awaitAnyTermination(awaitTimeout.toMillis) + assert(returnedValue === expectedReturnedValue, "Returned value does not match expected") + } else { + sqlContext.streams.awaitAnyTermination() + } + } + + AwaitTerminationTester.test(expectedBehavior, awaitTermFunc, testBehaviorFor) + } + + /** Stop a random active query either with `stop()` or with an error */ + private def stopRandomQueryAsync(stopAfter: Span, withError: Boolean): ContinuousQuery = { + + import scala.concurrent.ExecutionContext.Implicits.global + + val activeQueries = sqlContext.streams.active + val queryToStop = activeQueries(Random.nextInt(activeQueries.length)) + Future { + Thread.sleep(stopAfter.toMillis) + if (withError) { + logDebug(s"Terminating query ${queryToStop.name} with error") + queryToStop.asInstanceOf[StreamExecution].logicalPlan.collect { + case StreamingRelation(memoryStream, _) => + memoryStream.asInstanceOf[MemoryStream[Int]].addData(0) + } + } else { + logDebug(s"Stopping query ${queryToStop.name}") + queryToStop.stop() + } + } + queryToStop + } + + private def makeDataset: (MemoryStream[Int], Dataset[Int]) = { + val inputData = MemoryStream[Int] + val mapped = inputData.toDS.map(6 / _) + (inputData, mapped) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala new file mode 100644 index 0000000000..dac1a398ff --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.SparkException +import org.apache.spark.sql.StreamTest +import org.apache.spark.sql.execution.streaming.{CompositeOffset, LongOffset, MemoryStream, StreamExecution} +import org.apache.spark.sql.test.SharedSQLContext + +class ContinuousQuerySuite extends StreamTest with SharedSQLContext { + + import AwaitTerminationTester._ + import testImplicits._ + + test("lifecycle states and awaitTermination") { + val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map { 6 / _} + + testStream(mapped)( + AssertOnQuery(_.isActive === true), + AssertOnQuery(_.exception.isEmpty), + AddData(inputData, 1, 2), + CheckAnswer(6, 3), + TestAwaitTermination(ExpectBlocked), + TestAwaitTermination(ExpectBlocked, timeoutMs = 2000), + TestAwaitTermination(ExpectNotBlocked, timeoutMs = 10, expectedReturnValue = false), + StopStream, + AssertOnQuery(_.isActive === false), + AssertOnQuery(_.exception.isEmpty), + TestAwaitTermination(ExpectNotBlocked), + TestAwaitTermination(ExpectNotBlocked, timeoutMs = 2000, expectedReturnValue = true), + TestAwaitTermination(ExpectNotBlocked, timeoutMs = 10, expectedReturnValue = true), + StartStream, + AssertOnQuery(_.isActive === true), + AddData(inputData, 0), + ExpectFailure[SparkException], + AssertOnQuery(_.isActive === false), + TestAwaitTermination(ExpectException[SparkException]), + TestAwaitTermination(ExpectException[SparkException], timeoutMs = 2000), + TestAwaitTermination(ExpectException[SparkException], timeoutMs = 10), + AssertOnQuery( + q => q.exception.get.startOffset.get === q.streamProgress.toCompositeOffset(Seq(inputData)), + "incorrect start offset on exception") + ) + } + + test("source and sink statuses") { + val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map(6 / _) + + testStream(mapped)( + AssertOnQuery(_.sourceStatuses.length === 1), + AssertOnQuery(_.sourceStatuses(0).description.contains("Memory")), + AssertOnQuery(_.sourceStatuses(0).offset === None), + AssertOnQuery(_.sinkStatus.description.contains("Memory")), + AssertOnQuery(_.sinkStatus.offset === None), + AddData(inputData, 1, 2), + CheckAnswer(6, 3), + AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(0))), + AssertOnQuery(_.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(0)))), + AddData(inputData, 1, 2), + CheckAnswer(6, 3, 6, 3), + AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(1))), + AssertOnQuery(_.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(1)))), + AddData(inputData, 0), + ExpectFailure[SparkException], + AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(2))), + AssertOnQuery(_.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(1)))) + ) + } + + /** + * A [[StreamAction]] to test the behavior of `ContinuousQuery.awaitTermination()`. + * + * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown) + * @param timeoutMs Timeout in milliseconds + * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout) + * When timeoutMs > 0, awaitTermination(timeoutMs) is tested + * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used + */ + case class TestAwaitTermination( + expectedBehavior: ExpectedBehavior, + timeoutMs: Int = -1, + expectedReturnValue: Boolean = false + ) extends AssertOnQuery( + TestAwaitTermination.assertOnQueryCondition(expectedBehavior, timeoutMs, expectedReturnValue), + "Error testing awaitTermination behavior" + ) { + override def toString(): String = { + s"TestAwaitTermination($expectedBehavior, timeoutMs = $timeoutMs, " + + s"expectedReturnValue = $expectedReturnValue)" + } + } + + object TestAwaitTermination { + + /** + * Tests the behavior of `ContinuousQuery.awaitTermination`. + * + * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown) + * @param timeoutMs Timeout in milliseconds + * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout) + * When timeoutMs > 0, awaitTermination(timeoutMs) is tested + * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used + */ + def assertOnQueryCondition( + expectedBehavior: ExpectedBehavior, + timeoutMs: Int, + expectedReturnValue: Boolean + )(q: StreamExecution): Boolean = { + + def awaitTermFunc(): Unit = { + if (timeoutMs <= 0) { + q.awaitTermination() + } else { + val returnedValue = q.awaitTermination(timeoutMs) + assert(returnedValue === expectedReturnValue, "Returned value does not match expected") + } + } + AwaitTerminationTester.test(expectedBehavior, awaitTermFunc) + true // If the control reached here, then everything worked as expected + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index b762f9b90e..f060c6f623 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.streaming.test -import org.apache.spark.sql.{AnalysisException, SQLContext, StreamTest} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.{AnalysisException, ContinuousQuery, SQLContext, StreamTest} import org.apache.spark.sql.execution.streaming.{Batch, Offset, Sink, Source} import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} import org.apache.spark.sql.test.SharedSQLContext @@ -57,9 +59,13 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { } } -class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext { +class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { import testImplicits._ + after { + sqlContext.streams.active.foreach(_.stop()) + } + test("resolve default source") { sqlContext.read .format("org.apache.spark.sql.streaming.test") @@ -188,4 +194,63 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext { assert(LastOptions.parameters("boolOpt") == "false") assert(LastOptions.parameters("doubleOpt") == "6.7") } + + test("unique query names") { + + /** Start a query with a specific name */ + def startQueryWithName(name: String = ""): ContinuousQuery = { + sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream("/test") + .write + .format("org.apache.spark.sql.streaming.test") + .queryName(name) + .stream() + } + + /** Start a query without specifying a name */ + def startQueryWithoutName(): ContinuousQuery = { + sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream("/test") + .write + .format("org.apache.spark.sql.streaming.test") + .stream() + } + + /** Get the names of active streams */ + def activeStreamNames: Set[String] = { + val streams = sqlContext.streams.active + val names = streams.map(_.name).toSet + assert(streams.length === names.size, s"names of active queries are not unique: $names") + names + } + + val q1 = startQueryWithName("name") + + // Should not be able to start another query with the same name + intercept[IllegalArgumentException] { + startQueryWithName("name") + } + assert(activeStreamNames === Set("name")) + + // Should be able to start queries with other names + val q3 = startQueryWithName("another-name") + assert(activeStreamNames === Set("name", "another-name")) + + // Should be able to start queries with auto-generated names + val q4 = startQueryWithoutName() + assert(activeStreamNames.contains(q4.name)) + + // Should not be able to start a query with same auto-generated name + intercept[IllegalArgumentException] { + startQueryWithName(q4.name) + } + + // Should be able to start query with that name after stopping the previous query + q1.stop() + val q5 = startQueryWithName("name") + assert(activeStreamNames.contains("name")) + sqlContext.streams.active.foreach(_.stop()) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala new file mode 100644 index 0000000000..d6cc6ad86b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.util.control.NonFatal + +import org.scalatest.BeforeAndAfter +import org.scalatest.PrivateMethodTester._ +import org.scalatest.concurrent.AsyncAssertions.Waiter +import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.util.ContinuousQueryListener.{QueryProgress, QueryStarted, QueryTerminated} + +class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + assert(sqlContext.streams.active.isEmpty) + assert(addedListeners.isEmpty) + } + + test("single listener") { + val listener = new QueryStatusCollector + val input = MemoryStream[Int] + withListenerAdded(listener) { + testStream(input.toDS)( + StartStream, + Assert("Incorrect query status in onQueryStarted") { + val status = listener.startStatus + assert(status != null) + assert(status.active == true) + assert(status.sourceStatuses.size === 1) + assert(status.sourceStatuses(0).description.contains("Memory")) + + // The source and sink offsets must be None as this must be called before the + // batches have started + assert(status.sourceStatuses(0).offset === None) + assert(status.sinkStatus.offset === None) + + // No progress events or termination events + assert(listener.progressStatuses.isEmpty) + assert(listener.terminationStatus === null) + }, + AddDataMemory(input, Seq(1, 2, 3)), + CheckAnswer(1, 2, 3), + Assert("Incorrect query status in onQueryProgress") { + eventually(Timeout(streamingTimeout)) { + + // There should be only on progress event as batch has been processed + assert(listener.progressStatuses.size === 1) + val status = listener.progressStatuses.peek() + assert(status != null) + assert(status.active == true) + assert(status.sourceStatuses(0).offset === Some(LongOffset(0))) + assert(status.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(0)))) + + // No termination events + assert(listener.terminationStatus === null) + } + }, + StopStream, + Assert("Incorrect query status in onQueryTerminated") { + eventually(Timeout(streamingTimeout)) { + val status = listener.terminationStatus + assert(status != null) + + assert(status.active === false) // must be inactive by the time onQueryTerm is called + assert(status.sourceStatuses(0).offset === Some(LongOffset(0))) + assert(status.sinkStatus.offset === Some(CompositeOffset.fill(LongOffset(0)))) + } + listener.checkAsyncErrors() + } + ) + } + } + + test("adding and removing listener") { + def isListenerActive(listener: QueryStatusCollector): Boolean = { + listener.reset() + testStream(MemoryStream[Int].toDS)( + StartStream, + StopStream + ) + listener.startStatus != null + } + + try { + val listener1 = new QueryStatusCollector + val listener2 = new QueryStatusCollector + + sqlContext.streams.addListener(listener1) + assert(isListenerActive(listener1) === true) + assert(isListenerActive(listener2) === false) + sqlContext.streams.addListener(listener2) + assert(isListenerActive(listener1) === true) + assert(isListenerActive(listener2) === true) + sqlContext.streams.removeListener(listener1) + assert(isListenerActive(listener1) === false) + assert(isListenerActive(listener2) === true) + } finally { + addedListeners.foreach(sqlContext.streams.removeListener) + } + } + + test("event ordering") { + val listener = new QueryStatusCollector + withListenerAdded(listener) { + for (i <- 1 to 100) { + listener.reset() + require(listener.startStatus === null) + testStream(MemoryStream[Int].toDS)( + StartStream, + Assert(listener.startStatus !== null, "onQueryStarted not called before query returned"), + StopStream, + Assert { listener.checkAsyncErrors() } + ) + } + } + } + + + private def withListenerAdded(listener: ContinuousQueryListener)(body: => Unit): Unit = { + @volatile var query: StreamExecution = null + try { + failAfter(1 minute) { + sqlContext.streams.addListener(listener) + body + } + } finally { + sqlContext.streams.removeListener(listener) + } + } + + private def addedListeners(): Array[ContinuousQueryListener] = { + val listenerBusMethod = + PrivateMethod[ContinuousQueryListenerBus]('listenerBus) + val listenerBus = sqlContext.streams invokePrivate listenerBusMethod() + listenerBus.listeners.toArray.map(_.asInstanceOf[ContinuousQueryListener]) + } + + class QueryStatusCollector extends ContinuousQueryListener { + + private val asyncTestWaiter = new Waiter // to catch errors in the async listener events + + @volatile var startStatus: QueryStatus = null + @volatile var terminationStatus: QueryStatus = null + val progressStatuses = new ConcurrentLinkedQueue[QueryStatus] + + def reset(): Unit = { + startStatus = null + terminationStatus = null + progressStatuses.clear() + + // To reset the waiter + try asyncTestWaiter.await(timeout(1 milliseconds)) catch { + case NonFatal(e) => + } + } + + def checkAsyncErrors(): Unit = { + asyncTestWaiter.await(timeout(streamingTimeout)) + } + + + override def onQueryStarted(queryStarted: QueryStarted): Unit = { + asyncTestWaiter { + startStatus = QueryStatus(queryStarted.query) + } + } + + override def onQueryProgress(queryProgress: QueryProgress): Unit = { + asyncTestWaiter { + assert(startStatus != null, "onQueryProgress called before onQueryStarted") + progressStatuses.add(QueryStatus(queryProgress.query)) + } + } + + override def onQueryTerminated(queryTerminated: QueryTerminated): Unit = { + asyncTestWaiter { + assert(startStatus != null, "onQueryTerminated called before onQueryStarted") + terminationStatus = QueryStatus(queryTerminated.query) + } + asyncTestWaiter.dismiss() + } + } + + case class QueryStatus( + active: Boolean, + expection: Option[Exception], + sourceStatuses: Array[SourceStatus], + sinkStatus: SinkStatus) + + object QueryStatus { + def apply(query: ContinuousQuery): QueryStatus = { + QueryStatus(query.isActive, query.exception, query.sourceStatuses, query.sinkStatus) + } + } +}