diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 73105f881b..9208a527d2 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -497,6 +497,26 @@ class DataFrameWriter(object): self._jwrite = self._jwrite.mode(saveMode) return self + @since(2.0) + def outputMode(self, outputMode): + """Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. + + Options include: + + * `append`:Only the new rows in the streaming DataFrame/Dataset will be written to + the sink + * `complete`:All the rows in the streaming DataFrame/Dataset will be written to the sink + every time these is some updates + + .. note:: Experimental. + + >>> writer = sdf.write.outputMode('append') + """ + if not outputMode or type(outputMode) != str or len(outputMode.strip()) == 0: + raise ValueError('The output mode must be a non-empty string. Got: %s' % outputMode) + self._jwrite = self._jwrite.outputMode(outputMode) + return self + @since(1.4) def format(self, source): """Specifies the underlying output data source. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1790432edd..0d9dd5ea2a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -926,7 +926,7 @@ class SQLTests(ReusedPySparkTestCase): out = os.path.join(tmpPath, 'out') chk = os.path.join(tmpPath, 'chk') cq = df.write.option('checkpointLocation', chk).queryName('this_query') \ - .format('parquet').option('path', out).startStream() + .format('parquet').outputMode('append').option('path', out).startStream() try: self.assertEqual(cq.name, 'this_query') self.assertTrue(cq.isActive) @@ -952,8 +952,9 @@ class SQLTests(ReusedPySparkTestCase): fake1 = os.path.join(tmpPath, 'fake1') fake2 = os.path.join(tmpPath, 'fake2') cq = df.write.option('checkpointLocation', fake1).format('memory').option('path', fake2) \ - .queryName('fake_query').startStream(path=out, format='parquet', queryName='this_query', - checkpointLocation=chk) + .queryName('fake_query').outputMode('append') \ + .startStream(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) + try: self.assertEqual(cq.name, 'this_query') self.assertTrue(cq.isActive) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/OutputMode.java new file mode 100644 index 0000000000..1936d53e5e --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/OutputMode.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql; + +import org.apache.spark.annotation.Experimental; + +/** + * :: Experimental :: + * + * OutputMode is used to what data will be written to a streaming sink when there is + * new data available in a streaming DataFrame/Dataset. + * + * @since 2.0.0 + */ +@Experimental +public class OutputMode { + + /** + * OutputMode in which only the new rows in the streaming DataFrame/Dataset will be + * written to the sink. This output mode can be only be used in queries that do not + * contain any aggregation. + * + * @since 2.0.0 + */ + public static OutputMode Append() { + return InternalOutputModes.Append$.MODULE$; + } + + /** + * OutputMode in which all the rows in the streaming DataFrame/Dataset will be written + * to the sink every time these is some updates. This output mode can only be used in queries + * that contain aggregations. + * + * @since 2.0.0 + */ + public static OutputMode Complete() { + return InternalOutputModes.Complete$.MODULE$; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/InternalOutputModes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/InternalOutputModes.scala new file mode 100644 index 0000000000..8ef5d9a653 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/InternalOutputModes.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +/** + * Internal helper class to generate objects representing various [[OutputMode]]s, + */ +private[sql] object InternalOutputModes { + + /** + * OutputMode in which only the new rows in the streaming DataFrame/Dataset will be + * written to the sink. This output mode can be only be used in queries that do not + * contain any aggregation. + */ + case object Append extends OutputMode + + /** + * OutputMode in which all the rows in the streaming DataFrame/Dataset will be written + * to the sink every time these is some updates. This output mode can only be used in queries + * that contain aggregations. + */ + case object Complete extends OutputMode + + /** + * OutputMode in which only the rows in the streaming DataFrame/Dataset that were updated will be + * written to the sink every time these is some updates. This output mode can only be used in + * queries that contain aggregations. + */ + case object Update extends OutputMode +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 0e08bf013c..f4c0347609 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, InternalOutputModes, OutputMode} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -29,8 +29,7 @@ object UnsupportedOperationChecker { def checkForBatch(plan: LogicalPlan): Unit = { plan.foreachUp { case p if p.isStreaming => - throwError( - "Queries with streaming sources must be executed with write.startStream()")(p) + throwError("Queries with streaming sources must be executed with write.startStream()")(p) case _ => } @@ -43,10 +42,10 @@ object UnsupportedOperationChecker { "Queries without streaming sources cannot be executed with write.startStream()")(plan) } - plan.foreachUp { implicit plan => + plan.foreachUp { implicit subPlan => // Operations that cannot exists anywhere in a streaming plan - plan match { + subPlan match { case _: Command => throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " + @@ -55,21 +54,6 @@ object UnsupportedOperationChecker { case _: InsertIntoTable => throwError("InsertIntoTable is not supported with streaming DataFrames/Datasets") - case Aggregate(_, _, child) if child.isStreaming => - if (outputMode == Append) { - throwError( - "Aggregations are not supported on streaming DataFrames/Datasets in " + - "Append output mode. Consider changing output mode to Update.") - } - val moreStreamingAggregates = child.find { - case Aggregate(_, _, grandchild) if grandchild.isStreaming => true - case _ => false - } - if (moreStreamingAggregates.nonEmpty) { - throwError("Multiple streaming aggregations are not supported with " + - "streaming DataFrames/Datasets") - } - case Join(left, right, joinType, _) => joinType match { @@ -119,10 +103,10 @@ object UnsupportedOperationChecker { case GroupingSets(_, _, child, _) if child.isStreaming => throwError("GroupingSets is not supported on streaming DataFrames/Datasets") - case GlobalLimit(_, _) | LocalLimit(_, _) if plan.children.forall(_.isStreaming) => + case GlobalLimit(_, _) | LocalLimit(_, _) if subPlan.children.forall(_.isStreaming) => throwError("Limits are not supported on streaming DataFrames/Datasets") - case Sort(_, _, _) | SortPartitions(_, _) if plan.children.forall(_.isStreaming) => + case Sort(_, _, _) | SortPartitions(_, _) if subPlan.children.forall(_.isStreaming) => throwError("Sorting is not supported on streaming DataFrames/Datasets") case Sample(_, _, _, _, child) if child.isStreaming => @@ -138,6 +122,27 @@ object UnsupportedOperationChecker { case _ => } } + + // Checks related to aggregations + val aggregates = plan.collect { case a @ Aggregate(_, _, _) if a.isStreaming => a } + outputMode match { + case InternalOutputModes.Append if aggregates.nonEmpty => + throwError( + s"$outputMode output mode not supported when there are streaming aggregations on " + + s"streaming DataFrames/DataSets")(plan) + + case InternalOutputModes.Complete | InternalOutputModes.Update if aggregates.isEmpty => + throwError( + s"$outputMode output mode not supported when there are no streaming aggregations on " + + s"streaming DataFrames/Datasets")(plan) + + case _ => + } + if (aggregates.size > 1) { + throwError( + "Multiple streaming aggregations are not supported with " + + "streaming DataFrames/Datasets")(plan) + } } private def throwErrorIf( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/OutputMode.scala b/sql/catalyst/src/test/java/org/apache/spark/sql/JavaOutputModeSuite.java similarity index 70% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/OutputMode.scala rename to sql/catalyst/src/test/java/org/apache/spark/sql/JavaOutputModeSuite.java index a4d387eae3..1764f3348d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/OutputMode.scala +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/JavaOutputModeSuite.java @@ -15,9 +15,17 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.analysis +package org.apache.spark.sql; -sealed trait OutputMode +import org.junit.Test; -case object Append extends OutputMode -case object Update extends OutputMode +public class JavaOutputModeSuite { + + @Test + public void testOutputModes() { + OutputMode o1 = OutputMode.Append(); + assert(o1.toString().toLowerCase().contains("append")); + OutputMode o2 = OutputMode.Complete(); + assert (o2.toString().toLowerCase().contains("complete")); + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index aaeee0f2a4..c2e3d47450 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, OutputMode} +import org.apache.spark.sql.InternalOutputModes._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -79,35 +80,13 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Append, expectedMsgs = "commands" :: Nil) - // Aggregates: Not supported on streams in Append mode - assertSupportedInStreamingPlan( - "aggregate - batch with update output mode", - batchRelation.groupBy("a")("count(*)"), - outputMode = Update) - - assertSupportedInStreamingPlan( - "aggregate - batch with append output mode", - batchRelation.groupBy("a")("count(*)"), - outputMode = Append) - - assertSupportedInStreamingPlan( - "aggregate - stream with update output mode", - streamRelation.groupBy("a")("count(*)"), - outputMode = Update) - - assertNotSupportedInStreamingPlan( - "aggregate - stream with append output mode", - streamRelation.groupBy("a")("count(*)"), - outputMode = Append, - Seq("aggregation", "append output mode")) - // Multiple streaming aggregations not supported def aggExprs(name: String): Seq[NamedExpression] = Seq(Count("*").as(name)) assertSupportedInStreamingPlan( "aggregate - multiple batch aggregations", Aggregate(Nil, aggExprs("c"), Aggregate(Nil, aggExprs("d"), batchRelation)), - Update) + Append) assertSupportedInStreamingPlan( "aggregate - multiple aggregations but only one streaming aggregation", @@ -209,7 +188,6 @@ class UnsupportedOperationsSuite extends SparkFunSuite { _.intersect(_), streamStreamSupported = false) - // Unary operations testUnaryOperatorInStreamingPlan("sort", Sort(Nil, true, _)) testUnaryOperatorInStreamingPlan("sort partitions", SortPartitions(Nil, _), expectedMsg = "sort") @@ -218,6 +196,10 @@ class UnsupportedOperationsSuite extends SparkFunSuite { testUnaryOperatorInStreamingPlan( "window", Window(Nil, Nil, Nil, _), expectedMsg = "non-time-based windows") + // Output modes with aggregation and non-aggregation plans + testOutputMode(Append, shouldSupportAggregation = false) + testOutputMode(Update, shouldSupportAggregation = true) + testOutputMode(Complete, shouldSupportAggregation = true) /* ======================================================================================= @@ -316,6 +298,37 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode) } + def testOutputMode( + outputMode: OutputMode, + shouldSupportAggregation: Boolean): Unit = { + + // aggregation + if (shouldSupportAggregation) { + assertNotSupportedInStreamingPlan( + s"$outputMode output mode - no aggregation", + streamRelation.where($"a" > 1), + outputMode = outputMode, + Seq("aggregation", s"$outputMode output mode")) + + assertSupportedInStreamingPlan( + s"$outputMode output mode - aggregation", + streamRelation.groupBy("a")("count(*)"), + outputMode = outputMode) + + } else { + assertSupportedInStreamingPlan( + s"$outputMode output mode - no aggregation", + streamRelation.where($"a" > 1), + outputMode = outputMode) + + assertNotSupportedInStreamingPlan( + s"$outputMode output mode - aggregation", + streamRelation.groupBy("a")("count(*)"), + outputMode = outputMode, + Seq("aggregation", s"$outputMode output mode")) + } + } + /** * Assert that the logical plan is supported as subplan insider a streaming plan. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala index eab557443d..c686400150 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.collection.mutable import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.{Append, OutputMode, UnsupportedOperationChecker} +import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf @@ -35,7 +35,7 @@ import org.apache.spark.util.{Clock, SystemClock} * @since 2.0.0 */ @Experimental -class ContinuousQueryManager(sparkSession: SparkSession) { +class ContinuousQueryManager private[sql] (sparkSession: SparkSession) { private[sql] val stateStoreCoordinator = StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env) @@ -175,9 +175,9 @@ class ContinuousQueryManager(sparkSession: SparkSession) { checkpointLocation: String, df: DataFrame, sink: Sink, + outputMode: OutputMode, trigger: Trigger = ProcessingTime(0), - triggerClock: Clock = new SystemClock(), - outputMode: OutputMode = Append): ContinuousQuery = { + triggerClock: Clock = new SystemClock()): ContinuousQuery = { activeQueriesLock.synchronized { if (activeQueries.contains(name)) { throw new IllegalArgumentException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index f2ba2dfc08..291b8250c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -77,7 +77,47 @@ final class DataFrameWriter private[sql](df: DataFrame) { case "ignore" => SaveMode.Ignore case "error" | "default" => SaveMode.ErrorIfExists case _ => throw new IllegalArgumentException(s"Unknown save mode: $saveMode. " + - "Accepted modes are 'overwrite', 'append', 'ignore', 'error'.") + "Accepted save modes are 'overwrite', 'append', 'ignore', 'error'.") + } + this + } + + /** + * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. + * - `OutputMode.Append()`: only the new rows in the streaming DataFrame/Dataset will be + * written to the sink + * - `OutputMode.Complete()`: all the rows in the streaming DataFrame/Dataset will be written + * to the sink every time these is some updates + * + * @since 2.0.0 + */ + @Experimental + def outputMode(outputMode: OutputMode): DataFrameWriter = { + assertStreaming("outputMode() can only be called on continuous queries") + this.outputMode = outputMode + this + } + + /** + * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. + * - `append`: only the new rows in the streaming DataFrame/Dataset will be written to + * the sink + * - `complete`: all the rows in the streaming DataFrame/Dataset will be written to the sink + * every time these is some updates + * + * @since 2.0.0 + */ + @Experimental + def outputMode(outputMode: String): DataFrameWriter = { + assertStreaming("outputMode() can only be called on continuous queries") + this.outputMode = outputMode.toLowerCase match { + case "append" => + OutputMode.Append + case "complete" => + OutputMode.Complete + case _ => + throw new IllegalArgumentException(s"Unknown output mode $outputMode. " + + "Accepted output modes are 'append' and 'complete'") } this } @@ -319,7 +359,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { checkpointPath.toUri.toString } - val sink = new MemorySink(df.schema) + val sink = new MemorySink(df.schema, outputMode) val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink)) resultDf.createOrReplaceTempView(queryName) val continuousQuery = df.sparkSession.sessionState.continuousQueryManager.startQuery( @@ -327,6 +367,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { checkpointLocation, df, sink, + outputMode, trigger) continuousQuery } else { @@ -352,7 +393,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { queryName, checkpointLocation, df, - dataSource.createSink(), + dataSource.createSink(outputMode), + outputMode, trigger) } } @@ -708,6 +750,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { private var mode: SaveMode = SaveMode.ErrorIfExists + private var outputMode: OutputMode = OutputMode.Append + private var trigger: Trigger = ProcessingTime(0L) private var extraOptions = new scala.collection.mutable.HashMap[String, String] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index f93c446007..d617a04813 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -311,8 +311,10 @@ object Utils { aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), child = restored) } - - val saved = StateStoreSaveExec(groupingAttributes, None, partialMerged2) + // Note: stateId and returnAllStates are filled in later with preparation rules + // in IncrementalExecution. + val saved = StateStoreSaveExec( + groupingAttributes, stateId = None, returnAllStates = None, partialMerged2) val finalAndCompleteAggregate: SparkPlan = { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b3beb6c85f..814880b0e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -248,15 +248,20 @@ case class DataSource( } /** Returns a sink that can be used to continually write data. */ - def createSink(): Sink = { + def createSink(outputMode: OutputMode): Sink = { providingClass.newInstance() match { - case s: StreamSinkProvider => s.createSink(sparkSession.sqlContext, options, partitionColumns) + case s: StreamSinkProvider => + s.createSink(sparkSession.sqlContext, options, partitionColumns, outputMode) case parquet: parquet.ParquetFileFormat => val caseInsensitiveOptions = new CaseInsensitiveMap(options) val path = caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") }) + if (outputMode != OutputMode.Append) { + throw new IllegalArgumentException( + s"Data source $className does not support $outputMode output mode") + } new FileStreamSink(sparkSession, path, parquet, partitionColumns, options) case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index fe5f36e1cd..5c86049851 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.analysis.OutputMode +import org.apache.spark.sql.{InternalOutputModes, OutputMode, SparkSession} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} @@ -54,16 +53,19 @@ class IncrementalExecution private[sql]( /** Locates save/restore pairs surrounding aggregation. */ val state = new Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan transform { - case StateStoreSaveExec(keys, None, + case StateStoreSaveExec(keys, None, None, UnaryExecNode(agg, StateStoreRestoreExec(keys2, None, child))) => val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId) + val returnAllStates = if (outputMode == InternalOutputModes.Complete) true else false operatorId += 1 StateStoreSaveExec( keys, Some(stateId), + Some(returnAllStates), agg.withNewChildren( StateStoreRestoreExec( keys, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index d5e4dd8f78..4d0283fbef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -82,10 +82,14 @@ case class StateStoreRestoreExec( case class StateStoreSaveExec( keyExpressions: Seq[Attribute], stateId: Option[OperatorStateId], + returnAllStates: Option[Boolean], child: SparkPlan) extends execution.UnaryExecNode with StatefulOperator { override protected def doExecute(): RDD[InternalRow] = { + assert(returnAllStates.nonEmpty, + "Incorrect planning in IncrementalExecution, returnAllStates have not been set") + val saveAndReturnFunc = if (returnAllStates.get) saveAndReturnAll _ else saveAndReturnUpdated _ child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, operatorId = getStateId.operatorId, @@ -93,29 +97,57 @@ case class StateStoreSaveExec( keyExpressions.toStructType, child.output.toStructType, sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - new Iterator[InternalRow] { - private[this] val baseIterator = iter - private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) - - override def hasNext: Boolean = { - if (!baseIterator.hasNext) { - store.commit() - false - } else { - true - } - } - - override def next(): InternalRow = { - val row = baseIterator.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key.copy(), row.copy()) - row - } - } - } + Some(sqlContext.streams.stateStoreCoordinator) + )(saveAndReturnFunc) } override def output: Seq[Attribute] = child.output + + /** + * Save all the rows to the state store, and return all the rows in the state store. + * Note that this returns an iterator that pipelines the saving to store with downstream + * processing. + */ + private def saveAndReturnUpdated( + store: StateStore, + iter: Iterator[InternalRow]): Iterator[InternalRow] = { + new Iterator[InternalRow] { + private[this] val baseIterator = iter + private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + + override def hasNext: Boolean = { + if (!baseIterator.hasNext) { + store.commit() + false + } else { + true + } + } + + override def next(): InternalRow = { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + row + } + } + } + + /** + * Save all the rows to the state store, and return all the rows in the state store. + * Note that the saving to store is blocking; only after all the rows have been saved + * is the iterator on the update store data is generated. + */ + private def saveAndReturnAll( + store: StateStore, + iter: Iterator[InternalRow]): Iterator[InternalRow] = { + val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + } + store.commit() + store.iterator().map(_._2.asInstanceOf[InternalRow]) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 7d09bdcebd..ab0900d7f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -28,7 +28,6 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.OutputMode import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index f11a3fb969..391f1e54b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, OutputMode, SQLContext} import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} class ConsoleSink(options: Map[String, String]) extends Sink with Logging { @@ -52,7 +52,8 @@ class ConsoleSinkProvider extends StreamSinkProvider with DataSourceRegister { def createSink( sqlContext: SQLContext, parameters: Map[String, String], - partitionColumns: Seq[String]): Sink = { + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { new ConsoleSink(parameters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index bcc33ae8c8..e4a95e7335 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LeafNode @@ -114,35 +114,49 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType) extends Sink with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink with Logging { + + private case class AddedData(batchId: Long, data: Array[Row]) + /** An order list of batches that have been written to this [[Sink]]. */ @GuardedBy("this") - private val batches = new ArrayBuffer[Array[Row]]() + private val batches = new ArrayBuffer[AddedData]() /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { - batches.flatten + batches.map(_.data).flatten } - def latestBatchId: Option[Int] = synchronized { - if (batches.size == 0) None else Some(batches.size - 1) + def latestBatchId: Option[Long] = synchronized { + batches.lastOption.map(_.batchId) } - def lastBatch: Seq[Row] = synchronized { batches.last } + def latestBatchData: Seq[Row] = synchronized { batches.lastOption.toSeq.flatten(_.data) } def toDebugString: String = synchronized { - batches.zipWithIndex.map { case (b, i) => - val dataStr = try b.mkString(" ") catch { + batches.map { case AddedData(batchId, data) => + val dataStr = try data.mkString(" ") catch { case NonFatal(e) => "[Error converting to string]" } - s"$i: $dataStr" + s"$batchId: $dataStr" }.mkString("\n") } override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized { - if (batchId == batches.size) { - logDebug(s"Committing batch $batchId") - batches.append(data.collect()) + if (latestBatchId.isEmpty || batchId > latestBatchId.get) { + logDebug(s"Committing batch $batchId to $this") + outputMode match { + case InternalOutputModes.Append | InternalOutputModes.Update => + batches.append(AddedData(batchId, data.collect())) + + case InternalOutputModes.Complete => + batches.clear() + batches.append(AddedData(batchId, data.collect())) + + case _ => + throw new IllegalArgumentException( + s"Output mode $outputMode is not supported by MemorySink") + } } else { logDebug(s"Skipping already committed batch: $batchId") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 26285bde31..3d4edbb93d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -137,7 +137,8 @@ trait StreamSinkProvider { def createSink( sqlContext: SQLContext, parameters: Map[String, String], - partitionColumns: Seq[String]): Sink + partitionColumns: Seq[String], + outputMode: OutputMode): Sink } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 1ab562f873..b033725f18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -33,7 +33,6 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.Span import org.scalatest.time.SpanSugar._ -import org.apache.spark.sql.catalyst.analysis.{Append, OutputMode} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ @@ -69,8 +68,6 @@ trait StreamTest extends QueryTest with Timeouts { /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds - val outputMode: OutputMode = Append - /** A trait for actions that can be performed while testing a streaming DataFrame. */ trait StreamAction @@ -191,14 +188,17 @@ trait StreamTest extends QueryTest with Timeouts { * Note that if the stream is not explicitly started before an action that requires it to be * running then it will be automatically started before performing any other actions. */ - def testStream(_stream: Dataset[_])(actions: StreamAction*): Unit = { + def testStream( + _stream: Dataset[_], + outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = { + val stream = _stream.toDF() var pos = 0 var currentPlan: LogicalPlan = stream.logicalPlan var currentStream: StreamExecution = null var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for - val sink = new MemorySink(stream.schema) + val sink = new MemorySink(stream.schema, outputMode) @volatile var streamDeathCause: Throwable = null @@ -297,9 +297,9 @@ trait StreamTest extends QueryTest with Timeouts { metadataRoot, stream, sink, + outputMode, trigger, - triggerClock, - outputMode = outputMode) + triggerClock) .asInstanceOf[StreamExecution] currentStream.microBatchThread.setUncaughtExceptionHandler( new UncaughtExceptionHandler { @@ -429,7 +429,7 @@ trait StreamTest extends QueryTest with Timeouts { } } - val sparkAnswer = try if (lastOnly) sink.lastBatch else sink.allData catch { + val sparkAnswer = try if (lastOnly) sink.latestBatchData else sink.allData catch { case e: Exception => failTest("Exception while getting data from sink", e) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala index a743cdde40..b75c3ea106 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.time.Span import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkException -import org.apache.spark.sql.{ContinuousQuery, Dataset, StreamTest} +import org.apache.spark.sql.{ContinuousQuery, Dataset, OutputMode, StreamTest} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -232,20 +232,20 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with private def withQueriesOn(datasets: Dataset[_]*)(body: Seq[ContinuousQuery] => Unit): Unit = { failAfter(streamingTimeout) { val queries = withClue("Error starting queries") { - datasets.map { ds => + datasets.zipWithIndex.map { case (ds, i) => @volatile var query: StreamExecution = null try { val df = ds.toDF val metadataRoot = - Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath - query = spark - .streams - .startQuery( - StreamExecution.nextName, - metadataRoot, - df, - new MemorySink(df.schema)) - .asInstanceOf[StreamExecution] + Utils.createTempDir(namePrefix = "streaming.checkpoint").getCanonicalPath + query = + df.write + .format("memory") + .queryName(s"query$i") + .option("checkpointLocation", metadataRoot) + .outputMode("append") + .startStream() + .asInstanceOf[StreamExecution] } catch { case NonFatal(e) => if (query != null) query.stop() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala index 09c35bbf2c..e5bd0b4744 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala @@ -17,27 +17,132 @@ package org.apache.spark.sql.streaming -import org.apache.spark.sql.{AnalysisException, Row, StreamTest} +import scala.language.implicitConversions + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils -class MemorySinkSuite extends StreamTest with SharedSQLContext { +class MemorySinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { + import testImplicits._ - test("registering as a table") { - testRegisterAsTable() + after { + sqlContext.streams.active.foreach(_.stop()) } - ignore("stress test") { - // Ignore the stress test as it takes several minutes to run - (0 until 1000).foreach(_ => testRegisterAsTable()) + test("directly add data in Append output mode") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + val sink = new MemorySink(schema, InternalOutputModes.Append) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 6) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 6) + checkAnswer(sink.allData, 1 to 6) // new data should get appended to old data + + // Re-add batch 1 with different data, should not be added and outputs should not be changed + sink.addBatch(1, 7 to 9) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 6) + checkAnswer(sink.allData, 1 to 6) + + // Add batch 2 and check outputs + sink.addBatch(2, 7 to 9) + assert(sink.latestBatchId === Some(2)) + checkAnswer(sink.latestBatchData, 7 to 9) + checkAnswer(sink.allData, 1 to 9) } - private def testRegisterAsTable(): Unit = { + test("directly add data in Update output mode") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + val sink = new MemorySink(schema, InternalOutputModes.Update) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 6) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 6) + checkAnswer(sink.allData, 1 to 6) // new data should get appended to old data + + // Re-add batch 1 with different data, should not be added and outputs should not be changed + sink.addBatch(1, 7 to 9) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 6) + checkAnswer(sink.allData, 1 to 6) + + // Add batch 2 and check outputs + sink.addBatch(2, 7 to 9) + assert(sink.latestBatchId === Some(2)) + checkAnswer(sink.latestBatchData, 7 to 9) + checkAnswer(sink.allData, 1 to 9) + } + + test("directly add data in Complete output mode") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + val sink = new MemorySink(schema, InternalOutputModes.Complete) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 6) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 6) + checkAnswer(sink.allData, 4 to 6) // new data should replace old data + + // Re-add batch 1 with different data, should not be added and outputs should not be changed + sink.addBatch(1, 7 to 9) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 6) + checkAnswer(sink.allData, 4 to 6) + + // Add batch 2 and check outputs + sink.addBatch(2, 7 to 9) + assert(sink.latestBatchId === Some(2)) + checkAnswer(sink.latestBatchData, 7 to 9) + checkAnswer(sink.allData, 7 to 9) + } + + + test("registering as a table in Append output mode") { val input = MemoryStream[Int] val query = input.toDF().write .format("memory") + .outputMode("append") .queryName("memStream") .startStream() input.addData(1, 2, 3) @@ -56,6 +161,57 @@ class MemorySinkSuite extends StreamTest with SharedSQLContext { query.stop() } + test("registering as a table in Complete output mode") { + val input = MemoryStream[Int] + val query = input.toDF() + .groupBy("value") + .count() + .write + .format("memory") + .outputMode("complete") + .queryName("memStream") + .startStream() + input.addData(1, 2, 3) + query.processAllAvailable() + + checkDataset( + spark.table("memStream").as[(Int, Long)], + (1, 1L), (2, 1L), (3, 1L)) + + input.addData(4, 5, 6) + query.processAllAvailable() + checkDataset( + spark.table("memStream").as[(Int, Long)], + (1, 1L), (2, 1L), (3, 1L), (4, 1L), (5, 1L), (6, 1L)) + + query.stop() + } + + ignore("stress test") { + // Ignore the stress test as it takes several minutes to run + (0 until 1000).foreach { _ => + val input = MemoryStream[Int] + val query = input.toDF().write + .format("memory") + .queryName("memStream") + .startStream() + input.addData(1, 2, 3) + query.processAllAvailable() + + checkDataset( + spark.table("memStream").as[Int], + 1, 2, 3) + + input.addData(4, 5, 6) + query.processAllAvailable() + checkDataset( + spark.table("memStream").as[Int], + 1, 2, 3, 4, 5, 6) + + query.stop() + } + } + test("error when no name is specified") { val error = intercept[AnalysisException] { val input = MemoryStream[Int] @@ -88,4 +244,15 @@ class MemorySinkSuite extends StreamTest with SharedSQLContext { .startStream() } } + + private def checkAnswer(rows: Seq[Row], expected: Seq[Int])(implicit schema: StructType): Unit = { + checkAnswer( + sqlContext.createDataFrame(sparkContext.makeRDD(rows), schema), + intsToDF(expected)(schema)) + } + + private implicit def intsToDF(seq: Seq[Int])(implicit schema: StructType): DataFrame = { + require(schema.fields.size === 1) + sqlContext.createDataset(seq).toDF(schema.fieldNames.head) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index ae89a6887a..c17cb1de6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -235,6 +235,13 @@ class StreamSuite extends StreamTest with SharedSQLContext { spark.experimental.extraStrategies = Nil } } + + test("output mode API in Scala") { + val o1 = OutputMode.Append + assert(o1 === InternalOutputModes.Append) + val o2 = OutputMode.Complete + assert(o2 === InternalOutputModes.Complete) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 7104d01c4a..322bbb9ea0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.streaming import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException -import org.apache.spark.sql.StreamTest -import org.apache.spark.sql.catalyst.analysis.Update +import org.apache.spark.sql.{AnalysisException, StreamTest} +import org.apache.spark.sql.InternalOutputModes._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed @@ -41,9 +41,7 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext with Be import testImplicits._ - override val outputMode = Update - - test("simple count") { + test("simple count, update mode") { val inputData = MemoryStream[Int] val aggregated = @@ -52,7 +50,7 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext with Be .agg(count("*")) .as[(Int, Long)] - testStream(aggregated)( + testStream(aggregated, Update)( AddData(inputData, 3), CheckLastBatch((3, 1)), AddData(inputData, 3, 2), @@ -67,6 +65,46 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext with Be ) } + test("simple count, complete mode") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated, Complete)( + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 2), + CheckLastBatch((3, 1), (2, 1)), + StopStream, + StartStream(), + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 2), (2, 2), (1, 1)), + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4), (3, 2), (2, 2), (1, 1)) + ) + } + + test("simple count, append mode") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + val e = intercept[AnalysisException] { + testStream(aggregated, Append)() + } + Seq("append", "not supported").foreach { m => + assert(e.getMessage.toLowerCase.contains(m.toLowerCase)) + } + } + test("multiple keys") { val inputData = MemoryStream[Int] @@ -76,7 +114,7 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext with Be .agg(count("*")) .as[(Int, Int, Long)] - testStream(aggregated)( + testStream(aggregated, Update)( AddData(inputData, 1, 2), CheckLastBatch((1, 2, 1), (2, 3, 1)), AddData(inputData, 1, 2), @@ -101,7 +139,7 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext with Be .agg(count("*")) .as[(Int, Long)] - testStream(aggregated)( + testStream(aggregated, Update)( StartStream(), AddData(inputData, 1, 2, 3, 4), ExpectFailure[SparkException](), @@ -114,7 +152,7 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext with Be val inputData = MemoryStream[(String, Int)] val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2)) - testStream(aggregated)( + testStream(aggregated, Update)( AddData(inputData, ("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)), CheckLastBatch(("a", 30), ("b", 3), ("c", 1)) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataFrameReaderWriterSuite.scala index 288f6dc597..38a0534ab6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataFrameReaderWriterSuite.scala @@ -90,10 +90,11 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { override def createSink( spark: SQLContext, parameters: Map[String, String], - partitionColumns: Seq[String]): Sink = { + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { LastOptions.parameters = parameters LastOptions.partitionColumns = partitionColumns - LastOptions.mockStreamSinkProvider.createSink(spark, parameters, partitionColumns) + LastOptions.mockStreamSinkProvider.createSink(spark, parameters, partitionColumns, outputMode) new Sink { override def addBatch(batchId: Long, data: DataFrame): Unit = {} } @@ -416,6 +417,39 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B assert(e.getMessage == "mode() can only be called on non-continuous queries;") } + test("check outputMode(OutputMode) can only be called on continuous queries") { + val df = spark.read.text(newTextInput) + val w = df.write.option("checkpointLocation", newMetadataDir) + val e = intercept[AnalysisException](w.outputMode(OutputMode.Append)) + Seq("outputmode", "continuous queries").foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } + } + + test("check outputMode(string) can only be called on continuous queries") { + val df = spark.read.text(newTextInput) + val w = df.write.option("checkpointLocation", newMetadataDir) + val e = intercept[AnalysisException](w.outputMode("append")) + Seq("outputmode", "continuous queries").foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } + } + + test("check outputMode(string) throws exception on unsupported modes") { + def testError(outputMode: String): Unit = { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[IllegalArgumentException](w.outputMode(outputMode)) + Seq("output mode", "unknown", outputMode).foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } + } + testError("Update") + testError("Xyz") + } + test("check bucketBy() can only be called on non-continuous queries") { val df = spark.read .format("org.apache.spark.sql.streaming.test")