[SPARK-19497][SS] Implement streaming deduplication

## What changes were proposed in this pull request?

This PR adds a special streaming deduplication operator to support `dropDuplicates` with `aggregation` and watermark. It reuses the `dropDuplicates` API but creates new logical plan `Deduplication` and new physical plan `DeduplicationExec`.

The following cases are supported:

- one or multiple `dropDuplicates()` without aggregation (with or without watermark)
- `dropDuplicates` before aggregation

Not supported cases:

- `dropDuplicates` after aggregation

Breaking changes:
- `dropDuplicates` without aggregation doesn't work with `complete` or `update` mode.

## How was this patch tested?

The new unit tests.

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #16970 from zsxwing/dedup.
This commit is contained in:
Shixiong Zhu 2017-02-23 11:25:39 -08:00 committed by Tathagata Das
parent 7bf09433f5
commit 9bf4e2baad
15 changed files with 578 additions and 58 deletions

View file

@ -1158,6 +1158,12 @@ class DataFrame(object):
"""Return a new :class:`DataFrame` with duplicate rows removed,
optionally only considering certain columns.
For a static batch :class:`DataFrame`, it just drops duplicate rows. For a streaming
:class:`DataFrame`, it will keep all data across triggers as intermediate state to drop
duplicates rows. You can use :func:`withWatermark` to limit how late the duplicate data can
be and system will accordingly limit the state. In addition, too late data older than
watermark will be dropped to avoid any possibility of duplicates.
:func:`drop_duplicates` is an alias for :func:`dropDuplicates`.
>>> from pyspark.sql import Row

View file

@ -75,7 +75,7 @@ object UnsupportedOperationChecker {
if (watermarkAttributes.isEmpty) {
throwError(
s"$outputMode output mode not supported when there are streaming aggregations on " +
s"streaming DataFrames/DataSets")(plan)
s"streaming DataFrames/DataSets without watermark")(plan)
}
case InternalOutputModes.Complete if aggregates.isEmpty =>
@ -120,6 +120,10 @@ object UnsupportedOperationChecker {
throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " +
"streaming DataFrame/Dataset")
case d: Deduplicate if collectStreamingAggregates(d).nonEmpty =>
throwError("dropDuplicates is not supported after aggregation on a " +
"streaming DataFrame/Dataset")
case Join(left, right, joinType, _) =>
joinType match {

View file

@ -56,7 +56,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
ReplaceExpressions,
ComputeCurrentTime,
GetCurrentDatabase(sessionCatalog),
RewriteDistinctAggregates) ::
RewriteDistinctAggregates,
ReplaceDeduplicateWithAggregate) ::
//////////////////////////////////////////////////////////////////////////////////////////
// Optimizer rules start here
//////////////////////////////////////////////////////////////////////////////////////////
@ -1142,6 +1143,24 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
}
}
/**
* Replaces logical [[Deduplicate]] operator with an [[Aggregate]] operator.
*/
object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Deduplicate(keys, child, streaming) if !streaming =>
val keyExprIds = keys.map(_.exprId)
val aggCols = child.output.map { attr =>
if (keyExprIds.contains(attr.exprId)) {
attr
} else {
Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId)
}
}
Aggregate(keys, aggCols, child)
}
}
/**
* Replaces logical [[Intersect]] operator with a left-semi [[Join]] operator.
* {{{

View file

@ -864,3 +864,12 @@ case object OneRowRelation extends LeafNode {
override def output: Seq[Attribute] = Nil
override def computeStats(conf: CatalystConf): Statistics = Statistics(sizeInBytes = 1)
}
/** A logical plan for `dropDuplicates`. */
case class Deduplicate(
keys: Seq[Attribute],
child: LogicalPlan,
streaming: Boolean) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}

View file

@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, LongType}
import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder}
import org.apache.spark.unsafe.types.CalendarInterval
/** A dummy command for testing unsupported operations. */
case class DummyCommand() extends Command
@ -36,6 +37,11 @@ case class DummyCommand() extends Command
class UnsupportedOperationsSuite extends SparkFunSuite {
val attribute = AttributeReference("a", IntegerType, nullable = true)()
val watermarkMetadata = new MetadataBuilder()
.withMetadata(attribute.metadata)
.putLong(EventTimeWatermark.delayKey, 1000L)
.build()
val attributeWithWatermark = attribute.withMetadata(watermarkMetadata)
val batchRelation = LocalRelation(attribute)
val streamRelation = new TestStreamingRelation(attribute)
@ -98,6 +104,27 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
outputMode = Update,
expectedMsgs = Seq("multiple streaming aggregations"))
assertSupportedInStreamingPlan(
"aggregate - streaming aggregations in update mode",
Aggregate(Nil, aggExprs("d"), streamRelation),
outputMode = Update)
assertSupportedInStreamingPlan(
"aggregate - streaming aggregations in complete mode",
Aggregate(Nil, aggExprs("d"), streamRelation),
outputMode = Complete)
assertSupportedInStreamingPlan(
"aggregate - streaming aggregations with watermark in append mode",
Aggregate(Seq(attributeWithWatermark), aggExprs("d"), streamRelation),
outputMode = Append)
assertNotSupportedInStreamingPlan(
"aggregate - streaming aggregations without watermark in append mode",
Aggregate(Nil, aggExprs("d"), streamRelation),
outputMode = Append,
expectedMsgs = Seq("streaming aggregations", "without watermark"))
// Aggregation: Distinct aggregates not supported on streaming relation
val distinctAggExprs = Seq(Count("*").toAggregateExpression(isDistinct = true).as("c"))
assertSupportedInStreamingPlan(
@ -129,6 +156,33 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
outputMode = Complete,
expectedMsgs = Seq("(map/flatMap)GroupsWithState"))
assertSupportedInStreamingPlan(
"mapGroupsWithState - mapGroupsWithState on batch relation inside streaming relation",
MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation),
outputMode = Append
)
// Deduplicate
assertSupportedInStreamingPlan(
"Deduplicate - Deduplicate on streaming relation before aggregation",
Aggregate(
Seq(attributeWithWatermark),
aggExprs("c"),
Deduplicate(Seq(att), streamRelation, streaming = true)),
outputMode = Append)
assertNotSupportedInStreamingPlan(
"Deduplicate - Deduplicate on streaming relation after aggregation",
Deduplicate(Seq(att), Aggregate(Nil, aggExprs("c"), streamRelation), streaming = true),
outputMode = Complete,
expectedMsgs = Seq("dropDuplicates"))
assertSupportedInStreamingPlan(
"Deduplicate - Deduplicate on batch relation inside a streaming query",
Deduplicate(Seq(att), batchRelation, streaming = false),
outputMode = Append
)
// Inner joins: Stream-stream not supported
testBinaryOperationInStreamingPlan(
"inner join",

View file

@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@ -30,7 +32,8 @@ class ReplaceOperatorSuite extends PlanTest {
Batch("Replace Operators", FixedPoint(100),
ReplaceDistinctWithAggregate,
ReplaceExceptWithAntiJoin,
ReplaceIntersectWithSemiJoin) :: Nil
ReplaceIntersectWithSemiJoin,
ReplaceDeduplicateWithAggregate) :: Nil
}
test("replace Intersect with Left-semi Join") {
@ -71,4 +74,32 @@ class ReplaceOperatorSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
test("replace batch Deduplicate with Aggregate") {
val input = LocalRelation('a.int, 'b.int)
val attrA = input.output(0)
val attrB = input.output(1)
val query = Deduplicate(Seq(attrA), input, streaming = false) // dropDuplicates("a")
val optimized = Optimize.execute(query.analyze)
val correctAnswer =
Aggregate(
Seq(attrA),
Seq(
attrA,
Alias(new First(attrB).toAggregateExpression(), attrB.name)(attrB.exprId)
),
input)
comparePlans(optimized, correctAnswer)
}
test("don't replace streaming Deduplicate") {
val input = LocalRelation('a.int, 'b.int)
val attrA = input.output(0)
val query = Deduplicate(Seq(attrA), input, streaming = true) // dropDuplicates("a")
val optimized = Optimize.execute(query.analyze)
comparePlans(optimized, query)
}
}

View file

@ -557,7 +557,8 @@ class Dataset[T] private[sql](
* Spark will use this watermark for several purposes:
* - To know when a given time window aggregation can be finalized and thus can be emitted when
* using output modes that do not allow updates.
* - To minimize the amount of state that we need to keep for on-going aggregations.
* - To minimize the amount of state that we need to keep for on-going aggregations,
* `mapGroupsWithState` and `dropDuplicates` operators.
*
* The current watermark is computed by looking at the `MAX(eventTime)` seen across
* all of the partitions in the query minus a user specified `delayThreshold`. Due to the cost
@ -1981,6 +1982,12 @@ class Dataset[T] private[sql](
* Returns a new Dataset that contains only the unique rows from this Dataset.
* This is an alias for `distinct`.
*
* For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
* will keep all data across triggers as intermediate state to drop duplicates rows. You can use
* [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
* the state. In addition, too late data older than watermark will be dropped to avoid any
* possibility of duplicates.
*
* @group typedrel
* @since 2.0.0
*/
@ -1990,13 +1997,19 @@ class Dataset[T] private[sql](
* (Scala-specific) Returns a new Dataset with duplicate rows removed, considering only
* the subset of columns.
*
* For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
* will keep all data across triggers as intermediate state to drop duplicates rows. You can use
* [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
* the state. In addition, too late data older than watermark will be dropped to avoid any
* possibility of duplicates.
*
* @group typedrel
* @since 2.0.0
*/
def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan {
val resolver = sparkSession.sessionState.analyzer.resolver
val allColumns = queryExecution.analyzed.output
val groupCols = colNames.flatMap { colName =>
val groupCols = colNames.toSet.toSeq.flatMap { (colName: String) =>
// It is possibly there are more than one columns with the same name,
// so we call filter instead of find.
val cols = allColumns.filter(col => resolver(col.name, colName))
@ -2006,21 +2019,19 @@ class Dataset[T] private[sql](
}
cols
}
val groupColExprIds = groupCols.map(_.exprId)
val aggCols = logicalPlan.output.map { attr =>
if (groupColExprIds.contains(attr.exprId)) {
attr
} else {
Alias(new First(attr).toAggregateExpression(), attr.name)()
}
}
Aggregate(groupCols, aggCols, logicalPlan)
Deduplicate(groupCols, logicalPlan, isStreaming)
}
/**
* Returns a new Dataset with duplicate rows removed, considering only
* the subset of columns.
*
* For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
* will keep all data across triggers as intermediate state to drop duplicates rows. You can use
* [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
* the state. In addition, too late data older than watermark will be dropped to avoid any
* possibility of duplicates.
*
* @group typedrel
* @since 2.0.0
*/
@ -2030,6 +2041,12 @@ class Dataset[T] private[sql](
* Returns a new [[Dataset]] with duplicate rows removed, considering only
* the subset of columns.
*
* For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
* will keep all data across triggers as intermediate state to drop duplicates rows. You can use
* [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
* the state. In addition, too late data older than watermark will be dropped to avoid any
* possibility of duplicates.
*
* @group typedrel
* @since 2.0.0
*/

View file

@ -22,9 +22,10 @@ import org.apache.spark.sql.{SaveMode, Strategy}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan, MapGroupsWithState}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
@ -244,6 +245,18 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
/**
* Used to plan the streaming deduplicate operator.
*/
object StreamingDeduplicationStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case Deduplicate(keys, child, true) =>
StreamingDeduplicateExec(keys, planLater(child)) :: Nil
case _ => Nil
}
}
/**
* Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
*/

View file

@ -45,6 +45,7 @@ class IncrementalExecution(
sparkSession.sessionState.planner.StatefulAggregationStrategy +:
sparkSession.sessionState.planner.MapGroupsWithStateStrategy +:
sparkSession.sessionState.planner.StreamingRelationStrategy +:
sparkSession.sessionState.planner.StreamingDeduplicationStrategy +:
sparkSession.sessionState.experimentalMethods.extraStrategies
// Modified planner with stateful operations.
@ -93,6 +94,15 @@ class IncrementalExecution(
keys,
Some(stateId),
child) :: Nil))
case StreamingDeduplicateExec(keys, child, None, None) =>
val stateId =
OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
StreamingDeduplicateExec(
keys,
child,
Some(stateId),
Some(currentEventTimeWatermark))
case MapGroupsWithStateExec(
f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) =>
val stateId =

View file

@ -25,12 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjecti
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, NullType, StructType}
import org.apache.spark.util.CompletionIterator
@ -68,6 +67,40 @@ trait StateStoreWriter extends StatefulOperator {
"numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"))
}
/** An operator that supports watermark. */
trait WatermarkSupport extends SparkPlan {
/** The keys that may have a watermark attribute. */
def keyExpressions: Seq[Attribute]
/** The watermark value. */
def eventTimeWatermark: Option[Long]
/** Generate a predicate that matches data older than the watermark */
lazy val watermarkPredicate: Option[Predicate] = {
val optionalWatermarkAttribute =
keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey))
optionalWatermarkAttribute.map { watermarkAttribute =>
// If we are evicting based on a window, use the end of the window. Otherwise just
// use the attribute itself.
val evictionExpression =
if (watermarkAttribute.dataType.isInstanceOf[StructType]) {
LessThanOrEqual(
GetStructField(watermarkAttribute, 1),
Literal(eventTimeWatermark.get * 1000))
} else {
LessThanOrEqual(
watermarkAttribute,
Literal(eventTimeWatermark.get * 1000))
}
logInfo(s"Filtering state store on: $evictionExpression")
newPredicate(evictionExpression, keyExpressions)
}
}
}
/**
* For each input tuple, the key is calculated and the value from the [[StateStore]] is added
* to the stream (in addition to the input tuple) if present.
@ -76,7 +109,7 @@ case class StateStoreRestoreExec(
keyExpressions: Seq[Attribute],
stateId: Option[OperatorStateId],
child: SparkPlan)
extends execution.UnaryExecNode with StateStoreReader {
extends UnaryExecNode with StateStoreReader {
override protected def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
@ -113,31 +146,7 @@ case class StateStoreSaveExec(
outputMode: Option[OutputMode] = None,
eventTimeWatermark: Option[Long] = None,
child: SparkPlan)
extends execution.UnaryExecNode with StateStoreWriter {
/** Generate a predicate that matches data older than the watermark */
private lazy val watermarkPredicate: Option[Predicate] = {
val optionalWatermarkAttribute =
keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey))
optionalWatermarkAttribute.map { watermarkAttribute =>
// If we are evicting based on a window, use the end of the window. Otherwise just
// use the attribute itself.
val evictionExpression =
if (watermarkAttribute.dataType.isInstanceOf[StructType]) {
LessThanOrEqual(
GetStructField(watermarkAttribute, 1),
Literal(eventTimeWatermark.get * 1000))
} else {
LessThanOrEqual(
watermarkAttribute,
Literal(eventTimeWatermark.get * 1000))
}
logInfo(s"Filtering state store on: $evictionExpression")
newPredicate(evictionExpression, keyExpressions)
}
}
extends UnaryExecNode with StateStoreWriter with WatermarkSupport {
override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver
@ -146,8 +155,8 @@ case class StateStoreSaveExec(
child.execute().mapPartitionsWithStateStore(
getStateId.checkpointLocation,
operatorId = getStateId.operatorId,
storeVersion = getStateId.batchId,
getStateId.operatorId,
getStateId.batchId,
keyExpressions.toStructType,
child.output.toStructType,
sqlContext.sessionState,
@ -262,8 +271,8 @@ case class MapGroupsWithStateExec(
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsWithStateStore[InternalRow](
getStateId.checkpointLocation,
operatorId = getStateId.operatorId,
storeVersion = getStateId.batchId,
getStateId.operatorId,
getStateId.batchId,
groupingAttributes.toStructType,
child.output.toStructType,
sqlContext.sessionState,
@ -321,3 +330,70 @@ case class MapGroupsWithStateExec(
}
}
}
/** Physical operator for executing streaming Deduplicate. */
case class StreamingDeduplicateExec(
keyExpressions: Seq[Attribute],
child: SparkPlan,
stateId: Option[OperatorStateId] = None,
eventTimeWatermark: Option[Long] = None)
extends UnaryExecNode with StateStoreWriter with WatermarkSupport {
/** Distribute by grouping attributes */
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(keyExpressions) :: Nil
override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver
child.execute().mapPartitionsWithStateStore(
getStateId.checkpointLocation,
getStateId.operatorId,
getStateId.batchId,
keyExpressions.toStructType,
child.output.toStructType,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
val numOutputRows = longMetric("numOutputRows")
val numTotalStateRows = longMetric("numTotalStateRows")
val numUpdatedStateRows = longMetric("numUpdatedStateRows")
val baseIterator = watermarkPredicate match {
case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row))
case None => iter
}
val result = baseIterator.filter { r =>
val row = r.asInstanceOf[UnsafeRow]
val key = getKey(row)
val value = store.get(key)
if (value.isEmpty) {
store.put(key.copy(), StreamingDeduplicateExec.EMPTY_ROW)
numUpdatedStateRows += 1
numOutputRows += 1
true
} else {
// Drop duplicated rows
false
}
}
CompletionIterator[InternalRow, Iterator[InternalRow]](result, {
watermarkPredicate.foreach(f => store.remove(f.eval _))
store.commit()
numTotalStateRows += store.numKeys()
})
}
}
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
}
object StreamingDeduplicateExec {
private val EMPTY_ROW =
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
}

View file

@ -0,0 +1,252 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.streaming
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.state.StateStore
import org.apache.spark.sql.functions._
class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
import testImplicits._
override def afterAll(): Unit = {
super.afterAll()
StateStore.stop()
}
test("deduplicate with all columns") {
val inputData = MemoryStream[String]
val result = inputData.toDS().dropDuplicates()
testStream(result, Append)(
AddData(inputData, "a"),
CheckLastBatch("a"),
assertNumStateRows(total = 1, updated = 1),
AddData(inputData, "a"),
CheckLastBatch(),
assertNumStateRows(total = 1, updated = 0),
AddData(inputData, "b"),
CheckLastBatch("b"),
assertNumStateRows(total = 2, updated = 1)
)
}
test("deduplicate with some columns") {
val inputData = MemoryStream[(String, Int)]
val result = inputData.toDS().dropDuplicates("_1")
testStream(result, Append)(
AddData(inputData, "a" -> 1),
CheckLastBatch("a" -> 1),
assertNumStateRows(total = 1, updated = 1),
AddData(inputData, "a" -> 2), // Dropped
CheckLastBatch(),
assertNumStateRows(total = 1, updated = 0),
AddData(inputData, "b" -> 1),
CheckLastBatch("b" -> 1),
assertNumStateRows(total = 2, updated = 1)
)
}
test("multiple deduplicates") {
val inputData = MemoryStream[(String, Int)]
val result = inputData.toDS().dropDuplicates().dropDuplicates("_1")
testStream(result, Append)(
AddData(inputData, "a" -> 1),
CheckLastBatch("a" -> 1),
assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)),
AddData(inputData, "a" -> 2), // Dropped from the second `dropDuplicates`
CheckLastBatch(),
assertNumStateRows(total = Seq(1L, 2L), updated = Seq(0L, 1L)),
AddData(inputData, "b" -> 1),
CheckLastBatch("b" -> 1),
assertNumStateRows(total = Seq(2L, 3L), updated = Seq(1L, 1L))
)
}
test("deduplicate with watermark") {
val inputData = MemoryStream[Int]
val result = inputData.toDS()
.withColumn("eventTime", $"value".cast("timestamp"))
.withWatermark("eventTime", "10 seconds")
.dropDuplicates()
.select($"eventTime".cast("long").as[Long])
testStream(result, Append)(
AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*),
CheckLastBatch(10 to 15: _*),
assertNumStateRows(total = 6, updated = 6),
AddData(inputData, 25), // Advance watermark to 15 seconds
CheckLastBatch(25),
assertNumStateRows(total = 7, updated = 1),
AddData(inputData, 25), // Drop states less than watermark
CheckLastBatch(),
assertNumStateRows(total = 1, updated = 0),
AddData(inputData, 10), // Should not emit anything as data less than watermark
CheckLastBatch(),
assertNumStateRows(total = 1, updated = 0),
AddData(inputData, 45), // Advance watermark to 35 seconds
CheckLastBatch(45),
assertNumStateRows(total = 2, updated = 1),
AddData(inputData, 45), // Drop states less than watermark
CheckLastBatch(),
assertNumStateRows(total = 1, updated = 0)
)
}
test("deduplicate with aggregate - append mode") {
val inputData = MemoryStream[Int]
val windowedaggregate = inputData.toDS()
.withColumn("eventTime", $"value".cast("timestamp"))
.withWatermark("eventTime", "10 seconds")
.dropDuplicates()
.withWatermark("eventTime", "10 seconds")
.groupBy(window($"eventTime", "5 seconds") as 'window)
.agg(count("*") as 'count)
.select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
testStream(windowedaggregate)(
AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*),
CheckLastBatch(),
// states in aggregate in [10, 14), [15, 20) (2 windows)
// states in deduplicate is 10 to 15
assertNumStateRows(total = Seq(2L, 6L), updated = Seq(2L, 6L)),
AddData(inputData, 25), // Advance watermark to 15 seconds
CheckLastBatch(),
// states in aggregate in [10, 14), [15, 20) and [25, 30) (3 windows)
// states in deduplicate is 10 to 15 and 25
assertNumStateRows(total = Seq(3L, 7L), updated = Seq(1L, 1L)),
AddData(inputData, 25), // Emit items less than watermark and drop their state
CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate
// states in aggregate in [15, 20) and [25, 30) (2 windows, note aggregate uses the end of
// window to evict items, so [15, 20) is still in the state store)
// states in deduplicate is 25
assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)),
AddData(inputData, 10), // Should not emit anything as data less than watermark
CheckLastBatch(),
assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)),
AddData(inputData, 40), // Advance watermark to 30 seconds
CheckLastBatch(),
// states in aggregate in [15, 20), [25, 30) and [40, 45)
// states in deduplicate is 25 and 40,
assertNumStateRows(total = Seq(3L, 2L), updated = Seq(1L, 1L)),
AddData(inputData, 40), // Emit items less than watermark and drop their state
CheckLastBatch((15 -> 1), (25 -> 1)),
// states in aggregate in [40, 45)
// states in deduplicate is 40,
assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L))
)
}
test("deduplicate with aggregate - update mode") {
val inputData = MemoryStream[(String, Int)]
val result = inputData.toDS()
.select($"_1" as "str", $"_2" as "num")
.dropDuplicates()
.groupBy("str")
.agg(sum("num"))
.as[(String, Long)]
testStream(result, Update)(
AddData(inputData, "a" -> 1),
CheckLastBatch("a" -> 1L),
assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)),
AddData(inputData, "a" -> 1), // Dropped
CheckLastBatch(),
assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)),
AddData(inputData, "a" -> 2),
CheckLastBatch("a" -> 3L),
assertNumStateRows(total = Seq(1L, 2L), updated = Seq(1L, 1L)),
AddData(inputData, "b" -> 1),
CheckLastBatch("b" -> 1L),
assertNumStateRows(total = Seq(2L, 3L), updated = Seq(1L, 1L))
)
}
test("deduplicate with aggregate - complete mode") {
val inputData = MemoryStream[(String, Int)]
val result = inputData.toDS()
.select($"_1" as "str", $"_2" as "num")
.dropDuplicates()
.groupBy("str")
.agg(sum("num"))
.as[(String, Long)]
testStream(result, Complete)(
AddData(inputData, "a" -> 1),
CheckLastBatch("a" -> 1L),
assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)),
AddData(inputData, "a" -> 1), // Dropped
CheckLastBatch("a" -> 1L),
assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)),
AddData(inputData, "a" -> 2),
CheckLastBatch("a" -> 3L),
assertNumStateRows(total = Seq(1L, 2L), updated = Seq(1L, 1L)),
AddData(inputData, "b" -> 1),
CheckLastBatch("a" -> 3L, "b" -> 1L),
assertNumStateRows(total = Seq(2L, 3L), updated = Seq(1L, 1L))
)
}
test("deduplicate with file sink") {
withTempDir { output =>
withTempDir { checkpointDir =>
val outputPath = output.getAbsolutePath
val inputData = MemoryStream[String]
val result = inputData.toDS().dropDuplicates()
val q = result.writeStream
.format("parquet")
.outputMode(Append)
.option("checkpointLocation", checkpointDir.getPath)
.start(outputPath)
try {
inputData.addData("a")
q.processAllAvailable()
checkDataset(spark.read.parquet(outputPath).as[String], "a")
inputData.addData("a") // Dropped
q.processAllAvailable()
checkDataset(spark.read.parquet(outputPath).as[String], "a")
inputData.addData("b")
q.processAllAvailable()
checkDataset(spark.read.parquet(outputPath).as[String], "a", "b")
} finally {
q.stop()
}
}
}
}
}

View file

@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore
/** Class to check custom state types */
case class RunningCount(count: Long)
class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll {
class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
import testImplicits._
@ -321,13 +321,6 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll {
CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count
)
}
private def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = AssertOnQuery { q =>
val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
assert(progressWithData.stateOperators(0).numRowsTotal === total, "incorrect total rows")
assert(progressWithData.stateOperators(0).numRowsUpdated === updated, "incorrect updates rows")
true
}
}
object MapGroupsWithStateSuite {

View file

@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.streaming
trait StateStoreMetricsTest extends StreamTest {
def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery =
AssertOnQuery { q =>
val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
assert(
progressWithData.stateOperators.map(_.numRowsTotal) === total,
"incorrect total rows")
assert(
progressWithData.stateOperators.map(_.numRowsUpdated) === updated,
"incorrect updates rows")
true
}
def assertNumStateRows(total: Long, updated: Long): AssertOnQuery =
assertNumStateRows(Seq(total), Seq(updated))
}

View file

@ -338,7 +338,7 @@ class StreamSuite extends StreamTest {
.writeStream
.format("memory")
.queryName("testquery")
.outputMode("complete")
.outputMode("append")
.start()
try {
query.processAllAvailable()

View file

@ -35,7 +35,7 @@ object FailureSinglton {
var firstTime = true
}
class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll {
class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
override def afterAll(): Unit = {
super.afterAll()