From 7cd89efca5a74dcc2457c7be5f2ef65aeb90a967 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Tue, 20 Jul 2021 21:08:03 +0800 Subject: [PATCH] [SPARK-36201][SQL][FOLLOWUP] Schema check should check inner field too ### What changes were proposed in this pull request? When inner field have wrong schema filed name should check field name too. ![image](https://user-images.githubusercontent.com/46485123/126101009-c192d87f-1e18-4355-ad53-1419dacdeb76.png) ### Why are the changes needed? Early check early faield ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added UT Closes #33409 from AngersZhuuuu/SPARK-36201. Authored-by: Angerszhuuuu Signed-off-by: Wenchen Fan (cherry picked from commit 251885772d41a572655e950a8e298315f222a803) Signed-off-by: Wenchen Fan --- .../spark/sql/execution/command/ddl.scala | 12 +++++------ .../spark/sql/execution/command/tables.scala | 2 +- .../datasources/orc/OrcFileFormat.scala | 10 ++++++++-- .../parquet/ParquetSchemaConverter.scala | 10 ++++++++-- .../sql/hive/execution/HiveDDLSuite.scala | 20 +++++++++++++++++++ 5 files changed, 43 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 605d98ee54..140f9d7dbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -924,23 +924,23 @@ object DDLUtils { } private[sql] def checkDataColNames(table: CatalogTable): Unit = { - checkDataColNames(table, table.dataSchema.fieldNames) + checkDataColNames(table, table.dataSchema) } - private[sql] def checkDataColNames(table: CatalogTable, colNames: Seq[String]): Unit = { + private[sql] def checkDataColNames(table: CatalogTable, schema: StructType): Unit = { table.provider.foreach { _.toLowerCase(Locale.ROOT) match { case HIVE_PROVIDER => val serde = table.storage.serde if (serde == HiveSerDe.sourceToSerDe("orc").get.serde) { - OrcFileFormat.checkFieldNames(colNames) + OrcFileFormat.checkFieldNames(schema) } else if (serde == HiveSerDe.sourceToSerDe("parquet").get.serde || serde == Some("parquet.hive.serde.ParquetHiveSerDe") || serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) { - ParquetSchemaConverter.checkFieldNames(colNames) + ParquetSchemaConverter.checkFieldNames(schema) } - case "parquet" => ParquetSchemaConverter.checkFieldNames(colNames) - case "orc" => OrcFileFormat.checkFieldNames(colNames) + case "parquet" => ParquetSchemaConverter.checkFieldNames(schema) + case "orc" => OrcFileFormat.checkFieldNames(schema) case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 059962192c..f740915b6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -236,7 +236,7 @@ case class AlterTableAddColumnsCommand( (colsToAdd ++ catalogTable.schema).map(_.name), "in the table definition of " + table.identifier, conf.caseSensitiveAnalysis) - DDLUtils.checkDataColNames(catalogTable, colsToAdd.map(_.name)) + DDLUtils.checkDataColNames(catalogTable, StructType(colsToAdd)) val existingSchema = CharVarcharUtils.getRawSchema(catalogTable.dataSchema) catalog.alterTableDataSchema(table, StructType(existingSchema ++ colsToAdd)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index d6593cad4a..9024c7809f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -52,8 +52,14 @@ private[sql] object OrcFileFormat { } } - def checkFieldNames(names: Seq[String]): Unit = { - names.foreach(checkFieldName) + def checkFieldNames(schema: StructType): Unit = { + schema.foreach { field => + checkFieldName(field.name) + field.dataType match { + case s: StructType => checkFieldNames(s) + case _ => + } + } } def getQuotedSchemaString(dataType: DataType): String = dataType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 1b26c69fe1..a23eebe6cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -593,8 +593,14 @@ private[sql] object ParquetSchemaConverter { """.stripMargin.split("\n").mkString(" ").trim) } - def checkFieldNames(names: Seq[String]): Unit = { - names.foreach(checkFieldName) + def checkFieldNames(schema: StructType): Unit = { + schema.foreach { field => + checkFieldName(field.name) + field.dataType match { + case s: StructType => checkFieldNames(s) + case _ => + } + } } def checkConversionRequirement(f: => Boolean, message: String): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 9a39f1872b..3e01fcbe16 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -3008,6 +3008,26 @@ class HiveDDLSuite } } + test("SPARK-36201: Add check for inner field of parquet/orc schema") { + withView("v") { + spark.range(1).createTempView("v") + withTempPath { path => + val e = intercept[AnalysisException] { + spark.sql( + s""" + |INSERT OVERWRITE LOCAL DIRECTORY '${path.getCanonicalPath}' + |STORED AS PARQUET + |SELECT + |NAMED_STRUCT('ID', ID, 'IF(ID=1,ID,0)', IF(ID=1,ID,0), 'B', ABS(ID)) AS col1 + |FROM v + """.stripMargin) + }.getMessage + assert(e.contains("Attribute name \"IF(ID=1,ID,0)\" contains" + + " invalid character(s) among \" ,;{}()\\n\\t=\". Please use alias to rename it.")) + } + } + } + test("SPARK-34261: Avoid side effect if create exists temporary function") { withUserDefinedFunction("f1" -> true) { sql("CREATE TEMPORARY FUNCTION f1 AS 'org.apache.hadoop.hive.ql.udf.UDFUUID'")