[SPARK-21565][SS] Propagate metadata in attribute replacement.
## What changes were proposed in this pull request? Propagate metadata in attribute replacement during streaming execution. This is necessary for EventTimeWatermarks consuming replaced attributes. ## How was this patch tested? new unit test, which was verified to fail before the fix Author: Jose Torres <joseph-torres@databricks.com> Closes #18840 from joseph-torres/SPARK-21565.
This commit is contained in:
parent
4f7ec3a316
commit
cce25b360e
|
@ -628,7 +628,8 @@ class StreamExecution(
|
||||||
// Rewire the plan to use the new attributes that were returned by the source.
|
// Rewire the plan to use the new attributes that were returned by the source.
|
||||||
val replacementMap = AttributeMap(replacements)
|
val replacementMap = AttributeMap(replacements)
|
||||||
val triggerLogicalPlan = withNewSources transformAllExpressions {
|
val triggerLogicalPlan = withNewSources transformAllExpressions {
|
||||||
case a: Attribute if replacementMap.contains(a) => replacementMap(a)
|
case a: Attribute if replacementMap.contains(a) =>
|
||||||
|
replacementMap(a).withMetadata(a.metadata)
|
||||||
case ct: CurrentTimestamp =>
|
case ct: CurrentTimestamp =>
|
||||||
CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
|
CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
|
||||||
ct.dataType)
|
ct.dataType)
|
||||||
|
|
|
@ -391,6 +391,34 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche
|
||||||
checkDataset[Long](df, 1L to 100L: _*)
|
checkDataset[Long](df, 1L to 100L: _*)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-21565: watermark operator accepts attributes from replacement") {
|
||||||
|
withTempDir { dir =>
|
||||||
|
dir.delete()
|
||||||
|
|
||||||
|
val df = Seq(("a", 100.0, new java.sql.Timestamp(100L)))
|
||||||
|
.toDF("symbol", "price", "eventTime")
|
||||||
|
df.write.json(dir.getCanonicalPath)
|
||||||
|
|
||||||
|
val input = spark.readStream.schema(df.schema)
|
||||||
|
.json(dir.getCanonicalPath)
|
||||||
|
|
||||||
|
val groupEvents = input
|
||||||
|
.withWatermark("eventTime", "2 seconds")
|
||||||
|
.groupBy("symbol", "eventTime")
|
||||||
|
.agg(count("price") as 'count)
|
||||||
|
.select("symbol", "eventTime", "count")
|
||||||
|
val q = groupEvents.writeStream
|
||||||
|
.outputMode("append")
|
||||||
|
.format("console")
|
||||||
|
.start()
|
||||||
|
try {
|
||||||
|
q.processAllAvailable()
|
||||||
|
} finally {
|
||||||
|
q.stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q =>
|
private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q =>
|
||||||
val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
|
val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
|
||||||
assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows)
|
assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows)
|
||||||
|
|
Loading…
Reference in a new issue