diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 605d98ee54..140f9d7dbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -924,23 +924,23 @@ object DDLUtils { } private[sql] def checkDataColNames(table: CatalogTable): Unit = { - checkDataColNames(table, table.dataSchema.fieldNames) + checkDataColNames(table, table.dataSchema) } - private[sql] def checkDataColNames(table: CatalogTable, colNames: Seq[String]): Unit = { + private[sql] def checkDataColNames(table: CatalogTable, schema: StructType): Unit = { table.provider.foreach { _.toLowerCase(Locale.ROOT) match { case HIVE_PROVIDER => val serde = table.storage.serde if (serde == HiveSerDe.sourceToSerDe("orc").get.serde) { - OrcFileFormat.checkFieldNames(colNames) + OrcFileFormat.checkFieldNames(schema) } else if (serde == HiveSerDe.sourceToSerDe("parquet").get.serde || serde == Some("parquet.hive.serde.ParquetHiveSerDe") || serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) { - ParquetSchemaConverter.checkFieldNames(colNames) + ParquetSchemaConverter.checkFieldNames(schema) } - case "parquet" => ParquetSchemaConverter.checkFieldNames(colNames) - case "orc" => OrcFileFormat.checkFieldNames(colNames) + case "parquet" => ParquetSchemaConverter.checkFieldNames(schema) + case "orc" => OrcFileFormat.checkFieldNames(schema) case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 059962192c..f740915b6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -236,7 +236,7 @@ case class AlterTableAddColumnsCommand( (colsToAdd ++ catalogTable.schema).map(_.name), "in the table definition of " + table.identifier, conf.caseSensitiveAnalysis) - DDLUtils.checkDataColNames(catalogTable, colsToAdd.map(_.name)) + DDLUtils.checkDataColNames(catalogTable, StructType(colsToAdd)) val existingSchema = CharVarcharUtils.getRawSchema(catalogTable.dataSchema) catalog.alterTableDataSchema(table, StructType(existingSchema ++ colsToAdd)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index d6593cad4a..9024c7809f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -52,8 +52,14 @@ private[sql] object OrcFileFormat { } } - def checkFieldNames(names: Seq[String]): Unit = { - names.foreach(checkFieldName) + def checkFieldNames(schema: StructType): Unit = { + schema.foreach { field => + checkFieldName(field.name) + field.dataType match { + case s: StructType => checkFieldNames(s) + case _ => + } + } } def getQuotedSchemaString(dataType: DataType): String = dataType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 1b26c69fe1..a23eebe6cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -593,8 +593,14 @@ private[sql] object ParquetSchemaConverter { """.stripMargin.split("\n").mkString(" ").trim) } - def checkFieldNames(names: Seq[String]): Unit = { - names.foreach(checkFieldName) + def checkFieldNames(schema: StructType): Unit = { + schema.foreach { field => + checkFieldName(field.name) + field.dataType match { + case s: StructType => checkFieldNames(s) + case _ => + } + } } def checkConversionRequirement(f: => Boolean, message: String): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 9a39f1872b..3e01fcbe16 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -3008,6 +3008,26 @@ class HiveDDLSuite } } + test("SPARK-36201: Add check for inner field of parquet/orc schema") { + withView("v") { + spark.range(1).createTempView("v") + withTempPath { path => + val e = intercept[AnalysisException] { + spark.sql( + s""" + |INSERT OVERWRITE LOCAL DIRECTORY '${path.getCanonicalPath}' + |STORED AS PARQUET + |SELECT + |NAMED_STRUCT('ID', ID, 'IF(ID=1,ID,0)', IF(ID=1,ID,0), 'B', ABS(ID)) AS col1 + |FROM v + """.stripMargin) + }.getMessage + assert(e.contains("Attribute name \"IF(ID=1,ID,0)\" contains" + + " invalid character(s) among \" ,;{}()\\n\\t=\". Please use alias to rename it.")) + } + } + } + test("SPARK-34261: Avoid side effect if create exists temporary function") { withUserDefinedFunction("f1" -> true) { sql("CREATE TEMPORARY FUNCTION f1 AS 'org.apache.hadoop.hive.ql.udf.UDFUUID'")