[SPARK-23539][SS] Add support for Kafka headers in Structured Streaming

## What changes were proposed in this pull request?

This update adds support for Kafka Headers functionality in Structured Streaming.

## How was this patch tested?

With following unit tests:

- KafkaRelationSuite: "default starting and ending offsets with headers" (new)
- KafkaSinkSuite: "batch - write to kafka" (updated)

Closes #22282 from dongjinleekr/feature/SPARK-23539.

Lead-authored-by: Lee Dongjin <dongjin@apache.org>
Co-authored-by: Jungtaek Lim <kabhwan@gmail.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
Lee Dongjin 2019-09-13 12:31:28 -05:00 committed by Sean Owen
parent 77e9b58d4f
commit 1675d5114e
17 changed files with 404 additions and 158 deletions

View file

@ -27,6 +27,8 @@ For Scala/Java applications using SBT/Maven project definitions, link your appli
artifactId = spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}} artifactId = spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}}
version = {{site.SPARK_VERSION_SHORT}} version = {{site.SPARK_VERSION_SHORT}}
Please note that to use the headers functionality, your Kafka client version should be version 0.11.0.0 or up.
For Python applications, you need to add this above library and its dependencies when deploying your For Python applications, you need to add this above library and its dependencies when deploying your
application. See the [Deploying](#deploying) subsection below. application. See the [Deploying](#deploying) subsection below.
@ -50,6 +52,17 @@ val df = spark
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
.as[(String, String)] .as[(String, String)]
// Subscribe to 1 topic, with headers
val df = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", "host1:port1,host2:port2")
.option("subscribe", "topic1")
.option("includeHeaders", "true")
.load()
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers")
.as[(String, String, Map)]
// Subscribe to multiple topics // Subscribe to multiple topics
val df = spark val df = spark
.readStream .readStream
@ -84,6 +97,16 @@ Dataset<Row> df = spark
.load(); .load();
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)");
// Subscribe to 1 topic, with headers
Dataset<Row> df = spark
.readStream()
.format("kafka")
.option("kafka.bootstrap.servers", "host1:port1,host2:port2")
.option("subscribe", "topic1")
.option("includeHeaders", "true")
.load()
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers");
// Subscribe to multiple topics // Subscribe to multiple topics
Dataset<Row> df = spark Dataset<Row> df = spark
.readStream() .readStream()
@ -116,6 +139,16 @@ df = spark \
.load() .load()
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
# Subscribe to 1 topic, with headers
val df = spark \
.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "host1:port1,host2:port2") \
.option("subscribe", "topic1") \
.option("includeHeaders", "true") \
.load()
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers")
# Subscribe to multiple topics # Subscribe to multiple topics
df = spark \ df = spark \
.readStream \ .readStream \
@ -286,6 +319,10 @@ Each row in the source has the following schema:
<td>timestampType</td> <td>timestampType</td>
<td>int</td> <td>int</td>
</tr> </tr>
<tr>
<td>headers (optional)</td>
<td>array</td>
</tr>
</table> </table>
The following options must be set for the Kafka source The following options must be set for the Kafka source
@ -425,6 +462,13 @@ The following configurations are optional:
issues, set the Kafka consumer session timeout (by setting option "kafka.session.timeout.ms") to issues, set the Kafka consumer session timeout (by setting option "kafka.session.timeout.ms") to
be very small. When this is set, option "groupIdPrefix" will be ignored.</td> be very small. When this is set, option "groupIdPrefix" will be ignored.</td>
</tr> </tr>
<tr>
<td>includeHeaders</td>
<td>boolean</td>
<td>false</td>
<td>streaming and batch</td>
<td>Whether to include the Kafka headers in the row.</td>
</tr>
</table> </table>
### Consumer Caching ### Consumer Caching
@ -522,6 +566,10 @@ The Dataframe being written to Kafka should have the following columns in schema
<td>value (required)</td> <td>value (required)</td>
<td>string or binary</td> <td>string or binary</td>
</tr> </tr>
<tr>
<td>headers (optional)</td>
<td>array</td>
</tr>
<tr> <tr>
<td>topic (*optional)</td> <td>topic (*optional)</td>
<td>string</td> <td>string</td>
@ -559,6 +607,13 @@ The following configurations are optional:
<td>Sets the topic that all rows will be written to in Kafka. This option overrides any <td>Sets the topic that all rows will be written to in Kafka. This option overrides any
topic column that may exist in the data.</td> topic column that may exist in the data.</td>
</tr> </tr>
<tr>
<td>includeHeaders</td>
<td>boolean</td>
<td>false</td>
<td>streaming and batch</td>
<td>Whether to include the Kafka headers in the row.</td>
</tr>
</table> </table>
### Creating a Kafka Sink for Streaming Queries ### Creating a Kafka Sink for Streaming Queries

View file

@ -31,7 +31,8 @@ private[kafka010] class KafkaBatch(
specifiedKafkaParams: Map[String, String], specifiedKafkaParams: Map[String, String],
failOnDataLoss: Boolean, failOnDataLoss: Boolean,
startingOffsets: KafkaOffsetRangeLimit, startingOffsets: KafkaOffsetRangeLimit,
endingOffsets: KafkaOffsetRangeLimit) endingOffsets: KafkaOffsetRangeLimit,
includeHeaders: Boolean)
extends Batch with Logging { extends Batch with Logging {
assert(startingOffsets != LatestOffsetRangeLimit, assert(startingOffsets != LatestOffsetRangeLimit,
"Starting offset not allowed to be set to latest offsets.") "Starting offset not allowed to be set to latest offsets.")
@ -90,7 +91,7 @@ private[kafka010] class KafkaBatch(
KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId) KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId)
offsetRanges.map { range => offsetRanges.map { range =>
new KafkaBatchInputPartition( new KafkaBatchInputPartition(
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss) range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, includeHeaders)
}.toArray }.toArray
} }

View file

@ -29,13 +29,14 @@ private[kafka010] case class KafkaBatchInputPartition(
offsetRange: KafkaOffsetRange, offsetRange: KafkaOffsetRange,
executorKafkaParams: ju.Map[String, Object], executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long, pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends InputPartition failOnDataLoss: Boolean,
includeHeaders: Boolean) extends InputPartition
private[kafka010] object KafkaBatchReaderFactory extends PartitionReaderFactory { private[kafka010] object KafkaBatchReaderFactory extends PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
val p = partition.asInstanceOf[KafkaBatchInputPartition] val p = partition.asInstanceOf[KafkaBatchInputPartition]
KafkaBatchPartitionReader(p.offsetRange, p.executorKafkaParams, p.pollTimeoutMs, KafkaBatchPartitionReader(p.offsetRange, p.executorKafkaParams, p.pollTimeoutMs,
p.failOnDataLoss) p.failOnDataLoss, p.includeHeaders)
} }
} }
@ -44,12 +45,14 @@ private case class KafkaBatchPartitionReader(
offsetRange: KafkaOffsetRange, offsetRange: KafkaOffsetRange,
executorKafkaParams: ju.Map[String, Object], executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long, pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends PartitionReader[InternalRow] with Logging { failOnDataLoss: Boolean,
includeHeaders: Boolean) extends PartitionReader[InternalRow] with Logging {
private val consumer = KafkaDataConsumer.acquire(offsetRange.topicPartition, executorKafkaParams) private val consumer = KafkaDataConsumer.acquire(offsetRange.topicPartition, executorKafkaParams)
private val rangeToRead = resolveRange(offsetRange) private val rangeToRead = resolveRange(offsetRange)
private val converter = new KafkaRecordToUnsafeRowConverter private val unsafeRowProjector = new KafkaRecordToRowConverter()
.toUnsafeRowProjector(includeHeaders)
private var nextOffset = rangeToRead.fromOffset private var nextOffset = rangeToRead.fromOffset
private var nextRow: UnsafeRow = _ private var nextRow: UnsafeRow = _
@ -58,7 +61,7 @@ private case class KafkaBatchPartitionReader(
if (nextOffset < rangeToRead.untilOffset) { if (nextOffset < rangeToRead.untilOffset) {
val record = consumer.get(nextOffset, rangeToRead.untilOffset, pollTimeoutMs, failOnDataLoss) val record = consumer.get(nextOffset, rangeToRead.untilOffset, pollTimeoutMs, failOnDataLoss)
if (record != null) { if (record != null) {
nextRow = converter.toUnsafeRow(record) nextRow = unsafeRowProjector(record)
nextOffset = record.offset + 1 nextOffset = record.offset + 1
true true
} else { } else {

View file

@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReader, ContinuousPartitionReaderFactory, ContinuousStream, Offset, PartitionOffset} import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReader, ContinuousPartitionReaderFactory, ContinuousStream, Offset, PartitionOffset}
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.kafka010.KafkaSourceProvider._
import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.util.CaseInsensitiveStringMap
/** /**
@ -56,6 +56,7 @@ class KafkaContinuousStream(
private[kafka010] val pollTimeoutMs = private[kafka010] val pollTimeoutMs =
options.getLong(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, 512) options.getLong(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, 512)
private val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)
// Initialized when creating reader factories. If this diverges from the partitions at the latest // Initialized when creating reader factories. If this diverges from the partitions at the latest
// offsets, we need to reconfigure. // offsets, we need to reconfigure.
@ -88,7 +89,7 @@ class KafkaContinuousStream(
if (deletedPartitions.nonEmpty) { if (deletedPartitions.nonEmpty) {
val message = if ( val message = if (
offsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) { offsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}" s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}"
} else { } else {
s"$deletedPartitions are gone. Some data may have been missed." s"$deletedPartitions are gone. Some data may have been missed."
} }
@ -102,7 +103,7 @@ class KafkaContinuousStream(
startOffsets.toSeq.map { startOffsets.toSeq.map {
case (topicPartition, start) => case (topicPartition, start) =>
KafkaContinuousInputPartition( KafkaContinuousInputPartition(
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss, includeHeaders)
}.toArray }.toArray
} }
@ -153,19 +154,22 @@ class KafkaContinuousStream(
* @param pollTimeoutMs The timeout for Kafka consumer polling. * @param pollTimeoutMs The timeout for Kafka consumer polling.
* @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
* are skipped. * are skipped.
* @param includeHeaders Flag indicating whether to include Kafka records' headers.
*/ */
case class KafkaContinuousInputPartition( case class KafkaContinuousInputPartition(
topicPartition: TopicPartition, topicPartition: TopicPartition,
startOffset: Long, startOffset: Long,
kafkaParams: ju.Map[String, Object], kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long, pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends InputPartition failOnDataLoss: Boolean,
includeHeaders: Boolean) extends InputPartition
object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory { object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory {
override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = {
val p = partition.asInstanceOf[KafkaContinuousInputPartition] val p = partition.asInstanceOf[KafkaContinuousInputPartition]
new KafkaContinuousPartitionReader( new KafkaContinuousPartitionReader(
p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs, p.failOnDataLoss) p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs,
p.failOnDataLoss, p.includeHeaders)
} }
} }
@ -184,9 +188,11 @@ class KafkaContinuousPartitionReader(
startOffset: Long, startOffset: Long,
kafkaParams: ju.Map[String, Object], kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long, pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends ContinuousPartitionReader[InternalRow] { failOnDataLoss: Boolean,
includeHeaders: Boolean) extends ContinuousPartitionReader[InternalRow] {
private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams) private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams)
private val converter = new KafkaRecordToUnsafeRowConverter private val unsafeRowProjector = new KafkaRecordToRowConverter()
.toUnsafeRowProjector(includeHeaders)
private var nextKafkaOffset = startOffset private var nextKafkaOffset = startOffset
private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _
@ -225,7 +231,7 @@ class KafkaContinuousPartitionReader(
} }
override def get(): UnsafeRow = { override def get(): UnsafeRow = {
converter.toUnsafeRow(currentRecord) unsafeRowProjector(currentRecord)
} }
override def getOffset(): KafkaSourcePartitionOffset = { override def getOffset(): KafkaSourcePartitionOffset = {

View file

@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory}
import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset} import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset}
import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchStream import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchStream
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.kafka010.KafkaSourceProvider._
import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.UninterruptibleThread import org.apache.spark.util.UninterruptibleThread
@ -64,6 +64,8 @@ private[kafka010] class KafkaMicroBatchStream(
private[kafka010] val maxOffsetsPerTrigger = Option(options.get( private[kafka010] val maxOffsetsPerTrigger = Option(options.get(
KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER)).map(_.toLong) KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER)).map(_.toLong)
private val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)
private val rangeCalculator = KafkaOffsetRangeCalculator(options) private val rangeCalculator = KafkaOffsetRangeCalculator(options)
private var endPartitionOffsets: KafkaSourceOffset = _ private var endPartitionOffsets: KafkaSourceOffset = _
@ -112,7 +114,7 @@ private[kafka010] class KafkaMicroBatchStream(
if (deletedPartitions.nonEmpty) { if (deletedPartitions.nonEmpty) {
val message = val message =
if (kafkaOffsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) { if (kafkaOffsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}" s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}"
} else { } else {
s"$deletedPartitions are gone. Some data may have been missed." s"$deletedPartitions are gone. Some data may have been missed."
} }
@ -146,7 +148,8 @@ private[kafka010] class KafkaMicroBatchStream(
// Generate factories based on the offset ranges // Generate factories based on the offset ranges
offsetRanges.map { range => offsetRanges.map { range =>
KafkaBatchInputPartition(range, executorKafkaParams, pollTimeoutMs, failOnDataLoss) KafkaBatchInputPartition(range, executorKafkaParams, pollTimeoutMs,
failOnDataLoss, includeHeaders)
}.toArray }.toArray
} }

View file

@ -31,7 +31,6 @@ import org.apache.kafka.common.TopicPartition
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.types._
import org.apache.spark.util.{ThreadUtils, UninterruptibleThread} import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
/** /**
@ -421,16 +420,3 @@ private[kafka010] class KafkaOffsetReader(
_consumer = null // will automatically get reinitialized again _consumer = null // will automatically get reinitialized again
} }
} }
private[kafka010] object KafkaOffsetReader {
def kafkaSchema: StructType = StructType(Seq(
StructField("key", BinaryType),
StructField("value", BinaryType),
StructField("topic", StringType),
StructField("partition", IntegerType),
StructField("offset", LongType),
StructField("timestamp", TimestampType),
StructField("timestampType", IntegerType)
))
}

View file

@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.kafka010
import java.sql.Timestamp
import scala.collection.JavaConverters._
import org.apache.kafka.clients.consumer.ConsumerRecord
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
/** A simple class for converting Kafka ConsumerRecord to InternalRow/UnsafeRow */
private[kafka010] class KafkaRecordToRowConverter {
import KafkaRecordToRowConverter._
private val toUnsafeRowWithoutHeaders = UnsafeProjection.create(schemaWithoutHeaders)
private val toUnsafeRowWithHeaders = UnsafeProjection.create(schemaWithHeaders)
val toInternalRowWithoutHeaders: Record => InternalRow =
(cr: Record) => InternalRow(
cr.key, cr.value, UTF8String.fromString(cr.topic), cr.partition, cr.offset,
DateTimeUtils.fromJavaTimestamp(new Timestamp(cr.timestamp)), cr.timestampType.id
)
val toInternalRowWithHeaders: Record => InternalRow =
(cr: Record) => InternalRow(
cr.key, cr.value, UTF8String.fromString(cr.topic), cr.partition, cr.offset,
DateTimeUtils.fromJavaTimestamp(new Timestamp(cr.timestamp)), cr.timestampType.id,
if (cr.headers.iterator().hasNext) {
new GenericArrayData(cr.headers.iterator().asScala
.map(header =>
InternalRow(UTF8String.fromString(header.key()), header.value())
).toArray)
} else {
null
}
)
def toUnsafeRowWithoutHeadersProjector: Record => UnsafeRow =
(cr: Record) => toUnsafeRowWithoutHeaders(toInternalRowWithoutHeaders(cr))
def toUnsafeRowWithHeadersProjector: Record => UnsafeRow =
(cr: Record) => toUnsafeRowWithHeaders(toInternalRowWithHeaders(cr))
def toUnsafeRowProjector(includeHeaders: Boolean): Record => UnsafeRow = {
if (includeHeaders) toUnsafeRowWithHeadersProjector else toUnsafeRowWithoutHeadersProjector
}
}
private[kafka010] object KafkaRecordToRowConverter {
type Record = ConsumerRecord[Array[Byte], Array[Byte]]
val headersType = ArrayType(StructType(Array(
StructField("key", StringType),
StructField("value", BinaryType))))
private val schemaWithoutHeaders = new StructType(Array(
StructField("key", BinaryType),
StructField("value", BinaryType),
StructField("topic", StringType),
StructField("partition", IntegerType),
StructField("offset", LongType),
StructField("timestamp", TimestampType),
StructField("timestampType", IntegerType)
))
private val schemaWithHeaders =
new StructType(schemaWithoutHeaders.fields :+ StructField("headers", headersType))
def kafkaSchema(includeHeaders: Boolean): StructType = {
if (includeHeaders) schemaWithHeaders else schemaWithoutHeaders
}
}

View file

@ -1,54 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.kafka010
import org.apache.kafka.clients.consumer.ConsumerRecord
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.unsafe.types.UTF8String
/** A simple class for converting Kafka ConsumerRecord to UnsafeRow */
private[kafka010] class KafkaRecordToUnsafeRowConverter {
private val rowWriter = new UnsafeRowWriter(7)
def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = {
rowWriter.reset()
rowWriter.zeroOutNullBytes()
if (record.key == null) {
rowWriter.setNullAt(0)
} else {
rowWriter.write(0, record.key)
}
if (record.value == null) {
rowWriter.setNullAt(1)
} else {
rowWriter.write(1, record.value)
}
rowWriter.write(2, UTF8String.fromString(record.topic))
rowWriter.write(3, record.partition)
rowWriter.write(4, record.offset)
rowWriter.write(
5,
DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp)))
rowWriter.write(6, record.timestampType.id)
rowWriter.getRow()
}
}

View file

@ -24,10 +24,9 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.sources.{BaseRelation, TableScan} import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String
private[kafka010] class KafkaRelation( private[kafka010] class KafkaRelation(
@ -36,6 +35,7 @@ private[kafka010] class KafkaRelation(
sourceOptions: CaseInsensitiveMap[String], sourceOptions: CaseInsensitiveMap[String],
specifiedKafkaParams: Map[String, String], specifiedKafkaParams: Map[String, String],
failOnDataLoss: Boolean, failOnDataLoss: Boolean,
includeHeaders: Boolean,
startingOffsets: KafkaOffsetRangeLimit, startingOffsets: KafkaOffsetRangeLimit,
endingOffsets: KafkaOffsetRangeLimit) endingOffsets: KafkaOffsetRangeLimit)
extends BaseRelation with TableScan with Logging { extends BaseRelation with TableScan with Logging {
@ -49,7 +49,9 @@ private[kafka010] class KafkaRelation(
(sqlContext.sparkContext.conf.get(NETWORK_TIMEOUT) * 1000L).toString (sqlContext.sparkContext.conf.get(NETWORK_TIMEOUT) * 1000L).toString
).toLong ).toLong
override def schema: StructType = KafkaOffsetReader.kafkaSchema private val converter = new KafkaRecordToRowConverter()
override def schema: StructType = KafkaRecordToRowConverter.kafkaSchema(includeHeaders)
override def buildScan(): RDD[Row] = { override def buildScan(): RDD[Row] = {
// Each running query should use its own group id. Otherwise, the query may be only assigned // Each running query should use its own group id. Otherwise, the query may be only assigned
@ -100,18 +102,14 @@ private[kafka010] class KafkaRelation(
// Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays.
val executorKafkaParams = val executorKafkaParams =
KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId) KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId)
val toInternalRow = if (includeHeaders) {
converter.toInternalRowWithHeaders
} else {
converter.toInternalRowWithoutHeaders
}
val rdd = new KafkaSourceRDD( val rdd = new KafkaSourceRDD(
sqlContext.sparkContext, executorKafkaParams, offsetRanges, sqlContext.sparkContext, executorKafkaParams, offsetRanges,
pollTimeoutMs, failOnDataLoss).map { cr => pollTimeoutMs, failOnDataLoss).map(toInternalRow)
InternalRow(
cr.key,
cr.value,
UTF8String.fromString(cr.topic),
cr.partition,
cr.offset,
DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(cr.timestamp)),
cr.timestampType.id)
}
sqlContext.internalCreateDataFrame(rdd.setName("kafka"), schema).rdd sqlContext.internalCreateDataFrame(rdd.setName("kafka"), schema).rdd
} }

View file

@ -31,12 +31,11 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.kafka010.KafkaSource._ import org.apache.spark.sql.kafka010.KafkaSource._
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.kafka010.KafkaSourceProvider._
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
/** /**
* A [[Source]] that reads data from Kafka using the following design. * A [[Source]] that reads data from Kafka using the following design.
@ -84,13 +83,15 @@ private[kafka010] class KafkaSource(
private val sc = sqlContext.sparkContext private val sc = sqlContext.sparkContext
private val pollTimeoutMs = sourceOptions.getOrElse( private val pollTimeoutMs =
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, sourceOptions.getOrElse(CONSUMER_POLL_TIMEOUT, (sc.conf.get(NETWORK_TIMEOUT) * 1000L).toString)
(sc.conf.get(NETWORK_TIMEOUT) * 1000L).toString .toLong
).toLong
private val maxOffsetsPerTrigger = private val maxOffsetsPerTrigger =
sourceOptions.get(KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER).map(_.toLong) sourceOptions.get(MAX_OFFSET_PER_TRIGGER).map(_.toLong)
private val includeHeaders =
sourceOptions.getOrElse(INCLUDE_HEADERS, "false").toBoolean
/** /**
* Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only
@ -113,7 +114,9 @@ private[kafka010] class KafkaSource(
private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None
override def schema: StructType = KafkaOffsetReader.kafkaSchema private val converter = new KafkaRecordToRowConverter()
override def schema: StructType = KafkaRecordToRowConverter.kafkaSchema(includeHeaders)
/** Returns the maximum available offset for this source. */ /** Returns the maximum available offset for this source. */
override def getOffset: Option[Offset] = { override def getOffset: Option[Offset] = {
@ -223,7 +226,7 @@ private[kafka010] class KafkaSource(
val deletedPartitions = fromPartitionOffsets.keySet.diff(untilPartitionOffsets.keySet) val deletedPartitions = fromPartitionOffsets.keySet.diff(untilPartitionOffsets.keySet)
if (deletedPartitions.nonEmpty) { if (deletedPartitions.nonEmpty) {
val message = if (kafkaReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) { val message = if (kafkaReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}" s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}"
} else { } else {
s"$deletedPartitions are gone. Some data may have been missed." s"$deletedPartitions are gone. Some data may have been missed."
} }
@ -267,16 +270,14 @@ private[kafka010] class KafkaSource(
}.toArray }.toArray
// Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays.
val rdd = new KafkaSourceRDD( val rdd = if (includeHeaders) {
sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss).map { cr => new KafkaSourceRDD(
InternalRow( sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss)
cr.key, .map(converter.toInternalRowWithHeaders)
cr.value, } else {
UTF8String.fromString(cr.topic), new KafkaSourceRDD(
cr.partition, sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss)
cr.offset, .map(converter.toInternalRowWithoutHeaders)
DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(cr.timestamp)),
cr.timestampType.id)
} }
logInfo("GetBatch generating RDD of offset range: " + logInfo("GetBatch generating RDD of offset range: " +

View file

@ -69,7 +69,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
val caseInsensitiveParameters = CaseInsensitiveMap(parameters) val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
validateStreamOptions(caseInsensitiveParameters) validateStreamOptions(caseInsensitiveParameters)
require(schema.isEmpty, "Kafka source has a fixed schema and cannot be set with a custom one") require(schema.isEmpty, "Kafka source has a fixed schema and cannot be set with a custom one")
(shortName(), KafkaOffsetReader.kafkaSchema) val includeHeaders = caseInsensitiveParameters.getOrElse(INCLUDE_HEADERS, "false").toBoolean
(shortName(), KafkaRecordToRowConverter.kafkaSchema(includeHeaders))
} }
override def createSource( override def createSource(
@ -107,7 +108,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
} }
override def getTable(options: CaseInsensitiveStringMap): KafkaTable = { override def getTable(options: CaseInsensitiveStringMap): KafkaTable = {
new KafkaTable val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)
new KafkaTable(includeHeaders)
} }
/** /**
@ -131,12 +133,15 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
caseInsensitiveParameters, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) caseInsensitiveParameters, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
assert(endingRelationOffsets != EarliestOffsetRangeLimit) assert(endingRelationOffsets != EarliestOffsetRangeLimit)
val includeHeaders = caseInsensitiveParameters.getOrElse(INCLUDE_HEADERS, "false").toBoolean
new KafkaRelation( new KafkaRelation(
sqlContext, sqlContext,
strategy(caseInsensitiveParameters), strategy(caseInsensitiveParameters),
sourceOptions = caseInsensitiveParameters, sourceOptions = caseInsensitiveParameters,
specifiedKafkaParams = specifiedKafkaParams, specifiedKafkaParams = specifiedKafkaParams,
failOnDataLoss = failOnDataLoss(caseInsensitiveParameters), failOnDataLoss = failOnDataLoss(caseInsensitiveParameters),
includeHeaders = includeHeaders,
startingOffsets = startingRelationOffsets, startingOffsets = startingRelationOffsets,
endingOffsets = endingRelationOffsets) endingOffsets = endingRelationOffsets)
} }
@ -359,11 +364,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
} }
} }
class KafkaTable extends Table with SupportsRead with SupportsWrite { class KafkaTable(includeHeaders: Boolean) extends Table with SupportsRead with SupportsWrite {
override def name(): String = "KafkaTable" override def name(): String = "KafkaTable"
override def schema(): StructType = KafkaOffsetReader.kafkaSchema override def schema(): StructType = KafkaRecordToRowConverter.kafkaSchema(includeHeaders)
override def capabilities(): ju.Set[TableCapability] = { override def capabilities(): ju.Set[TableCapability] = {
import TableCapability._ import TableCapability._
@ -403,8 +408,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
} }
class KafkaScan(options: CaseInsensitiveStringMap) extends Scan { class KafkaScan(options: CaseInsensitiveStringMap) extends Scan {
val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)
override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema override def readSchema(): StructType = {
KafkaRecordToRowConverter.kafkaSchema(includeHeaders)
}
override def toBatch(): Batch = { override def toBatch(): Batch = {
val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap) val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap)
@ -423,7 +431,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
specifiedKafkaParams, specifiedKafkaParams,
failOnDataLoss(caseInsensitiveOptions), failOnDataLoss(caseInsensitiveOptions),
startingRelationOffsets, startingRelationOffsets,
endingRelationOffsets) endingRelationOffsets,
includeHeaders)
} }
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
@ -498,6 +507,7 @@ private[kafka010] object KafkaSourceProvider extends Logging {
private[kafka010] val FETCH_OFFSET_RETRY_INTERVAL_MS = "fetchoffset.retryintervalms" private[kafka010] val FETCH_OFFSET_RETRY_INTERVAL_MS = "fetchoffset.retryintervalms"
private[kafka010] val CONSUMER_POLL_TIMEOUT = "kafkaconsumer.polltimeoutms" private[kafka010] val CONSUMER_POLL_TIMEOUT = "kafkaconsumer.polltimeoutms"
private val GROUP_ID_PREFIX = "groupidprefix" private val GROUP_ID_PREFIX = "groupidprefix"
private[kafka010] val INCLUDE_HEADERS = "includeheaders"
val TOPIC_OPTION_KEY = "topic" val TOPIC_OPTION_KEY = "topic"

View file

@ -19,9 +19,13 @@ package org.apache.spark.sql.kafka010
import java.{util => ju} import java.{util => ju}
import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata} import scala.collection.JavaConverters._
import org.apache.spark.sql.catalyst.InternalRow import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata}
import org.apache.kafka.common.header.Header
import org.apache.kafka.common.header.internals.RecordHeader
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection}
import org.apache.spark.sql.types.{BinaryType, StringType} import org.apache.spark.sql.types.{BinaryType, StringType}
@ -88,7 +92,17 @@ private[kafka010] abstract class KafkaRowWriter(
throw new NullPointerException(s"null topic present in the data. Use the " + throw new NullPointerException(s"null topic present in the data. Use the " +
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.")
} }
val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) val record = if (projectedRow.isNullAt(3)) {
new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, null, key, value)
} else {
val headerArray = projectedRow.getArray(3)
val headers = (0 until headerArray.numElements()).map { i =>
val struct = headerArray.getStruct(i, 2)
new RecordHeader(struct.getUTF8String(0).toString, struct.getBinary(1))
.asInstanceOf[Header]
}
new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, null, key, value, headers.asJava)
}
producer.send(record, callback) producer.send(record, callback)
} }
@ -131,9 +145,26 @@ private[kafka010] abstract class KafkaRowWriter(
throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " + throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " +
s"attribute unsupported type ${t.catalogString}") s"attribute unsupported type ${t.catalogString}")
} }
val headersExpression = inputSchema
.find(_.name == KafkaWriter.HEADERS_ATTRIBUTE_NAME).getOrElse(
Literal(CatalystTypeConverters.convertToCatalyst(null),
KafkaRecordToRowConverter.headersType)
)
headersExpression.dataType match {
case KafkaRecordToRowConverter.headersType => // good
case t =>
throw new IllegalStateException(s"${KafkaWriter.HEADERS_ATTRIBUTE_NAME} " +
s"attribute unsupported type ${t.catalogString}")
}
UnsafeProjection.create( UnsafeProjection.create(
Seq(topicExpression, Cast(keyExpression, BinaryType), Seq(
Cast(valueExpression, BinaryType)), inputSchema) topicExpression,
Cast(keyExpression, BinaryType),
Cast(valueExpression, BinaryType),
headersExpression
),
inputSchema
)
} }
} }

View file

@ -21,9 +21,10 @@ import java.{util => ju}
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types.{BinaryType, StringType} import org.apache.spark.sql.types.{BinaryType, MapType, StringType}
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
/** /**
@ -39,6 +40,7 @@ private[kafka010] object KafkaWriter extends Logging {
val TOPIC_ATTRIBUTE_NAME: String = "topic" val TOPIC_ATTRIBUTE_NAME: String = "topic"
val KEY_ATTRIBUTE_NAME: String = "key" val KEY_ATTRIBUTE_NAME: String = "key"
val VALUE_ATTRIBUTE_NAME: String = "value" val VALUE_ATTRIBUTE_NAME: String = "value"
val HEADERS_ATTRIBUTE_NAME: String = "headers"
override def toString: String = "KafkaWriter" override def toString: String = "KafkaWriter"
@ -75,6 +77,15 @@ private[kafka010] object KafkaWriter extends Logging {
throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " + throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " +
s"must be a ${StringType.catalogString} or ${BinaryType.catalogString}") s"must be a ${StringType.catalogString} or ${BinaryType.catalogString}")
} }
schema.find(_.name == HEADERS_ATTRIBUTE_NAME).getOrElse(
Literal(CatalystTypeConverters.convertToCatalyst(null),
KafkaRecordToRowConverter.headersType)
).dataType match {
case KafkaRecordToRowConverter.headersType => // good
case _ =>
throw new AnalysisException(s"$HEADERS_ATTRIBUTE_NAME attribute type " +
s"must be a ${KafkaRecordToRowConverter.headersType.catalogString}")
}
} }
def write( def write(

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.kafka010 package org.apache.spark.sql.kafka010
import java.nio.charset.StandardCharsets
import java.util.concurrent.{Executors, TimeUnit} import java.util.concurrent.{Executors, TimeUnit}
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
@ -91,7 +92,7 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
test("new KafkaDataConsumer instance in case of Task retry") { test("new KafkaDataConsumer instance in case of Task retry") {
try { try {
val kafkaParams = getKafkaParams() val kafkaParams = getKafkaParams()
val key = new CacheKey(groupId, topicPartition) val key = CacheKey(groupId, topicPartition)
val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null) val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null)
TaskContext.setTaskContext(context1) TaskContext.setTaskContext(context1)
@ -137,7 +138,8 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
} }
test("SPARK-23623: concurrent use of KafkaDataConsumer") { test("SPARK-23623: concurrent use of KafkaDataConsumer") {
val data: immutable.IndexedSeq[String] = prepareTestTopicHavingTestMessages(topic) val data: immutable.IndexedSeq[(String, Seq[(String, Array[Byte])])] =
prepareTestTopicHavingTestMessages(topic)
val topicPartition = new TopicPartition(topic, 0) val topicPartition = new TopicPartition(topic, 0)
val kafkaParams = getKafkaParams() val kafkaParams = getKafkaParams()
@ -157,10 +159,22 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
try { try {
val range = consumer.getAvailableOffsetRange() val range = consumer.getAvailableOffsetRange()
val rcvd = range.earliest until range.latest map { offset => val rcvd = range.earliest until range.latest map { offset =>
val bytes = consumer.get(offset, Long.MaxValue, 10000, failOnDataLoss = false).value() val record = consumer.get(offset, Long.MaxValue, 10000, failOnDataLoss = false)
new String(bytes) val value = new String(record.value(), StandardCharsets.UTF_8)
val headers = record.headers().toArray.map(header => (header.key(), header.value())).toSeq
(value, headers)
}
data.zip(rcvd).foreach { case (expected, actual) =>
// value
assert(expected._1 === actual._1)
// headers
expected._2.zip(actual._2).foreach { case (l, r) =>
// header key
assert(l._1 === r._1)
// header value
assert(l._2 === r._2)
}
} }
assert(rcvd == data)
} catch { } catch {
case e: Throwable => case e: Throwable =>
error = e error = e
@ -307,9 +321,9 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
} }
private def prepareTestTopicHavingTestMessages(topic: String) = { private def prepareTestTopicHavingTestMessages(topic: String) = {
val data = (1 to 1000).map(_.toString) val data = (1 to 1000).map(i => (i.toString, Seq[(String, Array[Byte])]()))
testUtils.createTopic(topic, 1) testUtils.createTopic(topic, 1)
testUtils.sendMessages(topic, data.toArray) testUtils.sendMessages(topic, data.toArray, None)
data data
} }

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.kafka010 package org.apache.spark.sql.kafka010
import java.nio.charset.StandardCharsets.UTF_8
import java.util.Locale import java.util.Locale
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
@ -70,7 +71,8 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession
protected def createDF( protected def createDF(
topic: String, topic: String,
withOptions: Map[String, String] = Map.empty[String, String], withOptions: Map[String, String] = Map.empty[String, String],
brokerAddress: Option[String] = None) = { brokerAddress: Option[String] = None,
includeHeaders: Boolean = false) = {
val df = spark val df = spark
.read .read
.format("kafka") .format("kafka")
@ -80,7 +82,13 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession
withOptions.foreach { withOptions.foreach {
case (key, value) => df.option(key, value) case (key, value) => df.option(key, value)
} }
df.load().selectExpr("CAST(value AS STRING)") if (includeHeaders) {
df.option("includeHeaders", "true")
df.load()
.selectExpr("CAST(value AS STRING)", "headers")
} else {
df.load().selectExpr("CAST(value AS STRING)")
}
} }
test("explicit earliest to latest offsets") { test("explicit earliest to latest offsets") {
@ -147,6 +155,27 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession
checkAnswer(df, (0 to 30).map(_.toString).toDF) checkAnswer(df, (0 to 30).map(_.toString).toDF)
} }
test("default starting and ending offsets with headers") {
val topic = newTopic()
testUtils.createTopic(topic, partitions = 3)
testUtils.sendMessage(
topic, ("1", Seq()), Some(0)
)
testUtils.sendMessage(
topic, ("2", Seq(("a", "b".getBytes(UTF_8)), ("c", "d".getBytes(UTF_8)))), Some(1)
)
testUtils.sendMessage(
topic, ("3", Seq(("e", "f".getBytes(UTF_8)), ("e", "g".getBytes(UTF_8)))), Some(2)
)
// Implicit offset values, should default to earliest and latest
val df = createDF(topic, includeHeaders = true)
// Test that we default to "earliest" and "latest"
checkAnswer(df, Seq(("1", null),
("2", Seq(("a", "b".getBytes(UTF_8)), ("c", "d".getBytes(UTF_8)))),
("3", Seq(("e", "f".getBytes(UTF_8)), ("e", "g".getBytes(UTF_8))))).toDF)
}
test("reuse same dataframe in query") { test("reuse same dataframe in query") {
// This test ensures that we do not cache the Kafka Consumer in KafkaRelation // This test ensures that we do not cache the Kafka Consumer in KafkaRelation
val topic = newTopic() val topic = newTopic()

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.kafka010 package org.apache.spark.sql.kafka010
import java.nio.charset.StandardCharsets.UTF_8
import java.util.Locale import java.util.Locale
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
@ -32,7 +33,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming._
import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{BinaryType, DataType} import org.apache.spark.sql.types.{BinaryType, DataType, StringType, StructField, StructType}
abstract class KafkaSinkSuiteBase extends QueryTest with SharedSparkSession with KafkaTest { abstract class KafkaSinkSuiteBase extends QueryTest with SharedSparkSession with KafkaTest {
protected var testUtils: KafkaTestUtils = _ protected var testUtils: KafkaTestUtils = _
@ -59,13 +60,14 @@ abstract class KafkaSinkSuiteBase extends QueryTest with SharedSparkSession with
protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}"
protected def createKafkaReader(topic: String): DataFrame = { protected def createKafkaReader(topic: String, includeHeaders: Boolean = false): DataFrame = {
spark.read spark.read
.format("kafka") .format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("startingOffsets", "earliest") .option("startingOffsets", "earliest")
.option("endingOffsets", "latest") .option("endingOffsets", "latest")
.option("subscribe", topic) .option("subscribe", topic)
.option("includeHeaders", includeHeaders.toString)
.load() .load()
} }
} }
@ -368,15 +370,51 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase {
test("batch - write to kafka") { test("batch - write to kafka") {
val topic = newTopic() val topic = newTopic()
testUtils.createTopic(topic) testUtils.createTopic(topic)
val df = Seq("1", "2", "3", "4", "5").map(v => (topic, v)).toDF("topic", "value") val data = Seq(
Row(topic, "1", Seq(
Row("a", "b".getBytes(UTF_8))
)),
Row(topic, "2", Seq(
Row("c", "d".getBytes(UTF_8)),
Row("e", "f".getBytes(UTF_8))
)),
Row(topic, "3", Seq(
Row("g", "h".getBytes(UTF_8)),
Row("g", "i".getBytes(UTF_8))
)),
Row(topic, "4", null),
Row(topic, "5", Seq(
Row("j", "k".getBytes(UTF_8)),
Row("j", "l".getBytes(UTF_8)),
Row("m", "n".getBytes(UTF_8))
))
)
val df = spark.createDataFrame(
spark.sparkContext.parallelize(data),
StructType(Seq(StructField("topic", StringType), StructField("value", StringType),
StructField("headers", KafkaRecordToRowConverter.headersType)))
)
df.write df.write
.format("kafka") .format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("topic", topic) .option("topic", topic)
.save() .save()
checkAnswer( checkAnswer(
createKafkaReader(topic).selectExpr("CAST(value as STRING) value"), createKafkaReader(topic, includeHeaders = true).selectExpr(
Row("1") :: Row("2") :: Row("3") :: Row("4") :: Row("5") :: Nil) "CAST(value as STRING) value", "headers"
),
Row("1", Seq(Row("a", "b".getBytes(UTF_8)))) ::
Row("2", Seq(Row("c", "d".getBytes(UTF_8)), Row("e", "f".getBytes(UTF_8)))) ::
Row("3", Seq(Row("g", "h".getBytes(UTF_8)), Row("g", "i".getBytes(UTF_8)))) ::
Row("4", null) ::
Row("5", Seq(
Row("j", "k".getBytes(UTF_8)),
Row("j", "l".getBytes(UTF_8)),
Row("m", "n".getBytes(UTF_8)))) ::
Nil
)
} }
test("batch - null topic field value, and no topic option") { test("batch - null topic field value, and no topic option") {

View file

@ -41,6 +41,8 @@ import org.apache.kafka.clients.consumer.KafkaConsumer
import org.apache.kafka.clients.producer._ import org.apache.kafka.clients.producer._
import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.config.SaslConfigs import org.apache.kafka.common.config.SaslConfigs
import org.apache.kafka.common.header.Header
import org.apache.kafka.common.header.internals.RecordHeader
import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.network.ListenerName
import org.apache.kafka.common.security.auth.SecurityProtocol.{PLAINTEXT, SASL_PLAINTEXT} import org.apache.kafka.common.security.auth.SecurityProtocol.{PLAINTEXT, SASL_PLAINTEXT}
import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer} import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer}
@ -369,17 +371,36 @@ class KafkaTestUtils(
topic: String, topic: String,
messages: Array[String], messages: Array[String],
partition: Option[Int]): Seq[(String, RecordMetadata)] = { partition: Option[Int]): Seq[(String, RecordMetadata)] = {
sendMessages(topic, messages.map(m => (m, Seq())), partition)
}
/** Send record to the Kafka broker with headers using specified partition */
def sendMessage(topic: String,
record: (String, Seq[(String, Array[Byte])]),
partition: Option[Int]): Seq[(String, RecordMetadata)] = {
sendMessages(topic, Array(record).toSeq, partition)
}
/** Send the array of records to the Kafka broker with headers using specified partition */
def sendMessages(topic: String,
records: Seq[(String, Seq[(String, Array[Byte])])],
partition: Option[Int]): Seq[(String, RecordMetadata)] = {
producer = new KafkaProducer[String, String](producerConfiguration) producer = new KafkaProducer[String, String](producerConfiguration)
val offsets = try { val offsets = try {
messages.map { m => records.map { case (value, header) =>
val record = partition match { val headers = header.map { case (k, v) =>
case Some(p) => new ProducerRecord[String, String](topic, p, null, m) new RecordHeader(k, v).asInstanceOf[Header]
case None => new ProducerRecord[String, String](topic, m)
} }
val metadata = val record = partition match {
producer.send(record).get(10, TimeUnit.SECONDS) case Some(p) =>
logInfo(s"\tSent $m to partition ${metadata.partition}, offset ${metadata.offset}") new ProducerRecord[String, String](topic, p, null, value, headers.asJava)
(m, metadata) case None =>
new ProducerRecord[String, String](topic, null, null, value, headers.asJava)
}
val metadata = producer.send(record).get(10, TimeUnit.SECONDS)
logInfo(s"\tSent ($value, $header) to partition ${metadata.partition}," +
" offset ${metadata.offset}")
(value, metadata)
} }
} finally { } finally {
if (producer != null) { if (producer != null) {