From 4036ad9ad9d2fd0e5e5fb8b9a86bf7f4e408b1b9 Mon Sep 17 00:00:00 2001 From: Ivan Sadikov Date: Mon, 19 Jul 2021 17:48:32 +0900 Subject: [PATCH] [SPARK-36163][SQL] Propagate correct JDBC properties in JDBC connector provider and add "connectionProvider" option ### What changes were proposed in this pull request? This PR fixes two issues highlighted in https://issues.apache.org/jira/browse/SPARK-36163: - JDBC connection provider propagates incorrect connection properties. - Ambiguity when more than one JDBC connection provider is available. I updated `BasicConnectionProvider` to use `jdbcOptions.asConnectionProperties` to remove JDBC data source specific options. I also added `connectionProvider` data source option that specifies the name of the provider, e.g. `db2`, `presto`, to allow enforcing this specific provider in case of ambiguity. ### Why are the changes needed? Users can leverage `spark.sql.sources.disabledJdbcConnProviderList` but it is cumbersome and requires them to disable all other providers which could be problematic when using ambiguous providers in two or more different JDBC queries. ### Does this PR introduce _any_ user-facing change? Yes PROBLEM DESCRIPTION: This introduces new JDBC data source option `connectionProvider` that allows users to select a specific JDBC connection provider based on the short name. I updated the SQL guide doc and README. Before this change, the only way to resolve ambiguity was SQL conf to blacklist all of the other JDBC connection providers. After this change users will be able to specify the exact connection provider they need per data source. ### How was this patch tested? I updated the existing `ConnectionProviderSuite` and added a new `BasicConnectionProviderSuite`. Closes #33370 from sadikovi/fix-jdbc-conn-provider. Authored-by: Ivan Sadikov Signed-off-by: Hyukjin Kwon --- docs/sql-data-sources-jdbc.md | 19 +++- .../datasources/jdbc/JDBCOptions.scala | 4 + .../datasources/jdbc/JdbcUtils.scala | 3 +- .../connection/BasicConnectionProvider.scala | 2 +- .../jdbc/connection/ConnectionProvider.scala | 42 +++++++-- .../scala/org/apache/spark/sql/jdbc/README.md | 10 ++- .../BasicConnectionProviderSuite.scala | 57 ++++++++++++ .../connection/ConnectionProviderSuite.scala | 88 ++++++++++++++++++- 8 files changed, 209 insertions(+), 16 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProviderSuite.scala diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md index c973e8a3ff..6d44a229bf 100644 --- a/docs/sql-data-sources-jdbc.md +++ b/docs/sql-data-sources-jdbc.md @@ -9,9 +9,9 @@ license: | The ASF licenses this file to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -191,7 +191,7 @@ logging into the data sources. write - + cascadeTruncate the default cascading truncate behaviour of the JDBC database in question, specified in the isCascadeTruncate in each JDBCDialect @@ -275,11 +275,22 @@ logging into the data sources. read/write + + + connectionProvider + (none) + + The name of the JDBC connection provider to use to connect to this URL, e.g. db2, mssql. + Must be one of the providers loaded with the JDBC data source. Used to disambiguate when more than one provider can handle + the specified driver and options. The selected provider must not be disabled by spark.sql.sources.disabledJdbcConnProviderList. + + read/write + Note that kerberos authentication with keytab is not always supported by the JDBC driver.
Before using keytab and principal configuration options, please make sure the following requirements are met: -* The included JDBC driver version supports kerberos authentication with keytab. +* The included JDBC driver version supports kerberos authentication with keytab. * There is a built-in connection provider which supports the used database. There is a built-in connection providers for the following databases: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 97d4f2d976..e3baafbb4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -207,6 +207,9 @@ class JDBCOptions( val tableComment = parameters.getOrElse(JDBC_TABLE_COMMENT, "").toString val refreshKrb5Config = parameters.getOrElse(JDBC_REFRESH_KRB5_CONFIG, "false").toBoolean + + // User specified JDBC connection provider name + val connectionProviderName = parameters.get(JDBC_CONNECTION_PROVIDER) } class JdbcOptionsInWrite( @@ -263,4 +266,5 @@ object JDBCOptions { val JDBC_PRINCIPAL = newOption("principal") val JDBC_TABLE_COMMENT = newOption("tableComment") val JDBC_REFRESH_KRB5_CONFIG = newOption("refreshKrb5Config") + val JDBC_CONNECTION_PROVIDER = newOption("connectionProvider") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 60fcaf94e1..7b555bd281 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -61,7 +61,8 @@ object JdbcUtils extends Logging { () => { DriverRegistry.register(driverClass) val driver: Driver = DriverRegistry.get(driverClass) - val connection = ConnectionProvider.create(driver, options.parameters) + val connection = + ConnectionProvider.create(driver, options.parameters, options.connectionProviderName) require(connection != null, s"The driver could not open a JDBC connection. Check the URL: ${options.url}") connection diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProvider.scala index 890205f2f6..66854f2801 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProvider.scala @@ -42,7 +42,7 @@ private[jdbc] class BasicConnectionProvider extends JdbcConnectionProvider with override def getConnection(driver: Driver, options: Map[String, String]): Connection = { val jdbcOptions = new JDBCOptions(options) val properties = getAdditionalProperties(jdbcOptions) - jdbcOptions.asProperties.asScala.foreach { case(k, v) => + jdbcOptions.asConnectionProperties.asScala.foreach { case(k, v) => properties.put(k, v) } logDebug(s"JDBC connection initiated with URL: ${jdbcOptions.url} and properties: $properties") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala index fbc69704f1..e3d82757e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala @@ -25,12 +25,13 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.security.SecurityConfigurationLock +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.JdbcConnectionProvider import org.apache.spark.util.Utils -private[jdbc] object ConnectionProvider extends Logging { - private val providers = loadProviders() +protected abstract class ConnectionProviderBase extends Logging { + protected val providers = loadProviders() def loadProviders(): Seq[JdbcConnectionProvider] = { val loader = ServiceLoader.load(classOf[JdbcConnectionProvider], @@ -55,17 +56,42 @@ private[jdbc] object ConnectionProvider extends Logging { providers.filterNot(p => disabledProviders.contains(p.name)).toSeq } - def create(driver: Driver, options: Map[String, String]): Connection = { + def create( + driver: Driver, + options: Map[String, String], + connectionProviderName: Option[String]): Connection = { val filteredProviders = providers.filter(_.canHandle(driver, options)) - require(filteredProviders.size == 1, - "JDBC connection initiated but not exactly one connection provider found which can handle " + - s"it. Found active providers: ${filteredProviders.mkString(", ")}") + + if (filteredProviders.isEmpty) { + throw new IllegalArgumentException( + "Empty list of JDBC connection providers for the specified driver and options") + } + + val selectedProvider = connectionProviderName match { + case Some(providerName) => + // It is assumed that no two providers will have the same name + filteredProviders.filter(_.name == providerName).headOption.getOrElse { + throw new IllegalArgumentException( + s"Could not find a JDBC connection provider with name '$providerName' " + + "that can handle the specified driver and options. " + + s"Available providers are ${providers.mkString("[", ", ", "]")}") + } + case None => + if (filteredProviders.size != 1) { + throw new IllegalArgumentException( + "JDBC connection initiated but more than one connection provider was found. Use " + + s"'${JDBCOptions.JDBC_CONNECTION_PROVIDER}' option to select a specific provider. " + + s"Found active providers ${filteredProviders.mkString("[", ", ", "]")}") + } + filteredProviders.head + } + SecurityConfigurationLock.synchronized { // Inside getConnection it's safe to get parent again because SecurityConfigurationLock // makes sure it's untouched val parent = Configuration.getConfiguration try { - filteredProviders.head.getConnection(driver, options) + selectedProvider.getConnection(driver, options) } finally { logDebug("Restoring original security configuration") Configuration.setConfiguration(parent) @@ -73,3 +99,5 @@ private[jdbc] object ConnectionProvider extends Logging { } } } + +private[jdbc] object ConnectionProvider extends ConnectionProviderBase diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/README.md b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/README.md index f8a4ae08f8..72196be014 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/README.md +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/README.md @@ -6,9 +6,9 @@ license: | The ASF licenses this file to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -46,6 +46,12 @@ so they can be turned off and can be replaced with custom implementation. All CP which must be unique. One can set the following configuration entry in `SparkConf` to turn off CPs: `spark.sql.sources.disabledJdbcConnProviderList=name1,name2`. +## How to enforce a specific JDBC connection provider? + +When more than one JDBC connection provider can handle a specific driver and options, it is possible to +disambiguate and enforce a particular CP for the JDBC data source. One can set the DataFrame +option `connectionProvider` to specify the name of the CP they want to use. + ## How a JDBC connection provider found when new connection initiated? When a Spark source initiates JDBC connection it looks for a CP which supports the included driver, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProviderSuite.scala new file mode 100644 index 0000000000..823fdcae9d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProviderSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc.connection + +import java.sql.{Connection, Driver} +import java.util.Properties + +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.when +import org.mockito.invocation.InvocationOnMock +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions + +class BasicConnectionProviderSuite extends ConnectionProviderSuiteBase with MockitoSugar { + test("Check properties of BasicConnectionProvider") { + val opts = options("jdbc:postgresql://localhost/postgres") + val provider = new BasicConnectionProvider() + assert(provider.name == "basic") + assert(provider.getAdditionalProperties(opts).isEmpty()) + } + + test("Check that JDBC options don't contain data source configs") { + val provider = new BasicConnectionProvider() + val driver = mock[Driver] + when(driver.connect(any(), any())).thenAnswer((invocation: InvocationOnMock) => { + val props = invocation.getArguments().apply(1).asInstanceOf[Properties] + val conn = mock[Connection] + when(conn.getClientInfo()).thenReturn(props) + conn + }) + + val opts = Map( + JDBCOptions.JDBC_URL -> "jdbc:postgresql://localhost/postgres", + JDBCOptions.JDBC_TABLE_NAME -> "table", + JDBCOptions.JDBC_CONNECTION_PROVIDER -> "basic") + val conn = provider.getConnection(driver, opts) + assert(!conn.getClientInfo().containsKey(JDBCOptions.JDBC_URL)) + assert(!conn.getClientInfo().containsKey(JDBCOptions.JDBC_TABLE_NAME)) + assert(!conn.getClientInfo().containsKey(JDBCOptions.JDBC_CONNECTION_PROVIDER)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuite.scala index 32d8fce28a..6674483c29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuite.scala @@ -17,13 +17,21 @@ package org.apache.spark.sql.execution.datasources.jdbc.connection +import java.sql.{Connection, Driver} import javax.security.auth.login.Configuration +import org.scalatestplus.mockito.MockitoSugar + import org.apache.spark.SparkConf import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.spark.sql.jdbc.JdbcConnectionProvider import org.apache.spark.sql.test.SharedSparkSession -class ConnectionProviderSuite extends ConnectionProviderSuiteBase with SharedSparkSession { +class ConnectionProviderSuite + extends ConnectionProviderSuiteBase + with SharedSparkSession + with MockitoSugar { + test("All built-in providers must be loaded") { IntentionallyFaultyConnectionProvider.constructed = false val providers = ConnectionProvider.loadProviders() @@ -38,6 +46,84 @@ class ConnectionProviderSuite extends ConnectionProviderSuiteBase with SharedSpa assert(providers.size === 6) } + test("Throw an error selecting from an empty list of providers on create") { + val providerBase = new ConnectionProviderBase() { + override val providers = Seq.empty + } + + val err1 = intercept[IllegalArgumentException] { + providerBase.create(mock[Driver], Map.empty, None) + } + assert(err1.getMessage.contains("Empty list of JDBC connection providers")) + + val err2 = intercept[IllegalArgumentException] { + providerBase.create(mock[Driver], Map.empty, Some("test")) + } + assert(err2.getMessage.contains("Empty list of JDBC connection providers")) + } + + test("Throw an error when more than one provider is available on create") { + val provider1 = new JdbcConnectionProvider() { + override val name: String = "test1" + override def canHandle(driver: Driver, options: Map[String, String]): Boolean = true + override def getConnection(driver: Driver, options: Map[String, String]): Connection = + throw new RuntimeException() + } + val provider2 = new JdbcConnectionProvider() { + override val name: String = "test2" + override def canHandle(driver: Driver, options: Map[String, String]): Boolean = true + override def getConnection(driver: Driver, options: Map[String, String]): Connection = + throw new RuntimeException() + } + + val providerBase = new ConnectionProviderBase() { + override val providers = Seq(provider1, provider2) + } + + val err = intercept[IllegalArgumentException] { + providerBase.create(mock[Driver], Map.empty, None) + } + assert(err.getMessage.contains("more than one connection provider was found")) + } + + test("Handle user specified JDBC connection provider") { + val provider1 = new JdbcConnectionProvider() { + override val name: String = "test1" + override def canHandle(driver: Driver, options: Map[String, String]): Boolean = true + override def getConnection(driver: Driver, options: Map[String, String]): Connection = + throw new RuntimeException() + } + val provider2 = new JdbcConnectionProvider() { + override val name: String = "test2" + override def canHandle(driver: Driver, options: Map[String, String]): Boolean = true + override def getConnection(driver: Driver, options: Map[String, String]): Connection = + mock[Connection] + } + + val providerBase = new ConnectionProviderBase() { + override val providers = Seq(provider1, provider2) + } + // We don't expect any exceptions or null here + assert(providerBase.create(mock[Driver], Map.empty, Some("test2")).isInstanceOf[Connection]) + } + + test("Throw an error when user specified provider that does not exist") { + val provider = new JdbcConnectionProvider() { + override val name: String = "provider" + override def canHandle(driver: Driver, options: Map[String, String]): Boolean = true + override def getConnection(driver: Driver, options: Map[String, String]): Connection = + throw new RuntimeException() + } + + val providerBase = new ConnectionProviderBase() { + override val providers = Seq(provider) + } + val err = intercept[IllegalArgumentException] { + providerBase.create(mock[Driver], Map.empty, Some("test")) + } + assert(err.getMessage.contains("Could not find a JDBC connection provider with name 'test'")) + } + test("Multiple security configs must be reachable") { Configuration.setConfiguration(null) val postgresProvider = new PostgresConnectionProvider()