[SPARK-36074][SQL] Add error class for StructType.findNestedField

### What changes were proposed in this pull request?

This PR adds an INVALID_FIELD_NAME error class for the errors in `StructType.findNestedField`. It also cleans up the code there and adds UT for this method.

### Why are the changes needed?

follow the new error message framework

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

existing tests

Closes #33282 from cloud-fan/error.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Wenchen Fan 2021-07-13 21:13:58 +08:00
parent 57a4f310df
commit 4a62e1e9c1
7 changed files with 183 additions and 56 deletions

View file

@ -1,4 +1,8 @@
{
"AMBIGUOUS_FIELD_NAME" : {
"message" : [ "Field name %s is ambiguous and has %s matching fields in the struct." ],
"sqlState" : "42000"
},
"DIVIDE_BY_ZERO" : {
"message" : [ "divide by zero" ],
"sqlState" : "22012"
@ -7,6 +11,10 @@
"message" : [ "Found duplicate keys '%s'" ],
"sqlState" : "23000"
},
"INVALID_FIELD_NAME" : {
"message" : [ "Field name %s is invalid: %s is not a struct." ],
"sqlState" : "42000"
},
"MISSING_COLUMN" : {
"message" : [ "cannot resolve '%s' given input columns: [%s]" ],
"sqlState" : "42000"

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql
import org.apache.spark.{SparkThrowable, SparkThrowableHelper}
import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.Origin
/**
* Thrown when a query fails to analyze, usually because the query itself is invalid.
@ -48,12 +49,11 @@ class AnalysisException protected[sql] (
def this(
errorClass: String,
messageParameters: Array[String],
line: Option[Int],
startPosition: Option[Int]) =
origin: Origin) =
this(
SparkThrowableHelper.getMessage(errorClass, messageParameters),
line = line,
startPosition = startPosition,
line = origin.line,
startPosition = origin.startPosition,
errorClass = Some(errorClass),
messageParameters = messageParameters)

View file

@ -3613,7 +3613,9 @@ class Analyzer(override val catalogManager: CatalogManager)
table: ResolvedTable,
fieldName: Seq[String],
context: Expression): ResolvedFieldName = {
table.schema.findNestedField(fieldName, includeCollections = true, conf.resolver).map {
table.schema.findNestedField(
fieldName, includeCollections = true, conf.resolver, context.origin
).map {
case (path, field) => ResolvedFieldName(path, field)
}.getOrElse(throw QueryCompilationErrors.missingFieldError(fieldName, table, context))
}

View file

@ -51,8 +51,7 @@ package object analysis {
throw new AnalysisException(
errorClass = errorClass,
messageParameters = messageParameters,
line = t.origin.line,
startPosition = t.origin.startPosition)
origin = t.origin)
}
}

View file

@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CreateMap, Expression, GroupingID, NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, Join, LogicalPlan, SerdeInfo, Window}
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode}
import org.apache.spark.sql.catalyst.util.{toPrettySQL, FailFastMode, ParseMode, PermissiveMode}
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
@ -1352,9 +1352,12 @@ private[spark] object QueryCompilationErrors {
s"${evalTypes.mkString(",")}")
}
def ambiguousFieldNameError(fieldName: String, names: String): Throwable = {
def ambiguousFieldNameError(
fieldName: Seq[String], numMatches: Int, context: Origin): Throwable = {
new AnalysisException(
s"Ambiguous field name: $fieldName. Found multiple columns that can match: $names")
errorClass = "AMBIGUOUS_FIELD_NAME",
messageParameters = Array(fieldName.quoted, numMatches.toString),
origin = context)
}
def cannotUseIntervalTypeInTableSchemaError(): Throwable = {
@ -2359,8 +2362,10 @@ private[spark] object QueryCompilationErrors {
context.origin.startPosition)
}
def invalidFieldName(fieldName: Seq[String], path: Seq[String]): Throwable = {
def invalidFieldName(fieldName: Seq[String], path: Seq[String], context: Origin): Throwable = {
new AnalysisException(
s"Field name ${fieldName.quoted} is invalid, ${path.quoted} is not a struct.")
errorClass = "INVALID_FIELD_NAME",
messageParameters = Array(fieldName.quoted, path.quoted),
origin = context)
}
}

View file

@ -27,6 +27,7 @@ import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser}
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.catalyst.util.{truncatedString, StringUtils}
import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
@ -317,66 +318,69 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
private[sql] def findNestedField(
fieldNames: Seq[String],
includeCollections: Boolean = false,
resolver: Resolver = _ == _): Option[(Seq[String], StructField)] = {
def prettyFieldName(nameParts: Seq[String]): String = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
nameParts.quoted
}
resolver: Resolver = _ == _,
context: Origin = Origin()): Option[(Seq[String], StructField)] = {
def findField(
struct: StructType,
searchPath: Seq[String],
normalizedPath: Seq[String]): Option[(Seq[String], StructField)] = {
searchPath.headOption.flatMap { searchName =>
assert(searchPath.nonEmpty)
val searchName = searchPath.head
val found = struct.fields.filter(f => resolver(searchName, f.name))
if (found.length > 1) {
val names = found.map(f => prettyFieldName(normalizedPath :+ f.name))
.mkString("[", ", ", " ]")
throw QueryCompilationErrors.ambiguousFieldNameError(
prettyFieldName(normalizedPath :+ searchName), names)
throw QueryCompilationErrors.ambiguousFieldNameError(fieldNames, found.length, context)
} else if (found.isEmpty) {
None
} else {
val field = found.head
(searchPath.tail, field.dataType, includeCollections) match {
case (Seq(), _, _) =>
val currentPath = normalizedPath :+ field.name
val newSearchPath = searchPath.tail
if (newSearchPath.isEmpty) {
Some(normalizedPath -> field)
} else {
(newSearchPath, field.dataType) match {
case (_, s: StructType) =>
findField(s, newSearchPath, currentPath)
case (names, struct: StructType, _) =>
findField(struct, names, normalizedPath :+ field.name)
case _ if !includeCollections =>
throw QueryCompilationErrors.invalidFieldName(fieldNames, currentPath, context)
case (_, _, false) =>
None // types nested in maps and arrays are not used
case (Seq("key", rest @ _*), MapType(keyType, _, _)) =>
findFieldInCollection(keyType, false, rest, currentPath, "key")
case (Seq("key"), MapType(keyType, _, _), true) =>
// return the key type as a struct field to include nullability
Some((normalizedPath :+ field.name) -> StructField("key", keyType, nullable = false))
case (Seq("value", rest @ _*), MapType(_, valueType, isNullable)) =>
findFieldInCollection(valueType, isNullable, rest, currentPath, "value")
case (Seq("key", names @ _*), MapType(struct: StructType, _, _), true) =>
findField(struct, names, normalizedPath ++ Seq(field.name, "key"))
case (Seq("value"), MapType(_, valueType, isNullable), true) =>
// return the value type as a struct field to include nullability
Some((normalizedPath :+ field.name) ->
StructField("value", valueType, nullable = isNullable))
case (Seq("value", names @ _*), MapType(_, struct: StructType, _), true) =>
findField(struct, names, normalizedPath ++ Seq(field.name, "value"))
case (Seq("element"), ArrayType(elementType, isNullable), true) =>
// return the element type as a struct field to include nullability
Some((normalizedPath :+ field.name) ->
StructField("element", elementType, nullable = isNullable))
case (Seq("element", names @ _*), ArrayType(struct: StructType, _), true) =>
findField(struct, names, normalizedPath ++ Seq(field.name, "element"))
case (Seq("element", rest @ _*), ArrayType(elementType, isNullable)) =>
findFieldInCollection(elementType, isNullable, rest, currentPath, "element")
case _ =>
throw QueryCompilationErrors.invalidFieldName(fieldNames, normalizedPath)
throw QueryCompilationErrors.invalidFieldName(fieldNames, currentPath, context)
}
}
}
}
def findFieldInCollection(
dt: DataType,
nullable: Boolean,
searchPath: Seq[String],
normalizedPath: Seq[String],
collectionFieldName: String): Option[(Seq[String], StructField)] = {
if (searchPath.isEmpty) {
Some(normalizedPath -> StructField(collectionFieldName, dt, nullable))
} else {
val newPath = normalizedPath :+ collectionFieldName
dt match {
case s: StructType =>
findField(s, searchPath, newPath)
case _ =>
throw QueryCompilationErrors.invalidFieldName(fieldNames, newPath, context)
}
}
}
findField(this, fieldNames, Nil)
}

View file

@ -18,6 +18,8 @@
package org.apache.spark.sql.types
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.internal.SQLConf
@ -273,4 +275,111 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper {
checkIntervalDDL(start, end, DT.fieldToString)
}
}
test("findNestedField") {
val innerStruct = new StructType()
.add("s11", "int")
.add("s12", "int")
val input = new StructType()
.add("s1", innerStruct)
.add("s2", new StructType().add("x", "int").add("X", "int"))
.add("m1", MapType(IntegerType, IntegerType))
.add("m2", MapType(
new StructType().add("a", "int"),
new StructType().add("b", "int")
))
.add("a1", ArrayType(IntegerType))
.add("a2", ArrayType(new StructType().add("c", "int")))
def check(field: Seq[String], expect: Option[(Seq[String], StructField)]): Unit = {
val res = input.findNestedField(field, resolver = caseInsensitiveResolution)
assert(res == expect)
}
def caseSensitiveCheck(field: Seq[String], expect: Option[(Seq[String], StructField)]): Unit = {
val res = input.findNestedField(field, resolver = caseSensitiveResolution)
assert(res == expect)
}
def checkCollection(field: Seq[String], expect: Option[(Seq[String], StructField)]): Unit = {
val res = input.findNestedField(field,
includeCollections = true, resolver = caseInsensitiveResolution)
assert(res == expect)
}
// struct type
check(Seq("non_exist"), None)
check(Seq("S1"), Some(Nil -> StructField("s1", innerStruct)))
caseSensitiveCheck(Seq("S1"), None)
check(Seq("s1", "S12"), Some(Seq("s1") -> StructField("s12", IntegerType)))
caseSensitiveCheck(Seq("s1", "S12"), None)
check(Seq("S1.non_exist"), None)
var e = intercept[AnalysisException] {
check(Seq("S1", "S12", "S123"), None)
}
assert(e.getMessage.contains("Field name S1.S12.S123 is invalid: s1.s12 is not a struct"))
// ambiguous name
e = intercept[AnalysisException] {
check(Seq("S2", "x"), None)
}
assert(e.getMessage.contains(
"Field name S2.x is ambiguous and has 2 matching fields in the struct"))
caseSensitiveCheck(Seq("s2", "x"), Some(Seq("s2") -> StructField("x", IntegerType)))
// simple map type
e = intercept[AnalysisException] {
check(Seq("m1", "key"), None)
}
assert(e.getMessage.contains("Field name m1.key is invalid: m1 is not a struct"))
checkCollection(Seq("m1", "key"), Some(Seq("m1") -> StructField("key", IntegerType, false)))
checkCollection(Seq("M1", "value"), Some(Seq("m1") -> StructField("value", IntegerType)))
e = intercept[AnalysisException] {
checkCollection(Seq("M1", "key", "name"), None)
}
assert(e.getMessage.contains("Field name M1.key.name is invalid: m1.key is not a struct"))
e = intercept[AnalysisException] {
checkCollection(Seq("M1", "value", "name"), None)
}
assert(e.getMessage.contains("Field name M1.value.name is invalid: m1.value is not a struct"))
// map of struct
checkCollection(Seq("M2", "key", "A"),
Some(Seq("m2", "key") -> StructField("a", IntegerType)))
checkCollection(Seq("M2", "key", "non_exist"), None)
checkCollection(Seq("M2", "value", "b"),
Some(Seq("m2", "value") -> StructField("b", IntegerType)))
checkCollection(Seq("M2", "value", "non_exist"), None)
e = intercept[AnalysisException] {
checkCollection(Seq("m2", "key", "A", "name"), None)
}
assert(e.getMessage.contains("Field name m2.key.A.name is invalid: m2.key.a is not a struct"))
e = intercept[AnalysisException] {
checkCollection(Seq("M2", "value", "b", "name"), None)
}
assert(e.getMessage.contains(
"Field name M2.value.b.name is invalid: m2.value.b is not a struct"))
// simple array type
e = intercept[AnalysisException] {
check(Seq("A1", "element"), None)
}
assert(e.getMessage.contains("Field name A1.element is invalid: a1 is not a struct"))
checkCollection(Seq("A1", "element"), Some(Seq("a1") -> StructField("element", IntegerType)))
e = intercept[AnalysisException] {
checkCollection(Seq("A1", "element", "name"), None)
}
assert(e.getMessage.contains(
"Field name A1.element.name is invalid: a1.element is not a struct"))
// array of struct
checkCollection(Seq("A2", "element", "C"),
Some(Seq("a2", "element") -> StructField("c", IntegerType)))
checkCollection(Seq("A2", "element", "non_exist"), None)
e = intercept[AnalysisException] {
checkCollection(Seq("a2", "element", "C", "name"), None)
}
assert(e.getMessage.contains(
"Field name a2.element.C.name is invalid: a2.element.c is not a struct"))
}
}