[SPARK-32229][SQL] Fix PostgresConnectionProvider and MSSQLConnectionProvider by accessing wrapped driver
### What changes were proposed in this pull request? Postgres and MSSQL connection providers are not able to get custom `appEntry` because under some circumstances the driver is wrapped with `DriverWrapper`. Such case is not handled in the mentioned providers. In this PR I've added this edge case handling by passing unwrapped `Driver` from `JdbcUtils`. ### Why are the changes needed? `DriverWrapper` is not considered. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing + additional unit tests. Closes #30024 from gaborgsomogyi/SPARK-32229. Authored-by: Gabor Somogyi <gabor.g.somogyi@gmail.com> Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
parent
a44e008de3
commit
fbb6843620
|
@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.jdbc
|
|||
|
||||
import java.sql.{Driver, DriverManager}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
|
@ -58,5 +59,15 @@ object DriverRegistry extends Logging {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
def get(className: String): Driver = {
|
||||
DriverManager.getDrivers.asScala.collectFirst {
|
||||
case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == className => d.wrapped
|
||||
case d if d.getClass.getCanonicalName == className => d
|
||||
}.getOrElse {
|
||||
throw new IllegalStateException(
|
||||
s"Did not find registered driver with class $className")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,10 +17,9 @@
|
|||
|
||||
package org.apache.spark.sql.execution.datasources.jdbc
|
||||
|
||||
import java.sql.{Connection, Driver, DriverManager, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException, SQLFeatureNotSupportedException}
|
||||
import java.sql.{Connection, Driver, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException, SQLFeatureNotSupportedException}
|
||||
import java.util.Locale
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.util.Try
|
||||
import scala.util.control.NonFatal
|
||||
|
||||
|
@ -56,17 +55,10 @@ object JdbcUtils extends Logging {
|
|||
val driverClass: String = options.driverClass
|
||||
() => {
|
||||
DriverRegistry.register(driverClass)
|
||||
val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
|
||||
case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d
|
||||
case d if d.getClass.getCanonicalName == driverClass => d
|
||||
}.getOrElse {
|
||||
throw new IllegalStateException(
|
||||
s"Did not find registered driver with class $driverClass")
|
||||
}
|
||||
val driver: Driver = DriverRegistry.get(driverClass)
|
||||
val connection = ConnectionProvider.create(driver, options.parameters)
|
||||
require(connection != null,
|
||||
s"The driver could not open a JDBC connection. Check the URL: ${options.url}")
|
||||
|
||||
connection
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.datasources.jdbc
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.execution.datasources.jdbc.connection.TestDriver
|
||||
|
||||
class DriverRegistrySuite extends SparkFunSuite {
|
||||
test("SPARK-32229: get must give back wrapped driver if wrapped") {
|
||||
val className = classOf[TestDriver].getName
|
||||
DriverRegistry.register(className)
|
||||
assert(DriverRegistry.get(className).isInstanceOf[TestDriver])
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.datasources.jdbc.connection
|
||||
|
||||
import java.sql.{Connection, Driver, DriverPropertyInfo}
|
||||
import java.util.Properties
|
||||
import java.util.logging.Logger
|
||||
|
||||
private[jdbc] class TestDriver() extends Driver {
|
||||
override def connect(url: String, info: Properties): Connection = null
|
||||
override def acceptsURL(url: String): Boolean = false
|
||||
override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] =
|
||||
Array.empty
|
||||
override def getMajorVersion: Int = 0
|
||||
override def getMinorVersion: Int = 0
|
||||
override def jdbcCompliant(): Boolean = false
|
||||
override def getParentLogger: Logger = null
|
||||
}
|
Loading…
Reference in a new issue