[SPARK-21565][SS] Propagate metadata in attribute replacement.
## What changes were proposed in this pull request? Propagate metadata in attribute replacement during streaming execution. This is necessary for EventTimeWatermarks consuming replaced attributes. ## How was this patch tested? new unit test, which was verified to fail before the fix Author: Jose Torres <joseph-torres@databricks.com> Closes #18840 from joseph-torres/SPARK-21565.
This commit is contained in:
parent
4f7ec3a316
commit
cce25b360e
|
@ -628,7 +628,8 @@ class StreamExecution(
|
|||
// Rewire the plan to use the new attributes that were returned by the source.
|
||||
val replacementMap = AttributeMap(replacements)
|
||||
val triggerLogicalPlan = withNewSources transformAllExpressions {
|
||||
case a: Attribute if replacementMap.contains(a) => replacementMap(a)
|
||||
case a: Attribute if replacementMap.contains(a) =>
|
||||
replacementMap(a).withMetadata(a.metadata)
|
||||
case ct: CurrentTimestamp =>
|
||||
CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
|
||||
ct.dataType)
|
||||
|
|
|
@ -391,6 +391,34 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche
|
|||
checkDataset[Long](df, 1L to 100L: _*)
|
||||
}
|
||||
|
||||
test("SPARK-21565: watermark operator accepts attributes from replacement") {
|
||||
withTempDir { dir =>
|
||||
dir.delete()
|
||||
|
||||
val df = Seq(("a", 100.0, new java.sql.Timestamp(100L)))
|
||||
.toDF("symbol", "price", "eventTime")
|
||||
df.write.json(dir.getCanonicalPath)
|
||||
|
||||
val input = spark.readStream.schema(df.schema)
|
||||
.json(dir.getCanonicalPath)
|
||||
|
||||
val groupEvents = input
|
||||
.withWatermark("eventTime", "2 seconds")
|
||||
.groupBy("symbol", "eventTime")
|
||||
.agg(count("price") as 'count)
|
||||
.select("symbol", "eventTime", "count")
|
||||
val q = groupEvents.writeStream
|
||||
.outputMode("append")
|
||||
.format("console")
|
||||
.start()
|
||||
try {
|
||||
q.processAllAvailable()
|
||||
} finally {
|
||||
q.stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q =>
|
||||
val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
|
||||
assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows)
|
||||
|
|
Loading…
Reference in a new issue