[SPARK-24327][SQL] Verify and normalize a partition column name based on the JDBC resolved schema
## What changes were proposed in this pull request? This pr modified JDBC datasource code to verify and normalize a partition column based on the JDBC resolved schema before building `JDBCRelation`. Closes #20370 ## How was this patch tested? Added tests in `JDBCSuite`. Author: Takeshi Yamamuro <yamamuro@apache.org> Closes #21379 from maropu/SPARK-24327.
This commit is contained in:
parent
a5849ad9a3
commit
f596ebe4d3
|
@ -100,7 +100,7 @@ private[spark] object Utils extends Logging {
|
|||
*/
|
||||
val DEFAULT_MAX_TO_STRING_FIELDS = 25
|
||||
|
||||
private def maxNumToStringFields = {
|
||||
private[spark] def maxNumToStringFields = {
|
||||
if (SparkEnv.get != null) {
|
||||
SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS)
|
||||
} else {
|
||||
|
|
|
@ -22,10 +22,12 @@ import scala.collection.mutable.ArrayBuffer
|
|||
import org.apache.spark.Partition
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
|
||||
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.jdbc.JdbcDialects
|
||||
import org.apache.spark.sql.sources._
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
* Instructions on how to partition the table among workers.
|
||||
|
@ -48,10 +50,17 @@ private[sql] object JDBCRelation extends Logging {
|
|||
* Null value predicate is added to the first partition where clause to include
|
||||
* the rows with null value for the partitions column.
|
||||
*
|
||||
* @param schema resolved schema of a JDBC table
|
||||
* @param partitioning partition information to generate the where clause for each partition
|
||||
* @param resolver function used to determine if two identifiers are equal
|
||||
* @param jdbcOptions JDBC options that contains url
|
||||
* @return an array of partitions with where clause for each partition
|
||||
*/
|
||||
def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
|
||||
def columnPartition(
|
||||
schema: StructType,
|
||||
partitioning: JDBCPartitioningInfo,
|
||||
resolver: Resolver,
|
||||
jdbcOptions: JDBCOptions): Array[Partition] = {
|
||||
if (partitioning == null || partitioning.numPartitions <= 1 ||
|
||||
partitioning.lowerBound == partitioning.upperBound) {
|
||||
return Array[Partition](JDBCPartition(null, 0))
|
||||
|
@ -78,7 +87,10 @@ private[sql] object JDBCRelation extends Logging {
|
|||
// Overflow and silliness can happen if you subtract then divide.
|
||||
// Here we get a little roundoff, but that's (hopefully) OK.
|
||||
val stride: Long = upperBound / numPartitions - lowerBound / numPartitions
|
||||
val column = partitioning.column
|
||||
|
||||
val column = verifyAndGetNormalizedColumnName(
|
||||
schema, partitioning.column, resolver, jdbcOptions)
|
||||
|
||||
var i: Int = 0
|
||||
var currentValue: Long = lowerBound
|
||||
val ans = new ArrayBuffer[Partition]()
|
||||
|
@ -99,10 +111,57 @@ private[sql] object JDBCRelation extends Logging {
|
|||
}
|
||||
ans.toArray
|
||||
}
|
||||
|
||||
// Verify column name based on the JDBC resolved schema
|
||||
private def verifyAndGetNormalizedColumnName(
|
||||
schema: StructType,
|
||||
columnName: String,
|
||||
resolver: Resolver,
|
||||
jdbcOptions: JDBCOptions): String = {
|
||||
val dialect = JdbcDialects.get(jdbcOptions.url)
|
||||
schema.map(_.name).find { fieldName =>
|
||||
resolver(fieldName, columnName) ||
|
||||
resolver(dialect.quoteIdentifier(fieldName), columnName)
|
||||
}.map(dialect.quoteIdentifier).getOrElse {
|
||||
throw new AnalysisException(s"User-defined partition column $columnName not " +
|
||||
s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes a (schema, table) specification and returns the table's Catalyst schema.
|
||||
* If `customSchema` defined in the JDBC options, replaces the schema's dataType with the
|
||||
* custom schema's type.
|
||||
*
|
||||
* @param resolver function used to determine if two identifiers are equal
|
||||
* @param jdbcOptions JDBC options that contains url, table and other information.
|
||||
* @return resolved Catalyst schema of a JDBC table
|
||||
*/
|
||||
def getSchema(resolver: Resolver, jdbcOptions: JDBCOptions): StructType = {
|
||||
val tableSchema = JDBCRDD.resolveTable(jdbcOptions)
|
||||
jdbcOptions.customSchema match {
|
||||
case Some(customSchema) => JdbcUtils.getCustomSchema(
|
||||
tableSchema, customSchema, resolver)
|
||||
case None => tableSchema
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves a Catalyst schema of a JDBC table and returns [[JDBCRelation]] with the schema.
|
||||
*/
|
||||
def apply(
|
||||
parts: Array[Partition],
|
||||
jdbcOptions: JDBCOptions)(
|
||||
sparkSession: SparkSession): JDBCRelation = {
|
||||
val schema = JDBCRelation.getSchema(sparkSession.sessionState.conf.resolver, jdbcOptions)
|
||||
JDBCRelation(schema, parts, jdbcOptions)(sparkSession)
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] case class JDBCRelation(
|
||||
parts: Array[Partition], jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession)
|
||||
override val schema: StructType,
|
||||
parts: Array[Partition],
|
||||
jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession)
|
||||
extends BaseRelation
|
||||
with PrunedFilteredScan
|
||||
with InsertableRelation {
|
||||
|
@ -111,15 +170,6 @@ private[sql] case class JDBCRelation(
|
|||
|
||||
override val needConversion: Boolean = false
|
||||
|
||||
override val schema: StructType = {
|
||||
val tableSchema = JDBCRDD.resolveTable(jdbcOptions)
|
||||
jdbcOptions.customSchema match {
|
||||
case Some(customSchema) => JdbcUtils.getCustomSchema(
|
||||
tableSchema, customSchema, sparkSession.sessionState.conf.resolver)
|
||||
case None => tableSchema
|
||||
}
|
||||
}
|
||||
|
||||
// Check if JDBCRDD.compileFilter can accept input filters
|
||||
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
|
||||
filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty)
|
||||
|
|
|
@ -48,8 +48,10 @@ class JdbcRelationProvider extends CreatableRelationProvider
|
|||
JDBCPartitioningInfo(
|
||||
partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get)
|
||||
}
|
||||
val parts = JDBCRelation.columnPartition(partitionInfo)
|
||||
JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession)
|
||||
val resolver = sqlContext.conf.resolver
|
||||
val schema = JDBCRelation.getSchema(resolver, jdbcOptions)
|
||||
val parts = JDBCRelation.columnPartition(schema, partitionInfo, resolver, jdbcOptions)
|
||||
JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession)
|
||||
}
|
||||
|
||||
override def createRelation(
|
||||
|
|
|
@ -31,8 +31,9 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
|
|||
import org.apache.spark.sql.execution.DataSourceScanExec
|
||||
import org.apache.spark.sql.execution.command.ExplainCommand
|
||||
import org.apache.spark.sql.execution.datasources.LogicalRelation
|
||||
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils}
|
||||
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRDD, JDBCRelation, JdbcUtils}
|
||||
import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.sources._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -238,6 +239,11 @@ class JDBCSuite extends SparkFunSuite
|
|||
|OPTIONS (url '$url', dbtable 'TEST."mixedCaseCols"', user 'testUser', password 'testPass')
|
||||
""".stripMargin.replaceAll("\n", " "))
|
||||
|
||||
conn.prepareStatement("CREATE TABLE test.partition (THEID INTEGER, `THE ID` INTEGER) " +
|
||||
"AS SELECT 1, 1")
|
||||
.executeUpdate()
|
||||
conn.commit()
|
||||
|
||||
// Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types.
|
||||
}
|
||||
|
||||
|
@ -1206,4 +1212,47 @@ class JDBCSuite extends SparkFunSuite
|
|||
}.getMessage
|
||||
assert(errMsg.contains("Statement was canceled or the session timed out"))
|
||||
}
|
||||
|
||||
test("SPARK-24327 verify and normalize a partition column based on a JDBC resolved schema") {
|
||||
def testJdbcParitionColumn(partColName: String, expectedColumnName: String): Unit = {
|
||||
val df = spark.read.format("jdbc")
|
||||
.option("url", urlWithUserAndPass)
|
||||
.option("dbtable", "TEST.PARTITION")
|
||||
.option("partitionColumn", partColName)
|
||||
.option("lowerBound", 1)
|
||||
.option("upperBound", 4)
|
||||
.option("numPartitions", 3)
|
||||
.load()
|
||||
|
||||
val quotedPrtColName = testH2Dialect.quoteIdentifier(expectedColumnName)
|
||||
df.logicalPlan match {
|
||||
case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) =>
|
||||
val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet
|
||||
assert(whereClauses === Set(
|
||||
s"$quotedPrtColName < 2 or $quotedPrtColName is null",
|
||||
s"$quotedPrtColName >= 2 AND $quotedPrtColName < 3",
|
||||
s"$quotedPrtColName >= 3"))
|
||||
}
|
||||
}
|
||||
|
||||
testJdbcParitionColumn("THEID", "THEID")
|
||||
testJdbcParitionColumn("\"THEID\"", "THEID")
|
||||
withSQLConf("spark.sql.caseSensitive" -> "false") {
|
||||
testJdbcParitionColumn("ThEiD", "THEID")
|
||||
}
|
||||
testJdbcParitionColumn("THE ID", "THE ID")
|
||||
|
||||
def testIncorrectJdbcPartitionColumn(partColName: String): Unit = {
|
||||
val errMsg = intercept[AnalysisException] {
|
||||
testJdbcParitionColumn(partColName, "THEID")
|
||||
}.getMessage
|
||||
assert(errMsg.contains(s"User-defined partition column $partColName not found " +
|
||||
"in the JDBC relation:"))
|
||||
}
|
||||
|
||||
testIncorrectJdbcPartitionColumn("NoExistingColumn")
|
||||
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
|
||||
testIncorrectJdbcPartitionColumn(testH2Dialect.quoteIdentifier("ThEiD"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue