[SPARK-28695][SS] Use CaseInsensitiveMap in KafkaSourceProvider to make source param handling more robust

## What changes were proposed in this pull request?

[SPARK-28163](https://issues.apache.org/jira/browse/SPARK-28163) fixed a bug and during the analysis we've concluded it would be more robust to use `CaseInsensitiveMap` inside Kafka source. This case less lower/upper case problem would rise in the future.

Please note this PR doesn't intend to solve any kind of actual problem but finish the concept added in [SPARK-28163](https://issues.apache.org/jira/browse/SPARK-28163) (in a fix PR I didn't want to add too invasive changes). In this PR I've changed `Map[String, String]` to `CaseInsensitiveMap[String]` to enforce the usage. These are the main use-cases:
* `contains` => `CaseInsensitiveMap` solves it
* `get...` => `CaseInsensitiveMap` solves it
* `filter` => keys must be converted to lowercase because there is no guarantee that the incoming map has such key set
* `find` => keys must be converted to lowercase because there is no guarantee that the incoming map has such key set
* passing parameters to Kafka consumer/producer => keys must be converted to lowercase because there is no guarantee that the incoming map has such key set

## How was this patch tested?

Existing unit tests.

Closes #25418 from gaborgsomogyi/SPARK-28695.

Authored-by: Gabor Somogyi <gabor.g.somogyi@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Gabor Somogyi 2019-08-15 14:43:52 +08:00 committed by Wenchen Fan
parent 0526529b31
commit a493031e2e
3 changed files with 71 additions and 64 deletions

View file

@ -31,7 +31,7 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.kafka010.KafkaSource._
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
@ -76,7 +76,7 @@ private[kafka010] class KafkaSource(
sqlContext: SQLContext,
kafkaReader: KafkaOffsetReader,
executorKafkaParams: ju.Map[String, Object],
sourceOptions: Map[String, String],
sourceOptions: CaseInsensitiveMap[String],
metadataPath: String,
startingOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)

View file

@ -67,7 +67,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): (String, StructType) = {
validateStreamOptions(parameters)
val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
validateStreamOptions(caseInsensitiveParameters)
require(schema.isEmpty, "Kafka source has a fixed schema and cannot be set with a custom one")
(shortName(), KafkaOffsetReader.kafkaSchema)
}
@ -85,7 +86,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
// id. Hence, we should generate a unique id for each query.
val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveParameters, metadataPath)
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveParameters)
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
@ -121,7 +122,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
parameters: Map[String, String]): BaseRelation = {
val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
validateBatchOptions(caseInsensitiveParameters)
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveParameters)
val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit)
@ -146,8 +147,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
parameters: Map[String, String],
partitionColumns: Seq[String],
outputMode: OutputMode): Sink = {
val defaultTopic = parameters.get(TOPIC_OPTION_KEY).map(_.trim)
val specifiedKafkaParams = kafkaParamsForProducer(parameters)
val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
val defaultTopic = caseInsensitiveParameters.get(TOPIC_OPTION_KEY).map(_.trim)
val specifiedKafkaParams = kafkaParamsForProducer(caseInsensitiveParameters)
new KafkaSink(sqlContext, specifiedKafkaParams, defaultTopic)
}
@ -163,8 +165,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
s"${SaveMode.ErrorIfExists} (default).")
case _ => // good
}
val topic = parameters.get(TOPIC_OPTION_KEY).map(_.trim)
val specifiedKafkaParams = kafkaParamsForProducer(parameters)
val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
val topic = caseInsensitiveParameters.get(TOPIC_OPTION_KEY).map(_.trim)
val specifiedKafkaParams = kafkaParamsForProducer(caseInsensitiveParameters)
KafkaWriter.write(outerSQLContext.sparkSession, data.queryExecution, specifiedKafkaParams,
topic)
@ -184,28 +187,31 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
}
}
private def strategy(caseInsensitiveParams: Map[String, String]) =
caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match {
case (ASSIGN, value) =>
AssignStrategy(JsonUtils.partitions(value))
case (SUBSCRIBE, value) =>
SubscribeStrategy(value.split(",").map(_.trim()).filter(_.nonEmpty))
case (SUBSCRIBE_PATTERN, value) =>
SubscribePatternStrategy(value.trim())
case _ =>
// Should never reach here as we are already matching on
// matched strategy names
throw new IllegalArgumentException("Unknown option")
private def strategy(params: CaseInsensitiveMap[String]) = {
val lowercaseParams = params.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
lowercaseParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match {
case (ASSIGN, value) =>
AssignStrategy(JsonUtils.partitions(value))
case (SUBSCRIBE, value) =>
SubscribeStrategy(value.split(",").map(_.trim()).filter(_.nonEmpty))
case (SUBSCRIBE_PATTERN, value) =>
SubscribePatternStrategy(value.trim())
case _ =>
// Should never reach here as we are already matching on
// matched strategy names
throw new IllegalArgumentException("Unknown option")
}
}
private def failOnDataLoss(caseInsensitiveParams: Map[String, String]) =
caseInsensitiveParams.getOrElse(FAIL_ON_DATA_LOSS_OPTION_KEY, "true").toBoolean
private def failOnDataLoss(params: CaseInsensitiveMap[String]) =
params.getOrElse(FAIL_ON_DATA_LOSS_OPTION_KEY, "true").toBoolean
private def validateGeneralOptions(parameters: Map[String, String]): Unit = {
private def validateGeneralOptions(params: CaseInsensitiveMap[String]): Unit = {
// Validate source options
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val lowercaseParams = params.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val specifiedStrategies =
caseInsensitiveParams.filter { case (k, _) => STRATEGY_OPTION_KEYS.contains(k) }.toSeq
lowercaseParams.filter { case (k, _) => STRATEGY_OPTION_KEYS.contains(k) }.toSeq
if (specifiedStrategies.isEmpty) {
throw new IllegalArgumentException(
@ -217,7 +223,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
+ STRATEGY_OPTION_KEYS.mkString(", ") + ". See the docs for more details.")
}
caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match {
lowercaseParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match {
case (ASSIGN, value) =>
if (!value.trim.startsWith("{")) {
throw new IllegalArgumentException(
@ -233,7 +239,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
s"'subscribe' is '$value'")
}
case (SUBSCRIBE_PATTERN, value) =>
val pattern = caseInsensitiveParams(SUBSCRIBE_PATTERN).trim()
val pattern = params(SUBSCRIBE_PATTERN).trim()
if (pattern.isEmpty) {
throw new IllegalArgumentException(
"Pattern to subscribe is empty as specified value for option " +
@ -246,22 +252,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
}
// Validate minPartitions value if present
if (caseInsensitiveParams.contains(MIN_PARTITIONS_OPTION_KEY)) {
val p = caseInsensitiveParams(MIN_PARTITIONS_OPTION_KEY).toInt
if (params.contains(MIN_PARTITIONS_OPTION_KEY)) {
val p = params(MIN_PARTITIONS_OPTION_KEY).toInt
if (p <= 0) throw new IllegalArgumentException("minPartitions must be positive")
}
// Validate user-specified Kafka options
if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) {
if (params.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) {
logWarning(CUSTOM_GROUP_ID_ERROR_MESSAGE)
if (caseInsensitiveParams.contains(GROUP_ID_PREFIX)) {
if (params.contains(GROUP_ID_PREFIX)) {
logWarning("Option 'groupIdPrefix' will be ignored as " +
s"option 'kafka.${ConsumerConfig.GROUP_ID_CONFIG}' has been set.")
}
}
if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}")) {
if (params.contains(s"kafka.${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}")) {
throw new IllegalArgumentException(
s"""
|Kafka option '${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}' is not supported.
@ -275,14 +281,14 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
""".stripMargin)
}
if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG}")) {
if (params.contains(s"kafka.${ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG}")) {
throw new IllegalArgumentException(
s"Kafka option '${ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG}' is not supported as keys "
+ "are deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame operations "
+ "to explicitly deserialize the keys.")
}
if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}"))
if (params.contains(s"kafka.${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}"))
{
throw new IllegalArgumentException(
s"Kafka option '${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}' is not supported as "
@ -295,29 +301,29 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
ConsumerConfig.INTERCEPTOR_CLASSES_CONFIG) // interceptors can modify payload, so not safe
otherUnsupportedConfigs.foreach { c =>
if (caseInsensitiveParams.contains(s"kafka.$c")) {
if (params.contains(s"kafka.$c")) {
throw new IllegalArgumentException(s"Kafka option '$c' is not supported")
}
}
if (!caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG}")) {
if (!params.contains(s"kafka.${ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG}")) {
throw new IllegalArgumentException(
s"Option 'kafka.${ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG}' must be specified for " +
s"configuring Kafka consumer")
}
}
private def validateStreamOptions(caseInsensitiveParams: Map[String, String]) = {
private def validateStreamOptions(params: CaseInsensitiveMap[String]) = {
// Stream specific options
caseInsensitiveParams.get(ENDING_OFFSETS_OPTION_KEY).map(_ =>
params.get(ENDING_OFFSETS_OPTION_KEY).map(_ =>
throw new IllegalArgumentException("ending offset not valid in streaming queries"))
validateGeneralOptions(caseInsensitiveParams)
validateGeneralOptions(params)
}
private def validateBatchOptions(caseInsensitiveParams: Map[String, String]) = {
private def validateBatchOptions(params: CaseInsensitiveMap[String]) = {
// Batch specific options
KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) match {
params, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) match {
case EarliestOffsetRangeLimit => // good to go
case LatestOffsetRangeLimit =>
throw new IllegalArgumentException("starting offset can't be latest " +
@ -332,7 +338,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
}
KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParams, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) match {
params, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) match {
case EarliestOffsetRangeLimit =>
throw new IllegalArgumentException("ending offset can't be earliest " +
"for batch queries on Kafka")
@ -346,10 +352,10 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
}
}
validateGeneralOptions(caseInsensitiveParams)
validateGeneralOptions(params)
// Don't want to throw an error, but at least log a warning.
if (caseInsensitiveParams.get(MAX_OFFSET_PER_TRIGGER.toLowerCase(Locale.ROOT)).isDefined) {
if (params.contains(MAX_OFFSET_PER_TRIGGER)) {
logWarning("maxOffsetsPerTrigger option ignored in batch queries")
}
}
@ -375,7 +381,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
new WriteBuilder {
private var inputSchema: StructType = _
private val topic = Option(options.get(TOPIC_OPTION_KEY)).map(_.trim)
private val producerParams = kafkaParamsForProducer(options.asScala.toMap)
private val producerParams =
kafkaParamsForProducer(CaseInsensitiveMap(options.asScala.toMap))
override def withInputDataSchema(schema: StructType): WriteBuilder = {
this.inputSchema = schema
@ -486,10 +493,10 @@ private[kafka010] object KafkaSourceProvider extends Logging {
private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets"
private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss"
private[kafka010] val MIN_PARTITIONS_OPTION_KEY = "minpartitions"
private[kafka010] val MAX_OFFSET_PER_TRIGGER = "maxOffsetsPerTrigger"
private[kafka010] val FETCH_OFFSET_NUM_RETRY = "fetchOffset.numRetries"
private[kafka010] val FETCH_OFFSET_RETRY_INTERVAL_MS = "fetchOffset.retryIntervalMs"
private[kafka010] val CONSUMER_POLL_TIMEOUT = "kafkaConsumer.pollTimeoutMs"
private[kafka010] val MAX_OFFSET_PER_TRIGGER = "maxoffsetspertrigger"
private[kafka010] val FETCH_OFFSET_NUM_RETRY = "fetchoffset.numretries"
private[kafka010] val FETCH_OFFSET_RETRY_INTERVAL_MS = "fetchoffset.retryintervalms"
private[kafka010] val CONSUMER_POLL_TIMEOUT = "kafkaconsumer.polltimeoutms"
private val GROUP_ID_PREFIX = "groupidprefix"
val TOPIC_OPTION_KEY = "topic"
@ -525,7 +532,7 @@ private[kafka010] object KafkaSourceProvider extends Logging {
private val deserClassName = classOf[ByteArrayDeserializer].getName
def getKafkaOffsetRangeLimit(
params: Map[String, String],
params: CaseInsensitiveMap[String],
offsetOptionKey: String,
defaultOffsets: KafkaOffsetRangeLimit): KafkaOffsetRangeLimit = {
params.get(offsetOptionKey).map(_.trim) match {
@ -583,9 +590,8 @@ private[kafka010] object KafkaSourceProvider extends Logging {
* Returns a unique batch consumer group (group.id), allowing the user to set the prefix of
* the consumer group
*/
private[kafka010] def batchUniqueGroupId(parameters: Map[String, String]): String = {
val groupIdPrefix = parameters
.getOrElse(GROUP_ID_PREFIX, "spark-kafka-relation")
private[kafka010] def batchUniqueGroupId(params: CaseInsensitiveMap[String]): String = {
val groupIdPrefix = params.getOrElse(GROUP_ID_PREFIX, "spark-kafka-relation")
s"${groupIdPrefix}-${UUID.randomUUID}"
}
@ -594,29 +600,27 @@ private[kafka010] object KafkaSourceProvider extends Logging {
* the consumer group
*/
private def streamingUniqueGroupId(
parameters: Map[String, String],
params: CaseInsensitiveMap[String],
metadataPath: String): String = {
val groupIdPrefix = parameters
.getOrElse(GROUP_ID_PREFIX, "spark-kafka-source")
val groupIdPrefix = params.getOrElse(GROUP_ID_PREFIX, "spark-kafka-source")
s"${groupIdPrefix}-${UUID.randomUUID}-${metadataPath.hashCode}"
}
private[kafka010] def kafkaParamsForProducer(
parameters: Map[String, String]): ju.Map[String, Object] = {
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) {
params: CaseInsensitiveMap[String]): ju.Map[String, Object] = {
if (params.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) {
throw new IllegalArgumentException(
s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys "
+ "are serialized with ByteArraySerializer.")
}
if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) {
if (params.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) {
throw new IllegalArgumentException(
s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as "
+ "value are serialized with ByteArraySerializer.")
}
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
val specifiedKafkaParams = convertToSpecifiedParams(params)
KafkaConfigUpdater("executor", specifiedKafkaParams)
.set(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, serClassName)

View file

@ -34,6 +34,7 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.time.SpanSugar._
import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.streaming._
@ -1336,14 +1337,16 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
(ENDING_OFFSETS_OPTION_KEY, "laTest", LatestOffsetRangeLimit),
(STARTING_OFFSETS_OPTION_KEY, """{"topic-A":{"0":23}}""",
SpecificOffsetRangeLimit(Map(new TopicPartition("topic-A", 0) -> 23))))) {
val offset = getKafkaOffsetRangeLimit(Map(optionKey -> optionValue), optionKey, answer)
val offset = getKafkaOffsetRangeLimit(
CaseInsensitiveMap[String](Map(optionKey -> optionValue)), optionKey, answer)
assert(offset === answer)
}
for ((optionKey, answer) <- Seq(
(STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit),
(ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit))) {
val offset = getKafkaOffsetRangeLimit(Map.empty, optionKey, answer)
val offset = getKafkaOffsetRangeLimit(
CaseInsensitiveMap[String](Map.empty), optionKey, answer)
assert(offset === answer)
}
}