[SPARK-23093][SS] Don't change run id when reconfiguring a continuous processing query.

## What changes were proposed in this pull request?

Keep the run ID static, using a different ID for the epoch coordinator to avoid cross-execution message contamination.

## How was this patch tested?

new and existing unit tests

Author: Jose Torres <jose@databricks.com>

Closes #20282 from jose-torres/fix-runid.
This commit is contained in:
Jose Torres 2018-01-17 13:58:44 -08:00 committed by Shixiong Zhu
parent 86a8450318
commit e946c63dd5
8 changed files with 54 additions and 21 deletions

View file

@ -58,7 +58,8 @@ case class DataSourceV2ScanExec(
case _: ContinuousReader =>
EpochCoordinatorRef.get(
sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env)
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
sparkContext.env)
.askSync[Unit](SetReaderPartitions(readTasks.size()))
new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks)
.asInstanceOf[RDD[InternalRow]]

View file

@ -64,7 +64,8 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan)
val runTask = writer match {
case w: ContinuousWriter =>
EpochCoordinatorRef.get(
sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env)
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
sparkContext.env)
.askSync[Unit](SetWriterPartitions(rdd.getNumPartitions))
(context: TaskContext, iter: Iterator[InternalRow]) =>
@ -135,7 +136,7 @@ object DataWritingSparkTask extends Logging {
iter: Iterator[InternalRow]): WriterCommitMessage = {
val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber())
val epochCoordinator = EpochCoordinatorRef.get(
context.getLocalProperty(ContinuousExecution.RUN_ID_KEY),
context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
SparkEnv.get)
val currentMsg: WriterCommitMessage = null
var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong

View file

@ -142,8 +142,7 @@ abstract class StreamExecution(
override val id: UUID = UUID.fromString(streamMetadata.id)
override def runId: UUID = currentRunId
protected var currentRunId = UUID.randomUUID
override val runId: UUID = UUID.randomUUID
/**
* Pretty identified string of printing in logs. Format is

View file

@ -59,7 +59,7 @@ class ContinuousDataSourceRDD(
val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader()
val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)
val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)
// This queue contains two types of messages:
// * (null, null) representing an epoch boundary.
@ -68,7 +68,7 @@ class ContinuousDataSourceRDD(
val epochPollFailed = new AtomicBoolean(false)
val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
s"epoch-poll--${runId}--${context.partitionId()}")
s"epoch-poll--$coordinatorId--${context.partitionId()}")
val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed)
epochPollExecutor.scheduleWithFixedDelay(
epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS)
@ -86,7 +86,7 @@ class ContinuousDataSourceRDD(
epochPollExecutor.shutdown()
})
val epochEndpoint = EpochCoordinatorRef.get(runId, SparkEnv.get)
val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get)
new Iterator[UnsafeRow] {
private val POLL_TIMEOUT_MS = 1000
@ -150,7 +150,7 @@ class EpochPollRunnable(
private[continuous] var failureReason: Throwable = _
private val epochEndpoint = EpochCoordinatorRef.get(
context.getLocalProperty(ContinuousExecution.RUN_ID_KEY), SparkEnv.get)
context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get)
private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
override def run(): Unit = {
@ -177,7 +177,7 @@ class DataReaderThread(
failedFlag: AtomicBoolean)
extends Thread(
s"continuous-reader--${context.partitionId()}--" +
s"${context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)}") {
s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") {
private[continuous] var failureReason: Throwable = _
override def run(): Unit = {

View file

@ -57,6 +57,9 @@ class ContinuousExecution(
@volatile protected var continuousSources: Seq[ContinuousReader] = _
override protected def sources: Seq[BaseStreamingSource] = continuousSources
// For use only in test harnesses.
private[sql] var currentEpochCoordinatorId: String = _
override lazy val logicalPlan: LogicalPlan = {
assert(queryExecutionThread eq Thread.currentThread,
"logicalPlan must be initialized in StreamExecutionThread " +
@ -149,7 +152,6 @@ class ContinuousExecution(
* @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with.
*/
private def runContinuous(sparkSessionForQuery: SparkSession): Unit = {
currentRunId = UUID.randomUUID
// A list of attributes that will need to be updated.
val replacements = new ArrayBuffer[(Attribute, Attribute)]
// Translate from continuous relation to the underlying data source.
@ -219,15 +221,19 @@ class ContinuousExecution(
lastExecution.executedPlan // Force the lazy generation of execution plan
}
sparkSession.sparkContext.setLocalProperty(
sparkSessionForQuery.sparkContext.setLocalProperty(
ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString)
sparkSession.sparkContext.setLocalProperty(
ContinuousExecution.RUN_ID_KEY, runId.toString)
// Add another random ID on top of the run ID, to distinguish epoch coordinators across
// reconfigurations.
val epochCoordinatorId = s"$runId--${UUID.randomUUID}"
currentEpochCoordinatorId = epochCoordinatorId
sparkSessionForQuery.sparkContext.setLocalProperty(
ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId)
// Use the parent Spark session for the endpoint since it's where this query ID is registered.
val epochEndpoint =
EpochCoordinatorRef.create(
writer.get(), reader, this, currentBatchId, sparkSession, SparkEnv.get)
writer.get(), reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get)
val epochUpdateThread = new Thread(new Runnable {
override def run: Unit = {
try {
@ -359,5 +365,5 @@ class ContinuousExecution(
object ContinuousExecution {
val START_EPOCH_KEY = "__continuous_start_epoch"
val RUN_ID_KEY = "__run_id"
val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id"
}

View file

@ -79,7 +79,7 @@ private[sql] case class ReportPartitionOffset(
/** Helper object used to create reference to [[EpochCoordinator]]. */
private[sql] object EpochCoordinatorRef extends Logging {
private def endpointName(runId: String) = s"EpochCoordinator-$runId"
private def endpointName(id: String) = s"EpochCoordinator-$id"
/**
* Create a reference to a new [[EpochCoordinator]].
@ -88,18 +88,19 @@ private[sql] object EpochCoordinatorRef extends Logging {
writer: ContinuousWriter,
reader: ContinuousReader,
query: ContinuousExecution,
epochCoordinatorId: String,
startEpoch: Long,
session: SparkSession,
env: SparkEnv): RpcEndpointRef = synchronized {
val coordinator = new EpochCoordinator(
writer, reader, query, startEpoch, session, env.rpcEnv)
val ref = env.rpcEnv.setupEndpoint(endpointName(query.runId.toString()), coordinator)
val ref = env.rpcEnv.setupEndpoint(endpointName(epochCoordinatorId), coordinator)
logInfo("Registered EpochCoordinator endpoint")
ref
}
def get(runId: String, env: SparkEnv): RpcEndpointRef = synchronized {
val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(runId), env.conf, env.rpcEnv)
def get(id: String, env: SparkEnv): RpcEndpointRef = synchronized {
val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(id), env.conf, env.rpcEnv)
logDebug("Retrieved existing EpochCoordinator endpoint")
rpcEndpointRef
}

View file

@ -263,7 +263,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
def apply(): AssertOnQuery =
Execute {
case s: ContinuousExecution =>
val newEpoch = EpochCoordinatorRef.get(s.runId.toString, SparkEnv.get)
val newEpoch = EpochCoordinatorRef.get(s.currentEpochCoordinatorId, SparkEnv.get)
.askSync[Long](IncrementAndGetEpoch)
s.awaitEpoch(newEpoch - 1)
case _ => throw new IllegalStateException("microbatch cannot increment epoch")

View file

@ -174,6 +174,31 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
}
}
test("continuous processing listeners should receive QueryTerminatedEvent") {
val df = spark.readStream.format("rate").load()
val listeners = (1 to 5).map(_ => new EventCollector)
try {
listeners.foreach(listener => spark.streams.addListener(listener))
testStream(df, OutputMode.Append, useV2Sink = true)(
StartStream(Trigger.Continuous(1000)),
StopStream,
AssertOnQuery { query =>
eventually(Timeout(streamingTimeout)) {
listeners.foreach(listener => assert(listener.terminationEvent !== null))
listeners.foreach(listener => assert(listener.terminationEvent.id === query.id))
listeners.foreach(listener => assert(listener.terminationEvent.runId === query.runId))
listeners.foreach(listener => assert(listener.terminationEvent.exception === None))
}
listeners.foreach(listener => listener.checkAsyncErrors())
listeners.foreach(listener => listener.reset())
true
}
)
} finally {
listeners.foreach(spark.streams.removeListener)
}
}
test("adding and removing listener") {
def isListenerActive(listener: EventCollector): Boolean = {
listener.reset()