diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 10c42a7338..5ee596e06d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -263,6 +263,7 @@ class StreamExecution( try { sparkSession.sparkContext.setJobGroup(runId.toString, getBatchDescriptionString, interruptOnCancel = true) + sparkSession.sparkContext.setLocalProperty(StreamExecution.QUERY_ID_KEY, id.toString) if (sparkSession.sessionState.conf.streamingMetricsEnabled) { sparkSession.sparkContext.env.metricsSystem.registerSource(streamMetrics) } @@ -842,6 +843,9 @@ class StreamExecution( } } +object StreamExecution { + val QUERY_ID_KEY = "sql.streaming.queryId" +} /** * A special thread to run the stream query. Some codes require to run in the StreamExecutionThread diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 0925646beb..41f73b8529 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -613,6 +613,33 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + test("get the query id in source") { + @volatile var queryId: String = null + val source = new Source { + override def stop(): Unit = {} + override def getOffset: Option[Offset] = { + queryId = spark.sparkContext.getLocalProperty(StreamExecution.QUERY_ID_KEY) + None + } + override def getBatch(start: Option[Offset], end: Offset): DataFrame = spark.emptyDataFrame + override def schema: StructType = MockSourceProvider.fakeSchema + } + + MockSourceProvider.withMockSources(source) { + val df = spark.readStream + .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + .load() + testStream(df)( + AssertOnQuery { sq => + sq.processAllAvailable() + assert(sq.id.toString === queryId) + assert(sq.runId.toString !== queryId) + true + } + ) + } + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming)