[SPARK-28163][SS] Use CaseInsensitiveMap for KafkaOffsetReader
## What changes were proposed in this pull request? There are "unsafe" conversions in the Kafka connector. `CaseInsensitiveStringMap` comes in which is then converted the following way: ``` ... options.asScala.toMap ... ``` The main problem with this is that such case it looses its case insensitive nature (case insensitive map is converting the key to lower case when get/contains called). In this PR I'm using `CaseInsensitiveMap` to solve this problem. ## How was this patch tested? Existing + additional unit tests. Closes #24967 from gaborgsomogyi/SPARK-28163. Authored-by: Gabor Somogyi <gabor.g.somogyi@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
5368eaa2fc
commit
5663386f4b
|
@ -22,12 +22,13 @@ import org.apache.kafka.common.TopicPartition
|
|||
import org.apache.spark.SparkEnv
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
|
||||
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
|
||||
import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, PartitionReaderFactory}
|
||||
|
||||
|
||||
private[kafka010] class KafkaBatch(
|
||||
strategy: ConsumerStrategy,
|
||||
sourceOptions: Map[String, String],
|
||||
sourceOptions: CaseInsensitiveMap[String],
|
||||
specifiedKafkaParams: Map[String, String],
|
||||
failOnDataLoss: Boolean,
|
||||
startingOffsets: KafkaOffsetRangeLimit,
|
||||
|
@ -38,7 +39,7 @@ private[kafka010] class KafkaBatch(
|
|||
assert(endingOffsets != EarliestOffsetRangeLimit,
|
||||
"Ending offset not allowed to be set to earliest offsets.")
|
||||
|
||||
private val pollTimeoutMs = sourceOptions.getOrElse(
|
||||
private[kafka010] val pollTimeoutMs = sourceOptions.getOrElse(
|
||||
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
|
||||
(SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L).toString
|
||||
).toLong
|
||||
|
|
|
@ -46,7 +46,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
|||
* properly read.
|
||||
*/
|
||||
class KafkaContinuousStream(
|
||||
offsetReader: KafkaOffsetReader,
|
||||
private[kafka010] val offsetReader: KafkaOffsetReader,
|
||||
kafkaParams: ju.Map[String, Object],
|
||||
options: CaseInsensitiveStringMap,
|
||||
metadataPath: String,
|
||||
|
@ -54,7 +54,7 @@ class KafkaContinuousStream(
|
|||
failOnDataLoss: Boolean)
|
||||
extends ContinuousStream with Logging {
|
||||
|
||||
private val pollTimeoutMs =
|
||||
private[kafka010] val pollTimeoutMs =
|
||||
options.getLong(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, 512)
|
||||
|
||||
// Initialized when creating reader factories. If this diverges from the partitions at the latest
|
||||
|
|
|
@ -56,19 +56,19 @@ import org.apache.spark.util.UninterruptibleThread
|
|||
* and not use wrong broker addresses.
|
||||
*/
|
||||
private[kafka010] class KafkaMicroBatchStream(
|
||||
kafkaOffsetReader: KafkaOffsetReader,
|
||||
private[kafka010] val kafkaOffsetReader: KafkaOffsetReader,
|
||||
executorKafkaParams: ju.Map[String, Object],
|
||||
options: CaseInsensitiveStringMap,
|
||||
metadataPath: String,
|
||||
startingOffsets: KafkaOffsetRangeLimit,
|
||||
failOnDataLoss: Boolean) extends RateControlMicroBatchStream with Logging {
|
||||
|
||||
private val pollTimeoutMs = options.getLong(
|
||||
private[kafka010] val pollTimeoutMs = options.getLong(
|
||||
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
|
||||
SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L)
|
||||
|
||||
private val maxOffsetsPerTrigger = Option(options.get(KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER))
|
||||
.map(_.toLong)
|
||||
private[kafka010] val maxOffsetsPerTrigger = Option(options.get(
|
||||
KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER)).map(_.toLong)
|
||||
|
||||
private val rangeCalculator = KafkaOffsetRangeCalculator(options)
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsume
|
|||
import org.apache.kafka.common.TopicPartition
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
|
||||
|
||||
|
@ -47,7 +48,7 @@ import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
|
|||
private[kafka010] class KafkaOffsetReader(
|
||||
consumerStrategy: ConsumerStrategy,
|
||||
val driverKafkaParams: ju.Map[String, Object],
|
||||
readerOptions: Map[String, String],
|
||||
readerOptions: CaseInsensitiveMap[String],
|
||||
driverGroupIdPrefix: String) extends Logging {
|
||||
/**
|
||||
* Used to ensure execute fetch operations execute in an UninterruptibleThread
|
||||
|
@ -88,10 +89,10 @@ private[kafka010] class KafkaOffsetReader(
|
|||
_consumer
|
||||
}
|
||||
|
||||
private val maxOffsetFetchAttempts =
|
||||
private[kafka010] val maxOffsetFetchAttempts =
|
||||
readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_NUM_RETRY, "3").toInt
|
||||
|
||||
private val offsetFetchAttemptIntervalMs =
|
||||
private[kafka010] val offsetFetchAttemptIntervalMs =
|
||||
readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, "1000").toLong
|
||||
|
||||
private def nextGroupId(): String = {
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
|
|||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{Row, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
|
||||
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
@ -33,7 +33,7 @@ import org.apache.spark.unsafe.types.UTF8String
|
|||
private[kafka010] class KafkaRelation(
|
||||
override val sqlContext: SQLContext,
|
||||
strategy: ConsumerStrategy,
|
||||
sourceOptions: Map[String, String],
|
||||
sourceOptions: CaseInsensitiveMap[String],
|
||||
specifiedKafkaParams: Map[String, String],
|
||||
failOnDataLoss: Boolean,
|
||||
startingOffsets: KafkaOffsetRangeLimit,
|
||||
|
|
|
@ -78,32 +78,32 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
|||
schema: Option[StructType],
|
||||
providerName: String,
|
||||
parameters: Map[String, String]): Source = {
|
||||
validateStreamOptions(parameters)
|
||||
val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
|
||||
validateStreamOptions(caseInsensitiveParameters)
|
||||
// Each running query should use its own group id. Otherwise, the query may be only assigned
|
||||
// partial data since Kafka will assign partitions to multiple consumers having the same group
|
||||
// id. Hence, we should generate a unique id for each query.
|
||||
val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath)
|
||||
val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveParameters, metadataPath)
|
||||
|
||||
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
|
||||
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
|
||||
|
||||
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams,
|
||||
STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
|
||||
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
|
||||
caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
|
||||
|
||||
val kafkaOffsetReader = new KafkaOffsetReader(
|
||||
strategy(caseInsensitiveParams),
|
||||
strategy(caseInsensitiveParameters),
|
||||
kafkaParamsForDriver(specifiedKafkaParams),
|
||||
parameters,
|
||||
caseInsensitiveParameters,
|
||||
driverGroupIdPrefix = s"$uniqueGroupId-driver")
|
||||
|
||||
new KafkaSource(
|
||||
sqlContext,
|
||||
kafkaOffsetReader,
|
||||
kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
|
||||
parameters,
|
||||
caseInsensitiveParameters,
|
||||
metadataPath,
|
||||
startingStreamOffsets,
|
||||
failOnDataLoss(caseInsensitiveParams))
|
||||
failOnDataLoss(caseInsensitiveParameters))
|
||||
}
|
||||
|
||||
override def getTable(options: CaseInsensitiveStringMap): KafkaTable = {
|
||||
|
@ -119,24 +119,24 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
|||
override def createRelation(
|
||||
sqlContext: SQLContext,
|
||||
parameters: Map[String, String]): BaseRelation = {
|
||||
validateBatchOptions(parameters)
|
||||
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
|
||||
val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
|
||||
validateBatchOptions(caseInsensitiveParameters)
|
||||
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
|
||||
|
||||
val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
|
||||
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit)
|
||||
caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit)
|
||||
assert(startingRelationOffsets != LatestOffsetRangeLimit)
|
||||
|
||||
val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams,
|
||||
ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
|
||||
val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
|
||||
caseInsensitiveParameters, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
|
||||
assert(endingRelationOffsets != EarliestOffsetRangeLimit)
|
||||
|
||||
new KafkaRelation(
|
||||
sqlContext,
|
||||
strategy(caseInsensitiveParams),
|
||||
sourceOptions = parameters,
|
||||
strategy(caseInsensitiveParameters),
|
||||
sourceOptions = caseInsensitiveParameters,
|
||||
specifiedKafkaParams = specifiedKafkaParams,
|
||||
failOnDataLoss = failOnDataLoss(caseInsensitiveParams),
|
||||
failOnDataLoss = failOnDataLoss(caseInsensitiveParameters),
|
||||
startingOffsets = startingRelationOffsets,
|
||||
endingOffsets = endingRelationOffsets)
|
||||
}
|
||||
|
@ -420,23 +420,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
|||
}
|
||||
|
||||
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
|
||||
val parameters = options.asScala.toMap
|
||||
validateStreamOptions(parameters)
|
||||
val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap)
|
||||
validateStreamOptions(caseInsensitiveOptions)
|
||||
// Each running query should use its own group id. Otherwise, the query may be only assigned
|
||||
// partial data since Kafka will assign partitions to multiple consumers having the same group
|
||||
// id. Hence, we should generate a unique id for each query.
|
||||
val uniqueGroupId = streamingUniqueGroupId(parameters, checkpointLocation)
|
||||
val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveOptions, checkpointLocation)
|
||||
|
||||
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
|
||||
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
|
||||
val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions)
|
||||
|
||||
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
|
||||
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
|
||||
caseInsensitiveOptions, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
|
||||
|
||||
val kafkaOffsetReader = new KafkaOffsetReader(
|
||||
strategy(parameters),
|
||||
strategy(caseInsensitiveOptions),
|
||||
kafkaParamsForDriver(specifiedKafkaParams),
|
||||
parameters,
|
||||
caseInsensitiveOptions,
|
||||
driverGroupIdPrefix = s"$uniqueGroupId-driver")
|
||||
|
||||
new KafkaMicroBatchStream(
|
||||
|
@ -445,32 +444,26 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
|||
options,
|
||||
checkpointLocation,
|
||||
startingStreamOffsets,
|
||||
failOnDataLoss(caseInsensitiveParams))
|
||||
failOnDataLoss(caseInsensitiveOptions))
|
||||
}
|
||||
|
||||
override def toContinuousStream(checkpointLocation: String): ContinuousStream = {
|
||||
val parameters = options.asScala.toMap
|
||||
validateStreamOptions(parameters)
|
||||
val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap)
|
||||
validateStreamOptions(caseInsensitiveOptions)
|
||||
// Each running query should use its own group id. Otherwise, the query may be only assigned
|
||||
// partial data since Kafka will assign partitions to multiple consumers having the same group
|
||||
// id. Hence, we should generate a unique id for each query.
|
||||
val uniqueGroupId = streamingUniqueGroupId(parameters, checkpointLocation)
|
||||
val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveOptions, checkpointLocation)
|
||||
|
||||
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
|
||||
val specifiedKafkaParams =
|
||||
parameters
|
||||
.keySet
|
||||
.filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
|
||||
.map { k => k.drop(6).toString -> parameters(k) }
|
||||
.toMap
|
||||
val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions)
|
||||
|
||||
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
|
||||
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
|
||||
caseInsensitiveOptions, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
|
||||
|
||||
val kafkaOffsetReader = new KafkaOffsetReader(
|
||||
strategy(caseInsensitiveParams),
|
||||
strategy(caseInsensitiveOptions),
|
||||
kafkaParamsForDriver(specifiedKafkaParams),
|
||||
parameters,
|
||||
caseInsensitiveOptions,
|
||||
driverGroupIdPrefix = s"$uniqueGroupId-driver")
|
||||
|
||||
new KafkaContinuousStream(
|
||||
|
@ -479,7 +472,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
|||
options,
|
||||
checkpointLocation,
|
||||
startingStreamOffsets,
|
||||
failOnDataLoss(caseInsensitiveParams))
|
||||
failOnDataLoss(caseInsensitiveOptions))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,60 +22,92 @@ import java.util.Locale
|
|||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.mockito.Mockito.{mock, when}
|
||||
import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite}
|
||||
import org.apache.spark.sql.sources.v2.reader.Scan
|
||||
import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
||||
|
||||
class KafkaSourceProviderSuite extends SparkFunSuite with PrivateMethodTester {
|
||||
class KafkaSourceProviderSuite extends SparkFunSuite {
|
||||
|
||||
private val pollTimeoutMsMethod = PrivateMethod[Long]('pollTimeoutMs)
|
||||
private val maxOffsetsPerTriggerMethod = PrivateMethod[Option[Long]]('maxOffsetsPerTrigger)
|
||||
private val expected = "1111"
|
||||
|
||||
override protected def afterEach(): Unit = {
|
||||
SparkEnv.set(null)
|
||||
super.afterEach()
|
||||
}
|
||||
|
||||
test("micro-batch mode - options should be handled as case-insensitive") {
|
||||
def verifyFieldsInMicroBatchStream(
|
||||
options: CaseInsensitiveStringMap,
|
||||
expectedPollTimeoutMs: Long,
|
||||
expectedMaxOffsetsPerTrigger: Option[Long]): Unit = {
|
||||
// KafkaMicroBatchStream reads Spark conf from SparkEnv for default value
|
||||
// hence we set mock SparkEnv here before creating KafkaMicroBatchStream
|
||||
val sparkEnv = mock(classOf[SparkEnv])
|
||||
when(sparkEnv.conf).thenReturn(new SparkConf())
|
||||
SparkEnv.set(sparkEnv)
|
||||
|
||||
val scan = getKafkaDataSourceScan(options)
|
||||
val stream = scan.toMicroBatchStream("dummy").asInstanceOf[KafkaMicroBatchStream]
|
||||
|
||||
assert(expectedPollTimeoutMs === getField(stream, pollTimeoutMsMethod))
|
||||
assert(expectedMaxOffsetsPerTrigger === getField(stream, maxOffsetsPerTriggerMethod))
|
||||
}
|
||||
|
||||
val expectedValue = 1000L
|
||||
buildCaseInsensitiveStringMapForUpperAndLowerKey(
|
||||
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT -> expectedValue.toString,
|
||||
KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER -> expectedValue.toString)
|
||||
.foreach(verifyFieldsInMicroBatchStream(_, expectedValue, Some(expectedValue)))
|
||||
test("batch mode - options should be handled as case-insensitive") {
|
||||
verifyFieldsInBatch(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, expected, batch => {
|
||||
assert(expected.toLong === batch.pollTimeoutMs)
|
||||
})
|
||||
}
|
||||
|
||||
test("SPARK-28142 - continuous mode - options should be handled as case-insensitive") {
|
||||
def verifyFieldsInContinuousStream(
|
||||
options: CaseInsensitiveStringMap,
|
||||
expectedPollTimeoutMs: Long): Unit = {
|
||||
test("micro-batch mode - options should be handled as case-insensitive") {
|
||||
verifyFieldsInMicroBatchStream(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, expected, stream => {
|
||||
assert(expected.toLong === stream.pollTimeoutMs)
|
||||
})
|
||||
verifyFieldsInMicroBatchStream(KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER, expected, stream => {
|
||||
assert(Some(expected.toLong) === stream.maxOffsetsPerTrigger)
|
||||
})
|
||||
verifyFieldsInMicroBatchStream(KafkaSourceProvider.FETCH_OFFSET_NUM_RETRY, expected, stream => {
|
||||
assert(expected.toInt === stream.kafkaOffsetReader.maxOffsetFetchAttempts)
|
||||
})
|
||||
verifyFieldsInMicroBatchStream(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, expected,
|
||||
stream => {
|
||||
assert(expected.toLong === stream.kafkaOffsetReader.offsetFetchAttemptIntervalMs)
|
||||
})
|
||||
}
|
||||
|
||||
test("continuous mode - options should be handled as case-insensitive") {
|
||||
verifyFieldsInContinuousStream(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, expected, stream => {
|
||||
assert(expected.toLong === stream.pollTimeoutMs)
|
||||
})
|
||||
verifyFieldsInContinuousStream(KafkaSourceProvider.FETCH_OFFSET_NUM_RETRY, expected, stream => {
|
||||
assert(expected.toInt === stream.offsetReader.maxOffsetFetchAttempts)
|
||||
})
|
||||
verifyFieldsInContinuousStream(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, expected,
|
||||
stream => {
|
||||
assert(expected.toLong === stream.offsetReader.offsetFetchAttemptIntervalMs)
|
||||
})
|
||||
}
|
||||
|
||||
private def verifyFieldsInBatch(
|
||||
key: String,
|
||||
value: String,
|
||||
validate: (KafkaBatch) => Unit): Unit = {
|
||||
buildCaseInsensitiveStringMapForUpperAndLowerKey(key -> value).foreach { options =>
|
||||
val scan = getKafkaDataSourceScan(options)
|
||||
val batch = scan.toBatch().asInstanceOf[KafkaBatch]
|
||||
validate(batch)
|
||||
}
|
||||
}
|
||||
|
||||
private def verifyFieldsInMicroBatchStream(
|
||||
key: String,
|
||||
value: String,
|
||||
validate: (KafkaMicroBatchStream) => Unit): Unit = {
|
||||
// KafkaMicroBatchStream reads Spark conf from SparkEnv for default value
|
||||
// hence we set mock SparkEnv here before creating KafkaMicroBatchStream
|
||||
val sparkEnv = mock(classOf[SparkEnv])
|
||||
when(sparkEnv.conf).thenReturn(new SparkConf())
|
||||
SparkEnv.set(sparkEnv)
|
||||
|
||||
buildCaseInsensitiveStringMapForUpperAndLowerKey(key -> value).foreach { options =>
|
||||
val scan = getKafkaDataSourceScan(options)
|
||||
val stream = scan.toMicroBatchStream("dummy").asInstanceOf[KafkaMicroBatchStream]
|
||||
validate(stream)
|
||||
}
|
||||
}
|
||||
|
||||
private def verifyFieldsInContinuousStream(
|
||||
key: String,
|
||||
value: String,
|
||||
validate: (KafkaContinuousStream) => Unit): Unit = {
|
||||
buildCaseInsensitiveStringMapForUpperAndLowerKey(key -> value).foreach { options =>
|
||||
val scan = getKafkaDataSourceScan(options)
|
||||
val stream = scan.toContinuousStream("dummy").asInstanceOf[KafkaContinuousStream]
|
||||
assert(expectedPollTimeoutMs === getField(stream, pollTimeoutMsMethod))
|
||||
validate(stream)
|
||||
}
|
||||
|
||||
val expectedValue = 1000
|
||||
buildCaseInsensitiveStringMapForUpperAndLowerKey(
|
||||
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT -> expectedValue.toString)
|
||||
.foreach(verifyFieldsInContinuousStream(_, expectedValue))
|
||||
}
|
||||
|
||||
private def buildCaseInsensitiveStringMapForUpperAndLowerKey(
|
||||
|
@ -95,8 +127,4 @@ class KafkaSourceProviderSuite extends SparkFunSuite with PrivateMethodTester {
|
|||
val provider = new KafkaSourceProvider()
|
||||
provider.getTable(options).newScanBuilder(options).build()
|
||||
}
|
||||
|
||||
private def getField[T](obj: AnyRef, method: PrivateMethod[T]): T = {
|
||||
obj.invokePrivate(method())
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue