[SPARK-32387][SS] Extract UninterruptibleThread runner logic from KafkaOffsetReader

### What changes were proposed in this pull request?
`UninterruptibleThread` running functionality is baked into `KafkaOffsetReader` which can be extracted into a class. The main intention is to simplify `KafkaOffsetReader` in order to make easier to solve SPARK-32032. In this PR I've made this extraction without functionality change.

### Why are the changes needed?
`UninterruptibleThread` running functionality is baked into `KafkaOffsetReader`.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing + additional unit tests.

Closes #29187 from gaborgsomogyi/SPARK-32387.

Authored-by: Gabor Somogyi <gabor.g.somogyi@gmail.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
This commit is contained in:
Gabor Somogyi 2020-07-24 11:41:42 -07:00 committed by Dongjoon Hyun
parent e6ef27be52
commit b890fdc8df
3 changed files with 129 additions and 36 deletions

View file

@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util
import java.util.concurrent.Executors
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
/**
* [[UninterruptibleThreadRunner]] ensures that all tasks are running in an
* [[UninterruptibleThread]]. A good example is Kafka consumer usage.
*/
private[spark] class UninterruptibleThreadRunner(threadName: String) {
private val thread = Executors.newSingleThreadExecutor((r: Runnable) => {
val t = new UninterruptibleThread(threadName) {
override def run(): Unit = {
r.run()
}
}
t.setDaemon(true)
t
})
private val execContext = ExecutionContext.fromExecutorService(thread)
def runUninterruptibly[T](body: => T): T = {
if (!Thread.currentThread.isInstanceOf[UninterruptibleThread]) {
val future = Future {
body
}(execContext)
ThreadUtils.awaitResult(future, Duration.Inf)
} else {
body
}
}
def shutdown(): Unit = {
thread.shutdown()
}
}

View file

@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util
import org.apache.spark.SparkFunSuite
class UninterruptibleThreadRunnerSuite extends SparkFunSuite {
private var runner: UninterruptibleThreadRunner = null
override def beforeEach(): Unit = {
runner = new UninterruptibleThreadRunner("ThreadName")
}
override def afterEach(): Unit = {
runner.shutdown()
}
test("runUninterruptibly should switch to UninterruptibleThread") {
assert(!Thread.currentThread().isInstanceOf[UninterruptibleThread])
var isUninterruptibleThread = false
runner.runUninterruptibly {
isUninterruptibleThread = Thread.currentThread().isInstanceOf[UninterruptibleThread]
}
assert(isUninterruptibleThread, "The runner task must run in UninterruptibleThread")
}
test("runUninterruptibly should not add new UninterruptibleThread") {
var isInitialUninterruptibleThread = false
var isRunnerUninterruptibleThread = false
val t = new UninterruptibleThread("test") {
override def run(): Unit = {
runUninterruptibly {
val initialThread = Thread.currentThread()
isInitialUninterruptibleThread = initialThread.isInstanceOf[UninterruptibleThread]
runner.runUninterruptibly {
val runnerThread = Thread.currentThread()
isRunnerUninterruptibleThread = runnerThread.isInstanceOf[UninterruptibleThread]
assert(runnerThread.eq(initialThread))
}
}
}
}
t.start()
t.join()
assert(isInitialUninterruptibleThread,
"The initiator must already run in UninterruptibleThread")
assert(isRunnerUninterruptibleThread, "The runner task must run in UninterruptibleThread")
}
}

View file

@ -18,12 +18,9 @@
package org.apache.spark.sql.kafka010
import java.{util => ju}
import java.util.concurrent.Executors
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.util.control.NonFatal
import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsumer, OffsetAndTimestamp}
@ -33,7 +30,7 @@ import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
import org.apache.spark.util.{UninterruptibleThread, UninterruptibleThreadRunner}
/**
* This class uses Kafka's own [[KafkaConsumer]] API to read data offsets from Kafka.
@ -51,19 +48,13 @@ private[kafka010] class KafkaOffsetReader(
val driverKafkaParams: ju.Map[String, Object],
readerOptions: CaseInsensitiveMap[String],
driverGroupIdPrefix: String) extends Logging {
/**
* Used to ensure execute fetch operations execute in an UninterruptibleThread
* [[UninterruptibleThreadRunner]] ensures that all [[KafkaConsumer]] communication called in an
* [[UninterruptibleThread]]. In the case of streaming queries, we are already running in an
* [[UninterruptibleThread]], however for batch mode this is not the case.
*/
val kafkaReaderThread = Executors.newSingleThreadExecutor((r: Runnable) => {
val t = new UninterruptibleThread("Kafka Offset Reader") {
override def run(): Unit = {
r.run()
}
}
t.setDaemon(true)
t
})
val execContext = ExecutionContext.fromExecutorService(kafkaReaderThread)
val uninterruptibleThreadRunner = new UninterruptibleThreadRunner("Kafka Offset Reader")
/**
* Place [[groupId]] and [[nextId]] here so that they are initialized before any consumer is
@ -126,14 +117,14 @@ private[kafka010] class KafkaOffsetReader(
* Closes the connection to Kafka, and cleans up state.
*/
def close(): Unit = {
if (_consumer != null) runUninterruptibly { stopConsumer() }
kafkaReaderThread.shutdown()
if (_consumer != null) uninterruptibleThreadRunner.runUninterruptibly { stopConsumer() }
uninterruptibleThreadRunner.shutdown()
}
/**
* @return The Set of TopicPartitions for a given topic
*/
def fetchTopicPartitions(): Set[TopicPartition] = runUninterruptibly {
def fetchTopicPartitions(): Set[TopicPartition] = uninterruptibleThreadRunner.runUninterruptibly {
assert(Thread.currentThread().isInstanceOf[UninterruptibleThread])
// Poll to get the latest assigned partitions
consumer.poll(0)
@ -531,7 +522,7 @@ private[kafka010] class KafkaOffsetReader(
private def partitionsAssignedToConsumer(
body: ju.Set[TopicPartition] => Map[TopicPartition, Long],
fetchingEarliestOffset: Boolean = false)
: Map[TopicPartition, Long] = runUninterruptibly {
: Map[TopicPartition, Long] = uninterruptibleThreadRunner.runUninterruptibly {
withRetriesWithoutInterrupt {
// Poll to get the latest assigned partitions
@ -551,23 +542,6 @@ private[kafka010] class KafkaOffsetReader(
}
}
/**
* This method ensures that the closure is called in an [[UninterruptibleThread]].
* This is required when communicating with the [[KafkaConsumer]]. In the case
* of streaming queries, we are already running in an [[UninterruptibleThread]],
* however for batch mode this is not the case.
*/
private def runUninterruptibly[T](body: => T): T = {
if (!Thread.currentThread.isInstanceOf[UninterruptibleThread]) {
val future = Future {
body
}(execContext)
ThreadUtils.awaitResult(future, Duration.Inf)
} else {
body
}
}
/**
* Helper function that does multiple retries on a body of code that returns offsets.
* Retries are needed to handle transient failures. For e.g. race conditions between getting