[SPARK-28163][SS] Use CaseInsensitiveMap for KafkaOffsetReader

## What changes were proposed in this pull request?

There are "unsafe" conversions in the Kafka connector.
`CaseInsensitiveStringMap` comes in which is then converted the following way:
```
...
options.asScala.toMap
...
```
The main problem with this is that such case it looses its case insensitive nature
(case insensitive map is converting the key to lower case when get/contains called).

In this PR I'm using `CaseInsensitiveMap` to solve this problem.

## How was this patch tested?

Existing + additional unit tests.

Closes #24967 from gaborgsomogyi/SPARK-28163.

Authored-by: Gabor Somogyi <gabor.g.somogyi@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Gabor Somogyi 2019-08-09 17:08:11 +08:00 committed by Wenchen Fan
parent 5368eaa2fc
commit 5663386f4b
7 changed files with 117 additions and 94 deletions

View file

@ -22,12 +22,13 @@ import org.apache.kafka.common.TopicPartition
import org.apache.spark.SparkEnv import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, PartitionReaderFactory} import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, PartitionReaderFactory}
private[kafka010] class KafkaBatch( private[kafka010] class KafkaBatch(
strategy: ConsumerStrategy, strategy: ConsumerStrategy,
sourceOptions: Map[String, String], sourceOptions: CaseInsensitiveMap[String],
specifiedKafkaParams: Map[String, String], specifiedKafkaParams: Map[String, String],
failOnDataLoss: Boolean, failOnDataLoss: Boolean,
startingOffsets: KafkaOffsetRangeLimit, startingOffsets: KafkaOffsetRangeLimit,
@ -38,7 +39,7 @@ private[kafka010] class KafkaBatch(
assert(endingOffsets != EarliestOffsetRangeLimit, assert(endingOffsets != EarliestOffsetRangeLimit,
"Ending offset not allowed to be set to earliest offsets.") "Ending offset not allowed to be set to earliest offsets.")
private val pollTimeoutMs = sourceOptions.getOrElse( private[kafka010] val pollTimeoutMs = sourceOptions.getOrElse(
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
(SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L).toString (SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L).toString
).toLong ).toLong

View file

@ -46,7 +46,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
* properly read. * properly read.
*/ */
class KafkaContinuousStream( class KafkaContinuousStream(
offsetReader: KafkaOffsetReader, private[kafka010] val offsetReader: KafkaOffsetReader,
kafkaParams: ju.Map[String, Object], kafkaParams: ju.Map[String, Object],
options: CaseInsensitiveStringMap, options: CaseInsensitiveStringMap,
metadataPath: String, metadataPath: String,
@ -54,7 +54,7 @@ class KafkaContinuousStream(
failOnDataLoss: Boolean) failOnDataLoss: Boolean)
extends ContinuousStream with Logging { extends ContinuousStream with Logging {
private val pollTimeoutMs = private[kafka010] val pollTimeoutMs =
options.getLong(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, 512) options.getLong(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, 512)
// Initialized when creating reader factories. If this diverges from the partitions at the latest // Initialized when creating reader factories. If this diverges from the partitions at the latest

View file

@ -56,19 +56,19 @@ import org.apache.spark.util.UninterruptibleThread
* and not use wrong broker addresses. * and not use wrong broker addresses.
*/ */
private[kafka010] class KafkaMicroBatchStream( private[kafka010] class KafkaMicroBatchStream(
kafkaOffsetReader: KafkaOffsetReader, private[kafka010] val kafkaOffsetReader: KafkaOffsetReader,
executorKafkaParams: ju.Map[String, Object], executorKafkaParams: ju.Map[String, Object],
options: CaseInsensitiveStringMap, options: CaseInsensitiveStringMap,
metadataPath: String, metadataPath: String,
startingOffsets: KafkaOffsetRangeLimit, startingOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean) extends RateControlMicroBatchStream with Logging { failOnDataLoss: Boolean) extends RateControlMicroBatchStream with Logging {
private val pollTimeoutMs = options.getLong( private[kafka010] val pollTimeoutMs = options.getLong(
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L) SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L)
private val maxOffsetsPerTrigger = Option(options.get(KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER)) private[kafka010] val maxOffsetsPerTrigger = Option(options.get(
.map(_.toLong) KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER)).map(_.toLong)
private val rangeCalculator = KafkaOffsetRangeCalculator(options) private val rangeCalculator = KafkaOffsetRangeCalculator(options)

View file

@ -30,6 +30,7 @@ import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsume
import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.TopicPartition
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.util.{ThreadUtils, UninterruptibleThread} import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
@ -47,7 +48,7 @@ import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
private[kafka010] class KafkaOffsetReader( private[kafka010] class KafkaOffsetReader(
consumerStrategy: ConsumerStrategy, consumerStrategy: ConsumerStrategy,
val driverKafkaParams: ju.Map[String, Object], val driverKafkaParams: ju.Map[String, Object],
readerOptions: Map[String, String], readerOptions: CaseInsensitiveMap[String],
driverGroupIdPrefix: String) extends Logging { driverGroupIdPrefix: String) extends Logging {
/** /**
* Used to ensure execute fetch operations execute in an UninterruptibleThread * Used to ensure execute fetch operations execute in an UninterruptibleThread
@ -88,10 +89,10 @@ private[kafka010] class KafkaOffsetReader(
_consumer _consumer
} }
private val maxOffsetFetchAttempts = private[kafka010] val maxOffsetFetchAttempts =
readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_NUM_RETRY, "3").toInt readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_NUM_RETRY, "3").toInt
private val offsetFetchAttemptIntervalMs = private[kafka010] val offsetFetchAttemptIntervalMs =
readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, "1000").toLong readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, "1000").toLong
private def nextGroupId(): String = { private def nextGroupId(): String = {

View file

@ -24,7 +24,7 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.sources.{BaseRelation, TableScan} import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.unsafe.types.UTF8String
@ -33,7 +33,7 @@ import org.apache.spark.unsafe.types.UTF8String
private[kafka010] class KafkaRelation( private[kafka010] class KafkaRelation(
override val sqlContext: SQLContext, override val sqlContext: SQLContext,
strategy: ConsumerStrategy, strategy: ConsumerStrategy,
sourceOptions: Map[String, String], sourceOptions: CaseInsensitiveMap[String],
specifiedKafkaParams: Map[String, String], specifiedKafkaParams: Map[String, String],
failOnDataLoss: Boolean, failOnDataLoss: Boolean,
startingOffsets: KafkaOffsetRangeLimit, startingOffsets: KafkaOffsetRangeLimit,

View file

@ -78,32 +78,32 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
schema: Option[StructType], schema: Option[StructType],
providerName: String, providerName: String,
parameters: Map[String, String]): Source = { parameters: Map[String, String]): Source = {
validateStreamOptions(parameters) val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
validateStreamOptions(caseInsensitiveParameters)
// Each running query should use its own group id. Otherwise, the query may be only assigned // Each running query should use its own group id. Otherwise, the query may be only assigned
// partial data since Kafka will assign partitions to multiple consumers having the same group // partial data since Kafka will assign partitions to multiple consumers having the same group
// id. Hence, we should generate a unique id for each query. // id. Hence, we should generate a unique id for each query.
val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath) val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveParameters, metadataPath)
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val specifiedKafkaParams = convertToSpecifiedParams(parameters) val specifiedKafkaParams = convertToSpecifiedParams(parameters)
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
val kafkaOffsetReader = new KafkaOffsetReader( val kafkaOffsetReader = new KafkaOffsetReader(
strategy(caseInsensitiveParams), strategy(caseInsensitiveParameters),
kafkaParamsForDriver(specifiedKafkaParams), kafkaParamsForDriver(specifiedKafkaParams),
parameters, caseInsensitiveParameters,
driverGroupIdPrefix = s"$uniqueGroupId-driver") driverGroupIdPrefix = s"$uniqueGroupId-driver")
new KafkaSource( new KafkaSource(
sqlContext, sqlContext,
kafkaOffsetReader, kafkaOffsetReader,
kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
parameters, caseInsensitiveParameters,
metadataPath, metadataPath,
startingStreamOffsets, startingStreamOffsets,
failOnDataLoss(caseInsensitiveParams)) failOnDataLoss(caseInsensitiveParameters))
} }
override def getTable(options: CaseInsensitiveStringMap): KafkaTable = { override def getTable(options: CaseInsensitiveStringMap): KafkaTable = {
@ -119,24 +119,24 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
override def createRelation( override def createRelation(
sqlContext: SQLContext, sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = { parameters: Map[String, String]): BaseRelation = {
validateBatchOptions(parameters) val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } validateBatchOptions(caseInsensitiveParameters)
val specifiedKafkaParams = convertToSpecifiedParams(parameters) val specifiedKafkaParams = convertToSpecifiedParams(parameters)
val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit)
assert(startingRelationOffsets != LatestOffsetRangeLimit) assert(startingRelationOffsets != LatestOffsetRangeLimit)
val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) caseInsensitiveParameters, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
assert(endingRelationOffsets != EarliestOffsetRangeLimit) assert(endingRelationOffsets != EarliestOffsetRangeLimit)
new KafkaRelation( new KafkaRelation(
sqlContext, sqlContext,
strategy(caseInsensitiveParams), strategy(caseInsensitiveParameters),
sourceOptions = parameters, sourceOptions = caseInsensitiveParameters,
specifiedKafkaParams = specifiedKafkaParams, specifiedKafkaParams = specifiedKafkaParams,
failOnDataLoss = failOnDataLoss(caseInsensitiveParams), failOnDataLoss = failOnDataLoss(caseInsensitiveParameters),
startingOffsets = startingRelationOffsets, startingOffsets = startingRelationOffsets,
endingOffsets = endingRelationOffsets) endingOffsets = endingRelationOffsets)
} }
@ -420,23 +420,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
} }
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
val parameters = options.asScala.toMap val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap)
validateStreamOptions(parameters) validateStreamOptions(caseInsensitiveOptions)
// Each running query should use its own group id. Otherwise, the query may be only assigned // Each running query should use its own group id. Otherwise, the query may be only assigned
// partial data since Kafka will assign partitions to multiple consumers having the same group // partial data since Kafka will assign partitions to multiple consumers having the same group
// id. Hence, we should generate a unique id for each query. // id. Hence, we should generate a unique id for each query.
val uniqueGroupId = streamingUniqueGroupId(parameters, checkpointLocation) val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveOptions, checkpointLocation)
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions)
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) caseInsensitiveOptions, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
val kafkaOffsetReader = new KafkaOffsetReader( val kafkaOffsetReader = new KafkaOffsetReader(
strategy(parameters), strategy(caseInsensitiveOptions),
kafkaParamsForDriver(specifiedKafkaParams), kafkaParamsForDriver(specifiedKafkaParams),
parameters, caseInsensitiveOptions,
driverGroupIdPrefix = s"$uniqueGroupId-driver") driverGroupIdPrefix = s"$uniqueGroupId-driver")
new KafkaMicroBatchStream( new KafkaMicroBatchStream(
@ -445,32 +444,26 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
options, options,
checkpointLocation, checkpointLocation,
startingStreamOffsets, startingStreamOffsets,
failOnDataLoss(caseInsensitiveParams)) failOnDataLoss(caseInsensitiveOptions))
} }
override def toContinuousStream(checkpointLocation: String): ContinuousStream = { override def toContinuousStream(checkpointLocation: String): ContinuousStream = {
val parameters = options.asScala.toMap val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap)
validateStreamOptions(parameters) validateStreamOptions(caseInsensitiveOptions)
// Each running query should use its own group id. Otherwise, the query may be only assigned // Each running query should use its own group id. Otherwise, the query may be only assigned
// partial data since Kafka will assign partitions to multiple consumers having the same group // partial data since Kafka will assign partitions to multiple consumers having the same group
// id. Hence, we should generate a unique id for each query. // id. Hence, we should generate a unique id for each query.
val uniqueGroupId = streamingUniqueGroupId(parameters, checkpointLocation) val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveOptions, checkpointLocation)
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions)
val specifiedKafkaParams =
parameters
.keySet
.filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
.map { k => k.drop(6).toString -> parameters(k) }
.toMap
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) caseInsensitiveOptions, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
val kafkaOffsetReader = new KafkaOffsetReader( val kafkaOffsetReader = new KafkaOffsetReader(
strategy(caseInsensitiveParams), strategy(caseInsensitiveOptions),
kafkaParamsForDriver(specifiedKafkaParams), kafkaParamsForDriver(specifiedKafkaParams),
parameters, caseInsensitiveOptions,
driverGroupIdPrefix = s"$uniqueGroupId-driver") driverGroupIdPrefix = s"$uniqueGroupId-driver")
new KafkaContinuousStream( new KafkaContinuousStream(
@ -479,7 +472,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
options, options,
checkpointLocation, checkpointLocation,
startingStreamOffsets, startingStreamOffsets,
failOnDataLoss(caseInsensitiveParams)) failOnDataLoss(caseInsensitiveOptions))
} }
} }
} }

View file

@ -22,60 +22,92 @@ import java.util.Locale
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import org.mockito.Mockito.{mock, when} import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite} import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite}
import org.apache.spark.sql.sources.v2.reader.Scan import org.apache.spark.sql.sources.v2.reader.Scan
import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.util.CaseInsensitiveStringMap
class KafkaSourceProviderSuite extends SparkFunSuite with PrivateMethodTester { class KafkaSourceProviderSuite extends SparkFunSuite {
private val pollTimeoutMsMethod = PrivateMethod[Long]('pollTimeoutMs) private val expected = "1111"
private val maxOffsetsPerTriggerMethod = PrivateMethod[Option[Long]]('maxOffsetsPerTrigger)
override protected def afterEach(): Unit = { override protected def afterEach(): Unit = {
SparkEnv.set(null) SparkEnv.set(null)
super.afterEach() super.afterEach()
} }
test("micro-batch mode - options should be handled as case-insensitive") { test("batch mode - options should be handled as case-insensitive") {
def verifyFieldsInMicroBatchStream( verifyFieldsInBatch(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, expected, batch => {
options: CaseInsensitiveStringMap, assert(expected.toLong === batch.pollTimeoutMs)
expectedPollTimeoutMs: Long, })
expectedMaxOffsetsPerTrigger: Option[Long]): Unit = {
// KafkaMicroBatchStream reads Spark conf from SparkEnv for default value
// hence we set mock SparkEnv here before creating KafkaMicroBatchStream
val sparkEnv = mock(classOf[SparkEnv])
when(sparkEnv.conf).thenReturn(new SparkConf())
SparkEnv.set(sparkEnv)
val scan = getKafkaDataSourceScan(options)
val stream = scan.toMicroBatchStream("dummy").asInstanceOf[KafkaMicroBatchStream]
assert(expectedPollTimeoutMs === getField(stream, pollTimeoutMsMethod))
assert(expectedMaxOffsetsPerTrigger === getField(stream, maxOffsetsPerTriggerMethod))
}
val expectedValue = 1000L
buildCaseInsensitiveStringMapForUpperAndLowerKey(
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT -> expectedValue.toString,
KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER -> expectedValue.toString)
.foreach(verifyFieldsInMicroBatchStream(_, expectedValue, Some(expectedValue)))
} }
test("SPARK-28142 - continuous mode - options should be handled as case-insensitive") { test("micro-batch mode - options should be handled as case-insensitive") {
def verifyFieldsInContinuousStream( verifyFieldsInMicroBatchStream(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, expected, stream => {
options: CaseInsensitiveStringMap, assert(expected.toLong === stream.pollTimeoutMs)
expectedPollTimeoutMs: Long): Unit = { })
verifyFieldsInMicroBatchStream(KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER, expected, stream => {
assert(Some(expected.toLong) === stream.maxOffsetsPerTrigger)
})
verifyFieldsInMicroBatchStream(KafkaSourceProvider.FETCH_OFFSET_NUM_RETRY, expected, stream => {
assert(expected.toInt === stream.kafkaOffsetReader.maxOffsetFetchAttempts)
})
verifyFieldsInMicroBatchStream(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, expected,
stream => {
assert(expected.toLong === stream.kafkaOffsetReader.offsetFetchAttemptIntervalMs)
})
}
test("continuous mode - options should be handled as case-insensitive") {
verifyFieldsInContinuousStream(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, expected, stream => {
assert(expected.toLong === stream.pollTimeoutMs)
})
verifyFieldsInContinuousStream(KafkaSourceProvider.FETCH_OFFSET_NUM_RETRY, expected, stream => {
assert(expected.toInt === stream.offsetReader.maxOffsetFetchAttempts)
})
verifyFieldsInContinuousStream(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, expected,
stream => {
assert(expected.toLong === stream.offsetReader.offsetFetchAttemptIntervalMs)
})
}
private def verifyFieldsInBatch(
key: String,
value: String,
validate: (KafkaBatch) => Unit): Unit = {
buildCaseInsensitiveStringMapForUpperAndLowerKey(key -> value).foreach { options =>
val scan = getKafkaDataSourceScan(options)
val batch = scan.toBatch().asInstanceOf[KafkaBatch]
validate(batch)
}
}
private def verifyFieldsInMicroBatchStream(
key: String,
value: String,
validate: (KafkaMicroBatchStream) => Unit): Unit = {
// KafkaMicroBatchStream reads Spark conf from SparkEnv for default value
// hence we set mock SparkEnv here before creating KafkaMicroBatchStream
val sparkEnv = mock(classOf[SparkEnv])
when(sparkEnv.conf).thenReturn(new SparkConf())
SparkEnv.set(sparkEnv)
buildCaseInsensitiveStringMapForUpperAndLowerKey(key -> value).foreach { options =>
val scan = getKafkaDataSourceScan(options)
val stream = scan.toMicroBatchStream("dummy").asInstanceOf[KafkaMicroBatchStream]
validate(stream)
}
}
private def verifyFieldsInContinuousStream(
key: String,
value: String,
validate: (KafkaContinuousStream) => Unit): Unit = {
buildCaseInsensitiveStringMapForUpperAndLowerKey(key -> value).foreach { options =>
val scan = getKafkaDataSourceScan(options) val scan = getKafkaDataSourceScan(options)
val stream = scan.toContinuousStream("dummy").asInstanceOf[KafkaContinuousStream] val stream = scan.toContinuousStream("dummy").asInstanceOf[KafkaContinuousStream]
assert(expectedPollTimeoutMs === getField(stream, pollTimeoutMsMethod)) validate(stream)
} }
val expectedValue = 1000
buildCaseInsensitiveStringMapForUpperAndLowerKey(
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT -> expectedValue.toString)
.foreach(verifyFieldsInContinuousStream(_, expectedValue))
} }
private def buildCaseInsensitiveStringMapForUpperAndLowerKey( private def buildCaseInsensitiveStringMapForUpperAndLowerKey(
@ -95,8 +127,4 @@ class KafkaSourceProviderSuite extends SparkFunSuite with PrivateMethodTester {
val provider = new KafkaSourceProvider() val provider = new KafkaSourceProvider()
provider.getTable(options).newScanBuilder(options).build() provider.getTable(options).newScanBuilder(options).build()
} }
private def getField[T](obj: AnyRef, method: PrivateMethod[T]): T = {
obj.invokePrivate(method())
}
} }