From edf5cc64e4bfc34643952f3a9582beca20c4bddc Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 20 Sep 2018 20:22:55 +0800 Subject: [PATCH] [SPARK-25460][SS] DataSourceV2: SS sources do not respect SessionConfigSupport ## What changes were proposed in this pull request? This PR proposes to respect `SessionConfigSupport` in SS datasources as well. Currently these are only respected in batch sources: https://github.com/apache/spark/blob/e06da95cd9423f55cdb154a2778b0bddf7be984c/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala#L198-L203 https://github.com/apache/spark/blob/e06da95cd9423f55cdb154a2778b0bddf7be984c/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala#L244-L249 If a developer makes a datasource V2 that supports both structured streaming and batch jobs, batch jobs respect a specific configuration, let's say, URL to connect and fetch data (which end users might not be aware of); however, structured streaming ends up with not supporting this (and should explicitly be set into options). ## How was this patch tested? Unit tests were added. Closes #22462 from HyukjinKwon/SPARK-25460. Authored-by: hyukjinkwon Signed-off-by: Wenchen Fan --- .../sql/streaming/DataStreamReader.scala | 24 ++-- .../sql/streaming/DataStreamWriter.scala | 16 ++- .../sources/StreamingDataSourceV2Suite.scala | 116 +++++++++++++++--- 3 files changed, 128 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 2a4db4afbe..4c7dcedafe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} @@ -158,7 +159,6 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf).newInstance() - val options = new DataSourceOptions(extraOptions.asJava) // We need to generate the V1 data source so we can pass it to the V2 relation as a shim. // We can't be sure at this point whether we'll actually want to use V2, since we don't know the // writer or whether the query is continuous. @@ -173,13 +173,18 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } ds match { case s: MicroBatchReadSupportProvider => + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + ds = s, conf = sparkSession.sessionState.conf) + val options = sessionOptions ++ extraOptions + val dataSourceOptions = new DataSourceOptions(options.asJava) var tempReadSupport: MicroBatchReadSupport = null val schema = try { val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath tempReadSupport = if (userSpecifiedSchema.isDefined) { - s.createMicroBatchReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options) + s.createMicroBatchReadSupport( + userSpecifiedSchema.get, tmpCheckpointPath, dataSourceOptions) } else { - s.createMicroBatchReadSupport(tmpCheckpointPath, options) + s.createMicroBatchReadSupport(tmpCheckpointPath, dataSourceOptions) } tempReadSupport.fullSchema() } finally { @@ -192,16 +197,21 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo Dataset.ofRows( sparkSession, StreamingRelationV2( - s, source, extraOptions.toMap, + s, source, options, schema.toAttributes, v1Relation)(sparkSession)) case s: ContinuousReadSupportProvider => + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + ds = s, conf = sparkSession.sessionState.conf) + val options = sessionOptions ++ extraOptions + val dataSourceOptions = new DataSourceOptions(options.asJava) var tempReadSupport: ContinuousReadSupport = null val schema = try { val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath tempReadSupport = if (userSpecifiedSchema.isDefined) { - s.createContinuousReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options) + s.createContinuousReadSupport( + userSpecifiedSchema.get, tmpCheckpointPath, dataSourceOptions) } else { - s.createContinuousReadSupport(tmpCheckpointPath, options) + s.createContinuousReadSupport(tmpCheckpointPath, dataSourceOptions) } tempReadSupport.fullSchema() } finally { @@ -214,7 +224,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo Dataset.ofRows( sparkSession, StreamingRelationV2( - s, source, extraOptions.toMap, + s, source, options, schema.toAttributes, v1Relation)(sparkSession)) case _ => // Code path for data source v1. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 7866e4f70f..e9a15214d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources._ @@ -298,23 +299,28 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } else { val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") + var options = extraOptions.toMap val sink = ds.newInstance() match { case w: StreamingWriteSupportProvider - if !disabledSources.contains(w.getClass.getCanonicalName) => w + if !disabledSources.contains(w.getClass.getCanonicalName) => + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + w, df.sparkSession.sessionState.conf) + options = sessionOptions ++ extraOptions + w case _ => val ds = DataSource( df.sparkSession, className = source, - options = extraOptions.toMap, + options = options, partitionColumns = normalizedParCols.getOrElse(Nil)) ds.createSink(outputMode) } df.sparkSession.sessionState.streamingQueryManager.startQuery( - extraOptions.get("queryName"), - extraOptions.get("checkpointLocation"), + options.get("queryName"), + options.get("checkpointLocation"), df, - extraOptions.toMap, + options, sink, outputMode, useTempCheckpointLocation = source == "console", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index aeef4c8fe9..3a0e780a73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReaderFactory, ScanConfig, ScanConfigBuilder} import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport -import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} +import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -56,13 +56,19 @@ case class FakeReadSupport() extends MicroBatchReadSupport with ContinuousReadSu trait FakeMicroBatchReadSupportProvider extends MicroBatchReadSupportProvider { override def createMicroBatchReadSupport( checkpointLocation: String, - options: DataSourceOptions): MicroBatchReadSupport = FakeReadSupport() + options: DataSourceOptions): MicroBatchReadSupport = { + LastReadOptions.options = options + FakeReadSupport() + } } trait FakeContinuousReadSupportProvider extends ContinuousReadSupportProvider { override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReadSupport = FakeReadSupport() + options: DataSourceOptions): ContinuousReadSupport = { + LastReadOptions.options = options + FakeReadSupport() + } } trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider { @@ -71,16 +77,27 @@ trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider { schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamingWriteSupport = { + LastWriteOptions.options = options throw new IllegalStateException("fake sink - cannot actually write") } } -class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupportProvider { +class FakeReadMicroBatchOnly + extends DataSourceRegister + with FakeMicroBatchReadSupportProvider + with SessionConfigSupport { override def shortName(): String = "fake-read-microbatch-only" + + override def keyPrefix: String = shortName() } -class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupportProvider { +class FakeReadContinuousOnly + extends DataSourceRegister + with FakeContinuousReadSupportProvider + with SessionConfigSupport { override def shortName(): String = "fake-read-continuous-only" + + override def keyPrefix: String = shortName() } class FakeReadBothModes extends DataSourceRegister @@ -92,8 +109,13 @@ class FakeReadNeitherMode extends DataSourceRegister { override def shortName(): String = "fake-read-neither-mode" } -class FakeWriteSupportProvider extends DataSourceRegister with FakeStreamingWriteSupportProvider { +class FakeWriteSupportProvider + extends DataSourceRegister + with FakeStreamingWriteSupportProvider + with SessionConfigSupport { override def shortName(): String = "fake-write-microbatch-continuous" + + override def keyPrefix: String = shortName() } class FakeNoWrite extends DataSourceRegister { @@ -121,6 +143,21 @@ class FakeWriteSupportProviderV1Fallback extends DataSourceRegister override def shortName(): String = "fake-write-v1-fallback" } +object LastReadOptions { + var options: DataSourceOptions = _ + + def clear(): Unit = { + options = null + } +} + +object LastWriteOptions { + var options: DataSourceOptions = _ + + def clear(): Unit = { + options = null + } +} class StreamingDataSourceV2Suite extends StreamTest { @@ -130,6 +167,11 @@ class StreamingDataSourceV2Suite extends StreamTest { spark.conf.set("spark.sql.streaming.checkpointLocation", fakeCheckpoint.getCanonicalPath) } + override def afterEach(): Unit = { + LastReadOptions.clear() + LastWriteOptions.clear() + } + val readFormats = Seq( "fake-read-microbatch-only", "fake-read-continuous-only", @@ -143,7 +185,14 @@ class StreamingDataSourceV2Suite extends StreamTest { Trigger.ProcessingTime(1000), Trigger.Continuous(1000)) - private def testPositiveCase(readFormat: String, writeFormat: String, trigger: Trigger) = { + private def testPositiveCase(readFormat: String, writeFormat: String, trigger: Trigger): Unit = { + testPositiveCaseWithQuery(readFormat, writeFormat, trigger)(() => _) + } + + private def testPositiveCaseWithQuery( + readFormat: String, + writeFormat: String, + trigger: Trigger)(check: StreamingQuery => Unit): Unit = { val query = spark.readStream .format(readFormat) .load() @@ -151,8 +200,8 @@ class StreamingDataSourceV2Suite extends StreamTest { .format(writeFormat) .trigger(trigger) .start() + check(query) query.stop() - query } private def testNegativeCase( @@ -188,19 +237,54 @@ class StreamingDataSourceV2Suite extends StreamTest { test("disabled v2 write") { // Ensure the V2 path works normally and generates a V2 sink.. - val v2Query = testPositiveCase( - "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) - assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink - .isInstanceOf[FakeWriteSupportProviderV1Fallback]) + testPositiveCaseWithQuery( + "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) { v2Query => + assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink + .isInstanceOf[FakeWriteSupportProviderV1Fallback]) + } // Ensure we create a V1 sink with the config. Note the config is a comma separated // list, including other fake entries. val fullSinkName = classOf[FakeWriteSupportProviderV1Fallback].getName withSQLConf(SQLConf.DISABLED_V2_STREAMING_WRITERS.key -> s"a,b,c,test,$fullSinkName,d,e") { - val v1Query = testPositiveCase( - "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) - assert(v1Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink - .isInstanceOf[FakeSink]) + testPositiveCaseWithQuery( + "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) { v1Query => + assert(v1Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink + .isInstanceOf[FakeSink]) + } + } + } + + Seq( + Tuple2(classOf[FakeReadMicroBatchOnly], Trigger.Once()), + Tuple2(classOf[FakeReadContinuousOnly], Trigger.Continuous(1000)) + ).foreach { case (source, trigger) => + test(s"SPARK-25460: session options are respected in structured streaming sources - $source") { + // `keyPrefix` and `shortName` are the same in this test case + val readSource = source.newInstance().shortName() + val writeSource = "fake-write-microbatch-continuous" + + val readOptionName = "optionA" + withSQLConf(s"spark.datasource.$readSource.$readOptionName" -> "true") { + testPositiveCaseWithQuery(readSource, writeSource, trigger) { _ => + eventually(timeout(streamingTimeout)) { + // Write options should not be set. + assert(LastWriteOptions.options.getBoolean(readOptionName, false) == false) + assert(LastReadOptions.options.getBoolean(readOptionName, false) == true) + } + } + } + + val writeOptionName = "optionB" + withSQLConf(s"spark.datasource.$writeSource.$writeOptionName" -> "true") { + testPositiveCaseWithQuery(readSource, writeSource, trigger) { _ => + eventually(timeout(streamingTimeout)) { + // Read options should not be set. + assert(LastReadOptions.options.getBoolean(writeOptionName, false) == false) + assert(LastWriteOptions.options.getBoolean(writeOptionName, false) == true) + } + } + } } }