[SPARK-28163][SS] Use CaseInsensitiveMap for KafkaOffsetReader
## What changes were proposed in this pull request? There are "unsafe" conversions in the Kafka connector. `CaseInsensitiveStringMap` comes in which is then converted the following way: ``` ... options.asScala.toMap ... ``` The main problem with this is that such case it looses its case insensitive nature (case insensitive map is converting the key to lower case when get/contains called). In this PR I'm using `CaseInsensitiveMap` to solve this problem. ## How was this patch tested? Existing + additional unit tests. Closes #24967 from gaborgsomogyi/SPARK-28163. Authored-by: Gabor Somogyi <gabor.g.somogyi@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
5368eaa2fc
commit
5663386f4b
|
@ -22,12 +22,13 @@ import org.apache.kafka.common.TopicPartition
|
||||||
import org.apache.spark.SparkEnv
|
import org.apache.spark.SparkEnv
|
||||||
import org.apache.spark.internal.Logging
|
import org.apache.spark.internal.Logging
|
||||||
import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
|
import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
|
||||||
|
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
|
||||||
import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, PartitionReaderFactory}
|
import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, PartitionReaderFactory}
|
||||||
|
|
||||||
|
|
||||||
private[kafka010] class KafkaBatch(
|
private[kafka010] class KafkaBatch(
|
||||||
strategy: ConsumerStrategy,
|
strategy: ConsumerStrategy,
|
||||||
sourceOptions: Map[String, String],
|
sourceOptions: CaseInsensitiveMap[String],
|
||||||
specifiedKafkaParams: Map[String, String],
|
specifiedKafkaParams: Map[String, String],
|
||||||
failOnDataLoss: Boolean,
|
failOnDataLoss: Boolean,
|
||||||
startingOffsets: KafkaOffsetRangeLimit,
|
startingOffsets: KafkaOffsetRangeLimit,
|
||||||
|
@ -38,7 +39,7 @@ private[kafka010] class KafkaBatch(
|
||||||
assert(endingOffsets != EarliestOffsetRangeLimit,
|
assert(endingOffsets != EarliestOffsetRangeLimit,
|
||||||
"Ending offset not allowed to be set to earliest offsets.")
|
"Ending offset not allowed to be set to earliest offsets.")
|
||||||
|
|
||||||
private val pollTimeoutMs = sourceOptions.getOrElse(
|
private[kafka010] val pollTimeoutMs = sourceOptions.getOrElse(
|
||||||
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
|
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
|
||||||
(SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L).toString
|
(SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L).toString
|
||||||
).toLong
|
).toLong
|
||||||
|
|
|
@ -46,7 +46,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
||||||
* properly read.
|
* properly read.
|
||||||
*/
|
*/
|
||||||
class KafkaContinuousStream(
|
class KafkaContinuousStream(
|
||||||
offsetReader: KafkaOffsetReader,
|
private[kafka010] val offsetReader: KafkaOffsetReader,
|
||||||
kafkaParams: ju.Map[String, Object],
|
kafkaParams: ju.Map[String, Object],
|
||||||
options: CaseInsensitiveStringMap,
|
options: CaseInsensitiveStringMap,
|
||||||
metadataPath: String,
|
metadataPath: String,
|
||||||
|
@ -54,7 +54,7 @@ class KafkaContinuousStream(
|
||||||
failOnDataLoss: Boolean)
|
failOnDataLoss: Boolean)
|
||||||
extends ContinuousStream with Logging {
|
extends ContinuousStream with Logging {
|
||||||
|
|
||||||
private val pollTimeoutMs =
|
private[kafka010] val pollTimeoutMs =
|
||||||
options.getLong(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, 512)
|
options.getLong(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, 512)
|
||||||
|
|
||||||
// Initialized when creating reader factories. If this diverges from the partitions at the latest
|
// Initialized when creating reader factories. If this diverges from the partitions at the latest
|
||||||
|
|
|
@ -56,19 +56,19 @@ import org.apache.spark.util.UninterruptibleThread
|
||||||
* and not use wrong broker addresses.
|
* and not use wrong broker addresses.
|
||||||
*/
|
*/
|
||||||
private[kafka010] class KafkaMicroBatchStream(
|
private[kafka010] class KafkaMicroBatchStream(
|
||||||
kafkaOffsetReader: KafkaOffsetReader,
|
private[kafka010] val kafkaOffsetReader: KafkaOffsetReader,
|
||||||
executorKafkaParams: ju.Map[String, Object],
|
executorKafkaParams: ju.Map[String, Object],
|
||||||
options: CaseInsensitiveStringMap,
|
options: CaseInsensitiveStringMap,
|
||||||
metadataPath: String,
|
metadataPath: String,
|
||||||
startingOffsets: KafkaOffsetRangeLimit,
|
startingOffsets: KafkaOffsetRangeLimit,
|
||||||
failOnDataLoss: Boolean) extends RateControlMicroBatchStream with Logging {
|
failOnDataLoss: Boolean) extends RateControlMicroBatchStream with Logging {
|
||||||
|
|
||||||
private val pollTimeoutMs = options.getLong(
|
private[kafka010] val pollTimeoutMs = options.getLong(
|
||||||
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
|
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
|
||||||
SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L)
|
SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L)
|
||||||
|
|
||||||
private val maxOffsetsPerTrigger = Option(options.get(KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER))
|
private[kafka010] val maxOffsetsPerTrigger = Option(options.get(
|
||||||
.map(_.toLong)
|
KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER)).map(_.toLong)
|
||||||
|
|
||||||
private val rangeCalculator = KafkaOffsetRangeCalculator(options)
|
private val rangeCalculator = KafkaOffsetRangeCalculator(options)
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsume
|
||||||
import org.apache.kafka.common.TopicPartition
|
import org.apache.kafka.common.TopicPartition
|
||||||
|
|
||||||
import org.apache.spark.internal.Logging
|
import org.apache.spark.internal.Logging
|
||||||
|
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
|
import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
|
||||||
|
|
||||||
|
@ -47,7 +48,7 @@ import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
|
||||||
private[kafka010] class KafkaOffsetReader(
|
private[kafka010] class KafkaOffsetReader(
|
||||||
consumerStrategy: ConsumerStrategy,
|
consumerStrategy: ConsumerStrategy,
|
||||||
val driverKafkaParams: ju.Map[String, Object],
|
val driverKafkaParams: ju.Map[String, Object],
|
||||||
readerOptions: Map[String, String],
|
readerOptions: CaseInsensitiveMap[String],
|
||||||
driverGroupIdPrefix: String) extends Logging {
|
driverGroupIdPrefix: String) extends Logging {
|
||||||
/**
|
/**
|
||||||
* Used to ensure execute fetch operations execute in an UninterruptibleThread
|
* Used to ensure execute fetch operations execute in an UninterruptibleThread
|
||||||
|
@ -88,10 +89,10 @@ private[kafka010] class KafkaOffsetReader(
|
||||||
_consumer
|
_consumer
|
||||||
}
|
}
|
||||||
|
|
||||||
private val maxOffsetFetchAttempts =
|
private[kafka010] val maxOffsetFetchAttempts =
|
||||||
readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_NUM_RETRY, "3").toInt
|
readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_NUM_RETRY, "3").toInt
|
||||||
|
|
||||||
private val offsetFetchAttemptIntervalMs =
|
private[kafka010] val offsetFetchAttemptIntervalMs =
|
||||||
readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, "1000").toLong
|
readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, "1000").toLong
|
||||||
|
|
||||||
private def nextGroupId(): String = {
|
private def nextGroupId(): String = {
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{Row, SQLContext}
|
import org.apache.spark.sql.{Row, SQLContext}
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
|
||||||
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
|
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.StructType
|
||||||
import org.apache.spark.unsafe.types.UTF8String
|
import org.apache.spark.unsafe.types.UTF8String
|
||||||
|
@ -33,7 +33,7 @@ import org.apache.spark.unsafe.types.UTF8String
|
||||||
private[kafka010] class KafkaRelation(
|
private[kafka010] class KafkaRelation(
|
||||||
override val sqlContext: SQLContext,
|
override val sqlContext: SQLContext,
|
||||||
strategy: ConsumerStrategy,
|
strategy: ConsumerStrategy,
|
||||||
sourceOptions: Map[String, String],
|
sourceOptions: CaseInsensitiveMap[String],
|
||||||
specifiedKafkaParams: Map[String, String],
|
specifiedKafkaParams: Map[String, String],
|
||||||
failOnDataLoss: Boolean,
|
failOnDataLoss: Boolean,
|
||||||
startingOffsets: KafkaOffsetRangeLimit,
|
startingOffsets: KafkaOffsetRangeLimit,
|
||||||
|
|
|
@ -78,32 +78,32 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
||||||
schema: Option[StructType],
|
schema: Option[StructType],
|
||||||
providerName: String,
|
providerName: String,
|
||||||
parameters: Map[String, String]): Source = {
|
parameters: Map[String, String]): Source = {
|
||||||
validateStreamOptions(parameters)
|
val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
|
||||||
|
validateStreamOptions(caseInsensitiveParameters)
|
||||||
// Each running query should use its own group id. Otherwise, the query may be only assigned
|
// Each running query should use its own group id. Otherwise, the query may be only assigned
|
||||||
// partial data since Kafka will assign partitions to multiple consumers having the same group
|
// partial data since Kafka will assign partitions to multiple consumers having the same group
|
||||||
// id. Hence, we should generate a unique id for each query.
|
// id. Hence, we should generate a unique id for each query.
|
||||||
val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath)
|
val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveParameters, metadataPath)
|
||||||
|
|
||||||
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
|
|
||||||
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
|
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
|
||||||
|
|
||||||
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams,
|
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
|
||||||
STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
|
caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
|
||||||
|
|
||||||
val kafkaOffsetReader = new KafkaOffsetReader(
|
val kafkaOffsetReader = new KafkaOffsetReader(
|
||||||
strategy(caseInsensitiveParams),
|
strategy(caseInsensitiveParameters),
|
||||||
kafkaParamsForDriver(specifiedKafkaParams),
|
kafkaParamsForDriver(specifiedKafkaParams),
|
||||||
parameters,
|
caseInsensitiveParameters,
|
||||||
driverGroupIdPrefix = s"$uniqueGroupId-driver")
|
driverGroupIdPrefix = s"$uniqueGroupId-driver")
|
||||||
|
|
||||||
new KafkaSource(
|
new KafkaSource(
|
||||||
sqlContext,
|
sqlContext,
|
||||||
kafkaOffsetReader,
|
kafkaOffsetReader,
|
||||||
kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
|
kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
|
||||||
parameters,
|
caseInsensitiveParameters,
|
||||||
metadataPath,
|
metadataPath,
|
||||||
startingStreamOffsets,
|
startingStreamOffsets,
|
||||||
failOnDataLoss(caseInsensitiveParams))
|
failOnDataLoss(caseInsensitiveParameters))
|
||||||
}
|
}
|
||||||
|
|
||||||
override def getTable(options: CaseInsensitiveStringMap): KafkaTable = {
|
override def getTable(options: CaseInsensitiveStringMap): KafkaTable = {
|
||||||
|
@ -119,24 +119,24 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
||||||
override def createRelation(
|
override def createRelation(
|
||||||
sqlContext: SQLContext,
|
sqlContext: SQLContext,
|
||||||
parameters: Map[String, String]): BaseRelation = {
|
parameters: Map[String, String]): BaseRelation = {
|
||||||
validateBatchOptions(parameters)
|
val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
|
||||||
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
|
validateBatchOptions(caseInsensitiveParameters)
|
||||||
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
|
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
|
||||||
|
|
||||||
val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
|
val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
|
||||||
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit)
|
caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit)
|
||||||
assert(startingRelationOffsets != LatestOffsetRangeLimit)
|
assert(startingRelationOffsets != LatestOffsetRangeLimit)
|
||||||
|
|
||||||
val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams,
|
val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
|
||||||
ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
|
caseInsensitiveParameters, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
|
||||||
assert(endingRelationOffsets != EarliestOffsetRangeLimit)
|
assert(endingRelationOffsets != EarliestOffsetRangeLimit)
|
||||||
|
|
||||||
new KafkaRelation(
|
new KafkaRelation(
|
||||||
sqlContext,
|
sqlContext,
|
||||||
strategy(caseInsensitiveParams),
|
strategy(caseInsensitiveParameters),
|
||||||
sourceOptions = parameters,
|
sourceOptions = caseInsensitiveParameters,
|
||||||
specifiedKafkaParams = specifiedKafkaParams,
|
specifiedKafkaParams = specifiedKafkaParams,
|
||||||
failOnDataLoss = failOnDataLoss(caseInsensitiveParams),
|
failOnDataLoss = failOnDataLoss(caseInsensitiveParameters),
|
||||||
startingOffsets = startingRelationOffsets,
|
startingOffsets = startingRelationOffsets,
|
||||||
endingOffsets = endingRelationOffsets)
|
endingOffsets = endingRelationOffsets)
|
||||||
}
|
}
|
||||||
|
@ -420,23 +420,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
||||||
}
|
}
|
||||||
|
|
||||||
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
|
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
|
||||||
val parameters = options.asScala.toMap
|
val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap)
|
||||||
validateStreamOptions(parameters)
|
validateStreamOptions(caseInsensitiveOptions)
|
||||||
// Each running query should use its own group id. Otherwise, the query may be only assigned
|
// Each running query should use its own group id. Otherwise, the query may be only assigned
|
||||||
// partial data since Kafka will assign partitions to multiple consumers having the same group
|
// partial data since Kafka will assign partitions to multiple consumers having the same group
|
||||||
// id. Hence, we should generate a unique id for each query.
|
// id. Hence, we should generate a unique id for each query.
|
||||||
val uniqueGroupId = streamingUniqueGroupId(parameters, checkpointLocation)
|
val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveOptions, checkpointLocation)
|
||||||
|
|
||||||
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
|
val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions)
|
||||||
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
|
|
||||||
|
|
||||||
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
|
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
|
||||||
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
|
caseInsensitiveOptions, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
|
||||||
|
|
||||||
val kafkaOffsetReader = new KafkaOffsetReader(
|
val kafkaOffsetReader = new KafkaOffsetReader(
|
||||||
strategy(parameters),
|
strategy(caseInsensitiveOptions),
|
||||||
kafkaParamsForDriver(specifiedKafkaParams),
|
kafkaParamsForDriver(specifiedKafkaParams),
|
||||||
parameters,
|
caseInsensitiveOptions,
|
||||||
driverGroupIdPrefix = s"$uniqueGroupId-driver")
|
driverGroupIdPrefix = s"$uniqueGroupId-driver")
|
||||||
|
|
||||||
new KafkaMicroBatchStream(
|
new KafkaMicroBatchStream(
|
||||||
|
@ -445,32 +444,26 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
||||||
options,
|
options,
|
||||||
checkpointLocation,
|
checkpointLocation,
|
||||||
startingStreamOffsets,
|
startingStreamOffsets,
|
||||||
failOnDataLoss(caseInsensitiveParams))
|
failOnDataLoss(caseInsensitiveOptions))
|
||||||
}
|
}
|
||||||
|
|
||||||
override def toContinuousStream(checkpointLocation: String): ContinuousStream = {
|
override def toContinuousStream(checkpointLocation: String): ContinuousStream = {
|
||||||
val parameters = options.asScala.toMap
|
val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap)
|
||||||
validateStreamOptions(parameters)
|
validateStreamOptions(caseInsensitiveOptions)
|
||||||
// Each running query should use its own group id. Otherwise, the query may be only assigned
|
// Each running query should use its own group id. Otherwise, the query may be only assigned
|
||||||
// partial data since Kafka will assign partitions to multiple consumers having the same group
|
// partial data since Kafka will assign partitions to multiple consumers having the same group
|
||||||
// id. Hence, we should generate a unique id for each query.
|
// id. Hence, we should generate a unique id for each query.
|
||||||
val uniqueGroupId = streamingUniqueGroupId(parameters, checkpointLocation)
|
val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveOptions, checkpointLocation)
|
||||||
|
|
||||||
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
|
val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions)
|
||||||
val specifiedKafkaParams =
|
|
||||||
parameters
|
|
||||||
.keySet
|
|
||||||
.filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
|
|
||||||
.map { k => k.drop(6).toString -> parameters(k) }
|
|
||||||
.toMap
|
|
||||||
|
|
||||||
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
|
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
|
||||||
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
|
caseInsensitiveOptions, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
|
||||||
|
|
||||||
val kafkaOffsetReader = new KafkaOffsetReader(
|
val kafkaOffsetReader = new KafkaOffsetReader(
|
||||||
strategy(caseInsensitiveParams),
|
strategy(caseInsensitiveOptions),
|
||||||
kafkaParamsForDriver(specifiedKafkaParams),
|
kafkaParamsForDriver(specifiedKafkaParams),
|
||||||
parameters,
|
caseInsensitiveOptions,
|
||||||
driverGroupIdPrefix = s"$uniqueGroupId-driver")
|
driverGroupIdPrefix = s"$uniqueGroupId-driver")
|
||||||
|
|
||||||
new KafkaContinuousStream(
|
new KafkaContinuousStream(
|
||||||
|
@ -479,7 +472,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
||||||
options,
|
options,
|
||||||
checkpointLocation,
|
checkpointLocation,
|
||||||
startingStreamOffsets,
|
startingStreamOffsets,
|
||||||
failOnDataLoss(caseInsensitiveParams))
|
failOnDataLoss(caseInsensitiveOptions))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,60 +22,92 @@ import java.util.Locale
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import org.mockito.Mockito.{mock, when}
|
import org.mockito.Mockito.{mock, when}
|
||||||
import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
|
|
||||||
|
|
||||||
import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite}
|
import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite}
|
||||||
import org.apache.spark.sql.sources.v2.reader.Scan
|
import org.apache.spark.sql.sources.v2.reader.Scan
|
||||||
import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
||||||
|
|
||||||
class KafkaSourceProviderSuite extends SparkFunSuite with PrivateMethodTester {
|
class KafkaSourceProviderSuite extends SparkFunSuite {
|
||||||
|
|
||||||
private val pollTimeoutMsMethod = PrivateMethod[Long]('pollTimeoutMs)
|
private val expected = "1111"
|
||||||
private val maxOffsetsPerTriggerMethod = PrivateMethod[Option[Long]]('maxOffsetsPerTrigger)
|
|
||||||
|
|
||||||
override protected def afterEach(): Unit = {
|
override protected def afterEach(): Unit = {
|
||||||
SparkEnv.set(null)
|
SparkEnv.set(null)
|
||||||
super.afterEach()
|
super.afterEach()
|
||||||
}
|
}
|
||||||
|
|
||||||
test("micro-batch mode - options should be handled as case-insensitive") {
|
test("batch mode - options should be handled as case-insensitive") {
|
||||||
def verifyFieldsInMicroBatchStream(
|
verifyFieldsInBatch(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, expected, batch => {
|
||||||
options: CaseInsensitiveStringMap,
|
assert(expected.toLong === batch.pollTimeoutMs)
|
||||||
expectedPollTimeoutMs: Long,
|
})
|
||||||
expectedMaxOffsetsPerTrigger: Option[Long]): Unit = {
|
|
||||||
// KafkaMicroBatchStream reads Spark conf from SparkEnv for default value
|
|
||||||
// hence we set mock SparkEnv here before creating KafkaMicroBatchStream
|
|
||||||
val sparkEnv = mock(classOf[SparkEnv])
|
|
||||||
when(sparkEnv.conf).thenReturn(new SparkConf())
|
|
||||||
SparkEnv.set(sparkEnv)
|
|
||||||
|
|
||||||
val scan = getKafkaDataSourceScan(options)
|
|
||||||
val stream = scan.toMicroBatchStream("dummy").asInstanceOf[KafkaMicroBatchStream]
|
|
||||||
|
|
||||||
assert(expectedPollTimeoutMs === getField(stream, pollTimeoutMsMethod))
|
|
||||||
assert(expectedMaxOffsetsPerTrigger === getField(stream, maxOffsetsPerTriggerMethod))
|
|
||||||
}
|
|
||||||
|
|
||||||
val expectedValue = 1000L
|
|
||||||
buildCaseInsensitiveStringMapForUpperAndLowerKey(
|
|
||||||
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT -> expectedValue.toString,
|
|
||||||
KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER -> expectedValue.toString)
|
|
||||||
.foreach(verifyFieldsInMicroBatchStream(_, expectedValue, Some(expectedValue)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("SPARK-28142 - continuous mode - options should be handled as case-insensitive") {
|
test("micro-batch mode - options should be handled as case-insensitive") {
|
||||||
def verifyFieldsInContinuousStream(
|
verifyFieldsInMicroBatchStream(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, expected, stream => {
|
||||||
options: CaseInsensitiveStringMap,
|
assert(expected.toLong === stream.pollTimeoutMs)
|
||||||
expectedPollTimeoutMs: Long): Unit = {
|
})
|
||||||
|
verifyFieldsInMicroBatchStream(KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER, expected, stream => {
|
||||||
|
assert(Some(expected.toLong) === stream.maxOffsetsPerTrigger)
|
||||||
|
})
|
||||||
|
verifyFieldsInMicroBatchStream(KafkaSourceProvider.FETCH_OFFSET_NUM_RETRY, expected, stream => {
|
||||||
|
assert(expected.toInt === stream.kafkaOffsetReader.maxOffsetFetchAttempts)
|
||||||
|
})
|
||||||
|
verifyFieldsInMicroBatchStream(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, expected,
|
||||||
|
stream => {
|
||||||
|
assert(expected.toLong === stream.kafkaOffsetReader.offsetFetchAttemptIntervalMs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
test("continuous mode - options should be handled as case-insensitive") {
|
||||||
|
verifyFieldsInContinuousStream(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, expected, stream => {
|
||||||
|
assert(expected.toLong === stream.pollTimeoutMs)
|
||||||
|
})
|
||||||
|
verifyFieldsInContinuousStream(KafkaSourceProvider.FETCH_OFFSET_NUM_RETRY, expected, stream => {
|
||||||
|
assert(expected.toInt === stream.offsetReader.maxOffsetFetchAttempts)
|
||||||
|
})
|
||||||
|
verifyFieldsInContinuousStream(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, expected,
|
||||||
|
stream => {
|
||||||
|
assert(expected.toLong === stream.offsetReader.offsetFetchAttemptIntervalMs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
private def verifyFieldsInBatch(
|
||||||
|
key: String,
|
||||||
|
value: String,
|
||||||
|
validate: (KafkaBatch) => Unit): Unit = {
|
||||||
|
buildCaseInsensitiveStringMapForUpperAndLowerKey(key -> value).foreach { options =>
|
||||||
|
val scan = getKafkaDataSourceScan(options)
|
||||||
|
val batch = scan.toBatch().asInstanceOf[KafkaBatch]
|
||||||
|
validate(batch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def verifyFieldsInMicroBatchStream(
|
||||||
|
key: String,
|
||||||
|
value: String,
|
||||||
|
validate: (KafkaMicroBatchStream) => Unit): Unit = {
|
||||||
|
// KafkaMicroBatchStream reads Spark conf from SparkEnv for default value
|
||||||
|
// hence we set mock SparkEnv here before creating KafkaMicroBatchStream
|
||||||
|
val sparkEnv = mock(classOf[SparkEnv])
|
||||||
|
when(sparkEnv.conf).thenReturn(new SparkConf())
|
||||||
|
SparkEnv.set(sparkEnv)
|
||||||
|
|
||||||
|
buildCaseInsensitiveStringMapForUpperAndLowerKey(key -> value).foreach { options =>
|
||||||
|
val scan = getKafkaDataSourceScan(options)
|
||||||
|
val stream = scan.toMicroBatchStream("dummy").asInstanceOf[KafkaMicroBatchStream]
|
||||||
|
validate(stream)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def verifyFieldsInContinuousStream(
|
||||||
|
key: String,
|
||||||
|
value: String,
|
||||||
|
validate: (KafkaContinuousStream) => Unit): Unit = {
|
||||||
|
buildCaseInsensitiveStringMapForUpperAndLowerKey(key -> value).foreach { options =>
|
||||||
val scan = getKafkaDataSourceScan(options)
|
val scan = getKafkaDataSourceScan(options)
|
||||||
val stream = scan.toContinuousStream("dummy").asInstanceOf[KafkaContinuousStream]
|
val stream = scan.toContinuousStream("dummy").asInstanceOf[KafkaContinuousStream]
|
||||||
assert(expectedPollTimeoutMs === getField(stream, pollTimeoutMsMethod))
|
validate(stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
val expectedValue = 1000
|
|
||||||
buildCaseInsensitiveStringMapForUpperAndLowerKey(
|
|
||||||
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT -> expectedValue.toString)
|
|
||||||
.foreach(verifyFieldsInContinuousStream(_, expectedValue))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private def buildCaseInsensitiveStringMapForUpperAndLowerKey(
|
private def buildCaseInsensitiveStringMapForUpperAndLowerKey(
|
||||||
|
@ -95,8 +127,4 @@ class KafkaSourceProviderSuite extends SparkFunSuite with PrivateMethodTester {
|
||||||
val provider = new KafkaSourceProvider()
|
val provider = new KafkaSourceProvider()
|
||||||
provider.getTable(options).newScanBuilder(options).build()
|
provider.getTable(options).newScanBuilder(options).build()
|
||||||
}
|
}
|
||||||
|
|
||||||
private def getField[T](obj: AnyRef, method: PrivateMethod[T]): T = {
|
|
||||||
obj.invokePrivate(method())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue