[SPARK-33865][SPARK-36202][SQL] When HiveDDL, we need check avro schema too

### What changes were proposed in this pull request?
Unify schema check code of FileFormat and check avro schema filed name when CREATE TABLE DDL too

### Why are the changes needed?
Refactor code

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Not need

Closes #33441 from AngersZhuuuu/SPARK-36202.

Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 86f44578e5)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Angerszhuuuu 2021-07-28 14:04:24 +08:00 committed by Wenchen Fan
parent cd6b303d0f
commit 2f4f7936fd
8 changed files with 104 additions and 37 deletions

View file

@ -153,6 +153,18 @@ private[sql] class AvroFileFormat extends FileFormat
}
override def supportDataType(dataType: DataType): Boolean = AvroUtils.supportsDataType(dataType)
override def supportFieldName(name: String): Boolean = {
if (name.length == 0) {
false
} else {
name.zipWithIndex.forall {
case (c, 0) if !Character.isLetter(c) && c != '_' => false
case (c, _) if !Character.isLetterOrDigit(c) && c != '_' => false
case _ => true
}
}
}
}
private[avro] object AvroFileFormat {

View file

@ -2158,6 +2158,36 @@ abstract class AvroSuite
}
}
}
test("SPARK-33865: CREATE TABLE DDL with avro should check col name") {
withTable("test_ddl") {
withView("v") {
spark.range(1).createTempView("v")
withTempDir { dir =>
val e = intercept[AnalysisException] {
sql(
s"""
|CREATE TABLE test_ddl USING AVRO
|LOCATION '${dir}'
|AS SELECT ID, IF(ID=1,1,0) FROM v""".stripMargin)
}.getMessage
assert(e.contains("Column name \"(IF((ID = 1), 1, 0))\" contains invalid character(s)."))
}
withTempDir { dir =>
spark.sql(
s"""
|CREATE TABLE test_ddl USING AVRO
|LOCATION '${dir}'
|AS SELECT ID, IF(ID=1,ID,0) AS A, ABS(ID) AS B
|FROM v""".stripMargin)
val expectedSchema = StructType(Seq(StructField("ID", LongType, true),
StructField("A", LongType, true), StructField("B", LongType, true)))
assert(spark.table("test_ddl").schema == expectedSchema)
}
}
}
}
}
class AvroV1Suite extends AvroSuite {

View file

@ -29,6 +29,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs._
import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.RDD_PARALLEL_LISTING_THRESHOLD
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
@ -40,9 +41,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableCatalog}
import org.apache.spark.sql.connector.catalog.SupportsNamespaces._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter
import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils, FileFormat, HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.PartitioningUtils
@ -860,7 +860,7 @@ case class AlterTableSetLocationCommand(
}
object DDLUtils {
object DDLUtils extends Logging {
val HIVE_PROVIDER = "hive"
def isHiveTable(table: CatalogTable): Boolean = {
@ -933,19 +933,38 @@ object DDLUtils {
case HIVE_PROVIDER =>
val serde = table.storage.serde
if (serde == HiveSerDe.sourceToSerDe("orc").get.serde) {
OrcFileFormat.checkFieldNames(schema)
checkDataColNames("orc", schema)
} else if (serde == HiveSerDe.sourceToSerDe("parquet").get.serde ||
serde == Some("parquet.hive.serde.ParquetHiveSerDe") ||
serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) {
ParquetSchemaConverter.checkFieldNames(schema)
checkDataColNames("parquet", schema)
} else if (serde == HiveSerDe.sourceToSerDe("avro").get.serde) {
checkDataColNames("avro", schema)
}
case "parquet" => ParquetSchemaConverter.checkFieldNames(schema)
case "orc" => OrcFileFormat.checkFieldNames(schema)
case "parquet" => checkDataColNames("parquet", schema)
case "orc" => checkDataColNames("orc", schema)
case "avro" => checkDataColNames("avro", schema)
case _ =>
}
}
}
def checkDataColNames(provider: String, schema: StructType): Unit = {
val source = try {
DataSource.lookupDataSource(provider, SQLConf.get).getConstructor().newInstance()
} catch {
case e: Throwable =>
logError(s"Failed to find data source: $provider when check data column names.", e)
return
}
source match {
case f: FileFormat => DataSourceUtils.checkFieldNames(f, schema)
case f: FileDataSourceV2 =>
DataSourceUtils.checkFieldNames(f.fallbackFileFormat.newInstance(), schema)
case _ =>
}
}
/**
* Throws exception if outputPath tries to overwrite inputpath.
*/

View file

@ -58,6 +58,22 @@ object DataSourceUtils {
Serialization.read[Seq[String]](str)
}
/**
* Verify if the field name is supported in datasource. This verification should be done
* in a driver side.
*/
def checkFieldNames(format: FileFormat, schema: StructType): Unit = {
schema.foreach { field =>
if (!format.supportFieldName(field.name)) {
throw QueryCompilationErrors.columnNameContainsInvalidCharactersError(field.name)
}
field.dataType match {
case s: StructType => checkFieldNames(format, s)
case _ =>
}
}
}
/**
* Verify if the schema is supported in datasource. This verification should be done
* in a driver side.

View file

@ -163,6 +163,12 @@ trait FileFormat {
* By default all data types are supported.
*/
def supportDataType(dataType: DataType): Boolean = true
/**
* Returns whether this format supports the given filed name in read/write path.
* By default all field name is supported.
*/
def supportFieldName(name: String): Boolean = true
}
/**

View file

@ -36,31 +36,12 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.{SerializableConfiguration, Utils}
private[sql] object OrcFileFormat {
private def checkFieldName(name: String): Unit = {
try {
TypeDescription.fromString(s"struct<`$name`:int>")
} catch {
case _: IllegalArgumentException =>
throw QueryCompilationErrors.columnNameContainsInvalidCharactersError(name)
}
}
def checkFieldNames(schema: StructType): Unit = {
schema.foreach { field =>
checkFieldName(field.name)
field.dataType match {
case s: StructType => checkFieldNames(s)
case _ =>
}
}
}
def getQuotedSchemaString(dataType: DataType): String = dataType match {
case _: AtomicType => dataType.catalogString
@ -279,4 +260,13 @@ class OrcFileFormat
case _ => false
}
override def supportFieldName(name: String): Boolean = {
try {
TypeDescription.fromString(s"struct<`$name`:int>")
true
} catch {
case _: IllegalArgumentException => false
}
}
}

View file

@ -386,6 +386,10 @@ class ParquetFileFormat
case _ => false
}
override def supportFieldName(name: String): Boolean = {
!name.matches(".*[ ,;{}()\n\t=].*")
}
}
object ParquetFileFormat extends Logging {

View file

@ -591,16 +591,6 @@ private[sql] object ParquetSchemaConverter {
}
}
def checkFieldNames(schema: StructType): Unit = {
schema.foreach { field =>
checkFieldName(field.name)
field.dataType match {
case s: StructType => checkFieldNames(s)
case _ =>
}
}
}
def checkConversionRequirement(f: => Boolean, message: String): Unit = {
if (!f) {
throw new AnalysisException(message)