diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 8cb2a45c56..ef4cdc2608 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -27,6 +27,8 @@ For Scala/Java applications using SBT/Maven project definitions, link your appli artifactId = spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}} version = {{site.SPARK_VERSION_SHORT}} +Please note that to use the headers functionality, your Kafka client version should be version 0.11.0.0 or up. + For Python applications, you need to add this above library and its dependencies when deploying your application. See the [Deploying](#deploying) subsection below. @@ -50,6 +52,17 @@ val df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] +// Subscribe to 1 topic, with headers +val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .option("includeHeaders", "true") + .load() +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers") + .as[(String, String, Map)] + // Subscribe to multiple topics val df = spark .readStream @@ -84,6 +97,16 @@ Dataset df = spark .load(); df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); +// Subscribe to 1 topic, with headers +Dataset df = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .option("includeHeaders", "true") + .load() +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers"); + // Subscribe to multiple topics Dataset df = spark .readStream() @@ -116,6 +139,16 @@ df = spark \ .load() df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +# Subscribe to 1 topic, with headers +val df = spark \ + .readStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("subscribe", "topic1") \ + .option("includeHeaders", "true") \ + .load() +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers") + # Subscribe to multiple topics df = spark \ .readStream \ @@ -286,6 +319,10 @@ Each row in the source has the following schema: timestampType int + + headers (optional) + array + The following options must be set for the Kafka source @@ -425,6 +462,13 @@ The following configurations are optional: issues, set the Kafka consumer session timeout (by setting option "kafka.session.timeout.ms") to be very small. When this is set, option "groupIdPrefix" will be ignored. + + includeHeaders + boolean + false + streaming and batch + Whether to include the Kafka headers in the row. + ### Consumer Caching @@ -522,6 +566,10 @@ The Dataframe being written to Kafka should have the following columns in schema value (required) string or binary + + headers (optional) + array + topic (*optional) string @@ -559,6 +607,13 @@ The following configurations are optional: Sets the topic that all rows will be written to in Kafka. This option overrides any topic column that may exist in the data. + + includeHeaders + boolean + false + streaming and batch + Whether to include the Kafka headers in the row. + ### Creating a Kafka Sink for Streaming Queries diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatch.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatch.scala index b958035b39..667c383681 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatch.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatch.scala @@ -31,7 +31,8 @@ private[kafka010] class KafkaBatch( specifiedKafkaParams: Map[String, String], failOnDataLoss: Boolean, startingOffsets: KafkaOffsetRangeLimit, - endingOffsets: KafkaOffsetRangeLimit) + endingOffsets: KafkaOffsetRangeLimit, + includeHeaders: Boolean) extends Batch with Logging { assert(startingOffsets != LatestOffsetRangeLimit, "Starting offset not allowed to be set to latest offsets.") @@ -90,7 +91,7 @@ private[kafka010] class KafkaBatch( KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId) offsetRanges.map { range => new KafkaBatchInputPartition( - range, executorKafkaParams, pollTimeoutMs, failOnDataLoss) + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, includeHeaders) }.toArray } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala index d4aa6774be..645b68b0c4 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala @@ -29,13 +29,14 @@ private[kafka010] case class KafkaBatchInputPartition( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends InputPartition + failOnDataLoss: Boolean, + includeHeaders: Boolean) extends InputPartition private[kafka010] object KafkaBatchReaderFactory extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val p = partition.asInstanceOf[KafkaBatchInputPartition] KafkaBatchPartitionReader(p.offsetRange, p.executorKafkaParams, p.pollTimeoutMs, - p.failOnDataLoss) + p.failOnDataLoss, p.includeHeaders) } } @@ -44,12 +45,14 @@ private case class KafkaBatchPartitionReader( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends PartitionReader[InternalRow] with Logging { + failOnDataLoss: Boolean, + includeHeaders: Boolean) extends PartitionReader[InternalRow] with Logging { private val consumer = KafkaDataConsumer.acquire(offsetRange.topicPartition, executorKafkaParams) private val rangeToRead = resolveRange(offsetRange) - private val converter = new KafkaRecordToUnsafeRowConverter + private val unsafeRowProjector = new KafkaRecordToRowConverter() + .toUnsafeRowProjector(includeHeaders) private var nextOffset = rangeToRead.fromOffset private var nextRow: UnsafeRow = _ @@ -58,7 +61,7 @@ private case class KafkaBatchPartitionReader( if (nextOffset < rangeToRead.untilOffset) { val record = consumer.get(nextOffset, rangeToRead.untilOffset, pollTimeoutMs, failOnDataLoss) if (record != null) { - nextRow = converter.toUnsafeRow(record) + nextRow = unsafeRowProjector(record) nextOffset = record.offset + 1 true } else { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala index 1e8da4bc0f..9e7b7d6db2 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReader, ContinuousPartitionReaderFactory, ContinuousStream, Offset, PartitionOffset} -import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.util.CaseInsensitiveStringMap /** @@ -56,6 +56,7 @@ class KafkaContinuousStream( private[kafka010] val pollTimeoutMs = options.getLong(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, 512) + private val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false) // Initialized when creating reader factories. If this diverges from the partitions at the latest // offsets, we need to reconfigure. @@ -88,7 +89,7 @@ class KafkaContinuousStream( if (deletedPartitions.nonEmpty) { val message = if ( offsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) { - s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}" + s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}" } else { s"$deletedPartitions are gone. Some data may have been missed." } @@ -102,7 +103,7 @@ class KafkaContinuousStream( startOffsets.toSeq.map { case (topicPartition, start) => KafkaContinuousInputPartition( - topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) + topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss, includeHeaders) }.toArray } @@ -153,19 +154,22 @@ class KafkaContinuousStream( * @param pollTimeoutMs The timeout for Kafka consumer polling. * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. + * @param includeHeaders Flag indicating whether to include Kafka records' headers. */ case class KafkaContinuousInputPartition( - topicPartition: TopicPartition, - startOffset: Long, - kafkaParams: ju.Map[String, Object], - pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends InputPartition + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean, + includeHeaders: Boolean) extends InputPartition object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory { override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { val p = partition.asInstanceOf[KafkaContinuousInputPartition] new KafkaContinuousPartitionReader( - p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs, p.failOnDataLoss) + p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs, + p.failOnDataLoss, p.includeHeaders) } } @@ -184,9 +188,11 @@ class KafkaContinuousPartitionReader( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousPartitionReader[InternalRow] { + failOnDataLoss: Boolean, + includeHeaders: Boolean) extends ContinuousPartitionReader[InternalRow] { private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams) - private val converter = new KafkaRecordToUnsafeRowConverter + private val unsafeRowProjector = new KafkaRecordToRowConverter() + .toUnsafeRowProjector(includeHeaders) private var nextKafkaOffset = startOffset private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ @@ -225,7 +231,7 @@ class KafkaContinuousPartitionReader( } override def get(): UnsafeRow = { - converter.toUnsafeRow(currentRecord) + unsafeRowProjector(currentRecord) } override def getOffset(): KafkaSourcePartitionOffset = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala index 52d91abc86..6ea6efe5d1 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset} import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchStream -import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.UninterruptibleThread @@ -64,6 +64,8 @@ private[kafka010] class KafkaMicroBatchStream( private[kafka010] val maxOffsetsPerTrigger = Option(options.get( KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER)).map(_.toLong) + private val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false) + private val rangeCalculator = KafkaOffsetRangeCalculator(options) private var endPartitionOffsets: KafkaSourceOffset = _ @@ -112,7 +114,7 @@ private[kafka010] class KafkaMicroBatchStream( if (deletedPartitions.nonEmpty) { val message = if (kafkaOffsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) { - s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}" + s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}" } else { s"$deletedPartitions are gone. Some data may have been missed." } @@ -146,7 +148,8 @@ private[kafka010] class KafkaMicroBatchStream( // Generate factories based on the offset ranges offsetRanges.map { range => - KafkaBatchInputPartition(range, executorKafkaParams, pollTimeoutMs, failOnDataLoss) + KafkaBatchInputPartition(range, executorKafkaParams, pollTimeoutMs, + failOnDataLoss, includeHeaders) }.toArray } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index f3effd5300..20f2ce11d4 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -31,7 +31,6 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.types._ import org.apache.spark.util.{ThreadUtils, UninterruptibleThread} /** @@ -421,16 +420,3 @@ private[kafka010] class KafkaOffsetReader( _consumer = null // will automatically get reinitialized again } } - -private[kafka010] object KafkaOffsetReader { - - def kafkaSchema: StructType = StructType(Seq( - StructField("key", BinaryType), - StructField("value", BinaryType), - StructField("topic", StringType), - StructField("partition", IntegerType), - StructField("offset", LongType), - StructField("timestamp", TimestampType), - StructField("timestampType", IntegerType) - )) -} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala new file mode 100644 index 0000000000..aed099c142 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.sql.Timestamp + +import scala.collection.JavaConverters._ + +import org.apache.kafka.clients.consumer.ConsumerRecord + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** A simple class for converting Kafka ConsumerRecord to InternalRow/UnsafeRow */ +private[kafka010] class KafkaRecordToRowConverter { + import KafkaRecordToRowConverter._ + + private val toUnsafeRowWithoutHeaders = UnsafeProjection.create(schemaWithoutHeaders) + private val toUnsafeRowWithHeaders = UnsafeProjection.create(schemaWithHeaders) + + val toInternalRowWithoutHeaders: Record => InternalRow = + (cr: Record) => InternalRow( + cr.key, cr.value, UTF8String.fromString(cr.topic), cr.partition, cr.offset, + DateTimeUtils.fromJavaTimestamp(new Timestamp(cr.timestamp)), cr.timestampType.id + ) + + val toInternalRowWithHeaders: Record => InternalRow = + (cr: Record) => InternalRow( + cr.key, cr.value, UTF8String.fromString(cr.topic), cr.partition, cr.offset, + DateTimeUtils.fromJavaTimestamp(new Timestamp(cr.timestamp)), cr.timestampType.id, + if (cr.headers.iterator().hasNext) { + new GenericArrayData(cr.headers.iterator().asScala + .map(header => + InternalRow(UTF8String.fromString(header.key()), header.value()) + ).toArray) + } else { + null + } + ) + + def toUnsafeRowWithoutHeadersProjector: Record => UnsafeRow = + (cr: Record) => toUnsafeRowWithoutHeaders(toInternalRowWithoutHeaders(cr)) + + def toUnsafeRowWithHeadersProjector: Record => UnsafeRow = + (cr: Record) => toUnsafeRowWithHeaders(toInternalRowWithHeaders(cr)) + + def toUnsafeRowProjector(includeHeaders: Boolean): Record => UnsafeRow = { + if (includeHeaders) toUnsafeRowWithHeadersProjector else toUnsafeRowWithoutHeadersProjector + } +} + +private[kafka010] object KafkaRecordToRowConverter { + type Record = ConsumerRecord[Array[Byte], Array[Byte]] + + val headersType = ArrayType(StructType(Array( + StructField("key", StringType), + StructField("value", BinaryType)))) + + private val schemaWithoutHeaders = new StructType(Array( + StructField("key", BinaryType), + StructField("value", BinaryType), + StructField("topic", StringType), + StructField("partition", IntegerType), + StructField("offset", LongType), + StructField("timestamp", TimestampType), + StructField("timestampType", IntegerType) + )) + + private val schemaWithHeaders = + new StructType(schemaWithoutHeaders.fields :+ StructField("headers", headersType)) + + def kafkaSchema(includeHeaders: Boolean): StructType = { + if (includeHeaders) schemaWithHeaders else schemaWithoutHeaders + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala deleted file mode 100644 index 306ef10b77..0000000000 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import org.apache.kafka.clients.consumer.ConsumerRecord - -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.unsafe.types.UTF8String - -/** A simple class for converting Kafka ConsumerRecord to UnsafeRow */ -private[kafka010] class KafkaRecordToUnsafeRowConverter { - private val rowWriter = new UnsafeRowWriter(7) - - def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = { - rowWriter.reset() - rowWriter.zeroOutNullBytes() - - if (record.key == null) { - rowWriter.setNullAt(0) - } else { - rowWriter.write(0, record.key) - } - if (record.value == null) { - rowWriter.setNullAt(1) - } else { - rowWriter.write(1, record.value) - } - rowWriter.write(2, UTF8String.fromString(record.topic)) - rowWriter.write(3, record.partition) - rowWriter.write(4, record.offset) - rowWriter.write( - 5, - DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp))) - rowWriter.write(6, record.timestampType.id) - rowWriter.getRow() - } -} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala index dc7087821b..886f6b0fe0 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala @@ -24,10 +24,9 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.sources.{BaseRelation, TableScan} import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.UTF8String private[kafka010] class KafkaRelation( @@ -36,6 +35,7 @@ private[kafka010] class KafkaRelation( sourceOptions: CaseInsensitiveMap[String], specifiedKafkaParams: Map[String, String], failOnDataLoss: Boolean, + includeHeaders: Boolean, startingOffsets: KafkaOffsetRangeLimit, endingOffsets: KafkaOffsetRangeLimit) extends BaseRelation with TableScan with Logging { @@ -49,7 +49,9 @@ private[kafka010] class KafkaRelation( (sqlContext.sparkContext.conf.get(NETWORK_TIMEOUT) * 1000L).toString ).toLong - override def schema: StructType = KafkaOffsetReader.kafkaSchema + private val converter = new KafkaRecordToRowConverter() + + override def schema: StructType = KafkaRecordToRowConverter.kafkaSchema(includeHeaders) override def buildScan(): RDD[Row] = { // Each running query should use its own group id. Otherwise, the query may be only assigned @@ -100,18 +102,14 @@ private[kafka010] class KafkaRelation( // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. val executorKafkaParams = KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId) + val toInternalRow = if (includeHeaders) { + converter.toInternalRowWithHeaders + } else { + converter.toInternalRowWithoutHeaders + } val rdd = new KafkaSourceRDD( sqlContext.sparkContext, executorKafkaParams, offsetRanges, - pollTimeoutMs, failOnDataLoss).map { cr => - InternalRow( - cr.key, - cr.value, - UTF8String.fromString(cr.topic), - cr.partition, - cr.offset, - DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(cr.timestamp)), - cr.timestampType.id) - } + pollTimeoutMs, failOnDataLoss).map(toInternalRow) sqlContext.internalCreateDataFrame(rdd.setName("kafka"), schema).rdd } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index d1a35ec53b..29944dc3fb 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -31,12 +31,11 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.kafka010.KafkaSource._ -import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String /** * A [[Source]] that reads data from Kafka using the following design. @@ -84,13 +83,15 @@ private[kafka010] class KafkaSource( private val sc = sqlContext.sparkContext - private val pollTimeoutMs = sourceOptions.getOrElse( - KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, - (sc.conf.get(NETWORK_TIMEOUT) * 1000L).toString - ).toLong + private val pollTimeoutMs = + sourceOptions.getOrElse(CONSUMER_POLL_TIMEOUT, (sc.conf.get(NETWORK_TIMEOUT) * 1000L).toString) + .toLong private val maxOffsetsPerTrigger = - sourceOptions.get(KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER).map(_.toLong) + sourceOptions.get(MAX_OFFSET_PER_TRIGGER).map(_.toLong) + + private val includeHeaders = + sourceOptions.getOrElse(INCLUDE_HEADERS, "false").toBoolean /** * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only @@ -113,7 +114,9 @@ private[kafka010] class KafkaSource( private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None - override def schema: StructType = KafkaOffsetReader.kafkaSchema + private val converter = new KafkaRecordToRowConverter() + + override def schema: StructType = KafkaRecordToRowConverter.kafkaSchema(includeHeaders) /** Returns the maximum available offset for this source. */ override def getOffset: Option[Offset] = { @@ -223,7 +226,7 @@ private[kafka010] class KafkaSource( val deletedPartitions = fromPartitionOffsets.keySet.diff(untilPartitionOffsets.keySet) if (deletedPartitions.nonEmpty) { val message = if (kafkaReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) { - s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}" + s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}" } else { s"$deletedPartitions are gone. Some data may have been missed." } @@ -267,16 +270,14 @@ private[kafka010] class KafkaSource( }.toArray // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. - val rdd = new KafkaSourceRDD( - sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss).map { cr => - InternalRow( - cr.key, - cr.value, - UTF8String.fromString(cr.topic), - cr.partition, - cr.offset, - DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(cr.timestamp)), - cr.timestampType.id) + val rdd = if (includeHeaders) { + new KafkaSourceRDD( + sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss) + .map(converter.toInternalRowWithHeaders) + } else { + new KafkaSourceRDD( + sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss) + .map(converter.toInternalRowWithoutHeaders) } logInfo("GetBatch generating RDD of offset range: " + diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 53a6919e2d..a7f8db35d7 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -69,7 +69,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister val caseInsensitiveParameters = CaseInsensitiveMap(parameters) validateStreamOptions(caseInsensitiveParameters) require(schema.isEmpty, "Kafka source has a fixed schema and cannot be set with a custom one") - (shortName(), KafkaOffsetReader.kafkaSchema) + val includeHeaders = caseInsensitiveParameters.getOrElse(INCLUDE_HEADERS, "false").toBoolean + (shortName(), KafkaRecordToRowConverter.kafkaSchema(includeHeaders)) } override def createSource( @@ -107,7 +108,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } override def getTable(options: CaseInsensitiveStringMap): KafkaTable = { - new KafkaTable + val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false) + new KafkaTable(includeHeaders) } /** @@ -131,12 +133,15 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister caseInsensitiveParameters, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) assert(endingRelationOffsets != EarliestOffsetRangeLimit) + val includeHeaders = caseInsensitiveParameters.getOrElse(INCLUDE_HEADERS, "false").toBoolean + new KafkaRelation( sqlContext, strategy(caseInsensitiveParameters), sourceOptions = caseInsensitiveParameters, specifiedKafkaParams = specifiedKafkaParams, failOnDataLoss = failOnDataLoss(caseInsensitiveParameters), + includeHeaders = includeHeaders, startingOffsets = startingRelationOffsets, endingOffsets = endingRelationOffsets) } @@ -359,11 +364,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - class KafkaTable extends Table with SupportsRead with SupportsWrite { + class KafkaTable(includeHeaders: Boolean) extends Table with SupportsRead with SupportsWrite { override def name(): String = "KafkaTable" - override def schema(): StructType = KafkaOffsetReader.kafkaSchema + override def schema(): StructType = KafkaRecordToRowConverter.kafkaSchema(includeHeaders) override def capabilities(): ju.Set[TableCapability] = { import TableCapability._ @@ -403,8 +408,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } class KafkaScan(options: CaseInsensitiveStringMap) extends Scan { + val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false) - override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema + override def readSchema(): StructType = { + KafkaRecordToRowConverter.kafkaSchema(includeHeaders) + } override def toBatch(): Batch = { val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap) @@ -423,7 +431,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister specifiedKafkaParams, failOnDataLoss(caseInsensitiveOptions), startingRelationOffsets, - endingRelationOffsets) + endingRelationOffsets, + includeHeaders) } override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { @@ -498,6 +507,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { private[kafka010] val FETCH_OFFSET_RETRY_INTERVAL_MS = "fetchoffset.retryintervalms" private[kafka010] val CONSUMER_POLL_TIMEOUT = "kafkaconsumer.polltimeoutms" private val GROUP_ID_PREFIX = "groupidprefix" + private[kafka010] val INCLUDE_HEADERS = "includeheaders" val TOPIC_OPTION_KEY = "topic" diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index 041fac7717..b423ddc959 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -19,9 +19,13 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata} +import scala.collection.JavaConverters._ -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata} +import org.apache.kafka.common.header.Header +import org.apache.kafka.common.header.internals.RecordHeader + +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} import org.apache.spark.sql.types.{BinaryType, StringType} @@ -88,7 +92,17 @@ private[kafka010] abstract class KafkaRowWriter( throw new NullPointerException(s"null topic present in the data. Use the " + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") } - val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + val record = if (projectedRow.isNullAt(3)) { + new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, null, key, value) + } else { + val headerArray = projectedRow.getArray(3) + val headers = (0 until headerArray.numElements()).map { i => + val struct = headerArray.getStruct(i, 2) + new RecordHeader(struct.getUTF8String(0).toString, struct.getBinary(1)) + .asInstanceOf[Header] + } + new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, null, key, value, headers.asJava) + } producer.send(record, callback) } @@ -131,9 +145,26 @@ private[kafka010] abstract class KafkaRowWriter( throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " + s"attribute unsupported type ${t.catalogString}") } + val headersExpression = inputSchema + .find(_.name == KafkaWriter.HEADERS_ATTRIBUTE_NAME).getOrElse( + Literal(CatalystTypeConverters.convertToCatalyst(null), + KafkaRecordToRowConverter.headersType) + ) + headersExpression.dataType match { + case KafkaRecordToRowConverter.headersType => // good + case t => + throw new IllegalStateException(s"${KafkaWriter.HEADERS_ATTRIBUTE_NAME} " + + s"attribute unsupported type ${t.catalogString}") + } UnsafeProjection.create( - Seq(topicExpression, Cast(keyExpression, BinaryType), - Cast(valueExpression, BinaryType)), inputSchema) + Seq( + topicExpression, + Cast(keyExpression, BinaryType), + Cast(valueExpression, BinaryType), + headersExpression + ), + inputSchema + ) } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index e1a9191cc5..bbb060356f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -21,9 +21,10 @@ import java.{util => ju} import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} -import org.apache.spark.sql.types.{BinaryType, StringType} +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.types.{BinaryType, MapType, StringType} import org.apache.spark.util.Utils /** @@ -39,6 +40,7 @@ private[kafka010] object KafkaWriter extends Logging { val TOPIC_ATTRIBUTE_NAME: String = "topic" val KEY_ATTRIBUTE_NAME: String = "key" val VALUE_ATTRIBUTE_NAME: String = "value" + val HEADERS_ATTRIBUTE_NAME: String = "headers" override def toString: String = "KafkaWriter" @@ -75,6 +77,15 @@ private[kafka010] object KafkaWriter extends Logging { throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " + s"must be a ${StringType.catalogString} or ${BinaryType.catalogString}") } + schema.find(_.name == HEADERS_ATTRIBUTE_NAME).getOrElse( + Literal(CatalystTypeConverters.convertToCatalyst(null), + KafkaRecordToRowConverter.headersType) + ).dataType match { + case KafkaRecordToRowConverter.headersType => // good + case _ => + throw new AnalysisException(s"$HEADERS_ATTRIBUTE_NAME attribute type " + + s"must be a ${KafkaRecordToRowConverter.headersType.catalogString}") + } } def write( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala index 80f9a1b410..d97f627fba 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.kafka010 +import java.nio.charset.StandardCharsets import java.util.concurrent.{Executors, TimeUnit} import scala.collection.JavaConverters._ @@ -91,7 +92,7 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester test("new KafkaDataConsumer instance in case of Task retry") { try { val kafkaParams = getKafkaParams() - val key = new CacheKey(groupId, topicPartition) + val key = CacheKey(groupId, topicPartition) val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null) TaskContext.setTaskContext(context1) @@ -137,7 +138,8 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester } test("SPARK-23623: concurrent use of KafkaDataConsumer") { - val data: immutable.IndexedSeq[String] = prepareTestTopicHavingTestMessages(topic) + val data: immutable.IndexedSeq[(String, Seq[(String, Array[Byte])])] = + prepareTestTopicHavingTestMessages(topic) val topicPartition = new TopicPartition(topic, 0) val kafkaParams = getKafkaParams() @@ -157,10 +159,22 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester try { val range = consumer.getAvailableOffsetRange() val rcvd = range.earliest until range.latest map { offset => - val bytes = consumer.get(offset, Long.MaxValue, 10000, failOnDataLoss = false).value() - new String(bytes) + val record = consumer.get(offset, Long.MaxValue, 10000, failOnDataLoss = false) + val value = new String(record.value(), StandardCharsets.UTF_8) + val headers = record.headers().toArray.map(header => (header.key(), header.value())).toSeq + (value, headers) + } + data.zip(rcvd).foreach { case (expected, actual) => + // value + assert(expected._1 === actual._1) + // headers + expected._2.zip(actual._2).foreach { case (l, r) => + // header key + assert(l._1 === r._1) + // header value + assert(l._2 === r._2) + } } - assert(rcvd == data) } catch { case e: Throwable => error = e @@ -307,9 +321,9 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester } private def prepareTestTopicHavingTestMessages(topic: String) = { - val data = (1 to 1000).map(_.toString) + val data = (1 to 1000).map(i => (i.toString, Seq[(String, Array[Byte])]())) testUtils.createTopic(topic, 1) - testUtils.sendMessages(topic, data.toArray) + testUtils.sendMessages(topic, data.toArray, None) data } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala index b4e1b78c7d..3c88609bcb 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.kafka010 +import java.nio.charset.StandardCharsets.UTF_8 import java.util.Locale import java.util.concurrent.atomic.AtomicInteger @@ -70,7 +71,8 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession protected def createDF( topic: String, withOptions: Map[String, String] = Map.empty[String, String], - brokerAddress: Option[String] = None) = { + brokerAddress: Option[String] = None, + includeHeaders: Boolean = false) = { val df = spark .read .format("kafka") @@ -80,7 +82,13 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession withOptions.foreach { case (key, value) => df.option(key, value) } - df.load().selectExpr("CAST(value AS STRING)") + if (includeHeaders) { + df.option("includeHeaders", "true") + df.load() + .selectExpr("CAST(value AS STRING)", "headers") + } else { + df.load().selectExpr("CAST(value AS STRING)") + } } test("explicit earliest to latest offsets") { @@ -147,6 +155,27 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession checkAnswer(df, (0 to 30).map(_.toString).toDF) } + test("default starting and ending offsets with headers") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessage( + topic, ("1", Seq()), Some(0) + ) + testUtils.sendMessage( + topic, ("2", Seq(("a", "b".getBytes(UTF_8)), ("c", "d".getBytes(UTF_8)))), Some(1) + ) + testUtils.sendMessage( + topic, ("3", Seq(("e", "f".getBytes(UTF_8)), ("e", "g".getBytes(UTF_8)))), Some(2) + ) + + // Implicit offset values, should default to earliest and latest + val df = createDF(topic, includeHeaders = true) + // Test that we default to "earliest" and "latest" + checkAnswer(df, Seq(("1", null), + ("2", Seq(("a", "b".getBytes(UTF_8)), ("c", "d".getBytes(UTF_8)))), + ("3", Seq(("e", "f".getBytes(UTF_8)), ("e", "g".getBytes(UTF_8))))).toDF) + } + test("reuse same dataframe in query") { // This test ensures that we do not cache the Kafka Consumer in KafkaRelation val topic = newTopic() diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 84ad41610c..fdda13b1bf 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.kafka010 +import java.nio.charset.StandardCharsets.UTF_8 import java.util.Locale import java.util.concurrent.atomic.AtomicInteger @@ -32,7 +33,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{BinaryType, DataType} +import org.apache.spark.sql.types.{BinaryType, DataType, StringType, StructField, StructType} abstract class KafkaSinkSuiteBase extends QueryTest with SharedSparkSession with KafkaTest { protected var testUtils: KafkaTestUtils = _ @@ -59,13 +60,14 @@ abstract class KafkaSinkSuiteBase extends QueryTest with SharedSparkSession with protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" - protected def createKafkaReader(topic: String): DataFrame = { + protected def createKafkaReader(topic: String, includeHeaders: Boolean = false): DataFrame = { spark.read .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("startingOffsets", "earliest") .option("endingOffsets", "latest") .option("subscribe", topic) + .option("includeHeaders", includeHeaders.toString) .load() } } @@ -368,15 +370,51 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase { test("batch - write to kafka") { val topic = newTopic() testUtils.createTopic(topic) - val df = Seq("1", "2", "3", "4", "5").map(v => (topic, v)).toDF("topic", "value") + val data = Seq( + Row(topic, "1", Seq( + Row("a", "b".getBytes(UTF_8)) + )), + Row(topic, "2", Seq( + Row("c", "d".getBytes(UTF_8)), + Row("e", "f".getBytes(UTF_8)) + )), + Row(topic, "3", Seq( + Row("g", "h".getBytes(UTF_8)), + Row("g", "i".getBytes(UTF_8)) + )), + Row(topic, "4", null), + Row(topic, "5", Seq( + Row("j", "k".getBytes(UTF_8)), + Row("j", "l".getBytes(UTF_8)), + Row("m", "n".getBytes(UTF_8)) + )) + ) + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(data), + StructType(Seq(StructField("topic", StringType), StructField("value", StringType), + StructField("headers", KafkaRecordToRowConverter.headersType))) + ) + df.write .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("topic", topic) .save() checkAnswer( - createKafkaReader(topic).selectExpr("CAST(value as STRING) value"), - Row("1") :: Row("2") :: Row("3") :: Row("4") :: Row("5") :: Nil) + createKafkaReader(topic, includeHeaders = true).selectExpr( + "CAST(value as STRING) value", "headers" + ), + Row("1", Seq(Row("a", "b".getBytes(UTF_8)))) :: + Row("2", Seq(Row("c", "d".getBytes(UTF_8)), Row("e", "f".getBytes(UTF_8)))) :: + Row("3", Seq(Row("g", "h".getBytes(UTF_8)), Row("g", "i".getBytes(UTF_8)))) :: + Row("4", null) :: + Row("5", Seq( + Row("j", "k".getBytes(UTF_8)), + Row("j", "l".getBytes(UTF_8)), + Row("m", "n".getBytes(UTF_8)))) :: + Nil + ) } test("batch - null topic field value, and no topic option") { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index d7cb30f530..f7114129a3 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -41,6 +41,8 @@ import org.apache.kafka.clients.consumer.KafkaConsumer import org.apache.kafka.clients.producer._ import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.config.SaslConfigs +import org.apache.kafka.common.header.Header +import org.apache.kafka.common.header.internals.RecordHeader import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.security.auth.SecurityProtocol.{PLAINTEXT, SASL_PLAINTEXT} import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer} @@ -369,17 +371,36 @@ class KafkaTestUtils( topic: String, messages: Array[String], partition: Option[Int]): Seq[(String, RecordMetadata)] = { + sendMessages(topic, messages.map(m => (m, Seq())), partition) + } + + /** Send record to the Kafka broker with headers using specified partition */ + def sendMessage(topic: String, + record: (String, Seq[(String, Array[Byte])]), + partition: Option[Int]): Seq[(String, RecordMetadata)] = { + sendMessages(topic, Array(record).toSeq, partition) + } + + /** Send the array of records to the Kafka broker with headers using specified partition */ + def sendMessages(topic: String, + records: Seq[(String, Seq[(String, Array[Byte])])], + partition: Option[Int]): Seq[(String, RecordMetadata)] = { producer = new KafkaProducer[String, String](producerConfiguration) val offsets = try { - messages.map { m => - val record = partition match { - case Some(p) => new ProducerRecord[String, String](topic, p, null, m) - case None => new ProducerRecord[String, String](topic, m) + records.map { case (value, header) => + val headers = header.map { case (k, v) => + new RecordHeader(k, v).asInstanceOf[Header] } - val metadata = - producer.send(record).get(10, TimeUnit.SECONDS) - logInfo(s"\tSent $m to partition ${metadata.partition}, offset ${metadata.offset}") - (m, metadata) + val record = partition match { + case Some(p) => + new ProducerRecord[String, String](topic, p, null, value, headers.asJava) + case None => + new ProducerRecord[String, String](topic, null, null, value, headers.asJava) + } + val metadata = producer.send(record).get(10, TimeUnit.SECONDS) + logInfo(s"\tSent ($value, $header) to partition ${metadata.partition}," + + " offset ${metadata.offset}") + (value, metadata) } } finally { if (producer != null) {