[SPARK-24327][SQL] Verify and normalize a partition column name based on the JDBC resolved schema

## What changes were proposed in this pull request?
This pr modified JDBC datasource code to verify and normalize a partition column based on the JDBC resolved schema before building `JDBCRelation`.

Closes #20370

## How was this patch tested?
Added tests in `JDBCSuite`.

Author: Takeshi Yamamuro <yamamuro@apache.org>

Closes #21379 from maropu/SPARK-24327.
This commit is contained in:
Takeshi Yamamuro 2018-06-24 23:14:42 -07:00 committed by Xiao Li
parent a5849ad9a3
commit f596ebe4d3
4 changed files with 118 additions and 17 deletions

View file

@ -100,7 +100,7 @@ private[spark] object Utils extends Logging {
*/
val DEFAULT_MAX_TO_STRING_FIELDS = 25
private def maxNumToStringFields = {
private[spark] def maxNumToStringFields = {
if (SparkEnv.get != null) {
SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS)
} else {

View file

@ -22,10 +22,12 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.Partition
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
/**
* Instructions on how to partition the table among workers.
@ -48,10 +50,17 @@ private[sql] object JDBCRelation extends Logging {
* Null value predicate is added to the first partition where clause to include
* the rows with null value for the partitions column.
*
* @param schema resolved schema of a JDBC table
* @param partitioning partition information to generate the where clause for each partition
* @param resolver function used to determine if two identifiers are equal
* @param jdbcOptions JDBC options that contains url
* @return an array of partitions with where clause for each partition
*/
def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
def columnPartition(
schema: StructType,
partitioning: JDBCPartitioningInfo,
resolver: Resolver,
jdbcOptions: JDBCOptions): Array[Partition] = {
if (partitioning == null || partitioning.numPartitions <= 1 ||
partitioning.lowerBound == partitioning.upperBound) {
return Array[Partition](JDBCPartition(null, 0))
@ -78,7 +87,10 @@ private[sql] object JDBCRelation extends Logging {
// Overflow and silliness can happen if you subtract then divide.
// Here we get a little roundoff, but that's (hopefully) OK.
val stride: Long = upperBound / numPartitions - lowerBound / numPartitions
val column = partitioning.column
val column = verifyAndGetNormalizedColumnName(
schema, partitioning.column, resolver, jdbcOptions)
var i: Int = 0
var currentValue: Long = lowerBound
val ans = new ArrayBuffer[Partition]()
@ -99,10 +111,57 @@ private[sql] object JDBCRelation extends Logging {
}
ans.toArray
}
// Verify column name based on the JDBC resolved schema
private def verifyAndGetNormalizedColumnName(
schema: StructType,
columnName: String,
resolver: Resolver,
jdbcOptions: JDBCOptions): String = {
val dialect = JdbcDialects.get(jdbcOptions.url)
schema.map(_.name).find { fieldName =>
resolver(fieldName, columnName) ||
resolver(dialect.quoteIdentifier(fieldName), columnName)
}.map(dialect.quoteIdentifier).getOrElse {
throw new AnalysisException(s"User-defined partition column $columnName not " +
s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}")
}
}
/**
* Takes a (schema, table) specification and returns the table's Catalyst schema.
* If `customSchema` defined in the JDBC options, replaces the schema's dataType with the
* custom schema's type.
*
* @param resolver function used to determine if two identifiers are equal
* @param jdbcOptions JDBC options that contains url, table and other information.
* @return resolved Catalyst schema of a JDBC table
*/
def getSchema(resolver: Resolver, jdbcOptions: JDBCOptions): StructType = {
val tableSchema = JDBCRDD.resolveTable(jdbcOptions)
jdbcOptions.customSchema match {
case Some(customSchema) => JdbcUtils.getCustomSchema(
tableSchema, customSchema, resolver)
case None => tableSchema
}
}
/**
* Resolves a Catalyst schema of a JDBC table and returns [[JDBCRelation]] with the schema.
*/
def apply(
parts: Array[Partition],
jdbcOptions: JDBCOptions)(
sparkSession: SparkSession): JDBCRelation = {
val schema = JDBCRelation.getSchema(sparkSession.sessionState.conf.resolver, jdbcOptions)
JDBCRelation(schema, parts, jdbcOptions)(sparkSession)
}
}
private[sql] case class JDBCRelation(
parts: Array[Partition], jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession)
override val schema: StructType,
parts: Array[Partition],
jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession)
extends BaseRelation
with PrunedFilteredScan
with InsertableRelation {
@ -111,15 +170,6 @@ private[sql] case class JDBCRelation(
override val needConversion: Boolean = false
override val schema: StructType = {
val tableSchema = JDBCRDD.resolveTable(jdbcOptions)
jdbcOptions.customSchema match {
case Some(customSchema) => JdbcUtils.getCustomSchema(
tableSchema, customSchema, sparkSession.sessionState.conf.resolver)
case None => tableSchema
}
}
// Check if JDBCRDD.compileFilter can accept input filters
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty)

View file

@ -48,8 +48,10 @@ class JdbcRelationProvider extends CreatableRelationProvider
JDBCPartitioningInfo(
partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get)
}
val parts = JDBCRelation.columnPartition(partitionInfo)
JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession)
val resolver = sqlContext.conf.resolver
val schema = JDBCRelation.getSchema(resolver, jdbcOptions)
val parts = JDBCRelation.columnPartition(schema, partitionInfo, resolver, jdbcOptions)
JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession)
}
override def createRelation(

View file

@ -31,8 +31,9 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils}
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRDD, JDBCRelation, JdbcUtils}
import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@ -238,6 +239,11 @@ class JDBCSuite extends SparkFunSuite
|OPTIONS (url '$url', dbtable 'TEST."mixedCaseCols"', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
conn.prepareStatement("CREATE TABLE test.partition (THEID INTEGER, `THE ID` INTEGER) " +
"AS SELECT 1, 1")
.executeUpdate()
conn.commit()
// Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types.
}
@ -1206,4 +1212,47 @@ class JDBCSuite extends SparkFunSuite
}.getMessage
assert(errMsg.contains("Statement was canceled or the session timed out"))
}
test("SPARK-24327 verify and normalize a partition column based on a JDBC resolved schema") {
def testJdbcParitionColumn(partColName: String, expectedColumnName: String): Unit = {
val df = spark.read.format("jdbc")
.option("url", urlWithUserAndPass)
.option("dbtable", "TEST.PARTITION")
.option("partitionColumn", partColName)
.option("lowerBound", 1)
.option("upperBound", 4)
.option("numPartitions", 3)
.load()
val quotedPrtColName = testH2Dialect.quoteIdentifier(expectedColumnName)
df.logicalPlan match {
case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) =>
val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet
assert(whereClauses === Set(
s"$quotedPrtColName < 2 or $quotedPrtColName is null",
s"$quotedPrtColName >= 2 AND $quotedPrtColName < 3",
s"$quotedPrtColName >= 3"))
}
}
testJdbcParitionColumn("THEID", "THEID")
testJdbcParitionColumn("\"THEID\"", "THEID")
withSQLConf("spark.sql.caseSensitive" -> "false") {
testJdbcParitionColumn("ThEiD", "THEID")
}
testJdbcParitionColumn("THE ID", "THE ID")
def testIncorrectJdbcPartitionColumn(partColName: String): Unit = {
val errMsg = intercept[AnalysisException] {
testJdbcParitionColumn(partColName, "THEID")
}.getMessage
assert(errMsg.contains(s"User-defined partition column $partColName not found " +
"in the JDBC relation:"))
}
testIncorrectJdbcPartitionColumn("NoExistingColumn")
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
testIncorrectJdbcPartitionColumn(testH2Dialect.quoteIdentifier("ThEiD"))
}
}
}