[SPARK-22908][SS] Roll forward continuous processing Kafka support with fix to continuous Kafka data reader
## What changes were proposed in this pull request? The Kafka reader is now interruptible and can close itself. ## How was this patch tested? I locally ran one of the ContinuousKafkaSourceSuite tests in a tight loop. Before the fix, my machine ran out of open file descriptors a few iterations in; now it works fine. Author: Jose Torres <jose@databricks.com> Closes #20253 from jose-torres/fix-data-reader.
This commit is contained in:
parent
a9b845ebb5
commit
1667057851
260
external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
vendored
Normal file
260
external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
vendored
Normal file
|
@ -0,0 +1,260 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.kafka010
|
||||
|
||||
import java.{util => ju}
|
||||
import java.util.concurrent.TimeoutException
|
||||
|
||||
import org.apache.kafka.clients.consumer.{ConsumerRecord, OffsetOutOfRangeException}
|
||||
import org.apache.kafka.common.TopicPartition
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
|
||||
import org.apache.spark.sql.sources.v2.reader._
|
||||
import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
/**
|
||||
* A [[ContinuousReader]] for data from kafka.
|
||||
*
|
||||
* @param offsetReader a reader used to get kafka offsets. Note that the actual data will be
|
||||
* read by per-task consumers generated later.
|
||||
* @param kafkaParams String params for per-task Kafka consumers.
|
||||
* @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which
|
||||
* are not Kafka consumer params.
|
||||
* @param metadataPath Path to a directory this reader can use for writing metadata.
|
||||
* @param initialOffsets The Kafka offsets to start reading data at.
|
||||
* @param failOnDataLoss Flag indicating whether reading should fail in data loss
|
||||
* scenarios, where some offsets after the specified initial ones can't be
|
||||
* properly read.
|
||||
*/
|
||||
class KafkaContinuousReader(
|
||||
offsetReader: KafkaOffsetReader,
|
||||
kafkaParams: ju.Map[String, Object],
|
||||
sourceOptions: Map[String, String],
|
||||
metadataPath: String,
|
||||
initialOffsets: KafkaOffsetRangeLimit,
|
||||
failOnDataLoss: Boolean)
|
||||
extends ContinuousReader with SupportsScanUnsafeRow with Logging {
|
||||
|
||||
private lazy val session = SparkSession.getActiveSession.get
|
||||
private lazy val sc = session.sparkContext
|
||||
|
||||
private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong
|
||||
|
||||
// Initialized when creating read tasks. If this diverges from the partitions at the latest
|
||||
// offsets, we need to reconfigure.
|
||||
// Exposed outside this object only for unit tests.
|
||||
private[sql] var knownPartitions: Set[TopicPartition] = _
|
||||
|
||||
override def readSchema: StructType = KafkaOffsetReader.kafkaSchema
|
||||
|
||||
private var offset: Offset = _
|
||||
override def setOffset(start: ju.Optional[Offset]): Unit = {
|
||||
offset = start.orElse {
|
||||
val offsets = initialOffsets match {
|
||||
case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets())
|
||||
case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets())
|
||||
case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss)
|
||||
}
|
||||
logInfo(s"Initial offsets: $offsets")
|
||||
offsets
|
||||
}
|
||||
}
|
||||
|
||||
override def getStartOffset(): Offset = offset
|
||||
|
||||
override def deserializeOffset(json: String): Offset = {
|
||||
KafkaSourceOffset(JsonUtils.partitionOffsets(json))
|
||||
}
|
||||
|
||||
override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = {
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset)
|
||||
|
||||
val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet
|
||||
val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet)
|
||||
val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq)
|
||||
|
||||
val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet)
|
||||
if (deletedPartitions.nonEmpty) {
|
||||
reportDataLoss(s"Some partitions were deleted: $deletedPartitions")
|
||||
}
|
||||
|
||||
val startOffsets = newPartitionOffsets ++
|
||||
oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_))
|
||||
knownPartitions = startOffsets.keySet
|
||||
|
||||
startOffsets.toSeq.map {
|
||||
case (topicPartition, start) =>
|
||||
KafkaContinuousReadTask(
|
||||
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
|
||||
.asInstanceOf[ReadTask[UnsafeRow]]
|
||||
}.asJava
|
||||
}
|
||||
|
||||
/** Stop this source and free any resources it has allocated. */
|
||||
def stop(): Unit = synchronized {
|
||||
offsetReader.close()
|
||||
}
|
||||
|
||||
override def commit(end: Offset): Unit = {}
|
||||
|
||||
override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = {
|
||||
val mergedMap = offsets.map {
|
||||
case KafkaSourcePartitionOffset(p, o) => Map(p -> o)
|
||||
}.reduce(_ ++ _)
|
||||
KafkaSourceOffset(mergedMap)
|
||||
}
|
||||
|
||||
override def needsReconfiguration(): Boolean = {
|
||||
knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions
|
||||
}
|
||||
|
||||
override def toString(): String = s"KafkaSource[$offsetReader]"
|
||||
|
||||
/**
|
||||
* If `failOnDataLoss` is true, this method will throw an `IllegalStateException`.
|
||||
* Otherwise, just log a warning.
|
||||
*/
|
||||
private def reportDataLoss(message: String): Unit = {
|
||||
if (failOnDataLoss) {
|
||||
throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE")
|
||||
} else {
|
||||
logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A read task for continuous Kafka processing. This will be serialized and transformed into a
|
||||
* full reader on executors.
|
||||
*
|
||||
* @param topicPartition The (topic, partition) pair this task is responsible for.
|
||||
* @param startOffset The offset to start reading from within the partition.
|
||||
* @param kafkaParams Kafka consumer params to use.
|
||||
* @param pollTimeoutMs The timeout for Kafka consumer polling.
|
||||
* @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
|
||||
* are skipped.
|
||||
*/
|
||||
case class KafkaContinuousReadTask(
|
||||
topicPartition: TopicPartition,
|
||||
startOffset: Long,
|
||||
kafkaParams: ju.Map[String, Object],
|
||||
pollTimeoutMs: Long,
|
||||
failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] {
|
||||
override def createDataReader(): KafkaContinuousDataReader = {
|
||||
new KafkaContinuousDataReader(
|
||||
topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A per-task data reader for continuous Kafka processing.
|
||||
*
|
||||
* @param topicPartition The (topic, partition) pair this data reader is responsible for.
|
||||
* @param startOffset The offset to start reading from within the partition.
|
||||
* @param kafkaParams Kafka consumer params to use.
|
||||
* @param pollTimeoutMs The timeout for Kafka consumer polling.
|
||||
* @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
|
||||
* are skipped.
|
||||
*/
|
||||
class KafkaContinuousDataReader(
|
||||
topicPartition: TopicPartition,
|
||||
startOffset: Long,
|
||||
kafkaParams: ju.Map[String, Object],
|
||||
pollTimeoutMs: Long,
|
||||
failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] {
|
||||
private val topic = topicPartition.topic
|
||||
private val kafkaPartition = topicPartition.partition
|
||||
private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams)
|
||||
|
||||
private val sharedRow = new UnsafeRow(7)
|
||||
private val bufferHolder = new BufferHolder(sharedRow)
|
||||
private val rowWriter = new UnsafeRowWriter(bufferHolder, 7)
|
||||
|
||||
private var nextKafkaOffset = startOffset
|
||||
private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _
|
||||
|
||||
override def next(): Boolean = {
|
||||
var r: ConsumerRecord[Array[Byte], Array[Byte]] = null
|
||||
while (r == null) {
|
||||
if (TaskContext.get().isInterrupted() || TaskContext.get().isCompleted()) return false
|
||||
// Our consumer.get is not interruptible, so we have to set a low poll timeout, leaving
|
||||
// interrupt points to end the query rather than waiting for new data that might never come.
|
||||
try {
|
||||
r = consumer.get(
|
||||
nextKafkaOffset,
|
||||
untilOffset = Long.MaxValue,
|
||||
pollTimeoutMs,
|
||||
failOnDataLoss)
|
||||
} catch {
|
||||
// We didn't read within the timeout. We're supposed to block indefinitely for new data, so
|
||||
// swallow and ignore this.
|
||||
case _: TimeoutException =>
|
||||
|
||||
// This is a failOnDataLoss exception. Retry if nextKafkaOffset is within the data range,
|
||||
// or if it's the endpoint of the data range (i.e. the "true" next offset).
|
||||
case e: IllegalStateException if e.getCause.isInstanceOf[OffsetOutOfRangeException] =>
|
||||
val range = consumer.getAvailableOffsetRange()
|
||||
if (range.latest >= nextKafkaOffset && range.earliest <= nextKafkaOffset) {
|
||||
// retry
|
||||
} else {
|
||||
throw e
|
||||
}
|
||||
}
|
||||
}
|
||||
nextKafkaOffset = r.offset + 1
|
||||
currentRecord = r
|
||||
true
|
||||
}
|
||||
|
||||
override def get(): UnsafeRow = {
|
||||
bufferHolder.reset()
|
||||
|
||||
if (currentRecord.key == null) {
|
||||
rowWriter.setNullAt(0)
|
||||
} else {
|
||||
rowWriter.write(0, currentRecord.key)
|
||||
}
|
||||
rowWriter.write(1, currentRecord.value)
|
||||
rowWriter.write(2, UTF8String.fromString(currentRecord.topic))
|
||||
rowWriter.write(3, currentRecord.partition)
|
||||
rowWriter.write(4, currentRecord.offset)
|
||||
rowWriter.write(5,
|
||||
DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp)))
|
||||
rowWriter.write(6, currentRecord.timestampType.id)
|
||||
sharedRow.setTotalSize(bufferHolder.totalSize)
|
||||
sharedRow
|
||||
}
|
||||
|
||||
override def getOffset(): KafkaSourcePartitionOffset = {
|
||||
KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset)
|
||||
}
|
||||
|
||||
override def close(): Unit = {
|
||||
consumer.close()
|
||||
}
|
||||
}
|
119
external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala
vendored
Normal file
119
external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala
vendored
Normal file
|
@ -0,0 +1,119 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.kafka010
|
||||
|
||||
import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata}
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.{Row, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection}
|
||||
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY}
|
||||
import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery
|
||||
import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter
|
||||
import org.apache.spark.sql.sources.v2.writer._
|
||||
import org.apache.spark.sql.streaming.OutputMode
|
||||
import org.apache.spark.sql.types.{BinaryType, StringType, StructType}
|
||||
|
||||
/**
|
||||
* Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we
|
||||
* don't need to really send one.
|
||||
*/
|
||||
case object KafkaWriterCommitMessage extends WriterCommitMessage
|
||||
|
||||
/**
|
||||
* A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory.
|
||||
* @param topic The topic this writer is responsible for. If None, topic will be inferred from
|
||||
* a `topic` field in the incoming data.
|
||||
* @param producerParams Parameters for Kafka producers in each task.
|
||||
* @param schema The schema of the input data.
|
||||
*/
|
||||
class KafkaContinuousWriter(
|
||||
topic: Option[String], producerParams: Map[String, String], schema: StructType)
|
||||
extends ContinuousWriter with SupportsWriteInternalRow {
|
||||
|
||||
validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic)
|
||||
|
||||
override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory =
|
||||
KafkaContinuousWriterFactory(topic, producerParams, schema)
|
||||
|
||||
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
|
||||
override def abort(messages: Array[WriterCommitMessage]): Unit = {}
|
||||
}
|
||||
|
||||
/**
|
||||
* A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate
|
||||
* the per-task data writers.
|
||||
* @param topic The topic that should be written to. If None, topic will be inferred from
|
||||
* a `topic` field in the incoming data.
|
||||
* @param producerParams Parameters for Kafka producers in each task.
|
||||
* @param schema The schema of the input data.
|
||||
*/
|
||||
case class KafkaContinuousWriterFactory(
|
||||
topic: Option[String], producerParams: Map[String, String], schema: StructType)
|
||||
extends DataWriterFactory[InternalRow] {
|
||||
|
||||
override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = {
|
||||
new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A [[DataWriter]] for Kafka writing. One data writer will be created in each partition to
|
||||
* process incoming rows.
|
||||
*
|
||||
* @param targetTopic The topic that this data writer is targeting. If None, topic will be inferred
|
||||
* from a `topic` field in the incoming data.
|
||||
* @param producerParams Parameters to use for the Kafka producer.
|
||||
* @param inputSchema The attributes in the input data.
|
||||
*/
|
||||
class KafkaContinuousDataWriter(
|
||||
targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute])
|
||||
extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] {
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
private lazy val producer = CachedKafkaProducer.getOrCreate(
|
||||
new java.util.HashMap[String, Object](producerParams.asJava))
|
||||
|
||||
def write(row: InternalRow): Unit = {
|
||||
checkForErrors()
|
||||
sendRow(row, producer)
|
||||
}
|
||||
|
||||
def commit(): WriterCommitMessage = {
|
||||
// Send is asynchronous, but we can't commit until all rows are actually in Kafka.
|
||||
// This requires flushing and then checking that no callbacks produced errors.
|
||||
// We also check for errors before to fail as soon as possible - the check is cheap.
|
||||
checkForErrors()
|
||||
producer.flush()
|
||||
checkForErrors()
|
||||
KafkaWriterCommitMessage
|
||||
}
|
||||
|
||||
def abort(): Unit = {}
|
||||
|
||||
def close(): Unit = {
|
||||
checkForErrors()
|
||||
if (producer != null) {
|
||||
producer.flush()
|
||||
checkForErrors()
|
||||
CachedKafkaProducer.close(new java.util.HashMap[String, Object](producerParams.asJava))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -117,10 +117,14 @@ private[kafka010] class KafkaOffsetReader(
|
|||
* Resolves the specific offsets based on Kafka seek positions.
|
||||
* This method resolves offset value -1 to the latest and -2 to the
|
||||
* earliest Kafka seek position.
|
||||
*
|
||||
* @param partitionOffsets the specific offsets to resolve
|
||||
* @param reportDataLoss callback to either report or log data loss depending on setting
|
||||
*/
|
||||
def fetchSpecificOffsets(
|
||||
partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] =
|
||||
runUninterruptibly {
|
||||
partitionOffsets: Map[TopicPartition, Long],
|
||||
reportDataLoss: String => Unit): KafkaSourceOffset = {
|
||||
val fetched = runUninterruptibly {
|
||||
withRetriesWithoutInterrupt {
|
||||
// Poll to get the latest assigned partitions
|
||||
consumer.poll(0)
|
||||
|
@ -145,6 +149,19 @@ private[kafka010] class KafkaOffsetReader(
|
|||
}
|
||||
}
|
||||
|
||||
partitionOffsets.foreach {
|
||||
case (tp, off) if off != KafkaOffsetRangeLimit.LATEST &&
|
||||
off != KafkaOffsetRangeLimit.EARLIEST =>
|
||||
if (fetched(tp) != off) {
|
||||
reportDataLoss(
|
||||
s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}")
|
||||
}
|
||||
case _ =>
|
||||
// no real way to check that beginning or end is reasonable
|
||||
}
|
||||
KafkaSourceOffset(fetched)
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch the earliest offsets for the topic partitions that are indicated
|
||||
* in the [[ConsumerStrategy]].
|
||||
|
|
|
@ -130,7 +130,7 @@ private[kafka010] class KafkaSource(
|
|||
val offsets = startingOffsets match {
|
||||
case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets())
|
||||
case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets())
|
||||
case SpecificOffsetRangeLimit(p) => fetchAndVerify(p)
|
||||
case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss)
|
||||
}
|
||||
metadataLog.add(0, offsets)
|
||||
logInfo(s"Initial offsets: $offsets")
|
||||
|
@ -138,21 +138,6 @@ private[kafka010] class KafkaSource(
|
|||
}.partitionToOffsets
|
||||
}
|
||||
|
||||
private def fetchAndVerify(specificOffsets: Map[TopicPartition, Long]) = {
|
||||
val result = kafkaReader.fetchSpecificOffsets(specificOffsets)
|
||||
specificOffsets.foreach {
|
||||
case (tp, off) if off != KafkaOffsetRangeLimit.LATEST &&
|
||||
off != KafkaOffsetRangeLimit.EARLIEST =>
|
||||
if (result(tp) != off) {
|
||||
reportDataLoss(
|
||||
s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}")
|
||||
}
|
||||
case _ =>
|
||||
// no real way to check that beginning or end is reasonable
|
||||
}
|
||||
KafkaSourceOffset(result)
|
||||
}
|
||||
|
||||
private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None
|
||||
|
||||
override def schema: StructType = KafkaOffsetReader.kafkaSchema
|
||||
|
|
|
@ -20,17 +20,22 @@ package org.apache.spark.sql.kafka010
|
|||
import org.apache.kafka.common.TopicPartition
|
||||
|
||||
import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset}
|
||||
import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset}
|
||||
|
||||
/**
|
||||
* An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and
|
||||
* their offsets.
|
||||
*/
|
||||
private[kafka010]
|
||||
case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset {
|
||||
case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 {
|
||||
|
||||
override val json = JsonUtils.partitionOffsets(partitionToOffsets)
|
||||
}
|
||||
|
||||
private[kafka010]
|
||||
case class KafkaSourcePartitionOffset(topicPartition: TopicPartition, partitionOffset: Long)
|
||||
extends PartitionOffset
|
||||
|
||||
/** Companion object of the [[KafkaSourceOffset]] */
|
||||
private[kafka010] object KafkaSourceOffset {
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.sql.kafka010
|
||||
|
||||
import java.{util => ju}
|
||||
import java.util.{Locale, UUID}
|
||||
import java.util.{Locale, Optional, UUID}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
|
@ -27,9 +27,12 @@ import org.apache.kafka.clients.producer.ProducerConfig
|
|||
import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer}
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
|
||||
import org.apache.spark.sql.execution.streaming.{Sink, Source}
|
||||
import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext}
|
||||
import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source}
|
||||
import org.apache.spark.sql.sources._
|
||||
import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options}
|
||||
import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport}
|
||||
import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter
|
||||
import org.apache.spark.sql.streaming.OutputMode
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
|
@ -43,6 +46,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
|||
with StreamSinkProvider
|
||||
with RelationProvider
|
||||
with CreatableRelationProvider
|
||||
with ContinuousWriteSupport
|
||||
with ContinuousReadSupport
|
||||
with Logging {
|
||||
import KafkaSourceProvider._
|
||||
|
||||
|
@ -101,6 +106,43 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
|||
failOnDataLoss(caseInsensitiveParams))
|
||||
}
|
||||
|
||||
override def createContinuousReader(
|
||||
schema: Optional[StructType],
|
||||
metadataPath: String,
|
||||
options: DataSourceV2Options): KafkaContinuousReader = {
|
||||
val parameters = options.asMap().asScala.toMap
|
||||
validateStreamOptions(parameters)
|
||||
// Each running query should use its own group id. Otherwise, the query may be only assigned
|
||||
// partial data since Kafka will assign partitions to multiple consumers having the same group
|
||||
// id. Hence, we should generate a unique id for each query.
|
||||
val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}"
|
||||
|
||||
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
|
||||
val specifiedKafkaParams =
|
||||
parameters
|
||||
.keySet
|
||||
.filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
|
||||
.map { k => k.drop(6).toString -> parameters(k) }
|
||||
.toMap
|
||||
|
||||
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams,
|
||||
STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
|
||||
|
||||
val kafkaOffsetReader = new KafkaOffsetReader(
|
||||
strategy(caseInsensitiveParams),
|
||||
kafkaParamsForDriver(specifiedKafkaParams),
|
||||
parameters,
|
||||
driverGroupIdPrefix = s"$uniqueGroupId-driver")
|
||||
|
||||
new KafkaContinuousReader(
|
||||
kafkaOffsetReader,
|
||||
kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
|
||||
parameters,
|
||||
metadataPath,
|
||||
startingStreamOffsets,
|
||||
failOnDataLoss(caseInsensitiveParams))
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a new base relation with the given parameters.
|
||||
*
|
||||
|
@ -181,26 +223,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
|
|||
}
|
||||
}
|
||||
|
||||
private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = {
|
||||
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
|
||||
if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) {
|
||||
throw new IllegalArgumentException(
|
||||
s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys "
|
||||
+ "are serialized with ByteArraySerializer.")
|
||||
}
|
||||
override def createContinuousWriter(
|
||||
queryId: String,
|
||||
schema: StructType,
|
||||
mode: OutputMode,
|
||||
options: DataSourceV2Options): Optional[ContinuousWriter] = {
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}"))
|
||||
{
|
||||
throw new IllegalArgumentException(
|
||||
s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as "
|
||||
+ "value are serialized with ByteArraySerializer.")
|
||||
}
|
||||
parameters
|
||||
.keySet
|
||||
.filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
|
||||
.map { k => k.drop(6).toString -> parameters(k) }
|
||||
.toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName,
|
||||
ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName)
|
||||
val spark = SparkSession.getActiveSession.get
|
||||
val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim)
|
||||
// We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable.
|
||||
val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap)
|
||||
|
||||
KafkaWriter.validateQuery(
|
||||
schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic)
|
||||
|
||||
Optional.of(new KafkaContinuousWriter(topic, producerParams, schema))
|
||||
}
|
||||
|
||||
private def strategy(caseInsensitiveParams: Map[String, String]) =
|
||||
|
@ -450,4 +488,27 @@ private[kafka010] object KafkaSourceProvider extends Logging {
|
|||
|
||||
def build(): ju.Map[String, Object] = map
|
||||
}
|
||||
|
||||
private[kafka010] def kafkaParamsForProducer(
|
||||
parameters: Map[String, String]): Map[String, String] = {
|
||||
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
|
||||
if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) {
|
||||
throw new IllegalArgumentException(
|
||||
s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys "
|
||||
+ "are serialized with ByteArraySerializer.")
|
||||
}
|
||||
|
||||
if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}"))
|
||||
{
|
||||
throw new IllegalArgumentException(
|
||||
s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as "
|
||||
+ "value are serialized with ByteArraySerializer.")
|
||||
}
|
||||
parameters
|
||||
.keySet
|
||||
.filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
|
||||
.map { k => k.drop(6).toString -> parameters(k) }
|
||||
.toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName,
|
||||
ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,10 +33,8 @@ import org.apache.spark.sql.types.{BinaryType, StringType}
|
|||
private[kafka010] class KafkaWriteTask(
|
||||
producerConfiguration: ju.Map[String, Object],
|
||||
inputSchema: Seq[Attribute],
|
||||
topic: Option[String]) {
|
||||
topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) {
|
||||
// used to synchronize with Kafka callbacks
|
||||
@volatile private var failedWrite: Exception = null
|
||||
private val projection = createProjection
|
||||
private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _
|
||||
|
||||
/**
|
||||
|
@ -46,23 +44,7 @@ private[kafka010] class KafkaWriteTask(
|
|||
producer = CachedKafkaProducer.getOrCreate(producerConfiguration)
|
||||
while (iterator.hasNext && failedWrite == null) {
|
||||
val currentRow = iterator.next()
|
||||
val projectedRow = projection(currentRow)
|
||||
val topic = projectedRow.getUTF8String(0)
|
||||
val key = projectedRow.getBinary(1)
|
||||
val value = projectedRow.getBinary(2)
|
||||
if (topic == null) {
|
||||
throw new NullPointerException(s"null topic present in the data. Use the " +
|
||||
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.")
|
||||
}
|
||||
val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value)
|
||||
val callback = new Callback() {
|
||||
override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = {
|
||||
if (failedWrite == null && e != null) {
|
||||
failedWrite = e
|
||||
}
|
||||
}
|
||||
}
|
||||
producer.send(record, callback)
|
||||
sendRow(currentRow, producer)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -74,8 +56,49 @@ private[kafka010] class KafkaWriteTask(
|
|||
producer = null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def createProjection: UnsafeProjection = {
|
||||
private[kafka010] abstract class KafkaRowWriter(
|
||||
inputSchema: Seq[Attribute], topic: Option[String]) {
|
||||
|
||||
// used to synchronize with Kafka callbacks
|
||||
@volatile protected var failedWrite: Exception = _
|
||||
protected val projection = createProjection
|
||||
|
||||
private val callback = new Callback() {
|
||||
override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = {
|
||||
if (failedWrite == null && e != null) {
|
||||
failedWrite = e
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send the specified row to the producer, with a callback that will save any exception
|
||||
* to failedWrite. Note that send is asynchronous; subclasses must flush() their producer before
|
||||
* assuming the row is in Kafka.
|
||||
*/
|
||||
protected def sendRow(
|
||||
row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): Unit = {
|
||||
val projectedRow = projection(row)
|
||||
val topic = projectedRow.getUTF8String(0)
|
||||
val key = projectedRow.getBinary(1)
|
||||
val value = projectedRow.getBinary(2)
|
||||
if (topic == null) {
|
||||
throw new NullPointerException(s"null topic present in the data. Use the " +
|
||||
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.")
|
||||
}
|
||||
val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value)
|
||||
producer.send(record, callback)
|
||||
}
|
||||
|
||||
protected def checkForErrors(): Unit = {
|
||||
if (failedWrite != null) {
|
||||
throw failedWrite
|
||||
}
|
||||
}
|
||||
|
||||
private def createProjection = {
|
||||
val topicExpression = topic.map(Literal(_)).orElse {
|
||||
inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME)
|
||||
}.getOrElse {
|
||||
|
@ -112,11 +135,5 @@ private[kafka010] class KafkaWriteTask(
|
|||
Seq(topicExpression, Cast(keyExpression, BinaryType),
|
||||
Cast(valueExpression, BinaryType)), inputSchema)
|
||||
}
|
||||
|
||||
private def checkForErrors(): Unit = {
|
||||
if (failedWrite != null) {
|
||||
throw failedWrite
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -43,10 +43,9 @@ private[kafka010] object KafkaWriter extends Logging {
|
|||
override def toString: String = "KafkaWriter"
|
||||
|
||||
def validateQuery(
|
||||
queryExecution: QueryExecution,
|
||||
schema: Seq[Attribute],
|
||||
kafkaParameters: ju.Map[String, Object],
|
||||
topic: Option[String] = None): Unit = {
|
||||
val schema = queryExecution.analyzed.output
|
||||
schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse(
|
||||
if (topic.isEmpty) {
|
||||
throw new AnalysisException(s"topic option required when no " +
|
||||
|
@ -84,7 +83,7 @@ private[kafka010] object KafkaWriter extends Logging {
|
|||
kafkaParameters: ju.Map[String, Object],
|
||||
topic: Option[String] = None): Unit = {
|
||||
val schema = queryExecution.analyzed.output
|
||||
validateQuery(queryExecution, kafkaParameters, topic)
|
||||
validateQuery(schema, kafkaParameters, topic)
|
||||
queryExecution.toRdd.foreachPartition { iter =>
|
||||
val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic)
|
||||
Utils.tryWithSafeFinally(block = writeTask.execute(iter))(
|
||||
|
|
476
external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
vendored
Normal file
476
external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
vendored
Normal file
|
@ -0,0 +1,476 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.kafka010
|
||||
|
||||
import java.util.Locale
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
|
||||
import org.apache.kafka.clients.producer.ProducerConfig
|
||||
import org.apache.kafka.common.serialization.ByteArraySerializer
|
||||
import org.scalatest.time.SpanSugar._
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode}
|
||||
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection}
|
||||
import org.apache.spark.sql.execution.streaming.MemoryStream
|
||||
import org.apache.spark.sql.streaming._
|
||||
import org.apache.spark.sql.types.{BinaryType, DataType}
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
* This is a temporary port of KafkaSinkSuite, since we do not yet have a V2 memory stream.
|
||||
* Once we have one, this will be changed to a specialization of KafkaSinkSuite and we won't have
|
||||
* to duplicate all the code.
|
||||
*/
|
||||
class KafkaContinuousSinkSuite extends KafkaContinuousTest {
|
||||
import testImplicits._
|
||||
|
||||
override val streamingTimeout = 30.seconds
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
testUtils = new KafkaTestUtils(
|
||||
withBrokerProps = Map("auto.create.topics.enable" -> "false"))
|
||||
testUtils.setup()
|
||||
}
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
if (testUtils != null) {
|
||||
testUtils.teardown()
|
||||
testUtils = null
|
||||
}
|
||||
super.afterAll()
|
||||
}
|
||||
|
||||
test("streaming - write to kafka with topic field") {
|
||||
val inputTopic = newTopic()
|
||||
testUtils.createTopic(inputTopic, partitions = 1)
|
||||
|
||||
val input = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("subscribe", inputTopic)
|
||||
.option("startingOffsets", "earliest")
|
||||
.load()
|
||||
|
||||
val topic = newTopic()
|
||||
testUtils.createTopic(topic)
|
||||
|
||||
val writer = createKafkaWriter(
|
||||
input.toDF(),
|
||||
withTopic = None,
|
||||
withOutputMode = Some(OutputMode.Append))(
|
||||
withSelectExpr = s"'$topic' as topic", "value")
|
||||
|
||||
val reader = createKafkaReader(topic)
|
||||
.selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
|
||||
.selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
|
||||
.as[(Int, Int)]
|
||||
.map(_._2)
|
||||
|
||||
try {
|
||||
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
checkDatasetUnorderly(reader, 1, 2, 3, 4, 5)
|
||||
}
|
||||
testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10"))
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
|
||||
}
|
||||
} finally {
|
||||
writer.stop()
|
||||
}
|
||||
}
|
||||
|
||||
test("streaming - write w/o topic field, with topic option") {
|
||||
val inputTopic = newTopic()
|
||||
testUtils.createTopic(inputTopic, partitions = 1)
|
||||
|
||||
val input = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("subscribe", inputTopic)
|
||||
.option("startingOffsets", "earliest")
|
||||
.load()
|
||||
|
||||
val topic = newTopic()
|
||||
testUtils.createTopic(topic)
|
||||
|
||||
val writer = createKafkaWriter(
|
||||
input.toDF(),
|
||||
withTopic = Some(topic),
|
||||
withOutputMode = Some(OutputMode.Append()))()
|
||||
|
||||
val reader = createKafkaReader(topic)
|
||||
.selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
|
||||
.selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
|
||||
.as[(Int, Int)]
|
||||
.map(_._2)
|
||||
|
||||
try {
|
||||
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
checkDatasetUnorderly(reader, 1, 2, 3, 4, 5)
|
||||
}
|
||||
testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10"))
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
|
||||
}
|
||||
} finally {
|
||||
writer.stop()
|
||||
}
|
||||
}
|
||||
|
||||
test("streaming - topic field and topic option") {
|
||||
/* The purpose of this test is to ensure that the topic option
|
||||
* overrides the topic field. We begin by writing some data that
|
||||
* includes a topic field and value (e.g., 'foo') along with a topic
|
||||
* option. Then when we read from the topic specified in the option
|
||||
* we should see the data i.e., the data was written to the topic
|
||||
* option, and not to the topic in the data e.g., foo
|
||||
*/
|
||||
val inputTopic = newTopic()
|
||||
testUtils.createTopic(inputTopic, partitions = 1)
|
||||
|
||||
val input = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("subscribe", inputTopic)
|
||||
.option("startingOffsets", "earliest")
|
||||
.load()
|
||||
|
||||
val topic = newTopic()
|
||||
testUtils.createTopic(topic)
|
||||
|
||||
val writer = createKafkaWriter(
|
||||
input.toDF(),
|
||||
withTopic = Some(topic),
|
||||
withOutputMode = Some(OutputMode.Append()))(
|
||||
withSelectExpr = "'foo' as topic", "CAST(value as STRING) value")
|
||||
|
||||
val reader = createKafkaReader(topic)
|
||||
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
|
||||
.selectExpr("CAST(key AS INT)", "CAST(value AS INT)")
|
||||
.as[(Int, Int)]
|
||||
.map(_._2)
|
||||
|
||||
try {
|
||||
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
checkDatasetUnorderly(reader, 1, 2, 3, 4, 5)
|
||||
}
|
||||
testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10"))
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
|
||||
}
|
||||
} finally {
|
||||
writer.stop()
|
||||
}
|
||||
}
|
||||
|
||||
test("null topic attribute") {
|
||||
val inputTopic = newTopic()
|
||||
testUtils.createTopic(inputTopic, partitions = 1)
|
||||
|
||||
val input = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("subscribe", inputTopic)
|
||||
.option("startingOffsets", "earliest")
|
||||
.load()
|
||||
val topic = newTopic()
|
||||
testUtils.createTopic(topic)
|
||||
|
||||
/* No topic field or topic option */
|
||||
var writer: StreamingQuery = null
|
||||
var ex: Exception = null
|
||||
try {
|
||||
writer = createKafkaWriter(input.toDF())(
|
||||
withSelectExpr = "CAST(null as STRING) as topic", "value"
|
||||
)
|
||||
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
assert(writer.exception.isDefined)
|
||||
ex = writer.exception.get
|
||||
}
|
||||
} finally {
|
||||
writer.stop()
|
||||
}
|
||||
assert(ex.getCause.getCause.getMessage
|
||||
.toLowerCase(Locale.ROOT)
|
||||
.contains("null topic present in the data."))
|
||||
}
|
||||
|
||||
test("streaming - write data with bad schema") {
|
||||
val inputTopic = newTopic()
|
||||
testUtils.createTopic(inputTopic, partitions = 1)
|
||||
|
||||
val input = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("subscribe", inputTopic)
|
||||
.option("startingOffsets", "earliest")
|
||||
.load()
|
||||
val topic = newTopic()
|
||||
testUtils.createTopic(topic)
|
||||
|
||||
/* No topic field or topic option */
|
||||
var writer: StreamingQuery = null
|
||||
var ex: Exception = null
|
||||
try {
|
||||
writer = createKafkaWriter(input.toDF())(
|
||||
withSelectExpr = "value as key", "value"
|
||||
)
|
||||
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
assert(writer.exception.isDefined)
|
||||
ex = writer.exception.get
|
||||
}
|
||||
} finally {
|
||||
writer.stop()
|
||||
}
|
||||
assert(ex.getMessage
|
||||
.toLowerCase(Locale.ROOT)
|
||||
.contains("topic option required when no 'topic' attribute is present"))
|
||||
|
||||
try {
|
||||
/* No value field */
|
||||
writer = createKafkaWriter(input.toDF())(
|
||||
withSelectExpr = s"'$topic' as topic", "value as key"
|
||||
)
|
||||
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
assert(writer.exception.isDefined)
|
||||
ex = writer.exception.get
|
||||
}
|
||||
} finally {
|
||||
writer.stop()
|
||||
}
|
||||
assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
|
||||
"required attribute 'value' not found"))
|
||||
}
|
||||
|
||||
test("streaming - write data with valid schema but wrong types") {
|
||||
val inputTopic = newTopic()
|
||||
testUtils.createTopic(inputTopic, partitions = 1)
|
||||
|
||||
val input = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("subscribe", inputTopic)
|
||||
.option("startingOffsets", "earliest")
|
||||
.load()
|
||||
.selectExpr("CAST(value as STRING) value")
|
||||
val topic = newTopic()
|
||||
testUtils.createTopic(topic)
|
||||
|
||||
var writer: StreamingQuery = null
|
||||
var ex: Exception = null
|
||||
try {
|
||||
/* topic field wrong type */
|
||||
writer = createKafkaWriter(input.toDF())(
|
||||
withSelectExpr = s"CAST('1' as INT) as topic", "value"
|
||||
)
|
||||
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
assert(writer.exception.isDefined)
|
||||
ex = writer.exception.get
|
||||
}
|
||||
} finally {
|
||||
writer.stop()
|
||||
}
|
||||
assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string"))
|
||||
|
||||
try {
|
||||
/* value field wrong type */
|
||||
writer = createKafkaWriter(input.toDF())(
|
||||
withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value"
|
||||
)
|
||||
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
assert(writer.exception.isDefined)
|
||||
ex = writer.exception.get
|
||||
}
|
||||
} finally {
|
||||
writer.stop()
|
||||
}
|
||||
assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
|
||||
"value attribute type must be a string or binarytype"))
|
||||
|
||||
try {
|
||||
/* key field wrong type */
|
||||
writer = createKafkaWriter(input.toDF())(
|
||||
withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value"
|
||||
)
|
||||
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
assert(writer.exception.isDefined)
|
||||
ex = writer.exception.get
|
||||
}
|
||||
} finally {
|
||||
writer.stop()
|
||||
}
|
||||
assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
|
||||
"key attribute type must be a string or binarytype"))
|
||||
}
|
||||
|
||||
test("streaming - write to non-existing topic") {
|
||||
val inputTopic = newTopic()
|
||||
testUtils.createTopic(inputTopic, partitions = 1)
|
||||
|
||||
val input = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("subscribe", inputTopic)
|
||||
.option("startingOffsets", "earliest")
|
||||
.load()
|
||||
val topic = newTopic()
|
||||
|
||||
var writer: StreamingQuery = null
|
||||
var ex: Exception = null
|
||||
try {
|
||||
ex = intercept[StreamingQueryException] {
|
||||
writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))()
|
||||
testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
assert(writer.exception.isDefined)
|
||||
}
|
||||
throw writer.exception.get
|
||||
}
|
||||
} finally {
|
||||
writer.stop()
|
||||
}
|
||||
assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted"))
|
||||
}
|
||||
|
||||
test("streaming - exception on config serializer") {
|
||||
val inputTopic = newTopic()
|
||||
testUtils.createTopic(inputTopic, partitions = 1)
|
||||
testUtils.sendMessages(inputTopic, Array("0"))
|
||||
|
||||
val input = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("subscribe", inputTopic)
|
||||
.load()
|
||||
var writer: StreamingQuery = null
|
||||
var ex: Exception = null
|
||||
try {
|
||||
writer = createKafkaWriter(
|
||||
input.toDF(),
|
||||
withOptions = Map("kafka.key.serializer" -> "foo"))()
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
assert(writer.exception.isDefined)
|
||||
ex = writer.exception.get
|
||||
}
|
||||
assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
|
||||
"kafka option 'key.serializer' is not supported"))
|
||||
} finally {
|
||||
writer.stop()
|
||||
}
|
||||
|
||||
try {
|
||||
writer = createKafkaWriter(
|
||||
input.toDF(),
|
||||
withOptions = Map("kafka.value.serializer" -> "foo"))()
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
assert(writer.exception.isDefined)
|
||||
ex = writer.exception.get
|
||||
}
|
||||
assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
|
||||
"kafka option 'value.serializer' is not supported"))
|
||||
} finally {
|
||||
writer.stop()
|
||||
}
|
||||
}
|
||||
|
||||
test("generic - write big data with small producer buffer") {
|
||||
/* This test ensures that we understand the semantics of Kafka when
|
||||
* is comes to blocking on a call to send when the send buffer is full.
|
||||
* This test will configure the smallest possible producer buffer and
|
||||
* indicate that we should block when it is full. Thus, no exception should
|
||||
* be thrown in the case of a full buffer.
|
||||
*/
|
||||
val topic = newTopic()
|
||||
testUtils.createTopic(topic, 1)
|
||||
val options = new java.util.HashMap[String, String]
|
||||
options.put("bootstrap.servers", testUtils.brokerAddress)
|
||||
options.put("buffer.memory", "16384") // min buffer size
|
||||
options.put("block.on.buffer.full", "true")
|
||||
options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName)
|
||||
options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName)
|
||||
val inputSchema = Seq(AttributeReference("value", BinaryType)())
|
||||
val data = new Array[Byte](15000) // large value
|
||||
val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema)
|
||||
try {
|
||||
val fieldTypes: Array[DataType] = Array(BinaryType)
|
||||
val converter = UnsafeProjection.create(fieldTypes)
|
||||
val row = new SpecificInternalRow(fieldTypes)
|
||||
row.update(0, data)
|
||||
val iter = Seq.fill(1000)(converter.apply(row)).iterator
|
||||
iter.foreach(writeTask.write(_))
|
||||
writeTask.commit()
|
||||
} finally {
|
||||
writeTask.close()
|
||||
}
|
||||
}
|
||||
|
||||
private def createKafkaReader(topic: String): DataFrame = {
|
||||
spark.read
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("startingOffsets", "earliest")
|
||||
.option("endingOffsets", "latest")
|
||||
.option("subscribe", topic)
|
||||
.load()
|
||||
}
|
||||
|
||||
private def createKafkaWriter(
|
||||
input: DataFrame,
|
||||
withTopic: Option[String] = None,
|
||||
withOutputMode: Option[OutputMode] = None,
|
||||
withOptions: Map[String, String] = Map[String, String]())
|
||||
(withSelectExpr: String*): StreamingQuery = {
|
||||
var stream: DataStreamWriter[Row] = null
|
||||
val checkpointDir = Utils.createTempDir()
|
||||
var df = input.toDF()
|
||||
if (withSelectExpr.length > 0) {
|
||||
df = df.selectExpr(withSelectExpr: _*)
|
||||
}
|
||||
stream = df.writeStream
|
||||
.format("kafka")
|
||||
.option("checkpointLocation", checkpointDir.getCanonicalPath)
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
// We need to reduce blocking time to efficiently test non-existent partition behavior.
|
||||
.option("kafka.max.block.ms", "1000")
|
||||
.trigger(Trigger.Continuous(1000))
|
||||
.queryName("kafkaStream")
|
||||
withTopic.foreach(stream.option("topic", _))
|
||||
withOutputMode.foreach(stream.outputMode(_))
|
||||
withOptions.foreach(opt => stream.option(opt._1, opt._2))
|
||||
stream.start()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.kafka010
|
||||
|
||||
import java.util.Properties
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
|
||||
import org.scalatest.time.SpanSugar._
|
||||
import scala.collection.mutable
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row}
|
||||
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
|
||||
import org.apache.spark.sql.execution.streaming.StreamExecution
|
||||
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
|
||||
import org.apache.spark.sql.streaming.{StreamTest, Trigger}
|
||||
import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}
|
||||
|
||||
// Run tests in KafkaSourceSuiteBase in continuous execution mode.
|
||||
class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest
|
||||
|
||||
class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest {
|
||||
import testImplicits._
|
||||
|
||||
override val brokerProps = Map("auto.create.topics.enable" -> "false")
|
||||
|
||||
test("subscribing topic by pattern with topic deletions") {
|
||||
val topicPrefix = newTopic()
|
||||
val topic = topicPrefix + "-seems"
|
||||
val topic2 = topicPrefix + "-bad"
|
||||
testUtils.createTopic(topic, partitions = 5)
|
||||
testUtils.sendMessages(topic, Array("-1"))
|
||||
require(testUtils.getLatestOffsets(Set(topic)).size === 5)
|
||||
|
||||
val reader = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("kafka.metadata.max.age.ms", "1")
|
||||
.option("subscribePattern", s"$topicPrefix-.*")
|
||||
.option("failOnDataLoss", "false")
|
||||
|
||||
val kafka = reader.load()
|
||||
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
|
||||
.as[(String, String)]
|
||||
val mapped = kafka.map(kv => kv._2.toInt + 1)
|
||||
|
||||
testStream(mapped)(
|
||||
makeSureGetOffsetCalled,
|
||||
AddKafkaData(Set(topic), 1, 2, 3),
|
||||
CheckAnswer(2, 3, 4),
|
||||
Execute { query =>
|
||||
testUtils.deleteTopic(topic)
|
||||
testUtils.createTopic(topic2, partitions = 5)
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
assert(
|
||||
query.lastExecution.logical.collectFirst {
|
||||
case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
|
||||
}.exists { r =>
|
||||
// Ensure the new topic is present and the old topic is gone.
|
||||
r.knownPartitions.exists(_.topic == topic2)
|
||||
},
|
||||
s"query never reconfigured to new topic $topic2")
|
||||
}
|
||||
},
|
||||
AddKafkaData(Set(topic2), 4, 5, 6),
|
||||
CheckAnswer(2, 3, 4, 5, 6, 7)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
class KafkaContinuousSourceStressForDontFailOnDataLossSuite
|
||||
extends KafkaSourceStressForDontFailOnDataLossSuite {
|
||||
override protected def startStream(ds: Dataset[Int]) = {
|
||||
ds.writeStream
|
||||
.format("memory")
|
||||
.queryName("memory")
|
||||
.start()
|
||||
}
|
||||
}
|
94
external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
vendored
Normal file
94
external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
vendored
Normal file
|
@ -0,0 +1,94 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.kafka010
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart}
|
||||
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
|
||||
import org.apache.spark.sql.execution.streaming.StreamExecution
|
||||
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
|
||||
import org.apache.spark.sql.streaming.Trigger
|
||||
import org.apache.spark.sql.test.TestSparkSession
|
||||
|
||||
// Trait to configure StreamTest for kafka continuous execution tests.
|
||||
trait KafkaContinuousTest extends KafkaSourceTest {
|
||||
override val defaultTrigger = Trigger.Continuous(1000)
|
||||
override val defaultUseV2Sink = true
|
||||
|
||||
// We need more than the default local[2] to be able to schedule all partitions simultaneously.
|
||||
override protected def createSparkSession = new TestSparkSession(
|
||||
new SparkContext(
|
||||
"local[10]",
|
||||
"continuous-stream-test-sql-context",
|
||||
sparkConf.set("spark.sql.testkey", "true")))
|
||||
|
||||
// In addition to setting the partitions in Kafka, we have to wait until the query has
|
||||
// reconfigured to the new count so the test framework can hook in properly.
|
||||
override protected def setTopicPartitions(
|
||||
topic: String, newCount: Int, query: StreamExecution) = {
|
||||
testUtils.addPartitions(topic, newCount)
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
assert(
|
||||
query.lastExecution.logical.collectFirst {
|
||||
case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
|
||||
}.exists(_.knownPartitions.size == newCount),
|
||||
s"query never reconfigured to $newCount partitions")
|
||||
}
|
||||
}
|
||||
|
||||
// Continuous processing tasks end asynchronously, so test that they actually end.
|
||||
private val tasksEndedListener = new SparkListener() {
|
||||
val activeTaskIdCount = new AtomicInteger(0)
|
||||
|
||||
override def onTaskStart(start: SparkListenerTaskStart): Unit = {
|
||||
activeTaskIdCount.incrementAndGet()
|
||||
}
|
||||
|
||||
override def onTaskEnd(end: SparkListenerTaskEnd): Unit = {
|
||||
activeTaskIdCount.decrementAndGet()
|
||||
}
|
||||
}
|
||||
|
||||
override def beforeEach(): Unit = {
|
||||
super.beforeEach()
|
||||
spark.sparkContext.addSparkListener(tasksEndedListener)
|
||||
}
|
||||
|
||||
override def afterEach(): Unit = {
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
assert(tasksEndedListener.activeTaskIdCount.get() == 0)
|
||||
}
|
||||
spark.sparkContext.removeSparkListener(tasksEndedListener)
|
||||
super.afterEach()
|
||||
}
|
||||
|
||||
|
||||
test("ensure continuous stream is being used") {
|
||||
val query = spark.readStream
|
||||
.format("rate")
|
||||
.option("numPartitions", "1")
|
||||
.option("rowsPerSecond", "1")
|
||||
.load()
|
||||
|
||||
testStream(query)(
|
||||
Execute(q => assert(q.isInstanceOf[ContinuousExecution]))
|
||||
)
|
||||
}
|
||||
}
|
|
@ -34,11 +34,14 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout
|
|||
import org.scalatest.time.SpanSugar._
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.sql.ForeachWriter
|
||||
import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row}
|
||||
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec}
|
||||
import org.apache.spark.sql.execution.streaming._
|
||||
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
|
||||
import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter
|
||||
import org.apache.spark.sql.functions.{count, window}
|
||||
import org.apache.spark.sql.kafka010.KafkaSourceProvider._
|
||||
import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest}
|
||||
import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger}
|
||||
import org.apache.spark.sql.streaming.util.StreamManualClock
|
||||
import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}
|
||||
import org.apache.spark.util.Utils
|
||||
|
@ -49,9 +52,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
|
|||
|
||||
override val streamingTimeout = 30.seconds
|
||||
|
||||
protected val brokerProps = Map[String, Object]()
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
testUtils = new KafkaTestUtils
|
||||
testUtils = new KafkaTestUtils(brokerProps)
|
||||
testUtils.setup()
|
||||
}
|
||||
|
||||
|
@ -59,18 +64,25 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
|
|||
if (testUtils != null) {
|
||||
testUtils.teardown()
|
||||
testUtils = null
|
||||
super.afterAll()
|
||||
}
|
||||
super.afterAll()
|
||||
}
|
||||
|
||||
protected def makeSureGetOffsetCalled = AssertOnQuery { q =>
|
||||
// Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure
|
||||
// its "getOffset" is called before pushing any data. Otherwise, because of the race contion,
|
||||
// its "getOffset" is called before pushing any data. Otherwise, because of the race condition,
|
||||
// we don't know which data should be fetched when `startingOffsets` is latest.
|
||||
q.processAllAvailable()
|
||||
q match {
|
||||
case c: ContinuousExecution => c.awaitEpoch(0)
|
||||
case m: MicroBatchExecution => m.processAllAvailable()
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
protected def setTopicPartitions(topic: String, newCount: Int, query: StreamExecution) : Unit = {
|
||||
testUtils.addPartitions(topic, newCount)
|
||||
}
|
||||
|
||||
/**
|
||||
* Add data to Kafka.
|
||||
*
|
||||
|
@ -82,10 +94,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
|
|||
message: String = "",
|
||||
topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData {
|
||||
|
||||
override def addData(query: Option[StreamExecution]): (Source, Offset) = {
|
||||
if (query.get.isActive) {
|
||||
override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
|
||||
query match {
|
||||
// Make sure no Spark job is running when deleting a topic
|
||||
query.get.processAllAvailable()
|
||||
case Some(m: MicroBatchExecution) => m.processAllAvailable()
|
||||
case _ =>
|
||||
}
|
||||
|
||||
val existingTopics = testUtils.getAllTopicsAndPartitionSize().toMap
|
||||
|
@ -97,16 +110,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
|
|||
topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2))
|
||||
}
|
||||
|
||||
// Read all topics again in case some topics are delete.
|
||||
val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys
|
||||
require(
|
||||
query.nonEmpty,
|
||||
"Cannot add data when there is no query for finding the active kafka source")
|
||||
|
||||
val sources = query.get.logicalPlan.collect {
|
||||
case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] =>
|
||||
source.asInstanceOf[KafkaSource]
|
||||
}
|
||||
case StreamingExecutionRelation(source: KafkaSource, _) => source
|
||||
} ++ (query.get.lastExecution match {
|
||||
case null => Seq()
|
||||
case e => e.logical.collect {
|
||||
case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader
|
||||
}
|
||||
})
|
||||
if (sources.isEmpty) {
|
||||
throw new Exception(
|
||||
"Could not find Kafka source in the StreamExecution logical plan to add data to")
|
||||
|
@ -137,14 +152,158 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
|
|||
override def toString: String =
|
||||
s"AddKafkaData(topics = $topics, data = $data, message = $message)"
|
||||
}
|
||||
|
||||
private val topicId = new AtomicInteger(0)
|
||||
protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}"
|
||||
}
|
||||
|
||||
|
||||
class KafkaSourceSuite extends KafkaSourceTest {
|
||||
class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
private val topicId = new AtomicInteger(0)
|
||||
test("(de)serialization of initial offsets") {
|
||||
val topic = newTopic()
|
||||
testUtils.createTopic(topic, partitions = 5)
|
||||
|
||||
val reader = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("subscribe", topic)
|
||||
|
||||
testStream(reader.load)(
|
||||
makeSureGetOffsetCalled,
|
||||
StopStream,
|
||||
StartStream(),
|
||||
StopStream)
|
||||
}
|
||||
|
||||
test("maxOffsetsPerTrigger") {
|
||||
val topic = newTopic()
|
||||
testUtils.createTopic(topic, partitions = 3)
|
||||
testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0))
|
||||
testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1))
|
||||
testUtils.sendMessages(topic, Array("1"), Some(2))
|
||||
|
||||
val reader = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("kafka.metadata.max.age.ms", "1")
|
||||
.option("maxOffsetsPerTrigger", 10)
|
||||
.option("subscribe", topic)
|
||||
.option("startingOffsets", "earliest")
|
||||
val kafka = reader.load()
|
||||
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
|
||||
.as[(String, String)]
|
||||
val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt)
|
||||
|
||||
val clock = new StreamManualClock
|
||||
|
||||
val waitUntilBatchProcessed = AssertOnQuery { q =>
|
||||
eventually(Timeout(streamingTimeout)) {
|
||||
if (!q.exception.isDefined) {
|
||||
assert(clock.isStreamWaitingAt(clock.getTimeMillis()))
|
||||
}
|
||||
}
|
||||
if (q.exception.isDefined) {
|
||||
throw q.exception.get
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
testStream(mapped)(
|
||||
StartStream(ProcessingTime(100), clock),
|
||||
waitUntilBatchProcessed,
|
||||
// 1 from smallest, 1 from middle, 8 from biggest
|
||||
CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107),
|
||||
AdvanceManualClock(100),
|
||||
waitUntilBatchProcessed,
|
||||
// smallest now empty, 1 more from middle, 9 more from biggest
|
||||
CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
|
||||
11, 108, 109, 110, 111, 112, 113, 114, 115, 116
|
||||
),
|
||||
StopStream,
|
||||
StartStream(ProcessingTime(100), clock),
|
||||
waitUntilBatchProcessed,
|
||||
// smallest now empty, 1 more from middle, 9 more from biggest
|
||||
CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
|
||||
11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
|
||||
12, 117, 118, 119, 120, 121, 122, 123, 124, 125
|
||||
),
|
||||
AdvanceManualClock(100),
|
||||
waitUntilBatchProcessed,
|
||||
// smallest now empty, 1 more from middle, 9 more from biggest
|
||||
CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
|
||||
11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
|
||||
12, 117, 118, 119, 120, 121, 122, 123, 124, 125,
|
||||
13, 126, 127, 128, 129, 130, 131, 132, 133, 134
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
test("input row metrics") {
|
||||
val topic = newTopic()
|
||||
testUtils.createTopic(topic, partitions = 5)
|
||||
testUtils.sendMessages(topic, Array("-1"))
|
||||
require(testUtils.getLatestOffsets(Set(topic)).size === 5)
|
||||
|
||||
val kafka = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("subscribe", topic)
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.load()
|
||||
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
|
||||
.as[(String, String)]
|
||||
|
||||
val mapped = kafka.map(kv => kv._2.toInt + 1)
|
||||
testStream(mapped)(
|
||||
StartStream(trigger = ProcessingTime(1)),
|
||||
makeSureGetOffsetCalled,
|
||||
AddKafkaData(Set(topic), 1, 2, 3),
|
||||
CheckAnswer(2, 3, 4),
|
||||
AssertOnQuery { query =>
|
||||
val recordsRead = query.recentProgress.map(_.numInputRows).sum
|
||||
recordsRead == 3
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
test("subscribing topic by pattern with topic deletions") {
|
||||
val topicPrefix = newTopic()
|
||||
val topic = topicPrefix + "-seems"
|
||||
val topic2 = topicPrefix + "-bad"
|
||||
testUtils.createTopic(topic, partitions = 5)
|
||||
testUtils.sendMessages(topic, Array("-1"))
|
||||
require(testUtils.getLatestOffsets(Set(topic)).size === 5)
|
||||
|
||||
val reader = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("kafka.metadata.max.age.ms", "1")
|
||||
.option("subscribePattern", s"$topicPrefix-.*")
|
||||
.option("failOnDataLoss", "false")
|
||||
|
||||
val kafka = reader.load()
|
||||
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
|
||||
.as[(String, String)]
|
||||
val mapped = kafka.map(kv => kv._2.toInt + 1)
|
||||
|
||||
testStream(mapped)(
|
||||
makeSureGetOffsetCalled,
|
||||
AddKafkaData(Set(topic), 1, 2, 3),
|
||||
CheckAnswer(2, 3, 4),
|
||||
Assert {
|
||||
testUtils.deleteTopic(topic)
|
||||
testUtils.createTopic(topic2, partitions = 5)
|
||||
true
|
||||
},
|
||||
AddKafkaData(Set(topic2), 4, 5, 6),
|
||||
CheckAnswer(2, 3, 4, 5, 6, 7)
|
||||
)
|
||||
}
|
||||
|
||||
testWithUninterruptibleThread(
|
||||
"deserialization of initial offset with Spark 2.1.0") {
|
||||
|
@ -237,86 +396,94 @@ class KafkaSourceSuite extends KafkaSourceTest {
|
|||
}
|
||||
}
|
||||
|
||||
test("(de)serialization of initial offsets") {
|
||||
test("KafkaSource with watermark") {
|
||||
val now = System.currentTimeMillis()
|
||||
val topic = newTopic()
|
||||
testUtils.createTopic(topic, partitions = 64)
|
||||
testUtils.createTopic(newTopic(), partitions = 1)
|
||||
testUtils.sendMessages(topic, Array(1).map(_.toString))
|
||||
|
||||
val reader = spark
|
||||
val kafka = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("kafka.metadata.max.age.ms", "1")
|
||||
.option("startingOffsets", s"earliest")
|
||||
.option("subscribe", topic)
|
||||
.load()
|
||||
|
||||
testStream(reader.load)(
|
||||
makeSureGetOffsetCalled,
|
||||
StopStream,
|
||||
StartStream(),
|
||||
StopStream)
|
||||
val windowedAggregation = kafka
|
||||
.withWatermark("timestamp", "10 seconds")
|
||||
.groupBy(window($"timestamp", "5 seconds") as 'window)
|
||||
.agg(count("*") as 'count)
|
||||
.select($"window".getField("start") as 'window, $"count")
|
||||
|
||||
val query = windowedAggregation
|
||||
.writeStream
|
||||
.format("memory")
|
||||
.outputMode("complete")
|
||||
.queryName("kafkaWatermark")
|
||||
.start()
|
||||
query.processAllAvailable()
|
||||
val rows = spark.table("kafkaWatermark").collect()
|
||||
assert(rows.length === 1, s"Unexpected results: ${rows.toList}")
|
||||
val row = rows(0)
|
||||
// We cannot check the exact window start time as it depands on the time that messages were
|
||||
// inserted by the producer. So here we just use a low bound to make sure the internal
|
||||
// conversion works.
|
||||
assert(
|
||||
row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000,
|
||||
s"Unexpected results: $row")
|
||||
assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row")
|
||||
query.stop()
|
||||
}
|
||||
|
||||
test("maxOffsetsPerTrigger") {
|
||||
test("delete a topic when a Spark job is running") {
|
||||
KafkaSourceSuite.collectedData.clear()
|
||||
|
||||
val topic = newTopic()
|
||||
testUtils.createTopic(topic, partitions = 3)
|
||||
testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0))
|
||||
testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1))
|
||||
testUtils.sendMessages(topic, Array("1"), Some(2))
|
||||
testUtils.createTopic(topic, partitions = 1)
|
||||
testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray)
|
||||
|
||||
val reader = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("kafka.metadata.max.age.ms", "1")
|
||||
.option("maxOffsetsPerTrigger", 10)
|
||||
.option("subscribe", topic)
|
||||
// If a topic is deleted and we try to poll data starting from offset 0,
|
||||
// the Kafka consumer will just block until timeout and return an empty result.
|
||||
// So set the timeout to 1 second to make this test fast.
|
||||
.option("kafkaConsumer.pollTimeoutMs", "1000")
|
||||
.option("startingOffsets", "earliest")
|
||||
.option("failOnDataLoss", "false")
|
||||
val kafka = reader.load()
|
||||
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
|
||||
.as[(String, String)]
|
||||
val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt)
|
||||
|
||||
val clock = new StreamManualClock
|
||||
|
||||
val waitUntilBatchProcessed = AssertOnQuery { q =>
|
||||
eventually(Timeout(streamingTimeout)) {
|
||||
if (!q.exception.isDefined) {
|
||||
assert(clock.isStreamWaitingAt(clock.getTimeMillis()))
|
||||
}
|
||||
KafkaSourceSuite.globalTestUtils = testUtils
|
||||
// The following ForeachWriter will delete the topic before fetching data from Kafka
|
||||
// in executors.
|
||||
val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] {
|
||||
override def open(partitionId: Long, version: Long): Boolean = {
|
||||
KafkaSourceSuite.globalTestUtils.deleteTopic(topic)
|
||||
true
|
||||
}
|
||||
if (q.exception.isDefined) {
|
||||
throw q.exception.get
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
testStream(mapped)(
|
||||
StartStream(ProcessingTime(100), clock),
|
||||
waitUntilBatchProcessed,
|
||||
// 1 from smallest, 1 from middle, 8 from biggest
|
||||
CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107),
|
||||
AdvanceManualClock(100),
|
||||
waitUntilBatchProcessed,
|
||||
// smallest now empty, 1 more from middle, 9 more from biggest
|
||||
CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
|
||||
11, 108, 109, 110, 111, 112, 113, 114, 115, 116
|
||||
),
|
||||
StopStream,
|
||||
StartStream(ProcessingTime(100), clock),
|
||||
waitUntilBatchProcessed,
|
||||
// smallest now empty, 1 more from middle, 9 more from biggest
|
||||
CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
|
||||
11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
|
||||
12, 117, 118, 119, 120, 121, 122, 123, 124, 125
|
||||
),
|
||||
AdvanceManualClock(100),
|
||||
waitUntilBatchProcessed,
|
||||
// smallest now empty, 1 more from middle, 9 more from biggest
|
||||
CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
|
||||
11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
|
||||
12, 117, 118, 119, 120, 121, 122, 123, 124, 125,
|
||||
13, 126, 127, 128, 129, 130, 131, 132, 133, 134
|
||||
)
|
||||
)
|
||||
override def process(value: Int): Unit = {
|
||||
KafkaSourceSuite.collectedData.add(value)
|
||||
}
|
||||
|
||||
override def close(errorOrNull: Throwable): Unit = {}
|
||||
}).start()
|
||||
query.processAllAvailable()
|
||||
query.stop()
|
||||
// `failOnDataLoss` is `false`, we should not fail the query
|
||||
assert(query.exception.isEmpty)
|
||||
}
|
||||
}
|
||||
|
||||
class KafkaSourceSuiteBase extends KafkaSourceTest {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
test("SPARK-22956: currentPartitionOffsets should be set when no new data comes in") {
|
||||
def getSpecificDF(range: Range.Inclusive): org.apache.spark.sql.Dataset[Int] = {
|
||||
|
@ -393,7 +560,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
|
|||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("kafka.metadata.max.age.ms", "1")
|
||||
.option("subscribePattern", s"topic-.*")
|
||||
.option("subscribePattern", s"$topic.*")
|
||||
|
||||
val kafka = reader.load()
|
||||
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
|
||||
|
@ -487,65 +654,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
|
|||
}
|
||||
}
|
||||
|
||||
test("subscribing topic by pattern with topic deletions") {
|
||||
val topicPrefix = newTopic()
|
||||
val topic = topicPrefix + "-seems"
|
||||
val topic2 = topicPrefix + "-bad"
|
||||
testUtils.createTopic(topic, partitions = 5)
|
||||
testUtils.sendMessages(topic, Array("-1"))
|
||||
require(testUtils.getLatestOffsets(Set(topic)).size === 5)
|
||||
|
||||
val reader = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("kafka.metadata.max.age.ms", "1")
|
||||
.option("subscribePattern", s"$topicPrefix-.*")
|
||||
.option("failOnDataLoss", "false")
|
||||
|
||||
val kafka = reader.load()
|
||||
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
|
||||
.as[(String, String)]
|
||||
val mapped = kafka.map(kv => kv._2.toInt + 1)
|
||||
|
||||
testStream(mapped)(
|
||||
makeSureGetOffsetCalled,
|
||||
AddKafkaData(Set(topic), 1, 2, 3),
|
||||
CheckAnswer(2, 3, 4),
|
||||
Assert {
|
||||
testUtils.deleteTopic(topic)
|
||||
testUtils.createTopic(topic2, partitions = 5)
|
||||
true
|
||||
},
|
||||
AddKafkaData(Set(topic2), 4, 5, 6),
|
||||
CheckAnswer(2, 3, 4, 5, 6, 7)
|
||||
)
|
||||
}
|
||||
|
||||
test("starting offset is latest by default") {
|
||||
val topic = newTopic()
|
||||
testUtils.createTopic(topic, partitions = 5)
|
||||
testUtils.sendMessages(topic, Array("0"))
|
||||
require(testUtils.getLatestOffsets(Set(topic)).size === 5)
|
||||
|
||||
val reader = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("subscribe", topic)
|
||||
|
||||
val kafka = reader.load()
|
||||
.selectExpr("CAST(value AS STRING)")
|
||||
.as[String]
|
||||
val mapped = kafka.map(_.toInt)
|
||||
|
||||
testStream(mapped)(
|
||||
makeSureGetOffsetCalled,
|
||||
AddKafkaData(Set(topic), 1, 2, 3),
|
||||
CheckAnswer(1, 2, 3) // should not have 0
|
||||
)
|
||||
}
|
||||
|
||||
test("bad source options") {
|
||||
def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = {
|
||||
val ex = intercept[IllegalArgumentException] {
|
||||
|
@ -605,77 +713,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
|
|||
testUnsupportedConfig("kafka.auto.offset.reset", "latest")
|
||||
}
|
||||
|
||||
test("input row metrics") {
|
||||
val topic = newTopic()
|
||||
testUtils.createTopic(topic, partitions = 5)
|
||||
testUtils.sendMessages(topic, Array("-1"))
|
||||
require(testUtils.getLatestOffsets(Set(topic)).size === 5)
|
||||
|
||||
val kafka = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("subscribe", topic)
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.load()
|
||||
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
|
||||
.as[(String, String)]
|
||||
|
||||
val mapped = kafka.map(kv => kv._2.toInt + 1)
|
||||
testStream(mapped)(
|
||||
StartStream(trigger = ProcessingTime(1)),
|
||||
makeSureGetOffsetCalled,
|
||||
AddKafkaData(Set(topic), 1, 2, 3),
|
||||
CheckAnswer(2, 3, 4),
|
||||
AssertOnQuery { query =>
|
||||
val recordsRead = query.recentProgress.map(_.numInputRows).sum
|
||||
recordsRead == 3
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
test("delete a topic when a Spark job is running") {
|
||||
KafkaSourceSuite.collectedData.clear()
|
||||
|
||||
val topic = newTopic()
|
||||
testUtils.createTopic(topic, partitions = 1)
|
||||
testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray)
|
||||
|
||||
val reader = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("kafka.metadata.max.age.ms", "1")
|
||||
.option("subscribe", topic)
|
||||
// If a topic is deleted and we try to poll data starting from offset 0,
|
||||
// the Kafka consumer will just block until timeout and return an empty result.
|
||||
// So set the timeout to 1 second to make this test fast.
|
||||
.option("kafkaConsumer.pollTimeoutMs", "1000")
|
||||
.option("startingOffsets", "earliest")
|
||||
.option("failOnDataLoss", "false")
|
||||
val kafka = reader.load()
|
||||
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
|
||||
.as[(String, String)]
|
||||
KafkaSourceSuite.globalTestUtils = testUtils
|
||||
// The following ForeachWriter will delete the topic before fetching data from Kafka
|
||||
// in executors.
|
||||
val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] {
|
||||
override def open(partitionId: Long, version: Long): Boolean = {
|
||||
KafkaSourceSuite.globalTestUtils.deleteTopic(topic)
|
||||
true
|
||||
}
|
||||
|
||||
override def process(value: Int): Unit = {
|
||||
KafkaSourceSuite.collectedData.add(value)
|
||||
}
|
||||
|
||||
override def close(errorOrNull: Throwable): Unit = {}
|
||||
}).start()
|
||||
query.processAllAvailable()
|
||||
query.stop()
|
||||
// `failOnDataLoss` is `false`, we should not fail the query
|
||||
assert(query.exception.isEmpty)
|
||||
}
|
||||
|
||||
test("get offsets from case insensitive parameters") {
|
||||
for ((optionKey, optionValue, answer) <- Seq(
|
||||
(STARTING_OFFSETS_OPTION_KEY, "earLiEst", EarliestOffsetRangeLimit),
|
||||
|
@ -694,8 +731,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
|
|||
}
|
||||
}
|
||||
|
||||
private def newTopic(): String = s"topic-${topicId.getAndIncrement()}"
|
||||
|
||||
private def assignString(topic: String, partitions: Iterable[Int]): String = {
|
||||
JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p)))
|
||||
}
|
||||
|
@ -741,6 +776,10 @@ class KafkaSourceSuite extends KafkaSourceTest {
|
|||
|
||||
testStream(mapped)(
|
||||
makeSureGetOffsetCalled,
|
||||
Execute { q =>
|
||||
// wait to reach the last offset in every partition
|
||||
q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L)))
|
||||
},
|
||||
CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22),
|
||||
StopStream,
|
||||
StartStream(),
|
||||
|
@ -771,10 +810,13 @@ class KafkaSourceSuite extends KafkaSourceTest {
|
|||
.format("memory")
|
||||
.outputMode("append")
|
||||
.queryName("kafkaColumnTypes")
|
||||
.trigger(defaultTrigger)
|
||||
.start()
|
||||
query.processAllAvailable()
|
||||
val rows = spark.table("kafkaColumnTypes").collect()
|
||||
assert(rows.length === 1, s"Unexpected results: ${rows.toList}")
|
||||
var rows: Array[Row] = Array()
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
rows = spark.table("kafkaColumnTypes").collect()
|
||||
assert(rows.length === 1, s"Unexpected results: ${rows.toList}")
|
||||
}
|
||||
val row = rows(0)
|
||||
assert(row.getAs[Array[Byte]]("key") === null, s"Unexpected results: $row")
|
||||
assert(row.getAs[Array[Byte]]("value") === "1".getBytes(UTF_8), s"Unexpected results: $row")
|
||||
|
@ -788,47 +830,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
|
|||
query.stop()
|
||||
}
|
||||
|
||||
test("KafkaSource with watermark") {
|
||||
val now = System.currentTimeMillis()
|
||||
val topic = newTopic()
|
||||
testUtils.createTopic(newTopic(), partitions = 1)
|
||||
testUtils.sendMessages(topic, Array(1).map(_.toString))
|
||||
|
||||
val kafka = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("kafka.metadata.max.age.ms", "1")
|
||||
.option("startingOffsets", s"earliest")
|
||||
.option("subscribe", topic)
|
||||
.load()
|
||||
|
||||
val windowedAggregation = kafka
|
||||
.withWatermark("timestamp", "10 seconds")
|
||||
.groupBy(window($"timestamp", "5 seconds") as 'window)
|
||||
.agg(count("*") as 'count)
|
||||
.select($"window".getField("start") as 'window, $"count")
|
||||
|
||||
val query = windowedAggregation
|
||||
.writeStream
|
||||
.format("memory")
|
||||
.outputMode("complete")
|
||||
.queryName("kafkaWatermark")
|
||||
.start()
|
||||
query.processAllAvailable()
|
||||
val rows = spark.table("kafkaWatermark").collect()
|
||||
assert(rows.length === 1, s"Unexpected results: ${rows.toList}")
|
||||
val row = rows(0)
|
||||
// We cannot check the exact window start time as it depands on the time that messages were
|
||||
// inserted by the producer. So here we just use a low bound to make sure the internal
|
||||
// conversion works.
|
||||
assert(
|
||||
row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000,
|
||||
s"Unexpected results: $row")
|
||||
assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row")
|
||||
query.stop()
|
||||
}
|
||||
|
||||
private def testFromLatestOffsets(
|
||||
topic: String,
|
||||
addPartitions: Boolean,
|
||||
|
@ -865,9 +866,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
|
|||
AddKafkaData(Set(topic), 7, 8),
|
||||
CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9),
|
||||
AssertOnQuery("Add partitions") { query: StreamExecution =>
|
||||
if (addPartitions) {
|
||||
testUtils.addPartitions(topic, 10)
|
||||
}
|
||||
if (addPartitions) setTopicPartitions(topic, 10, query)
|
||||
true
|
||||
},
|
||||
AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16),
|
||||
|
@ -908,9 +907,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
|
|||
StartStream(),
|
||||
CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9),
|
||||
AssertOnQuery("Add partitions") { query: StreamExecution =>
|
||||
if (addPartitions) {
|
||||
testUtils.addPartitions(topic, 10)
|
||||
}
|
||||
if (addPartitions) setTopicPartitions(topic, 10, query)
|
||||
true
|
||||
},
|
||||
AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16),
|
||||
|
@ -1042,20 +1039,8 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared
|
|||
}
|
||||
}
|
||||
|
||||
test("stress test for failOnDataLoss=false") {
|
||||
val reader = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("kafka.metadata.max.age.ms", "1")
|
||||
.option("subscribePattern", "failOnDataLoss.*")
|
||||
.option("startingOffsets", "earliest")
|
||||
.option("failOnDataLoss", "false")
|
||||
.option("fetchOffset.retryIntervalMs", "3000")
|
||||
val kafka = reader.load()
|
||||
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
|
||||
.as[(String, String)]
|
||||
val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] {
|
||||
protected def startStream(ds: Dataset[Int]) = {
|
||||
ds.writeStream.foreach(new ForeachWriter[Int] {
|
||||
|
||||
override def open(partitionId: Long, version: Long): Boolean = {
|
||||
true
|
||||
|
@ -1069,6 +1054,22 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared
|
|||
override def close(errorOrNull: Throwable): Unit = {
|
||||
}
|
||||
}).start()
|
||||
}
|
||||
|
||||
test("stress test for failOnDataLoss=false") {
|
||||
val reader = spark
|
||||
.readStream
|
||||
.format("kafka")
|
||||
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
|
||||
.option("kafka.metadata.max.age.ms", "1")
|
||||
.option("subscribePattern", "failOnDataLoss.*")
|
||||
.option("startingOffsets", "earliest")
|
||||
.option("failOnDataLoss", "false")
|
||||
.option("fetchOffset.retryIntervalMs", "3000")
|
||||
val kafka = reader.load()
|
||||
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
|
||||
.as[(String, String)]
|
||||
val query = startStream(kafka.map(kv => kv._2.toInt))
|
||||
|
||||
val testTime = 1.minutes
|
||||
val startTime = System.currentTimeMillis()
|
||||
|
|
|
@ -191,6 +191,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
|
|||
ds = ds.asInstanceOf[DataSourceV2],
|
||||
conf = sparkSession.sessionState.conf)).asJava)
|
||||
|
||||
// Streaming also uses the data source V2 API. So it may be that the data source implements
|
||||
// v2, but has no v2 implementation for batch reads. In that case, we fall back to loading
|
||||
// the dataframe as a v1 source.
|
||||
val reader = (ds, userSpecifiedSchema) match {
|
||||
case (ds: ReadSupportWithSchema, Some(schema)) =>
|
||||
ds.createReader(schema, options)
|
||||
|
@ -208,23 +211,30 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
|
|||
}
|
||||
reader
|
||||
|
||||
case _ =>
|
||||
throw new AnalysisException(s"$cls does not support data reading.")
|
||||
case _ => null // fall back to v1
|
||||
}
|
||||
|
||||
Dataset.ofRows(sparkSession, DataSourceV2Relation(reader))
|
||||
if (reader == null) {
|
||||
loadV1Source(paths: _*)
|
||||
} else {
|
||||
Dataset.ofRows(sparkSession, DataSourceV2Relation(reader))
|
||||
}
|
||||
} else {
|
||||
// Code path for data source v1.
|
||||
sparkSession.baseRelationToDataFrame(
|
||||
DataSource.apply(
|
||||
sparkSession,
|
||||
paths = paths,
|
||||
userSpecifiedSchema = userSpecifiedSchema,
|
||||
className = source,
|
||||
options = extraOptions.toMap).resolveRelation())
|
||||
loadV1Source(paths: _*)
|
||||
}
|
||||
}
|
||||
|
||||
private def loadV1Source(paths: String*) = {
|
||||
// Code path for data source v1.
|
||||
sparkSession.baseRelationToDataFrame(
|
||||
DataSource.apply(
|
||||
sparkSession,
|
||||
paths = paths,
|
||||
userSpecifiedSchema = userSpecifiedSchema,
|
||||
className = source,
|
||||
options = extraOptions.toMap).resolveRelation())
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a `DataFrame` representing the database table accessible via JDBC URL
|
||||
* url named table and connection properties.
|
||||
|
|
|
@ -255,17 +255,24 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
|
|||
}
|
||||
}
|
||||
|
||||
case _ => throw new AnalysisException(s"$cls does not support data writing.")
|
||||
// Streaming also uses the data source V2 API. So it may be that the data source implements
|
||||
// v2, but has no v2 implementation for batch writes. In that case, we fall back to saving
|
||||
// as though it's a V1 source.
|
||||
case _ => saveToV1Source()
|
||||
}
|
||||
} else {
|
||||
// Code path for data source v1.
|
||||
runCommand(df.sparkSession, "save") {
|
||||
DataSource(
|
||||
sparkSession = df.sparkSession,
|
||||
className = source,
|
||||
partitionColumns = partitioningColumns.getOrElse(Nil),
|
||||
options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan))
|
||||
}
|
||||
saveToV1Source()
|
||||
}
|
||||
}
|
||||
|
||||
private def saveToV1Source(): Unit = {
|
||||
// Code path for data source v1.
|
||||
runCommand(df.sparkSession, "save") {
|
||||
DataSource(
|
||||
sparkSession = df.sparkSession,
|
||||
className = source,
|
||||
partitionColumns = partitioningColumns.getOrElse(Nil),
|
||||
options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -81,9 +81,11 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan)
|
|||
(index, message: WriterCommitMessage) => messages(index) = message
|
||||
)
|
||||
|
||||
logInfo(s"Data source writer $writer is committing.")
|
||||
writer.commit(messages)
|
||||
logInfo(s"Data source writer $writer committed.")
|
||||
if (!writer.isInstanceOf[ContinuousWriter]) {
|
||||
logInfo(s"Data source writer $writer is committing.")
|
||||
writer.commit(messages)
|
||||
logInfo(s"Data source writer $writer committed.")
|
||||
}
|
||||
} catch {
|
||||
case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] =>
|
||||
// Interruption is how continuous queries are ended, so accept and ignore the exception.
|
||||
|
|
|
@ -142,7 +142,8 @@ abstract class StreamExecution(
|
|||
|
||||
override val id: UUID = UUID.fromString(streamMetadata.id)
|
||||
|
||||
override val runId: UUID = UUID.randomUUID
|
||||
override def runId: UUID = currentRunId
|
||||
protected var currentRunId = UUID.randomUUID
|
||||
|
||||
/**
|
||||
* Pretty identified string of printing in logs. Format is
|
||||
|
@ -418,11 +419,17 @@ abstract class StreamExecution(
|
|||
* Blocks the current thread until processing for data from the given `source` has reached at
|
||||
* least the given `Offset`. This method is intended for use primarily when writing tests.
|
||||
*/
|
||||
private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = {
|
||||
private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = {
|
||||
assertAwaitThread()
|
||||
def notDone = {
|
||||
val localCommittedOffsets = committedOffsets
|
||||
!localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset
|
||||
if (sources == null) {
|
||||
// sources might not be initialized yet
|
||||
false
|
||||
} else {
|
||||
val source = sources(sourceIndex)
|
||||
!localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset
|
||||
}
|
||||
}
|
||||
|
||||
while (notDone) {
|
||||
|
@ -436,7 +443,7 @@ abstract class StreamExecution(
|
|||
awaitProgressLock.unlock()
|
||||
}
|
||||
}
|
||||
logDebug(s"Unblocked at $newOffset for $source")
|
||||
logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}")
|
||||
}
|
||||
|
||||
/** A flag to indicate that a batch has completed with no new data available. */
|
||||
|
|
|
@ -77,7 +77,6 @@ class ContinuousDataSourceRDD(
|
|||
dataReaderThread.start()
|
||||
|
||||
context.addTaskCompletionListener(_ => {
|
||||
reader.close()
|
||||
dataReaderThread.interrupt()
|
||||
epochPollExecutor.shutdown()
|
||||
})
|
||||
|
@ -177,6 +176,7 @@ class DataReaderThread(
|
|||
private[continuous] var failureReason: Throwable = _
|
||||
|
||||
override def run(): Unit = {
|
||||
TaskContext.setTaskContext(context)
|
||||
val baseReader = ContinuousDataSourceRDD.getBaseReader(reader)
|
||||
try {
|
||||
while (!context.isInterrupted && !context.isCompleted()) {
|
||||
|
@ -201,6 +201,8 @@ class DataReaderThread(
|
|||
failedFlag.set(true)
|
||||
// Don't rethrow the exception in this thread. It's not needed, and the default Spark
|
||||
// exception handler will kill the executor.
|
||||
} finally {
|
||||
reader.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,9 @@
|
|||
|
||||
package org.apache.spark.sql.execution.streaming.continuous
|
||||
|
||||
import java.util.UUID
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.function.UnaryOperator
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable.{ArrayBuffer, Map => MutableMap}
|
||||
|
@ -52,7 +54,7 @@ class ContinuousExecution(
|
|||
sparkSession, name, checkpointRoot, analyzedPlan, sink,
|
||||
trigger, triggerClock, outputMode, deleteCheckpointOnStop) {
|
||||
|
||||
@volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty
|
||||
@volatile protected var continuousSources: Seq[ContinuousReader] = _
|
||||
override protected def sources: Seq[BaseStreamingSource] = continuousSources
|
||||
|
||||
override lazy val logicalPlan: LogicalPlan = {
|
||||
|
@ -78,15 +80,17 @@ class ContinuousExecution(
|
|||
}
|
||||
|
||||
override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = {
|
||||
do {
|
||||
try {
|
||||
runContinuous(sparkSessionForStream)
|
||||
} catch {
|
||||
case _: InterruptedException if state.get().equals(RECONFIGURING) =>
|
||||
// swallow exception and run again
|
||||
state.set(ACTIVE)
|
||||
val stateUpdate = new UnaryOperator[State] {
|
||||
override def apply(s: State) = s match {
|
||||
// If we ended the query to reconfigure, reset the state to active.
|
||||
case RECONFIGURING => ACTIVE
|
||||
case _ => s
|
||||
}
|
||||
} while (state.get() == ACTIVE)
|
||||
}
|
||||
|
||||
do {
|
||||
runContinuous(sparkSessionForStream)
|
||||
} while (state.updateAndGet(stateUpdate) == ACTIVE)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -120,12 +124,16 @@ class ContinuousExecution(
|
|||
}
|
||||
committedOffsets = nextOffsets.toStreamProgress(sources)
|
||||
|
||||
// Forcibly align commit and offset logs by slicing off any spurious offset logs from
|
||||
// a previous run. We can't allow commits to an epoch that a previous run reached but
|
||||
// this run has not.
|
||||
offsetLog.purgeAfter(latestEpochId)
|
||||
// Get to an epoch ID that has definitely never been sent to a sink before. Since sink
|
||||
// commit happens between offset log write and commit log write, this means an epoch ID
|
||||
// which is not in the offset log.
|
||||
val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse {
|
||||
throw new IllegalStateException(
|
||||
s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" +
|
||||
s"an element.")
|
||||
}
|
||||
currentBatchId = latestOffsetEpoch + 1
|
||||
|
||||
currentBatchId = latestEpochId + 1
|
||||
logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets")
|
||||
nextOffsets
|
||||
case None =>
|
||||
|
@ -141,6 +149,7 @@ class ContinuousExecution(
|
|||
* @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with.
|
||||
*/
|
||||
private def runContinuous(sparkSessionForQuery: SparkSession): Unit = {
|
||||
currentRunId = UUID.randomUUID
|
||||
// A list of attributes that will need to be updated.
|
||||
val replacements = new ArrayBuffer[(Attribute, Attribute)]
|
||||
// Translate from continuous relation to the underlying data source.
|
||||
|
@ -225,13 +234,11 @@ class ContinuousExecution(
|
|||
triggerExecutor.execute(() => {
|
||||
startTrigger()
|
||||
|
||||
if (reader.needsReconfiguration()) {
|
||||
state.set(RECONFIGURING)
|
||||
if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) {
|
||||
stopSources()
|
||||
if (queryExecutionThread.isAlive) {
|
||||
sparkSession.sparkContext.cancelJobGroup(runId.toString)
|
||||
queryExecutionThread.interrupt()
|
||||
// No need to join - this thread is about to end anyway.
|
||||
}
|
||||
false
|
||||
} else if (isActive) {
|
||||
|
@ -259,6 +266,7 @@ class ContinuousExecution(
|
|||
sparkSessionForQuery, lastExecution)(lastExecution.toRdd)
|
||||
}
|
||||
} finally {
|
||||
epochEndpoint.askSync[Unit](StopContinuousExecutionWrites)
|
||||
SparkEnv.get.rpcEnv.stop(epochEndpoint)
|
||||
|
||||
epochUpdateThread.interrupt()
|
||||
|
@ -273,17 +281,22 @@ class ContinuousExecution(
|
|||
epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = {
|
||||
assert(continuousSources.length == 1, "only one continuous source supported currently")
|
||||
|
||||
if (partitionOffsets.contains(null)) {
|
||||
// If any offset is null, that means the corresponding partition hasn't seen any data yet, so
|
||||
// there's nothing meaningful to add to the offset log.
|
||||
}
|
||||
val globalOffset = reader.mergeOffsets(partitionOffsets.toArray)
|
||||
synchronized {
|
||||
if (queryExecutionThread.isAlive) {
|
||||
offsetLog.add(epoch, OffsetSeq.fill(globalOffset))
|
||||
} else {
|
||||
return
|
||||
}
|
||||
val oldOffset = synchronized {
|
||||
offsetLog.add(epoch, OffsetSeq.fill(globalOffset))
|
||||
offsetLog.get(epoch - 1)
|
||||
}
|
||||
|
||||
// If offset hasn't changed since last epoch, there's been no new data.
|
||||
if (oldOffset.contains(OffsetSeq.fill(globalOffset))) {
|
||||
noNewData = true
|
||||
}
|
||||
|
||||
awaitProgressLock.lock()
|
||||
try {
|
||||
awaitProgressLockCondition.signalAll()
|
||||
} finally {
|
||||
awaitProgressLock.unlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -39,6 +39,15 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable
|
|||
*/
|
||||
private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage
|
||||
|
||||
/**
|
||||
* The RpcEndpoint stop() will wait to clear out the message queue before terminating the
|
||||
* object. This can lead to a race condition where the query restarts at epoch n, a new
|
||||
* EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1.
|
||||
* The framework doesn't provide a handle to wait on the message queue, so we use a synchronous
|
||||
* message to stop any writes to the ContinuousExecution object.
|
||||
*/
|
||||
private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage
|
||||
|
||||
// Init messages
|
||||
/**
|
||||
* Set the reader and writer partition counts. Tasks may not be started until the coordinator
|
||||
|
@ -116,6 +125,8 @@ private[continuous] class EpochCoordinator(
|
|||
override val rpcEnv: RpcEnv)
|
||||
extends ThreadSafeRpcEndpoint with Logging {
|
||||
|
||||
private var queryWritesStopped: Boolean = false
|
||||
|
||||
private var numReaderPartitions: Int = _
|
||||
private var numWriterPartitions: Int = _
|
||||
|
||||
|
@ -147,12 +158,16 @@ private[continuous] class EpochCoordinator(
|
|||
partitionCommits.remove(k)
|
||||
}
|
||||
for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) {
|
||||
partitionCommits.remove(k)
|
||||
partitionOffsets.remove(k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def receive: PartialFunction[Any, Unit] = {
|
||||
// If we just drop these messages, we won't do any writes to the query. The lame duck tasks
|
||||
// won't shed errors or anything.
|
||||
case _ if queryWritesStopped => ()
|
||||
|
||||
case CommitPartitionEpoch(partitionId, epoch, message) =>
|
||||
logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message")
|
||||
if (!partitionCommits.isDefinedAt((epoch, partitionId))) {
|
||||
|
@ -188,5 +203,9 @@ private[continuous] class EpochCoordinator(
|
|||
case SetWriterPartitions(numPartitions) =>
|
||||
numWriterPartitions = numPartitions
|
||||
context.reply(())
|
||||
|
||||
case StopContinuousExecutionWrites =>
|
||||
queryWritesStopped = true
|
||||
context.reply(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource
|
|||
import org.apache.spark.sql.execution.streaming._
|
||||
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
|
||||
import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2}
|
||||
import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport
|
||||
|
||||
/**
|
||||
* Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
|
||||
|
@ -279,18 +280,29 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
|
|||
useTempCheckpointLocation = true,
|
||||
trigger = trigger)
|
||||
} else {
|
||||
val dataSource =
|
||||
DataSource(
|
||||
df.sparkSession,
|
||||
className = source,
|
||||
options = extraOptions.toMap,
|
||||
partitionColumns = normalizedParCols.getOrElse(Nil))
|
||||
val sink = trigger match {
|
||||
case _: ContinuousTrigger =>
|
||||
val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
|
||||
ds.newInstance() match {
|
||||
case w: ContinuousWriteSupport => w
|
||||
case _ => throw new AnalysisException(
|
||||
s"Data source $source does not support continuous writing")
|
||||
}
|
||||
case _ =>
|
||||
val ds = DataSource(
|
||||
df.sparkSession,
|
||||
className = source,
|
||||
options = extraOptions.toMap,
|
||||
partitionColumns = normalizedParCols.getOrElse(Nil))
|
||||
ds.createSink(outputMode)
|
||||
}
|
||||
|
||||
df.sparkSession.sessionState.streamingQueryManager.startQuery(
|
||||
extraOptions.get("queryName"),
|
||||
extraOptions.get("checkpointLocation"),
|
||||
df,
|
||||
extraOptions.toMap,
|
||||
dataSource.createSink(outputMode),
|
||||
sink,
|
||||
outputMode,
|
||||
useTempCheckpointLocation = source == "console",
|
||||
recoverFromCheckpointLocation = true,
|
||||
|
|
|
@ -38,8 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row}
|
|||
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.catalyst.util._
|
||||
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
|
||||
import org.apache.spark.sql.execution.streaming._
|
||||
import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch}
|
||||
import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch}
|
||||
import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2
|
||||
import org.apache.spark.sql.execution.streaming.state.StateStore
|
||||
import org.apache.spark.sql.streaming.StreamingQueryListener._
|
||||
|
@ -80,6 +81,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
|||
StateStore.stop() // stop the state store maintenance thread and unload store providers
|
||||
}
|
||||
|
||||
protected val defaultTrigger = Trigger.ProcessingTime(0)
|
||||
protected val defaultUseV2Sink = false
|
||||
|
||||
/** How long to wait for an active stream to catch up when checking a result. */
|
||||
val streamingTimeout = 10.seconds
|
||||
|
||||
|
@ -189,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
|||
|
||||
/** Starts the stream, resuming if data has already been processed. It must not be running. */
|
||||
case class StartStream(
|
||||
trigger: Trigger = Trigger.ProcessingTime(0),
|
||||
trigger: Trigger = defaultTrigger,
|
||||
triggerClock: Clock = new SystemClock,
|
||||
additionalConfs: Map[String, String] = Map.empty,
|
||||
checkpointLocation: String = null)
|
||||
|
@ -276,7 +280,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
|||
def testStream(
|
||||
_stream: Dataset[_],
|
||||
outputMode: OutputMode = OutputMode.Append,
|
||||
useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized {
|
||||
useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized {
|
||||
import org.apache.spark.sql.streaming.util.StreamManualClock
|
||||
|
||||
// `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently
|
||||
|
@ -403,18 +407,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
|||
|
||||
def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = {
|
||||
verify(currentStream != null, "stream not running")
|
||||
// Get the map of source index to the current source objects
|
||||
val indexToSource = currentStream
|
||||
.logicalPlan
|
||||
.collect { case StreamingExecutionRelation(s, _) => s }
|
||||
.zipWithIndex
|
||||
.map(_.swap)
|
||||
.toMap
|
||||
|
||||
// Block until all data added has been processed for all the source
|
||||
awaiting.foreach { case (sourceIndex, offset) =>
|
||||
failAfter(streamingTimeout) {
|
||||
currentStream.awaitOffset(indexToSource(sourceIndex), offset)
|
||||
currentStream.awaitOffset(sourceIndex, offset)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -473,6 +470,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
|||
// after starting the query.
|
||||
try {
|
||||
currentStream.awaitInitialization(streamingTimeout.toMillis)
|
||||
currentStream match {
|
||||
case s: ContinuousExecution => eventually("IncrementalExecution was not created") {
|
||||
s.lastExecution.executedPlan // will fail if lastExecution is null
|
||||
}
|
||||
case _ =>
|
||||
}
|
||||
} catch {
|
||||
case _: StreamingQueryException =>
|
||||
// Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well.
|
||||
|
@ -600,7 +603,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
|||
|
||||
def findSourceIndex(plan: LogicalPlan): Option[Int] = {
|
||||
plan
|
||||
.collect { case StreamingExecutionRelation(s, _) => s }
|
||||
.collect {
|
||||
case StreamingExecutionRelation(s, _) => s
|
||||
case DataSourceV2Relation(_, r) => r
|
||||
}
|
||||
.zipWithIndex
|
||||
.find(_._1 == source)
|
||||
.map(_._2)
|
||||
|
@ -613,9 +619,13 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
|
|||
findSourceIndex(query.logicalPlan)
|
||||
}.orElse {
|
||||
findSourceIndex(stream.logicalPlan)
|
||||
}.orElse {
|
||||
queryToUse.flatMap { q =>
|
||||
findSourceIndex(q.lastExecution.logical)
|
||||
}
|
||||
}.getOrElse {
|
||||
throw new IllegalArgumentException(
|
||||
"Could find index of the source to which data was added")
|
||||
"Could not find index of the source to which data was added")
|
||||
}
|
||||
|
||||
// Store the expected offset of added data to wait for it later
|
||||
|
|
Loading…
Reference in a new issue