[SPARK-28695][SS] Use CaseInsensitiveMap in KafkaSourceProvider to make source param handling more robust

## What changes were proposed in this pull request?

[SPARK-28163](https://issues.apache.org/jira/browse/SPARK-28163) fixed a bug and during the analysis we've concluded it would be more robust to use `CaseInsensitiveMap` inside Kafka source. This case less lower/upper case problem would rise in the future.

Please note this PR doesn't intend to solve any kind of actual problem but finish the concept added in [SPARK-28163](https://issues.apache.org/jira/browse/SPARK-28163) (in a fix PR I didn't want to add too invasive changes). In this PR I've changed `Map[String, String]` to `CaseInsensitiveMap[String]` to enforce the usage. These are the main use-cases:
* `contains` => `CaseInsensitiveMap` solves it
* `get...` => `CaseInsensitiveMap` solves it
* `filter` => keys must be converted to lowercase because there is no guarantee that the incoming map has such key set
* `find` => keys must be converted to lowercase because there is no guarantee that the incoming map has such key set
* passing parameters to Kafka consumer/producer => keys must be converted to lowercase because there is no guarantee that the incoming map has such key set

## How was this patch tested?

Existing unit tests.

Closes #25418 from gaborgsomogyi/SPARK-28695.

Authored-by: Gabor Somogyi <gabor.g.somogyi@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Gabor Somogyi 2019-08-15 14:43:52 +08:00 committed by Wenchen Fan
parent 0526529b31
commit a493031e2e
3 changed files with 71 additions and 64 deletions

View file

@ -31,7 +31,7 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.kafka010.KafkaSource._ import org.apache.spark.sql.kafka010.KafkaSource._
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
@ -76,7 +76,7 @@ private[kafka010] class KafkaSource(
sqlContext: SQLContext, sqlContext: SQLContext,
kafkaReader: KafkaOffsetReader, kafkaReader: KafkaOffsetReader,
executorKafkaParams: ju.Map[String, Object], executorKafkaParams: ju.Map[String, Object],
sourceOptions: Map[String, String], sourceOptions: CaseInsensitiveMap[String],
metadataPath: String, metadataPath: String,
startingOffsets: KafkaOffsetRangeLimit, startingOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean) failOnDataLoss: Boolean)

View file

@ -67,7 +67,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
schema: Option[StructType], schema: Option[StructType],
providerName: String, providerName: String,
parameters: Map[String, String]): (String, StructType) = { parameters: Map[String, String]): (String, StructType) = {
validateStreamOptions(parameters) val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
validateStreamOptions(caseInsensitiveParameters)
require(schema.isEmpty, "Kafka source has a fixed schema and cannot be set with a custom one") require(schema.isEmpty, "Kafka source has a fixed schema and cannot be set with a custom one")
(shortName(), KafkaOffsetReader.kafkaSchema) (shortName(), KafkaOffsetReader.kafkaSchema)
} }
@ -85,7 +86,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
// id. Hence, we should generate a unique id for each query. // id. Hence, we should generate a unique id for each query.
val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveParameters, metadataPath) val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveParameters, metadataPath)
val specifiedKafkaParams = convertToSpecifiedParams(parameters) val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveParameters)
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
@ -121,7 +122,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
parameters: Map[String, String]): BaseRelation = { parameters: Map[String, String]): BaseRelation = {
val caseInsensitiveParameters = CaseInsensitiveMap(parameters) val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
validateBatchOptions(caseInsensitiveParameters) validateBatchOptions(caseInsensitiveParameters)
val specifiedKafkaParams = convertToSpecifiedParams(parameters) val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveParameters)
val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit)
@ -146,8 +147,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
parameters: Map[String, String], parameters: Map[String, String],
partitionColumns: Seq[String], partitionColumns: Seq[String],
outputMode: OutputMode): Sink = { outputMode: OutputMode): Sink = {
val defaultTopic = parameters.get(TOPIC_OPTION_KEY).map(_.trim) val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
val specifiedKafkaParams = kafkaParamsForProducer(parameters) val defaultTopic = caseInsensitiveParameters.get(TOPIC_OPTION_KEY).map(_.trim)
val specifiedKafkaParams = kafkaParamsForProducer(caseInsensitiveParameters)
new KafkaSink(sqlContext, specifiedKafkaParams, defaultTopic) new KafkaSink(sqlContext, specifiedKafkaParams, defaultTopic)
} }
@ -163,8 +165,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
s"${SaveMode.ErrorIfExists} (default).") s"${SaveMode.ErrorIfExists} (default).")
case _ => // good case _ => // good
} }
val topic = parameters.get(TOPIC_OPTION_KEY).map(_.trim) val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
val specifiedKafkaParams = kafkaParamsForProducer(parameters) val topic = caseInsensitiveParameters.get(TOPIC_OPTION_KEY).map(_.trim)
val specifiedKafkaParams = kafkaParamsForProducer(caseInsensitiveParameters)
KafkaWriter.write(outerSQLContext.sparkSession, data.queryExecution, specifiedKafkaParams, KafkaWriter.write(outerSQLContext.sparkSession, data.queryExecution, specifiedKafkaParams,
topic) topic)
@ -184,28 +187,31 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
} }
} }
private def strategy(caseInsensitiveParams: Map[String, String]) = private def strategy(params: CaseInsensitiveMap[String]) = {
caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { val lowercaseParams = params.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
case (ASSIGN, value) =>
AssignStrategy(JsonUtils.partitions(value)) lowercaseParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match {
case (SUBSCRIBE, value) => case (ASSIGN, value) =>
SubscribeStrategy(value.split(",").map(_.trim()).filter(_.nonEmpty)) AssignStrategy(JsonUtils.partitions(value))
case (SUBSCRIBE_PATTERN, value) => case (SUBSCRIBE, value) =>
SubscribePatternStrategy(value.trim()) SubscribeStrategy(value.split(",").map(_.trim()).filter(_.nonEmpty))
case _ => case (SUBSCRIBE_PATTERN, value) =>
// Should never reach here as we are already matching on SubscribePatternStrategy(value.trim())
// matched strategy names case _ =>
throw new IllegalArgumentException("Unknown option") // Should never reach here as we are already matching on
// matched strategy names
throw new IllegalArgumentException("Unknown option")
}
} }
private def failOnDataLoss(caseInsensitiveParams: Map[String, String]) = private def failOnDataLoss(params: CaseInsensitiveMap[String]) =
caseInsensitiveParams.getOrElse(FAIL_ON_DATA_LOSS_OPTION_KEY, "true").toBoolean params.getOrElse(FAIL_ON_DATA_LOSS_OPTION_KEY, "true").toBoolean
private def validateGeneralOptions(parameters: Map[String, String]): Unit = { private def validateGeneralOptions(params: CaseInsensitiveMap[String]): Unit = {
// Validate source options // Validate source options
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val lowercaseParams = params.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val specifiedStrategies = val specifiedStrategies =
caseInsensitiveParams.filter { case (k, _) => STRATEGY_OPTION_KEYS.contains(k) }.toSeq lowercaseParams.filter { case (k, _) => STRATEGY_OPTION_KEYS.contains(k) }.toSeq
if (specifiedStrategies.isEmpty) { if (specifiedStrategies.isEmpty) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
@ -217,7 +223,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
+ STRATEGY_OPTION_KEYS.mkString(", ") + ". See the docs for more details.") + STRATEGY_OPTION_KEYS.mkString(", ") + ". See the docs for more details.")
} }
caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { lowercaseParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match {
case (ASSIGN, value) => case (ASSIGN, value) =>
if (!value.trim.startsWith("{")) { if (!value.trim.startsWith("{")) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
@ -233,7 +239,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
s"'subscribe' is '$value'") s"'subscribe' is '$value'")
} }
case (SUBSCRIBE_PATTERN, value) => case (SUBSCRIBE_PATTERN, value) =>
val pattern = caseInsensitiveParams(SUBSCRIBE_PATTERN).trim() val pattern = params(SUBSCRIBE_PATTERN).trim()
if (pattern.isEmpty) { if (pattern.isEmpty) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Pattern to subscribe is empty as specified value for option " + "Pattern to subscribe is empty as specified value for option " +
@ -246,22 +252,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
} }
// Validate minPartitions value if present // Validate minPartitions value if present
if (caseInsensitiveParams.contains(MIN_PARTITIONS_OPTION_KEY)) { if (params.contains(MIN_PARTITIONS_OPTION_KEY)) {
val p = caseInsensitiveParams(MIN_PARTITIONS_OPTION_KEY).toInt val p = params(MIN_PARTITIONS_OPTION_KEY).toInt
if (p <= 0) throw new IllegalArgumentException("minPartitions must be positive") if (p <= 0) throw new IllegalArgumentException("minPartitions must be positive")
} }
// Validate user-specified Kafka options // Validate user-specified Kafka options
if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) { if (params.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) {
logWarning(CUSTOM_GROUP_ID_ERROR_MESSAGE) logWarning(CUSTOM_GROUP_ID_ERROR_MESSAGE)
if (caseInsensitiveParams.contains(GROUP_ID_PREFIX)) { if (params.contains(GROUP_ID_PREFIX)) {
logWarning("Option 'groupIdPrefix' will be ignored as " + logWarning("Option 'groupIdPrefix' will be ignored as " +
s"option 'kafka.${ConsumerConfig.GROUP_ID_CONFIG}' has been set.") s"option 'kafka.${ConsumerConfig.GROUP_ID_CONFIG}' has been set.")
} }
} }
if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}")) { if (params.contains(s"kafka.${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}")) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
s""" s"""
|Kafka option '${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}' is not supported. |Kafka option '${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}' is not supported.
@ -275,14 +281,14 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
""".stripMargin) """.stripMargin)
} }
if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG}")) { if (params.contains(s"kafka.${ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG}")) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
s"Kafka option '${ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG}' is not supported as keys " s"Kafka option '${ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG}' is not supported as keys "
+ "are deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame operations " + "are deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame operations "
+ "to explicitly deserialize the keys.") + "to explicitly deserialize the keys.")
} }
if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}")) if (params.contains(s"kafka.${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}"))
{ {
throw new IllegalArgumentException( throw new IllegalArgumentException(
s"Kafka option '${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}' is not supported as " s"Kafka option '${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}' is not supported as "
@ -295,29 +301,29 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
ConsumerConfig.INTERCEPTOR_CLASSES_CONFIG) // interceptors can modify payload, so not safe ConsumerConfig.INTERCEPTOR_CLASSES_CONFIG) // interceptors can modify payload, so not safe
otherUnsupportedConfigs.foreach { c => otherUnsupportedConfigs.foreach { c =>
if (caseInsensitiveParams.contains(s"kafka.$c")) { if (params.contains(s"kafka.$c")) {
throw new IllegalArgumentException(s"Kafka option '$c' is not supported") throw new IllegalArgumentException(s"Kafka option '$c' is not supported")
} }
} }
if (!caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG}")) { if (!params.contains(s"kafka.${ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG}")) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
s"Option 'kafka.${ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG}' must be specified for " + s"Option 'kafka.${ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG}' must be specified for " +
s"configuring Kafka consumer") s"configuring Kafka consumer")
} }
} }
private def validateStreamOptions(caseInsensitiveParams: Map[String, String]) = { private def validateStreamOptions(params: CaseInsensitiveMap[String]) = {
// Stream specific options // Stream specific options
caseInsensitiveParams.get(ENDING_OFFSETS_OPTION_KEY).map(_ => params.get(ENDING_OFFSETS_OPTION_KEY).map(_ =>
throw new IllegalArgumentException("ending offset not valid in streaming queries")) throw new IllegalArgumentException("ending offset not valid in streaming queries"))
validateGeneralOptions(caseInsensitiveParams) validateGeneralOptions(params)
} }
private def validateBatchOptions(caseInsensitiveParams: Map[String, String]) = { private def validateBatchOptions(params: CaseInsensitiveMap[String]) = {
// Batch specific options // Batch specific options
KafkaSourceProvider.getKafkaOffsetRangeLimit( KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) match { params, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) match {
case EarliestOffsetRangeLimit => // good to go case EarliestOffsetRangeLimit => // good to go
case LatestOffsetRangeLimit => case LatestOffsetRangeLimit =>
throw new IllegalArgumentException("starting offset can't be latest " + throw new IllegalArgumentException("starting offset can't be latest " +
@ -332,7 +338,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
} }
KafkaSourceProvider.getKafkaOffsetRangeLimit( KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParams, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) match { params, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) match {
case EarliestOffsetRangeLimit => case EarliestOffsetRangeLimit =>
throw new IllegalArgumentException("ending offset can't be earliest " + throw new IllegalArgumentException("ending offset can't be earliest " +
"for batch queries on Kafka") "for batch queries on Kafka")
@ -346,10 +352,10 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
} }
} }
validateGeneralOptions(caseInsensitiveParams) validateGeneralOptions(params)
// Don't want to throw an error, but at least log a warning. // Don't want to throw an error, but at least log a warning.
if (caseInsensitiveParams.get(MAX_OFFSET_PER_TRIGGER.toLowerCase(Locale.ROOT)).isDefined) { if (params.contains(MAX_OFFSET_PER_TRIGGER)) {
logWarning("maxOffsetsPerTrigger option ignored in batch queries") logWarning("maxOffsetsPerTrigger option ignored in batch queries")
} }
} }
@ -375,7 +381,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
new WriteBuilder { new WriteBuilder {
private var inputSchema: StructType = _ private var inputSchema: StructType = _
private val topic = Option(options.get(TOPIC_OPTION_KEY)).map(_.trim) private val topic = Option(options.get(TOPIC_OPTION_KEY)).map(_.trim)
private val producerParams = kafkaParamsForProducer(options.asScala.toMap) private val producerParams =
kafkaParamsForProducer(CaseInsensitiveMap(options.asScala.toMap))
override def withInputDataSchema(schema: StructType): WriteBuilder = { override def withInputDataSchema(schema: StructType): WriteBuilder = {
this.inputSchema = schema this.inputSchema = schema
@ -486,10 +493,10 @@ private[kafka010] object KafkaSourceProvider extends Logging {
private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets"
private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss"
private[kafka010] val MIN_PARTITIONS_OPTION_KEY = "minpartitions" private[kafka010] val MIN_PARTITIONS_OPTION_KEY = "minpartitions"
private[kafka010] val MAX_OFFSET_PER_TRIGGER = "maxOffsetsPerTrigger" private[kafka010] val MAX_OFFSET_PER_TRIGGER = "maxoffsetspertrigger"
private[kafka010] val FETCH_OFFSET_NUM_RETRY = "fetchOffset.numRetries" private[kafka010] val FETCH_OFFSET_NUM_RETRY = "fetchoffset.numretries"
private[kafka010] val FETCH_OFFSET_RETRY_INTERVAL_MS = "fetchOffset.retryIntervalMs" private[kafka010] val FETCH_OFFSET_RETRY_INTERVAL_MS = "fetchoffset.retryintervalms"
private[kafka010] val CONSUMER_POLL_TIMEOUT = "kafkaConsumer.pollTimeoutMs" private[kafka010] val CONSUMER_POLL_TIMEOUT = "kafkaconsumer.polltimeoutms"
private val GROUP_ID_PREFIX = "groupidprefix" private val GROUP_ID_PREFIX = "groupidprefix"
val TOPIC_OPTION_KEY = "topic" val TOPIC_OPTION_KEY = "topic"
@ -525,7 +532,7 @@ private[kafka010] object KafkaSourceProvider extends Logging {
private val deserClassName = classOf[ByteArrayDeserializer].getName private val deserClassName = classOf[ByteArrayDeserializer].getName
def getKafkaOffsetRangeLimit( def getKafkaOffsetRangeLimit(
params: Map[String, String], params: CaseInsensitiveMap[String],
offsetOptionKey: String, offsetOptionKey: String,
defaultOffsets: KafkaOffsetRangeLimit): KafkaOffsetRangeLimit = { defaultOffsets: KafkaOffsetRangeLimit): KafkaOffsetRangeLimit = {
params.get(offsetOptionKey).map(_.trim) match { params.get(offsetOptionKey).map(_.trim) match {
@ -583,9 +590,8 @@ private[kafka010] object KafkaSourceProvider extends Logging {
* Returns a unique batch consumer group (group.id), allowing the user to set the prefix of * Returns a unique batch consumer group (group.id), allowing the user to set the prefix of
* the consumer group * the consumer group
*/ */
private[kafka010] def batchUniqueGroupId(parameters: Map[String, String]): String = { private[kafka010] def batchUniqueGroupId(params: CaseInsensitiveMap[String]): String = {
val groupIdPrefix = parameters val groupIdPrefix = params.getOrElse(GROUP_ID_PREFIX, "spark-kafka-relation")
.getOrElse(GROUP_ID_PREFIX, "spark-kafka-relation")
s"${groupIdPrefix}-${UUID.randomUUID}" s"${groupIdPrefix}-${UUID.randomUUID}"
} }
@ -594,29 +600,27 @@ private[kafka010] object KafkaSourceProvider extends Logging {
* the consumer group * the consumer group
*/ */
private def streamingUniqueGroupId( private def streamingUniqueGroupId(
parameters: Map[String, String], params: CaseInsensitiveMap[String],
metadataPath: String): String = { metadataPath: String): String = {
val groupIdPrefix = parameters val groupIdPrefix = params.getOrElse(GROUP_ID_PREFIX, "spark-kafka-source")
.getOrElse(GROUP_ID_PREFIX, "spark-kafka-source")
s"${groupIdPrefix}-${UUID.randomUUID}-${metadataPath.hashCode}" s"${groupIdPrefix}-${UUID.randomUUID}-${metadataPath.hashCode}"
} }
private[kafka010] def kafkaParamsForProducer( private[kafka010] def kafkaParamsForProducer(
parameters: Map[String, String]): ju.Map[String, Object] = { params: CaseInsensitiveMap[String]): ju.Map[String, Object] = {
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } if (params.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) {
if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys "
+ "are serialized with ByteArraySerializer.") + "are serialized with ByteArraySerializer.")
} }
if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) { if (params.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as "
+ "value are serialized with ByteArraySerializer.") + "value are serialized with ByteArraySerializer.")
} }
val specifiedKafkaParams = convertToSpecifiedParams(parameters) val specifiedKafkaParams = convertToSpecifiedParams(params)
KafkaConfigUpdater("executor", specifiedKafkaParams) KafkaConfigUpdater("executor", specifiedKafkaParams)
.set(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, serClassName) .set(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, serClassName)

View file

@ -34,6 +34,7 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.time.SpanSugar._ import org.scalatest.time.SpanSugar._
import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession} import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming._
@ -1336,14 +1337,16 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
(ENDING_OFFSETS_OPTION_KEY, "laTest", LatestOffsetRangeLimit), (ENDING_OFFSETS_OPTION_KEY, "laTest", LatestOffsetRangeLimit),
(STARTING_OFFSETS_OPTION_KEY, """{"topic-A":{"0":23}}""", (STARTING_OFFSETS_OPTION_KEY, """{"topic-A":{"0":23}}""",
SpecificOffsetRangeLimit(Map(new TopicPartition("topic-A", 0) -> 23))))) { SpecificOffsetRangeLimit(Map(new TopicPartition("topic-A", 0) -> 23))))) {
val offset = getKafkaOffsetRangeLimit(Map(optionKey -> optionValue), optionKey, answer) val offset = getKafkaOffsetRangeLimit(
CaseInsensitiveMap[String](Map(optionKey -> optionValue)), optionKey, answer)
assert(offset === answer) assert(offset === answer)
} }
for ((optionKey, answer) <- Seq( for ((optionKey, answer) <- Seq(
(STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit), (STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit),
(ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit))) { (ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit))) {
val offset = getKafkaOffsetRangeLimit(Map.empty, optionKey, answer) val offset = getKafkaOffsetRangeLimit(
CaseInsensitiveMap[String](Map.empty), optionKey, answer)
assert(offset === answer) assert(offset === answer)
} }
} }