[SPARK-22908][SS] Roll forward continuous processing Kafka support with fix to continuous Kafka data reader

## What changes were proposed in this pull request?

The Kafka reader is now interruptible and can close itself.
## How was this patch tested?

I locally ran one of the ContinuousKafkaSourceSuite tests in a tight loop. Before the fix, my machine ran out of open file descriptors a few iterations in; now it works fine.

Author: Jose Torres <jose@databricks.com>

Closes #20253 from jose-torres/fix-data-reader.
This commit is contained in:
Jose Torres 2018-01-16 18:11:27 -08:00 committed by Tathagata Das
parent a9b845ebb5
commit 1667057851
21 changed files with 1629 additions and 417 deletions

View file

@ -0,0 +1,260 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.kafka010
import java.{util => ju}
import java.util.concurrent.TimeoutException
import org.apache.kafka.clients.consumer.{ConsumerRecord, OffsetOutOfRangeException}
import org.apache.kafka.common.TopicPartition
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String
/**
* A [[ContinuousReader]] for data from kafka.
*
* @param offsetReader a reader used to get kafka offsets. Note that the actual data will be
* read by per-task consumers generated later.
* @param kafkaParams String params for per-task Kafka consumers.
* @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which
* are not Kafka consumer params.
* @param metadataPath Path to a directory this reader can use for writing metadata.
* @param initialOffsets The Kafka offsets to start reading data at.
* @param failOnDataLoss Flag indicating whether reading should fail in data loss
* scenarios, where some offsets after the specified initial ones can't be
* properly read.
*/
class KafkaContinuousReader(
offsetReader: KafkaOffsetReader,
kafkaParams: ju.Map[String, Object],
sourceOptions: Map[String, String],
metadataPath: String,
initialOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
extends ContinuousReader with SupportsScanUnsafeRow with Logging {
private lazy val session = SparkSession.getActiveSession.get
private lazy val sc = session.sparkContext
private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong
// Initialized when creating read tasks. If this diverges from the partitions at the latest
// offsets, we need to reconfigure.
// Exposed outside this object only for unit tests.
private[sql] var knownPartitions: Set[TopicPartition] = _
override def readSchema: StructType = KafkaOffsetReader.kafkaSchema
private var offset: Offset = _
override def setOffset(start: ju.Optional[Offset]): Unit = {
offset = start.orElse {
val offsets = initialOffsets match {
case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets())
case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets())
case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss)
}
logInfo(s"Initial offsets: $offsets")
offsets
}
}
override def getStartOffset(): Offset = offset
override def deserializeOffset(json: String): Offset = {
KafkaSourceOffset(JsonUtils.partitionOffsets(json))
}
override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = {
import scala.collection.JavaConverters._
val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset)
val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet
val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet)
val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq)
val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet)
if (deletedPartitions.nonEmpty) {
reportDataLoss(s"Some partitions were deleted: $deletedPartitions")
}
val startOffsets = newPartitionOffsets ++
oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_))
knownPartitions = startOffsets.keySet
startOffsets.toSeq.map {
case (topicPartition, start) =>
KafkaContinuousReadTask(
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
.asInstanceOf[ReadTask[UnsafeRow]]
}.asJava
}
/** Stop this source and free any resources it has allocated. */
def stop(): Unit = synchronized {
offsetReader.close()
}
override def commit(end: Offset): Unit = {}
override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = {
val mergedMap = offsets.map {
case KafkaSourcePartitionOffset(p, o) => Map(p -> o)
}.reduce(_ ++ _)
KafkaSourceOffset(mergedMap)
}
override def needsReconfiguration(): Boolean = {
knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions
}
override def toString(): String = s"KafkaSource[$offsetReader]"
/**
* If `failOnDataLoss` is true, this method will throw an `IllegalStateException`.
* Otherwise, just log a warning.
*/
private def reportDataLoss(message: String): Unit = {
if (failOnDataLoss) {
throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE")
} else {
logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE")
}
}
}
/**
* A read task for continuous Kafka processing. This will be serialized and transformed into a
* full reader on executors.
*
* @param topicPartition The (topic, partition) pair this task is responsible for.
* @param startOffset The offset to start reading from within the partition.
* @param kafkaParams Kafka consumer params to use.
* @param pollTimeoutMs The timeout for Kafka consumer polling.
* @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
* are skipped.
*/
case class KafkaContinuousReadTask(
topicPartition: TopicPartition,
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] {
override def createDataReader(): KafkaContinuousDataReader = {
new KafkaContinuousDataReader(
topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss)
}
}
/**
* A per-task data reader for continuous Kafka processing.
*
* @param topicPartition The (topic, partition) pair this data reader is responsible for.
* @param startOffset The offset to start reading from within the partition.
* @param kafkaParams Kafka consumer params to use.
* @param pollTimeoutMs The timeout for Kafka consumer polling.
* @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
* are skipped.
*/
class KafkaContinuousDataReader(
topicPartition: TopicPartition,
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] {
private val topic = topicPartition.topic
private val kafkaPartition = topicPartition.partition
private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams)
private val sharedRow = new UnsafeRow(7)
private val bufferHolder = new BufferHolder(sharedRow)
private val rowWriter = new UnsafeRowWriter(bufferHolder, 7)
private var nextKafkaOffset = startOffset
private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _
override def next(): Boolean = {
var r: ConsumerRecord[Array[Byte], Array[Byte]] = null
while (r == null) {
if (TaskContext.get().isInterrupted() || TaskContext.get().isCompleted()) return false
// Our consumer.get is not interruptible, so we have to set a low poll timeout, leaving
// interrupt points to end the query rather than waiting for new data that might never come.
try {
r = consumer.get(
nextKafkaOffset,
untilOffset = Long.MaxValue,
pollTimeoutMs,
failOnDataLoss)
} catch {
// We didn't read within the timeout. We're supposed to block indefinitely for new data, so
// swallow and ignore this.
case _: TimeoutException =>
// This is a failOnDataLoss exception. Retry if nextKafkaOffset is within the data range,
// or if it's the endpoint of the data range (i.e. the "true" next offset).
case e: IllegalStateException if e.getCause.isInstanceOf[OffsetOutOfRangeException] =>
val range = consumer.getAvailableOffsetRange()
if (range.latest >= nextKafkaOffset && range.earliest <= nextKafkaOffset) {
// retry
} else {
throw e
}
}
}
nextKafkaOffset = r.offset + 1
currentRecord = r
true
}
override def get(): UnsafeRow = {
bufferHolder.reset()
if (currentRecord.key == null) {
rowWriter.setNullAt(0)
} else {
rowWriter.write(0, currentRecord.key)
}
rowWriter.write(1, currentRecord.value)
rowWriter.write(2, UTF8String.fromString(currentRecord.topic))
rowWriter.write(3, currentRecord.partition)
rowWriter.write(4, currentRecord.offset)
rowWriter.write(5,
DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp)))
rowWriter.write(6, currentRecord.timestampType.id)
sharedRow.setTotalSize(bufferHolder.totalSize)
sharedRow
}
override def getOffset(): KafkaSourcePartitionOffset = {
KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset)
}
override def close(): Unit = {
consumer.close()
}
}

View file

@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.kafka010
import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata}
import scala.collection.JavaConverters._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection}
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY}
import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery
import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{BinaryType, StringType, StructType}
/**
* Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we
* don't need to really send one.
*/
case object KafkaWriterCommitMessage extends WriterCommitMessage
/**
* A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory.
* @param topic The topic this writer is responsible for. If None, topic will be inferred from
* a `topic` field in the incoming data.
* @param producerParams Parameters for Kafka producers in each task.
* @param schema The schema of the input data.
*/
class KafkaContinuousWriter(
topic: Option[String], producerParams: Map[String, String], schema: StructType)
extends ContinuousWriter with SupportsWriteInternalRow {
validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic)
override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory =
KafkaContinuousWriterFactory(topic, producerParams, schema)
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
override def abort(messages: Array[WriterCommitMessage]): Unit = {}
}
/**
* A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate
* the per-task data writers.
* @param topic The topic that should be written to. If None, topic will be inferred from
* a `topic` field in the incoming data.
* @param producerParams Parameters for Kafka producers in each task.
* @param schema The schema of the input data.
*/
case class KafkaContinuousWriterFactory(
topic: Option[String], producerParams: Map[String, String], schema: StructType)
extends DataWriterFactory[InternalRow] {
override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = {
new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes)
}
}
/**
* A [[DataWriter]] for Kafka writing. One data writer will be created in each partition to
* process incoming rows.
*
* @param targetTopic The topic that this data writer is targeting. If None, topic will be inferred
* from a `topic` field in the incoming data.
* @param producerParams Parameters to use for the Kafka producer.
* @param inputSchema The attributes in the input data.
*/
class KafkaContinuousDataWriter(
targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute])
extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] {
import scala.collection.JavaConverters._
private lazy val producer = CachedKafkaProducer.getOrCreate(
new java.util.HashMap[String, Object](producerParams.asJava))
def write(row: InternalRow): Unit = {
checkForErrors()
sendRow(row, producer)
}
def commit(): WriterCommitMessage = {
// Send is asynchronous, but we can't commit until all rows are actually in Kafka.
// This requires flushing and then checking that no callbacks produced errors.
// We also check for errors before to fail as soon as possible - the check is cheap.
checkForErrors()
producer.flush()
checkForErrors()
KafkaWriterCommitMessage
}
def abort(): Unit = {}
def close(): Unit = {
checkForErrors()
if (producer != null) {
producer.flush()
checkForErrors()
CachedKafkaProducer.close(new java.util.HashMap[String, Object](producerParams.asJava))
}
}
}

View file

@ -117,10 +117,14 @@ private[kafka010] class KafkaOffsetReader(
* Resolves the specific offsets based on Kafka seek positions.
* This method resolves offset value -1 to the latest and -2 to the
* earliest Kafka seek position.
*
* @param partitionOffsets the specific offsets to resolve
* @param reportDataLoss callback to either report or log data loss depending on setting
*/
def fetchSpecificOffsets(
partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] =
runUninterruptibly {
partitionOffsets: Map[TopicPartition, Long],
reportDataLoss: String => Unit): KafkaSourceOffset = {
val fetched = runUninterruptibly {
withRetriesWithoutInterrupt {
// Poll to get the latest assigned partitions
consumer.poll(0)
@ -145,6 +149,19 @@ private[kafka010] class KafkaOffsetReader(
}
}
partitionOffsets.foreach {
case (tp, off) if off != KafkaOffsetRangeLimit.LATEST &&
off != KafkaOffsetRangeLimit.EARLIEST =>
if (fetched(tp) != off) {
reportDataLoss(
s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}")
}
case _ =>
// no real way to check that beginning or end is reasonable
}
KafkaSourceOffset(fetched)
}
/**
* Fetch the earliest offsets for the topic partitions that are indicated
* in the [[ConsumerStrategy]].

View file

@ -130,7 +130,7 @@ private[kafka010] class KafkaSource(
val offsets = startingOffsets match {
case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets())
case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets())
case SpecificOffsetRangeLimit(p) => fetchAndVerify(p)
case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss)
}
metadataLog.add(0, offsets)
logInfo(s"Initial offsets: $offsets")
@ -138,21 +138,6 @@ private[kafka010] class KafkaSource(
}.partitionToOffsets
}
private def fetchAndVerify(specificOffsets: Map[TopicPartition, Long]) = {
val result = kafkaReader.fetchSpecificOffsets(specificOffsets)
specificOffsets.foreach {
case (tp, off) if off != KafkaOffsetRangeLimit.LATEST &&
off != KafkaOffsetRangeLimit.EARLIEST =>
if (result(tp) != off) {
reportDataLoss(
s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}")
}
case _ =>
// no real way to check that beginning or end is reasonable
}
KafkaSourceOffset(result)
}
private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None
override def schema: StructType = KafkaOffsetReader.kafkaSchema

View file

@ -20,17 +20,22 @@ package org.apache.spark.sql.kafka010
import org.apache.kafka.common.TopicPartition
import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset}
import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset}
/**
* An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and
* their offsets.
*/
private[kafka010]
case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset {
case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 {
override val json = JsonUtils.partitionOffsets(partitionToOffsets)
}
private[kafka010]
case class KafkaSourcePartitionOffset(topicPartition: TopicPartition, partitionOffset: Long)
extends PartitionOffset
/** Companion object of the [[KafkaSourceOffset]] */
private[kafka010] object KafkaSourceOffset {

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.kafka010
import java.{util => ju}
import java.util.{Locale, UUID}
import java.util.{Locale, Optional, UUID}
import scala.collection.JavaConverters._
@ -27,9 +27,12 @@ import org.apache.kafka.clients.producer.ProducerConfig
import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options}
import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport}
import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
@ -43,6 +46,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
with StreamSinkProvider
with RelationProvider
with CreatableRelationProvider
with ContinuousWriteSupport
with ContinuousReadSupport
with Logging {
import KafkaSourceProvider._
@ -101,6 +106,43 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
failOnDataLoss(caseInsensitiveParams))
}
override def createContinuousReader(
schema: Optional[StructType],
metadataPath: String,
options: DataSourceV2Options): KafkaContinuousReader = {
val parameters = options.asMap().asScala.toMap
validateStreamOptions(parameters)
// Each running query should use its own group id. Otherwise, the query may be only assigned
// partial data since Kafka will assign partitions to multiple consumers having the same group
// id. Hence, we should generate a unique id for each query.
val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}"
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val specifiedKafkaParams =
parameters
.keySet
.filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
.map { k => k.drop(6).toString -> parameters(k) }
.toMap
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams,
STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
val kafkaOffsetReader = new KafkaOffsetReader(
strategy(caseInsensitiveParams),
kafkaParamsForDriver(specifiedKafkaParams),
parameters,
driverGroupIdPrefix = s"$uniqueGroupId-driver")
new KafkaContinuousReader(
kafkaOffsetReader,
kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
parameters,
metadataPath,
startingStreamOffsets,
failOnDataLoss(caseInsensitiveParams))
}
/**
* Returns a new base relation with the given parameters.
*
@ -181,26 +223,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
}
}
private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = {
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) {
throw new IllegalArgumentException(
s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys "
+ "are serialized with ByteArraySerializer.")
}
override def createContinuousWriter(
queryId: String,
schema: StructType,
mode: OutputMode,
options: DataSourceV2Options): Optional[ContinuousWriter] = {
import scala.collection.JavaConverters._
if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}"))
{
throw new IllegalArgumentException(
s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as "
+ "value are serialized with ByteArraySerializer.")
}
parameters
.keySet
.filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
.map { k => k.drop(6).toString -> parameters(k) }
.toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName,
ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName)
val spark = SparkSession.getActiveSession.get
val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim)
// We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable.
val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap)
KafkaWriter.validateQuery(
schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic)
Optional.of(new KafkaContinuousWriter(topic, producerParams, schema))
}
private def strategy(caseInsensitiveParams: Map[String, String]) =
@ -450,4 +488,27 @@ private[kafka010] object KafkaSourceProvider extends Logging {
def build(): ju.Map[String, Object] = map
}
private[kafka010] def kafkaParamsForProducer(
parameters: Map[String, String]): Map[String, String] = {
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) {
throw new IllegalArgumentException(
s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys "
+ "are serialized with ByteArraySerializer.")
}
if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}"))
{
throw new IllegalArgumentException(
s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as "
+ "value are serialized with ByteArraySerializer.")
}
parameters
.keySet
.filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
.map { k => k.drop(6).toString -> parameters(k) }
.toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName,
ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName)
}
}

View file

@ -33,10 +33,8 @@ import org.apache.spark.sql.types.{BinaryType, StringType}
private[kafka010] class KafkaWriteTask(
producerConfiguration: ju.Map[String, Object],
inputSchema: Seq[Attribute],
topic: Option[String]) {
topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) {
// used to synchronize with Kafka callbacks
@volatile private var failedWrite: Exception = null
private val projection = createProjection
private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _
/**
@ -46,23 +44,7 @@ private[kafka010] class KafkaWriteTask(
producer = CachedKafkaProducer.getOrCreate(producerConfiguration)
while (iterator.hasNext && failedWrite == null) {
val currentRow = iterator.next()
val projectedRow = projection(currentRow)
val topic = projectedRow.getUTF8String(0)
val key = projectedRow.getBinary(1)
val value = projectedRow.getBinary(2)
if (topic == null) {
throw new NullPointerException(s"null topic present in the data. Use the " +
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.")
}
val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value)
val callback = new Callback() {
override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = {
if (failedWrite == null && e != null) {
failedWrite = e
}
}
}
producer.send(record, callback)
sendRow(currentRow, producer)
}
}
@ -74,8 +56,49 @@ private[kafka010] class KafkaWriteTask(
producer = null
}
}
}
private def createProjection: UnsafeProjection = {
private[kafka010] abstract class KafkaRowWriter(
inputSchema: Seq[Attribute], topic: Option[String]) {
// used to synchronize with Kafka callbacks
@volatile protected var failedWrite: Exception = _
protected val projection = createProjection
private val callback = new Callback() {
override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = {
if (failedWrite == null && e != null) {
failedWrite = e
}
}
}
/**
* Send the specified row to the producer, with a callback that will save any exception
* to failedWrite. Note that send is asynchronous; subclasses must flush() their producer before
* assuming the row is in Kafka.
*/
protected def sendRow(
row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): Unit = {
val projectedRow = projection(row)
val topic = projectedRow.getUTF8String(0)
val key = projectedRow.getBinary(1)
val value = projectedRow.getBinary(2)
if (topic == null) {
throw new NullPointerException(s"null topic present in the data. Use the " +
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.")
}
val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value)
producer.send(record, callback)
}
protected def checkForErrors(): Unit = {
if (failedWrite != null) {
throw failedWrite
}
}
private def createProjection = {
val topicExpression = topic.map(Literal(_)).orElse {
inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME)
}.getOrElse {
@ -112,11 +135,5 @@ private[kafka010] class KafkaWriteTask(
Seq(topicExpression, Cast(keyExpression, BinaryType),
Cast(valueExpression, BinaryType)), inputSchema)
}
private def checkForErrors(): Unit = {
if (failedWrite != null) {
throw failedWrite
}
}
}

View file

@ -43,10 +43,9 @@ private[kafka010] object KafkaWriter extends Logging {
override def toString: String = "KafkaWriter"
def validateQuery(
queryExecution: QueryExecution,
schema: Seq[Attribute],
kafkaParameters: ju.Map[String, Object],
topic: Option[String] = None): Unit = {
val schema = queryExecution.analyzed.output
schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse(
if (topic.isEmpty) {
throw new AnalysisException(s"topic option required when no " +
@ -84,7 +83,7 @@ private[kafka010] object KafkaWriter extends Logging {
kafkaParameters: ju.Map[String, Object],
topic: Option[String] = None): Unit = {
val schema = queryExecution.analyzed.output
validateQuery(queryExecution, kafkaParameters, topic)
validateQuery(schema, kafkaParameters, topic)
queryExecution.toRdd.foreachPartition { iter =>
val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic)
Utils.tryWithSafeFinally(block = writeTask.execute(iter))(

View file

@ -0,0 +1,476 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.kafka010
import java.util.Locale
import java.util.concurrent.atomic.AtomicInteger
import org.apache.kafka.clients.producer.ProducerConfig
import org.apache.kafka.common.serialization.ByteArraySerializer
import org.scalatest.time.SpanSugar._
import scala.collection.JavaConverters._
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.types.{BinaryType, DataType}
import org.apache.spark.util.Utils
/**
* This is a temporary port of KafkaSinkSuite, since we do not yet have a V2 memory stream.
* Once we have one, this will be changed to a specialization of KafkaSinkSuite and we won't have
* to duplicate all the code.
*/
class KafkaContinuousSinkSuite extends KafkaContinuousTest {
import testImplicits._
override val streamingTimeout = 30.seconds
override def beforeAll(): Unit = {
super.beforeAll()
testUtils = new KafkaTestUtils(
withBrokerProps = Map("auto.create.topics.enable" -> "false"))
testUtils.setup()
}
override def afterAll(): Unit = {
if (testUtils != null) {
testUtils.teardown()
testUtils = null
}
super.afterAll()
}
test("streaming - write to kafka with topic field") {
val inputTopic = newTopic()
testUtils.createTopic(inputTopic, partitions = 1)
val input = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("subscribe", inputTopic)
.option("startingOffsets", "earliest")
.load()
val topic = newTopic()
testUtils.createTopic(topic)
val writer = createKafkaWriter(
input.toDF(),
withTopic = None,
withOutputMode = Some(OutputMode.Append))(
withSelectExpr = s"'$topic' as topic", "value")
val reader = createKafkaReader(topic)
.selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
.selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
.as[(Int, Int)]
.map(_._2)
try {
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
eventually(timeout(streamingTimeout)) {
checkDatasetUnorderly(reader, 1, 2, 3, 4, 5)
}
testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10"))
eventually(timeout(streamingTimeout)) {
checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
}
} finally {
writer.stop()
}
}
test("streaming - write w/o topic field, with topic option") {
val inputTopic = newTopic()
testUtils.createTopic(inputTopic, partitions = 1)
val input = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("subscribe", inputTopic)
.option("startingOffsets", "earliest")
.load()
val topic = newTopic()
testUtils.createTopic(topic)
val writer = createKafkaWriter(
input.toDF(),
withTopic = Some(topic),
withOutputMode = Some(OutputMode.Append()))()
val reader = createKafkaReader(topic)
.selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
.selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
.as[(Int, Int)]
.map(_._2)
try {
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
eventually(timeout(streamingTimeout)) {
checkDatasetUnorderly(reader, 1, 2, 3, 4, 5)
}
testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10"))
eventually(timeout(streamingTimeout)) {
checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
}
} finally {
writer.stop()
}
}
test("streaming - topic field and topic option") {
/* The purpose of this test is to ensure that the topic option
* overrides the topic field. We begin by writing some data that
* includes a topic field and value (e.g., 'foo') along with a topic
* option. Then when we read from the topic specified in the option
* we should see the data i.e., the data was written to the topic
* option, and not to the topic in the data e.g., foo
*/
val inputTopic = newTopic()
testUtils.createTopic(inputTopic, partitions = 1)
val input = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("subscribe", inputTopic)
.option("startingOffsets", "earliest")
.load()
val topic = newTopic()
testUtils.createTopic(topic)
val writer = createKafkaWriter(
input.toDF(),
withTopic = Some(topic),
withOutputMode = Some(OutputMode.Append()))(
withSelectExpr = "'foo' as topic", "CAST(value as STRING) value")
val reader = createKafkaReader(topic)
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
.selectExpr("CAST(key AS INT)", "CAST(value AS INT)")
.as[(Int, Int)]
.map(_._2)
try {
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
eventually(timeout(streamingTimeout)) {
checkDatasetUnorderly(reader, 1, 2, 3, 4, 5)
}
testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10"))
eventually(timeout(streamingTimeout)) {
checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
}
} finally {
writer.stop()
}
}
test("null topic attribute") {
val inputTopic = newTopic()
testUtils.createTopic(inputTopic, partitions = 1)
val input = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("subscribe", inputTopic)
.option("startingOffsets", "earliest")
.load()
val topic = newTopic()
testUtils.createTopic(topic)
/* No topic field or topic option */
var writer: StreamingQuery = null
var ex: Exception = null
try {
writer = createKafkaWriter(input.toDF())(
withSelectExpr = "CAST(null as STRING) as topic", "value"
)
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
eventually(timeout(streamingTimeout)) {
assert(writer.exception.isDefined)
ex = writer.exception.get
}
} finally {
writer.stop()
}
assert(ex.getCause.getCause.getMessage
.toLowerCase(Locale.ROOT)
.contains("null topic present in the data."))
}
test("streaming - write data with bad schema") {
val inputTopic = newTopic()
testUtils.createTopic(inputTopic, partitions = 1)
val input = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("subscribe", inputTopic)
.option("startingOffsets", "earliest")
.load()
val topic = newTopic()
testUtils.createTopic(topic)
/* No topic field or topic option */
var writer: StreamingQuery = null
var ex: Exception = null
try {
writer = createKafkaWriter(input.toDF())(
withSelectExpr = "value as key", "value"
)
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
eventually(timeout(streamingTimeout)) {
assert(writer.exception.isDefined)
ex = writer.exception.get
}
} finally {
writer.stop()
}
assert(ex.getMessage
.toLowerCase(Locale.ROOT)
.contains("topic option required when no 'topic' attribute is present"))
try {
/* No value field */
writer = createKafkaWriter(input.toDF())(
withSelectExpr = s"'$topic' as topic", "value as key"
)
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
eventually(timeout(streamingTimeout)) {
assert(writer.exception.isDefined)
ex = writer.exception.get
}
} finally {
writer.stop()
}
assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
"required attribute 'value' not found"))
}
test("streaming - write data with valid schema but wrong types") {
val inputTopic = newTopic()
testUtils.createTopic(inputTopic, partitions = 1)
val input = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("subscribe", inputTopic)
.option("startingOffsets", "earliest")
.load()
.selectExpr("CAST(value as STRING) value")
val topic = newTopic()
testUtils.createTopic(topic)
var writer: StreamingQuery = null
var ex: Exception = null
try {
/* topic field wrong type */
writer = createKafkaWriter(input.toDF())(
withSelectExpr = s"CAST('1' as INT) as topic", "value"
)
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
eventually(timeout(streamingTimeout)) {
assert(writer.exception.isDefined)
ex = writer.exception.get
}
} finally {
writer.stop()
}
assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string"))
try {
/* value field wrong type */
writer = createKafkaWriter(input.toDF())(
withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value"
)
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
eventually(timeout(streamingTimeout)) {
assert(writer.exception.isDefined)
ex = writer.exception.get
}
} finally {
writer.stop()
}
assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
"value attribute type must be a string or binarytype"))
try {
/* key field wrong type */
writer = createKafkaWriter(input.toDF())(
withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value"
)
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
eventually(timeout(streamingTimeout)) {
assert(writer.exception.isDefined)
ex = writer.exception.get
}
} finally {
writer.stop()
}
assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
"key attribute type must be a string or binarytype"))
}
test("streaming - write to non-existing topic") {
val inputTopic = newTopic()
testUtils.createTopic(inputTopic, partitions = 1)
val input = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("subscribe", inputTopic)
.option("startingOffsets", "earliest")
.load()
val topic = newTopic()
var writer: StreamingQuery = null
var ex: Exception = null
try {
ex = intercept[StreamingQueryException] {
writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))()
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
eventually(timeout(streamingTimeout)) {
assert(writer.exception.isDefined)
}
throw writer.exception.get
}
} finally {
writer.stop()
}
assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted"))
}
test("streaming - exception on config serializer") {
val inputTopic = newTopic()
testUtils.createTopic(inputTopic, partitions = 1)
testUtils.sendMessages(inputTopic, Array("0"))
val input = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("subscribe", inputTopic)
.load()
var writer: StreamingQuery = null
var ex: Exception = null
try {
writer = createKafkaWriter(
input.toDF(),
withOptions = Map("kafka.key.serializer" -> "foo"))()
eventually(timeout(streamingTimeout)) {
assert(writer.exception.isDefined)
ex = writer.exception.get
}
assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
"kafka option 'key.serializer' is not supported"))
} finally {
writer.stop()
}
try {
writer = createKafkaWriter(
input.toDF(),
withOptions = Map("kafka.value.serializer" -> "foo"))()
eventually(timeout(streamingTimeout)) {
assert(writer.exception.isDefined)
ex = writer.exception.get
}
assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
"kafka option 'value.serializer' is not supported"))
} finally {
writer.stop()
}
}
test("generic - write big data with small producer buffer") {
/* This test ensures that we understand the semantics of Kafka when
* is comes to blocking on a call to send when the send buffer is full.
* This test will configure the smallest possible producer buffer and
* indicate that we should block when it is full. Thus, no exception should
* be thrown in the case of a full buffer.
*/
val topic = newTopic()
testUtils.createTopic(topic, 1)
val options = new java.util.HashMap[String, String]
options.put("bootstrap.servers", testUtils.brokerAddress)
options.put("buffer.memory", "16384") // min buffer size
options.put("block.on.buffer.full", "true")
options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName)
options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName)
val inputSchema = Seq(AttributeReference("value", BinaryType)())
val data = new Array[Byte](15000) // large value
val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema)
try {
val fieldTypes: Array[DataType] = Array(BinaryType)
val converter = UnsafeProjection.create(fieldTypes)
val row = new SpecificInternalRow(fieldTypes)
row.update(0, data)
val iter = Seq.fill(1000)(converter.apply(row)).iterator
iter.foreach(writeTask.write(_))
writeTask.commit()
} finally {
writeTask.close()
}
}
private def createKafkaReader(topic: String): DataFrame = {
spark.read
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("startingOffsets", "earliest")
.option("endingOffsets", "latest")
.option("subscribe", topic)
.load()
}
private def createKafkaWriter(
input: DataFrame,
withTopic: Option[String] = None,
withOutputMode: Option[OutputMode] = None,
withOptions: Map[String, String] = Map[String, String]())
(withSelectExpr: String*): StreamingQuery = {
var stream: DataStreamWriter[Row] = null
val checkpointDir = Utils.createTempDir()
var df = input.toDF()
if (withSelectExpr.length > 0) {
df = df.selectExpr(withSelectExpr: _*)
}
stream = df.writeStream
.format("kafka")
.option("checkpointLocation", checkpointDir.getCanonicalPath)
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
// We need to reduce blocking time to efficiently test non-existent partition behavior.
.option("kafka.max.block.ms", "1000")
.trigger(Trigger.Continuous(1000))
.queryName("kafkaStream")
withTopic.foreach(stream.option("topic", _))
withOutputMode.foreach(stream.outputMode(_))
withOptions.foreach(opt => stream.option(opt._1, opt._2))
stream.start()
}
}

View file

@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.kafka010
import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger
import org.scalatest.time.SpanSugar._
import scala.collection.mutable
import scala.util.Random
import org.apache.spark.SparkContext
import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.streaming.{StreamTest, Trigger}
import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}
// Run tests in KafkaSourceSuiteBase in continuous execution mode.
class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest
class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest {
import testImplicits._
override val brokerProps = Map("auto.create.topics.enable" -> "false")
test("subscribing topic by pattern with topic deletions") {
val topicPrefix = newTopic()
val topic = topicPrefix + "-seems"
val topic2 = topicPrefix + "-bad"
testUtils.createTopic(topic, partitions = 5)
testUtils.sendMessages(topic, Array("-1"))
require(testUtils.getLatestOffsets(Set(topic)).size === 5)
val reader = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("kafka.metadata.max.age.ms", "1")
.option("subscribePattern", s"$topicPrefix-.*")
.option("failOnDataLoss", "false")
val kafka = reader.load()
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
.as[(String, String)]
val mapped = kafka.map(kv => kv._2.toInt + 1)
testStream(mapped)(
makeSureGetOffsetCalled,
AddKafkaData(Set(topic), 1, 2, 3),
CheckAnswer(2, 3, 4),
Execute { query =>
testUtils.deleteTopic(topic)
testUtils.createTopic(topic2, partitions = 5)
eventually(timeout(streamingTimeout)) {
assert(
query.lastExecution.logical.collectFirst {
case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
}.exists { r =>
// Ensure the new topic is present and the old topic is gone.
r.knownPartitions.exists(_.topic == topic2)
},
s"query never reconfigured to new topic $topic2")
}
},
AddKafkaData(Set(topic2), 4, 5, 6),
CheckAnswer(2, 3, 4, 5, 6, 7)
)
}
}
class KafkaContinuousSourceStressForDontFailOnDataLossSuite
extends KafkaSourceStressForDontFailOnDataLossSuite {
override protected def startStream(ds: Dataset[Int]) = {
ds.writeStream
.format("memory")
.queryName("memory")
.start()
}
}

View file

@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.kafka010
import java.util.concurrent.atomic.AtomicInteger
import org.apache.spark.SparkContext
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.streaming.Trigger
import org.apache.spark.sql.test.TestSparkSession
// Trait to configure StreamTest for kafka continuous execution tests.
trait KafkaContinuousTest extends KafkaSourceTest {
override val defaultTrigger = Trigger.Continuous(1000)
override val defaultUseV2Sink = true
// We need more than the default local[2] to be able to schedule all partitions simultaneously.
override protected def createSparkSession = new TestSparkSession(
new SparkContext(
"local[10]",
"continuous-stream-test-sql-context",
sparkConf.set("spark.sql.testkey", "true")))
// In addition to setting the partitions in Kafka, we have to wait until the query has
// reconfigured to the new count so the test framework can hook in properly.
override protected def setTopicPartitions(
topic: String, newCount: Int, query: StreamExecution) = {
testUtils.addPartitions(topic, newCount)
eventually(timeout(streamingTimeout)) {
assert(
query.lastExecution.logical.collectFirst {
case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
}.exists(_.knownPartitions.size == newCount),
s"query never reconfigured to $newCount partitions")
}
}
// Continuous processing tasks end asynchronously, so test that they actually end.
private val tasksEndedListener = new SparkListener() {
val activeTaskIdCount = new AtomicInteger(0)
override def onTaskStart(start: SparkListenerTaskStart): Unit = {
activeTaskIdCount.incrementAndGet()
}
override def onTaskEnd(end: SparkListenerTaskEnd): Unit = {
activeTaskIdCount.decrementAndGet()
}
}
override def beforeEach(): Unit = {
super.beforeEach()
spark.sparkContext.addSparkListener(tasksEndedListener)
}
override def afterEach(): Unit = {
eventually(timeout(streamingTimeout)) {
assert(tasksEndedListener.activeTaskIdCount.get() == 0)
}
spark.sparkContext.removeSparkListener(tasksEndedListener)
super.afterEach()
}
test("ensure continuous stream is being used") {
val query = spark.readStream
.format("rate")
.option("numPartitions", "1")
.option("rowsPerSecond", "1")
.load()
testStream(query)(
Execute(q => assert(q.isInstanceOf[ContinuousExecution]))
)
}
}

View file

@ -34,11 +34,14 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkContext
import org.apache.spark.sql.ForeachWriter
import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter
import org.apache.spark.sql.functions.{count, window}
import org.apache.spark.sql.kafka010.KafkaSourceProvider._
import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest}
import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger}
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}
import org.apache.spark.util.Utils
@ -49,9 +52,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
override val streamingTimeout = 30.seconds
protected val brokerProps = Map[String, Object]()
override def beforeAll(): Unit = {
super.beforeAll()
testUtils = new KafkaTestUtils
testUtils = new KafkaTestUtils(brokerProps)
testUtils.setup()
}
@ -59,18 +64,25 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
if (testUtils != null) {
testUtils.teardown()
testUtils = null
super.afterAll()
}
super.afterAll()
}
protected def makeSureGetOffsetCalled = AssertOnQuery { q =>
// Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure
// its "getOffset" is called before pushing any data. Otherwise, because of the race contion,
// its "getOffset" is called before pushing any data. Otherwise, because of the race condition,
// we don't know which data should be fetched when `startingOffsets` is latest.
q.processAllAvailable()
q match {
case c: ContinuousExecution => c.awaitEpoch(0)
case m: MicroBatchExecution => m.processAllAvailable()
}
true
}
protected def setTopicPartitions(topic: String, newCount: Int, query: StreamExecution) : Unit = {
testUtils.addPartitions(topic, newCount)
}
/**
* Add data to Kafka.
*
@ -82,10 +94,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
message: String = "",
topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData {
override def addData(query: Option[StreamExecution]): (Source, Offset) = {
if (query.get.isActive) {
override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
query match {
// Make sure no Spark job is running when deleting a topic
query.get.processAllAvailable()
case Some(m: MicroBatchExecution) => m.processAllAvailable()
case _ =>
}
val existingTopics = testUtils.getAllTopicsAndPartitionSize().toMap
@ -97,16 +110,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2))
}
// Read all topics again in case some topics are delete.
val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys
require(
query.nonEmpty,
"Cannot add data when there is no query for finding the active kafka source")
val sources = query.get.logicalPlan.collect {
case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] =>
source.asInstanceOf[KafkaSource]
}
case StreamingExecutionRelation(source: KafkaSource, _) => source
} ++ (query.get.lastExecution match {
case null => Seq()
case e => e.logical.collect {
case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader
}
})
if (sources.isEmpty) {
throw new Exception(
"Could not find Kafka source in the StreamExecution logical plan to add data to")
@ -137,14 +152,158 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
override def toString: String =
s"AddKafkaData(topics = $topics, data = $data, message = $message)"
}
private val topicId = new AtomicInteger(0)
protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}"
}
class KafkaSourceSuite extends KafkaSourceTest {
class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase {
import testImplicits._
private val topicId = new AtomicInteger(0)
test("(de)serialization of initial offsets") {
val topic = newTopic()
testUtils.createTopic(topic, partitions = 5)
val reader = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("subscribe", topic)
testStream(reader.load)(
makeSureGetOffsetCalled,
StopStream,
StartStream(),
StopStream)
}
test("maxOffsetsPerTrigger") {
val topic = newTopic()
testUtils.createTopic(topic, partitions = 3)
testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0))
testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1))
testUtils.sendMessages(topic, Array("1"), Some(2))
val reader = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("kafka.metadata.max.age.ms", "1")
.option("maxOffsetsPerTrigger", 10)
.option("subscribe", topic)
.option("startingOffsets", "earliest")
val kafka = reader.load()
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
.as[(String, String)]
val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt)
val clock = new StreamManualClock
val waitUntilBatchProcessed = AssertOnQuery { q =>
eventually(Timeout(streamingTimeout)) {
if (!q.exception.isDefined) {
assert(clock.isStreamWaitingAt(clock.getTimeMillis()))
}
}
if (q.exception.isDefined) {
throw q.exception.get
}
true
}
testStream(mapped)(
StartStream(ProcessingTime(100), clock),
waitUntilBatchProcessed,
// 1 from smallest, 1 from middle, 8 from biggest
CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107),
AdvanceManualClock(100),
waitUntilBatchProcessed,
// smallest now empty, 1 more from middle, 9 more from biggest
CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
11, 108, 109, 110, 111, 112, 113, 114, 115, 116
),
StopStream,
StartStream(ProcessingTime(100), clock),
waitUntilBatchProcessed,
// smallest now empty, 1 more from middle, 9 more from biggest
CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
12, 117, 118, 119, 120, 121, 122, 123, 124, 125
),
AdvanceManualClock(100),
waitUntilBatchProcessed,
// smallest now empty, 1 more from middle, 9 more from biggest
CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
12, 117, 118, 119, 120, 121, 122, 123, 124, 125,
13, 126, 127, 128, 129, 130, 131, 132, 133, 134
)
)
}
test("input row metrics") {
val topic = newTopic()
testUtils.createTopic(topic, partitions = 5)
testUtils.sendMessages(topic, Array("-1"))
require(testUtils.getLatestOffsets(Set(topic)).size === 5)
val kafka = spark
.readStream
.format("kafka")
.option("subscribe", topic)
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.load()
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
.as[(String, String)]
val mapped = kafka.map(kv => kv._2.toInt + 1)
testStream(mapped)(
StartStream(trigger = ProcessingTime(1)),
makeSureGetOffsetCalled,
AddKafkaData(Set(topic), 1, 2, 3),
CheckAnswer(2, 3, 4),
AssertOnQuery { query =>
val recordsRead = query.recentProgress.map(_.numInputRows).sum
recordsRead == 3
}
)
}
test("subscribing topic by pattern with topic deletions") {
val topicPrefix = newTopic()
val topic = topicPrefix + "-seems"
val topic2 = topicPrefix + "-bad"
testUtils.createTopic(topic, partitions = 5)
testUtils.sendMessages(topic, Array("-1"))
require(testUtils.getLatestOffsets(Set(topic)).size === 5)
val reader = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("kafka.metadata.max.age.ms", "1")
.option("subscribePattern", s"$topicPrefix-.*")
.option("failOnDataLoss", "false")
val kafka = reader.load()
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
.as[(String, String)]
val mapped = kafka.map(kv => kv._2.toInt + 1)
testStream(mapped)(
makeSureGetOffsetCalled,
AddKafkaData(Set(topic), 1, 2, 3),
CheckAnswer(2, 3, 4),
Assert {
testUtils.deleteTopic(topic)
testUtils.createTopic(topic2, partitions = 5)
true
},
AddKafkaData(Set(topic2), 4, 5, 6),
CheckAnswer(2, 3, 4, 5, 6, 7)
)
}
testWithUninterruptibleThread(
"deserialization of initial offset with Spark 2.1.0") {
@ -237,86 +396,94 @@ class KafkaSourceSuite extends KafkaSourceTest {
}
}
test("(de)serialization of initial offsets") {
test("KafkaSource with watermark") {
val now = System.currentTimeMillis()
val topic = newTopic()
testUtils.createTopic(topic, partitions = 64)
testUtils.createTopic(newTopic(), partitions = 1)
testUtils.sendMessages(topic, Array(1).map(_.toString))
val reader = spark
val kafka = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("kafka.metadata.max.age.ms", "1")
.option("startingOffsets", s"earliest")
.option("subscribe", topic)
.load()
testStream(reader.load)(
makeSureGetOffsetCalled,
StopStream,
StartStream(),
StopStream)
val windowedAggregation = kafka
.withWatermark("timestamp", "10 seconds")
.groupBy(window($"timestamp", "5 seconds") as 'window)
.agg(count("*") as 'count)
.select($"window".getField("start") as 'window, $"count")
val query = windowedAggregation
.writeStream
.format("memory")
.outputMode("complete")
.queryName("kafkaWatermark")
.start()
query.processAllAvailable()
val rows = spark.table("kafkaWatermark").collect()
assert(rows.length === 1, s"Unexpected results: ${rows.toList}")
val row = rows(0)
// We cannot check the exact window start time as it depands on the time that messages were
// inserted by the producer. So here we just use a low bound to make sure the internal
// conversion works.
assert(
row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000,
s"Unexpected results: $row")
assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row")
query.stop()
}
test("maxOffsetsPerTrigger") {
test("delete a topic when a Spark job is running") {
KafkaSourceSuite.collectedData.clear()
val topic = newTopic()
testUtils.createTopic(topic, partitions = 3)
testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0))
testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1))
testUtils.sendMessages(topic, Array("1"), Some(2))
testUtils.createTopic(topic, partitions = 1)
testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray)
val reader = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("kafka.metadata.max.age.ms", "1")
.option("maxOffsetsPerTrigger", 10)
.option("subscribe", topic)
// If a topic is deleted and we try to poll data starting from offset 0,
// the Kafka consumer will just block until timeout and return an empty result.
// So set the timeout to 1 second to make this test fast.
.option("kafkaConsumer.pollTimeoutMs", "1000")
.option("startingOffsets", "earliest")
.option("failOnDataLoss", "false")
val kafka = reader.load()
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
.as[(String, String)]
val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt)
val clock = new StreamManualClock
val waitUntilBatchProcessed = AssertOnQuery { q =>
eventually(Timeout(streamingTimeout)) {
if (!q.exception.isDefined) {
assert(clock.isStreamWaitingAt(clock.getTimeMillis()))
}
KafkaSourceSuite.globalTestUtils = testUtils
// The following ForeachWriter will delete the topic before fetching data from Kafka
// in executors.
val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] {
override def open(partitionId: Long, version: Long): Boolean = {
KafkaSourceSuite.globalTestUtils.deleteTopic(topic)
true
}
if (q.exception.isDefined) {
throw q.exception.get
}
true
}
testStream(mapped)(
StartStream(ProcessingTime(100), clock),
waitUntilBatchProcessed,
// 1 from smallest, 1 from middle, 8 from biggest
CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107),
AdvanceManualClock(100),
waitUntilBatchProcessed,
// smallest now empty, 1 more from middle, 9 more from biggest
CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
11, 108, 109, 110, 111, 112, 113, 114, 115, 116
),
StopStream,
StartStream(ProcessingTime(100), clock),
waitUntilBatchProcessed,
// smallest now empty, 1 more from middle, 9 more from biggest
CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
12, 117, 118, 119, 120, 121, 122, 123, 124, 125
),
AdvanceManualClock(100),
waitUntilBatchProcessed,
// smallest now empty, 1 more from middle, 9 more from biggest
CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
12, 117, 118, 119, 120, 121, 122, 123, 124, 125,
13, 126, 127, 128, 129, 130, 131, 132, 133, 134
)
)
override def process(value: Int): Unit = {
KafkaSourceSuite.collectedData.add(value)
}
override def close(errorOrNull: Throwable): Unit = {}
}).start()
query.processAllAvailable()
query.stop()
// `failOnDataLoss` is `false`, we should not fail the query
assert(query.exception.isEmpty)
}
}
class KafkaSourceSuiteBase extends KafkaSourceTest {
import testImplicits._
test("SPARK-22956: currentPartitionOffsets should be set when no new data comes in") {
def getSpecificDF(range: Range.Inclusive): org.apache.spark.sql.Dataset[Int] = {
@ -393,7 +560,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("kafka.metadata.max.age.ms", "1")
.option("subscribePattern", s"topic-.*")
.option("subscribePattern", s"$topic.*")
val kafka = reader.load()
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
@ -487,65 +654,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
}
}
test("subscribing topic by pattern with topic deletions") {
val topicPrefix = newTopic()
val topic = topicPrefix + "-seems"
val topic2 = topicPrefix + "-bad"
testUtils.createTopic(topic, partitions = 5)
testUtils.sendMessages(topic, Array("-1"))
require(testUtils.getLatestOffsets(Set(topic)).size === 5)
val reader = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("kafka.metadata.max.age.ms", "1")
.option("subscribePattern", s"$topicPrefix-.*")
.option("failOnDataLoss", "false")
val kafka = reader.load()
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
.as[(String, String)]
val mapped = kafka.map(kv => kv._2.toInt + 1)
testStream(mapped)(
makeSureGetOffsetCalled,
AddKafkaData(Set(topic), 1, 2, 3),
CheckAnswer(2, 3, 4),
Assert {
testUtils.deleteTopic(topic)
testUtils.createTopic(topic2, partitions = 5)
true
},
AddKafkaData(Set(topic2), 4, 5, 6),
CheckAnswer(2, 3, 4, 5, 6, 7)
)
}
test("starting offset is latest by default") {
val topic = newTopic()
testUtils.createTopic(topic, partitions = 5)
testUtils.sendMessages(topic, Array("0"))
require(testUtils.getLatestOffsets(Set(topic)).size === 5)
val reader = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("subscribe", topic)
val kafka = reader.load()
.selectExpr("CAST(value AS STRING)")
.as[String]
val mapped = kafka.map(_.toInt)
testStream(mapped)(
makeSureGetOffsetCalled,
AddKafkaData(Set(topic), 1, 2, 3),
CheckAnswer(1, 2, 3) // should not have 0
)
}
test("bad source options") {
def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = {
val ex = intercept[IllegalArgumentException] {
@ -605,77 +713,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
testUnsupportedConfig("kafka.auto.offset.reset", "latest")
}
test("input row metrics") {
val topic = newTopic()
testUtils.createTopic(topic, partitions = 5)
testUtils.sendMessages(topic, Array("-1"))
require(testUtils.getLatestOffsets(Set(topic)).size === 5)
val kafka = spark
.readStream
.format("kafka")
.option("subscribe", topic)
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.load()
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
.as[(String, String)]
val mapped = kafka.map(kv => kv._2.toInt + 1)
testStream(mapped)(
StartStream(trigger = ProcessingTime(1)),
makeSureGetOffsetCalled,
AddKafkaData(Set(topic), 1, 2, 3),
CheckAnswer(2, 3, 4),
AssertOnQuery { query =>
val recordsRead = query.recentProgress.map(_.numInputRows).sum
recordsRead == 3
}
)
}
test("delete a topic when a Spark job is running") {
KafkaSourceSuite.collectedData.clear()
val topic = newTopic()
testUtils.createTopic(topic, partitions = 1)
testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray)
val reader = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("kafka.metadata.max.age.ms", "1")
.option("subscribe", topic)
// If a topic is deleted and we try to poll data starting from offset 0,
// the Kafka consumer will just block until timeout and return an empty result.
// So set the timeout to 1 second to make this test fast.
.option("kafkaConsumer.pollTimeoutMs", "1000")
.option("startingOffsets", "earliest")
.option("failOnDataLoss", "false")
val kafka = reader.load()
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
.as[(String, String)]
KafkaSourceSuite.globalTestUtils = testUtils
// The following ForeachWriter will delete the topic before fetching data from Kafka
// in executors.
val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] {
override def open(partitionId: Long, version: Long): Boolean = {
KafkaSourceSuite.globalTestUtils.deleteTopic(topic)
true
}
override def process(value: Int): Unit = {
KafkaSourceSuite.collectedData.add(value)
}
override def close(errorOrNull: Throwable): Unit = {}
}).start()
query.processAllAvailable()
query.stop()
// `failOnDataLoss` is `false`, we should not fail the query
assert(query.exception.isEmpty)
}
test("get offsets from case insensitive parameters") {
for ((optionKey, optionValue, answer) <- Seq(
(STARTING_OFFSETS_OPTION_KEY, "earLiEst", EarliestOffsetRangeLimit),
@ -694,8 +731,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
}
}
private def newTopic(): String = s"topic-${topicId.getAndIncrement()}"
private def assignString(topic: String, partitions: Iterable[Int]): String = {
JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p)))
}
@ -741,6 +776,10 @@ class KafkaSourceSuite extends KafkaSourceTest {
testStream(mapped)(
makeSureGetOffsetCalled,
Execute { q =>
// wait to reach the last offset in every partition
q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L)))
},
CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22),
StopStream,
StartStream(),
@ -771,10 +810,13 @@ class KafkaSourceSuite extends KafkaSourceTest {
.format("memory")
.outputMode("append")
.queryName("kafkaColumnTypes")
.trigger(defaultTrigger)
.start()
query.processAllAvailable()
val rows = spark.table("kafkaColumnTypes").collect()
assert(rows.length === 1, s"Unexpected results: ${rows.toList}")
var rows: Array[Row] = Array()
eventually(timeout(streamingTimeout)) {
rows = spark.table("kafkaColumnTypes").collect()
assert(rows.length === 1, s"Unexpected results: ${rows.toList}")
}
val row = rows(0)
assert(row.getAs[Array[Byte]]("key") === null, s"Unexpected results: $row")
assert(row.getAs[Array[Byte]]("value") === "1".getBytes(UTF_8), s"Unexpected results: $row")
@ -788,47 +830,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
query.stop()
}
test("KafkaSource with watermark") {
val now = System.currentTimeMillis()
val topic = newTopic()
testUtils.createTopic(newTopic(), partitions = 1)
testUtils.sendMessages(topic, Array(1).map(_.toString))
val kafka = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("kafka.metadata.max.age.ms", "1")
.option("startingOffsets", s"earliest")
.option("subscribe", topic)
.load()
val windowedAggregation = kafka
.withWatermark("timestamp", "10 seconds")
.groupBy(window($"timestamp", "5 seconds") as 'window)
.agg(count("*") as 'count)
.select($"window".getField("start") as 'window, $"count")
val query = windowedAggregation
.writeStream
.format("memory")
.outputMode("complete")
.queryName("kafkaWatermark")
.start()
query.processAllAvailable()
val rows = spark.table("kafkaWatermark").collect()
assert(rows.length === 1, s"Unexpected results: ${rows.toList}")
val row = rows(0)
// We cannot check the exact window start time as it depands on the time that messages were
// inserted by the producer. So here we just use a low bound to make sure the internal
// conversion works.
assert(
row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000,
s"Unexpected results: $row")
assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row")
query.stop()
}
private def testFromLatestOffsets(
topic: String,
addPartitions: Boolean,
@ -865,9 +866,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
AddKafkaData(Set(topic), 7, 8),
CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9),
AssertOnQuery("Add partitions") { query: StreamExecution =>
if (addPartitions) {
testUtils.addPartitions(topic, 10)
}
if (addPartitions) setTopicPartitions(topic, 10, query)
true
},
AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16),
@ -908,9 +907,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
StartStream(),
CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9),
AssertOnQuery("Add partitions") { query: StreamExecution =>
if (addPartitions) {
testUtils.addPartitions(topic, 10)
}
if (addPartitions) setTopicPartitions(topic, 10, query)
true
},
AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16),
@ -1042,20 +1039,8 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared
}
}
test("stress test for failOnDataLoss=false") {
val reader = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("kafka.metadata.max.age.ms", "1")
.option("subscribePattern", "failOnDataLoss.*")
.option("startingOffsets", "earliest")
.option("failOnDataLoss", "false")
.option("fetchOffset.retryIntervalMs", "3000")
val kafka = reader.load()
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
.as[(String, String)]
val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] {
protected def startStream(ds: Dataset[Int]) = {
ds.writeStream.foreach(new ForeachWriter[Int] {
override def open(partitionId: Long, version: Long): Boolean = {
true
@ -1069,6 +1054,22 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared
override def close(errorOrNull: Throwable): Unit = {
}
}).start()
}
test("stress test for failOnDataLoss=false") {
val reader = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("kafka.metadata.max.age.ms", "1")
.option("subscribePattern", "failOnDataLoss.*")
.option("startingOffsets", "earliest")
.option("failOnDataLoss", "false")
.option("fetchOffset.retryIntervalMs", "3000")
val kafka = reader.load()
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
.as[(String, String)]
val query = startStream(kafka.map(kv => kv._2.toInt))
val testTime = 1.minutes
val startTime = System.currentTimeMillis()

View file

@ -191,6 +191,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
ds = ds.asInstanceOf[DataSourceV2],
conf = sparkSession.sessionState.conf)).asJava)
// Streaming also uses the data source V2 API. So it may be that the data source implements
// v2, but has no v2 implementation for batch reads. In that case, we fall back to loading
// the dataframe as a v1 source.
val reader = (ds, userSpecifiedSchema) match {
case (ds: ReadSupportWithSchema, Some(schema)) =>
ds.createReader(schema, options)
@ -208,23 +211,30 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}
reader
case _ =>
throw new AnalysisException(s"$cls does not support data reading.")
case _ => null // fall back to v1
}
Dataset.ofRows(sparkSession, DataSourceV2Relation(reader))
if (reader == null) {
loadV1Source(paths: _*)
} else {
Dataset.ofRows(sparkSession, DataSourceV2Relation(reader))
}
} else {
// Code path for data source v1.
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = paths,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap).resolveRelation())
loadV1Source(paths: _*)
}
}
private def loadV1Source(paths: String*) = {
// Code path for data source v1.
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = paths,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap).resolveRelation())
}
/**
* Construct a `DataFrame` representing the database table accessible via JDBC URL
* url named table and connection properties.

View file

@ -255,17 +255,24 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}
}
case _ => throw new AnalysisException(s"$cls does not support data writing.")
// Streaming also uses the data source V2 API. So it may be that the data source implements
// v2, but has no v2 implementation for batch writes. In that case, we fall back to saving
// as though it's a V1 source.
case _ => saveToV1Source()
}
} else {
// Code path for data source v1.
runCommand(df.sparkSession, "save") {
DataSource(
sparkSession = df.sparkSession,
className = source,
partitionColumns = partitioningColumns.getOrElse(Nil),
options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan))
}
saveToV1Source()
}
}
private def saveToV1Source(): Unit = {
// Code path for data source v1.
runCommand(df.sparkSession, "save") {
DataSource(
sparkSession = df.sparkSession,
className = source,
partitionColumns = partitioningColumns.getOrElse(Nil),
options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan))
}
}

View file

@ -81,9 +81,11 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan)
(index, message: WriterCommitMessage) => messages(index) = message
)
logInfo(s"Data source writer $writer is committing.")
writer.commit(messages)
logInfo(s"Data source writer $writer committed.")
if (!writer.isInstanceOf[ContinuousWriter]) {
logInfo(s"Data source writer $writer is committing.")
writer.commit(messages)
logInfo(s"Data source writer $writer committed.")
}
} catch {
case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] =>
// Interruption is how continuous queries are ended, so accept and ignore the exception.

View file

@ -142,7 +142,8 @@ abstract class StreamExecution(
override val id: UUID = UUID.fromString(streamMetadata.id)
override val runId: UUID = UUID.randomUUID
override def runId: UUID = currentRunId
protected var currentRunId = UUID.randomUUID
/**
* Pretty identified string of printing in logs. Format is
@ -418,11 +419,17 @@ abstract class StreamExecution(
* Blocks the current thread until processing for data from the given `source` has reached at
* least the given `Offset`. This method is intended for use primarily when writing tests.
*/
private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = {
private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = {
assertAwaitThread()
def notDone = {
val localCommittedOffsets = committedOffsets
!localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset
if (sources == null) {
// sources might not be initialized yet
false
} else {
val source = sources(sourceIndex)
!localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset
}
}
while (notDone) {
@ -436,7 +443,7 @@ abstract class StreamExecution(
awaitProgressLock.unlock()
}
}
logDebug(s"Unblocked at $newOffset for $source")
logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}")
}
/** A flag to indicate that a batch has completed with no new data available. */

View file

@ -77,7 +77,6 @@ class ContinuousDataSourceRDD(
dataReaderThread.start()
context.addTaskCompletionListener(_ => {
reader.close()
dataReaderThread.interrupt()
epochPollExecutor.shutdown()
})
@ -177,6 +176,7 @@ class DataReaderThread(
private[continuous] var failureReason: Throwable = _
override def run(): Unit = {
TaskContext.setTaskContext(context)
val baseReader = ContinuousDataSourceRDD.getBaseReader(reader)
try {
while (!context.isInterrupted && !context.isCompleted()) {
@ -201,6 +201,8 @@ class DataReaderThread(
failedFlag.set(true)
// Don't rethrow the exception in this thread. It's not needed, and the default Spark
// exception handler will kill the executor.
} finally {
reader.close()
}
}
}

View file

@ -17,7 +17,9 @@
package org.apache.spark.sql.execution.streaming.continuous
import java.util.UUID
import java.util.concurrent.TimeUnit
import java.util.function.UnaryOperator
import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, Map => MutableMap}
@ -52,7 +54,7 @@ class ContinuousExecution(
sparkSession, name, checkpointRoot, analyzedPlan, sink,
trigger, triggerClock, outputMode, deleteCheckpointOnStop) {
@volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty
@volatile protected var continuousSources: Seq[ContinuousReader] = _
override protected def sources: Seq[BaseStreamingSource] = continuousSources
override lazy val logicalPlan: LogicalPlan = {
@ -78,15 +80,17 @@ class ContinuousExecution(
}
override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = {
do {
try {
runContinuous(sparkSessionForStream)
} catch {
case _: InterruptedException if state.get().equals(RECONFIGURING) =>
// swallow exception and run again
state.set(ACTIVE)
val stateUpdate = new UnaryOperator[State] {
override def apply(s: State) = s match {
// If we ended the query to reconfigure, reset the state to active.
case RECONFIGURING => ACTIVE
case _ => s
}
} while (state.get() == ACTIVE)
}
do {
runContinuous(sparkSessionForStream)
} while (state.updateAndGet(stateUpdate) == ACTIVE)
}
/**
@ -120,12 +124,16 @@ class ContinuousExecution(
}
committedOffsets = nextOffsets.toStreamProgress(sources)
// Forcibly align commit and offset logs by slicing off any spurious offset logs from
// a previous run. We can't allow commits to an epoch that a previous run reached but
// this run has not.
offsetLog.purgeAfter(latestEpochId)
// Get to an epoch ID that has definitely never been sent to a sink before. Since sink
// commit happens between offset log write and commit log write, this means an epoch ID
// which is not in the offset log.
val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse {
throw new IllegalStateException(
s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" +
s"an element.")
}
currentBatchId = latestOffsetEpoch + 1
currentBatchId = latestEpochId + 1
logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets")
nextOffsets
case None =>
@ -141,6 +149,7 @@ class ContinuousExecution(
* @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with.
*/
private def runContinuous(sparkSessionForQuery: SparkSession): Unit = {
currentRunId = UUID.randomUUID
// A list of attributes that will need to be updated.
val replacements = new ArrayBuffer[(Attribute, Attribute)]
// Translate from continuous relation to the underlying data source.
@ -225,13 +234,11 @@ class ContinuousExecution(
triggerExecutor.execute(() => {
startTrigger()
if (reader.needsReconfiguration()) {
state.set(RECONFIGURING)
if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) {
stopSources()
if (queryExecutionThread.isAlive) {
sparkSession.sparkContext.cancelJobGroup(runId.toString)
queryExecutionThread.interrupt()
// No need to join - this thread is about to end anyway.
}
false
} else if (isActive) {
@ -259,6 +266,7 @@ class ContinuousExecution(
sparkSessionForQuery, lastExecution)(lastExecution.toRdd)
}
} finally {
epochEndpoint.askSync[Unit](StopContinuousExecutionWrites)
SparkEnv.get.rpcEnv.stop(epochEndpoint)
epochUpdateThread.interrupt()
@ -273,17 +281,22 @@ class ContinuousExecution(
epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = {
assert(continuousSources.length == 1, "only one continuous source supported currently")
if (partitionOffsets.contains(null)) {
// If any offset is null, that means the corresponding partition hasn't seen any data yet, so
// there's nothing meaningful to add to the offset log.
}
val globalOffset = reader.mergeOffsets(partitionOffsets.toArray)
synchronized {
if (queryExecutionThread.isAlive) {
offsetLog.add(epoch, OffsetSeq.fill(globalOffset))
} else {
return
}
val oldOffset = synchronized {
offsetLog.add(epoch, OffsetSeq.fill(globalOffset))
offsetLog.get(epoch - 1)
}
// If offset hasn't changed since last epoch, there's been no new data.
if (oldOffset.contains(OffsetSeq.fill(globalOffset))) {
noNewData = true
}
awaitProgressLock.lock()
try {
awaitProgressLockCondition.signalAll()
} finally {
awaitProgressLock.unlock()
}
}

View file

@ -39,6 +39,15 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable
*/
private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage
/**
* The RpcEndpoint stop() will wait to clear out the message queue before terminating the
* object. This can lead to a race condition where the query restarts at epoch n, a new
* EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1.
* The framework doesn't provide a handle to wait on the message queue, so we use a synchronous
* message to stop any writes to the ContinuousExecution object.
*/
private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage
// Init messages
/**
* Set the reader and writer partition counts. Tasks may not be started until the coordinator
@ -116,6 +125,8 @@ private[continuous] class EpochCoordinator(
override val rpcEnv: RpcEnv)
extends ThreadSafeRpcEndpoint with Logging {
private var queryWritesStopped: Boolean = false
private var numReaderPartitions: Int = _
private var numWriterPartitions: Int = _
@ -147,12 +158,16 @@ private[continuous] class EpochCoordinator(
partitionCommits.remove(k)
}
for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) {
partitionCommits.remove(k)
partitionOffsets.remove(k)
}
}
}
override def receive: PartialFunction[Any, Unit] = {
// If we just drop these messages, we won't do any writes to the query. The lame duck tasks
// won't shed errors or anything.
case _ if queryWritesStopped => ()
case CommitPartitionEpoch(partitionId, epoch, message) =>
logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message")
if (!partitionCommits.isDefinedAt((epoch, partitionId))) {
@ -188,5 +203,9 @@ private[continuous] class EpochCoordinator(
case SetWriterPartitions(numPartitions) =>
numWriterPartitions = numPartitions
context.reply(())
case StopContinuousExecutionWrites =>
queryWritesStopped = true
context.reply(())
}
}

View file

@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2}
import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport
/**
* Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
@ -279,18 +280,29 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
useTempCheckpointLocation = true,
trigger = trigger)
} else {
val dataSource =
DataSource(
df.sparkSession,
className = source,
options = extraOptions.toMap,
partitionColumns = normalizedParCols.getOrElse(Nil))
val sink = trigger match {
case _: ContinuousTrigger =>
val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
ds.newInstance() match {
case w: ContinuousWriteSupport => w
case _ => throw new AnalysisException(
s"Data source $source does not support continuous writing")
}
case _ =>
val ds = DataSource(
df.sparkSession,
className = source,
options = extraOptions.toMap,
partitionColumns = normalizedParCols.getOrElse(Nil))
ds.createSink(outputMode)
}
df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
df,
extraOptions.toMap,
dataSource.createSink(outputMode),
sink,
outputMode,
useTempCheckpointLocation = source == "console",
recoverFromCheckpointLocation = true,

View file

@ -38,8 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row}
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch}
import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch}
import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2
import org.apache.spark.sql.execution.streaming.state.StateStore
import org.apache.spark.sql.streaming.StreamingQueryListener._
@ -80,6 +81,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
StateStore.stop() // stop the state store maintenance thread and unload store providers
}
protected val defaultTrigger = Trigger.ProcessingTime(0)
protected val defaultUseV2Sink = false
/** How long to wait for an active stream to catch up when checking a result. */
val streamingTimeout = 10.seconds
@ -189,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
/** Starts the stream, resuming if data has already been processed. It must not be running. */
case class StartStream(
trigger: Trigger = Trigger.ProcessingTime(0),
trigger: Trigger = defaultTrigger,
triggerClock: Clock = new SystemClock,
additionalConfs: Map[String, String] = Map.empty,
checkpointLocation: String = null)
@ -276,7 +280,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
def testStream(
_stream: Dataset[_],
outputMode: OutputMode = OutputMode.Append,
useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized {
useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized {
import org.apache.spark.sql.streaming.util.StreamManualClock
// `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently
@ -403,18 +407,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = {
verify(currentStream != null, "stream not running")
// Get the map of source index to the current source objects
val indexToSource = currentStream
.logicalPlan
.collect { case StreamingExecutionRelation(s, _) => s }
.zipWithIndex
.map(_.swap)
.toMap
// Block until all data added has been processed for all the source
awaiting.foreach { case (sourceIndex, offset) =>
failAfter(streamingTimeout) {
currentStream.awaitOffset(indexToSource(sourceIndex), offset)
currentStream.awaitOffset(sourceIndex, offset)
}
}
@ -473,6 +470,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
// after starting the query.
try {
currentStream.awaitInitialization(streamingTimeout.toMillis)
currentStream match {
case s: ContinuousExecution => eventually("IncrementalExecution was not created") {
s.lastExecution.executedPlan // will fail if lastExecution is null
}
case _ =>
}
} catch {
case _: StreamingQueryException =>
// Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well.
@ -600,7 +603,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
def findSourceIndex(plan: LogicalPlan): Option[Int] = {
plan
.collect { case StreamingExecutionRelation(s, _) => s }
.collect {
case StreamingExecutionRelation(s, _) => s
case DataSourceV2Relation(_, r) => r
}
.zipWithIndex
.find(_._1 == source)
.map(_._2)
@ -613,9 +619,13 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
findSourceIndex(query.logicalPlan)
}.orElse {
findSourceIndex(stream.logicalPlan)
}.orElse {
queryToUse.flatMap { q =>
findSourceIndex(q.lastExecution.logical)
}
}.getOrElse {
throw new IllegalArgumentException(
"Could find index of the source to which data was added")
"Could not find index of the source to which data was added")
}
// Store the expected offset of added data to wait for it later