diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 02feb9dcc0..6ab113bee8 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -1,4 +1,8 @@ { + "AMBIGUOUS_FIELD_NAME" : { + "message" : [ "Field name %s is ambiguous and has %s matching fields in the struct." ], + "sqlState" : "42000" + }, "DIVIDE_BY_ZERO" : { "message" : [ "divide by zero" ], "sqlState" : "22012" @@ -7,6 +11,10 @@ "message" : [ "Found duplicate keys '%s'" ], "sqlState" : "23000" }, + "INVALID_FIELD_NAME" : { + "message" : [ "Field name %s is invalid: %s is not a struct." ], + "sqlState" : "42000" + }, "MISSING_COLUMN" : { "message" : [ "cannot resolve '%s' given input columns: [%s]" ], "sqlState" : "42000" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 6299431a49..d0a3a713d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.{SparkThrowable, SparkThrowableHelper} import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.Origin /** * Thrown when a query fails to analyze, usually because the query itself is invalid. @@ -48,12 +49,11 @@ class AnalysisException protected[sql] ( def this( errorClass: String, messageParameters: Array[String], - line: Option[Int], - startPosition: Option[Int]) = + origin: Origin) = this( SparkThrowableHelper.getMessage(errorClass, messageParameters), - line = line, - startPosition = startPosition, + line = origin.line, + startPosition = origin.startPosition, errorClass = Some(errorClass), messageParameters = messageParameters) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2d747f79c1..64f6b79fa2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3613,7 +3613,9 @@ class Analyzer(override val catalogManager: CatalogManager) table: ResolvedTable, fieldName: Seq[String], context: Expression): ResolvedFieldName = { - table.schema.findNestedField(fieldName, includeCollections = true, conf.resolver).map { + table.schema.findNestedField( + fieldName, includeCollections = true, conf.resolver, context.origin + ).map { case (path, field) => ResolvedFieldName(path, field) }.getOrElse(throw QueryCompilationErrors.missingFieldError(fieldName, table, context)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index 8ad8706cc5..81683adc23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -51,8 +51,7 @@ package object analysis { throw new AnalysisException( errorClass = errorClass, messageParameters = messageParameters, - line = t.origin.line, - startPosition = t.origin.startPosition) + origin = t.origin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index d1dcbbcf6c..6322676661 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CreateMap, Expression, GroupingID, NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, Join, LogicalPlan, SerdeInfo, Window} -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} import org.apache.spark.sql.catalyst.util.{toPrettySQL, FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -1352,9 +1352,12 @@ private[spark] object QueryCompilationErrors { s"${evalTypes.mkString(",")}") } - def ambiguousFieldNameError(fieldName: String, names: String): Throwable = { + def ambiguousFieldNameError( + fieldName: Seq[String], numMatches: Int, context: Origin): Throwable = { new AnalysisException( - s"Ambiguous field name: $fieldName. Found multiple columns that can match: $names") + errorClass = "AMBIGUOUS_FIELD_NAME", + messageParameters = Array(fieldName.quoted, numMatches.toString), + origin = context) } def cannotUseIntervalTypeInTableSchemaError(): Throwable = { @@ -2359,8 +2362,10 @@ private[spark] object QueryCompilationErrors { context.origin.startPosition) } - def invalidFieldName(fieldName: Seq[String], path: Seq[String]): Throwable = { + def invalidFieldName(fieldName: Seq[String], path: Seq[String], context: Origin): Throwable = { new AnalysisException( - s"Field name ${fieldName.quoted} is invalid, ${path.quoted} is not a struct.") + errorClass = "INVALID_FIELD_NAME", + messageParameters = Array(fieldName.quoted, path.quoted), + origin = context) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index a9ba2b968d..87ff4eb571 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -27,6 +27,7 @@ import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} +import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.catalyst.util.{truncatedString, StringUtils} import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -317,66 +318,69 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private[sql] def findNestedField( fieldNames: Seq[String], includeCollections: Boolean = false, - resolver: Resolver = _ == _): Option[(Seq[String], StructField)] = { - def prettyFieldName(nameParts: Seq[String]): String = { - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - nameParts.quoted - } + resolver: Resolver = _ == _, + context: Origin = Origin()): Option[(Seq[String], StructField)] = { def findField( struct: StructType, searchPath: Seq[String], normalizedPath: Seq[String]): Option[(Seq[String], StructField)] = { - searchPath.headOption.flatMap { searchName => - val found = struct.fields.filter(f => resolver(searchName, f.name)) - if (found.length > 1) { - val names = found.map(f => prettyFieldName(normalizedPath :+ f.name)) - .mkString("[", ", ", " ]") - throw QueryCompilationErrors.ambiguousFieldNameError( - prettyFieldName(normalizedPath :+ searchName), names) - } else if (found.isEmpty) { - None + assert(searchPath.nonEmpty) + val searchName = searchPath.head + val found = struct.fields.filter(f => resolver(searchName, f.name)) + if (found.length > 1) { + throw QueryCompilationErrors.ambiguousFieldNameError(fieldNames, found.length, context) + } else if (found.isEmpty) { + None + } else { + val field = found.head + val currentPath = normalizedPath :+ field.name + val newSearchPath = searchPath.tail + if (newSearchPath.isEmpty) { + Some(normalizedPath -> field) } else { - val field = found.head - (searchPath.tail, field.dataType, includeCollections) match { - case (Seq(), _, _) => - Some(normalizedPath -> field) + (newSearchPath, field.dataType) match { + case (_, s: StructType) => + findField(s, newSearchPath, currentPath) - case (names, struct: StructType, _) => - findField(struct, names, normalizedPath :+ field.name) + case _ if !includeCollections => + throw QueryCompilationErrors.invalidFieldName(fieldNames, currentPath, context) - case (_, _, false) => - None // types nested in maps and arrays are not used + case (Seq("key", rest @ _*), MapType(keyType, _, _)) => + findFieldInCollection(keyType, false, rest, currentPath, "key") - case (Seq("key"), MapType(keyType, _, _), true) => - // return the key type as a struct field to include nullability - Some((normalizedPath :+ field.name) -> StructField("key", keyType, nullable = false)) + case (Seq("value", rest @ _*), MapType(_, valueType, isNullable)) => + findFieldInCollection(valueType, isNullable, rest, currentPath, "value") - case (Seq("key", names @ _*), MapType(struct: StructType, _, _), true) => - findField(struct, names, normalizedPath ++ Seq(field.name, "key")) - - case (Seq("value"), MapType(_, valueType, isNullable), true) => - // return the value type as a struct field to include nullability - Some((normalizedPath :+ field.name) -> - StructField("value", valueType, nullable = isNullable)) - - case (Seq("value", names @ _*), MapType(_, struct: StructType, _), true) => - findField(struct, names, normalizedPath ++ Seq(field.name, "value")) - - case (Seq("element"), ArrayType(elementType, isNullable), true) => - // return the element type as a struct field to include nullability - Some((normalizedPath :+ field.name) -> - StructField("element", elementType, nullable = isNullable)) - - case (Seq("element", names @ _*), ArrayType(struct: StructType, _), true) => - findField(struct, names, normalizedPath ++ Seq(field.name, "element")) + case (Seq("element", rest @ _*), ArrayType(elementType, isNullable)) => + findFieldInCollection(elementType, isNullable, rest, currentPath, "element") case _ => - throw QueryCompilationErrors.invalidFieldName(fieldNames, normalizedPath) + throw QueryCompilationErrors.invalidFieldName(fieldNames, currentPath, context) } } } } + + def findFieldInCollection( + dt: DataType, + nullable: Boolean, + searchPath: Seq[String], + normalizedPath: Seq[String], + collectionFieldName: String): Option[(Seq[String], StructField)] = { + if (searchPath.isEmpty) { + Some(normalizedPath -> StructField(collectionFieldName, dt, nullable)) + } else { + val newPath = normalizedPath :+ collectionFieldName + dt match { + case s: StructType => + findField(s, searchPath, newPath) + case _ => + throw QueryCompilationErrors.invalidFieldName(fieldNames, newPath, context) + } + } + } + findField(this, fieldNames, Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index 18821b8e1c..8db3831392 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.types import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.internal.SQLConf @@ -273,4 +275,111 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { checkIntervalDDL(start, end, DT.fieldToString) } } + + test("findNestedField") { + val innerStruct = new StructType() + .add("s11", "int") + .add("s12", "int") + val input = new StructType() + .add("s1", innerStruct) + .add("s2", new StructType().add("x", "int").add("X", "int")) + .add("m1", MapType(IntegerType, IntegerType)) + .add("m2", MapType( + new StructType().add("a", "int"), + new StructType().add("b", "int") + )) + .add("a1", ArrayType(IntegerType)) + .add("a2", ArrayType(new StructType().add("c", "int"))) + + def check(field: Seq[String], expect: Option[(Seq[String], StructField)]): Unit = { + val res = input.findNestedField(field, resolver = caseInsensitiveResolution) + assert(res == expect) + } + + def caseSensitiveCheck(field: Seq[String], expect: Option[(Seq[String], StructField)]): Unit = { + val res = input.findNestedField(field, resolver = caseSensitiveResolution) + assert(res == expect) + } + + def checkCollection(field: Seq[String], expect: Option[(Seq[String], StructField)]): Unit = { + val res = input.findNestedField(field, + includeCollections = true, resolver = caseInsensitiveResolution) + assert(res == expect) + } + + // struct type + check(Seq("non_exist"), None) + check(Seq("S1"), Some(Nil -> StructField("s1", innerStruct))) + caseSensitiveCheck(Seq("S1"), None) + check(Seq("s1", "S12"), Some(Seq("s1") -> StructField("s12", IntegerType))) + caseSensitiveCheck(Seq("s1", "S12"), None) + check(Seq("S1.non_exist"), None) + var e = intercept[AnalysisException] { + check(Seq("S1", "S12", "S123"), None) + } + assert(e.getMessage.contains("Field name S1.S12.S123 is invalid: s1.s12 is not a struct")) + + // ambiguous name + e = intercept[AnalysisException] { + check(Seq("S2", "x"), None) + } + assert(e.getMessage.contains( + "Field name S2.x is ambiguous and has 2 matching fields in the struct")) + caseSensitiveCheck(Seq("s2", "x"), Some(Seq("s2") -> StructField("x", IntegerType))) + + // simple map type + e = intercept[AnalysisException] { + check(Seq("m1", "key"), None) + } + assert(e.getMessage.contains("Field name m1.key is invalid: m1 is not a struct")) + checkCollection(Seq("m1", "key"), Some(Seq("m1") -> StructField("key", IntegerType, false))) + checkCollection(Seq("M1", "value"), Some(Seq("m1") -> StructField("value", IntegerType))) + e = intercept[AnalysisException] { + checkCollection(Seq("M1", "key", "name"), None) + } + assert(e.getMessage.contains("Field name M1.key.name is invalid: m1.key is not a struct")) + e = intercept[AnalysisException] { + checkCollection(Seq("M1", "value", "name"), None) + } + assert(e.getMessage.contains("Field name M1.value.name is invalid: m1.value is not a struct")) + + // map of struct + checkCollection(Seq("M2", "key", "A"), + Some(Seq("m2", "key") -> StructField("a", IntegerType))) + checkCollection(Seq("M2", "key", "non_exist"), None) + checkCollection(Seq("M2", "value", "b"), + Some(Seq("m2", "value") -> StructField("b", IntegerType))) + checkCollection(Seq("M2", "value", "non_exist"), None) + e = intercept[AnalysisException] { + checkCollection(Seq("m2", "key", "A", "name"), None) + } + assert(e.getMessage.contains("Field name m2.key.A.name is invalid: m2.key.a is not a struct")) + e = intercept[AnalysisException] { + checkCollection(Seq("M2", "value", "b", "name"), None) + } + assert(e.getMessage.contains( + "Field name M2.value.b.name is invalid: m2.value.b is not a struct")) + + // simple array type + e = intercept[AnalysisException] { + check(Seq("A1", "element"), None) + } + assert(e.getMessage.contains("Field name A1.element is invalid: a1 is not a struct")) + checkCollection(Seq("A1", "element"), Some(Seq("a1") -> StructField("element", IntegerType))) + e = intercept[AnalysisException] { + checkCollection(Seq("A1", "element", "name"), None) + } + assert(e.getMessage.contains( + "Field name A1.element.name is invalid: a1.element is not a struct")) + + // array of struct + checkCollection(Seq("A2", "element", "C"), + Some(Seq("a2", "element") -> StructField("c", IntegerType))) + checkCollection(Seq("A2", "element", "non_exist"), None) + e = intercept[AnalysisException] { + checkCollection(Seq("a2", "element", "C", "name"), None) + } + assert(e.getMessage.contains( + "Field name a2.element.C.name is invalid: a2.element.c is not a struct")) + } }