[SPARK-33865][SPARK-36202][SQL] When HiveDDL, we need check avro schema too
### What changes were proposed in this pull request?
Unify schema check code of FileFormat and check avro schema filed name when CREATE TABLE DDL too
### Why are the changes needed?
Refactor code
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Not need
Closes #33441 from AngersZhuuuu/SPARK-36202.
Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 86f44578e5
)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
cd6b303d0f
commit
2f4f7936fd
|
@ -153,6 +153,18 @@ private[sql] class AvroFileFormat extends FileFormat
|
|||
}
|
||||
|
||||
override def supportDataType(dataType: DataType): Boolean = AvroUtils.supportsDataType(dataType)
|
||||
|
||||
override def supportFieldName(name: String): Boolean = {
|
||||
if (name.length == 0) {
|
||||
false
|
||||
} else {
|
||||
name.zipWithIndex.forall {
|
||||
case (c, 0) if !Character.isLetter(c) && c != '_' => false
|
||||
case (c, _) if !Character.isLetterOrDigit(c) && c != '_' => false
|
||||
case _ => true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[avro] object AvroFileFormat {
|
||||
|
|
|
@ -2158,6 +2158,36 @@ abstract class AvroSuite
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-33865: CREATE TABLE DDL with avro should check col name") {
|
||||
withTable("test_ddl") {
|
||||
withView("v") {
|
||||
spark.range(1).createTempView("v")
|
||||
withTempDir { dir =>
|
||||
val e = intercept[AnalysisException] {
|
||||
sql(
|
||||
s"""
|
||||
|CREATE TABLE test_ddl USING AVRO
|
||||
|LOCATION '${dir}'
|
||||
|AS SELECT ID, IF(ID=1,1,0) FROM v""".stripMargin)
|
||||
}.getMessage
|
||||
assert(e.contains("Column name \"(IF((ID = 1), 1, 0))\" contains invalid character(s)."))
|
||||
}
|
||||
|
||||
withTempDir { dir =>
|
||||
spark.sql(
|
||||
s"""
|
||||
|CREATE TABLE test_ddl USING AVRO
|
||||
|LOCATION '${dir}'
|
||||
|AS SELECT ID, IF(ID=1,ID,0) AS A, ABS(ID) AS B
|
||||
|FROM v""".stripMargin)
|
||||
val expectedSchema = StructType(Seq(StructField("ID", LongType, true),
|
||||
StructField("A", LongType, true), StructField("B", LongType, true)))
|
||||
assert(spark.table("test_ddl").schema == expectedSchema)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class AvroV1Suite extends AvroSuite {
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.apache.hadoop.conf.Configuration
|
|||
import org.apache.hadoop.fs._
|
||||
import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.internal.config.RDD_PARALLEL_LISTING_THRESHOLD
|
||||
import org.apache.spark.sql.{Row, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||
|
@ -40,9 +41,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
|||
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableCatalog}
|
||||
import org.apache.spark.sql.connector.catalog.SupportsNamespaces._
|
||||
import org.apache.spark.sql.errors.QueryCompilationErrors
|
||||
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
|
||||
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
|
||||
import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter
|
||||
import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils, FileFormat, HadoopFsRelation, LogicalRelation}
|
||||
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
|
||||
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.util.PartitioningUtils
|
||||
|
@ -860,7 +860,7 @@ case class AlterTableSetLocationCommand(
|
|||
}
|
||||
|
||||
|
||||
object DDLUtils {
|
||||
object DDLUtils extends Logging {
|
||||
val HIVE_PROVIDER = "hive"
|
||||
|
||||
def isHiveTable(table: CatalogTable): Boolean = {
|
||||
|
@ -933,19 +933,38 @@ object DDLUtils {
|
|||
case HIVE_PROVIDER =>
|
||||
val serde = table.storage.serde
|
||||
if (serde == HiveSerDe.sourceToSerDe("orc").get.serde) {
|
||||
OrcFileFormat.checkFieldNames(schema)
|
||||
checkDataColNames("orc", schema)
|
||||
} else if (serde == HiveSerDe.sourceToSerDe("parquet").get.serde ||
|
||||
serde == Some("parquet.hive.serde.ParquetHiveSerDe") ||
|
||||
serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) {
|
||||
ParquetSchemaConverter.checkFieldNames(schema)
|
||||
checkDataColNames("parquet", schema)
|
||||
} else if (serde == HiveSerDe.sourceToSerDe("avro").get.serde) {
|
||||
checkDataColNames("avro", schema)
|
||||
}
|
||||
case "parquet" => ParquetSchemaConverter.checkFieldNames(schema)
|
||||
case "orc" => OrcFileFormat.checkFieldNames(schema)
|
||||
case "parquet" => checkDataColNames("parquet", schema)
|
||||
case "orc" => checkDataColNames("orc", schema)
|
||||
case "avro" => checkDataColNames("avro", schema)
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def checkDataColNames(provider: String, schema: StructType): Unit = {
|
||||
val source = try {
|
||||
DataSource.lookupDataSource(provider, SQLConf.get).getConstructor().newInstance()
|
||||
} catch {
|
||||
case e: Throwable =>
|
||||
logError(s"Failed to find data source: $provider when check data column names.", e)
|
||||
return
|
||||
}
|
||||
source match {
|
||||
case f: FileFormat => DataSourceUtils.checkFieldNames(f, schema)
|
||||
case f: FileDataSourceV2 =>
|
||||
DataSourceUtils.checkFieldNames(f.fallbackFileFormat.newInstance(), schema)
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Throws exception if outputPath tries to overwrite inputpath.
|
||||
*/
|
||||
|
|
|
@ -58,6 +58,22 @@ object DataSourceUtils {
|
|||
Serialization.read[Seq[String]](str)
|
||||
}
|
||||
|
||||
/**
|
||||
* Verify if the field name is supported in datasource. This verification should be done
|
||||
* in a driver side.
|
||||
*/
|
||||
def checkFieldNames(format: FileFormat, schema: StructType): Unit = {
|
||||
schema.foreach { field =>
|
||||
if (!format.supportFieldName(field.name)) {
|
||||
throw QueryCompilationErrors.columnNameContainsInvalidCharactersError(field.name)
|
||||
}
|
||||
field.dataType match {
|
||||
case s: StructType => checkFieldNames(format, s)
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Verify if the schema is supported in datasource. This verification should be done
|
||||
* in a driver side.
|
||||
|
|
|
@ -163,6 +163,12 @@ trait FileFormat {
|
|||
* By default all data types are supported.
|
||||
*/
|
||||
def supportDataType(dataType: DataType): Boolean = true
|
||||
|
||||
/**
|
||||
* Returns whether this format supports the given filed name in read/write path.
|
||||
* By default all field name is supported.
|
||||
*/
|
||||
def supportFieldName(name: String): Boolean = true
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -36,31 +36,12 @@ import org.apache.spark.sql.SparkSession
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
|
||||
import org.apache.spark.sql.errors.QueryCompilationErrors
|
||||
import org.apache.spark.sql.execution.datasources._
|
||||
import org.apache.spark.sql.sources._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.{SerializableConfiguration, Utils}
|
||||
|
||||
private[sql] object OrcFileFormat {
|
||||
private def checkFieldName(name: String): Unit = {
|
||||
try {
|
||||
TypeDescription.fromString(s"struct<`$name`:int>")
|
||||
} catch {
|
||||
case _: IllegalArgumentException =>
|
||||
throw QueryCompilationErrors.columnNameContainsInvalidCharactersError(name)
|
||||
}
|
||||
}
|
||||
|
||||
def checkFieldNames(schema: StructType): Unit = {
|
||||
schema.foreach { field =>
|
||||
checkFieldName(field.name)
|
||||
field.dataType match {
|
||||
case s: StructType => checkFieldNames(s)
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def getQuotedSchemaString(dataType: DataType): String = dataType match {
|
||||
case _: AtomicType => dataType.catalogString
|
||||
|
@ -279,4 +260,13 @@ class OrcFileFormat
|
|||
|
||||
case _ => false
|
||||
}
|
||||
|
||||
override def supportFieldName(name: String): Boolean = {
|
||||
try {
|
||||
TypeDescription.fromString(s"struct<`$name`:int>")
|
||||
true
|
||||
} catch {
|
||||
case _: IllegalArgumentException => false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -386,6 +386,10 @@ class ParquetFileFormat
|
|||
|
||||
case _ => false
|
||||
}
|
||||
|
||||
override def supportFieldName(name: String): Boolean = {
|
||||
!name.matches(".*[ ,;{}()\n\t=].*")
|
||||
}
|
||||
}
|
||||
|
||||
object ParquetFileFormat extends Logging {
|
||||
|
|
|
@ -591,16 +591,6 @@ private[sql] object ParquetSchemaConverter {
|
|||
}
|
||||
}
|
||||
|
||||
def checkFieldNames(schema: StructType): Unit = {
|
||||
schema.foreach { field =>
|
||||
checkFieldName(field.name)
|
||||
field.dataType match {
|
||||
case s: StructType => checkFieldNames(s)
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def checkConversionRequirement(f: => Boolean, message: String): Unit = {
|
||||
if (!f) {
|
||||
throw new AnalysisException(message)
|
||||
|
|
Loading…
Reference in a new issue