[SPARK-36074][SQL] Add error class for StructType.findNestedField
### What changes were proposed in this pull request? This PR adds an INVALID_FIELD_NAME error class for the errors in `StructType.findNestedField`. It also cleans up the code there and adds UT for this method. ### Why are the changes needed? follow the new error message framework ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests Closes #33282 from cloud-fan/error. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
57a4f310df
commit
4a62e1e9c1
|
@ -1,4 +1,8 @@
|
|||
{
|
||||
"AMBIGUOUS_FIELD_NAME" : {
|
||||
"message" : [ "Field name %s is ambiguous and has %s matching fields in the struct." ],
|
||||
"sqlState" : "42000"
|
||||
},
|
||||
"DIVIDE_BY_ZERO" : {
|
||||
"message" : [ "divide by zero" ],
|
||||
"sqlState" : "22012"
|
||||
|
@ -7,6 +11,10 @@
|
|||
"message" : [ "Found duplicate keys '%s'" ],
|
||||
"sqlState" : "23000"
|
||||
},
|
||||
"INVALID_FIELD_NAME" : {
|
||||
"message" : [ "Field name %s is invalid: %s is not a struct." ],
|
||||
"sqlState" : "42000"
|
||||
},
|
||||
"MISSING_COLUMN" : {
|
||||
"message" : [ "cannot resolve '%s' given input columns: [%s]" ],
|
||||
"sqlState" : "42000"
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql
|
|||
import org.apache.spark.{SparkThrowable, SparkThrowableHelper}
|
||||
import org.apache.spark.annotation.Stable
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.catalyst.trees.Origin
|
||||
|
||||
/**
|
||||
* Thrown when a query fails to analyze, usually because the query itself is invalid.
|
||||
|
@ -48,12 +49,11 @@ class AnalysisException protected[sql] (
|
|||
def this(
|
||||
errorClass: String,
|
||||
messageParameters: Array[String],
|
||||
line: Option[Int],
|
||||
startPosition: Option[Int]) =
|
||||
origin: Origin) =
|
||||
this(
|
||||
SparkThrowableHelper.getMessage(errorClass, messageParameters),
|
||||
line = line,
|
||||
startPosition = startPosition,
|
||||
line = origin.line,
|
||||
startPosition = origin.startPosition,
|
||||
errorClass = Some(errorClass),
|
||||
messageParameters = messageParameters)
|
||||
|
||||
|
|
|
@ -3613,7 +3613,9 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
table: ResolvedTable,
|
||||
fieldName: Seq[String],
|
||||
context: Expression): ResolvedFieldName = {
|
||||
table.schema.findNestedField(fieldName, includeCollections = true, conf.resolver).map {
|
||||
table.schema.findNestedField(
|
||||
fieldName, includeCollections = true, conf.resolver, context.origin
|
||||
).map {
|
||||
case (path, field) => ResolvedFieldName(path, field)
|
||||
}.getOrElse(throw QueryCompilationErrors.missingFieldError(fieldName, table, context))
|
||||
}
|
||||
|
|
|
@ -51,8 +51,7 @@ package object analysis {
|
|||
throw new AnalysisException(
|
||||
errorClass = errorClass,
|
||||
messageParameters = messageParameters,
|
||||
line = t.origin.line,
|
||||
startPosition = t.origin.startPosition)
|
||||
origin = t.origin)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
|
|||
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CreateMap, Expression, GroupingID, NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, WindowSpecDefinition}
|
||||
import org.apache.spark.sql.catalyst.plans.JoinType
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, Join, LogicalPlan, SerdeInfo, Window}
|
||||
import org.apache.spark.sql.catalyst.trees.TreeNode
|
||||
import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode}
|
||||
import org.apache.spark.sql.catalyst.util.{toPrettySQL, FailFastMode, ParseMode, PermissiveMode}
|
||||
import org.apache.spark.sql.connector.catalog._
|
||||
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
|
||||
|
@ -1352,9 +1352,12 @@ private[spark] object QueryCompilationErrors {
|
|||
s"${evalTypes.mkString(",")}")
|
||||
}
|
||||
|
||||
def ambiguousFieldNameError(fieldName: String, names: String): Throwable = {
|
||||
def ambiguousFieldNameError(
|
||||
fieldName: Seq[String], numMatches: Int, context: Origin): Throwable = {
|
||||
new AnalysisException(
|
||||
s"Ambiguous field name: $fieldName. Found multiple columns that can match: $names")
|
||||
errorClass = "AMBIGUOUS_FIELD_NAME",
|
||||
messageParameters = Array(fieldName.quoted, numMatches.toString),
|
||||
origin = context)
|
||||
}
|
||||
|
||||
def cannotUseIntervalTypeInTableSchemaError(): Throwable = {
|
||||
|
@ -2359,8 +2362,10 @@ private[spark] object QueryCompilationErrors {
|
|||
context.origin.startPosition)
|
||||
}
|
||||
|
||||
def invalidFieldName(fieldName: Seq[String], path: Seq[String]): Throwable = {
|
||||
def invalidFieldName(fieldName: Seq[String], path: Seq[String], context: Origin): Throwable = {
|
||||
new AnalysisException(
|
||||
s"Field name ${fieldName.quoted} is invalid, ${path.quoted} is not a struct.")
|
||||
errorClass = "INVALID_FIELD_NAME",
|
||||
messageParameters = Array(fieldName.quoted, path.quoted),
|
||||
origin = context)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.spark.annotation.Stable
|
|||
import org.apache.spark.sql.catalyst.analysis.Resolver
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering}
|
||||
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser}
|
||||
import org.apache.spark.sql.catalyst.trees.Origin
|
||||
import org.apache.spark.sql.catalyst.util.{truncatedString, StringUtils}
|
||||
import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat
|
||||
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
|
||||
|
@ -317,66 +318,69 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
|
|||
private[sql] def findNestedField(
|
||||
fieldNames: Seq[String],
|
||||
includeCollections: Boolean = false,
|
||||
resolver: Resolver = _ == _): Option[(Seq[String], StructField)] = {
|
||||
def prettyFieldName(nameParts: Seq[String]): String = {
|
||||
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
|
||||
nameParts.quoted
|
||||
}
|
||||
resolver: Resolver = _ == _,
|
||||
context: Origin = Origin()): Option[(Seq[String], StructField)] = {
|
||||
|
||||
def findField(
|
||||
struct: StructType,
|
||||
searchPath: Seq[String],
|
||||
normalizedPath: Seq[String]): Option[(Seq[String], StructField)] = {
|
||||
searchPath.headOption.flatMap { searchName =>
|
||||
val found = struct.fields.filter(f => resolver(searchName, f.name))
|
||||
if (found.length > 1) {
|
||||
val names = found.map(f => prettyFieldName(normalizedPath :+ f.name))
|
||||
.mkString("[", ", ", " ]")
|
||||
throw QueryCompilationErrors.ambiguousFieldNameError(
|
||||
prettyFieldName(normalizedPath :+ searchName), names)
|
||||
} else if (found.isEmpty) {
|
||||
None
|
||||
assert(searchPath.nonEmpty)
|
||||
val searchName = searchPath.head
|
||||
val found = struct.fields.filter(f => resolver(searchName, f.name))
|
||||
if (found.length > 1) {
|
||||
throw QueryCompilationErrors.ambiguousFieldNameError(fieldNames, found.length, context)
|
||||
} else if (found.isEmpty) {
|
||||
None
|
||||
} else {
|
||||
val field = found.head
|
||||
val currentPath = normalizedPath :+ field.name
|
||||
val newSearchPath = searchPath.tail
|
||||
if (newSearchPath.isEmpty) {
|
||||
Some(normalizedPath -> field)
|
||||
} else {
|
||||
val field = found.head
|
||||
(searchPath.tail, field.dataType, includeCollections) match {
|
||||
case (Seq(), _, _) =>
|
||||
Some(normalizedPath -> field)
|
||||
(newSearchPath, field.dataType) match {
|
||||
case (_, s: StructType) =>
|
||||
findField(s, newSearchPath, currentPath)
|
||||
|
||||
case (names, struct: StructType, _) =>
|
||||
findField(struct, names, normalizedPath :+ field.name)
|
||||
case _ if !includeCollections =>
|
||||
throw QueryCompilationErrors.invalidFieldName(fieldNames, currentPath, context)
|
||||
|
||||
case (_, _, false) =>
|
||||
None // types nested in maps and arrays are not used
|
||||
case (Seq("key", rest @ _*), MapType(keyType, _, _)) =>
|
||||
findFieldInCollection(keyType, false, rest, currentPath, "key")
|
||||
|
||||
case (Seq("key"), MapType(keyType, _, _), true) =>
|
||||
// return the key type as a struct field to include nullability
|
||||
Some((normalizedPath :+ field.name) -> StructField("key", keyType, nullable = false))
|
||||
case (Seq("value", rest @ _*), MapType(_, valueType, isNullable)) =>
|
||||
findFieldInCollection(valueType, isNullable, rest, currentPath, "value")
|
||||
|
||||
case (Seq("key", names @ _*), MapType(struct: StructType, _, _), true) =>
|
||||
findField(struct, names, normalizedPath ++ Seq(field.name, "key"))
|
||||
|
||||
case (Seq("value"), MapType(_, valueType, isNullable), true) =>
|
||||
// return the value type as a struct field to include nullability
|
||||
Some((normalizedPath :+ field.name) ->
|
||||
StructField("value", valueType, nullable = isNullable))
|
||||
|
||||
case (Seq("value", names @ _*), MapType(_, struct: StructType, _), true) =>
|
||||
findField(struct, names, normalizedPath ++ Seq(field.name, "value"))
|
||||
|
||||
case (Seq("element"), ArrayType(elementType, isNullable), true) =>
|
||||
// return the element type as a struct field to include nullability
|
||||
Some((normalizedPath :+ field.name) ->
|
||||
StructField("element", elementType, nullable = isNullable))
|
||||
|
||||
case (Seq("element", names @ _*), ArrayType(struct: StructType, _), true) =>
|
||||
findField(struct, names, normalizedPath ++ Seq(field.name, "element"))
|
||||
case (Seq("element", rest @ _*), ArrayType(elementType, isNullable)) =>
|
||||
findFieldInCollection(elementType, isNullable, rest, currentPath, "element")
|
||||
|
||||
case _ =>
|
||||
throw QueryCompilationErrors.invalidFieldName(fieldNames, normalizedPath)
|
||||
throw QueryCompilationErrors.invalidFieldName(fieldNames, currentPath, context)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def findFieldInCollection(
|
||||
dt: DataType,
|
||||
nullable: Boolean,
|
||||
searchPath: Seq[String],
|
||||
normalizedPath: Seq[String],
|
||||
collectionFieldName: String): Option[(Seq[String], StructField)] = {
|
||||
if (searchPath.isEmpty) {
|
||||
Some(normalizedPath -> StructField(collectionFieldName, dt, nullable))
|
||||
} else {
|
||||
val newPath = normalizedPath :+ collectionFieldName
|
||||
dt match {
|
||||
case s: StructType =>
|
||||
findField(s, searchPath, newPath)
|
||||
case _ =>
|
||||
throw QueryCompilationErrors.invalidFieldName(fieldNames, newPath, context)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
findField(this, fieldNames, Nil)
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
package org.apache.spark.sql.types
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution}
|
||||
import org.apache.spark.sql.catalyst.parser.ParseException
|
||||
import org.apache.spark.sql.catalyst.plans.SQLHelper
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
@ -273,4 +275,111 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper {
|
|||
checkIntervalDDL(start, end, DT.fieldToString)
|
||||
}
|
||||
}
|
||||
|
||||
test("findNestedField") {
|
||||
val innerStruct = new StructType()
|
||||
.add("s11", "int")
|
||||
.add("s12", "int")
|
||||
val input = new StructType()
|
||||
.add("s1", innerStruct)
|
||||
.add("s2", new StructType().add("x", "int").add("X", "int"))
|
||||
.add("m1", MapType(IntegerType, IntegerType))
|
||||
.add("m2", MapType(
|
||||
new StructType().add("a", "int"),
|
||||
new StructType().add("b", "int")
|
||||
))
|
||||
.add("a1", ArrayType(IntegerType))
|
||||
.add("a2", ArrayType(new StructType().add("c", "int")))
|
||||
|
||||
def check(field: Seq[String], expect: Option[(Seq[String], StructField)]): Unit = {
|
||||
val res = input.findNestedField(field, resolver = caseInsensitiveResolution)
|
||||
assert(res == expect)
|
||||
}
|
||||
|
||||
def caseSensitiveCheck(field: Seq[String], expect: Option[(Seq[String], StructField)]): Unit = {
|
||||
val res = input.findNestedField(field, resolver = caseSensitiveResolution)
|
||||
assert(res == expect)
|
||||
}
|
||||
|
||||
def checkCollection(field: Seq[String], expect: Option[(Seq[String], StructField)]): Unit = {
|
||||
val res = input.findNestedField(field,
|
||||
includeCollections = true, resolver = caseInsensitiveResolution)
|
||||
assert(res == expect)
|
||||
}
|
||||
|
||||
// struct type
|
||||
check(Seq("non_exist"), None)
|
||||
check(Seq("S1"), Some(Nil -> StructField("s1", innerStruct)))
|
||||
caseSensitiveCheck(Seq("S1"), None)
|
||||
check(Seq("s1", "S12"), Some(Seq("s1") -> StructField("s12", IntegerType)))
|
||||
caseSensitiveCheck(Seq("s1", "S12"), None)
|
||||
check(Seq("S1.non_exist"), None)
|
||||
var e = intercept[AnalysisException] {
|
||||
check(Seq("S1", "S12", "S123"), None)
|
||||
}
|
||||
assert(e.getMessage.contains("Field name S1.S12.S123 is invalid: s1.s12 is not a struct"))
|
||||
|
||||
// ambiguous name
|
||||
e = intercept[AnalysisException] {
|
||||
check(Seq("S2", "x"), None)
|
||||
}
|
||||
assert(e.getMessage.contains(
|
||||
"Field name S2.x is ambiguous and has 2 matching fields in the struct"))
|
||||
caseSensitiveCheck(Seq("s2", "x"), Some(Seq("s2") -> StructField("x", IntegerType)))
|
||||
|
||||
// simple map type
|
||||
e = intercept[AnalysisException] {
|
||||
check(Seq("m1", "key"), None)
|
||||
}
|
||||
assert(e.getMessage.contains("Field name m1.key is invalid: m1 is not a struct"))
|
||||
checkCollection(Seq("m1", "key"), Some(Seq("m1") -> StructField("key", IntegerType, false)))
|
||||
checkCollection(Seq("M1", "value"), Some(Seq("m1") -> StructField("value", IntegerType)))
|
||||
e = intercept[AnalysisException] {
|
||||
checkCollection(Seq("M1", "key", "name"), None)
|
||||
}
|
||||
assert(e.getMessage.contains("Field name M1.key.name is invalid: m1.key is not a struct"))
|
||||
e = intercept[AnalysisException] {
|
||||
checkCollection(Seq("M1", "value", "name"), None)
|
||||
}
|
||||
assert(e.getMessage.contains("Field name M1.value.name is invalid: m1.value is not a struct"))
|
||||
|
||||
// map of struct
|
||||
checkCollection(Seq("M2", "key", "A"),
|
||||
Some(Seq("m2", "key") -> StructField("a", IntegerType)))
|
||||
checkCollection(Seq("M2", "key", "non_exist"), None)
|
||||
checkCollection(Seq("M2", "value", "b"),
|
||||
Some(Seq("m2", "value") -> StructField("b", IntegerType)))
|
||||
checkCollection(Seq("M2", "value", "non_exist"), None)
|
||||
e = intercept[AnalysisException] {
|
||||
checkCollection(Seq("m2", "key", "A", "name"), None)
|
||||
}
|
||||
assert(e.getMessage.contains("Field name m2.key.A.name is invalid: m2.key.a is not a struct"))
|
||||
e = intercept[AnalysisException] {
|
||||
checkCollection(Seq("M2", "value", "b", "name"), None)
|
||||
}
|
||||
assert(e.getMessage.contains(
|
||||
"Field name M2.value.b.name is invalid: m2.value.b is not a struct"))
|
||||
|
||||
// simple array type
|
||||
e = intercept[AnalysisException] {
|
||||
check(Seq("A1", "element"), None)
|
||||
}
|
||||
assert(e.getMessage.contains("Field name A1.element is invalid: a1 is not a struct"))
|
||||
checkCollection(Seq("A1", "element"), Some(Seq("a1") -> StructField("element", IntegerType)))
|
||||
e = intercept[AnalysisException] {
|
||||
checkCollection(Seq("A1", "element", "name"), None)
|
||||
}
|
||||
assert(e.getMessage.contains(
|
||||
"Field name A1.element.name is invalid: a1.element is not a struct"))
|
||||
|
||||
// array of struct
|
||||
checkCollection(Seq("A2", "element", "C"),
|
||||
Some(Seq("a2", "element") -> StructField("c", IntegerType)))
|
||||
checkCollection(Seq("A2", "element", "non_exist"), None)
|
||||
e = intercept[AnalysisException] {
|
||||
checkCollection(Seq("a2", "element", "C", "name"), None)
|
||||
}
|
||||
assert(e.getMessage.contains(
|
||||
"Field name a2.element.C.name is invalid: a2.element.c is not a struct"))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue