diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 2858ff1162..e4ed84552b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -31,7 +31,7 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.kafka010.KafkaSource._ import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} @@ -76,7 +76,7 @@ private[kafka010] class KafkaSource( sqlContext: SQLContext, kafkaReader: KafkaOffsetReader, executorKafkaParams: ju.Map[String, Object], - sourceOptions: Map[String, String], + sourceOptions: CaseInsensitiveMap[String], metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 372bcab1ca..c3f0be4be9 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -67,7 +67,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister schema: Option[StructType], providerName: String, parameters: Map[String, String]): (String, StructType) = { - validateStreamOptions(parameters) + val caseInsensitiveParameters = CaseInsensitiveMap(parameters) + validateStreamOptions(caseInsensitiveParameters) require(schema.isEmpty, "Kafka source has a fixed schema and cannot be set with a custom one") (shortName(), KafkaOffsetReader.kafkaSchema) } @@ -85,7 +86,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // id. Hence, we should generate a unique id for each query. val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveParameters, metadataPath) - val specifiedKafkaParams = convertToSpecifiedParams(parameters) + val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveParameters) val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) @@ -121,7 +122,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister parameters: Map[String, String]): BaseRelation = { val caseInsensitiveParameters = CaseInsensitiveMap(parameters) validateBatchOptions(caseInsensitiveParameters) - val specifiedKafkaParams = convertToSpecifiedParams(parameters) + val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveParameters) val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) @@ -146,8 +147,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister parameters: Map[String, String], partitionColumns: Seq[String], outputMode: OutputMode): Sink = { - val defaultTopic = parameters.get(TOPIC_OPTION_KEY).map(_.trim) - val specifiedKafkaParams = kafkaParamsForProducer(parameters) + val caseInsensitiveParameters = CaseInsensitiveMap(parameters) + val defaultTopic = caseInsensitiveParameters.get(TOPIC_OPTION_KEY).map(_.trim) + val specifiedKafkaParams = kafkaParamsForProducer(caseInsensitiveParameters) new KafkaSink(sqlContext, specifiedKafkaParams, defaultTopic) } @@ -163,8 +165,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister s"${SaveMode.ErrorIfExists} (default).") case _ => // good } - val topic = parameters.get(TOPIC_OPTION_KEY).map(_.trim) - val specifiedKafkaParams = kafkaParamsForProducer(parameters) + val caseInsensitiveParameters = CaseInsensitiveMap(parameters) + val topic = caseInsensitiveParameters.get(TOPIC_OPTION_KEY).map(_.trim) + val specifiedKafkaParams = kafkaParamsForProducer(caseInsensitiveParameters) KafkaWriter.write(outerSQLContext.sparkSession, data.queryExecution, specifiedKafkaParams, topic) @@ -184,28 +187,31 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - private def strategy(caseInsensitiveParams: Map[String, String]) = - caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { - case (ASSIGN, value) => - AssignStrategy(JsonUtils.partitions(value)) - case (SUBSCRIBE, value) => - SubscribeStrategy(value.split(",").map(_.trim()).filter(_.nonEmpty)) - case (SUBSCRIBE_PATTERN, value) => - SubscribePatternStrategy(value.trim()) - case _ => - // Should never reach here as we are already matching on - // matched strategy names - throw new IllegalArgumentException("Unknown option") + private def strategy(params: CaseInsensitiveMap[String]) = { + val lowercaseParams = params.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + + lowercaseParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { + case (ASSIGN, value) => + AssignStrategy(JsonUtils.partitions(value)) + case (SUBSCRIBE, value) => + SubscribeStrategy(value.split(",").map(_.trim()).filter(_.nonEmpty)) + case (SUBSCRIBE_PATTERN, value) => + SubscribePatternStrategy(value.trim()) + case _ => + // Should never reach here as we are already matching on + // matched strategy names + throw new IllegalArgumentException("Unknown option") + } } - private def failOnDataLoss(caseInsensitiveParams: Map[String, String]) = - caseInsensitiveParams.getOrElse(FAIL_ON_DATA_LOSS_OPTION_KEY, "true").toBoolean + private def failOnDataLoss(params: CaseInsensitiveMap[String]) = + params.getOrElse(FAIL_ON_DATA_LOSS_OPTION_KEY, "true").toBoolean - private def validateGeneralOptions(parameters: Map[String, String]): Unit = { + private def validateGeneralOptions(params: CaseInsensitiveMap[String]): Unit = { // Validate source options - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + val lowercaseParams = params.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedStrategies = - caseInsensitiveParams.filter { case (k, _) => STRATEGY_OPTION_KEYS.contains(k) }.toSeq + lowercaseParams.filter { case (k, _) => STRATEGY_OPTION_KEYS.contains(k) }.toSeq if (specifiedStrategies.isEmpty) { throw new IllegalArgumentException( @@ -217,7 +223,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister + STRATEGY_OPTION_KEYS.mkString(", ") + ". See the docs for more details.") } - caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { + lowercaseParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { case (ASSIGN, value) => if (!value.trim.startsWith("{")) { throw new IllegalArgumentException( @@ -233,7 +239,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister s"'subscribe' is '$value'") } case (SUBSCRIBE_PATTERN, value) => - val pattern = caseInsensitiveParams(SUBSCRIBE_PATTERN).trim() + val pattern = params(SUBSCRIBE_PATTERN).trim() if (pattern.isEmpty) { throw new IllegalArgumentException( "Pattern to subscribe is empty as specified value for option " + @@ -246,22 +252,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } // Validate minPartitions value if present - if (caseInsensitiveParams.contains(MIN_PARTITIONS_OPTION_KEY)) { - val p = caseInsensitiveParams(MIN_PARTITIONS_OPTION_KEY).toInt + if (params.contains(MIN_PARTITIONS_OPTION_KEY)) { + val p = params(MIN_PARTITIONS_OPTION_KEY).toInt if (p <= 0) throw new IllegalArgumentException("minPartitions must be positive") } // Validate user-specified Kafka options - if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) { + if (params.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) { logWarning(CUSTOM_GROUP_ID_ERROR_MESSAGE) - if (caseInsensitiveParams.contains(GROUP_ID_PREFIX)) { + if (params.contains(GROUP_ID_PREFIX)) { logWarning("Option 'groupIdPrefix' will be ignored as " + s"option 'kafka.${ConsumerConfig.GROUP_ID_CONFIG}' has been set.") } } - if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}")) { + if (params.contains(s"kafka.${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}")) { throw new IllegalArgumentException( s""" |Kafka option '${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}' is not supported. @@ -275,14 +281,14 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister """.stripMargin) } - if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG}")) { + if (params.contains(s"kafka.${ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG}")) { throw new IllegalArgumentException( s"Kafka option '${ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG}' is not supported as keys " + "are deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame operations " + "to explicitly deserialize the keys.") } - if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}")) + if (params.contains(s"kafka.${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}")) { throw new IllegalArgumentException( s"Kafka option '${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}' is not supported as " @@ -295,29 +301,29 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister ConsumerConfig.INTERCEPTOR_CLASSES_CONFIG) // interceptors can modify payload, so not safe otherUnsupportedConfigs.foreach { c => - if (caseInsensitiveParams.contains(s"kafka.$c")) { + if (params.contains(s"kafka.$c")) { throw new IllegalArgumentException(s"Kafka option '$c' is not supported") } } - if (!caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG}")) { + if (!params.contains(s"kafka.${ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG}")) { throw new IllegalArgumentException( s"Option 'kafka.${ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG}' must be specified for " + s"configuring Kafka consumer") } } - private def validateStreamOptions(caseInsensitiveParams: Map[String, String]) = { + private def validateStreamOptions(params: CaseInsensitiveMap[String]) = { // Stream specific options - caseInsensitiveParams.get(ENDING_OFFSETS_OPTION_KEY).map(_ => + params.get(ENDING_OFFSETS_OPTION_KEY).map(_ => throw new IllegalArgumentException("ending offset not valid in streaming queries")) - validateGeneralOptions(caseInsensitiveParams) + validateGeneralOptions(params) } - private def validateBatchOptions(caseInsensitiveParams: Map[String, String]) = { + private def validateBatchOptions(params: CaseInsensitiveMap[String]) = { // Batch specific options KafkaSourceProvider.getKafkaOffsetRangeLimit( - caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) match { + params, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) match { case EarliestOffsetRangeLimit => // good to go case LatestOffsetRangeLimit => throw new IllegalArgumentException("starting offset can't be latest " + @@ -332,7 +338,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } KafkaSourceProvider.getKafkaOffsetRangeLimit( - caseInsensitiveParams, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) match { + params, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) match { case EarliestOffsetRangeLimit => throw new IllegalArgumentException("ending offset can't be earliest " + "for batch queries on Kafka") @@ -346,10 +352,10 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - validateGeneralOptions(caseInsensitiveParams) + validateGeneralOptions(params) // Don't want to throw an error, but at least log a warning. - if (caseInsensitiveParams.get(MAX_OFFSET_PER_TRIGGER.toLowerCase(Locale.ROOT)).isDefined) { + if (params.contains(MAX_OFFSET_PER_TRIGGER)) { logWarning("maxOffsetsPerTrigger option ignored in batch queries") } } @@ -375,7 +381,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister new WriteBuilder { private var inputSchema: StructType = _ private val topic = Option(options.get(TOPIC_OPTION_KEY)).map(_.trim) - private val producerParams = kafkaParamsForProducer(options.asScala.toMap) + private val producerParams = + kafkaParamsForProducer(CaseInsensitiveMap(options.asScala.toMap)) override def withInputDataSchema(schema: StructType): WriteBuilder = { this.inputSchema = schema @@ -486,10 +493,10 @@ private[kafka010] object KafkaSourceProvider extends Logging { private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" private[kafka010] val MIN_PARTITIONS_OPTION_KEY = "minpartitions" - private[kafka010] val MAX_OFFSET_PER_TRIGGER = "maxOffsetsPerTrigger" - private[kafka010] val FETCH_OFFSET_NUM_RETRY = "fetchOffset.numRetries" - private[kafka010] val FETCH_OFFSET_RETRY_INTERVAL_MS = "fetchOffset.retryIntervalMs" - private[kafka010] val CONSUMER_POLL_TIMEOUT = "kafkaConsumer.pollTimeoutMs" + private[kafka010] val MAX_OFFSET_PER_TRIGGER = "maxoffsetspertrigger" + private[kafka010] val FETCH_OFFSET_NUM_RETRY = "fetchoffset.numretries" + private[kafka010] val FETCH_OFFSET_RETRY_INTERVAL_MS = "fetchoffset.retryintervalms" + private[kafka010] val CONSUMER_POLL_TIMEOUT = "kafkaconsumer.polltimeoutms" private val GROUP_ID_PREFIX = "groupidprefix" val TOPIC_OPTION_KEY = "topic" @@ -525,7 +532,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { private val deserClassName = classOf[ByteArrayDeserializer].getName def getKafkaOffsetRangeLimit( - params: Map[String, String], + params: CaseInsensitiveMap[String], offsetOptionKey: String, defaultOffsets: KafkaOffsetRangeLimit): KafkaOffsetRangeLimit = { params.get(offsetOptionKey).map(_.trim) match { @@ -583,9 +590,8 @@ private[kafka010] object KafkaSourceProvider extends Logging { * Returns a unique batch consumer group (group.id), allowing the user to set the prefix of * the consumer group */ - private[kafka010] def batchUniqueGroupId(parameters: Map[String, String]): String = { - val groupIdPrefix = parameters - .getOrElse(GROUP_ID_PREFIX, "spark-kafka-relation") + private[kafka010] def batchUniqueGroupId(params: CaseInsensitiveMap[String]): String = { + val groupIdPrefix = params.getOrElse(GROUP_ID_PREFIX, "spark-kafka-relation") s"${groupIdPrefix}-${UUID.randomUUID}" } @@ -594,29 +600,27 @@ private[kafka010] object KafkaSourceProvider extends Logging { * the consumer group */ private def streamingUniqueGroupId( - parameters: Map[String, String], + params: CaseInsensitiveMap[String], metadataPath: String): String = { - val groupIdPrefix = parameters - .getOrElse(GROUP_ID_PREFIX, "spark-kafka-source") + val groupIdPrefix = params.getOrElse(GROUP_ID_PREFIX, "spark-kafka-source") s"${groupIdPrefix}-${UUID.randomUUID}-${metadataPath.hashCode}" } private[kafka010] def kafkaParamsForProducer( - parameters: Map[String, String]): ju.Map[String, Object] = { - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { + params: CaseInsensitiveMap[String]): ju.Map[String, Object] = { + if (params.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { throw new IllegalArgumentException( s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " + "are serialized with ByteArraySerializer.") } - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) { + if (params.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) { throw new IllegalArgumentException( s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + "value are serialized with ByteArraySerializer.") } - val specifiedKafkaParams = convertToSpecifiedParams(parameters) + val specifiedKafkaParams = convertToSpecifiedParams(params) KafkaConfigUpdater("executor", specifiedKafkaParams) .set(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, serClassName) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index bb9b3696fe..609c43803b 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -34,6 +34,7 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.streaming._ @@ -1336,14 +1337,16 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { (ENDING_OFFSETS_OPTION_KEY, "laTest", LatestOffsetRangeLimit), (STARTING_OFFSETS_OPTION_KEY, """{"topic-A":{"0":23}}""", SpecificOffsetRangeLimit(Map(new TopicPartition("topic-A", 0) -> 23))))) { - val offset = getKafkaOffsetRangeLimit(Map(optionKey -> optionValue), optionKey, answer) + val offset = getKafkaOffsetRangeLimit( + CaseInsensitiveMap[String](Map(optionKey -> optionValue)), optionKey, answer) assert(offset === answer) } for ((optionKey, answer) <- Seq( (STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit), (ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit))) { - val offset = getKafkaOffsetRangeLimit(Map.empty, optionKey, answer) + val offset = getKafkaOffsetRangeLimit( + CaseInsensitiveMap[String](Map.empty), optionKey, answer) assert(offset === answer) } }