[SPARK-19617][SS] Fix the race condition when starting and stopping a query quickly

## What changes were proposed in this pull request?

The streaming thread in StreamExecution uses the following ways to check if it should exit:
- Catch an InterruptException.
- `StreamExecution.state` is TERMINATED.

When starting and stopping a query quickly, the above two checks may both fail:
- Hit [HADOOP-14084](https://issues.apache.org/jira/browse/HADOOP-14084) and swallow InterruptException
- StreamExecution.stop is called before `state` becomes `ACTIVE`. Then [runBatches](dcc2d540a5/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala (L252)) changes the state from `TERMINATED` to `ACTIVE`.

If the above cases both happen, the query will hang forever.

This PR changes `state` to `AtomicReference` and uses`compareAndSet` to make sure we only change the state from `INITIALIZING` to `ACTIVE`. It also removes the `runUninterruptibly` hack from ``HDFSMetadata`, because HADOOP-14084 won't cause any problem after we fix the race condition.

## How was this patch tested?

Jenkins

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #16947 from zsxwing/SPARK-19617.
This commit is contained in:
Shixiong Zhu 2017-02-17 19:04:45 -08:00
parent 988f6d7ee8
commit 15b144d2bf
8 changed files with 62 additions and 105 deletions

View file

@ -55,7 +55,7 @@ class KafkaSourceOffsetSuite extends OffsetSuite with SharedSQLContext {
}
testWithUninterruptibleThread("OffsetSeqLog serialization - deserialization") {
test("OffsetSeqLog serialization - deserialization") {
withTempDir { temp =>
// use non-existent directory to test whether log make the dir
val dir = new File(temp, "dir")

View file

@ -32,7 +32,6 @@ import org.json4s.jackson.Serialization
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.util.UninterruptibleThread
/**
@ -109,39 +108,12 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
override def add(batchId: Long, metadata: T): Boolean = {
get(batchId).map(_ => false).getOrElse {
// Only write metadata when the batch has not yet been written
if (fileManager.isLocalFileSystem) {
Thread.currentThread match {
case ut: UninterruptibleThread =>
// When using a local file system, "writeBatch" must be called on a
// [[org.apache.spark.util.UninterruptibleThread]] so that interrupts can be disabled
// while writing the batch file.
//
// This is because Hadoop "Shell.runCommand" swallows InterruptException (HADOOP-14084).
// If the user tries to stop a query, and the thread running "Shell.runCommand" is
// interrupted, then InterruptException will be dropped and the query will be still
// running. (Note: `writeBatch` creates a file using HDFS APIs and will call
// "Shell.runCommand" to set the file permission if using the local file system)
//
// Hence, we make sure that "writeBatch" is called on [[UninterruptibleThread]] which
// allows us to disable interrupts here, in order to propagate the interrupt state
// correctly. Also see SPARK-19599.
ut.runUninterruptibly { writeBatch(batchId, metadata) }
case _ =>
throw new IllegalStateException(
"HDFSMetadataLog.add() on a local file system must be executed on " +
"a o.a.spark.util.UninterruptibleThread")
}
} else {
// For a distributed file system, such as HDFS or S3, if the network is broken, write
// operations may just hang until timeout. We should enable interrupts to allow stopping
// the query fast.
writeBatch(batchId, metadata)
}
writeBatch(batchId, metadata)
true
}
}
def writeTempBatch(metadata: T): Option[Path] = {
private def writeTempBatch(metadata: T): Option[Path] = {
while (true) {
val tempPath = new Path(metadataPath, s".${UUID.randomUUID.toString}.tmp")
try {
@ -327,9 +299,6 @@ object HDFSMetadataLog {
/** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */
def delete(path: Path): Unit
/** Whether the file systme is a local FS. */
def isLocalFileSystem: Boolean
}
/**
@ -374,13 +343,6 @@ object HDFSMetadataLog {
// ignore if file has already been deleted
}
}
override def isLocalFileSystem: Boolean = fc.getDefaultFileSystem match {
case _: local.LocalFs | _: local.RawLocalFs =>
// LocalFs = RawLocalFs + ChecksumFs
true
case _ => false
}
}
/**
@ -437,12 +399,5 @@ object HDFSMetadataLog {
// ignore if file has already been deleted
}
}
override def isLocalFileSystem: Boolean = fs match {
case _: LocalFileSystem | _: RawLocalFileSystem =>
// LocalFileSystem = RawLocalFileSystem + ChecksumFileSystem
true
case _ => false
}
}
}

View file

@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming
import java.util.UUID
import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.locks.ReentrantLock
import scala.collection.mutable.ArrayBuffer
@ -157,8 +158,7 @@ class StreamExecution(
}
/** Defines the internal state of execution */
@volatile
private var state: State = INITIALIZING
private val state = new AtomicReference[State](INITIALIZING)
@volatile
var lastExecution: IncrementalExecution = _
@ -178,8 +178,8 @@ class StreamExecution(
/**
* The thread that runs the micro-batches of this stream. Note that this thread must be
* [[org.apache.spark.util.UninterruptibleThread]] to avoid swallowing `InterruptException` when
* using [[HDFSMetadataLog]]. See SPARK-19599 for more details.
* [[org.apache.spark.util.UninterruptibleThread]] to workaround KAFKA-1894: interrupting a
* running `KafkaConsumer` may cause endless loop.
*/
val microBatchThread =
new StreamExecutionThread(s"stream execution thread for $prettyIdString") {
@ -200,10 +200,10 @@ class StreamExecution(
val offsetLog = new OffsetSeqLog(sparkSession, checkpointFile("offsets"))
/** Whether all fields of the query have been initialized */
private def isInitialized: Boolean = state != INITIALIZING
private def isInitialized: Boolean = state.get != INITIALIZING
/** Whether the query is currently active or not */
override def isActive: Boolean = state != TERMINATED
override def isActive: Boolean = state.get != TERMINATED
/** Returns the [[StreamingQueryException]] if the query was terminated by an exception. */
override def exception: Option[StreamingQueryException] = Option(streamDeathCause)
@ -249,53 +249,56 @@ class StreamExecution(
updateStatusMessage("Initializing sources")
// force initialization of the logical plan so that the sources can be created
logicalPlan
state = ACTIVE
// Unblock `awaitInitialization`
initializationLatch.countDown()
if (state.compareAndSet(INITIALIZING, ACTIVE)) {
// Unblock `awaitInitialization`
initializationLatch.countDown()
triggerExecutor.execute(() => {
startTrigger()
triggerExecutor.execute(() => {
startTrigger()
val isTerminated =
if (isActive) {
reportTimeTaken("triggerExecution") {
if (currentBatchId < 0) {
// We'll do this initialization only once
populateStartOffsets()
logDebug(s"Stream running from $committedOffsets to $availableOffsets")
} else {
constructNextBatch()
val continueToRun =
if (isActive) {
reportTimeTaken("triggerExecution") {
if (currentBatchId < 0) {
// We'll do this initialization only once
populateStartOffsets()
logDebug(s"Stream running from $committedOffsets to $availableOffsets")
} else {
constructNextBatch()
}
if (dataAvailable) {
currentStatus = currentStatus.copy(isDataAvailable = true)
updateStatusMessage("Processing new data")
runBatch()
}
}
// Report trigger as finished and construct progress object.
finishTrigger(dataAvailable)
if (dataAvailable) {
currentStatus = currentStatus.copy(isDataAvailable = true)
updateStatusMessage("Processing new data")
runBatch()
// We'll increase currentBatchId after we complete processing current batch's data
currentBatchId += 1
} else {
currentStatus = currentStatus.copy(isDataAvailable = false)
updateStatusMessage("Waiting for data to arrive")
Thread.sleep(pollingDelayMs)
}
}
// Report trigger as finished and construct progress object.
finishTrigger(dataAvailable)
if (dataAvailable) {
// We'll increase currentBatchId after we complete processing current batch's data
currentBatchId += 1
true
} else {
currentStatus = currentStatus.copy(isDataAvailable = false)
updateStatusMessage("Waiting for data to arrive")
Thread.sleep(pollingDelayMs)
false
}
true
} else {
false
}
// Update committed offsets.
committedOffsets ++= availableOffsets
updateStatusMessage("Waiting for next trigger")
isTerminated
})
updateStatusMessage("Stopped")
// Update committed offsets.
committedOffsets ++= availableOffsets
updateStatusMessage("Waiting for next trigger")
continueToRun
})
updateStatusMessage("Stopped")
} else {
// `stop()` is already called. Let `finally` finish the cleanup.
}
} catch {
case _: InterruptedException if state == TERMINATED => // interrupted by stop()
case _: InterruptedException if state.get == TERMINATED => // interrupted by stop()
updateStatusMessage("Stopped")
case e: Throwable =>
streamDeathCause = new StreamingQueryException(
@ -318,7 +321,7 @@ class StreamExecution(
initializationLatch.countDown()
try {
state = TERMINATED
state.set(TERMINATED)
currentStatus = status.copy(isTriggerActive = false, isDataAvailable = false)
// Update metrics and status
@ -562,7 +565,7 @@ class StreamExecution(
override def stop(): Unit = {
// Set the state to TERMINATED so that the batching thread knows that it was interrupted
// intentionally
state = TERMINATED
state.set(TERMINATED)
if (microBatchThread.isAlive) {
microBatchThread.interrupt()
microBatchThread.join()

View file

@ -156,7 +156,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext
})
}
testWithUninterruptibleThread("compact") {
test("compact") {
withFakeCompactibleFileStreamLog(
fileCleanupDelayMs = Long.MaxValue,
defaultCompactInterval = 3,
@ -174,7 +174,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext
})
}
testWithUninterruptibleThread("delete expired file") {
test("delete expired file") {
// Set `fileCleanupDelayMs` to 0 so that we can detect the deleting behaviour deterministically
withFakeCompactibleFileStreamLog(
fileCleanupDelayMs = 0,

View file

@ -129,7 +129,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext {
}
}
testWithUninterruptibleThread("compact") {
test("compact") {
withSQLConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3") {
withFileStreamSinkLog { sinkLog =>
for (batchId <- 0 to 10) {
@ -149,7 +149,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext {
}
}
testWithUninterruptibleThread("delete expired file") {
test("delete expired file") {
// Set FILE_SINK_LOG_CLEANUP_DELAY to 0 so that we can detect the deleting behaviour
// deterministically and one min batches to retain
withSQLConf(

View file

@ -57,7 +57,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
}
}
testWithUninterruptibleThread("HDFSMetadataLog: basic") {
test("HDFSMetadataLog: basic") {
withTempDir { temp =>
val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir
val metadataLog = new HDFSMetadataLog[String](spark, dir.getAbsolutePath)
@ -82,8 +82,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
}
}
testWithUninterruptibleThread(
"HDFSMetadataLog: fallback from FileContext to FileSystem", quietly = true) {
testQuietly("HDFSMetadataLog: fallback from FileContext to FileSystem") {
spark.conf.set(
s"fs.$scheme.impl",
classOf[FakeFileSystem].getName)
@ -103,7 +102,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
}
}
testWithUninterruptibleThread("HDFSMetadataLog: purge") {
test("HDFSMetadataLog: purge") {
withTempDir { temp =>
val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath)
assert(metadataLog.add(0, "batch0"))
@ -128,7 +127,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
}
}
testWithUninterruptibleThread("HDFSMetadataLog: restart") {
test("HDFSMetadataLog: restart") {
withTempDir { temp =>
val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath)
assert(metadataLog.add(0, "batch0"))

View file

@ -36,7 +36,7 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext {
OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}"""))
}
testWithUninterruptibleThread("OffsetSeqLog - serialization - deserialization") {
test("OffsetSeqLog - serialization - deserialization") {
withTempDir { temp =>
val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir
val metadataLog = new OffsetSeqLog(spark, dir.getAbsolutePath)

View file

@ -1174,7 +1174,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest {
assert(map.isNewFile("b", 10))
}
testWithUninterruptibleThread("do not recheck that files exist during getBatch") {
test("do not recheck that files exist during getBatch") {
withTempDir { temp =>
spark.conf.set(
s"fs.$scheme.impl",