[SPARK-21912][SQL] ORC/Parquet table should not create invalid column names
## What changes were proposed in this pull request? Currently, users meet job abortions while creating or altering ORC/Parquet tables with invalid column names. We had better prevent this by raising **AnalysisException** with a guide to use aliases instead like Paquet data source tables. **BEFORE** ```scala scala> sql("CREATE TABLE orc1 USING ORC AS SELECT 1 `a b`") 17/09/04 13:28:21 ERROR Utils: Aborting task java.lang.IllegalArgumentException: Error: : expected at the position 8 of 'struct<a b:int>' but ' ' is found. 17/09/04 13:28:21 ERROR FileFormatWriter: Job job_20170904132821_0001 aborted. 17/09/04 13:28:21 ERROR Executor: Exception in task 0.0 in stage 1.0 (TID 1) org.apache.spark.SparkException: Task failed while writing rows. ``` **AFTER** ```scala scala> sql("CREATE TABLE orc1 USING ORC AS SELECT 1 `a b`") 17/09/04 13:27:40 ERROR CreateDataSourceTableAsSelectCommand: Failed to write to table orc1 org.apache.spark.sql.AnalysisException: Attribute name "a b" contains invalid character(s) among " ,;{}()\n\t=". Please use alias to rename it.; ``` ## How was this patch tested? Pass the Jenkins with a new test case. Author: Dongjoon Hyun <dongjoon@apache.org> Closes #19124 from dongjoon-hyun/SPARK-21912.
This commit is contained in:
parent
ce7293c150
commit
eea2b877cf
|
@ -34,6 +34,9 @@ import org.apache.spark.sql.catalyst.catalog._
|
|||
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
|
||||
import org.apache.spark.sql.execution.datasources.PartitioningUtils
|
||||
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
|
||||
import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter
|
||||
import org.apache.spark.sql.internal.HiveSerDe
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.{SerializableConfiguration, ThreadUtils}
|
||||
|
||||
|
@ -848,4 +851,22 @@ object DDLUtils {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] def checkDataSchemaFieldNames(table: CatalogTable): Unit = {
|
||||
table.provider.foreach {
|
||||
_.toLowerCase(Locale.ROOT) match {
|
||||
case HIVE_PROVIDER =>
|
||||
val serde = table.storage.serde
|
||||
if (serde == HiveSerDe.sourceToSerDe("orc").get.serde) {
|
||||
OrcFileFormat.checkFieldNames(table.dataSchema)
|
||||
} else if (serde == HiveSerDe.sourceToSerDe("parquet").get.serde ||
|
||||
serde == Some("parquet.hive.serde.ParquetHiveSerDe")) {
|
||||
ParquetSchemaConverter.checkFieldNames(table.dataSchema)
|
||||
}
|
||||
case "parquet" => ParquetSchemaConverter.checkFieldNames(table.dataSchema)
|
||||
case "orc" => OrcFileFormat.checkFieldNames(table.dataSchema)
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -201,13 +201,14 @@ case class AlterTableAddColumnsCommand(
|
|||
|
||||
// make sure any partition columns are at the end of the fields
|
||||
val reorderedSchema = catalogTable.dataSchema ++ columns ++ catalogTable.partitionSchema
|
||||
val newSchema = catalogTable.schema.copy(fields = reorderedSchema.toArray)
|
||||
|
||||
SchemaUtils.checkColumnNameDuplication(
|
||||
reorderedSchema.map(_.name), "in the table definition of " + table.identifier,
|
||||
conf.caseSensitiveAnalysis)
|
||||
DDLUtils.checkDataSchemaFieldNames(catalogTable.copy(schema = newSchema))
|
||||
|
||||
catalog.alterTableSchema(
|
||||
table, catalogTable.schema.copy(fields = reorderedSchema.toArray))
|
||||
catalog.alterTableSchema(table, newSchema)
|
||||
|
||||
Seq.empty[Row]
|
||||
}
|
||||
|
|
|
@ -130,10 +130,12 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
|
|||
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case CreateTable(tableDesc, mode, None) if DDLUtils.isDatasourceTable(tableDesc) =>
|
||||
DDLUtils.checkDataSchemaFieldNames(tableDesc)
|
||||
CreateDataSourceTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore)
|
||||
|
||||
case CreateTable(tableDesc, mode, Some(query))
|
||||
if query.resolved && DDLUtils.isDatasourceTable(tableDesc) =>
|
||||
DDLUtils.checkDataSchemaFieldNames(tableDesc.copy(schema = query.schema))
|
||||
CreateDataSourceTableAsSelectCommand(tableDesc, mode, query)
|
||||
|
||||
case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _),
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.datasources.orc
|
||||
|
||||
import org.apache.orc.TypeDescription
|
||||
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
private[sql] object OrcFileFormat {
|
||||
private def checkFieldName(name: String): Unit = {
|
||||
try {
|
||||
TypeDescription.fromString(s"struct<$name:int>")
|
||||
} catch {
|
||||
case _: IllegalArgumentException =>
|
||||
throw new AnalysisException(
|
||||
s"""Column name "$name" contains invalid character(s).
|
||||
|Please use alias to rename it.
|
||||
""".stripMargin.split("\n").mkString(" ").trim)
|
||||
}
|
||||
}
|
||||
|
||||
def checkFieldNames(schema: StructType): StructType = {
|
||||
schema.fieldNames.foreach(checkFieldName)
|
||||
schema
|
||||
}
|
||||
}
|
|
@ -556,7 +556,7 @@ private[parquet] class ParquetSchemaConverter(
|
|||
}
|
||||
}
|
||||
|
||||
private[parquet] object ParquetSchemaConverter {
|
||||
private[sql] object ParquetSchemaConverter {
|
||||
val SPARK_PARQUET_SCHEMA_NAME = "spark_schema"
|
||||
|
||||
val EMPTY_MESSAGE: MessageType =
|
||||
|
|
|
@ -2,9 +2,9 @@ CREATE DATABASE showdb;
|
|||
|
||||
USE showdb;
|
||||
|
||||
CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING parquet;
|
||||
CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING json;
|
||||
CREATE TABLE showcolumn2 (price int, qty int, year int, month int) USING parquet partitioned by (year, month);
|
||||
CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING parquet;
|
||||
CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING json;
|
||||
CREATE GLOBAL TEMP VIEW showColumn4 AS SELECT 1 as col1, 'abc' as `col 5`;
|
||||
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ struct<>
|
|||
|
||||
|
||||
-- !query 2
|
||||
CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING parquet
|
||||
CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING json
|
||||
-- !query 2 schema
|
||||
struct<>
|
||||
-- !query 2 output
|
||||
|
@ -35,7 +35,7 @@ struct<>
|
|||
|
||||
|
||||
-- !query 4
|
||||
CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING parquet
|
||||
CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING json
|
||||
-- !query 4 schema
|
||||
struct<>
|
||||
-- !query 4 output
|
||||
|
|
|
@ -151,9 +151,11 @@ object HiveAnalysis extends Rule[LogicalPlan] {
|
|||
InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists)
|
||||
|
||||
case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) =>
|
||||
DDLUtils.checkDataSchemaFieldNames(tableDesc)
|
||||
CreateTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore)
|
||||
|
||||
case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) =>
|
||||
DDLUtils.checkDataSchemaFieldNames(tableDesc)
|
||||
CreateHiveTableAsSelectCommand(tableDesc, query, mode)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2000,4 +2000,38 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
|
|||
assert(setOfPath.size() == pathSizeToDeleteOnExit)
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-21912 ORC/Parquet table should not create invalid column names") {
|
||||
Seq(" ", ",", ";", "{", "}", "(", ")", "\n", "\t", "=").foreach { name =>
|
||||
withTable("t21912") {
|
||||
Seq("ORC", "PARQUET").foreach { source =>
|
||||
val m = intercept[AnalysisException] {
|
||||
sql(s"CREATE TABLE t21912(`col$name` INT) USING $source")
|
||||
}.getMessage
|
||||
assert(m.contains(s"contains invalid character(s)"))
|
||||
|
||||
val m2 = intercept[AnalysisException] {
|
||||
sql(s"CREATE TABLE t21912 USING $source AS SELECT 1 `col$name`")
|
||||
}.getMessage
|
||||
assert(m2.contains(s"contains invalid character(s)"))
|
||||
|
||||
withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false") {
|
||||
val m3 = intercept[AnalysisException] {
|
||||
sql(s"CREATE TABLE t21912(`col$name` INT) USING hive OPTIONS (fileFormat '$source')")
|
||||
}.getMessage
|
||||
assert(m3.contains(s"contains invalid character(s)"))
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: After SPARK-21929, we need to check ORC, too.
|
||||
Seq("PARQUET").foreach { source =>
|
||||
sql(s"CREATE TABLE t21912(`col` INT) USING $source")
|
||||
val m = intercept[AnalysisException] {
|
||||
sql(s"ALTER TABLE t21912 ADD COLUMNS(`col$name` INT)")
|
||||
}.getMessage
|
||||
assert(m.contains(s"contains invalid character(s)"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue