[SPARK-25460][SS] DataSourceV2: SS sources do not respect SessionConfigSupport

## What changes were proposed in this pull request?

This PR proposes to respect `SessionConfigSupport` in SS datasources as well. Currently these are only respected in batch sources:

e06da95cd9/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala (L198-L203)

e06da95cd9/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala (L244-L249)

If a developer makes a datasource V2 that supports both structured streaming and batch jobs, batch jobs respect a specific configuration, let's say, URL to connect and fetch data (which end users might not be aware of); however, structured streaming ends up with not supporting this (and should explicitly be set into options).

## How was this patch tested?

Unit tests were added.

Closes #22462 from HyukjinKwon/SPARK-25460.

Authored-by: hyukjinkwon <gurwls223@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
hyukjinkwon 2018-09-20 20:22:55 +08:00 committed by Wenchen Fan
parent 89671a27e7
commit edf5cc64e4
3 changed files with 128 additions and 28 deletions

View file

@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2}
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider}
@ -158,7 +159,6 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
}
val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf).newInstance()
val options = new DataSourceOptions(extraOptions.asJava)
// We need to generate the V1 data source so we can pass it to the V2 relation as a shim.
// We can't be sure at this point whether we'll actually want to use V2, since we don't know the
// writer or whether the query is continuous.
@ -173,13 +173,18 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
}
ds match {
case s: MicroBatchReadSupportProvider =>
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
ds = s, conf = sparkSession.sessionState.conf)
val options = sessionOptions ++ extraOptions
val dataSourceOptions = new DataSourceOptions(options.asJava)
var tempReadSupport: MicroBatchReadSupport = null
val schema = try {
val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath
tempReadSupport = if (userSpecifiedSchema.isDefined) {
s.createMicroBatchReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options)
s.createMicroBatchReadSupport(
userSpecifiedSchema.get, tmpCheckpointPath, dataSourceOptions)
} else {
s.createMicroBatchReadSupport(tmpCheckpointPath, options)
s.createMicroBatchReadSupport(tmpCheckpointPath, dataSourceOptions)
}
tempReadSupport.fullSchema()
} finally {
@ -192,16 +197,21 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
Dataset.ofRows(
sparkSession,
StreamingRelationV2(
s, source, extraOptions.toMap,
s, source, options,
schema.toAttributes, v1Relation)(sparkSession))
case s: ContinuousReadSupportProvider =>
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
ds = s, conf = sparkSession.sessionState.conf)
val options = sessionOptions ++ extraOptions
val dataSourceOptions = new DataSourceOptions(options.asJava)
var tempReadSupport: ContinuousReadSupport = null
val schema = try {
val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath
tempReadSupport = if (userSpecifiedSchema.isDefined) {
s.createContinuousReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options)
s.createContinuousReadSupport(
userSpecifiedSchema.get, tmpCheckpointPath, dataSourceOptions)
} else {
s.createContinuousReadSupport(tmpCheckpointPath, options)
s.createContinuousReadSupport(tmpCheckpointPath, dataSourceOptions)
}
tempReadSupport.fullSchema()
} finally {
@ -214,7 +224,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
Dataset.ofRows(
sparkSession,
StreamingRelationV2(
s, source, extraOptions.toMap,
s, source, options,
schema.toAttributes, v1Relation)(sparkSession))
case _ =>
// Code path for data source v1.

View file

@ -27,6 +27,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
import org.apache.spark.sql.execution.streaming.sources._
@ -298,23 +299,28 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
} else {
val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",")
var options = extraOptions.toMap
val sink = ds.newInstance() match {
case w: StreamingWriteSupportProvider
if !disabledSources.contains(w.getClass.getCanonicalName) => w
if !disabledSources.contains(w.getClass.getCanonicalName) =>
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
w, df.sparkSession.sessionState.conf)
options = sessionOptions ++ extraOptions
w
case _ =>
val ds = DataSource(
df.sparkSession,
className = source,
options = extraOptions.toMap,
options = options,
partitionColumns = normalizedParCols.getOrElse(Nil))
ds.createSink(outputMode)
}
df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
options.get("queryName"),
options.get("checkpointLocation"),
df,
extraOptions.toMap,
options,
sink,
outputMode,
useTempCheckpointLocation = source == "console",

View file

@ -27,7 +27,7 @@ import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReaderFactory, ScanConfig, ScanConfigBuilder}
import org.apache.spark.sql.sources.v2.reader.streaming._
import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger}
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, StreamTest, Trigger}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
@ -56,13 +56,19 @@ case class FakeReadSupport() extends MicroBatchReadSupport with ContinuousReadSu
trait FakeMicroBatchReadSupportProvider extends MicroBatchReadSupportProvider {
override def createMicroBatchReadSupport(
checkpointLocation: String,
options: DataSourceOptions): MicroBatchReadSupport = FakeReadSupport()
options: DataSourceOptions): MicroBatchReadSupport = {
LastReadOptions.options = options
FakeReadSupport()
}
}
trait FakeContinuousReadSupportProvider extends ContinuousReadSupportProvider {
override def createContinuousReadSupport(
checkpointLocation: String,
options: DataSourceOptions): ContinuousReadSupport = FakeReadSupport()
options: DataSourceOptions): ContinuousReadSupport = {
LastReadOptions.options = options
FakeReadSupport()
}
}
trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider {
@ -71,16 +77,27 @@ trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider {
schema: StructType,
mode: OutputMode,
options: DataSourceOptions): StreamingWriteSupport = {
LastWriteOptions.options = options
throw new IllegalStateException("fake sink - cannot actually write")
}
}
class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupportProvider {
class FakeReadMicroBatchOnly
extends DataSourceRegister
with FakeMicroBatchReadSupportProvider
with SessionConfigSupport {
override def shortName(): String = "fake-read-microbatch-only"
override def keyPrefix: String = shortName()
}
class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupportProvider {
class FakeReadContinuousOnly
extends DataSourceRegister
with FakeContinuousReadSupportProvider
with SessionConfigSupport {
override def shortName(): String = "fake-read-continuous-only"
override def keyPrefix: String = shortName()
}
class FakeReadBothModes extends DataSourceRegister
@ -92,8 +109,13 @@ class FakeReadNeitherMode extends DataSourceRegister {
override def shortName(): String = "fake-read-neither-mode"
}
class FakeWriteSupportProvider extends DataSourceRegister with FakeStreamingWriteSupportProvider {
class FakeWriteSupportProvider
extends DataSourceRegister
with FakeStreamingWriteSupportProvider
with SessionConfigSupport {
override def shortName(): String = "fake-write-microbatch-continuous"
override def keyPrefix: String = shortName()
}
class FakeNoWrite extends DataSourceRegister {
@ -121,6 +143,21 @@ class FakeWriteSupportProviderV1Fallback extends DataSourceRegister
override def shortName(): String = "fake-write-v1-fallback"
}
object LastReadOptions {
var options: DataSourceOptions = _
def clear(): Unit = {
options = null
}
}
object LastWriteOptions {
var options: DataSourceOptions = _
def clear(): Unit = {
options = null
}
}
class StreamingDataSourceV2Suite extends StreamTest {
@ -130,6 +167,11 @@ class StreamingDataSourceV2Suite extends StreamTest {
spark.conf.set("spark.sql.streaming.checkpointLocation", fakeCheckpoint.getCanonicalPath)
}
override def afterEach(): Unit = {
LastReadOptions.clear()
LastWriteOptions.clear()
}
val readFormats = Seq(
"fake-read-microbatch-only",
"fake-read-continuous-only",
@ -143,7 +185,14 @@ class StreamingDataSourceV2Suite extends StreamTest {
Trigger.ProcessingTime(1000),
Trigger.Continuous(1000))
private def testPositiveCase(readFormat: String, writeFormat: String, trigger: Trigger) = {
private def testPositiveCase(readFormat: String, writeFormat: String, trigger: Trigger): Unit = {
testPositiveCaseWithQuery(readFormat, writeFormat, trigger)(() => _)
}
private def testPositiveCaseWithQuery(
readFormat: String,
writeFormat: String,
trigger: Trigger)(check: StreamingQuery => Unit): Unit = {
val query = spark.readStream
.format(readFormat)
.load()
@ -151,8 +200,8 @@ class StreamingDataSourceV2Suite extends StreamTest {
.format(writeFormat)
.trigger(trigger)
.start()
check(query)
query.stop()
query
}
private def testNegativeCase(
@ -188,19 +237,54 @@ class StreamingDataSourceV2Suite extends StreamTest {
test("disabled v2 write") {
// Ensure the V2 path works normally and generates a V2 sink..
val v2Query = testPositiveCase(
"fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once())
assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink
.isInstanceOf[FakeWriteSupportProviderV1Fallback])
testPositiveCaseWithQuery(
"fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) { v2Query =>
assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink
.isInstanceOf[FakeWriteSupportProviderV1Fallback])
}
// Ensure we create a V1 sink with the config. Note the config is a comma separated
// list, including other fake entries.
val fullSinkName = classOf[FakeWriteSupportProviderV1Fallback].getName
withSQLConf(SQLConf.DISABLED_V2_STREAMING_WRITERS.key -> s"a,b,c,test,$fullSinkName,d,e") {
val v1Query = testPositiveCase(
"fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once())
assert(v1Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink
.isInstanceOf[FakeSink])
testPositiveCaseWithQuery(
"fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) { v1Query =>
assert(v1Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink
.isInstanceOf[FakeSink])
}
}
}
Seq(
Tuple2(classOf[FakeReadMicroBatchOnly], Trigger.Once()),
Tuple2(classOf[FakeReadContinuousOnly], Trigger.Continuous(1000))
).foreach { case (source, trigger) =>
test(s"SPARK-25460: session options are respected in structured streaming sources - $source") {
// `keyPrefix` and `shortName` are the same in this test case
val readSource = source.newInstance().shortName()
val writeSource = "fake-write-microbatch-continuous"
val readOptionName = "optionA"
withSQLConf(s"spark.datasource.$readSource.$readOptionName" -> "true") {
testPositiveCaseWithQuery(readSource, writeSource, trigger) { _ =>
eventually(timeout(streamingTimeout)) {
// Write options should not be set.
assert(LastWriteOptions.options.getBoolean(readOptionName, false) == false)
assert(LastReadOptions.options.getBoolean(readOptionName, false) == true)
}
}
}
val writeOptionName = "optionB"
withSQLConf(s"spark.datasource.$writeSource.$writeOptionName" -> "true") {
testPositiveCaseWithQuery(readSource, writeSource, trigger) { _ =>
eventually(timeout(streamingTimeout)) {
// Read options should not be set.
assert(LastReadOptions.options.getBoolean(writeOptionName, false) == false)
assert(LastWriteOptions.options.getBoolean(writeOptionName, false) == true)
}
}
}
}
}