[SPARK-6250][SPARK-6146][SPARK-5911][SQL] Types are now reserved words in DDL parser.
This PR creates a trait `DataTypeParser` used to parse data types. This trait aims to be single place to provide the functionality of parsing data types' string representation. It is currently mixed in with `DDLParser` and `SqlParser`. It is also used to parse the data type for `DataFrame.cast` and to convert Hive metastore's data type string back to a `DataType`. JIRA: https://issues.apache.org/jira/browse/SPARK-6250 Author: Yin Huai <yhuai@databricks.com> Closes #5078 from yhuai/ddlKeywords and squashes the following commits: 0e66097 [Yin Huai] Special handle struct<>. fea6012 [Yin Huai] Style. c9733fb [Yin Huai] Create a trait to parse data types.
This commit is contained in:
parent
ee569a0c71
commit
94a102acb8
|
@ -35,7 +35,7 @@ import org.apache.spark.sql.types._
|
|||
* This is currently included mostly for illustrative purposes. Users wanting more complete support
|
||||
* for a SQL like language should checkout the HiveQL support in the sql/hive sub-project.
|
||||
*/
|
||||
class SqlParser extends AbstractSparkSQLParser {
|
||||
class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
|
||||
|
||||
def parseExpression(input: String): Expression = {
|
||||
// Initialize the Keywords.
|
||||
|
@ -61,11 +61,8 @@ class SqlParser extends AbstractSparkSQLParser {
|
|||
protected val CAST = Keyword("CAST")
|
||||
protected val COALESCE = Keyword("COALESCE")
|
||||
protected val COUNT = Keyword("COUNT")
|
||||
protected val DATE = Keyword("DATE")
|
||||
protected val DECIMAL = Keyword("DECIMAL")
|
||||
protected val DESC = Keyword("DESC")
|
||||
protected val DISTINCT = Keyword("DISTINCT")
|
||||
protected val DOUBLE = Keyword("DOUBLE")
|
||||
protected val ELSE = Keyword("ELSE")
|
||||
protected val END = Keyword("END")
|
||||
protected val EXCEPT = Keyword("EXCEPT")
|
||||
|
@ -78,7 +75,6 @@ class SqlParser extends AbstractSparkSQLParser {
|
|||
protected val IF = Keyword("IF")
|
||||
protected val IN = Keyword("IN")
|
||||
protected val INNER = Keyword("INNER")
|
||||
protected val INT = Keyword("INT")
|
||||
protected val INSERT = Keyword("INSERT")
|
||||
protected val INTERSECT = Keyword("INTERSECT")
|
||||
protected val INTO = Keyword("INTO")
|
||||
|
@ -105,13 +101,11 @@ class SqlParser extends AbstractSparkSQLParser {
|
|||
protected val SELECT = Keyword("SELECT")
|
||||
protected val SEMI = Keyword("SEMI")
|
||||
protected val SQRT = Keyword("SQRT")
|
||||
protected val STRING = Keyword("STRING")
|
||||
protected val SUBSTR = Keyword("SUBSTR")
|
||||
protected val SUBSTRING = Keyword("SUBSTRING")
|
||||
protected val SUM = Keyword("SUM")
|
||||
protected val TABLE = Keyword("TABLE")
|
||||
protected val THEN = Keyword("THEN")
|
||||
protected val TIMESTAMP = Keyword("TIMESTAMP")
|
||||
protected val TRUE = Keyword("TRUE")
|
||||
protected val UNION = Keyword("UNION")
|
||||
protected val UPPER = Keyword("UPPER")
|
||||
|
@ -315,7 +309,9 @@ class SqlParser extends AbstractSparkSQLParser {
|
|||
)
|
||||
|
||||
protected lazy val cast: Parser[Expression] =
|
||||
CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { case exp ~ t => Cast(exp, t) }
|
||||
CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ {
|
||||
case exp ~ t => Cast(exp, t)
|
||||
}
|
||||
|
||||
protected lazy val literal: Parser[Literal] =
|
||||
( numericLiteral
|
||||
|
@ -387,19 +383,4 @@ class SqlParser extends AbstractSparkSQLParser {
|
|||
(ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ {
|
||||
case i1 ~ i2 ~ rest => UnresolvedAttribute((Seq(i1, i2) ++ rest).mkString("."))
|
||||
}
|
||||
|
||||
protected lazy val dataType: Parser[DataType] =
|
||||
( STRING ^^^ StringType
|
||||
| TIMESTAMP ^^^ TimestampType
|
||||
| DOUBLE ^^^ DoubleType
|
||||
| fixedDecimalType
|
||||
| DECIMAL ^^^ DecimalType.Unlimited
|
||||
| DATE ^^^ DateType
|
||||
| INT ^^^ IntegerType
|
||||
)
|
||||
|
||||
protected lazy val fixedDecimalType: Parser[DataType] =
|
||||
(DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ {
|
||||
case precision ~ scale => DecimalType(precision.toInt, scale.toInt)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.types
|
||||
|
||||
import scala.language.implicitConversions
|
||||
import scala.util.matching.Regex
|
||||
import scala.util.parsing.combinator.syntactical.StandardTokenParsers
|
||||
|
||||
import org.apache.spark.sql.catalyst.SqlLexical
|
||||
|
||||
/**
|
||||
* This is a data type parser that can be used to parse string representations of data types
|
||||
* provided in SQL queries. This parser is mixed in with DDLParser and SqlParser.
|
||||
*/
|
||||
private[sql] trait DataTypeParser extends StandardTokenParsers {
|
||||
|
||||
// This is used to create a parser from a regex. We are using regexes for data type strings
|
||||
// since these strings can be also used as column names or field names.
|
||||
import lexical.Identifier
|
||||
implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch(
|
||||
s"identifier matching regex ${regex}",
|
||||
{ case Identifier(str) if regex.unapplySeq(str).isDefined => str }
|
||||
)
|
||||
|
||||
protected lazy val primitiveType: Parser[DataType] =
|
||||
"(?i)string".r ^^^ StringType |
|
||||
"(?i)float".r ^^^ FloatType |
|
||||
"(?i)int".r ^^^ IntegerType |
|
||||
"(?i)tinyint".r ^^^ ByteType |
|
||||
"(?i)smallint".r ^^^ ShortType |
|
||||
"(?i)double".r ^^^ DoubleType |
|
||||
"(?i)bigint".r ^^^ LongType |
|
||||
"(?i)binary".r ^^^ BinaryType |
|
||||
"(?i)boolean".r ^^^ BooleanType |
|
||||
fixedDecimalType |
|
||||
"(?i)decimal".r ^^^ DecimalType.Unlimited |
|
||||
"(?i)date".r ^^^ DateType |
|
||||
"(?i)timestamp".r ^^^ TimestampType |
|
||||
varchar
|
||||
|
||||
protected lazy val fixedDecimalType: Parser[DataType] =
|
||||
("(?i)decimal".r ~> "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ {
|
||||
case precision ~ scale =>
|
||||
DecimalType(precision.toInt, scale.toInt)
|
||||
}
|
||||
|
||||
protected lazy val varchar: Parser[DataType] =
|
||||
"(?i)varchar".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType
|
||||
|
||||
protected lazy val arrayType: Parser[DataType] =
|
||||
"(?i)array".r ~> "<" ~> dataType <~ ">" ^^ {
|
||||
case tpe => ArrayType(tpe)
|
||||
}
|
||||
|
||||
protected lazy val mapType: Parser[DataType] =
|
||||
"(?i)map".r ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ {
|
||||
case t1 ~ _ ~ t2 => MapType(t1, t2)
|
||||
}
|
||||
|
||||
protected lazy val structField: Parser[StructField] =
|
||||
ident ~ ":" ~ dataType ^^ {
|
||||
case name ~ _ ~ tpe => StructField(name, tpe, nullable = true)
|
||||
}
|
||||
|
||||
protected lazy val structType: Parser[DataType] =
|
||||
("(?i)struct".r ~> "<" ~> repsep(structField, ",") <~ ">" ^^ {
|
||||
case fields => new StructType(fields.toArray)
|
||||
}) |
|
||||
("(?i)struct".r ~ "<>" ^^^ StructType(Nil))
|
||||
|
||||
protected lazy val dataType: Parser[DataType] =
|
||||
arrayType |
|
||||
mapType |
|
||||
structType |
|
||||
primitiveType
|
||||
|
||||
def toDataType(dataTypeString: String): DataType = synchronized {
|
||||
phrase(dataType)(new lexical.Scanner(dataTypeString)) match {
|
||||
case Success(result, _) => result
|
||||
case failure: NoSuccess => throw new DataTypeException(failMessage(dataTypeString))
|
||||
}
|
||||
}
|
||||
|
||||
private def failMessage(dataTypeString: String): String = {
|
||||
s"Unsupported dataType: $dataTypeString. If you have a struct and a field name of it has " +
|
||||
"any special characters, please use backticks (`) to quote that field name, e.g. `x+y`. " +
|
||||
"Please note that backtick itself is not supported in a field name."
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] object DataTypeParser {
|
||||
lazy val dataTypeParser = new DataTypeParser {
|
||||
override val lexical = new SqlLexical
|
||||
}
|
||||
|
||||
def apply(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString)
|
||||
}
|
||||
|
||||
/** The exception thrown from the [[DataTypeParser]]. */
|
||||
protected[sql] class DataTypeException(message: String) extends Exception(message)
|
|
@ -0,0 +1,116 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.types
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
class DataTypeParserSuite extends FunSuite {
|
||||
|
||||
def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = {
|
||||
test(s"parse ${dataTypeString.replace("\n", "")}") {
|
||||
assert(DataTypeParser(dataTypeString) === expectedDataType)
|
||||
}
|
||||
}
|
||||
|
||||
def unsupported(dataTypeString: String): Unit = {
|
||||
test(s"$dataTypeString is not supported") {
|
||||
intercept[DataTypeException](DataTypeParser(dataTypeString))
|
||||
}
|
||||
}
|
||||
|
||||
checkDataType("int", IntegerType)
|
||||
checkDataType("BooLean", BooleanType)
|
||||
checkDataType("tinYint", ByteType)
|
||||
checkDataType("smallINT", ShortType)
|
||||
checkDataType("INT", IntegerType)
|
||||
checkDataType("bigint", LongType)
|
||||
checkDataType("float", FloatType)
|
||||
checkDataType("dOUBle", DoubleType)
|
||||
checkDataType("decimal(10, 5)", DecimalType(10, 5))
|
||||
checkDataType("decimal", DecimalType.Unlimited)
|
||||
checkDataType("DATE", DateType)
|
||||
checkDataType("timestamp", TimestampType)
|
||||
checkDataType("string", StringType)
|
||||
checkDataType("varchAr(20)", StringType)
|
||||
checkDataType("BINARY", BinaryType)
|
||||
|
||||
checkDataType("array<doublE>", ArrayType(DoubleType, true))
|
||||
checkDataType("Array<map<int, tinYint>>", ArrayType(MapType(IntegerType, ByteType, true), true))
|
||||
checkDataType(
|
||||
"array<struct<tinYint:tinyint>>",
|
||||
ArrayType(StructType(StructField("tinYint", ByteType, true) :: Nil), true)
|
||||
)
|
||||
checkDataType("MAP<int, STRING>", MapType(IntegerType, StringType, true))
|
||||
checkDataType("MAp<int, ARRAY<double>>", MapType(IntegerType, ArrayType(DoubleType), true))
|
||||
checkDataType(
|
||||
"MAP<int, struct<varchar:string>>",
|
||||
MapType(IntegerType, StructType(StructField("varchar", StringType, true) :: Nil), true)
|
||||
)
|
||||
|
||||
checkDataType(
|
||||
"struct<intType: int, ts:timestamp>",
|
||||
StructType(
|
||||
StructField("intType", IntegerType, true) ::
|
||||
StructField("ts", TimestampType, true) :: Nil)
|
||||
)
|
||||
// It is fine to use the data type string as the column name.
|
||||
checkDataType(
|
||||
"Struct<int: int, timestamp:timestamp>",
|
||||
StructType(
|
||||
StructField("int", IntegerType, true) ::
|
||||
StructField("timestamp", TimestampType, true) :: Nil)
|
||||
)
|
||||
checkDataType(
|
||||
"""
|
||||
|struct<
|
||||
| struct:struct<deciMal:DECimal, anotherDecimal:decimAL(5,2)>,
|
||||
| MAP:Map<timestamp, varchar(10)>,
|
||||
| arrAy:Array<double>>
|
||||
""".stripMargin,
|
||||
StructType(
|
||||
StructField("struct",
|
||||
StructType(
|
||||
StructField("deciMal", DecimalType.Unlimited, true) ::
|
||||
StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) ::
|
||||
StructField("MAP", MapType(TimestampType, StringType), true) ::
|
||||
StructField("arrAy", ArrayType(DoubleType, true), true) :: Nil)
|
||||
)
|
||||
// A column name can be a reserved word in our DDL parser and SqlParser.
|
||||
checkDataType(
|
||||
"Struct<TABLE: string, CASE:boolean>",
|
||||
StructType(
|
||||
StructField("TABLE", StringType, true) ::
|
||||
StructField("CASE", BooleanType, true) :: Nil)
|
||||
)
|
||||
// Use backticks to quote column names having special characters.
|
||||
checkDataType(
|
||||
"struct<`x+y`:int, `!@#$%^&*()`:string, `1_2.345<>:\"`:varchar(20)>",
|
||||
StructType(
|
||||
StructField("x+y", IntegerType, true) ::
|
||||
StructField("!@#$%^&*()", StringType, true) ::
|
||||
StructField("1_2.345<>:\"", StringType, true) :: Nil)
|
||||
)
|
||||
// Empty struct.
|
||||
checkDataType("strUCt<>", StructType(Nil))
|
||||
|
||||
unsupported("it is not a data type")
|
||||
unsupported("struct<x+y: int, 1.1:timestamp>")
|
||||
unsupported("struct<x: int")
|
||||
unsupported("struct<x int, y string>")
|
||||
unsupported("struct<`x``y` int>")
|
||||
}
|
|
@ -624,20 +624,7 @@ class Column(protected[sql] val expr: Expression) {
|
|||
*
|
||||
* @group expr_ops
|
||||
*/
|
||||
def cast(to: String): Column = cast(to.toLowerCase match {
|
||||
case "string" | "str" => StringType
|
||||
case "boolean" => BooleanType
|
||||
case "byte" => ByteType
|
||||
case "short" => ShortType
|
||||
case "int" => IntegerType
|
||||
case "long" => LongType
|
||||
case "float" => FloatType
|
||||
case "double" => DoubleType
|
||||
case "decimal" => DecimalType.Unlimited
|
||||
case "date" => DateType
|
||||
case "timestamp" => TimestampType
|
||||
case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""")
|
||||
})
|
||||
def cast(to: String): Column = cast(DataTypeParser(to))
|
||||
|
||||
/**
|
||||
* Returns an ordering used in sorting.
|
||||
|
|
|
@ -34,7 +34,8 @@ import org.apache.spark.util.Utils
|
|||
* A parser for foreign DDL commands.
|
||||
*/
|
||||
private[sql] class DDLParser(
|
||||
parseQuery: String => LogicalPlan) extends AbstractSparkSQLParser with Logging {
|
||||
parseQuery: String => LogicalPlan)
|
||||
extends AbstractSparkSQLParser with DataTypeParser with Logging {
|
||||
|
||||
def apply(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = {
|
||||
try {
|
||||
|
@ -46,14 +47,6 @@ private[sql] class DDLParser(
|
|||
}
|
||||
}
|
||||
|
||||
def parseType(input: String): DataType = {
|
||||
lexical.initialize(reservedWords)
|
||||
phrase(dataType)(new lexical.Scanner(input)) match {
|
||||
case Success(r, x) => r
|
||||
case x => throw new DDLException(s"Unsupported dataType: $x")
|
||||
}
|
||||
}
|
||||
|
||||
// Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
|
||||
// properties via reflection the class in runtime for constructing the SqlLexical object
|
||||
protected val CREATE = Keyword("CREATE")
|
||||
|
@ -70,24 +63,6 @@ private[sql] class DDLParser(
|
|||
protected val COMMENT = Keyword("COMMENT")
|
||||
protected val REFRESH = Keyword("REFRESH")
|
||||
|
||||
// Data types.
|
||||
protected val STRING = Keyword("STRING")
|
||||
protected val BINARY = Keyword("BINARY")
|
||||
protected val BOOLEAN = Keyword("BOOLEAN")
|
||||
protected val TINYINT = Keyword("TINYINT")
|
||||
protected val SMALLINT = Keyword("SMALLINT")
|
||||
protected val INT = Keyword("INT")
|
||||
protected val BIGINT = Keyword("BIGINT")
|
||||
protected val FLOAT = Keyword("FLOAT")
|
||||
protected val DOUBLE = Keyword("DOUBLE")
|
||||
protected val DECIMAL = Keyword("DECIMAL")
|
||||
protected val DATE = Keyword("DATE")
|
||||
protected val TIMESTAMP = Keyword("TIMESTAMP")
|
||||
protected val VARCHAR = Keyword("VARCHAR")
|
||||
protected val ARRAY = Keyword("ARRAY")
|
||||
protected val MAP = Keyword("MAP")
|
||||
protected val STRUCT = Keyword("STRUCT")
|
||||
|
||||
protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable
|
||||
|
||||
protected def start: Parser[LogicalPlan] = ddl
|
||||
|
@ -189,58 +164,9 @@ private[sql] class DDLParser(
|
|||
new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build()
|
||||
case None => Metadata.empty
|
||||
}
|
||||
|
||||
StructField(columnName, typ, nullable = true, meta)
|
||||
}
|
||||
|
||||
protected lazy val primitiveType: Parser[DataType] =
|
||||
STRING ^^^ StringType |
|
||||
BINARY ^^^ BinaryType |
|
||||
BOOLEAN ^^^ BooleanType |
|
||||
TINYINT ^^^ ByteType |
|
||||
SMALLINT ^^^ ShortType |
|
||||
INT ^^^ IntegerType |
|
||||
BIGINT ^^^ LongType |
|
||||
FLOAT ^^^ FloatType |
|
||||
DOUBLE ^^^ DoubleType |
|
||||
fixedDecimalType | // decimal with precision/scale
|
||||
DECIMAL ^^^ DecimalType.Unlimited | // decimal with no precision/scale
|
||||
DATE ^^^ DateType |
|
||||
TIMESTAMP ^^^ TimestampType |
|
||||
VARCHAR ~ "(" ~ numericLit ~ ")" ^^^ StringType
|
||||
|
||||
protected lazy val fixedDecimalType: Parser[DataType] =
|
||||
(DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ {
|
||||
case precision ~ scale => DecimalType(precision.toInt, scale.toInt)
|
||||
}
|
||||
|
||||
protected lazy val arrayType: Parser[DataType] =
|
||||
ARRAY ~> "<" ~> dataType <~ ">" ^^ {
|
||||
case tpe => ArrayType(tpe)
|
||||
}
|
||||
|
||||
protected lazy val mapType: Parser[DataType] =
|
||||
MAP ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ {
|
||||
case t1 ~ _ ~ t2 => MapType(t1, t2)
|
||||
}
|
||||
|
||||
protected lazy val structField: Parser[StructField] =
|
||||
ident ~ ":" ~ dataType ^^ {
|
||||
case fieldName ~ _ ~ tpe => StructField(fieldName, tpe, nullable = true)
|
||||
}
|
||||
|
||||
protected lazy val structType: Parser[DataType] =
|
||||
(STRUCT ~> "<" ~> repsep(structField, ",") <~ ">" ^^ {
|
||||
case fields => StructType(fields)
|
||||
}) |
|
||||
(STRUCT ~> "<>" ^^ {
|
||||
case fields => StructType(Nil)
|
||||
})
|
||||
|
||||
private[sql] lazy val dataType: Parser[DataType] =
|
||||
arrayType |
|
||||
mapType |
|
||||
structType |
|
||||
primitiveType
|
||||
}
|
||||
|
||||
private[sql] object ResolvedDataSource {
|
||||
|
|
|
@ -756,7 +756,7 @@ private[hive] case class MetastoreRelation
|
|||
implicit class SchemaAttribute(f: FieldSchema) {
|
||||
def toAttribute = AttributeReference(
|
||||
f.getName,
|
||||
sqlContext.ddlParser.parseType(f.getType),
|
||||
HiveMetastoreTypes.toDataType(f.getType),
|
||||
// Since data can be dumped in randomly with no validation, everything is nullable.
|
||||
nullable = true
|
||||
)(qualifiers = Seq(alias.getOrElse(tableName)))
|
||||
|
@ -779,11 +779,7 @@ private[hive] case class MetastoreRelation
|
|||
|
||||
|
||||
private[hive] object HiveMetastoreTypes {
|
||||
protected val ddlParser = new DDLParser(HiveQl.parseSql(_))
|
||||
|
||||
def toDataType(metastoreType: String): DataType = synchronized {
|
||||
ddlParser.parseType(metastoreType)
|
||||
}
|
||||
def toDataType(metastoreType: String): DataType = DataTypeParser(metastoreType)
|
||||
|
||||
def toMetastoreType(dt: DataType): String = dt match {
|
||||
case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>"
|
||||
|
|
Loading…
Reference in a new issue