[SPARK-23093][SS] Don't change run id when reconfiguring a continuous processing query.
## What changes were proposed in this pull request? Keep the run ID static, using a different ID for the epoch coordinator to avoid cross-execution message contamination. ## How was this patch tested? new and existing unit tests Author: Jose Torres <jose@databricks.com> Closes #20282 from jose-torres/fix-runid.
This commit is contained in:
parent
86a8450318
commit
e946c63dd5
|
@ -58,7 +58,8 @@ case class DataSourceV2ScanExec(
|
|||
|
||||
case _: ContinuousReader =>
|
||||
EpochCoordinatorRef.get(
|
||||
sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env)
|
||||
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
|
||||
sparkContext.env)
|
||||
.askSync[Unit](SetReaderPartitions(readTasks.size()))
|
||||
new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks)
|
||||
.asInstanceOf[RDD[InternalRow]]
|
||||
|
|
|
@ -64,7 +64,8 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan)
|
|||
val runTask = writer match {
|
||||
case w: ContinuousWriter =>
|
||||
EpochCoordinatorRef.get(
|
||||
sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env)
|
||||
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
|
||||
sparkContext.env)
|
||||
.askSync[Unit](SetWriterPartitions(rdd.getNumPartitions))
|
||||
|
||||
(context: TaskContext, iter: Iterator[InternalRow]) =>
|
||||
|
@ -135,7 +136,7 @@ object DataWritingSparkTask extends Logging {
|
|||
iter: Iterator[InternalRow]): WriterCommitMessage = {
|
||||
val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber())
|
||||
val epochCoordinator = EpochCoordinatorRef.get(
|
||||
context.getLocalProperty(ContinuousExecution.RUN_ID_KEY),
|
||||
context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
|
||||
SparkEnv.get)
|
||||
val currentMsg: WriterCommitMessage = null
|
||||
var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
|
||||
|
|
|
@ -142,8 +142,7 @@ abstract class StreamExecution(
|
|||
|
||||
override val id: UUID = UUID.fromString(streamMetadata.id)
|
||||
|
||||
override def runId: UUID = currentRunId
|
||||
protected var currentRunId = UUID.randomUUID
|
||||
override val runId: UUID = UUID.randomUUID
|
||||
|
||||
/**
|
||||
* Pretty identified string of printing in logs. Format is
|
||||
|
|
|
@ -59,7 +59,7 @@ class ContinuousDataSourceRDD(
|
|||
|
||||
val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader()
|
||||
|
||||
val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)
|
||||
val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)
|
||||
|
||||
// This queue contains two types of messages:
|
||||
// * (null, null) representing an epoch boundary.
|
||||
|
@ -68,7 +68,7 @@ class ContinuousDataSourceRDD(
|
|||
|
||||
val epochPollFailed = new AtomicBoolean(false)
|
||||
val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
|
||||
s"epoch-poll--${runId}--${context.partitionId()}")
|
||||
s"epoch-poll--$coordinatorId--${context.partitionId()}")
|
||||
val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed)
|
||||
epochPollExecutor.scheduleWithFixedDelay(
|
||||
epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS)
|
||||
|
@ -86,7 +86,7 @@ class ContinuousDataSourceRDD(
|
|||
epochPollExecutor.shutdown()
|
||||
})
|
||||
|
||||
val epochEndpoint = EpochCoordinatorRef.get(runId, SparkEnv.get)
|
||||
val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get)
|
||||
new Iterator[UnsafeRow] {
|
||||
private val POLL_TIMEOUT_MS = 1000
|
||||
|
||||
|
@ -150,7 +150,7 @@ class EpochPollRunnable(
|
|||
private[continuous] var failureReason: Throwable = _
|
||||
|
||||
private val epochEndpoint = EpochCoordinatorRef.get(
|
||||
context.getLocalProperty(ContinuousExecution.RUN_ID_KEY), SparkEnv.get)
|
||||
context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get)
|
||||
private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
|
||||
|
||||
override def run(): Unit = {
|
||||
|
@ -177,7 +177,7 @@ class DataReaderThread(
|
|||
failedFlag: AtomicBoolean)
|
||||
extends Thread(
|
||||
s"continuous-reader--${context.partitionId()}--" +
|
||||
s"${context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)}") {
|
||||
s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") {
|
||||
private[continuous] var failureReason: Throwable = _
|
||||
|
||||
override def run(): Unit = {
|
||||
|
|
|
@ -57,6 +57,9 @@ class ContinuousExecution(
|
|||
@volatile protected var continuousSources: Seq[ContinuousReader] = _
|
||||
override protected def sources: Seq[BaseStreamingSource] = continuousSources
|
||||
|
||||
// For use only in test harnesses.
|
||||
private[sql] var currentEpochCoordinatorId: String = _
|
||||
|
||||
override lazy val logicalPlan: LogicalPlan = {
|
||||
assert(queryExecutionThread eq Thread.currentThread,
|
||||
"logicalPlan must be initialized in StreamExecutionThread " +
|
||||
|
@ -149,7 +152,6 @@ class ContinuousExecution(
|
|||
* @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with.
|
||||
*/
|
||||
private def runContinuous(sparkSessionForQuery: SparkSession): Unit = {
|
||||
currentRunId = UUID.randomUUID
|
||||
// A list of attributes that will need to be updated.
|
||||
val replacements = new ArrayBuffer[(Attribute, Attribute)]
|
||||
// Translate from continuous relation to the underlying data source.
|
||||
|
@ -219,15 +221,19 @@ class ContinuousExecution(
|
|||
lastExecution.executedPlan // Force the lazy generation of execution plan
|
||||
}
|
||||
|
||||
sparkSession.sparkContext.setLocalProperty(
|
||||
sparkSessionForQuery.sparkContext.setLocalProperty(
|
||||
ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString)
|
||||
sparkSession.sparkContext.setLocalProperty(
|
||||
ContinuousExecution.RUN_ID_KEY, runId.toString)
|
||||
// Add another random ID on top of the run ID, to distinguish epoch coordinators across
|
||||
// reconfigurations.
|
||||
val epochCoordinatorId = s"$runId--${UUID.randomUUID}"
|
||||
currentEpochCoordinatorId = epochCoordinatorId
|
||||
sparkSessionForQuery.sparkContext.setLocalProperty(
|
||||
ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId)
|
||||
|
||||
// Use the parent Spark session for the endpoint since it's where this query ID is registered.
|
||||
val epochEndpoint =
|
||||
EpochCoordinatorRef.create(
|
||||
writer.get(), reader, this, currentBatchId, sparkSession, SparkEnv.get)
|
||||
writer.get(), reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get)
|
||||
val epochUpdateThread = new Thread(new Runnable {
|
||||
override def run: Unit = {
|
||||
try {
|
||||
|
@ -359,5 +365,5 @@ class ContinuousExecution(
|
|||
|
||||
object ContinuousExecution {
|
||||
val START_EPOCH_KEY = "__continuous_start_epoch"
|
||||
val RUN_ID_KEY = "__run_id"
|
||||
val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id"
|
||||
}
|
||||
|
|
|
@ -79,7 +79,7 @@ private[sql] case class ReportPartitionOffset(
|
|||
|
||||
/** Helper object used to create reference to [[EpochCoordinator]]. */
|
||||
private[sql] object EpochCoordinatorRef extends Logging {
|
||||
private def endpointName(runId: String) = s"EpochCoordinator-$runId"
|
||||
private def endpointName(id: String) = s"EpochCoordinator-$id"
|
||||
|
||||
/**
|
||||
* Create a reference to a new [[EpochCoordinator]].
|
||||
|
@ -88,18 +88,19 @@ private[sql] object EpochCoordinatorRef extends Logging {
|
|||
writer: ContinuousWriter,
|
||||
reader: ContinuousReader,
|
||||
query: ContinuousExecution,
|
||||
epochCoordinatorId: String,
|
||||
startEpoch: Long,
|
||||
session: SparkSession,
|
||||
env: SparkEnv): RpcEndpointRef = synchronized {
|
||||
val coordinator = new EpochCoordinator(
|
||||
writer, reader, query, startEpoch, session, env.rpcEnv)
|
||||
val ref = env.rpcEnv.setupEndpoint(endpointName(query.runId.toString()), coordinator)
|
||||
val ref = env.rpcEnv.setupEndpoint(endpointName(epochCoordinatorId), coordinator)
|
||||
logInfo("Registered EpochCoordinator endpoint")
|
||||
ref
|
||||
}
|
||||
|
||||
def get(runId: String, env: SparkEnv): RpcEndpointRef = synchronized {
|
||||
val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(runId), env.conf, env.rpcEnv)
|
||||
def get(id: String, env: SparkEnv): RpcEndpointRef = synchronized {
|
||||
val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(id), env.conf, env.rpcEnv)
|
||||
logDebug("Retrieved existing EpochCoordinator endpoint")
|
||||
rpcEndpointRef
|
||||
}
|
||||
|
|
|
@ -263,7 +263,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
|||
def apply(): AssertOnQuery =
|
||||
Execute {
|
||||
case s: ContinuousExecution =>
|
||||
val newEpoch = EpochCoordinatorRef.get(s.runId.toString, SparkEnv.get)
|
||||
val newEpoch = EpochCoordinatorRef.get(s.currentEpochCoordinatorId, SparkEnv.get)
|
||||
.askSync[Long](IncrementAndGetEpoch)
|
||||
s.awaitEpoch(newEpoch - 1)
|
||||
case _ => throw new IllegalStateException("microbatch cannot increment epoch")
|
||||
|
|
|
@ -174,6 +174,31 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
|
|||
}
|
||||
}
|
||||
|
||||
test("continuous processing listeners should receive QueryTerminatedEvent") {
|
||||
val df = spark.readStream.format("rate").load()
|
||||
val listeners = (1 to 5).map(_ => new EventCollector)
|
||||
try {
|
||||
listeners.foreach(listener => spark.streams.addListener(listener))
|
||||
testStream(df, OutputMode.Append, useV2Sink = true)(
|
||||
StartStream(Trigger.Continuous(1000)),
|
||||
StopStream,
|
||||
AssertOnQuery { query =>
|
||||
eventually(Timeout(streamingTimeout)) {
|
||||
listeners.foreach(listener => assert(listener.terminationEvent !== null))
|
||||
listeners.foreach(listener => assert(listener.terminationEvent.id === query.id))
|
||||
listeners.foreach(listener => assert(listener.terminationEvent.runId === query.runId))
|
||||
listeners.foreach(listener => assert(listener.terminationEvent.exception === None))
|
||||
}
|
||||
listeners.foreach(listener => listener.checkAsyncErrors())
|
||||
listeners.foreach(listener => listener.reset())
|
||||
true
|
||||
}
|
||||
)
|
||||
} finally {
|
||||
listeners.foreach(spark.streams.removeListener)
|
||||
}
|
||||
}
|
||||
|
||||
test("adding and removing listener") {
|
||||
def isListenerActive(listener: EventCollector): Boolean = {
|
||||
listener.reset()
|
||||
|
|
Loading…
Reference in a new issue