[SPARK-21912][SQL] ORC/Parquet table should not create invalid column names

## What changes were proposed in this pull request?

Currently, users meet job abortions while creating or altering ORC/Parquet tables with invalid column names. We had better prevent this by raising **AnalysisException** with a guide to use aliases instead like Paquet data source tables.

**BEFORE**
```scala
scala> sql("CREATE TABLE orc1 USING ORC AS SELECT 1 `a b`")
17/09/04 13:28:21 ERROR Utils: Aborting task
java.lang.IllegalArgumentException: Error: : expected at the position 8 of 'struct<a b:int>' but ' ' is found.
17/09/04 13:28:21 ERROR FileFormatWriter: Job job_20170904132821_0001 aborted.
17/09/04 13:28:21 ERROR Executor: Exception in task 0.0 in stage 1.0 (TID 1)
org.apache.spark.SparkException: Task failed while writing rows.
```

**AFTER**
```scala
scala> sql("CREATE TABLE orc1 USING ORC AS SELECT 1 `a b`")
17/09/04 13:27:40 ERROR CreateDataSourceTableAsSelectCommand: Failed to write to table orc1
org.apache.spark.sql.AnalysisException: Attribute name "a b" contains invalid character(s) among " ,;{}()\n\t=". Please use alias to rename it.;
```

## How was this patch tested?

Pass the Jenkins with a new test case.

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #19124 from dongjoon-hyun/SPARK-21912.
This commit is contained in:
Dongjoon Hyun 2017-09-06 22:20:48 -07:00 committed by gatorsmile
parent ce7293c150
commit eea2b877cf
9 changed files with 109 additions and 7 deletions

View file

@ -34,6 +34,9 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter
import org.apache.spark.sql.internal.HiveSerDe
import org.apache.spark.sql.types._
import org.apache.spark.util.{SerializableConfiguration, ThreadUtils}
@ -848,4 +851,22 @@ object DDLUtils {
}
}
}
private[sql] def checkDataSchemaFieldNames(table: CatalogTable): Unit = {
table.provider.foreach {
_.toLowerCase(Locale.ROOT) match {
case HIVE_PROVIDER =>
val serde = table.storage.serde
if (serde == HiveSerDe.sourceToSerDe("orc").get.serde) {
OrcFileFormat.checkFieldNames(table.dataSchema)
} else if (serde == HiveSerDe.sourceToSerDe("parquet").get.serde ||
serde == Some("parquet.hive.serde.ParquetHiveSerDe")) {
ParquetSchemaConverter.checkFieldNames(table.dataSchema)
}
case "parquet" => ParquetSchemaConverter.checkFieldNames(table.dataSchema)
case "orc" => OrcFileFormat.checkFieldNames(table.dataSchema)
case _ =>
}
}
}
}

View file

@ -201,13 +201,14 @@ case class AlterTableAddColumnsCommand(
// make sure any partition columns are at the end of the fields
val reorderedSchema = catalogTable.dataSchema ++ columns ++ catalogTable.partitionSchema
val newSchema = catalogTable.schema.copy(fields = reorderedSchema.toArray)
SchemaUtils.checkColumnNameDuplication(
reorderedSchema.map(_.name), "in the table definition of " + table.identifier,
conf.caseSensitiveAnalysis)
DDLUtils.checkDataSchemaFieldNames(catalogTable.copy(schema = newSchema))
catalog.alterTableSchema(
table, catalogTable.schema.copy(fields = reorderedSchema.toArray))
catalog.alterTableSchema(table, newSchema)
Seq.empty[Row]
}

View file

@ -130,10 +130,12 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case CreateTable(tableDesc, mode, None) if DDLUtils.isDatasourceTable(tableDesc) =>
DDLUtils.checkDataSchemaFieldNames(tableDesc)
CreateDataSourceTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore)
case CreateTable(tableDesc, mode, Some(query))
if query.resolved && DDLUtils.isDatasourceTable(tableDesc) =>
DDLUtils.checkDataSchemaFieldNames(tableDesc.copy(schema = query.schema))
CreateDataSourceTableAsSelectCommand(tableDesc, mode, query)
case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _),

View file

@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.orc
import org.apache.orc.TypeDescription
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.types.StructType
private[sql] object OrcFileFormat {
private def checkFieldName(name: String): Unit = {
try {
TypeDescription.fromString(s"struct<$name:int>")
} catch {
case _: IllegalArgumentException =>
throw new AnalysisException(
s"""Column name "$name" contains invalid character(s).
|Please use alias to rename it.
""".stripMargin.split("\n").mkString(" ").trim)
}
}
def checkFieldNames(schema: StructType): StructType = {
schema.fieldNames.foreach(checkFieldName)
schema
}
}

View file

@ -556,7 +556,7 @@ private[parquet] class ParquetSchemaConverter(
}
}
private[parquet] object ParquetSchemaConverter {
private[sql] object ParquetSchemaConverter {
val SPARK_PARQUET_SCHEMA_NAME = "spark_schema"
val EMPTY_MESSAGE: MessageType =

View file

@ -2,9 +2,9 @@ CREATE DATABASE showdb;
USE showdb;
CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING parquet;
CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING json;
CREATE TABLE showcolumn2 (price int, qty int, year int, month int) USING parquet partitioned by (year, month);
CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING parquet;
CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING json;
CREATE GLOBAL TEMP VIEW showColumn4 AS SELECT 1 as col1, 'abc' as `col 5`;

View file

@ -19,7 +19,7 @@ struct<>
-- !query 2
CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING parquet
CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING json
-- !query 2 schema
struct<>
-- !query 2 output
@ -35,7 +35,7 @@ struct<>
-- !query 4
CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING parquet
CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING json
-- !query 4 schema
struct<>
-- !query 4 output

View file

@ -151,9 +151,11 @@ object HiveAnalysis extends Rule[LogicalPlan] {
InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists)
case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) =>
DDLUtils.checkDataSchemaFieldNames(tableDesc)
CreateTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore)
case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) =>
DDLUtils.checkDataSchemaFieldNames(tableDesc)
CreateHiveTableAsSelectCommand(tableDesc, query, mode)
}
}

View file

@ -2000,4 +2000,38 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
assert(setOfPath.size() == pathSizeToDeleteOnExit)
}
}
test("SPARK-21912 ORC/Parquet table should not create invalid column names") {
Seq(" ", ",", ";", "{", "}", "(", ")", "\n", "\t", "=").foreach { name =>
withTable("t21912") {
Seq("ORC", "PARQUET").foreach { source =>
val m = intercept[AnalysisException] {
sql(s"CREATE TABLE t21912(`col$name` INT) USING $source")
}.getMessage
assert(m.contains(s"contains invalid character(s)"))
val m2 = intercept[AnalysisException] {
sql(s"CREATE TABLE t21912 USING $source AS SELECT 1 `col$name`")
}.getMessage
assert(m2.contains(s"contains invalid character(s)"))
withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false") {
val m3 = intercept[AnalysisException] {
sql(s"CREATE TABLE t21912(`col$name` INT) USING hive OPTIONS (fileFormat '$source')")
}.getMessage
assert(m3.contains(s"contains invalid character(s)"))
}
}
// TODO: After SPARK-21929, we need to check ORC, too.
Seq("PARQUET").foreach { source =>
sql(s"CREATE TABLE t21912(`col` INT) USING $source")
val m = intercept[AnalysisException] {
sql(s"ALTER TABLE t21912 ADD COLUMNS(`col$name` INT)")
}.getMessage
assert(m.contains(s"contains invalid character(s)"))
}
}
}
}
}