[SPARK-28163][SS] Use CaseInsensitiveMap for KafkaOffsetReader

## What changes were proposed in this pull request?

There are "unsafe" conversions in the Kafka connector.
`CaseInsensitiveStringMap` comes in which is then converted the following way:
```
...
options.asScala.toMap
...
```
The main problem with this is that such case it looses its case insensitive nature
(case insensitive map is converting the key to lower case when get/contains called).

In this PR I'm using `CaseInsensitiveMap` to solve this problem.

## How was this patch tested?

Existing + additional unit tests.

Closes #24967 from gaborgsomogyi/SPARK-28163.

Authored-by: Gabor Somogyi <gabor.g.somogyi@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Gabor Somogyi 2019-08-09 17:08:11 +08:00 committed by Wenchen Fan
parent 5368eaa2fc
commit 5663386f4b
7 changed files with 117 additions and 94 deletions

View file

@ -22,12 +22,13 @@ import org.apache.kafka.common.TopicPartition
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, PartitionReaderFactory}
private[kafka010] class KafkaBatch(
strategy: ConsumerStrategy,
sourceOptions: Map[String, String],
sourceOptions: CaseInsensitiveMap[String],
specifiedKafkaParams: Map[String, String],
failOnDataLoss: Boolean,
startingOffsets: KafkaOffsetRangeLimit,
@ -38,7 +39,7 @@ private[kafka010] class KafkaBatch(
assert(endingOffsets != EarliestOffsetRangeLimit,
"Ending offset not allowed to be set to earliest offsets.")
private val pollTimeoutMs = sourceOptions.getOrElse(
private[kafka010] val pollTimeoutMs = sourceOptions.getOrElse(
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
(SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L).toString
).toLong

View file

@ -46,7 +46,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
* properly read.
*/
class KafkaContinuousStream(
offsetReader: KafkaOffsetReader,
private[kafka010] val offsetReader: KafkaOffsetReader,
kafkaParams: ju.Map[String, Object],
options: CaseInsensitiveStringMap,
metadataPath: String,
@ -54,7 +54,7 @@ class KafkaContinuousStream(
failOnDataLoss: Boolean)
extends ContinuousStream with Logging {
private val pollTimeoutMs =
private[kafka010] val pollTimeoutMs =
options.getLong(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, 512)
// Initialized when creating reader factories. If this diverges from the partitions at the latest

View file

@ -56,19 +56,19 @@ import org.apache.spark.util.UninterruptibleThread
* and not use wrong broker addresses.
*/
private[kafka010] class KafkaMicroBatchStream(
kafkaOffsetReader: KafkaOffsetReader,
private[kafka010] val kafkaOffsetReader: KafkaOffsetReader,
executorKafkaParams: ju.Map[String, Object],
options: CaseInsensitiveStringMap,
metadataPath: String,
startingOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean) extends RateControlMicroBatchStream with Logging {
private val pollTimeoutMs = options.getLong(
private[kafka010] val pollTimeoutMs = options.getLong(
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L)
private val maxOffsetsPerTrigger = Option(options.get(KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER))
.map(_.toLong)
private[kafka010] val maxOffsetsPerTrigger = Option(options.get(
KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER)).map(_.toLong)
private val rangeCalculator = KafkaOffsetRangeCalculator(options)

View file

@ -30,6 +30,7 @@ import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsume
import org.apache.kafka.common.TopicPartition
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.types._
import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
@ -47,7 +48,7 @@ import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
private[kafka010] class KafkaOffsetReader(
consumerStrategy: ConsumerStrategy,
val driverKafkaParams: ju.Map[String, Object],
readerOptions: Map[String, String],
readerOptions: CaseInsensitiveMap[String],
driverGroupIdPrefix: String) extends Logging {
/**
* Used to ensure execute fetch operations execute in an UninterruptibleThread
@ -88,10 +89,10 @@ private[kafka010] class KafkaOffsetReader(
_consumer
}
private val maxOffsetFetchAttempts =
private[kafka010] val maxOffsetFetchAttempts =
readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_NUM_RETRY, "3").toInt
private val offsetFetchAttemptIntervalMs =
private[kafka010] val offsetFetchAttemptIntervalMs =
readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, "1000").toLong
private def nextGroupId(): String = {

View file

@ -24,7 +24,7 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String
@ -33,7 +33,7 @@ import org.apache.spark.unsafe.types.UTF8String
private[kafka010] class KafkaRelation(
override val sqlContext: SQLContext,
strategy: ConsumerStrategy,
sourceOptions: Map[String, String],
sourceOptions: CaseInsensitiveMap[String],
specifiedKafkaParams: Map[String, String],
failOnDataLoss: Boolean,
startingOffsets: KafkaOffsetRangeLimit,

View file

@ -78,32 +78,32 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source = {
validateStreamOptions(parameters)
val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
validateStreamOptions(caseInsensitiveParameters)
// Each running query should use its own group id. Otherwise, the query may be only assigned
// partial data since Kafka will assign partitions to multiple consumers having the same group
// id. Hence, we should generate a unique id for each query.
val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath)
val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveParameters, metadataPath)
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams,
STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
val kafkaOffsetReader = new KafkaOffsetReader(
strategy(caseInsensitiveParams),
strategy(caseInsensitiveParameters),
kafkaParamsForDriver(specifiedKafkaParams),
parameters,
caseInsensitiveParameters,
driverGroupIdPrefix = s"$uniqueGroupId-driver")
new KafkaSource(
sqlContext,
kafkaOffsetReader,
kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
parameters,
caseInsensitiveParameters,
metadataPath,
startingStreamOffsets,
failOnDataLoss(caseInsensitiveParams))
failOnDataLoss(caseInsensitiveParameters))
}
override def getTable(options: CaseInsensitiveStringMap): KafkaTable = {
@ -119,24 +119,24 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
validateBatchOptions(parameters)
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
validateBatchOptions(caseInsensitiveParameters)
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit)
caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit)
assert(startingRelationOffsets != LatestOffsetRangeLimit)
val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams,
ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParameters, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
assert(endingRelationOffsets != EarliestOffsetRangeLimit)
new KafkaRelation(
sqlContext,
strategy(caseInsensitiveParams),
sourceOptions = parameters,
strategy(caseInsensitiveParameters),
sourceOptions = caseInsensitiveParameters,
specifiedKafkaParams = specifiedKafkaParams,
failOnDataLoss = failOnDataLoss(caseInsensitiveParams),
failOnDataLoss = failOnDataLoss(caseInsensitiveParameters),
startingOffsets = startingRelationOffsets,
endingOffsets = endingRelationOffsets)
}
@ -420,23 +420,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
}
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
val parameters = options.asScala.toMap
validateStreamOptions(parameters)
val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap)
validateStreamOptions(caseInsensitiveOptions)
// Each running query should use its own group id. Otherwise, the query may be only assigned
// partial data since Kafka will assign partitions to multiple consumers having the same group
// id. Hence, we should generate a unique id for each query.
val uniqueGroupId = streamingUniqueGroupId(parameters, checkpointLocation)
val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveOptions, checkpointLocation)
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions)
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
caseInsensitiveOptions, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
val kafkaOffsetReader = new KafkaOffsetReader(
strategy(parameters),
strategy(caseInsensitiveOptions),
kafkaParamsForDriver(specifiedKafkaParams),
parameters,
caseInsensitiveOptions,
driverGroupIdPrefix = s"$uniqueGroupId-driver")
new KafkaMicroBatchStream(
@ -445,32 +444,26 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
options,
checkpointLocation,
startingStreamOffsets,
failOnDataLoss(caseInsensitiveParams))
failOnDataLoss(caseInsensitiveOptions))
}
override def toContinuousStream(checkpointLocation: String): ContinuousStream = {
val parameters = options.asScala.toMap
validateStreamOptions(parameters)
val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap)
validateStreamOptions(caseInsensitiveOptions)
// Each running query should use its own group id. Otherwise, the query may be only assigned
// partial data since Kafka will assign partitions to multiple consumers having the same group
// id. Hence, we should generate a unique id for each query.
val uniqueGroupId = streamingUniqueGroupId(parameters, checkpointLocation)
val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveOptions, checkpointLocation)
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val specifiedKafkaParams =
parameters
.keySet
.filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
.map { k => k.drop(6).toString -> parameters(k) }
.toMap
val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions)
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
caseInsensitiveOptions, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
val kafkaOffsetReader = new KafkaOffsetReader(
strategy(caseInsensitiveParams),
strategy(caseInsensitiveOptions),
kafkaParamsForDriver(specifiedKafkaParams),
parameters,
caseInsensitiveOptions,
driverGroupIdPrefix = s"$uniqueGroupId-driver")
new KafkaContinuousStream(
@ -479,7 +472,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
options,
checkpointLocation,
startingStreamOffsets,
failOnDataLoss(caseInsensitiveParams))
failOnDataLoss(caseInsensitiveOptions))
}
}
}

View file

@ -22,60 +22,92 @@ import java.util.Locale
import scala.collection.JavaConverters._
import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite}
import org.apache.spark.sql.sources.v2.reader.Scan
import org.apache.spark.sql.util.CaseInsensitiveStringMap
class KafkaSourceProviderSuite extends SparkFunSuite with PrivateMethodTester {
class KafkaSourceProviderSuite extends SparkFunSuite {
private val pollTimeoutMsMethod = PrivateMethod[Long]('pollTimeoutMs)
private val maxOffsetsPerTriggerMethod = PrivateMethod[Option[Long]]('maxOffsetsPerTrigger)
private val expected = "1111"
override protected def afterEach(): Unit = {
SparkEnv.set(null)
super.afterEach()
}
test("micro-batch mode - options should be handled as case-insensitive") {
def verifyFieldsInMicroBatchStream(
options: CaseInsensitiveStringMap,
expectedPollTimeoutMs: Long,
expectedMaxOffsetsPerTrigger: Option[Long]): Unit = {
// KafkaMicroBatchStream reads Spark conf from SparkEnv for default value
// hence we set mock SparkEnv here before creating KafkaMicroBatchStream
val sparkEnv = mock(classOf[SparkEnv])
when(sparkEnv.conf).thenReturn(new SparkConf())
SparkEnv.set(sparkEnv)
val scan = getKafkaDataSourceScan(options)
val stream = scan.toMicroBatchStream("dummy").asInstanceOf[KafkaMicroBatchStream]
assert(expectedPollTimeoutMs === getField(stream, pollTimeoutMsMethod))
assert(expectedMaxOffsetsPerTrigger === getField(stream, maxOffsetsPerTriggerMethod))
}
val expectedValue = 1000L
buildCaseInsensitiveStringMapForUpperAndLowerKey(
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT -> expectedValue.toString,
KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER -> expectedValue.toString)
.foreach(verifyFieldsInMicroBatchStream(_, expectedValue, Some(expectedValue)))
test("batch mode - options should be handled as case-insensitive") {
verifyFieldsInBatch(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, expected, batch => {
assert(expected.toLong === batch.pollTimeoutMs)
})
}
test("SPARK-28142 - continuous mode - options should be handled as case-insensitive") {
def verifyFieldsInContinuousStream(
options: CaseInsensitiveStringMap,
expectedPollTimeoutMs: Long): Unit = {
test("micro-batch mode - options should be handled as case-insensitive") {
verifyFieldsInMicroBatchStream(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, expected, stream => {
assert(expected.toLong === stream.pollTimeoutMs)
})
verifyFieldsInMicroBatchStream(KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER, expected, stream => {
assert(Some(expected.toLong) === stream.maxOffsetsPerTrigger)
})
verifyFieldsInMicroBatchStream(KafkaSourceProvider.FETCH_OFFSET_NUM_RETRY, expected, stream => {
assert(expected.toInt === stream.kafkaOffsetReader.maxOffsetFetchAttempts)
})
verifyFieldsInMicroBatchStream(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, expected,
stream => {
assert(expected.toLong === stream.kafkaOffsetReader.offsetFetchAttemptIntervalMs)
})
}
test("continuous mode - options should be handled as case-insensitive") {
verifyFieldsInContinuousStream(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, expected, stream => {
assert(expected.toLong === stream.pollTimeoutMs)
})
verifyFieldsInContinuousStream(KafkaSourceProvider.FETCH_OFFSET_NUM_RETRY, expected, stream => {
assert(expected.toInt === stream.offsetReader.maxOffsetFetchAttempts)
})
verifyFieldsInContinuousStream(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, expected,
stream => {
assert(expected.toLong === stream.offsetReader.offsetFetchAttemptIntervalMs)
})
}
private def verifyFieldsInBatch(
key: String,
value: String,
validate: (KafkaBatch) => Unit): Unit = {
buildCaseInsensitiveStringMapForUpperAndLowerKey(key -> value).foreach { options =>
val scan = getKafkaDataSourceScan(options)
val batch = scan.toBatch().asInstanceOf[KafkaBatch]
validate(batch)
}
}
private def verifyFieldsInMicroBatchStream(
key: String,
value: String,
validate: (KafkaMicroBatchStream) => Unit): Unit = {
// KafkaMicroBatchStream reads Spark conf from SparkEnv for default value
// hence we set mock SparkEnv here before creating KafkaMicroBatchStream
val sparkEnv = mock(classOf[SparkEnv])
when(sparkEnv.conf).thenReturn(new SparkConf())
SparkEnv.set(sparkEnv)
buildCaseInsensitiveStringMapForUpperAndLowerKey(key -> value).foreach { options =>
val scan = getKafkaDataSourceScan(options)
val stream = scan.toMicroBatchStream("dummy").asInstanceOf[KafkaMicroBatchStream]
validate(stream)
}
}
private def verifyFieldsInContinuousStream(
key: String,
value: String,
validate: (KafkaContinuousStream) => Unit): Unit = {
buildCaseInsensitiveStringMapForUpperAndLowerKey(key -> value).foreach { options =>
val scan = getKafkaDataSourceScan(options)
val stream = scan.toContinuousStream("dummy").asInstanceOf[KafkaContinuousStream]
assert(expectedPollTimeoutMs === getField(stream, pollTimeoutMsMethod))
validate(stream)
}
val expectedValue = 1000
buildCaseInsensitiveStringMapForUpperAndLowerKey(
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT -> expectedValue.toString)
.foreach(verifyFieldsInContinuousStream(_, expectedValue))
}
private def buildCaseInsensitiveStringMapForUpperAndLowerKey(
@ -95,8 +127,4 @@ class KafkaSourceProviderSuite extends SparkFunSuite with PrivateMethodTester {
val provider = new KafkaSourceProvider()
provider.getTable(options).newScanBuilder(options).build()
}
private def getField[T](obj: AnyRef, method: PrivateMethod[T]): T = {
obj.invokePrivate(method())
}
}