[SPARK-19497][SS] Implement streaming deduplication
## What changes were proposed in this pull request? This PR adds a special streaming deduplication operator to support `dropDuplicates` with `aggregation` and watermark. It reuses the `dropDuplicates` API but creates new logical plan `Deduplication` and new physical plan `DeduplicationExec`. The following cases are supported: - one or multiple `dropDuplicates()` without aggregation (with or without watermark) - `dropDuplicates` before aggregation Not supported cases: - `dropDuplicates` after aggregation Breaking changes: - `dropDuplicates` without aggregation doesn't work with `complete` or `update` mode. ## How was this patch tested? The new unit tests. Author: Shixiong Zhu <shixiong@databricks.com> Closes #16970 from zsxwing/dedup.
This commit is contained in:
parent
7bf09433f5
commit
9bf4e2baad
|
@ -1158,6 +1158,12 @@ class DataFrame(object):
|
|||
"""Return a new :class:`DataFrame` with duplicate rows removed,
|
||||
optionally only considering certain columns.
|
||||
|
||||
For a static batch :class:`DataFrame`, it just drops duplicate rows. For a streaming
|
||||
:class:`DataFrame`, it will keep all data across triggers as intermediate state to drop
|
||||
duplicates rows. You can use :func:`withWatermark` to limit how late the duplicate data can
|
||||
be and system will accordingly limit the state. In addition, too late data older than
|
||||
watermark will be dropped to avoid any possibility of duplicates.
|
||||
|
||||
:func:`drop_duplicates` is an alias for :func:`dropDuplicates`.
|
||||
|
||||
>>> from pyspark.sql import Row
|
||||
|
|
|
@ -75,7 +75,7 @@ object UnsupportedOperationChecker {
|
|||
if (watermarkAttributes.isEmpty) {
|
||||
throwError(
|
||||
s"$outputMode output mode not supported when there are streaming aggregations on " +
|
||||
s"streaming DataFrames/DataSets")(plan)
|
||||
s"streaming DataFrames/DataSets without watermark")(plan)
|
||||
}
|
||||
|
||||
case InternalOutputModes.Complete if aggregates.isEmpty =>
|
||||
|
@ -120,6 +120,10 @@ object UnsupportedOperationChecker {
|
|||
throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " +
|
||||
"streaming DataFrame/Dataset")
|
||||
|
||||
case d: Deduplicate if collectStreamingAggregates(d).nonEmpty =>
|
||||
throwError("dropDuplicates is not supported after aggregation on a " +
|
||||
"streaming DataFrame/Dataset")
|
||||
|
||||
case Join(left, right, joinType, _) =>
|
||||
|
||||
joinType match {
|
||||
|
|
|
@ -56,7 +56,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
|
|||
ReplaceExpressions,
|
||||
ComputeCurrentTime,
|
||||
GetCurrentDatabase(sessionCatalog),
|
||||
RewriteDistinctAggregates) ::
|
||||
RewriteDistinctAggregates,
|
||||
ReplaceDeduplicateWithAggregate) ::
|
||||
//////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Optimizer rules start here
|
||||
//////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1142,6 +1143,24 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Replaces logical [[Deduplicate]] operator with an [[Aggregate]] operator.
|
||||
*/
|
||||
object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case Deduplicate(keys, child, streaming) if !streaming =>
|
||||
val keyExprIds = keys.map(_.exprId)
|
||||
val aggCols = child.output.map { attr =>
|
||||
if (keyExprIds.contains(attr.exprId)) {
|
||||
attr
|
||||
} else {
|
||||
Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId)
|
||||
}
|
||||
}
|
||||
Aggregate(keys, aggCols, child)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Replaces logical [[Intersect]] operator with a left-semi [[Join]] operator.
|
||||
* {{{
|
||||
|
|
|
@ -864,3 +864,12 @@ case object OneRowRelation extends LeafNode {
|
|||
override def output: Seq[Attribute] = Nil
|
||||
override def computeStats(conf: CatalystConf): Statistics = Statistics(sizeInBytes = 1)
|
||||
}
|
||||
|
||||
/** A logical plan for `dropDuplicates`. */
|
||||
case class Deduplicate(
|
||||
keys: Seq[Attribute],
|
||||
child: LogicalPlan,
|
||||
streaming: Boolean) extends UnaryNode {
|
||||
|
||||
override def output: Seq[Attribute] = child.output
|
||||
}
|
||||
|
|
|
@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.plans._
|
|||
import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _}
|
||||
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
|
||||
import org.apache.spark.sql.streaming.OutputMode
|
||||
import org.apache.spark.sql.types.{IntegerType, LongType}
|
||||
import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder}
|
||||
import org.apache.spark.unsafe.types.CalendarInterval
|
||||
|
||||
/** A dummy command for testing unsupported operations. */
|
||||
case class DummyCommand() extends Command
|
||||
|
@ -36,6 +37,11 @@ case class DummyCommand() extends Command
|
|||
class UnsupportedOperationsSuite extends SparkFunSuite {
|
||||
|
||||
val attribute = AttributeReference("a", IntegerType, nullable = true)()
|
||||
val watermarkMetadata = new MetadataBuilder()
|
||||
.withMetadata(attribute.metadata)
|
||||
.putLong(EventTimeWatermark.delayKey, 1000L)
|
||||
.build()
|
||||
val attributeWithWatermark = attribute.withMetadata(watermarkMetadata)
|
||||
val batchRelation = LocalRelation(attribute)
|
||||
val streamRelation = new TestStreamingRelation(attribute)
|
||||
|
||||
|
@ -98,6 +104,27 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
|
|||
outputMode = Update,
|
||||
expectedMsgs = Seq("multiple streaming aggregations"))
|
||||
|
||||
assertSupportedInStreamingPlan(
|
||||
"aggregate - streaming aggregations in update mode",
|
||||
Aggregate(Nil, aggExprs("d"), streamRelation),
|
||||
outputMode = Update)
|
||||
|
||||
assertSupportedInStreamingPlan(
|
||||
"aggregate - streaming aggregations in complete mode",
|
||||
Aggregate(Nil, aggExprs("d"), streamRelation),
|
||||
outputMode = Complete)
|
||||
|
||||
assertSupportedInStreamingPlan(
|
||||
"aggregate - streaming aggregations with watermark in append mode",
|
||||
Aggregate(Seq(attributeWithWatermark), aggExprs("d"), streamRelation),
|
||||
outputMode = Append)
|
||||
|
||||
assertNotSupportedInStreamingPlan(
|
||||
"aggregate - streaming aggregations without watermark in append mode",
|
||||
Aggregate(Nil, aggExprs("d"), streamRelation),
|
||||
outputMode = Append,
|
||||
expectedMsgs = Seq("streaming aggregations", "without watermark"))
|
||||
|
||||
// Aggregation: Distinct aggregates not supported on streaming relation
|
||||
val distinctAggExprs = Seq(Count("*").toAggregateExpression(isDistinct = true).as("c"))
|
||||
assertSupportedInStreamingPlan(
|
||||
|
@ -129,6 +156,33 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
|
|||
outputMode = Complete,
|
||||
expectedMsgs = Seq("(map/flatMap)GroupsWithState"))
|
||||
|
||||
assertSupportedInStreamingPlan(
|
||||
"mapGroupsWithState - mapGroupsWithState on batch relation inside streaming relation",
|
||||
MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation),
|
||||
outputMode = Append
|
||||
)
|
||||
|
||||
// Deduplicate
|
||||
assertSupportedInStreamingPlan(
|
||||
"Deduplicate - Deduplicate on streaming relation before aggregation",
|
||||
Aggregate(
|
||||
Seq(attributeWithWatermark),
|
||||
aggExprs("c"),
|
||||
Deduplicate(Seq(att), streamRelation, streaming = true)),
|
||||
outputMode = Append)
|
||||
|
||||
assertNotSupportedInStreamingPlan(
|
||||
"Deduplicate - Deduplicate on streaming relation after aggregation",
|
||||
Deduplicate(Seq(att), Aggregate(Nil, aggExprs("c"), streamRelation), streaming = true),
|
||||
outputMode = Complete,
|
||||
expectedMsgs = Seq("dropDuplicates"))
|
||||
|
||||
assertSupportedInStreamingPlan(
|
||||
"Deduplicate - Deduplicate on batch relation inside a streaming query",
|
||||
Deduplicate(Seq(att), batchRelation, streaming = false),
|
||||
outputMode = Append
|
||||
)
|
||||
|
||||
// Inner joins: Stream-stream not supported
|
||||
testBinaryOperationInStreamingPlan(
|
||||
"inner join",
|
||||
|
|
|
@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.optimizer
|
|||
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
import org.apache.spark.sql.catalyst.expressions.Alias
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.First
|
||||
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.rules.RuleExecutor
|
||||
|
@ -30,7 +32,8 @@ class ReplaceOperatorSuite extends PlanTest {
|
|||
Batch("Replace Operators", FixedPoint(100),
|
||||
ReplaceDistinctWithAggregate,
|
||||
ReplaceExceptWithAntiJoin,
|
||||
ReplaceIntersectWithSemiJoin) :: Nil
|
||||
ReplaceIntersectWithSemiJoin,
|
||||
ReplaceDeduplicateWithAggregate) :: Nil
|
||||
}
|
||||
|
||||
test("replace Intersect with Left-semi Join") {
|
||||
|
@ -71,4 +74,32 @@ class ReplaceOperatorSuite extends PlanTest {
|
|||
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
||||
test("replace batch Deduplicate with Aggregate") {
|
||||
val input = LocalRelation('a.int, 'b.int)
|
||||
val attrA = input.output(0)
|
||||
val attrB = input.output(1)
|
||||
val query = Deduplicate(Seq(attrA), input, streaming = false) // dropDuplicates("a")
|
||||
val optimized = Optimize.execute(query.analyze)
|
||||
|
||||
val correctAnswer =
|
||||
Aggregate(
|
||||
Seq(attrA),
|
||||
Seq(
|
||||
attrA,
|
||||
Alias(new First(attrB).toAggregateExpression(), attrB.name)(attrB.exprId)
|
||||
),
|
||||
input)
|
||||
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
||||
test("don't replace streaming Deduplicate") {
|
||||
val input = LocalRelation('a.int, 'b.int)
|
||||
val attrA = input.output(0)
|
||||
val query = Deduplicate(Seq(attrA), input, streaming = true) // dropDuplicates("a")
|
||||
val optimized = Optimize.execute(query.analyze)
|
||||
|
||||
comparePlans(optimized, query)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -557,7 +557,8 @@ class Dataset[T] private[sql](
|
|||
* Spark will use this watermark for several purposes:
|
||||
* - To know when a given time window aggregation can be finalized and thus can be emitted when
|
||||
* using output modes that do not allow updates.
|
||||
* - To minimize the amount of state that we need to keep for on-going aggregations.
|
||||
* - To minimize the amount of state that we need to keep for on-going aggregations,
|
||||
* `mapGroupsWithState` and `dropDuplicates` operators.
|
||||
*
|
||||
* The current watermark is computed by looking at the `MAX(eventTime)` seen across
|
||||
* all of the partitions in the query minus a user specified `delayThreshold`. Due to the cost
|
||||
|
@ -1981,6 +1982,12 @@ class Dataset[T] private[sql](
|
|||
* Returns a new Dataset that contains only the unique rows from this Dataset.
|
||||
* This is an alias for `distinct`.
|
||||
*
|
||||
* For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
|
||||
* will keep all data across triggers as intermediate state to drop duplicates rows. You can use
|
||||
* [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
|
||||
* the state. In addition, too late data older than watermark will be dropped to avoid any
|
||||
* possibility of duplicates.
|
||||
*
|
||||
* @group typedrel
|
||||
* @since 2.0.0
|
||||
*/
|
||||
|
@ -1990,13 +1997,19 @@ class Dataset[T] private[sql](
|
|||
* (Scala-specific) Returns a new Dataset with duplicate rows removed, considering only
|
||||
* the subset of columns.
|
||||
*
|
||||
* For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
|
||||
* will keep all data across triggers as intermediate state to drop duplicates rows. You can use
|
||||
* [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
|
||||
* the state. In addition, too late data older than watermark will be dropped to avoid any
|
||||
* possibility of duplicates.
|
||||
*
|
||||
* @group typedrel
|
||||
* @since 2.0.0
|
||||
*/
|
||||
def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan {
|
||||
val resolver = sparkSession.sessionState.analyzer.resolver
|
||||
val allColumns = queryExecution.analyzed.output
|
||||
val groupCols = colNames.flatMap { colName =>
|
||||
val groupCols = colNames.toSet.toSeq.flatMap { (colName: String) =>
|
||||
// It is possibly there are more than one columns with the same name,
|
||||
// so we call filter instead of find.
|
||||
val cols = allColumns.filter(col => resolver(col.name, colName))
|
||||
|
@ -2006,21 +2019,19 @@ class Dataset[T] private[sql](
|
|||
}
|
||||
cols
|
||||
}
|
||||
val groupColExprIds = groupCols.map(_.exprId)
|
||||
val aggCols = logicalPlan.output.map { attr =>
|
||||
if (groupColExprIds.contains(attr.exprId)) {
|
||||
attr
|
||||
} else {
|
||||
Alias(new First(attr).toAggregateExpression(), attr.name)()
|
||||
}
|
||||
}
|
||||
Aggregate(groupCols, aggCols, logicalPlan)
|
||||
Deduplicate(groupCols, logicalPlan, isStreaming)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a new Dataset with duplicate rows removed, considering only
|
||||
* the subset of columns.
|
||||
*
|
||||
* For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
|
||||
* will keep all data across triggers as intermediate state to drop duplicates rows. You can use
|
||||
* [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
|
||||
* the state. In addition, too late data older than watermark will be dropped to avoid any
|
||||
* possibility of duplicates.
|
||||
*
|
||||
* @group typedrel
|
||||
* @since 2.0.0
|
||||
*/
|
||||
|
@ -2030,6 +2041,12 @@ class Dataset[T] private[sql](
|
|||
* Returns a new [[Dataset]] with duplicate rows removed, considering only
|
||||
* the subset of columns.
|
||||
*
|
||||
* For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
|
||||
* will keep all data across triggers as intermediate state to drop duplicates rows. You can use
|
||||
* [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
|
||||
* the state. In addition, too late data older than watermark will be dropped to avoid any
|
||||
* possibility of duplicates.
|
||||
*
|
||||
* @group typedrel
|
||||
* @since 2.0.0
|
||||
*/
|
||||
|
|
|
@ -22,9 +22,10 @@ import org.apache.spark.sql.{SaveMode, Strategy}
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.First
|
||||
import org.apache.spark.sql.catalyst.planning._
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan, MapGroupsWithState}
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution
|
||||
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
|
||||
|
@ -244,6 +245,18 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Used to plan the streaming deduplicate operator.
|
||||
*/
|
||||
object StreamingDeduplicationStrategy extends Strategy {
|
||||
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
|
||||
case Deduplicate(keys, child, true) =>
|
||||
StreamingDeduplicateExec(keys, planLater(child)) :: Nil
|
||||
|
||||
case _ => Nil
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
|
||||
*/
|
||||
|
|
|
@ -45,6 +45,7 @@ class IncrementalExecution(
|
|||
sparkSession.sessionState.planner.StatefulAggregationStrategy +:
|
||||
sparkSession.sessionState.planner.MapGroupsWithStateStrategy +:
|
||||
sparkSession.sessionState.planner.StreamingRelationStrategy +:
|
||||
sparkSession.sessionState.planner.StreamingDeduplicationStrategy +:
|
||||
sparkSession.sessionState.experimentalMethods.extraStrategies
|
||||
|
||||
// Modified planner with stateful operations.
|
||||
|
@ -93,6 +94,15 @@ class IncrementalExecution(
|
|||
keys,
|
||||
Some(stateId),
|
||||
child) :: Nil))
|
||||
case StreamingDeduplicateExec(keys, child, None, None) =>
|
||||
val stateId =
|
||||
OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
|
||||
|
||||
StreamingDeduplicateExec(
|
||||
keys,
|
||||
child,
|
||||
Some(stateId),
|
||||
Some(currentEventTimeWatermark))
|
||||
case MapGroupsWithStateExec(
|
||||
f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) =>
|
||||
val stateId =
|
||||
|
|
|
@ -25,12 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjecti
|
|||
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState}
|
||||
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
|
||||
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
|
||||
import org.apache.spark.sql.execution
|
||||
import org.apache.spark.sql.execution._
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
import org.apache.spark.sql.execution.streaming.state._
|
||||
import org.apache.spark.sql.streaming.OutputMode
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.sql.types.{DataType, NullType, StructType}
|
||||
import org.apache.spark.util.CompletionIterator
|
||||
|
||||
|
||||
|
@ -68,6 +67,40 @@ trait StateStoreWriter extends StatefulOperator {
|
|||
"numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"))
|
||||
}
|
||||
|
||||
/** An operator that supports watermark. */
|
||||
trait WatermarkSupport extends SparkPlan {
|
||||
|
||||
/** The keys that may have a watermark attribute. */
|
||||
def keyExpressions: Seq[Attribute]
|
||||
|
||||
/** The watermark value. */
|
||||
def eventTimeWatermark: Option[Long]
|
||||
|
||||
/** Generate a predicate that matches data older than the watermark */
|
||||
lazy val watermarkPredicate: Option[Predicate] = {
|
||||
val optionalWatermarkAttribute =
|
||||
keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey))
|
||||
|
||||
optionalWatermarkAttribute.map { watermarkAttribute =>
|
||||
// If we are evicting based on a window, use the end of the window. Otherwise just
|
||||
// use the attribute itself.
|
||||
val evictionExpression =
|
||||
if (watermarkAttribute.dataType.isInstanceOf[StructType]) {
|
||||
LessThanOrEqual(
|
||||
GetStructField(watermarkAttribute, 1),
|
||||
Literal(eventTimeWatermark.get * 1000))
|
||||
} else {
|
||||
LessThanOrEqual(
|
||||
watermarkAttribute,
|
||||
Literal(eventTimeWatermark.get * 1000))
|
||||
}
|
||||
|
||||
logInfo(s"Filtering state store on: $evictionExpression")
|
||||
newPredicate(evictionExpression, keyExpressions)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* For each input tuple, the key is calculated and the value from the [[StateStore]] is added
|
||||
* to the stream (in addition to the input tuple) if present.
|
||||
|
@ -76,7 +109,7 @@ case class StateStoreRestoreExec(
|
|||
keyExpressions: Seq[Attribute],
|
||||
stateId: Option[OperatorStateId],
|
||||
child: SparkPlan)
|
||||
extends execution.UnaryExecNode with StateStoreReader {
|
||||
extends UnaryExecNode with StateStoreReader {
|
||||
|
||||
override protected def doExecute(): RDD[InternalRow] = {
|
||||
val numOutputRows = longMetric("numOutputRows")
|
||||
|
@ -113,31 +146,7 @@ case class StateStoreSaveExec(
|
|||
outputMode: Option[OutputMode] = None,
|
||||
eventTimeWatermark: Option[Long] = None,
|
||||
child: SparkPlan)
|
||||
extends execution.UnaryExecNode with StateStoreWriter {
|
||||
|
||||
/** Generate a predicate that matches data older than the watermark */
|
||||
private lazy val watermarkPredicate: Option[Predicate] = {
|
||||
val optionalWatermarkAttribute =
|
||||
keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey))
|
||||
|
||||
optionalWatermarkAttribute.map { watermarkAttribute =>
|
||||
// If we are evicting based on a window, use the end of the window. Otherwise just
|
||||
// use the attribute itself.
|
||||
val evictionExpression =
|
||||
if (watermarkAttribute.dataType.isInstanceOf[StructType]) {
|
||||
LessThanOrEqual(
|
||||
GetStructField(watermarkAttribute, 1),
|
||||
Literal(eventTimeWatermark.get * 1000))
|
||||
} else {
|
||||
LessThanOrEqual(
|
||||
watermarkAttribute,
|
||||
Literal(eventTimeWatermark.get * 1000))
|
||||
}
|
||||
|
||||
logInfo(s"Filtering state store on: $evictionExpression")
|
||||
newPredicate(evictionExpression, keyExpressions)
|
||||
}
|
||||
}
|
||||
extends UnaryExecNode with StateStoreWriter with WatermarkSupport {
|
||||
|
||||
override protected def doExecute(): RDD[InternalRow] = {
|
||||
metrics // force lazy init at driver
|
||||
|
@ -146,8 +155,8 @@ case class StateStoreSaveExec(
|
|||
|
||||
child.execute().mapPartitionsWithStateStore(
|
||||
getStateId.checkpointLocation,
|
||||
operatorId = getStateId.operatorId,
|
||||
storeVersion = getStateId.batchId,
|
||||
getStateId.operatorId,
|
||||
getStateId.batchId,
|
||||
keyExpressions.toStructType,
|
||||
child.output.toStructType,
|
||||
sqlContext.sessionState,
|
||||
|
@ -262,8 +271,8 @@ case class MapGroupsWithStateExec(
|
|||
override protected def doExecute(): RDD[InternalRow] = {
|
||||
child.execute().mapPartitionsWithStateStore[InternalRow](
|
||||
getStateId.checkpointLocation,
|
||||
operatorId = getStateId.operatorId,
|
||||
storeVersion = getStateId.batchId,
|
||||
getStateId.operatorId,
|
||||
getStateId.batchId,
|
||||
groupingAttributes.toStructType,
|
||||
child.output.toStructType,
|
||||
sqlContext.sessionState,
|
||||
|
@ -321,3 +330,70 @@ case class MapGroupsWithStateExec(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/** Physical operator for executing streaming Deduplicate. */
|
||||
case class StreamingDeduplicateExec(
|
||||
keyExpressions: Seq[Attribute],
|
||||
child: SparkPlan,
|
||||
stateId: Option[OperatorStateId] = None,
|
||||
eventTimeWatermark: Option[Long] = None)
|
||||
extends UnaryExecNode with StateStoreWriter with WatermarkSupport {
|
||||
|
||||
/** Distribute by grouping attributes */
|
||||
override def requiredChildDistribution: Seq[Distribution] =
|
||||
ClusteredDistribution(keyExpressions) :: Nil
|
||||
|
||||
override protected def doExecute(): RDD[InternalRow] = {
|
||||
metrics // force lazy init at driver
|
||||
|
||||
child.execute().mapPartitionsWithStateStore(
|
||||
getStateId.checkpointLocation,
|
||||
getStateId.operatorId,
|
||||
getStateId.batchId,
|
||||
keyExpressions.toStructType,
|
||||
child.output.toStructType,
|
||||
sqlContext.sessionState,
|
||||
Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
|
||||
val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
|
||||
val numOutputRows = longMetric("numOutputRows")
|
||||
val numTotalStateRows = longMetric("numTotalStateRows")
|
||||
val numUpdatedStateRows = longMetric("numUpdatedStateRows")
|
||||
|
||||
val baseIterator = watermarkPredicate match {
|
||||
case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row))
|
||||
case None => iter
|
||||
}
|
||||
|
||||
val result = baseIterator.filter { r =>
|
||||
val row = r.asInstanceOf[UnsafeRow]
|
||||
val key = getKey(row)
|
||||
val value = store.get(key)
|
||||
if (value.isEmpty) {
|
||||
store.put(key.copy(), StreamingDeduplicateExec.EMPTY_ROW)
|
||||
numUpdatedStateRows += 1
|
||||
numOutputRows += 1
|
||||
true
|
||||
} else {
|
||||
// Drop duplicated rows
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
CompletionIterator[InternalRow, Iterator[InternalRow]](result, {
|
||||
watermarkPredicate.foreach(f => store.remove(f.eval _))
|
||||
store.commit()
|
||||
numTotalStateRows += store.numKeys()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
override def output: Seq[Attribute] = child.output
|
||||
|
||||
override def outputPartitioning: Partitioning = child.outputPartitioning
|
||||
}
|
||||
|
||||
object StreamingDeduplicateExec {
|
||||
private val EMPTY_ROW =
|
||||
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
|
||||
}
|
||||
|
|
|
@ -0,0 +1,252 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.streaming
|
||||
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
||||
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
|
||||
import org.apache.spark.sql.execution.streaming.MemoryStream
|
||||
import org.apache.spark.sql.execution.streaming.state.StateStore
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
super.afterAll()
|
||||
StateStore.stop()
|
||||
}
|
||||
|
||||
test("deduplicate with all columns") {
|
||||
val inputData = MemoryStream[String]
|
||||
val result = inputData.toDS().dropDuplicates()
|
||||
|
||||
testStream(result, Append)(
|
||||
AddData(inputData, "a"),
|
||||
CheckLastBatch("a"),
|
||||
assertNumStateRows(total = 1, updated = 1),
|
||||
AddData(inputData, "a"),
|
||||
CheckLastBatch(),
|
||||
assertNumStateRows(total = 1, updated = 0),
|
||||
AddData(inputData, "b"),
|
||||
CheckLastBatch("b"),
|
||||
assertNumStateRows(total = 2, updated = 1)
|
||||
)
|
||||
}
|
||||
|
||||
test("deduplicate with some columns") {
|
||||
val inputData = MemoryStream[(String, Int)]
|
||||
val result = inputData.toDS().dropDuplicates("_1")
|
||||
|
||||
testStream(result, Append)(
|
||||
AddData(inputData, "a" -> 1),
|
||||
CheckLastBatch("a" -> 1),
|
||||
assertNumStateRows(total = 1, updated = 1),
|
||||
AddData(inputData, "a" -> 2), // Dropped
|
||||
CheckLastBatch(),
|
||||
assertNumStateRows(total = 1, updated = 0),
|
||||
AddData(inputData, "b" -> 1),
|
||||
CheckLastBatch("b" -> 1),
|
||||
assertNumStateRows(total = 2, updated = 1)
|
||||
)
|
||||
}
|
||||
|
||||
test("multiple deduplicates") {
|
||||
val inputData = MemoryStream[(String, Int)]
|
||||
val result = inputData.toDS().dropDuplicates().dropDuplicates("_1")
|
||||
|
||||
testStream(result, Append)(
|
||||
AddData(inputData, "a" -> 1),
|
||||
CheckLastBatch("a" -> 1),
|
||||
assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)),
|
||||
|
||||
AddData(inputData, "a" -> 2), // Dropped from the second `dropDuplicates`
|
||||
CheckLastBatch(),
|
||||
assertNumStateRows(total = Seq(1L, 2L), updated = Seq(0L, 1L)),
|
||||
|
||||
AddData(inputData, "b" -> 1),
|
||||
CheckLastBatch("b" -> 1),
|
||||
assertNumStateRows(total = Seq(2L, 3L), updated = Seq(1L, 1L))
|
||||
)
|
||||
}
|
||||
|
||||
test("deduplicate with watermark") {
|
||||
val inputData = MemoryStream[Int]
|
||||
val result = inputData.toDS()
|
||||
.withColumn("eventTime", $"value".cast("timestamp"))
|
||||
.withWatermark("eventTime", "10 seconds")
|
||||
.dropDuplicates()
|
||||
.select($"eventTime".cast("long").as[Long])
|
||||
|
||||
testStream(result, Append)(
|
||||
AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*),
|
||||
CheckLastBatch(10 to 15: _*),
|
||||
assertNumStateRows(total = 6, updated = 6),
|
||||
|
||||
AddData(inputData, 25), // Advance watermark to 15 seconds
|
||||
CheckLastBatch(25),
|
||||
assertNumStateRows(total = 7, updated = 1),
|
||||
|
||||
AddData(inputData, 25), // Drop states less than watermark
|
||||
CheckLastBatch(),
|
||||
assertNumStateRows(total = 1, updated = 0),
|
||||
|
||||
AddData(inputData, 10), // Should not emit anything as data less than watermark
|
||||
CheckLastBatch(),
|
||||
assertNumStateRows(total = 1, updated = 0),
|
||||
|
||||
AddData(inputData, 45), // Advance watermark to 35 seconds
|
||||
CheckLastBatch(45),
|
||||
assertNumStateRows(total = 2, updated = 1),
|
||||
|
||||
AddData(inputData, 45), // Drop states less than watermark
|
||||
CheckLastBatch(),
|
||||
assertNumStateRows(total = 1, updated = 0)
|
||||
)
|
||||
}
|
||||
|
||||
test("deduplicate with aggregate - append mode") {
|
||||
val inputData = MemoryStream[Int]
|
||||
val windowedaggregate = inputData.toDS()
|
||||
.withColumn("eventTime", $"value".cast("timestamp"))
|
||||
.withWatermark("eventTime", "10 seconds")
|
||||
.dropDuplicates()
|
||||
.withWatermark("eventTime", "10 seconds")
|
||||
.groupBy(window($"eventTime", "5 seconds") as 'window)
|
||||
.agg(count("*") as 'count)
|
||||
.select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
|
||||
|
||||
testStream(windowedaggregate)(
|
||||
AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*),
|
||||
CheckLastBatch(),
|
||||
// states in aggregate in [10, 14), [15, 20) (2 windows)
|
||||
// states in deduplicate is 10 to 15
|
||||
assertNumStateRows(total = Seq(2L, 6L), updated = Seq(2L, 6L)),
|
||||
|
||||
AddData(inputData, 25), // Advance watermark to 15 seconds
|
||||
CheckLastBatch(),
|
||||
// states in aggregate in [10, 14), [15, 20) and [25, 30) (3 windows)
|
||||
// states in deduplicate is 10 to 15 and 25
|
||||
assertNumStateRows(total = Seq(3L, 7L), updated = Seq(1L, 1L)),
|
||||
|
||||
AddData(inputData, 25), // Emit items less than watermark and drop their state
|
||||
CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate
|
||||
// states in aggregate in [15, 20) and [25, 30) (2 windows, note aggregate uses the end of
|
||||
// window to evict items, so [15, 20) is still in the state store)
|
||||
// states in deduplicate is 25
|
||||
assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)),
|
||||
|
||||
AddData(inputData, 10), // Should not emit anything as data less than watermark
|
||||
CheckLastBatch(),
|
||||
assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)),
|
||||
|
||||
AddData(inputData, 40), // Advance watermark to 30 seconds
|
||||
CheckLastBatch(),
|
||||
// states in aggregate in [15, 20), [25, 30) and [40, 45)
|
||||
// states in deduplicate is 25 and 40,
|
||||
assertNumStateRows(total = Seq(3L, 2L), updated = Seq(1L, 1L)),
|
||||
|
||||
AddData(inputData, 40), // Emit items less than watermark and drop their state
|
||||
CheckLastBatch((15 -> 1), (25 -> 1)),
|
||||
// states in aggregate in [40, 45)
|
||||
// states in deduplicate is 40,
|
||||
assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L))
|
||||
)
|
||||
}
|
||||
|
||||
test("deduplicate with aggregate - update mode") {
|
||||
val inputData = MemoryStream[(String, Int)]
|
||||
val result = inputData.toDS()
|
||||
.select($"_1" as "str", $"_2" as "num")
|
||||
.dropDuplicates()
|
||||
.groupBy("str")
|
||||
.agg(sum("num"))
|
||||
.as[(String, Long)]
|
||||
|
||||
testStream(result, Update)(
|
||||
AddData(inputData, "a" -> 1),
|
||||
CheckLastBatch("a" -> 1L),
|
||||
assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)),
|
||||
AddData(inputData, "a" -> 1), // Dropped
|
||||
CheckLastBatch(),
|
||||
assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)),
|
||||
AddData(inputData, "a" -> 2),
|
||||
CheckLastBatch("a" -> 3L),
|
||||
assertNumStateRows(total = Seq(1L, 2L), updated = Seq(1L, 1L)),
|
||||
AddData(inputData, "b" -> 1),
|
||||
CheckLastBatch("b" -> 1L),
|
||||
assertNumStateRows(total = Seq(2L, 3L), updated = Seq(1L, 1L))
|
||||
)
|
||||
}
|
||||
|
||||
test("deduplicate with aggregate - complete mode") {
|
||||
val inputData = MemoryStream[(String, Int)]
|
||||
val result = inputData.toDS()
|
||||
.select($"_1" as "str", $"_2" as "num")
|
||||
.dropDuplicates()
|
||||
.groupBy("str")
|
||||
.agg(sum("num"))
|
||||
.as[(String, Long)]
|
||||
|
||||
testStream(result, Complete)(
|
||||
AddData(inputData, "a" -> 1),
|
||||
CheckLastBatch("a" -> 1L),
|
||||
assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)),
|
||||
AddData(inputData, "a" -> 1), // Dropped
|
||||
CheckLastBatch("a" -> 1L),
|
||||
assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)),
|
||||
AddData(inputData, "a" -> 2),
|
||||
CheckLastBatch("a" -> 3L),
|
||||
assertNumStateRows(total = Seq(1L, 2L), updated = Seq(1L, 1L)),
|
||||
AddData(inputData, "b" -> 1),
|
||||
CheckLastBatch("a" -> 3L, "b" -> 1L),
|
||||
assertNumStateRows(total = Seq(2L, 3L), updated = Seq(1L, 1L))
|
||||
)
|
||||
}
|
||||
|
||||
test("deduplicate with file sink") {
|
||||
withTempDir { output =>
|
||||
withTempDir { checkpointDir =>
|
||||
val outputPath = output.getAbsolutePath
|
||||
val inputData = MemoryStream[String]
|
||||
val result = inputData.toDS().dropDuplicates()
|
||||
val q = result.writeStream
|
||||
.format("parquet")
|
||||
.outputMode(Append)
|
||||
.option("checkpointLocation", checkpointDir.getPath)
|
||||
.start(outputPath)
|
||||
try {
|
||||
inputData.addData("a")
|
||||
q.processAllAvailable()
|
||||
checkDataset(spark.read.parquet(outputPath).as[String], "a")
|
||||
|
||||
inputData.addData("a") // Dropped
|
||||
q.processAllAvailable()
|
||||
checkDataset(spark.read.parquet(outputPath).as[String], "a")
|
||||
|
||||
inputData.addData("b")
|
||||
q.processAllAvailable()
|
||||
checkDataset(spark.read.parquet(outputPath).as[String], "a", "b")
|
||||
} finally {
|
||||
q.stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore
|
|||
/** Class to check custom state types */
|
||||
case class RunningCount(count: Long)
|
||||
|
||||
class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll {
|
||||
class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
|
@ -321,13 +321,6 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll {
|
|||
CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count
|
||||
)
|
||||
}
|
||||
|
||||
private def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = AssertOnQuery { q =>
|
||||
val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
|
||||
assert(progressWithData.stateOperators(0).numRowsTotal === total, "incorrect total rows")
|
||||
assert(progressWithData.stateOperators(0).numRowsUpdated === updated, "incorrect updates rows")
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
object MapGroupsWithStateSuite {
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.streaming
|
||||
|
||||
trait StateStoreMetricsTest extends StreamTest {
|
||||
|
||||
def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery =
|
||||
AssertOnQuery { q =>
|
||||
val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
|
||||
assert(
|
||||
progressWithData.stateOperators.map(_.numRowsTotal) === total,
|
||||
"incorrect total rows")
|
||||
assert(
|
||||
progressWithData.stateOperators.map(_.numRowsUpdated) === updated,
|
||||
"incorrect updates rows")
|
||||
true
|
||||
}
|
||||
|
||||
def assertNumStateRows(total: Long, updated: Long): AssertOnQuery =
|
||||
assertNumStateRows(Seq(total), Seq(updated))
|
||||
}
|
|
@ -338,7 +338,7 @@ class StreamSuite extends StreamTest {
|
|||
.writeStream
|
||||
.format("memory")
|
||||
.queryName("testquery")
|
||||
.outputMode("complete")
|
||||
.outputMode("append")
|
||||
.start()
|
||||
try {
|
||||
query.processAllAvailable()
|
||||
|
|
|
@ -35,7 +35,7 @@ object FailureSinglton {
|
|||
var firstTime = true
|
||||
}
|
||||
|
||||
class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll {
|
||||
class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
super.afterAll()
|
||||
|
|
Loading…
Reference in a new issue