From 809b88a16287ffb87835b72419fdc9150b9de0e0 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Wed, 28 Jul 2021 14:00:29 +0800 Subject: [PATCH] [SPARK-36006][SQL] Migrate ALTER TABLE ... ADD/REPLACE COLUMNS commands to use UnresolvedTable to resolve the identifier ### What changes were proposed in this pull request? This PR proposes to migrate the following `ALTER TABLE ... ADD/REPLACE COLUMNS` commands to use `UnresolvedTable` as a `child` to resolve the table identifier. This allows consistent resolution rules (temp view first, etc.) to be applied for both v1/v2 commands. More info about the consistent resolution rule proposal can be found in [JIRA](https://issues.apache.org/jira/browse/SPARK-29900) or [proposal doc](https://docs.google.com/document/d/1hvLjGA8y_W_hhilpngXVub1Ebv8RsMap986nENCFnrg/edit?usp=sharing). ### Why are the changes needed? This is a part of effort to make the relation lookup behavior consistent: [SPARK-29900](https://issues.apache.org/jira/browse/SPARK-29900). ### Does this PR introduce _any_ user-facing change? After this PR, the above `ALTER TABLE ... ADD/REPLACE COLUMNS` commands will have a consistent resolution behavior. ### How was this patch tested? Updated existing tests. Closes #33200 from imback82/alter_add_cols. Authored-by: Terry Kim Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 167 +++++++----------- .../sql/catalyst/analysis/CheckAnalysis.scala | 105 ++--------- .../catalyst/analysis/ResolveCatalogs.scala | 35 +--- .../sql/catalyst/parser/AstBuilder.scala | 25 ++- .../catalyst/plans/logical/statements.scala | 23 +-- .../catalyst/plans/logical/v2Commands.scala | 117 +++++++----- .../sql/connector/catalog/CatalogV2Util.scala | 15 +- .../sql/errors/QueryCompilationErrors.scala | 6 +- ...eateTablePartitioningValidationSuite.scala | 9 +- .../sql/catalyst/parser/DDLParserSuite.scala | 116 ++++++++---- .../analysis/ResolveSessionCatalog.scala | 52 +----- .../datasources/v2/DataSourceV2Strategy.scala | 5 +- .../spark/sql/connector/AlterTableTests.scala | 3 +- .../V2CommandsCaseSensitivitySuite.scala | 104 ++++++----- .../sql/execution/command/DDLSuite.scala | 6 +- 15 files changed, 335 insertions(+), 453 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3e048ed9b6..f031f0816d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ -import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn} +import org.apache.spark.sql.connector.catalog.TableChange.{After, ColumnPosition} import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, BoundFunction, ScalarFunction} import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} @@ -269,7 +269,7 @@ class Analyzer(override val catalogManager: CatalogManager) ResolveRelations :: ResolveTables :: ResolvePartitionSpec :: - ResolveAlterTableCommands :: + ResolveAlterTableColumnCommands :: AddMetadataColumns :: DeduplicateRelations :: ResolveReferences :: @@ -312,7 +312,6 @@ class Analyzer(override val catalogManager: CatalogManager) Batch("Post-Hoc Resolution", Once, Seq(ResolveCommandsWithIfExists) ++ postHocResolutionRules: _*), - Batch("Normalize Alter Table", Once, ResolveAlterTableChanges), Batch("Remove Unresolved Hints", Once, new ResolveHints.RemoveAllHints), Batch("Nondeterministic", Once, @@ -1082,11 +1081,6 @@ class Analyzer(override val catalogManager: CatalogManager) case _ => write } - case alter @ AlterTable(_, _, u: UnresolvedV2Relation, _) => - CatalogV2Util.loadRelation(u.catalog, u.tableName) - .map(rel => alter.copy(table = rel)) - .getOrElse(alter) - case u: UnresolvedV2Relation => CatalogV2Util.loadRelation(u.catalog, u.tableName).getOrElse(u) } @@ -3611,16 +3605,69 @@ class Analyzer(override val catalogManager: CatalogManager) /** * Rule to mostly resolve, normalize and rewrite column names based on case sensitivity - * for alter table commands. + * for alter table column commands. */ - object ResolveAlterTableCommands extends Rule[LogicalPlan] { + object ResolveAlterTableColumnCommands extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { - case a: AlterTableCommand if a.table.resolved && hasUnresolvedFieldName(a) => + case a: AlterTableColumnCommand if a.table.resolved && hasUnresolvedFieldName(a) => val table = a.table.asInstanceOf[ResolvedTable] a.transformExpressions { case u: UnresolvedFieldName => resolveFieldNames(table, u.name, u) } + case a @ AlterTableAddColumns(r: ResolvedTable, cols) if !a.resolved => + // 'colsToAdd' keeps track of new columns being added. It stores a mapping from a + // normalized parent name of fields to field names that belong to the parent. + // For example, if we add columns "a.b.c", "a.b.d", and "a.c", 'colsToAdd' will become + // Map(Seq("a", "b") -> Seq("c", "d"), Seq("a") -> Seq("c")). + val colsToAdd = mutable.Map.empty[Seq[String], Seq[String]] + def resolvePosition( + col: QualifiedColType, + parentSchema: StructType, + resolvedParentName: Seq[String]): Option[FieldPosition] = { + val fieldsAdded = colsToAdd.getOrElse(resolvedParentName, Nil) + val resolvedPosition = col.position.map { + case u: UnresolvedFieldPosition => u.position match { + case after: After => + val allFields = parentSchema.fieldNames ++ fieldsAdded + allFields.find(n => conf.resolver(n, after.column())) match { + case Some(colName) => + ResolvedFieldPosition(ColumnPosition.after(colName)) + case None => + val name = if (resolvedParentName.isEmpty) "root" else resolvedParentName.quoted + throw QueryCompilationErrors.referenceColNotFoundForAlterTableChangesError( + after, name) + } + case _ => ResolvedFieldPosition(u.position) + } + case resolved => resolved + } + colsToAdd(resolvedParentName) = fieldsAdded :+ col.colName + resolvedPosition + } + val schema = r.table.schema + val resolvedCols = cols.map { col => + col.path match { + case Some(parent: UnresolvedFieldName) => + // Adding a nested field, need to resolve the parent column and position. + val resolvedParent = resolveFieldNames(r, parent.name, parent) + val parentSchema = resolvedParent.field.dataType match { + case s: StructType => s + case _ => throw QueryCompilationErrors.invalidFieldName( + col.name, parent.name, parent.origin) + } + val resolvedPosition = resolvePosition(col, parentSchema, resolvedParent.name) + col.copy(path = Some(resolvedParent), position = resolvedPosition) + case _ => + // Adding to the root. Just need to resolve position. + val resolvedPosition = resolvePosition(col, schema, Nil) + col.copy(position = resolvedPosition) + } + } + val resolved = a.copy(columnsToAdd = resolvedCols) + resolved.copyTagsFrom(a) + resolved + case a @ AlterTableAlterColumn( table: ResolvedTable, ResolvedFieldName(path, field), dataType, _, _, position) => val newDataType = dataType.flatMap { dt => @@ -3655,108 +3702,14 @@ class Analyzer(override val catalogManager: CatalogManager) fieldName, includeCollections = true, conf.resolver, context.origin ).map { case (path, field) => ResolvedFieldName(path, field) - }.getOrElse(throw QueryCompilationErrors.missingFieldError(fieldName, table, context)) + }.getOrElse(throw QueryCompilationErrors.missingFieldError(fieldName, table, context.origin)) } - private def hasUnresolvedFieldName(a: AlterTableCommand): Boolean = { + private def hasUnresolvedFieldName(a: AlterTableColumnCommand): Boolean = { a.expressions.exists(_.find(_.isInstanceOf[UnresolvedFieldName]).isDefined) } } - /** Rule to mostly resolve, normalize and rewrite column names based on case sensitivity. */ - object ResolveAlterTableChanges extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { - case a @ AlterTable(_, _, t: NamedRelation, changes) if t.resolved => - // 'colsToAdd' keeps track of new columns being added. It stores a mapping from a - // normalized parent name of fields to field names that belong to the parent. - // For example, if we add columns "a.b.c", "a.b.d", and "a.c", 'colsToAdd' will become - // Map(Seq("a", "b") -> Seq("c", "d"), Seq("a") -> Seq("c")). - val colsToAdd = mutable.Map.empty[Seq[String], Seq[String]] - val schema = t.schema - val normalizedChanges = changes.flatMap { - case add: AddColumn => - def addColumn( - parentSchema: StructType, - parentName: String, - normalizedParentName: Seq[String]): TableChange = { - val fieldsAdded = colsToAdd.getOrElse(normalizedParentName, Nil) - val pos = findColumnPosition(add.position(), parentName, parentSchema, fieldsAdded) - val field = add.fieldNames().last - colsToAdd(normalizedParentName) = fieldsAdded :+ field - TableChange.addColumn( - (normalizedParentName :+ field).toArray, - add.dataType(), - add.isNullable, - add.comment, - pos) - } - val parent = add.fieldNames().init - if (parent.nonEmpty) { - // Adding a nested field, need to normalize the parent column and position - val target = schema.findNestedField(parent, includeCollections = true, conf.resolver) - if (target.isEmpty) { - // Leave unresolved. Throws error in CheckAnalysis - Some(add) - } else { - val (normalizedName, sf) = target.get - sf.dataType match { - case struct: StructType => - Some(addColumn(struct, parent.quoted, normalizedName :+ sf.name)) - case other => - Some(add) - } - } - } else { - // Adding to the root. Just need to normalize position - Some(addColumn(schema, "root", Nil)) - } - - case delete: DeleteColumn => - resolveFieldNames(schema, delete.fieldNames(), TableChange.deleteColumn) - .orElse(Some(delete)) - - case column: ColumnChange => - // This is informational for future developers - throw QueryExecutionErrors.columnChangeUnsupportedError - case other => Some(other) - } - - a.copy(changes = normalizedChanges) - } - - /** - * Returns the table change if the field can be resolved, returns None if the column is not - * found. An error will be thrown in CheckAnalysis for columns that can't be resolved. - */ - private def resolveFieldNames( - schema: StructType, - fieldNames: Array[String], - copy: Array[String] => TableChange): Option[TableChange] = { - val fieldOpt = schema.findNestedField( - fieldNames, includeCollections = true, conf.resolver) - fieldOpt.map { case (path, field) => copy((path :+ field.name).toArray) } - } - - private def findColumnPosition( - position: ColumnPosition, - parentName: String, - struct: StructType, - fieldsAdded: Seq[String]): ColumnPosition = { - position match { - case null => null - case after: After => - (struct.fieldNames ++ fieldsAdded).find(n => conf.resolver(n, after.column())) match { - case Some(colName) => - ColumnPosition.after(colName) - case None => - throw QueryCompilationErrors.referenceColNotFoundForAlterTableChangesError(after, - parentName) - } - case other => other - } - } - } - /** * A rule that marks a command as analyzed so that its children are removed to avoid * being optimized. This rule should run after all other analysis rules are run. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c1578483ca..2d8ac6446e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, TypeUtils} import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsPartitionManagement} -import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnPosition, DeleteColumn} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -140,13 +139,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { case u: UnresolvedV2Relation => u.failAnalysis(s"Table not found: ${u.originalNameParts.quoted}") - case AlterTable(_, _, u: UnresolvedV2Relation, _) if isView(u.originalNameParts) => - u.failAnalysis( - s"Invalid command: '${u.originalNameParts.quoted}' is a view not a table.") - - case AlterTable(_, _, u: UnresolvedV2Relation, _) => - failAnalysis(s"Table not found: ${u.originalNameParts.quoted}") - case command: V2PartitionCommand => command.table match { case r @ ResolvedTable(_, _, table, _) => table match { @@ -449,87 +441,8 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { case write: V2WriteCommand if write.resolved => write.query.schema.foreach(f => TypeUtils.failWithIntervalType(f.dataType)) - case alter: AlterTableCommand if alter.table.resolved => - checkAlterTableCommand(alter) - - case alter: AlterTable if alter.table.resolved => - val table = alter.table - def findField(operation: String, fieldName: Array[String]): StructField = { - // include collections because structs nested in maps and arrays may be altered - val field = table.schema.findNestedField(fieldName, includeCollections = true) - if (field.isEmpty) { - alter.failAnalysis( - s"Cannot $operation missing field ${fieldName.quoted} in ${table.name} schema: " + - table.schema.treeString) - } - field.get._2 - } - def positionArgumentExists( - position: ColumnPosition, - struct: StructType, - fieldsAdded: Seq[String]): Unit = { - position match { - case after: After => - val allFields = struct.fieldNames ++ fieldsAdded - if (!allFields.contains(after.column())) { - alter.failAnalysis(s"Couldn't resolve positional argument $position amongst " + - s"${allFields.mkString("[", ", ", "]")}") - } - case _ => - } - } - def findParentStruct(operation: String, fieldNames: Array[String]): StructType = { - val parent = fieldNames.init - val field = if (parent.nonEmpty) { - findField(operation, parent).dataType - } else { - table.schema - } - field match { - case s: StructType => s - case o => alter.failAnalysis(s"Cannot $operation ${fieldNames.quoted}, because " + - s"its parent is not a StructType. Found $o") - } - } - def checkColumnNotExists( - operation: String, - fieldNames: Array[String], - struct: StructType): Unit = { - if (struct.findNestedField(fieldNames, includeCollections = true).isDefined) { - alter.failAnalysis(s"Cannot $operation column, because ${fieldNames.quoted} " + - s"already exists in ${struct.treeString}") - } - } - - val colsToDelete = mutable.Set.empty[Seq[String]] - // 'colsToAdd' keeps track of new columns being added. It stores a mapping from a parent - // name of fields to field names that belong to the parent. For example, if we add - // columns "a.b.c", "a.b.d", and "a.c", 'colsToAdd' will become - // Map(Seq("a", "b") -> Seq("c", "d"), Seq("a") -> Seq("c")). - val colsToAdd = mutable.Map.empty[Seq[String], Seq[String]] - - alter.changes.foreach { - case add: AddColumn => - // If a column to add is a part of columns to delete, we don't need to check - // if column already exists - applies to REPLACE COLUMNS scenario. - if (!colsToDelete.contains(add.fieldNames())) { - checkColumnNotExists("add", add.fieldNames(), table.schema) - } - val parent = findParentStruct("add", add.fieldNames()) - val parentName = add.fieldNames().init - val fieldsAdded = colsToAdd.getOrElse(parentName, Nil) - positionArgumentExists(add.position(), parent, fieldsAdded) - TypeUtils.failWithIntervalType(add.dataType()) - colsToAdd(parentName) = fieldsAdded :+ add.fieldNames().last - case delete: DeleteColumn => - findField("delete", delete.fieldNames) - // REPLACE COLUMNS has deletes followed by adds. Remember the deleted columns - // so that add operations do not fail when the columns to add exist and they - // are to be deleted. - colsToDelete += delete.fieldNames - case _ => - // no validation needed for set and remove property - } + case alter: AlterTableColumnCommand if alter.table.resolved => + checkAlterTableColumnCommand(alter) case _ => // Falls back to the following checks } @@ -1025,17 +938,23 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { /** * Validates the options used for alter table commands after table and columns are resolved. */ - private def checkAlterTableCommand(alter: AlterTableCommand): Unit = { - def checkColumnNotExists(fieldNames: Seq[String], struct: StructType): Unit = { + private def checkAlterTableColumnCommand(alter: AlterTableColumnCommand): Unit = { + def checkColumnNotExists(op: String, fieldNames: Seq[String], struct: StructType): Unit = { if (struct.findNestedField(fieldNames, includeCollections = true).isDefined) { - alter.failAnalysis(s"Cannot ${alter.operation} column, because ${fieldNames.quoted} " + + alter.failAnalysis(s"Cannot $op column, because ${fieldNames.quoted} " + s"already exists in ${struct.treeString}") } } alter match { + case AlterTableAddColumns(table: ResolvedTable, colsToAdd) => + colsToAdd.foreach { colToAdd => + checkColumnNotExists("add", colToAdd.name, table.schema) + } + case AlterTableRenameColumn(table: ResolvedTable, col: ResolvedFieldName, newName) => - checkColumnNotExists(col.path :+ newName, table.schema) + checkColumnNotExists("rename", col.path :+ newName, table.schema) + case a @ AlterTableAlterColumn(table: ResolvedTable, col: ResolvedFieldName, _, _, _, _) => val fieldName = col.name.quoted if (a.dataType.isDefined) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index c1e3644e26..1365cf6a13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog, TableChange} +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog} /** * Resolves catalogs from the multi-part identifiers in SQL statements, and convert the statements @@ -31,39 +31,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) import org.apache.spark.sql.connector.catalog.CatalogV2Util._ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case AlterTableAddColumnsStatement( - nameParts @ NonSessionCatalogAndTable(catalog, tbl), cols) => - val changes = cols.map { col => - TableChange.addColumn( - col.name.toArray, - col.dataType, - col.nullable, - col.comment.orNull, - col.position.orNull) - } - createAlterTable(nameParts, catalog, tbl, changes) - - case AlterTableReplaceColumnsStatement( - nameParts @ NonSessionCatalogAndTable(catalog, tbl), cols) => - val changes: Seq[TableChange] = loadTable(catalog, tbl.asIdentifier) match { - case Some(table) => - // REPLACE COLUMNS deletes all the existing columns and adds new columns specified. - val deleteChanges = table.schema.fieldNames.map { name => - TableChange.deleteColumn(Array(name)) - } - val addChanges = cols.map { col => - TableChange.addColumn( - col.name.toArray, - col.dataType, - col.nullable, - col.comment.orNull, - col.position.orNull) - } - deleteChanges ++ addChanges - case None => Seq() - } - createAlterTable(nameParts, catalog, tbl, changes) - case c @ CreateTableStatement( NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _) => CreateV2Table( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d213549ec4..a715686c8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3589,16 +3589,19 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg */ override def visitQualifiedColTypeWithPosition( ctx: QualifiedColTypeWithPositionContext): QualifiedColType = withOrigin(ctx) { + val name = typedVisit[Seq[String]](ctx.name) QualifiedColType( - name = typedVisit[Seq[String]](ctx.name), + path = if (name.length > 1) Some(UnresolvedFieldName(name.init)) else None, + colName = name.last, dataType = typedVisit[DataType](ctx.dataType), nullable = ctx.NULL == null, comment = Option(ctx.commentSpec()).map(visitCommentSpec), - position = Option(ctx.colPosition).map(typedVisit[ColumnPosition])) + position = Option(ctx.colPosition).map( pos => + UnresolvedFieldPosition(typedVisit[ColumnPosition](pos)))) } /** - * Parse a [[AlterTableAddColumnsStatement]] command. + * Parse a [[AlterTableAddColumns]] command. * * For example: * {{{ @@ -3607,8 +3610,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * }}} */ override def visitAddTableColumns(ctx: AddTableColumnsContext): LogicalPlan = withOrigin(ctx) { - AlterTableAddColumnsStatement( - visitMultipartIdentifier(ctx.multipartIdentifier), + val colToken = if (ctx.COLUMN() != null) "COLUMN" else "COLUMNS" + AlterTableAddColumns( + createUnresolvedTable(ctx.multipartIdentifier, s"ALTER TABLE ... ADD $colToken"), ctx.columns.qualifiedColTypeWithPosition.asScala.map(typedVisit[QualifiedColType]).toSeq ) } @@ -3726,8 +3730,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg if (ctx.partitionSpec != null) { operationNotAllowed("ALTER TABLE table PARTITION partition_spec REPLACE COLUMNS", ctx) } - AlterTableReplaceColumnsStatement( - visitMultipartIdentifier(ctx.multipartIdentifier), + AlterTableReplaceColumns( + createUnresolvedTable(ctx.multipartIdentifier, "ALTER TABLE ... REPLACE COLUMNS"), ctx.columns.qualifiedColTypeWithPosition.asScala.map { colType => if (colType.NULL != null) { throw QueryParsingErrors.operationInHiveStyleCommandUnsupportedError( @@ -3737,7 +3741,12 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg throw QueryParsingErrors.operationInHiveStyleCommandUnsupportedError( "Column position", "REPLACE COLUMNS", ctx) } - typedVisit[QualifiedColType](colType) + val col = typedVisit[QualifiedColType](colType) + if (col.path.isDefined) { + throw QueryParsingErrors.operationInHiveStyleCommandUnsupportedError( + "Replacing with a nested column", "REPLACE COLUMNS", ctx) + } + col }.toSeq ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index 8777b3caf8..378e6baca2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.analysis.ViewType +import org.apache.spark.sql.catalyst.analysis.{FieldName, FieldPosition, ViewType} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, FunctionResource} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike} -import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{DataType, StructType} @@ -226,25 +225,19 @@ case class ReplaceTableAsSelectStatement( /** - * Column data as parsed by ALTER TABLE ... ADD COLUMNS. + * Column data as parsed by ALTER TABLE ... (ADD|REPLACE) COLUMNS. */ case class QualifiedColType( - name: Seq[String], + path: Option[FieldName], + colName: String, dataType: DataType, nullable: Boolean, comment: Option[String], - position: Option[ColumnPosition]) + position: Option[FieldPosition]) { + def name: Seq[String] = path.map(_.name).getOrElse(Nil) :+ colName -/** - * ALTER TABLE ... ADD COLUMNS command, as parsed from SQL. - */ -case class AlterTableAddColumnsStatement( - tableName: Seq[String], - columnsToAdd: Seq[QualifiedColType]) extends LeafParsedStatement - -case class AlterTableReplaceColumnsStatement( - tableName: Seq[String], - columnsToAdd: Seq[QualifiedColType]) extends LeafParsedStatement + def resolved: Boolean = path.forall(_.resolved) && position.forall(_.resolved) +} /** * An INSERT INTO statement, as parsed from SQL. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index a82001a6e4..d2b59095e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -22,9 +22,8 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, Unevaluable} import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema import org.apache.spark.sql.catalyst.trees.BinaryLike -import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, TypeUtils} import org.apache.spark.sql.connector.catalog._ -import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, ColumnChange} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.write.Write import org.apache.spark.sql.types.{BooleanType, DataType, MetadataBuilder, StringType, StructType} @@ -545,38 +544,6 @@ case class NoopCommand( commandName: String, multipartIdentifier: Seq[String]) extends LeafCommand -/** - * The logical plan of the ALTER TABLE command. - */ -case class AlterTable( - catalog: TableCatalog, - ident: Identifier, - table: NamedRelation, - changes: Seq[TableChange]) extends LeafCommand { - - override lazy val resolved: Boolean = table.resolved && { - changes.forall { - case add: AddColumn => - add.fieldNames match { - case Array(_) => - // a top-level field can always be added - true - case _ => - // the parent field must exist - table.schema.findNestedField(add.fieldNames.init, includeCollections = true).isDefined - } - - case colChange: ColumnChange => - // the column that will be changed must exist - table.schema.findNestedField(colChange.fieldNames, includeCollections = true).isDefined - - case _ => - // property changes require no resolution checks - true - } - } -} - /** * The logical plan of the ALTER [TABLE|VIEW] ... RENAME TO command. */ @@ -1112,21 +1079,84 @@ case class UnsetTableProperties( copy(table = newChild) } -trait AlterTableCommand extends UnaryCommand { +trait AlterTableColumnCommand extends UnaryCommand { def table: LogicalPlan - def operation: String def changes: Seq[TableChange] override def child: LogicalPlan = table } +/** + * The logical plan of the ALTER TABLE ... ADD COLUMNS command. + */ +case class AlterTableAddColumns( + table: LogicalPlan, + columnsToAdd: Seq[QualifiedColType]) extends AlterTableColumnCommand { + columnsToAdd.foreach { c => + TypeUtils.failWithIntervalType(c.dataType) + } + + override lazy val resolved: Boolean = table.resolved && columnsToAdd.forall(_.resolved) + + override def changes: Seq[TableChange] = { + columnsToAdd.map { col => + require(col.path.forall(_.resolved), + "FieldName should be resolved before it's converted to TableChange.") + require(col.position.forall(_.resolved), + "FieldPosition should be resolved before it's converted to TableChange.") + TableChange.addColumn( + col.name.toArray, + col.dataType, + col.nullable, + col.comment.orNull, + col.position.map(_.position).orNull) + } + } + + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(table = newChild) +} + +/** + * The logical plan of the ALTER TABLE ... REPLACE COLUMNS command. + */ +case class AlterTableReplaceColumns( + table: LogicalPlan, + columnsToAdd: Seq[QualifiedColType]) extends AlterTableColumnCommand { + columnsToAdd.foreach { c => + TypeUtils.failWithIntervalType(c.dataType) + } + + override lazy val resolved: Boolean = table.resolved && columnsToAdd.forall(_.resolved) + + override def changes: Seq[TableChange] = { + // REPLACE COLUMNS deletes all the existing columns and adds new columns specified. + require(table.resolved) + val deleteChanges = table.schema.fieldNames.map { name => + TableChange.deleteColumn(Array(name)) + } + val addChanges = columnsToAdd.map { col => + assert(col.path.isEmpty) + assert(col.position.isEmpty) + TableChange.addColumn( + col.name.toArray, + col.dataType, + col.nullable, + col.comment.orNull, + null) + } + deleteChanges ++ addChanges + } + + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(table = newChild) +} + /** * The logical plan of the ALTER TABLE ... DROP COLUMNS command. */ case class AlterTableDropColumns( table: LogicalPlan, - columnsToDrop: Seq[FieldName]) extends AlterTableCommand { - override def operation: String = "delete" - + columnsToDrop: Seq[FieldName]) extends AlterTableColumnCommand { override def changes: Seq[TableChange] = { columnsToDrop.map { col => require(col.resolved, "FieldName should be resolved before it's converted to TableChange.") @@ -1144,9 +1174,7 @@ case class AlterTableDropColumns( case class AlterTableRenameColumn( table: LogicalPlan, column: FieldName, - newName: String) extends AlterTableCommand { - override def operation: String = "rename" - + newName: String) extends AlterTableColumnCommand { override def changes: Seq[TableChange] = { require(column.resolved, "FieldName should be resolved before it's converted to TableChange.") Seq(TableChange.renameColumn(column.name.toArray, newName)) @@ -1165,10 +1193,7 @@ case class AlterTableAlterColumn( dataType: Option[DataType], nullable: Option[Boolean], comment: Option[String], - position: Option[FieldPosition]) extends AlterTableCommand { - - override def operation: String = "update" - + position: Option[FieldPosition]) extends AlterTableColumnCommand { override def changes: Seq[TableChange] = { require(column.resolved, "FieldName should be resolved before it's converted to TableChange.") val colName = column.name.toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index a9e87724fe..69625a121e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -22,8 +22,8 @@ import java.util.Collections import scala.collection.JavaConverters._ -import org.apache.spark.sql.catalyst.analysis.{NamedRelation, NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException, UnresolvedV2Relation} -import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, CreateTableAsSelectStatement, CreateTableStatement, ReplaceTableAsSelectStatement, ReplaceTableStatement, SerdeInfo} +import org.apache.spark.sql.catalyst.analysis.{NamedRelation, NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException} +import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelectStatement, CreateTableStatement, ReplaceTableAsSelectStatement, ReplaceTableStatement, SerdeInfo} import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{ArrayType, MapType, StructField, StructType} @@ -356,17 +356,6 @@ private[sql] object CatalogV2Util { properties ++ Map(TableCatalog.PROP_OWNER -> Utils.getCurrentUserName()) } - def createAlterTable( - originalNameParts: Seq[String], - catalog: CatalogPlugin, - tableName: Seq[String], - changes: Seq[TableChange]): AlterTable = { - val tableCatalog = catalog.asTableCatalog - val ident = tableName.asIdentifier - val unresolved = UnresolvedV2Relation(originalNameParts, tableCatalog, ident) - AlterTable(tableCatalog, ident, unresolved, changes) - } - def getTableProviderCatalog( provider: SupportsCatalogOptions, catalogManager: CatalogManager, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index f31de476e9..1421fa39cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -2352,12 +2352,12 @@ private[spark] object QueryCompilationErrors { } def missingFieldError( - fieldName: Seq[String], table: ResolvedTable, context: Expression): Throwable = { + fieldName: Seq[String], table: ResolvedTable, context: Origin): Throwable = { throw new AnalysisException( s"Missing field ${fieldName.quoted} in table ${table.name} with schema:\n" + table.schema.treeString, - context.origin.line, - context.origin.startPosition) + context.line, + context.startPosition) } def invalidFieldName(fieldName: Seq[String], path: Seq[String], context: Origin): Throwable = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala index f7e57e3b27..c869524b5d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.catalyst.analysis +import java.util + import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LeafNode} -import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, TableCatalog} +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, Table, TableCapability, TableCatalog} import org.apache.spark.sql.connector.expressions.Expressions import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -151,3 +153,8 @@ private[sql] case object TestRelation2 extends LeafNode with NamedRelation { CreateTablePartitioningValidationSuite.schema.toAttributes } +private[sql] case object TestTable2 extends Table { + override def name: String = "table_name" + override def schema: StructType = CreateTablePartitioningValidationSuite.schema + override def capabilities: util.Set[TableCapability] = new util.HashSet[TableCapability]() +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index bca2680b52..ea35f8b13d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -788,88 +788,125 @@ class DDLParserSuite extends AnalysisTest { test("alter table: add column") { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMN x int"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, true, None, None) + AlterTableAddColumns( + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMN", None), + Seq(QualifiedColType(None, "x", IntegerType, true, None, None) ))) } test("alter table: add multiple columns") { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMNS x int, y string"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, true, None, None), - QualifiedColType(Seq("y"), StringType, true, None, None) + AlterTableAddColumns( + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMNS", None), + Seq(QualifiedColType(None, "x", IntegerType, true, None, None), + QualifiedColType(None, "y", StringType, true, None, None) ))) } test("alter table: add column with COLUMNS") { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMNS x int"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, true, None, None) + AlterTableAddColumns( + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMNS", None), + Seq(QualifiedColType(None, "x", IntegerType, true, None, None) ))) } test("alter table: add column with COLUMNS (...)") { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMNS (x int)"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, true, None, None) + AlterTableAddColumns( + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMNS", None), + Seq(QualifiedColType(None, "x", IntegerType, true, None, None) ))) } test("alter table: add column with COLUMNS (...) and COMMENT") { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMNS (x int COMMENT 'doc')"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, true, Some("doc"), None) + AlterTableAddColumns( + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMNS", None), + Seq(QualifiedColType(None, "x", IntegerType, true, Some("doc"), None) ))) } test("alter table: add non-nullable column") { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMN x int NOT NULL"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, false, None, None) + AlterTableAddColumns( + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMN", None), + Seq(QualifiedColType(None, "x", IntegerType, false, None, None) ))) } test("alter table: add column with COMMENT") { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMN x int COMMENT 'doc'"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, true, Some("doc"), None) + AlterTableAddColumns( + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMN", None), + Seq(QualifiedColType(None, "x", IntegerType, true, Some("doc"), None) ))) } test("alter table: add column with position") { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMN x int FIRST"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, true, None, Some(first())) + AlterTableAddColumns( + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMN", None), + Seq(QualifiedColType( + None, + "x", + IntegerType, + true, + None, + Some(UnresolvedFieldPosition(first()))) ))) comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMN x int AFTER y"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, true, None, Some(after("y"))) + AlterTableAddColumns( + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMN", None), + Seq(QualifiedColType( + None, + "x", + IntegerType, + true, + None, + Some(UnresolvedFieldPosition(after("y")))) ))) } test("alter table: add column with nested column name") { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMN x.y.z int COMMENT 'doc'"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x", "y", "z"), IntegerType, true, Some("doc"), None) + AlterTableAddColumns( + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMN", None), + Seq(QualifiedColType( + Some(UnresolvedFieldName(Seq("x", "y"))), "z", IntegerType, true, Some("doc"), None) ))) } test("alter table: add multiple columns with nested column name") { comparePlans( parsePlan("ALTER TABLE table_name ADD COLUMN x.y.z int COMMENT 'doc', a.b string FIRST"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x", "y", "z"), IntegerType, true, Some("doc"), None), - QualifiedColType(Seq("a", "b"), StringType, true, None, Some(first())) + AlterTableAddColumns( + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMN", None), + Seq( + QualifiedColType( + Some(UnresolvedFieldName(Seq("x", "y"))), + "z", + IntegerType, + true, + Some("doc"), + None), + QualifiedColType( + Some(UnresolvedFieldName(Seq("a"))), + "b", + StringType, + true, + None, + Some(UnresolvedFieldPosition(first()))) ))) } @@ -1062,32 +1099,32 @@ class DDLParserSuite extends AnalysisTest { comparePlans( parsePlan(sql1), - AlterTableReplaceColumnsStatement( - Seq("table_name"), - Seq(QualifiedColType(Seq("x"), StringType, true, None, None)))) + AlterTableReplaceColumns( + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... REPLACE COLUMNS", None), + Seq(QualifiedColType(None, "x", StringType, true, None, None)))) comparePlans( parsePlan(sql2), - AlterTableReplaceColumnsStatement( - Seq("table_name"), - Seq(QualifiedColType(Seq("x"), StringType, true, Some("x1"), None)))) + AlterTableReplaceColumns( + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... REPLACE COLUMNS", None), + Seq(QualifiedColType(None, "x", StringType, true, Some("x1"), None)))) comparePlans( parsePlan(sql3), - AlterTableReplaceColumnsStatement( - Seq("table_name"), + AlterTableReplaceColumns( + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... REPLACE COLUMNS", None), Seq( - QualifiedColType(Seq("x"), StringType, true, Some("x1"), None), - QualifiedColType(Seq("y"), IntegerType, true, None, None) + QualifiedColType(None, "x", StringType, true, Some("x1"), None), + QualifiedColType(None, "y", IntegerType, true, None, None) ))) comparePlans( parsePlan(sql4), - AlterTableReplaceColumnsStatement( - Seq("table_name"), + AlterTableReplaceColumns( + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... REPLACE COLUMNS", None), Seq( - QualifiedColType(Seq("x"), StringType, true, Some("x1"), None), - QualifiedColType(Seq("y"), IntegerType, true, Some("y1"), None) + QualifiedColType(None, "x", StringType, true, Some("x1"), None), + QualifiedColType(None, "y", IntegerType, true, Some("y1"), None) ))) intercept("ALTER TABLE table_name PARTITION (a='1') REPLACE COLUMNS (x string)", @@ -1098,6 +1135,9 @@ class DDLParserSuite extends AnalysisTest { intercept("ALTER TABLE table_name REPLACE COLUMNS (x string FIRST)", "Column position is not supported in Hive-style REPLACE COLUMNS") + + intercept("ALTER TABLE table_name REPLACE COLUMNS (a.b.c string)", + "Replacing with a nested column is not supported in Hive-style REPLACE COLUMNS") } test("alter view: rename view") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index b9c82c39bb..3a2f525162 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, toPrettySQL} -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, Identifier, LookupCatalog, SupportsNamespaces, TableChange, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, Identifier, LookupCatalog, SupportsNamespaces, V1Table} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command._ @@ -46,51 +46,17 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { - case AlterTableAddColumnsStatement( - nameParts @ SessionCatalogAndTable(catalog, tbl), cols) => - loadTable(catalog, tbl.asIdentifier).collect { - case v1Table: V1Table => - cols.foreach { c => - assertTopLevelColumn(c.name, "AlterTableAddColumnsCommand") - if (!c.nullable) { - throw QueryCompilationErrors.addColumnWithV1TableCannotSpecifyNotNullError - } - } - AlterTableAddColumnsCommand(tbl.asTableIdentifier, cols.map(convertToStructField)) - }.getOrElse { - val changes = cols.map { col => - TableChange.addColumn( - col.name.toArray, - col.dataType, - col.nullable, - col.comment.orNull, - col.position.orNull) + case AlterTableAddColumns(ResolvedV1TableIdentifier(ident), cols) => + cols.foreach { c => + assertTopLevelColumn(c.name, "AlterTableAddColumnsCommand") + if (!c.nullable) { + throw QueryCompilationErrors.addColumnWithV1TableCannotSpecifyNotNullError } - createAlterTable(nameParts, catalog, tbl, changes) } + AlterTableAddColumnsCommand(ident.asTableIdentifier, cols.map(convertToStructField)) - case AlterTableReplaceColumnsStatement( - nameParts @ SessionCatalogAndTable(catalog, tbl), cols) => - val changes: Seq[TableChange] = loadTable(catalog, tbl.asIdentifier) match { - case Some(_: V1Table) => - throw QueryCompilationErrors.replaceColumnsOnlySupportedWithV2TableError - case Some(table) => - // REPLACE COLUMNS deletes all the existing columns and adds new columns specified. - val deleteChanges = table.schema.fieldNames.map { name => - TableChange.deleteColumn(Array(name)) - } - val addChanges = cols.map { col => - TableChange.addColumn( - col.name.toArray, - col.dataType, - col.nullable, - col.comment.orNull, - col.position.orNull) - } - deleteChanges ++ addChanges - case None => Seq() // Unresolved table will be handled in CheckAnalysis. - } - createAlterTable(nameParts, catalog, tbl, changes) + case AlterTableReplaceColumns(ResolvedV1TableIdentifier(_), _) => + throw QueryCompilationErrors.replaceColumnsOnlySupportedWithV2TableError case a @ AlterTableAlterColumn(ResolvedV1TableAndIdentifier(table, ident), _, _, _, _, _) => if (a.column.name.length > 1) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 4d77674f12..3d69029e94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -288,9 +288,6 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case _: NoopCommand => LocalTableScanExec(Nil, Nil) :: Nil - case AlterTable(catalog, ident, _, changes) => - AlterTableExec(catalog, ident, changes) :: Nil - case RenameTable(r @ ResolvedTable(catalog, oldIdent, _, _), newIdent, isView) => if (isView) { throw QueryCompilationErrors.cannotRenameTableWithAlterViewError() @@ -445,7 +442,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat val changes = keys.map(key => TableChange.removeProperty(key)) AlterTableExec(table.catalog, table.identifier, changes) :: Nil - case a: AlterTableCommand if a.table.resolved => + case a: AlterTableColumnCommand if a.table.resolved => val table = a.table.asInstanceOf[ResolvedTable] AlterTableExec(table.catalog, table.identifier, a.changes) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index 6c6a155780..004a64ac69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -356,8 +356,7 @@ trait AlterTableTests extends SharedSparkSession { sql(s"ALTER TABLE $t ADD COLUMN point.z double") } - assert(exc.getMessage.contains("point")) - assert(exc.getMessage.contains("missing field")) + assert(exc.getMessage.contains("Missing field point")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala index 722f3ea138..6651576150 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.connector -import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, CreateTablePartitioningValidationSuite, ResolvedTable, TestRelation2, UnresolvedFieldName} -import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, AlterTableAlterColumn, AlterTableCommand, AlterTableDropColumns, AlterTableRenameColumn, CreateTableAsSelect, LogicalPlan, ReplaceTableAsSelect} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, CreateTablePartitioningValidationSuite, ResolvedTable, TestRelation2, TestTable2, UnresolvedFieldName, UnresolvedFieldPosition} +import org.apache.spark.sql.catalyst.plans.logical.{AlterTableAddColumns, AlterTableAlterColumn, AlterTableColumnCommand, AlterTableDropColumns, AlterTableRenameColumn, CreateTableAsSelect, LogicalPlan, QualifiedColType, ReplaceTableAsSelect} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.{Identifier, TableChange} +import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.expressions.Expressions import org.apache.spark.sql.execution.datasources.PreprocessTableCreation @@ -35,7 +35,7 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes private val table = ResolvedTable( catalog, Identifier.of(Array(), "table_name"), - null, + TestTable2, schema.toAttributes) override protected def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = { @@ -140,8 +140,11 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes Seq("POINT.Z", "poInt.z", "poInt.Z").foreach { ref => val field = ref.split("\\.") alterTableTest( - TableChange.addColumn(field, LongType), - Seq("add", field.head) + AlterTableAddColumns( + table, + Seq(QualifiedColType( + Some(UnresolvedFieldName(field.init)), field.last, LongType, true, None, None))), + Seq("Missing field " + field.head) ) } } @@ -149,8 +152,15 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes test("AlterTable: add column resolution - positional") { Seq("ID", "iD").foreach { ref => alterTableTest( - TableChange.addColumn( - Array("f"), LongType, true, null, ColumnPosition.after(ref)), + AlterTableAddColumns( + table, + Seq(QualifiedColType( + None, + "f", + LongType, + true, + None, + Some(UnresolvedFieldPosition(ColumnPosition.after(ref)))))), Seq("reference column", ref) ) } @@ -158,11 +168,22 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes test("AlterTable: add column resolution - column position referencing new column") { alterTableTest( - Seq( - TableChange.addColumn( - Array("x"), LongType, true, null, ColumnPosition.after("id")), - TableChange.addColumn( - Array("y"), LongType, true, null, ColumnPosition.after("X"))), + AlterTableAddColumns( + table, + Seq(QualifiedColType( + None, + "x", + LongType, + true, + None, + Some(UnresolvedFieldPosition(ColumnPosition.after("id")))), + QualifiedColType( + None, + "x", + LongType, + true, + None, + Some(UnresolvedFieldPosition(ColumnPosition.after("X")))))), Seq("Couldn't find the reference column for AFTER X at root") ) } @@ -170,8 +191,15 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes test("AlterTable: add column resolution - nested positional") { Seq("X", "Y").foreach { ref => alterTableTest( - TableChange.addColumn( - Array("point", "z"), LongType, true, null, ColumnPosition.after(ref)), + AlterTableAddColumns( + table, + Seq(QualifiedColType( + Some(UnresolvedFieldName(Seq("point"))), + "z", + LongType, + true, + None, + Some(UnresolvedFieldPosition(ColumnPosition.after(ref)))))), Seq("reference column", ref) ) } @@ -179,11 +207,22 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes test("AlterTable: add column resolution - column position referencing new nested column") { alterTableTest( - Seq( - TableChange.addColumn( - Array("point", "z"), LongType, true, null), - TableChange.addColumn( - Array("point", "zz"), LongType, true, null, ColumnPosition.after("Z"))), + AlterTableAddColumns( + table, + Seq(QualifiedColType( + Some(UnresolvedFieldName(Seq("point"))), + "z", + LongType, + true, + None, + None), + QualifiedColType( + Some(UnresolvedFieldName(Seq("point"))), + "zz", + LongType, + true, + None, + Some(UnresolvedFieldPosition(ColumnPosition.after("Z")))))), Seq("Couldn't find the reference column for AFTER Z at point") ) } @@ -233,30 +272,7 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes } } - private def alterTableTest(change: TableChange, error: Seq[String]): Unit = { - alterTableTest(Seq(change), error) - } - - private def alterTableTest(changes: Seq[TableChange], error: Seq[String]): Unit = { - Seq(true, false).foreach { caseSensitive => - withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - val plan = AlterTable( - catalog, - Identifier.of(Array(), "table_name"), - TestRelation2, - changes - ) - - if (caseSensitive) { - assertAnalysisError(plan, error, caseSensitive) - } else { - assertAnalysisSuccess(plan, caseSensitive) - } - } - } - } - - private def alterTableTest(alter: AlterTableCommand, error: Seq[String]): Unit = { + private def alterTableTest(alter: AlterTableColumnCommand, error: Seq[String]): Unit = { Seq(true, false).foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { if (caseSensitive) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 8842ba07b8..ba87883b64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2305,7 +2305,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val e = intercept[AnalysisException] { sql("ALTER TABLE tmp_v ADD COLUMNS (c3 INT)") } - assert(e.message.contains("'tmp_v' is a view not a table")) + assert(e.message.contains( + "tmp_v is a temp view. 'ALTER TABLE ... ADD COLUMNS' expects a table.")) } } @@ -2315,7 +2316,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val e = intercept[AnalysisException] { sql("ALTER TABLE v1 ADD COLUMNS (c3 INT)") } - assert(e.message.contains("ALTER ADD COLUMNS does not support views")) + assert(e.message.contains( + "default.v1 is a view. 'ALTER TABLE ... ADD COLUMNS' expects a table.")) } }