[SPARK-23539][SS] Add support for Kafka headers in Structured Streaming
## What changes were proposed in this pull request? This update adds support for Kafka Headers functionality in Structured Streaming. ## How was this patch tested? With following unit tests: - KafkaRelationSuite: "default starting and ending offsets with headers" (new) - KafkaSinkSuite: "batch - write to kafka" (updated) Closes #22282 from dongjinleekr/feature/SPARK-23539. Lead-authored-by: Lee Dongjin <dongjin@apache.org> Co-authored-by: Jungtaek Lim <kabhwan@gmail.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
parent
77e9b58d4f
commit
1675d5114e
|
@ -27,6 +27,8 @@ For Scala/Java applications using SBT/Maven project definitions, link your appli
|
|||
artifactId = spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}}
|
||||
version = {{site.SPARK_VERSION_SHORT}}
|
||||
|
||||
Please note that to use the headers functionality, your Kafka client version should be version 0.11.0.0 or up.
|
||||
|
||||
For Python applications, you need to add this above library and its dependencies when deploying your
|
||||
application. See the [Deploying](#deploying) subsection below.
|
||||
|
||||
|
@ -50,6 +52,17 @@ val df = spark
|
|||
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
|
||||
.as[(String, String)]
|
||||
|
||||
// Subscribe to 1 topic, with headers
|
||||
val df = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", "host1:port1,host2:port2")
|
||||
.option("subscribe", "topic1")
|
||||
.option("includeHeaders", "true")
|
||||
.load()
|
||||
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers")
|
||||
.as[(String, String, Map)]
|
||||
|
||||
// Subscribe to multiple topics
|
||||
val df = spark
|
||||
.readStream
|
||||
|
@ -84,6 +97,16 @@ Dataset<Row> df = spark
|
|||
.load();
|
||||
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)");
|
||||
|
||||
// Subscribe to 1 topic, with headers
|
||||
Dataset<Row> df = spark
|
||||
.readStream()
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", "host1:port1,host2:port2")
|
||||
.option("subscribe", "topic1")
|
||||
.option("includeHeaders", "true")
|
||||
.load()
|
||||
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers");
|
||||
|
||||
// Subscribe to multiple topics
|
||||
Dataset<Row> df = spark
|
||||
.readStream()
|
||||
|
@ -116,6 +139,16 @@ df = spark \
|
|||
.load()
|
||||
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
|
||||
|
||||
# Subscribe to 1 topic, with headers
|
||||
val df = spark \
|
||||
.readStream \
|
||||
.format("kafka") \
|
||||
.option("kafka.bootstrap.servers", "host1:port1,host2:port2") \
|
||||
.option("subscribe", "topic1") \
|
||||
.option("includeHeaders", "true") \
|
||||
.load()
|
||||
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers")
|
||||
|
||||
# Subscribe to multiple topics
|
||||
df = spark \
|
||||
.readStream \
|
||||
|
@ -286,6 +319,10 @@ Each row in the source has the following schema:
|
|||
<td>timestampType</td>
|
||||
<td>int</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>headers (optional)</td>
|
||||
<td>array</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
The following options must be set for the Kafka source
|
||||
|
@ -425,6 +462,13 @@ The following configurations are optional:
|
|||
issues, set the Kafka consumer session timeout (by setting option "kafka.session.timeout.ms") to
|
||||
be very small. When this is set, option "groupIdPrefix" will be ignored.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>includeHeaders</td>
|
||||
<td>boolean</td>
|
||||
<td>false</td>
|
||||
<td>streaming and batch</td>
|
||||
<td>Whether to include the Kafka headers in the row.</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### Consumer Caching
|
||||
|
@ -522,6 +566,10 @@ The Dataframe being written to Kafka should have the following columns in schema
|
|||
<td>value (required)</td>
|
||||
<td>string or binary</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>headers (optional)</td>
|
||||
<td>array</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>topic (*optional)</td>
|
||||
<td>string</td>
|
||||
|
@ -559,6 +607,13 @@ The following configurations are optional:
|
|||
<td>Sets the topic that all rows will be written to in Kafka. This option overrides any
|
||||
topic column that may exist in the data.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>includeHeaders</td>
|
||||
<td>boolean</td>
|
||||
<td>false</td>
|
||||
<td>streaming and batch</td>
|
||||
<td>Whether to include the Kafka headers in the row.</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### Creating a Kafka Sink for Streaming Queries
|
||||
|
|
|
@ -31,7 +31,8 @@ private[kafka010] class KafkaBatch(
|
|||
specifiedKafkaParams: Map[String, String],
|
||||
failOnDataLoss: Boolean,
|
||||
startingOffsets: KafkaOffsetRangeLimit,
|
||||
endingOffsets: KafkaOffsetRangeLimit)
|
||||
endingOffsets: KafkaOffsetRangeLimit,
|
||||
includeHeaders: Boolean)
|
||||
extends Batch with Logging {
|
||||
assert(startingOffsets != LatestOffsetRangeLimit,
|
||||
"Starting offset not allowed to be set to latest offsets.")
|
||||
|
@ -90,7 +91,7 @@ private[kafka010] class KafkaBatch(
|
|||
KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId)
|
||||
offsetRanges.map { range =>
|
||||
new KafkaBatchInputPartition(
|
||||
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss)
|
||||
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, includeHeaders)
|
||||
}.toArray
|
||||
}
|
||||
|
||||
|
|
|
@ -29,13 +29,14 @@ private[kafka010] case class KafkaBatchInputPartition(
|
|||
offsetRange: KafkaOffsetRange,
|
||||
executorKafkaParams: ju.Map[String, Object],
|
||||
pollTimeoutMs: Long,
|
||||
failOnDataLoss: Boolean) extends InputPartition
|
||||
failOnDataLoss: Boolean,
|
||||
includeHeaders: Boolean) extends InputPartition
|
||||
|
||||
private[kafka010] object KafkaBatchReaderFactory extends PartitionReaderFactory {
|
||||
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
|
||||
val p = partition.asInstanceOf[KafkaBatchInputPartition]
|
||||
KafkaBatchPartitionReader(p.offsetRange, p.executorKafkaParams, p.pollTimeoutMs,
|
||||
p.failOnDataLoss)
|
||||
p.failOnDataLoss, p.includeHeaders)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -44,12 +45,14 @@ private case class KafkaBatchPartitionReader(
|
|||
offsetRange: KafkaOffsetRange,
|
||||
executorKafkaParams: ju.Map[String, Object],
|
||||
pollTimeoutMs: Long,
|
||||
failOnDataLoss: Boolean) extends PartitionReader[InternalRow] with Logging {
|
||||
failOnDataLoss: Boolean,
|
||||
includeHeaders: Boolean) extends PartitionReader[InternalRow] with Logging {
|
||||
|
||||
private val consumer = KafkaDataConsumer.acquire(offsetRange.topicPartition, executorKafkaParams)
|
||||
|
||||
private val rangeToRead = resolveRange(offsetRange)
|
||||
private val converter = new KafkaRecordToUnsafeRowConverter
|
||||
private val unsafeRowProjector = new KafkaRecordToRowConverter()
|
||||
.toUnsafeRowProjector(includeHeaders)
|
||||
|
||||
private var nextOffset = rangeToRead.fromOffset
|
||||
private var nextRow: UnsafeRow = _
|
||||
|
@ -58,7 +61,7 @@ private case class KafkaBatchPartitionReader(
|
|||
if (nextOffset < rangeToRead.untilOffset) {
|
||||
val record = consumer.get(nextOffset, rangeToRead.untilOffset, pollTimeoutMs, failOnDataLoss)
|
||||
if (record != null) {
|
||||
nextRow = converter.toUnsafeRow(record)
|
||||
nextRow = unsafeRowProjector(record)
|
||||
nextOffset = record.offset + 1
|
||||
true
|
||||
} else {
|
||||
|
|
|
@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
|
||||
import org.apache.spark.sql.connector.read.InputPartition
|
||||
import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReader, ContinuousPartitionReaderFactory, ContinuousStream, Offset, PartitionOffset}
|
||||
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
|
||||
import org.apache.spark.sql.kafka010.KafkaSourceProvider._
|
||||
import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
||||
|
||||
/**
|
||||
|
@ -56,6 +56,7 @@ class KafkaContinuousStream(
|
|||
|
||||
private[kafka010] val pollTimeoutMs =
|
||||
options.getLong(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, 512)
|
||||
private val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)
|
||||
|
||||
// Initialized when creating reader factories. If this diverges from the partitions at the latest
|
||||
// offsets, we need to reconfigure.
|
||||
|
@ -88,7 +89,7 @@ class KafkaContinuousStream(
|
|||
if (deletedPartitions.nonEmpty) {
|
||||
val message = if (
|
||||
offsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
|
||||
s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}"
|
||||
s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}"
|
||||
} else {
|
||||
s"$deletedPartitions are gone. Some data may have been missed."
|
||||
}
|
||||
|
@ -102,7 +103,7 @@ class KafkaContinuousStream(
|
|||
startOffsets.toSeq.map {
|
||||
case (topicPartition, start) =>
|
||||
KafkaContinuousInputPartition(
|
||||
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
|
||||
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss, includeHeaders)
|
||||
}.toArray
|
||||
}
|
||||
|
||||
|
@ -153,19 +154,22 @@ class KafkaContinuousStream(
|
|||
* @param pollTimeoutMs The timeout for Kafka consumer polling.
|
||||
* @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
|
||||
* are skipped.
|
||||
* @param includeHeaders Flag indicating whether to include Kafka records' headers.
|
||||
*/
|
||||
case class KafkaContinuousInputPartition(
|
||||
topicPartition: TopicPartition,
|
||||
startOffset: Long,
|
||||
kafkaParams: ju.Map[String, Object],
|
||||
pollTimeoutMs: Long,
|
||||
failOnDataLoss: Boolean) extends InputPartition
|
||||
topicPartition: TopicPartition,
|
||||
startOffset: Long,
|
||||
kafkaParams: ju.Map[String, Object],
|
||||
pollTimeoutMs: Long,
|
||||
failOnDataLoss: Boolean,
|
||||
includeHeaders: Boolean) extends InputPartition
|
||||
|
||||
object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory {
|
||||
override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = {
|
||||
val p = partition.asInstanceOf[KafkaContinuousInputPartition]
|
||||
new KafkaContinuousPartitionReader(
|
||||
p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs, p.failOnDataLoss)
|
||||
p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs,
|
||||
p.failOnDataLoss, p.includeHeaders)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -184,9 +188,11 @@ class KafkaContinuousPartitionReader(
|
|||
startOffset: Long,
|
||||
kafkaParams: ju.Map[String, Object],
|
||||
pollTimeoutMs: Long,
|
||||
failOnDataLoss: Boolean) extends ContinuousPartitionReader[InternalRow] {
|
||||
failOnDataLoss: Boolean,
|
||||
includeHeaders: Boolean) extends ContinuousPartitionReader[InternalRow] {
|
||||
private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams)
|
||||
private val converter = new KafkaRecordToUnsafeRowConverter
|
||||
private val unsafeRowProjector = new KafkaRecordToRowConverter()
|
||||
.toUnsafeRowProjector(includeHeaders)
|
||||
|
||||
private var nextKafkaOffset = startOffset
|
||||
private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _
|
||||
|
@ -225,7 +231,7 @@ class KafkaContinuousPartitionReader(
|
|||
}
|
||||
|
||||
override def get(): UnsafeRow = {
|
||||
converter.toUnsafeRow(currentRecord)
|
||||
unsafeRowProjector(currentRecord)
|
||||
}
|
||||
|
||||
override def getOffset(): KafkaSourcePartitionOffset = {
|
||||
|
|
|
@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession
|
|||
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory}
|
||||
import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset}
|
||||
import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchStream
|
||||
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
|
||||
import org.apache.spark.sql.kafka010.KafkaSourceProvider._
|
||||
import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
||||
import org.apache.spark.util.UninterruptibleThread
|
||||
|
||||
|
@ -64,6 +64,8 @@ private[kafka010] class KafkaMicroBatchStream(
|
|||
private[kafka010] val maxOffsetsPerTrigger = Option(options.get(
|
||||
KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER)).map(_.toLong)
|
||||
|
||||
private val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)
|
||||
|
||||
private val rangeCalculator = KafkaOffsetRangeCalculator(options)
|
||||
|
||||
private var endPartitionOffsets: KafkaSourceOffset = _
|
||||
|
@ -112,7 +114,7 @@ private[kafka010] class KafkaMicroBatchStream(
|
|||
if (deletedPartitions.nonEmpty) {
|
||||
val message =
|
||||
if (kafkaOffsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
|
||||
s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}"
|
||||
s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}"
|
||||
} else {
|
||||
s"$deletedPartitions are gone. Some data may have been missed."
|
||||
}
|
||||
|
@ -146,7 +148,8 @@ private[kafka010] class KafkaMicroBatchStream(
|
|||
|
||||
// Generate factories based on the offset ranges
|
||||
offsetRanges.map { range =>
|
||||
KafkaBatchInputPartition(range, executorKafkaParams, pollTimeoutMs, failOnDataLoss)
|
||||
KafkaBatchInputPartition(range, executorKafkaParams, pollTimeoutMs,
|
||||
failOnDataLoss, includeHeaders)
|
||||
}.toArray
|
||||
}
|
||||
|
||||
|
|
|
@ -31,7 +31,6 @@ import org.apache.kafka.common.TopicPartition
|
|||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
|
||||
|
||||
/**
|
||||
|
@ -421,16 +420,3 @@ private[kafka010] class KafkaOffsetReader(
|
|||
_consumer = null // will automatically get reinitialized again
|
||||
}
|
||||
}
|
||||
|
||||
private[kafka010] object KafkaOffsetReader {
|
||||
|
||||
def kafkaSchema: StructType = StructType(Seq(
|
||||
StructField("key", BinaryType),
|
||||
StructField("value", BinaryType),
|
||||
StructField("topic", StringType),
|
||||
StructField("partition", IntegerType),
|
||||
StructField("offset", LongType),
|
||||
StructField("timestamp", TimestampType),
|
||||
StructField("timestampType", IntegerType)
|
||||
))
|
||||
}
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.kafka010
|
||||
|
||||
import java.sql.Timestamp
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.kafka.clients.consumer.ConsumerRecord
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
|
||||
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
/** A simple class for converting Kafka ConsumerRecord to InternalRow/UnsafeRow */
|
||||
private[kafka010] class KafkaRecordToRowConverter {
|
||||
import KafkaRecordToRowConverter._
|
||||
|
||||
private val toUnsafeRowWithoutHeaders = UnsafeProjection.create(schemaWithoutHeaders)
|
||||
private val toUnsafeRowWithHeaders = UnsafeProjection.create(schemaWithHeaders)
|
||||
|
||||
val toInternalRowWithoutHeaders: Record => InternalRow =
|
||||
(cr: Record) => InternalRow(
|
||||
cr.key, cr.value, UTF8String.fromString(cr.topic), cr.partition, cr.offset,
|
||||
DateTimeUtils.fromJavaTimestamp(new Timestamp(cr.timestamp)), cr.timestampType.id
|
||||
)
|
||||
|
||||
val toInternalRowWithHeaders: Record => InternalRow =
|
||||
(cr: Record) => InternalRow(
|
||||
cr.key, cr.value, UTF8String.fromString(cr.topic), cr.partition, cr.offset,
|
||||
DateTimeUtils.fromJavaTimestamp(new Timestamp(cr.timestamp)), cr.timestampType.id,
|
||||
if (cr.headers.iterator().hasNext) {
|
||||
new GenericArrayData(cr.headers.iterator().asScala
|
||||
.map(header =>
|
||||
InternalRow(UTF8String.fromString(header.key()), header.value())
|
||||
).toArray)
|
||||
} else {
|
||||
null
|
||||
}
|
||||
)
|
||||
|
||||
def toUnsafeRowWithoutHeadersProjector: Record => UnsafeRow =
|
||||
(cr: Record) => toUnsafeRowWithoutHeaders(toInternalRowWithoutHeaders(cr))
|
||||
|
||||
def toUnsafeRowWithHeadersProjector: Record => UnsafeRow =
|
||||
(cr: Record) => toUnsafeRowWithHeaders(toInternalRowWithHeaders(cr))
|
||||
|
||||
def toUnsafeRowProjector(includeHeaders: Boolean): Record => UnsafeRow = {
|
||||
if (includeHeaders) toUnsafeRowWithHeadersProjector else toUnsafeRowWithoutHeadersProjector
|
||||
}
|
||||
}
|
||||
|
||||
private[kafka010] object KafkaRecordToRowConverter {
|
||||
type Record = ConsumerRecord[Array[Byte], Array[Byte]]
|
||||
|
||||
val headersType = ArrayType(StructType(Array(
|
||||
StructField("key", StringType),
|
||||
StructField("value", BinaryType))))
|
||||
|
||||
private val schemaWithoutHeaders = new StructType(Array(
|
||||
StructField("key", BinaryType),
|
||||
StructField("value", BinaryType),
|
||||
StructField("topic", StringType),
|
||||
StructField("partition", IntegerType),
|
||||
StructField("offset", LongType),
|
||||
StructField("timestamp", TimestampType),
|
||||
StructField("timestampType", IntegerType)
|
||||
))
|
||||
|
||||
private val schemaWithHeaders =
|
||||
new StructType(schemaWithoutHeaders.fields :+ StructField("headers", headersType))
|
||||
|
||||
def kafkaSchema(includeHeaders: Boolean): StructType = {
|
||||
if (includeHeaders) schemaWithHeaders else schemaWithoutHeaders
|
||||
}
|
||||
}
|
|
@ -1,54 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.kafka010
|
||||
|
||||
import org.apache.kafka.clients.consumer.ConsumerRecord
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
/** A simple class for converting Kafka ConsumerRecord to UnsafeRow */
|
||||
private[kafka010] class KafkaRecordToUnsafeRowConverter {
|
||||
private val rowWriter = new UnsafeRowWriter(7)
|
||||
|
||||
def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = {
|
||||
rowWriter.reset()
|
||||
rowWriter.zeroOutNullBytes()
|
||||
|
||||
if (record.key == null) {
|
||||
rowWriter.setNullAt(0)
|
||||
} else {
|
||||
rowWriter.write(0, record.key)
|
||||
}
|
||||
if (record.value == null) {
|
||||
rowWriter.setNullAt(1)
|
||||
} else {
|
||||
rowWriter.write(1, record.value)
|
||||
}
|
||||
rowWriter.write(2, UTF8String.fromString(record.topic))
|
||||
rowWriter.write(3, record.partition)
|
||||
rowWriter.write(4, record.offset)
|
||||
rowWriter.write(
|
||||
5,
|
||||
DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp)))
|
||||
rowWriter.write(6, record.timestampType.id)
|
||||
rowWriter.getRow()
|
||||
}
|
||||
}
|
|
@ -24,10 +24,9 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
|
|||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{Row, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
|
||||
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
|
||||
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
|
||||
private[kafka010] class KafkaRelation(
|
||||
|
@ -36,6 +35,7 @@ private[kafka010] class KafkaRelation(
|
|||
sourceOptions: CaseInsensitiveMap[String],
|
||||
specifiedKafkaParams: Map[String, String],
|
||||
failOnDataLoss: Boolean,
|
||||
includeHeaders: Boolean,
|
||||
startingOffsets: KafkaOffsetRangeLimit,
|
||||
endingOffsets: KafkaOffsetRangeLimit)
|
||||
extends BaseRelation with TableScan with Logging {
|
||||
|
@ -49,7 +49,9 @@ private[kafka010] class KafkaRelation(
|
|||
(sqlContext.sparkContext.conf.get(NETWORK_TIMEOUT) * 1000L).toString
|
||||
).toLong
|
||||
|
||||
override def schema: StructType = KafkaOffsetReader.kafkaSchema
|
||||
private val converter = new KafkaRecordToRowConverter()
|
||||
|
||||
override def schema: StructType = KafkaRecordToRowConverter.kafkaSchema(includeHeaders)
|
||||
|
||||
override def buildScan(): RDD[Row] = {
|
||||
// Each running query should use its own group id. Otherwise, the query may be only assigned
|
||||
|
@ -100,18 +102,14 @@ private[kafka010] class KafkaRelation(
|
|||
// Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays.
|
||||
val executorKafkaParams =
|
||||
KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId)
|
||||
val toInternalRow = if (includeHeaders) {
|
||||
converter.toInternalRowWithHeaders
|
||||
} else {
|
||||
converter.toInternalRowWithoutHeaders
|
||||
}
|
||||
val rdd = new KafkaSourceRDD(
|
||||
sqlContext.sparkContext, executorKafkaParams, offsetRanges,
|
||||
pollTimeoutMs, failOnDataLoss).map { cr =>
|
||||
InternalRow(
|
||||
cr.key,
|
||||
cr.value,
|
||||
UTF8String.fromString(cr.topic),
|
||||
cr.partition,
|
||||
cr.offset,
|
||||
DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(cr.timestamp)),
|
||||
cr.timestampType.id)
|
||||
}
|
||||
pollTimeoutMs, failOnDataLoss).map(toInternalRow)
|
||||
sqlContext.internalCreateDataFrame(rdd.setName("kafka"), schema).rdd
|
||||
}
|
||||
|
||||
|
|
|
@ -31,12 +31,11 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
|
|||
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
|
||||
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
|
||||
import org.apache.spark.sql.execution.streaming._
|
||||
import org.apache.spark.sql.kafka010.KafkaSource._
|
||||
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
|
||||
import org.apache.spark.sql.kafka010.KafkaSourceProvider._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
/**
|
||||
* A [[Source]] that reads data from Kafka using the following design.
|
||||
|
@ -84,13 +83,15 @@ private[kafka010] class KafkaSource(
|
|||
|
||||
private val sc = sqlContext.sparkContext
|
||||
|
||||
private val pollTimeoutMs = sourceOptions.getOrElse(
|
||||
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
|
||||
(sc.conf.get(NETWORK_TIMEOUT) * 1000L).toString
|
||||
).toLong
|
||||
private val pollTimeoutMs =
|
||||
sourceOptions.getOrElse(CONSUMER_POLL_TIMEOUT, (sc.conf.get(NETWORK_TIMEOUT) * 1000L).toString)
|
||||
.toLong
|
||||
|
||||
private val maxOffsetsPerTrigger =
|
||||
sourceOptions.get(KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER).map(_.toLong)
|
||||
sourceOptions.get(MAX_OFFSET_PER_TRIGGER).map(_.toLong)
|
||||
|
||||
private val includeHeaders =
|
||||
sourceOptions.getOrElse(INCLUDE_HEADERS, "false").toBoolean
|
||||
|
||||
/**
|
||||
* Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only
|
||||
|
@ -113,7 +114,9 @@ private[kafka010] class KafkaSource(
|
|||
|
||||
private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None
|
||||
|
||||
override def schema: StructType = KafkaOffsetReader.kafkaSchema
|
||||
private val converter = new KafkaRecordToRowConverter()
|
||||
|
||||
override def schema: StructType = KafkaRecordToRowConverter.kafkaSchema(includeHeaders)
|
||||
|
||||
/** Returns the maximum available offset for this source. */
|
||||
override def getOffset: Option[Offset] = {
|
||||
|
@ -223,7 +226,7 @@ private[kafka010] class KafkaSource(
|
|||
val deletedPartitions = fromPartitionOffsets.keySet.diff(untilPartitionOffsets.keySet)
|
||||
if (deletedPartitions.nonEmpty) {
|
||||
val message = if (kafkaReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
|
||||
s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}"
|
||||
s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}"
|
||||
} else {
|
||||
s"$deletedPartitions are gone. Some data may have been missed."
|
||||
}
|
||||
|
@ -267,16 +270,14 @@ private[kafka010] class KafkaSource(
|
|||
}.toArray
|
||||
|
||||
// Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays.
|
||||
val rdd = new KafkaSourceRDD(
|
||||
sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss).map { cr =>
|
||||
InternalRow(
|
||||
cr.key,
|
||||
cr.value,
|
||||
UTF8String.fromString(cr.topic),
|
||||
cr.partition,
|
||||
cr.offset,
|
||||
DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(cr.timestamp)),
|
||||
cr.timestampType.id)
|
||||
val rdd = if (includeHeaders) {
|
||||
new KafkaSourceRDD(
|
||||
sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss)
|
||||
.map(converter.toInternalRowWithHeaders)
|
||||
} else {
|
||||
new KafkaSourceRDD(
|
||||
sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss)
|
||||
.map(converter.toInternalRowWithoutHeaders)
|
||||
}
|
||||
|
||||
logInfo("GetBatch generating RDD of offset range: " +
|
||||
|
|
|
@ -69,7 +69,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
|||
val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
|
||||
validateStreamOptions(caseInsensitiveParameters)
|
||||
require(schema.isEmpty, "Kafka source has a fixed schema and cannot be set with a custom one")
|
||||
(shortName(), KafkaOffsetReader.kafkaSchema)
|
||||
val includeHeaders = caseInsensitiveParameters.getOrElse(INCLUDE_HEADERS, "false").toBoolean
|
||||
(shortName(), KafkaRecordToRowConverter.kafkaSchema(includeHeaders))
|
||||
}
|
||||
|
||||
override def createSource(
|
||||
|
@ -107,7 +108,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
|||
}
|
||||
|
||||
override def getTable(options: CaseInsensitiveStringMap): KafkaTable = {
|
||||
new KafkaTable
|
||||
val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)
|
||||
new KafkaTable(includeHeaders)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -131,12 +133,15 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
|||
caseInsensitiveParameters, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
|
||||
assert(endingRelationOffsets != EarliestOffsetRangeLimit)
|
||||
|
||||
val includeHeaders = caseInsensitiveParameters.getOrElse(INCLUDE_HEADERS, "false").toBoolean
|
||||
|
||||
new KafkaRelation(
|
||||
sqlContext,
|
||||
strategy(caseInsensitiveParameters),
|
||||
sourceOptions = caseInsensitiveParameters,
|
||||
specifiedKafkaParams = specifiedKafkaParams,
|
||||
failOnDataLoss = failOnDataLoss(caseInsensitiveParameters),
|
||||
includeHeaders = includeHeaders,
|
||||
startingOffsets = startingRelationOffsets,
|
||||
endingOffsets = endingRelationOffsets)
|
||||
}
|
||||
|
@ -359,11 +364,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
|||
}
|
||||
}
|
||||
|
||||
class KafkaTable extends Table with SupportsRead with SupportsWrite {
|
||||
class KafkaTable(includeHeaders: Boolean) extends Table with SupportsRead with SupportsWrite {
|
||||
|
||||
override def name(): String = "KafkaTable"
|
||||
|
||||
override def schema(): StructType = KafkaOffsetReader.kafkaSchema
|
||||
override def schema(): StructType = KafkaRecordToRowConverter.kafkaSchema(includeHeaders)
|
||||
|
||||
override def capabilities(): ju.Set[TableCapability] = {
|
||||
import TableCapability._
|
||||
|
@ -403,8 +408,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
|||
}
|
||||
|
||||
class KafkaScan(options: CaseInsensitiveStringMap) extends Scan {
|
||||
val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)
|
||||
|
||||
override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema
|
||||
override def readSchema(): StructType = {
|
||||
KafkaRecordToRowConverter.kafkaSchema(includeHeaders)
|
||||
}
|
||||
|
||||
override def toBatch(): Batch = {
|
||||
val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap)
|
||||
|
@ -423,7 +431,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
|||
specifiedKafkaParams,
|
||||
failOnDataLoss(caseInsensitiveOptions),
|
||||
startingRelationOffsets,
|
||||
endingRelationOffsets)
|
||||
endingRelationOffsets,
|
||||
includeHeaders)
|
||||
}
|
||||
|
||||
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
|
||||
|
@ -498,6 +507,7 @@ private[kafka010] object KafkaSourceProvider extends Logging {
|
|||
private[kafka010] val FETCH_OFFSET_RETRY_INTERVAL_MS = "fetchoffset.retryintervalms"
|
||||
private[kafka010] val CONSUMER_POLL_TIMEOUT = "kafkaconsumer.polltimeoutms"
|
||||
private val GROUP_ID_PREFIX = "groupidprefix"
|
||||
private[kafka010] val INCLUDE_HEADERS = "includeheaders"
|
||||
|
||||
val TOPIC_OPTION_KEY = "topic"
|
||||
|
||||
|
|
|
@ -19,9 +19,13 @@ package org.apache.spark.sql.kafka010
|
|||
|
||||
import java.{util => ju}
|
||||
|
||||
import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata}
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata}
|
||||
import org.apache.kafka.common.header.Header
|
||||
import org.apache.kafka.common.header.internals.RecordHeader
|
||||
|
||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection}
|
||||
import org.apache.spark.sql.types.{BinaryType, StringType}
|
||||
|
||||
|
@ -88,7 +92,17 @@ private[kafka010] abstract class KafkaRowWriter(
|
|||
throw new NullPointerException(s"null topic present in the data. Use the " +
|
||||
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.")
|
||||
}
|
||||
val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value)
|
||||
val record = if (projectedRow.isNullAt(3)) {
|
||||
new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, null, key, value)
|
||||
} else {
|
||||
val headerArray = projectedRow.getArray(3)
|
||||
val headers = (0 until headerArray.numElements()).map { i =>
|
||||
val struct = headerArray.getStruct(i, 2)
|
||||
new RecordHeader(struct.getUTF8String(0).toString, struct.getBinary(1))
|
||||
.asInstanceOf[Header]
|
||||
}
|
||||
new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, null, key, value, headers.asJava)
|
||||
}
|
||||
producer.send(record, callback)
|
||||
}
|
||||
|
||||
|
@ -131,9 +145,26 @@ private[kafka010] abstract class KafkaRowWriter(
|
|||
throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " +
|
||||
s"attribute unsupported type ${t.catalogString}")
|
||||
}
|
||||
val headersExpression = inputSchema
|
||||
.find(_.name == KafkaWriter.HEADERS_ATTRIBUTE_NAME).getOrElse(
|
||||
Literal(CatalystTypeConverters.convertToCatalyst(null),
|
||||
KafkaRecordToRowConverter.headersType)
|
||||
)
|
||||
headersExpression.dataType match {
|
||||
case KafkaRecordToRowConverter.headersType => // good
|
||||
case t =>
|
||||
throw new IllegalStateException(s"${KafkaWriter.HEADERS_ATTRIBUTE_NAME} " +
|
||||
s"attribute unsupported type ${t.catalogString}")
|
||||
}
|
||||
UnsafeProjection.create(
|
||||
Seq(topicExpression, Cast(keyExpression, BinaryType),
|
||||
Cast(valueExpression, BinaryType)), inputSchema)
|
||||
Seq(
|
||||
topicExpression,
|
||||
Cast(keyExpression, BinaryType),
|
||||
Cast(valueExpression, BinaryType),
|
||||
headersExpression
|
||||
),
|
||||
inputSchema
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -21,9 +21,10 @@ import java.{util => ju}
|
|||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.{AnalysisException, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.CatalystTypeConverters
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
|
||||
import org.apache.spark.sql.types.{BinaryType, StringType}
|
||||
import org.apache.spark.sql.execution.QueryExecution
|
||||
import org.apache.spark.sql.types.{BinaryType, MapType, StringType}
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
|
@ -39,6 +40,7 @@ private[kafka010] object KafkaWriter extends Logging {
|
|||
val TOPIC_ATTRIBUTE_NAME: String = "topic"
|
||||
val KEY_ATTRIBUTE_NAME: String = "key"
|
||||
val VALUE_ATTRIBUTE_NAME: String = "value"
|
||||
val HEADERS_ATTRIBUTE_NAME: String = "headers"
|
||||
|
||||
override def toString: String = "KafkaWriter"
|
||||
|
||||
|
@ -75,6 +77,15 @@ private[kafka010] object KafkaWriter extends Logging {
|
|||
throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " +
|
||||
s"must be a ${StringType.catalogString} or ${BinaryType.catalogString}")
|
||||
}
|
||||
schema.find(_.name == HEADERS_ATTRIBUTE_NAME).getOrElse(
|
||||
Literal(CatalystTypeConverters.convertToCatalyst(null),
|
||||
KafkaRecordToRowConverter.headersType)
|
||||
).dataType match {
|
||||
case KafkaRecordToRowConverter.headersType => // good
|
||||
case _ =>
|
||||
throw new AnalysisException(s"$HEADERS_ATTRIBUTE_NAME attribute type " +
|
||||
s"must be a ${KafkaRecordToRowConverter.headersType.catalogString}")
|
||||
}
|
||||
}
|
||||
|
||||
def write(
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.kafka010
|
||||
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.util.concurrent.{Executors, TimeUnit}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
@ -91,7 +92,7 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
|
|||
test("new KafkaDataConsumer instance in case of Task retry") {
|
||||
try {
|
||||
val kafkaParams = getKafkaParams()
|
||||
val key = new CacheKey(groupId, topicPartition)
|
||||
val key = CacheKey(groupId, topicPartition)
|
||||
|
||||
val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null)
|
||||
TaskContext.setTaskContext(context1)
|
||||
|
@ -137,7 +138,8 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
|
|||
}
|
||||
|
||||
test("SPARK-23623: concurrent use of KafkaDataConsumer") {
|
||||
val data: immutable.IndexedSeq[String] = prepareTestTopicHavingTestMessages(topic)
|
||||
val data: immutable.IndexedSeq[(String, Seq[(String, Array[Byte])])] =
|
||||
prepareTestTopicHavingTestMessages(topic)
|
||||
|
||||
val topicPartition = new TopicPartition(topic, 0)
|
||||
val kafkaParams = getKafkaParams()
|
||||
|
@ -157,10 +159,22 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
|
|||
try {
|
||||
val range = consumer.getAvailableOffsetRange()
|
||||
val rcvd = range.earliest until range.latest map { offset =>
|
||||
val bytes = consumer.get(offset, Long.MaxValue, 10000, failOnDataLoss = false).value()
|
||||
new String(bytes)
|
||||
val record = consumer.get(offset, Long.MaxValue, 10000, failOnDataLoss = false)
|
||||
val value = new String(record.value(), StandardCharsets.UTF_8)
|
||||
val headers = record.headers().toArray.map(header => (header.key(), header.value())).toSeq
|
||||
(value, headers)
|
||||
}
|
||||
data.zip(rcvd).foreach { case (expected, actual) =>
|
||||
// value
|
||||
assert(expected._1 === actual._1)
|
||||
// headers
|
||||
expected._2.zip(actual._2).foreach { case (l, r) =>
|
||||
// header key
|
||||
assert(l._1 === r._1)
|
||||
// header value
|
||||
assert(l._2 === r._2)
|
||||
}
|
||||
}
|
||||
assert(rcvd == data)
|
||||
} catch {
|
||||
case e: Throwable =>
|
||||
error = e
|
||||
|
@ -307,9 +321,9 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
|
|||
}
|
||||
|
||||
private def prepareTestTopicHavingTestMessages(topic: String) = {
|
||||
val data = (1 to 1000).map(_.toString)
|
||||
val data = (1 to 1000).map(i => (i.toString, Seq[(String, Array[Byte])]()))
|
||||
testUtils.createTopic(topic, 1)
|
||||
testUtils.sendMessages(topic, data.toArray)
|
||||
testUtils.sendMessages(topic, data.toArray, None)
|
||||
data
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.kafka010
|
||||
|
||||
import java.nio.charset.StandardCharsets.UTF_8
|
||||
import java.util.Locale
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
|
||||
|
@ -70,7 +71,8 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession
|
|||
protected def createDF(
|
||||
topic: String,
|
||||
withOptions: Map[String, String] = Map.empty[String, String],
|
||||
brokerAddress: Option[String] = None) = {
|
||||
brokerAddress: Option[String] = None,
|
||||
includeHeaders: Boolean = false) = {
|
||||
val df = spark
|
||||
.read
|
||||
.format("kafka")
|
||||
|
@ -80,7 +82,13 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession
|
|||
withOptions.foreach {
|
||||
case (key, value) => df.option(key, value)
|
||||
}
|
||||
df.load().selectExpr("CAST(value AS STRING)")
|
||||
if (includeHeaders) {
|
||||
df.option("includeHeaders", "true")
|
||||
df.load()
|
||||
.selectExpr("CAST(value AS STRING)", "headers")
|
||||
} else {
|
||||
df.load().selectExpr("CAST(value AS STRING)")
|
||||
}
|
||||
}
|
||||
|
||||
test("explicit earliest to latest offsets") {
|
||||
|
@ -147,6 +155,27 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession
|
|||
checkAnswer(df, (0 to 30).map(_.toString).toDF)
|
||||
}
|
||||
|
||||
test("default starting and ending offsets with headers") {
|
||||
val topic = newTopic()
|
||||
testUtils.createTopic(topic, partitions = 3)
|
||||
testUtils.sendMessage(
|
||||
topic, ("1", Seq()), Some(0)
|
||||
)
|
||||
testUtils.sendMessage(
|
||||
topic, ("2", Seq(("a", "b".getBytes(UTF_8)), ("c", "d".getBytes(UTF_8)))), Some(1)
|
||||
)
|
||||
testUtils.sendMessage(
|
||||
topic, ("3", Seq(("e", "f".getBytes(UTF_8)), ("e", "g".getBytes(UTF_8)))), Some(2)
|
||||
)
|
||||
|
||||
// Implicit offset values, should default to earliest and latest
|
||||
val df = createDF(topic, includeHeaders = true)
|
||||
// Test that we default to "earliest" and "latest"
|
||||
checkAnswer(df, Seq(("1", null),
|
||||
("2", Seq(("a", "b".getBytes(UTF_8)), ("c", "d".getBytes(UTF_8)))),
|
||||
("3", Seq(("e", "f".getBytes(UTF_8)), ("e", "g".getBytes(UTF_8))))).toDF)
|
||||
}
|
||||
|
||||
test("reuse same dataframe in query") {
|
||||
// This test ensures that we do not cache the Kafka Consumer in KafkaRelation
|
||||
val topic = newTopic()
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.kafka010
|
||||
|
||||
import java.nio.charset.StandardCharsets.UTF_8
|
||||
import java.util.Locale
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
|
||||
|
@ -32,7 +33,7 @@ import org.apache.spark.sql.functions._
|
|||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.streaming._
|
||||
import org.apache.spark.sql.test.SharedSparkSession
|
||||
import org.apache.spark.sql.types.{BinaryType, DataType}
|
||||
import org.apache.spark.sql.types.{BinaryType, DataType, StringType, StructField, StructType}
|
||||
|
||||
abstract class KafkaSinkSuiteBase extends QueryTest with SharedSparkSession with KafkaTest {
|
||||
protected var testUtils: KafkaTestUtils = _
|
||||
|
@ -59,13 +60,14 @@ abstract class KafkaSinkSuiteBase extends QueryTest with SharedSparkSession with
|
|||
|
||||
protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}"
|
||||
|
||||
protected def createKafkaReader(topic: String): DataFrame = {
|
||||
protected def createKafkaReader(topic: String, includeHeaders: Boolean = false): DataFrame = {
|
||||
spark.read
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("startingOffsets", "earliest")
|
||||
.option("endingOffsets", "latest")
|
||||
.option("subscribe", topic)
|
||||
.option("includeHeaders", includeHeaders.toString)
|
||||
.load()
|
||||
}
|
||||
}
|
||||
|
@ -368,15 +370,51 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase {
|
|||
test("batch - write to kafka") {
|
||||
val topic = newTopic()
|
||||
testUtils.createTopic(topic)
|
||||
val df = Seq("1", "2", "3", "4", "5").map(v => (topic, v)).toDF("topic", "value")
|
||||
val data = Seq(
|
||||
Row(topic, "1", Seq(
|
||||
Row("a", "b".getBytes(UTF_8))
|
||||
)),
|
||||
Row(topic, "2", Seq(
|
||||
Row("c", "d".getBytes(UTF_8)),
|
||||
Row("e", "f".getBytes(UTF_8))
|
||||
)),
|
||||
Row(topic, "3", Seq(
|
||||
Row("g", "h".getBytes(UTF_8)),
|
||||
Row("g", "i".getBytes(UTF_8))
|
||||
)),
|
||||
Row(topic, "4", null),
|
||||
Row(topic, "5", Seq(
|
||||
Row("j", "k".getBytes(UTF_8)),
|
||||
Row("j", "l".getBytes(UTF_8)),
|
||||
Row("m", "n".getBytes(UTF_8))
|
||||
))
|
||||
)
|
||||
|
||||
val df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(data),
|
||||
StructType(Seq(StructField("topic", StringType), StructField("value", StringType),
|
||||
StructField("headers", KafkaRecordToRowConverter.headersType)))
|
||||
)
|
||||
|
||||
df.write
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("topic", topic)
|
||||
.save()
|
||||
checkAnswer(
|
||||
createKafkaReader(topic).selectExpr("CAST(value as STRING) value"),
|
||||
Row("1") :: Row("2") :: Row("3") :: Row("4") :: Row("5") :: Nil)
|
||||
createKafkaReader(topic, includeHeaders = true).selectExpr(
|
||||
"CAST(value as STRING) value", "headers"
|
||||
),
|
||||
Row("1", Seq(Row("a", "b".getBytes(UTF_8)))) ::
|
||||
Row("2", Seq(Row("c", "d".getBytes(UTF_8)), Row("e", "f".getBytes(UTF_8)))) ::
|
||||
Row("3", Seq(Row("g", "h".getBytes(UTF_8)), Row("g", "i".getBytes(UTF_8)))) ::
|
||||
Row("4", null) ::
|
||||
Row("5", Seq(
|
||||
Row("j", "k".getBytes(UTF_8)),
|
||||
Row("j", "l".getBytes(UTF_8)),
|
||||
Row("m", "n".getBytes(UTF_8)))) ::
|
||||
Nil
|
||||
)
|
||||
}
|
||||
|
||||
test("batch - null topic field value, and no topic option") {
|
||||
|
|
|
@ -41,6 +41,8 @@ import org.apache.kafka.clients.consumer.KafkaConsumer
|
|||
import org.apache.kafka.clients.producer._
|
||||
import org.apache.kafka.common.TopicPartition
|
||||
import org.apache.kafka.common.config.SaslConfigs
|
||||
import org.apache.kafka.common.header.Header
|
||||
import org.apache.kafka.common.header.internals.RecordHeader
|
||||
import org.apache.kafka.common.network.ListenerName
|
||||
import org.apache.kafka.common.security.auth.SecurityProtocol.{PLAINTEXT, SASL_PLAINTEXT}
|
||||
import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer}
|
||||
|
@ -369,17 +371,36 @@ class KafkaTestUtils(
|
|||
topic: String,
|
||||
messages: Array[String],
|
||||
partition: Option[Int]): Seq[(String, RecordMetadata)] = {
|
||||
sendMessages(topic, messages.map(m => (m, Seq())), partition)
|
||||
}
|
||||
|
||||
/** Send record to the Kafka broker with headers using specified partition */
|
||||
def sendMessage(topic: String,
|
||||
record: (String, Seq[(String, Array[Byte])]),
|
||||
partition: Option[Int]): Seq[(String, RecordMetadata)] = {
|
||||
sendMessages(topic, Array(record).toSeq, partition)
|
||||
}
|
||||
|
||||
/** Send the array of records to the Kafka broker with headers using specified partition */
|
||||
def sendMessages(topic: String,
|
||||
records: Seq[(String, Seq[(String, Array[Byte])])],
|
||||
partition: Option[Int]): Seq[(String, RecordMetadata)] = {
|
||||
producer = new KafkaProducer[String, String](producerConfiguration)
|
||||
val offsets = try {
|
||||
messages.map { m =>
|
||||
val record = partition match {
|
||||
case Some(p) => new ProducerRecord[String, String](topic, p, null, m)
|
||||
case None => new ProducerRecord[String, String](topic, m)
|
||||
records.map { case (value, header) =>
|
||||
val headers = header.map { case (k, v) =>
|
||||
new RecordHeader(k, v).asInstanceOf[Header]
|
||||
}
|
||||
val metadata =
|
||||
producer.send(record).get(10, TimeUnit.SECONDS)
|
||||
logInfo(s"\tSent $m to partition ${metadata.partition}, offset ${metadata.offset}")
|
||||
(m, metadata)
|
||||
val record = partition match {
|
||||
case Some(p) =>
|
||||
new ProducerRecord[String, String](topic, p, null, value, headers.asJava)
|
||||
case None =>
|
||||
new ProducerRecord[String, String](topic, null, null, value, headers.asJava)
|
||||
}
|
||||
val metadata = producer.send(record).get(10, TimeUnit.SECONDS)
|
||||
logInfo(s"\tSent ($value, $header) to partition ${metadata.partition}," +
|
||||
" offset ${metadata.offset}")
|
||||
(value, metadata)
|
||||
}
|
||||
} finally {
|
||||
if (producer != null) {
|
||||
|
|
Loading…
Reference in a new issue