[SPARK-32387][SS] Extract UninterruptibleThread runner logic from KafkaOffsetReader
### What changes were proposed in this pull request? `UninterruptibleThread` running functionality is baked into `KafkaOffsetReader` which can be extracted into a class. The main intention is to simplify `KafkaOffsetReader` in order to make easier to solve SPARK-32032. In this PR I've made this extraction without functionality change. ### Why are the changes needed? `UninterruptibleThread` running functionality is baked into `KafkaOffsetReader`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing + additional unit tests. Closes #29187 from gaborgsomogyi/SPARK-32387. Authored-by: Gabor Somogyi <gabor.g.somogyi@gmail.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
This commit is contained in:
parent
e6ef27be52
commit
b890fdc8df
|
@ -0,0 +1,55 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.util
|
||||
|
||||
import java.util.concurrent.Executors
|
||||
|
||||
import scala.concurrent.{ExecutionContext, Future}
|
||||
import scala.concurrent.duration.Duration
|
||||
|
||||
/**
|
||||
* [[UninterruptibleThreadRunner]] ensures that all tasks are running in an
|
||||
* [[UninterruptibleThread]]. A good example is Kafka consumer usage.
|
||||
*/
|
||||
private[spark] class UninterruptibleThreadRunner(threadName: String) {
|
||||
private val thread = Executors.newSingleThreadExecutor((r: Runnable) => {
|
||||
val t = new UninterruptibleThread(threadName) {
|
||||
override def run(): Unit = {
|
||||
r.run()
|
||||
}
|
||||
}
|
||||
t.setDaemon(true)
|
||||
t
|
||||
})
|
||||
private val execContext = ExecutionContext.fromExecutorService(thread)
|
||||
|
||||
def runUninterruptibly[T](body: => T): T = {
|
||||
if (!Thread.currentThread.isInstanceOf[UninterruptibleThread]) {
|
||||
val future = Future {
|
||||
body
|
||||
}(execContext)
|
||||
ThreadUtils.awaitResult(future, Duration.Inf)
|
||||
} else {
|
||||
body
|
||||
}
|
||||
}
|
||||
|
||||
def shutdown(): Unit = {
|
||||
thread.shutdown()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.util
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
|
||||
class UninterruptibleThreadRunnerSuite extends SparkFunSuite {
|
||||
private var runner: UninterruptibleThreadRunner = null
|
||||
|
||||
override def beforeEach(): Unit = {
|
||||
runner = new UninterruptibleThreadRunner("ThreadName")
|
||||
}
|
||||
|
||||
override def afterEach(): Unit = {
|
||||
runner.shutdown()
|
||||
}
|
||||
|
||||
test("runUninterruptibly should switch to UninterruptibleThread") {
|
||||
assert(!Thread.currentThread().isInstanceOf[UninterruptibleThread])
|
||||
var isUninterruptibleThread = false
|
||||
runner.runUninterruptibly {
|
||||
isUninterruptibleThread = Thread.currentThread().isInstanceOf[UninterruptibleThread]
|
||||
}
|
||||
assert(isUninterruptibleThread, "The runner task must run in UninterruptibleThread")
|
||||
}
|
||||
|
||||
test("runUninterruptibly should not add new UninterruptibleThread") {
|
||||
var isInitialUninterruptibleThread = false
|
||||
var isRunnerUninterruptibleThread = false
|
||||
val t = new UninterruptibleThread("test") {
|
||||
override def run(): Unit = {
|
||||
runUninterruptibly {
|
||||
val initialThread = Thread.currentThread()
|
||||
isInitialUninterruptibleThread = initialThread.isInstanceOf[UninterruptibleThread]
|
||||
runner.runUninterruptibly {
|
||||
val runnerThread = Thread.currentThread()
|
||||
isRunnerUninterruptibleThread = runnerThread.isInstanceOf[UninterruptibleThread]
|
||||
assert(runnerThread.eq(initialThread))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
t.start()
|
||||
t.join()
|
||||
assert(isInitialUninterruptibleThread,
|
||||
"The initiator must already run in UninterruptibleThread")
|
||||
assert(isRunnerUninterruptibleThread, "The runner task must run in UninterruptibleThread")
|
||||
}
|
||||
}
|
|
@ -18,12 +18,9 @@
|
|||
package org.apache.spark.sql.kafka010
|
||||
|
||||
import java.{util => ju}
|
||||
import java.util.concurrent.Executors
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.concurrent.{ExecutionContext, Future}
|
||||
import scala.concurrent.duration.Duration
|
||||
import scala.util.control.NonFatal
|
||||
|
||||
import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsumer, OffsetAndTimestamp}
|
||||
|
@ -33,7 +30,7 @@ import org.apache.spark.SparkEnv
|
|||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
|
||||
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
|
||||
import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
|
||||
import org.apache.spark.util.{UninterruptibleThread, UninterruptibleThreadRunner}
|
||||
|
||||
/**
|
||||
* This class uses Kafka's own [[KafkaConsumer]] API to read data offsets from Kafka.
|
||||
|
@ -51,19 +48,13 @@ private[kafka010] class KafkaOffsetReader(
|
|||
val driverKafkaParams: ju.Map[String, Object],
|
||||
readerOptions: CaseInsensitiveMap[String],
|
||||
driverGroupIdPrefix: String) extends Logging {
|
||||
|
||||
/**
|
||||
* Used to ensure execute fetch operations execute in an UninterruptibleThread
|
||||
* [[UninterruptibleThreadRunner]] ensures that all [[KafkaConsumer]] communication called in an
|
||||
* [[UninterruptibleThread]]. In the case of streaming queries, we are already running in an
|
||||
* [[UninterruptibleThread]], however for batch mode this is not the case.
|
||||
*/
|
||||
val kafkaReaderThread = Executors.newSingleThreadExecutor((r: Runnable) => {
|
||||
val t = new UninterruptibleThread("Kafka Offset Reader") {
|
||||
override def run(): Unit = {
|
||||
r.run()
|
||||
}
|
||||
}
|
||||
t.setDaemon(true)
|
||||
t
|
||||
})
|
||||
val execContext = ExecutionContext.fromExecutorService(kafkaReaderThread)
|
||||
val uninterruptibleThreadRunner = new UninterruptibleThreadRunner("Kafka Offset Reader")
|
||||
|
||||
/**
|
||||
* Place [[groupId]] and [[nextId]] here so that they are initialized before any consumer is
|
||||
|
@ -126,14 +117,14 @@ private[kafka010] class KafkaOffsetReader(
|
|||
* Closes the connection to Kafka, and cleans up state.
|
||||
*/
|
||||
def close(): Unit = {
|
||||
if (_consumer != null) runUninterruptibly { stopConsumer() }
|
||||
kafkaReaderThread.shutdown()
|
||||
if (_consumer != null) uninterruptibleThreadRunner.runUninterruptibly { stopConsumer() }
|
||||
uninterruptibleThreadRunner.shutdown()
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The Set of TopicPartitions for a given topic
|
||||
*/
|
||||
def fetchTopicPartitions(): Set[TopicPartition] = runUninterruptibly {
|
||||
def fetchTopicPartitions(): Set[TopicPartition] = uninterruptibleThreadRunner.runUninterruptibly {
|
||||
assert(Thread.currentThread().isInstanceOf[UninterruptibleThread])
|
||||
// Poll to get the latest assigned partitions
|
||||
consumer.poll(0)
|
||||
|
@ -531,7 +522,7 @@ private[kafka010] class KafkaOffsetReader(
|
|||
private def partitionsAssignedToConsumer(
|
||||
body: ju.Set[TopicPartition] => Map[TopicPartition, Long],
|
||||
fetchingEarliestOffset: Boolean = false)
|
||||
: Map[TopicPartition, Long] = runUninterruptibly {
|
||||
: Map[TopicPartition, Long] = uninterruptibleThreadRunner.runUninterruptibly {
|
||||
|
||||
withRetriesWithoutInterrupt {
|
||||
// Poll to get the latest assigned partitions
|
||||
|
@ -551,23 +542,6 @@ private[kafka010] class KafkaOffsetReader(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This method ensures that the closure is called in an [[UninterruptibleThread]].
|
||||
* This is required when communicating with the [[KafkaConsumer]]. In the case
|
||||
* of streaming queries, we are already running in an [[UninterruptibleThread]],
|
||||
* however for batch mode this is not the case.
|
||||
*/
|
||||
private def runUninterruptibly[T](body: => T): T = {
|
||||
if (!Thread.currentThread.isInstanceOf[UninterruptibleThread]) {
|
||||
val future = Future {
|
||||
body
|
||||
}(execContext)
|
||||
ThreadUtils.awaitResult(future, Duration.Inf)
|
||||
} else {
|
||||
body
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function that does multiple retries on a body of code that returns offsets.
|
||||
* Retries are needed to handle transient failures. For e.g. race conditions between getting
|
||||
|
|
Loading…
Reference in a new issue