From 2f4f7936fdc06a84abbb264d4f7899b9084e606c Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Wed, 28 Jul 2021 14:04:24 +0800 Subject: [PATCH] [SPARK-33865][SPARK-36202][SQL] When HiveDDL, we need check avro schema too ### What changes were proposed in this pull request? Unify schema check code of FileFormat and check avro schema filed name when CREATE TABLE DDL too ### Why are the changes needed? Refactor code ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Not need Closes #33441 from AngersZhuuuu/SPARK-36202. Authored-by: Angerszhuuuu Signed-off-by: Wenchen Fan (cherry picked from commit 86f44578e5204487930f334aecdd97255681a3fc) Signed-off-by: Wenchen Fan --- .../spark/sql/avro/AvroFileFormat.scala | 12 +++++++ .../org/apache/spark/sql/avro/AvroSuite.scala | 30 ++++++++++++++++ .../spark/sql/execution/command/ddl.scala | 35 ++++++++++++++----- .../datasources/DataSourceUtils.scala | 16 +++++++++ .../execution/datasources/FileFormat.scala | 6 ++++ .../datasources/orc/OrcFileFormat.scala | 28 +++++---------- .../parquet/ParquetFileFormat.scala | 4 +++ .../parquet/ParquetSchemaConverter.scala | 10 ------ 8 files changed, 104 insertions(+), 37 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index c2cea41354..398cb02242 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -153,6 +153,18 @@ private[sql] class AvroFileFormat extends FileFormat } override def supportDataType(dataType: DataType): Boolean = AvroUtils.supportsDataType(dataType) + + override def supportFieldName(name: String): Boolean = { + if (name.length == 0) { + false + } else { + name.zipWithIndex.forall { + case (c, 0) if !Character.isLetter(c) && c != '_' => false + case (c, _) if !Character.isLetterOrDigit(c) && c != '_' => false + case _ => true + } + } + } } private[avro] object AvroFileFormat { diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index ffad851132..f93c61a424 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -2158,6 +2158,36 @@ abstract class AvroSuite } } } + + test("SPARK-33865: CREATE TABLE DDL with avro should check col name") { + withTable("test_ddl") { + withView("v") { + spark.range(1).createTempView("v") + withTempDir { dir => + val e = intercept[AnalysisException] { + sql( + s""" + |CREATE TABLE test_ddl USING AVRO + |LOCATION '${dir}' + |AS SELECT ID, IF(ID=1,1,0) FROM v""".stripMargin) + }.getMessage + assert(e.contains("Column name \"(IF((ID = 1), 1, 0))\" contains invalid character(s).")) + } + + withTempDir { dir => + spark.sql( + s""" + |CREATE TABLE test_ddl USING AVRO + |LOCATION '${dir}' + |AS SELECT ID, IF(ID=1,ID,0) AS A, ABS(ID) AS B + |FROM v""".stripMargin) + val expectedSchema = StructType(Seq(StructField("ID", LongType, true), + StructField("A", LongType, true), StructField("B", LongType, true))) + assert(spark.table("test_ddl").schema == expectedSchema) + } + } + } + } } class AvroV1Suite extends AvroSuite { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 140f9d7dbe..ea1b6566fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} +import org.apache.spark.internal.Logging import org.apache.spark.internal.config.RDD_PARALLEL_LISTING_THRESHOLD import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier @@ -40,9 +41,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableCatalog} import org.apache.spark.sql.connector.catalog.SupportsNamespaces._ import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat -import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter +import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils, FileFormat, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.PartitioningUtils @@ -860,7 +860,7 @@ case class AlterTableSetLocationCommand( } -object DDLUtils { +object DDLUtils extends Logging { val HIVE_PROVIDER = "hive" def isHiveTable(table: CatalogTable): Boolean = { @@ -933,19 +933,38 @@ object DDLUtils { case HIVE_PROVIDER => val serde = table.storage.serde if (serde == HiveSerDe.sourceToSerDe("orc").get.serde) { - OrcFileFormat.checkFieldNames(schema) + checkDataColNames("orc", schema) } else if (serde == HiveSerDe.sourceToSerDe("parquet").get.serde || serde == Some("parquet.hive.serde.ParquetHiveSerDe") || serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) { - ParquetSchemaConverter.checkFieldNames(schema) + checkDataColNames("parquet", schema) + } else if (serde == HiveSerDe.sourceToSerDe("avro").get.serde) { + checkDataColNames("avro", schema) } - case "parquet" => ParquetSchemaConverter.checkFieldNames(schema) - case "orc" => OrcFileFormat.checkFieldNames(schema) + case "parquet" => checkDataColNames("parquet", schema) + case "orc" => checkDataColNames("orc", schema) + case "avro" => checkDataColNames("avro", schema) case _ => } } } + def checkDataColNames(provider: String, schema: StructType): Unit = { + val source = try { + DataSource.lookupDataSource(provider, SQLConf.get).getConstructor().newInstance() + } catch { + case e: Throwable => + logError(s"Failed to find data source: $provider when check data column names.", e) + return + } + source match { + case f: FileFormat => DataSourceUtils.checkFieldNames(f, schema) + case f: FileDataSourceV2 => + DataSourceUtils.checkFieldNames(f.fallbackFileFormat.newInstance(), schema) + case _ => + } + } + /** * Throws exception if outputPath tries to overwrite inputpath. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index 2b10e4efd9..b562d44dec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -58,6 +58,22 @@ object DataSourceUtils { Serialization.read[Seq[String]](str) } + /** + * Verify if the field name is supported in datasource. This verification should be done + * in a driver side. + */ + def checkFieldNames(format: FileFormat, schema: StructType): Unit = { + schema.foreach { field => + if (!format.supportFieldName(field.name)) { + throw QueryCompilationErrors.columnNameContainsInvalidCharactersError(field.name) + } + field.dataType match { + case s: StructType => checkFieldNames(format, s) + case _ => + } + } + } + /** * Verify if the schema is supported in datasource. This verification should be done * in a driver side. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 7fd48caa65..beb1f4d38e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -163,6 +163,12 @@ trait FileFormat { * By default all data types are supported. */ def supportDataType(dataType: DataType): Boolean = true + + /** + * Returns whether this format supports the given filed name in read/write path. + * By default all field name is supported. + */ + def supportFieldName(name: String): Boolean = true } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 9024c7809f..85c0ff01cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -36,31 +36,12 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.{SerializableConfiguration, Utils} private[sql] object OrcFileFormat { - private def checkFieldName(name: String): Unit = { - try { - TypeDescription.fromString(s"struct<`$name`:int>") - } catch { - case _: IllegalArgumentException => - throw QueryCompilationErrors.columnNameContainsInvalidCharactersError(name) - } - } - - def checkFieldNames(schema: StructType): Unit = { - schema.foreach { field => - checkFieldName(field.name) - field.dataType match { - case s: StructType => checkFieldNames(s) - case _ => - } - } - } def getQuotedSchemaString(dataType: DataType): String = dataType match { case _: AtomicType => dataType.catalogString @@ -279,4 +260,13 @@ class OrcFileFormat case _ => false } + + override def supportFieldName(name: String): Boolean = { + try { + TypeDescription.fromString(s"struct<`$name`:int>") + true + } catch { + case _: IllegalArgumentException => false + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index ee229a334f..586952aafb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -386,6 +386,10 @@ class ParquetFileFormat case _ => false } + + override def supportFieldName(name: String): Boolean = { + !name.matches(".*[ ,;{}()\n\t=].*") + } } object ParquetFileFormat extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index f3bfd99368..217c020358 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -591,16 +591,6 @@ private[sql] object ParquetSchemaConverter { } } - def checkFieldNames(schema: StructType): Unit = { - schema.foreach { field => - checkFieldName(field.name) - field.dataType match { - case s: StructType => checkFieldNames(s) - case _ => - } - } - } - def checkConversionRequirement(f: => Boolean, message: String): Unit = { if (!f) { throw new AnalysisException(message)