[SPARK-36006][SQL] Migrate ALTER TABLE ... ADD/REPLACE COLUMNS commands to use UnresolvedTable to resolve the identifier

### What changes were proposed in this pull request?

This PR proposes to migrate the following `ALTER TABLE ... ADD/REPLACE COLUMNS` commands to use `UnresolvedTable` as a `child` to resolve the table identifier. This allows consistent resolution rules (temp view first, etc.) to be applied for both v1/v2 commands. More info about the consistent resolution rule proposal can be found in [JIRA](https://issues.apache.org/jira/browse/SPARK-29900) or [proposal doc](https://docs.google.com/document/d/1hvLjGA8y_W_hhilpngXVub1Ebv8RsMap986nENCFnrg/edit?usp=sharing).

### Why are the changes needed?

This is a part of effort to make the relation lookup behavior consistent: [SPARK-29900](https://issues.apache.org/jira/browse/SPARK-29900).

### Does this PR introduce _any_ user-facing change?

After this PR, the above `ALTER TABLE ... ADD/REPLACE COLUMNS` commands will have a consistent resolution behavior.

### How was this patch tested?

Updated existing tests.

Closes #33200 from imback82/alter_add_cols.

Authored-by: Terry Kim <yuminkim@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Terry Kim 2021-07-28 14:00:29 +08:00 committed by Wenchen Fan
parent c9a7ff3f36
commit 809b88a162
15 changed files with 335 additions and 453 deletions

View file

@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils}
import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn} import org.apache.spark.sql.connector.catalog.TableChange.{After, ColumnPosition}
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, BoundFunction, ScalarFunction} import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, BoundFunction, ScalarFunction}
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
@ -269,7 +269,7 @@ class Analyzer(override val catalogManager: CatalogManager)
ResolveRelations :: ResolveRelations ::
ResolveTables :: ResolveTables ::
ResolvePartitionSpec :: ResolvePartitionSpec ::
ResolveAlterTableCommands :: ResolveAlterTableColumnCommands ::
AddMetadataColumns :: AddMetadataColumns ::
DeduplicateRelations :: DeduplicateRelations ::
ResolveReferences :: ResolveReferences ::
@ -312,7 +312,6 @@ class Analyzer(override val catalogManager: CatalogManager)
Batch("Post-Hoc Resolution", Once, Batch("Post-Hoc Resolution", Once,
Seq(ResolveCommandsWithIfExists) ++ Seq(ResolveCommandsWithIfExists) ++
postHocResolutionRules: _*), postHocResolutionRules: _*),
Batch("Normalize Alter Table", Once, ResolveAlterTableChanges),
Batch("Remove Unresolved Hints", Once, Batch("Remove Unresolved Hints", Once,
new ResolveHints.RemoveAllHints), new ResolveHints.RemoveAllHints),
Batch("Nondeterministic", Once, Batch("Nondeterministic", Once,
@ -1082,11 +1081,6 @@ class Analyzer(override val catalogManager: CatalogManager)
case _ => write case _ => write
} }
case alter @ AlterTable(_, _, u: UnresolvedV2Relation, _) =>
CatalogV2Util.loadRelation(u.catalog, u.tableName)
.map(rel => alter.copy(table = rel))
.getOrElse(alter)
case u: UnresolvedV2Relation => case u: UnresolvedV2Relation =>
CatalogV2Util.loadRelation(u.catalog, u.tableName).getOrElse(u) CatalogV2Util.loadRelation(u.catalog, u.tableName).getOrElse(u)
} }
@ -3611,16 +3605,69 @@ class Analyzer(override val catalogManager: CatalogManager)
/** /**
* Rule to mostly resolve, normalize and rewrite column names based on case sensitivity * Rule to mostly resolve, normalize and rewrite column names based on case sensitivity
* for alter table commands. * for alter table column commands.
*/ */
object ResolveAlterTableCommands extends Rule[LogicalPlan] { object ResolveAlterTableColumnCommands extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case a: AlterTableCommand if a.table.resolved && hasUnresolvedFieldName(a) => case a: AlterTableColumnCommand if a.table.resolved && hasUnresolvedFieldName(a) =>
val table = a.table.asInstanceOf[ResolvedTable] val table = a.table.asInstanceOf[ResolvedTable]
a.transformExpressions { a.transformExpressions {
case u: UnresolvedFieldName => resolveFieldNames(table, u.name, u) case u: UnresolvedFieldName => resolveFieldNames(table, u.name, u)
} }
case a @ AlterTableAddColumns(r: ResolvedTable, cols) if !a.resolved =>
// 'colsToAdd' keeps track of new columns being added. It stores a mapping from a
// normalized parent name of fields to field names that belong to the parent.
// For example, if we add columns "a.b.c", "a.b.d", and "a.c", 'colsToAdd' will become
// Map(Seq("a", "b") -> Seq("c", "d"), Seq("a") -> Seq("c")).
val colsToAdd = mutable.Map.empty[Seq[String], Seq[String]]
def resolvePosition(
col: QualifiedColType,
parentSchema: StructType,
resolvedParentName: Seq[String]): Option[FieldPosition] = {
val fieldsAdded = colsToAdd.getOrElse(resolvedParentName, Nil)
val resolvedPosition = col.position.map {
case u: UnresolvedFieldPosition => u.position match {
case after: After =>
val allFields = parentSchema.fieldNames ++ fieldsAdded
allFields.find(n => conf.resolver(n, after.column())) match {
case Some(colName) =>
ResolvedFieldPosition(ColumnPosition.after(colName))
case None =>
val name = if (resolvedParentName.isEmpty) "root" else resolvedParentName.quoted
throw QueryCompilationErrors.referenceColNotFoundForAlterTableChangesError(
after, name)
}
case _ => ResolvedFieldPosition(u.position)
}
case resolved => resolved
}
colsToAdd(resolvedParentName) = fieldsAdded :+ col.colName
resolvedPosition
}
val schema = r.table.schema
val resolvedCols = cols.map { col =>
col.path match {
case Some(parent: UnresolvedFieldName) =>
// Adding a nested field, need to resolve the parent column and position.
val resolvedParent = resolveFieldNames(r, parent.name, parent)
val parentSchema = resolvedParent.field.dataType match {
case s: StructType => s
case _ => throw QueryCompilationErrors.invalidFieldName(
col.name, parent.name, parent.origin)
}
val resolvedPosition = resolvePosition(col, parentSchema, resolvedParent.name)
col.copy(path = Some(resolvedParent), position = resolvedPosition)
case _ =>
// Adding to the root. Just need to resolve position.
val resolvedPosition = resolvePosition(col, schema, Nil)
col.copy(position = resolvedPosition)
}
}
val resolved = a.copy(columnsToAdd = resolvedCols)
resolved.copyTagsFrom(a)
resolved
case a @ AlterTableAlterColumn( case a @ AlterTableAlterColumn(
table: ResolvedTable, ResolvedFieldName(path, field), dataType, _, _, position) => table: ResolvedTable, ResolvedFieldName(path, field), dataType, _, _, position) =>
val newDataType = dataType.flatMap { dt => val newDataType = dataType.flatMap { dt =>
@ -3655,108 +3702,14 @@ class Analyzer(override val catalogManager: CatalogManager)
fieldName, includeCollections = true, conf.resolver, context.origin fieldName, includeCollections = true, conf.resolver, context.origin
).map { ).map {
case (path, field) => ResolvedFieldName(path, field) case (path, field) => ResolvedFieldName(path, field)
}.getOrElse(throw QueryCompilationErrors.missingFieldError(fieldName, table, context)) }.getOrElse(throw QueryCompilationErrors.missingFieldError(fieldName, table, context.origin))
} }
private def hasUnresolvedFieldName(a: AlterTableCommand): Boolean = { private def hasUnresolvedFieldName(a: AlterTableColumnCommand): Boolean = {
a.expressions.exists(_.find(_.isInstanceOf[UnresolvedFieldName]).isDefined) a.expressions.exists(_.find(_.isInstanceOf[UnresolvedFieldName]).isDefined)
} }
} }
/** Rule to mostly resolve, normalize and rewrite column names based on case sensitivity. */
object ResolveAlterTableChanges extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case a @ AlterTable(_, _, t: NamedRelation, changes) if t.resolved =>
// 'colsToAdd' keeps track of new columns being added. It stores a mapping from a
// normalized parent name of fields to field names that belong to the parent.
// For example, if we add columns "a.b.c", "a.b.d", and "a.c", 'colsToAdd' will become
// Map(Seq("a", "b") -> Seq("c", "d"), Seq("a") -> Seq("c")).
val colsToAdd = mutable.Map.empty[Seq[String], Seq[String]]
val schema = t.schema
val normalizedChanges = changes.flatMap {
case add: AddColumn =>
def addColumn(
parentSchema: StructType,
parentName: String,
normalizedParentName: Seq[String]): TableChange = {
val fieldsAdded = colsToAdd.getOrElse(normalizedParentName, Nil)
val pos = findColumnPosition(add.position(), parentName, parentSchema, fieldsAdded)
val field = add.fieldNames().last
colsToAdd(normalizedParentName) = fieldsAdded :+ field
TableChange.addColumn(
(normalizedParentName :+ field).toArray,
add.dataType(),
add.isNullable,
add.comment,
pos)
}
val parent = add.fieldNames().init
if (parent.nonEmpty) {
// Adding a nested field, need to normalize the parent column and position
val target = schema.findNestedField(parent, includeCollections = true, conf.resolver)
if (target.isEmpty) {
// Leave unresolved. Throws error in CheckAnalysis
Some(add)
} else {
val (normalizedName, sf) = target.get
sf.dataType match {
case struct: StructType =>
Some(addColumn(struct, parent.quoted, normalizedName :+ sf.name))
case other =>
Some(add)
}
}
} else {
// Adding to the root. Just need to normalize position
Some(addColumn(schema, "root", Nil))
}
case delete: DeleteColumn =>
resolveFieldNames(schema, delete.fieldNames(), TableChange.deleteColumn)
.orElse(Some(delete))
case column: ColumnChange =>
// This is informational for future developers
throw QueryExecutionErrors.columnChangeUnsupportedError
case other => Some(other)
}
a.copy(changes = normalizedChanges)
}
/**
* Returns the table change if the field can be resolved, returns None if the column is not
* found. An error will be thrown in CheckAnalysis for columns that can't be resolved.
*/
private def resolveFieldNames(
schema: StructType,
fieldNames: Array[String],
copy: Array[String] => TableChange): Option[TableChange] = {
val fieldOpt = schema.findNestedField(
fieldNames, includeCollections = true, conf.resolver)
fieldOpt.map { case (path, field) => copy((path :+ field.name).toArray) }
}
private def findColumnPosition(
position: ColumnPosition,
parentName: String,
struct: StructType,
fieldsAdded: Seq[String]): ColumnPosition = {
position match {
case null => null
case after: After =>
(struct.fieldNames ++ fieldsAdded).find(n => conf.resolver(n, after.column())) match {
case Some(colName) =>
ColumnPosition.after(colName)
case None =>
throw QueryCompilationErrors.referenceColNotFoundForAlterTableChangesError(after,
parentName)
}
case other => other
}
}
}
/** /**
* A rule that marks a command as analyzed so that its children are removed to avoid * A rule that marks a command as analyzed so that its children are removed to avoid
* being optimized. This rule should run after all other analysis rules are run. * being optimized. This rule should run after all other analysis rules are run.

View file

@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, TypeUtils} import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, TypeUtils}
import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsPartitionManagement} import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsPartitionManagement}
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnPosition, DeleteColumn}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
@ -140,13 +139,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
case u: UnresolvedV2Relation => case u: UnresolvedV2Relation =>
u.failAnalysis(s"Table not found: ${u.originalNameParts.quoted}") u.failAnalysis(s"Table not found: ${u.originalNameParts.quoted}")
case AlterTable(_, _, u: UnresolvedV2Relation, _) if isView(u.originalNameParts) =>
u.failAnalysis(
s"Invalid command: '${u.originalNameParts.quoted}' is a view not a table.")
case AlterTable(_, _, u: UnresolvedV2Relation, _) =>
failAnalysis(s"Table not found: ${u.originalNameParts.quoted}")
case command: V2PartitionCommand => case command: V2PartitionCommand =>
command.table match { command.table match {
case r @ ResolvedTable(_, _, table, _) => table match { case r @ ResolvedTable(_, _, table, _) => table match {
@ -449,87 +441,8 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
case write: V2WriteCommand if write.resolved => case write: V2WriteCommand if write.resolved =>
write.query.schema.foreach(f => TypeUtils.failWithIntervalType(f.dataType)) write.query.schema.foreach(f => TypeUtils.failWithIntervalType(f.dataType))
case alter: AlterTableCommand if alter.table.resolved => case alter: AlterTableColumnCommand if alter.table.resolved =>
checkAlterTableCommand(alter) checkAlterTableColumnCommand(alter)
case alter: AlterTable if alter.table.resolved =>
val table = alter.table
def findField(operation: String, fieldName: Array[String]): StructField = {
// include collections because structs nested in maps and arrays may be altered
val field = table.schema.findNestedField(fieldName, includeCollections = true)
if (field.isEmpty) {
alter.failAnalysis(
s"Cannot $operation missing field ${fieldName.quoted} in ${table.name} schema: " +
table.schema.treeString)
}
field.get._2
}
def positionArgumentExists(
position: ColumnPosition,
struct: StructType,
fieldsAdded: Seq[String]): Unit = {
position match {
case after: After =>
val allFields = struct.fieldNames ++ fieldsAdded
if (!allFields.contains(after.column())) {
alter.failAnalysis(s"Couldn't resolve positional argument $position amongst " +
s"${allFields.mkString("[", ", ", "]")}")
}
case _ =>
}
}
def findParentStruct(operation: String, fieldNames: Array[String]): StructType = {
val parent = fieldNames.init
val field = if (parent.nonEmpty) {
findField(operation, parent).dataType
} else {
table.schema
}
field match {
case s: StructType => s
case o => alter.failAnalysis(s"Cannot $operation ${fieldNames.quoted}, because " +
s"its parent is not a StructType. Found $o")
}
}
def checkColumnNotExists(
operation: String,
fieldNames: Array[String],
struct: StructType): Unit = {
if (struct.findNestedField(fieldNames, includeCollections = true).isDefined) {
alter.failAnalysis(s"Cannot $operation column, because ${fieldNames.quoted} " +
s"already exists in ${struct.treeString}")
}
}
val colsToDelete = mutable.Set.empty[Seq[String]]
// 'colsToAdd' keeps track of new columns being added. It stores a mapping from a parent
// name of fields to field names that belong to the parent. For example, if we add
// columns "a.b.c", "a.b.d", and "a.c", 'colsToAdd' will become
// Map(Seq("a", "b") -> Seq("c", "d"), Seq("a") -> Seq("c")).
val colsToAdd = mutable.Map.empty[Seq[String], Seq[String]]
alter.changes.foreach {
case add: AddColumn =>
// If a column to add is a part of columns to delete, we don't need to check
// if column already exists - applies to REPLACE COLUMNS scenario.
if (!colsToDelete.contains(add.fieldNames())) {
checkColumnNotExists("add", add.fieldNames(), table.schema)
}
val parent = findParentStruct("add", add.fieldNames())
val parentName = add.fieldNames().init
val fieldsAdded = colsToAdd.getOrElse(parentName, Nil)
positionArgumentExists(add.position(), parent, fieldsAdded)
TypeUtils.failWithIntervalType(add.dataType())
colsToAdd(parentName) = fieldsAdded :+ add.fieldNames().last
case delete: DeleteColumn =>
findField("delete", delete.fieldNames)
// REPLACE COLUMNS has deletes followed by adds. Remember the deleted columns
// so that add operations do not fail when the columns to add exist and they
// are to be deleted.
colsToDelete += delete.fieldNames
case _ =>
// no validation needed for set and remove property
}
case _ => // Falls back to the following checks case _ => // Falls back to the following checks
} }
@ -1025,17 +938,23 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
/** /**
* Validates the options used for alter table commands after table and columns are resolved. * Validates the options used for alter table commands after table and columns are resolved.
*/ */
private def checkAlterTableCommand(alter: AlterTableCommand): Unit = { private def checkAlterTableColumnCommand(alter: AlterTableColumnCommand): Unit = {
def checkColumnNotExists(fieldNames: Seq[String], struct: StructType): Unit = { def checkColumnNotExists(op: String, fieldNames: Seq[String], struct: StructType): Unit = {
if (struct.findNestedField(fieldNames, includeCollections = true).isDefined) { if (struct.findNestedField(fieldNames, includeCollections = true).isDefined) {
alter.failAnalysis(s"Cannot ${alter.operation} column, because ${fieldNames.quoted} " + alter.failAnalysis(s"Cannot $op column, because ${fieldNames.quoted} " +
s"already exists in ${struct.treeString}") s"already exists in ${struct.treeString}")
} }
} }
alter match { alter match {
case AlterTableAddColumns(table: ResolvedTable, colsToAdd) =>
colsToAdd.foreach { colToAdd =>
checkColumnNotExists("add", colToAdd.name, table.schema)
}
case AlterTableRenameColumn(table: ResolvedTable, col: ResolvedFieldName, newName) => case AlterTableRenameColumn(table: ResolvedTable, col: ResolvedFieldName, newName) =>
checkColumnNotExists(col.path :+ newName, table.schema) checkColumnNotExists("rename", col.path :+ newName, table.schema)
case a @ AlterTableAlterColumn(table: ResolvedTable, col: ResolvedFieldName, _, _, _, _) => case a @ AlterTableAlterColumn(table: ResolvedTable, col: ResolvedFieldName, _, _, _, _) =>
val fieldName = col.name.quoted val fieldName = col.name.quoted
if (a.dataType.isDefined) { if (a.dataType.isDefined) {

View file

@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog, TableChange} import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog}
/** /**
* Resolves catalogs from the multi-part identifiers in SQL statements, and convert the statements * Resolves catalogs from the multi-part identifiers in SQL statements, and convert the statements
@ -31,39 +31,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
import org.apache.spark.sql.connector.catalog.CatalogV2Util._ import org.apache.spark.sql.connector.catalog.CatalogV2Util._
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case AlterTableAddColumnsStatement(
nameParts @ NonSessionCatalogAndTable(catalog, tbl), cols) =>
val changes = cols.map { col =>
TableChange.addColumn(
col.name.toArray,
col.dataType,
col.nullable,
col.comment.orNull,
col.position.orNull)
}
createAlterTable(nameParts, catalog, tbl, changes)
case AlterTableReplaceColumnsStatement(
nameParts @ NonSessionCatalogAndTable(catalog, tbl), cols) =>
val changes: Seq[TableChange] = loadTable(catalog, tbl.asIdentifier) match {
case Some(table) =>
// REPLACE COLUMNS deletes all the existing columns and adds new columns specified.
val deleteChanges = table.schema.fieldNames.map { name =>
TableChange.deleteColumn(Array(name))
}
val addChanges = cols.map { col =>
TableChange.addColumn(
col.name.toArray,
col.dataType,
col.nullable,
col.comment.orNull,
col.position.orNull)
}
deleteChanges ++ addChanges
case None => Seq()
}
createAlterTable(nameParts, catalog, tbl, changes)
case c @ CreateTableStatement( case c @ CreateTableStatement(
NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _) => NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _) =>
CreateV2Table( CreateV2Table(

View file

@ -3589,16 +3589,19 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
*/ */
override def visitQualifiedColTypeWithPosition( override def visitQualifiedColTypeWithPosition(
ctx: QualifiedColTypeWithPositionContext): QualifiedColType = withOrigin(ctx) { ctx: QualifiedColTypeWithPositionContext): QualifiedColType = withOrigin(ctx) {
val name = typedVisit[Seq[String]](ctx.name)
QualifiedColType( QualifiedColType(
name = typedVisit[Seq[String]](ctx.name), path = if (name.length > 1) Some(UnresolvedFieldName(name.init)) else None,
colName = name.last,
dataType = typedVisit[DataType](ctx.dataType), dataType = typedVisit[DataType](ctx.dataType),
nullable = ctx.NULL == null, nullable = ctx.NULL == null,
comment = Option(ctx.commentSpec()).map(visitCommentSpec), comment = Option(ctx.commentSpec()).map(visitCommentSpec),
position = Option(ctx.colPosition).map(typedVisit[ColumnPosition])) position = Option(ctx.colPosition).map( pos =>
UnresolvedFieldPosition(typedVisit[ColumnPosition](pos))))
} }
/** /**
* Parse a [[AlterTableAddColumnsStatement]] command. * Parse a [[AlterTableAddColumns]] command.
* *
* For example: * For example:
* {{{ * {{{
@ -3607,8 +3610,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
* }}} * }}}
*/ */
override def visitAddTableColumns(ctx: AddTableColumnsContext): LogicalPlan = withOrigin(ctx) { override def visitAddTableColumns(ctx: AddTableColumnsContext): LogicalPlan = withOrigin(ctx) {
AlterTableAddColumnsStatement( val colToken = if (ctx.COLUMN() != null) "COLUMN" else "COLUMNS"
visitMultipartIdentifier(ctx.multipartIdentifier), AlterTableAddColumns(
createUnresolvedTable(ctx.multipartIdentifier, s"ALTER TABLE ... ADD $colToken"),
ctx.columns.qualifiedColTypeWithPosition.asScala.map(typedVisit[QualifiedColType]).toSeq ctx.columns.qualifiedColTypeWithPosition.asScala.map(typedVisit[QualifiedColType]).toSeq
) )
} }
@ -3726,8 +3730,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
if (ctx.partitionSpec != null) { if (ctx.partitionSpec != null) {
operationNotAllowed("ALTER TABLE table PARTITION partition_spec REPLACE COLUMNS", ctx) operationNotAllowed("ALTER TABLE table PARTITION partition_spec REPLACE COLUMNS", ctx)
} }
AlterTableReplaceColumnsStatement( AlterTableReplaceColumns(
visitMultipartIdentifier(ctx.multipartIdentifier), createUnresolvedTable(ctx.multipartIdentifier, "ALTER TABLE ... REPLACE COLUMNS"),
ctx.columns.qualifiedColTypeWithPosition.asScala.map { colType => ctx.columns.qualifiedColTypeWithPosition.asScala.map { colType =>
if (colType.NULL != null) { if (colType.NULL != null) {
throw QueryParsingErrors.operationInHiveStyleCommandUnsupportedError( throw QueryParsingErrors.operationInHiveStyleCommandUnsupportedError(
@ -3737,7 +3741,12 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
throw QueryParsingErrors.operationInHiveStyleCommandUnsupportedError( throw QueryParsingErrors.operationInHiveStyleCommandUnsupportedError(
"Column position", "REPLACE COLUMNS", ctx) "Column position", "REPLACE COLUMNS", ctx)
} }
typedVisit[QualifiedColType](colType) val col = typedVisit[QualifiedColType](colType)
if (col.path.isDefined) {
throw QueryParsingErrors.operationInHiveStyleCommandUnsupportedError(
"Replacing with a nested column", "REPLACE COLUMNS", ctx)
}
col
}.toSeq }.toSeq
) )
} }

View file

@ -17,11 +17,10 @@
package org.apache.spark.sql.catalyst.plans.logical package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.analysis.ViewType import org.apache.spark.sql.catalyst.analysis.{FieldName, FieldPosition, ViewType}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, FunctionResource} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, FunctionResource}
import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike} import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike}
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.types.{DataType, StructType}
@ -226,25 +225,19 @@ case class ReplaceTableAsSelectStatement(
/** /**
* Column data as parsed by ALTER TABLE ... ADD COLUMNS. * Column data as parsed by ALTER TABLE ... (ADD|REPLACE) COLUMNS.
*/ */
case class QualifiedColType( case class QualifiedColType(
name: Seq[String], path: Option[FieldName],
colName: String,
dataType: DataType, dataType: DataType,
nullable: Boolean, nullable: Boolean,
comment: Option[String], comment: Option[String],
position: Option[ColumnPosition]) position: Option[FieldPosition]) {
def name: Seq[String] = path.map(_.name).getOrElse(Nil) :+ colName
/** def resolved: Boolean = path.forall(_.resolved) && position.forall(_.resolved)
* ALTER TABLE ... ADD COLUMNS command, as parsed from SQL. }
*/
case class AlterTableAddColumnsStatement(
tableName: Seq[String],
columnsToAdd: Seq[QualifiedColType]) extends LeafParsedStatement
case class AlterTableReplaceColumnsStatement(
tableName: Seq[String],
columnsToAdd: Seq[QualifiedColType]) extends LeafParsedStatement
/** /**
* An INSERT INTO statement, as parsed from SQL. * An INSERT INTO statement, as parsed from SQL.

View file

@ -22,9 +22,8 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, Unevaluable} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, Unevaluable}
import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema
import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, TypeUtils}
import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, ColumnChange}
import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.write.Write import org.apache.spark.sql.connector.write.Write
import org.apache.spark.sql.types.{BooleanType, DataType, MetadataBuilder, StringType, StructType} import org.apache.spark.sql.types.{BooleanType, DataType, MetadataBuilder, StringType, StructType}
@ -545,38 +544,6 @@ case class NoopCommand(
commandName: String, commandName: String,
multipartIdentifier: Seq[String]) extends LeafCommand multipartIdentifier: Seq[String]) extends LeafCommand
/**
* The logical plan of the ALTER TABLE command.
*/
case class AlterTable(
catalog: TableCatalog,
ident: Identifier,
table: NamedRelation,
changes: Seq[TableChange]) extends LeafCommand {
override lazy val resolved: Boolean = table.resolved && {
changes.forall {
case add: AddColumn =>
add.fieldNames match {
case Array(_) =>
// a top-level field can always be added
true
case _ =>
// the parent field must exist
table.schema.findNestedField(add.fieldNames.init, includeCollections = true).isDefined
}
case colChange: ColumnChange =>
// the column that will be changed must exist
table.schema.findNestedField(colChange.fieldNames, includeCollections = true).isDefined
case _ =>
// property changes require no resolution checks
true
}
}
}
/** /**
* The logical plan of the ALTER [TABLE|VIEW] ... RENAME TO command. * The logical plan of the ALTER [TABLE|VIEW] ... RENAME TO command.
*/ */
@ -1112,21 +1079,84 @@ case class UnsetTableProperties(
copy(table = newChild) copy(table = newChild)
} }
trait AlterTableCommand extends UnaryCommand { trait AlterTableColumnCommand extends UnaryCommand {
def table: LogicalPlan def table: LogicalPlan
def operation: String
def changes: Seq[TableChange] def changes: Seq[TableChange]
override def child: LogicalPlan = table override def child: LogicalPlan = table
} }
/**
* The logical plan of the ALTER TABLE ... ADD COLUMNS command.
*/
case class AlterTableAddColumns(
table: LogicalPlan,
columnsToAdd: Seq[QualifiedColType]) extends AlterTableColumnCommand {
columnsToAdd.foreach { c =>
TypeUtils.failWithIntervalType(c.dataType)
}
override lazy val resolved: Boolean = table.resolved && columnsToAdd.forall(_.resolved)
override def changes: Seq[TableChange] = {
columnsToAdd.map { col =>
require(col.path.forall(_.resolved),
"FieldName should be resolved before it's converted to TableChange.")
require(col.position.forall(_.resolved),
"FieldPosition should be resolved before it's converted to TableChange.")
TableChange.addColumn(
col.name.toArray,
col.dataType,
col.nullable,
col.comment.orNull,
col.position.map(_.position).orNull)
}
}
override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
copy(table = newChild)
}
/**
* The logical plan of the ALTER TABLE ... REPLACE COLUMNS command.
*/
case class AlterTableReplaceColumns(
table: LogicalPlan,
columnsToAdd: Seq[QualifiedColType]) extends AlterTableColumnCommand {
columnsToAdd.foreach { c =>
TypeUtils.failWithIntervalType(c.dataType)
}
override lazy val resolved: Boolean = table.resolved && columnsToAdd.forall(_.resolved)
override def changes: Seq[TableChange] = {
// REPLACE COLUMNS deletes all the existing columns and adds new columns specified.
require(table.resolved)
val deleteChanges = table.schema.fieldNames.map { name =>
TableChange.deleteColumn(Array(name))
}
val addChanges = columnsToAdd.map { col =>
assert(col.path.isEmpty)
assert(col.position.isEmpty)
TableChange.addColumn(
col.name.toArray,
col.dataType,
col.nullable,
col.comment.orNull,
null)
}
deleteChanges ++ addChanges
}
override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
copy(table = newChild)
}
/** /**
* The logical plan of the ALTER TABLE ... DROP COLUMNS command. * The logical plan of the ALTER TABLE ... DROP COLUMNS command.
*/ */
case class AlterTableDropColumns( case class AlterTableDropColumns(
table: LogicalPlan, table: LogicalPlan,
columnsToDrop: Seq[FieldName]) extends AlterTableCommand { columnsToDrop: Seq[FieldName]) extends AlterTableColumnCommand {
override def operation: String = "delete"
override def changes: Seq[TableChange] = { override def changes: Seq[TableChange] = {
columnsToDrop.map { col => columnsToDrop.map { col =>
require(col.resolved, "FieldName should be resolved before it's converted to TableChange.") require(col.resolved, "FieldName should be resolved before it's converted to TableChange.")
@ -1144,9 +1174,7 @@ case class AlterTableDropColumns(
case class AlterTableRenameColumn( case class AlterTableRenameColumn(
table: LogicalPlan, table: LogicalPlan,
column: FieldName, column: FieldName,
newName: String) extends AlterTableCommand { newName: String) extends AlterTableColumnCommand {
override def operation: String = "rename"
override def changes: Seq[TableChange] = { override def changes: Seq[TableChange] = {
require(column.resolved, "FieldName should be resolved before it's converted to TableChange.") require(column.resolved, "FieldName should be resolved before it's converted to TableChange.")
Seq(TableChange.renameColumn(column.name.toArray, newName)) Seq(TableChange.renameColumn(column.name.toArray, newName))
@ -1165,10 +1193,7 @@ case class AlterTableAlterColumn(
dataType: Option[DataType], dataType: Option[DataType],
nullable: Option[Boolean], nullable: Option[Boolean],
comment: Option[String], comment: Option[String],
position: Option[FieldPosition]) extends AlterTableCommand { position: Option[FieldPosition]) extends AlterTableColumnCommand {
override def operation: String = "update"
override def changes: Seq[TableChange] = { override def changes: Seq[TableChange] = {
require(column.resolved, "FieldName should be resolved before it's converted to TableChange.") require(column.resolved, "FieldName should be resolved before it's converted to TableChange.")
val colName = column.name.toArray val colName = column.name.toArray

View file

@ -22,8 +22,8 @@ import java.util.Collections
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import org.apache.spark.sql.catalyst.analysis.{NamedRelation, NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException, UnresolvedV2Relation} import org.apache.spark.sql.catalyst.analysis.{NamedRelation, NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException}
import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, CreateTableAsSelectStatement, CreateTableStatement, ReplaceTableAsSelectStatement, ReplaceTableStatement, SerdeInfo} import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelectStatement, CreateTableStatement, ReplaceTableAsSelectStatement, ReplaceTableStatement, SerdeInfo}
import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.TableChange._
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.{ArrayType, MapType, StructField, StructType} import org.apache.spark.sql.types.{ArrayType, MapType, StructField, StructType}
@ -356,17 +356,6 @@ private[sql] object CatalogV2Util {
properties ++ Map(TableCatalog.PROP_OWNER -> Utils.getCurrentUserName()) properties ++ Map(TableCatalog.PROP_OWNER -> Utils.getCurrentUserName())
} }
def createAlterTable(
originalNameParts: Seq[String],
catalog: CatalogPlugin,
tableName: Seq[String],
changes: Seq[TableChange]): AlterTable = {
val tableCatalog = catalog.asTableCatalog
val ident = tableName.asIdentifier
val unresolved = UnresolvedV2Relation(originalNameParts, tableCatalog, ident)
AlterTable(tableCatalog, ident, unresolved, changes)
}
def getTableProviderCatalog( def getTableProviderCatalog(
provider: SupportsCatalogOptions, provider: SupportsCatalogOptions,
catalogManager: CatalogManager, catalogManager: CatalogManager,

View file

@ -2352,12 +2352,12 @@ private[spark] object QueryCompilationErrors {
} }
def missingFieldError( def missingFieldError(
fieldName: Seq[String], table: ResolvedTable, context: Expression): Throwable = { fieldName: Seq[String], table: ResolvedTable, context: Origin): Throwable = {
throw new AnalysisException( throw new AnalysisException(
s"Missing field ${fieldName.quoted} in table ${table.name} with schema:\n" + s"Missing field ${fieldName.quoted} in table ${table.name} with schema:\n" +
table.schema.treeString, table.schema.treeString,
context.origin.line, context.line,
context.origin.startPosition) context.startPosition)
} }
def invalidFieldName(fieldName: Seq[String], path: Seq[String], context: Origin): Throwable = { def invalidFieldName(fieldName: Seq[String], path: Seq[String], context: Origin): Throwable = {

View file

@ -17,9 +17,11 @@
package org.apache.spark.sql.catalyst.analysis package org.apache.spark.sql.catalyst.analysis
import java.util
import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LeafNode} import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LeafNode}
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, TableCatalog} import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, Table, TableCapability, TableCatalog}
import org.apache.spark.sql.connector.expressions.Expressions import org.apache.spark.sql.connector.expressions.Expressions
import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructType} import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.util.CaseInsensitiveStringMap
@ -151,3 +153,8 @@ private[sql] case object TestRelation2 extends LeafNode with NamedRelation {
CreateTablePartitioningValidationSuite.schema.toAttributes CreateTablePartitioningValidationSuite.schema.toAttributes
} }
private[sql] case object TestTable2 extends Table {
override def name: String = "table_name"
override def schema: StructType = CreateTablePartitioningValidationSuite.schema
override def capabilities: util.Set[TableCapability] = new util.HashSet[TableCapability]()
}

View file

@ -788,88 +788,125 @@ class DDLParserSuite extends AnalysisTest {
test("alter table: add column") { test("alter table: add column") {
comparePlans( comparePlans(
parsePlan("ALTER TABLE table_name ADD COLUMN x int"), parsePlan("ALTER TABLE table_name ADD COLUMN x int"),
AlterTableAddColumnsStatement(Seq("table_name"), Seq( AlterTableAddColumns(
QualifiedColType(Seq("x"), IntegerType, true, None, None) UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMN", None),
Seq(QualifiedColType(None, "x", IntegerType, true, None, None)
))) )))
} }
test("alter table: add multiple columns") { test("alter table: add multiple columns") {
comparePlans( comparePlans(
parsePlan("ALTER TABLE table_name ADD COLUMNS x int, y string"), parsePlan("ALTER TABLE table_name ADD COLUMNS x int, y string"),
AlterTableAddColumnsStatement(Seq("table_name"), Seq( AlterTableAddColumns(
QualifiedColType(Seq("x"), IntegerType, true, None, None), UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMNS", None),
QualifiedColType(Seq("y"), StringType, true, None, None) Seq(QualifiedColType(None, "x", IntegerType, true, None, None),
QualifiedColType(None, "y", StringType, true, None, None)
))) )))
} }
test("alter table: add column with COLUMNS") { test("alter table: add column with COLUMNS") {
comparePlans( comparePlans(
parsePlan("ALTER TABLE table_name ADD COLUMNS x int"), parsePlan("ALTER TABLE table_name ADD COLUMNS x int"),
AlterTableAddColumnsStatement(Seq("table_name"), Seq( AlterTableAddColumns(
QualifiedColType(Seq("x"), IntegerType, true, None, None) UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMNS", None),
Seq(QualifiedColType(None, "x", IntegerType, true, None, None)
))) )))
} }
test("alter table: add column with COLUMNS (...)") { test("alter table: add column with COLUMNS (...)") {
comparePlans( comparePlans(
parsePlan("ALTER TABLE table_name ADD COLUMNS (x int)"), parsePlan("ALTER TABLE table_name ADD COLUMNS (x int)"),
AlterTableAddColumnsStatement(Seq("table_name"), Seq( AlterTableAddColumns(
QualifiedColType(Seq("x"), IntegerType, true, None, None) UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMNS", None),
Seq(QualifiedColType(None, "x", IntegerType, true, None, None)
))) )))
} }
test("alter table: add column with COLUMNS (...) and COMMENT") { test("alter table: add column with COLUMNS (...) and COMMENT") {
comparePlans( comparePlans(
parsePlan("ALTER TABLE table_name ADD COLUMNS (x int COMMENT 'doc')"), parsePlan("ALTER TABLE table_name ADD COLUMNS (x int COMMENT 'doc')"),
AlterTableAddColumnsStatement(Seq("table_name"), Seq( AlterTableAddColumns(
QualifiedColType(Seq("x"), IntegerType, true, Some("doc"), None) UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMNS", None),
Seq(QualifiedColType(None, "x", IntegerType, true, Some("doc"), None)
))) )))
} }
test("alter table: add non-nullable column") { test("alter table: add non-nullable column") {
comparePlans( comparePlans(
parsePlan("ALTER TABLE table_name ADD COLUMN x int NOT NULL"), parsePlan("ALTER TABLE table_name ADD COLUMN x int NOT NULL"),
AlterTableAddColumnsStatement(Seq("table_name"), Seq( AlterTableAddColumns(
QualifiedColType(Seq("x"), IntegerType, false, None, None) UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMN", None),
Seq(QualifiedColType(None, "x", IntegerType, false, None, None)
))) )))
} }
test("alter table: add column with COMMENT") { test("alter table: add column with COMMENT") {
comparePlans( comparePlans(
parsePlan("ALTER TABLE table_name ADD COLUMN x int COMMENT 'doc'"), parsePlan("ALTER TABLE table_name ADD COLUMN x int COMMENT 'doc'"),
AlterTableAddColumnsStatement(Seq("table_name"), Seq( AlterTableAddColumns(
QualifiedColType(Seq("x"), IntegerType, true, Some("doc"), None) UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMN", None),
Seq(QualifiedColType(None, "x", IntegerType, true, Some("doc"), None)
))) )))
} }
test("alter table: add column with position") { test("alter table: add column with position") {
comparePlans( comparePlans(
parsePlan("ALTER TABLE table_name ADD COLUMN x int FIRST"), parsePlan("ALTER TABLE table_name ADD COLUMN x int FIRST"),
AlterTableAddColumnsStatement(Seq("table_name"), Seq( AlterTableAddColumns(
QualifiedColType(Seq("x"), IntegerType, true, None, Some(first())) UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMN", None),
Seq(QualifiedColType(
None,
"x",
IntegerType,
true,
None,
Some(UnresolvedFieldPosition(first())))
))) )))
comparePlans( comparePlans(
parsePlan("ALTER TABLE table_name ADD COLUMN x int AFTER y"), parsePlan("ALTER TABLE table_name ADD COLUMN x int AFTER y"),
AlterTableAddColumnsStatement(Seq("table_name"), Seq( AlterTableAddColumns(
QualifiedColType(Seq("x"), IntegerType, true, None, Some(after("y"))) UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMN", None),
Seq(QualifiedColType(
None,
"x",
IntegerType,
true,
None,
Some(UnresolvedFieldPosition(after("y"))))
))) )))
} }
test("alter table: add column with nested column name") { test("alter table: add column with nested column name") {
comparePlans( comparePlans(
parsePlan("ALTER TABLE table_name ADD COLUMN x.y.z int COMMENT 'doc'"), parsePlan("ALTER TABLE table_name ADD COLUMN x.y.z int COMMENT 'doc'"),
AlterTableAddColumnsStatement(Seq("table_name"), Seq( AlterTableAddColumns(
QualifiedColType(Seq("x", "y", "z"), IntegerType, true, Some("doc"), None) UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMN", None),
Seq(QualifiedColType(
Some(UnresolvedFieldName(Seq("x", "y"))), "z", IntegerType, true, Some("doc"), None)
))) )))
} }
test("alter table: add multiple columns with nested column name") { test("alter table: add multiple columns with nested column name") {
comparePlans( comparePlans(
parsePlan("ALTER TABLE table_name ADD COLUMN x.y.z int COMMENT 'doc', a.b string FIRST"), parsePlan("ALTER TABLE table_name ADD COLUMN x.y.z int COMMENT 'doc', a.b string FIRST"),
AlterTableAddColumnsStatement(Seq("table_name"), Seq( AlterTableAddColumns(
QualifiedColType(Seq("x", "y", "z"), IntegerType, true, Some("doc"), None), UnresolvedTable(Seq("table_name"), "ALTER TABLE ... ADD COLUMN", None),
QualifiedColType(Seq("a", "b"), StringType, true, None, Some(first())) Seq(
QualifiedColType(
Some(UnresolvedFieldName(Seq("x", "y"))),
"z",
IntegerType,
true,
Some("doc"),
None),
QualifiedColType(
Some(UnresolvedFieldName(Seq("a"))),
"b",
StringType,
true,
None,
Some(UnresolvedFieldPosition(first())))
))) )))
} }
@ -1062,32 +1099,32 @@ class DDLParserSuite extends AnalysisTest {
comparePlans( comparePlans(
parsePlan(sql1), parsePlan(sql1),
AlterTableReplaceColumnsStatement( AlterTableReplaceColumns(
Seq("table_name"), UnresolvedTable(Seq("table_name"), "ALTER TABLE ... REPLACE COLUMNS", None),
Seq(QualifiedColType(Seq("x"), StringType, true, None, None)))) Seq(QualifiedColType(None, "x", StringType, true, None, None))))
comparePlans( comparePlans(
parsePlan(sql2), parsePlan(sql2),
AlterTableReplaceColumnsStatement( AlterTableReplaceColumns(
Seq("table_name"), UnresolvedTable(Seq("table_name"), "ALTER TABLE ... REPLACE COLUMNS", None),
Seq(QualifiedColType(Seq("x"), StringType, true, Some("x1"), None)))) Seq(QualifiedColType(None, "x", StringType, true, Some("x1"), None))))
comparePlans( comparePlans(
parsePlan(sql3), parsePlan(sql3),
AlterTableReplaceColumnsStatement( AlterTableReplaceColumns(
Seq("table_name"), UnresolvedTable(Seq("table_name"), "ALTER TABLE ... REPLACE COLUMNS", None),
Seq( Seq(
QualifiedColType(Seq("x"), StringType, true, Some("x1"), None), QualifiedColType(None, "x", StringType, true, Some("x1"), None),
QualifiedColType(Seq("y"), IntegerType, true, None, None) QualifiedColType(None, "y", IntegerType, true, None, None)
))) )))
comparePlans( comparePlans(
parsePlan(sql4), parsePlan(sql4),
AlterTableReplaceColumnsStatement( AlterTableReplaceColumns(
Seq("table_name"), UnresolvedTable(Seq("table_name"), "ALTER TABLE ... REPLACE COLUMNS", None),
Seq( Seq(
QualifiedColType(Seq("x"), StringType, true, Some("x1"), None), QualifiedColType(None, "x", StringType, true, Some("x1"), None),
QualifiedColType(Seq("y"), IntegerType, true, Some("y1"), None) QualifiedColType(None, "y", IntegerType, true, Some("y1"), None)
))) )))
intercept("ALTER TABLE table_name PARTITION (a='1') REPLACE COLUMNS (x string)", intercept("ALTER TABLE table_name PARTITION (a='1') REPLACE COLUMNS (x string)",
@ -1098,6 +1135,9 @@ class DDLParserSuite extends AnalysisTest {
intercept("ALTER TABLE table_name REPLACE COLUMNS (x string FIRST)", intercept("ALTER TABLE table_name REPLACE COLUMNS (x string FIRST)",
"Column position is not supported in Hive-style REPLACE COLUMNS") "Column position is not supported in Hive-style REPLACE COLUMNS")
intercept("ALTER TABLE table_name REPLACE COLUMNS (a.b.c string)",
"Replacing with a nested column is not supported in Hive-style REPLACE COLUMNS")
} }
test("alter view: rename view") { test("alter view: rename view") {

View file

@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute}
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, toPrettySQL} import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, toPrettySQL}
import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, Identifier, LookupCatalog, SupportsNamespaces, TableChange, V1Table} import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, Identifier, LookupCatalog, SupportsNamespaces, V1Table}
import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.command._
@ -46,51 +46,17 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case AlterTableAddColumnsStatement( case AlterTableAddColumns(ResolvedV1TableIdentifier(ident), cols) =>
nameParts @ SessionCatalogAndTable(catalog, tbl), cols) => cols.foreach { c =>
loadTable(catalog, tbl.asIdentifier).collect { assertTopLevelColumn(c.name, "AlterTableAddColumnsCommand")
case v1Table: V1Table => if (!c.nullable) {
cols.foreach { c => throw QueryCompilationErrors.addColumnWithV1TableCannotSpecifyNotNullError
assertTopLevelColumn(c.name, "AlterTableAddColumnsCommand")
if (!c.nullable) {
throw QueryCompilationErrors.addColumnWithV1TableCannotSpecifyNotNullError
}
}
AlterTableAddColumnsCommand(tbl.asTableIdentifier, cols.map(convertToStructField))
}.getOrElse {
val changes = cols.map { col =>
TableChange.addColumn(
col.name.toArray,
col.dataType,
col.nullable,
col.comment.orNull,
col.position.orNull)
} }
createAlterTable(nameParts, catalog, tbl, changes)
} }
AlterTableAddColumnsCommand(ident.asTableIdentifier, cols.map(convertToStructField))
case AlterTableReplaceColumnsStatement( case AlterTableReplaceColumns(ResolvedV1TableIdentifier(_), _) =>
nameParts @ SessionCatalogAndTable(catalog, tbl), cols) => throw QueryCompilationErrors.replaceColumnsOnlySupportedWithV2TableError
val changes: Seq[TableChange] = loadTable(catalog, tbl.asIdentifier) match {
case Some(_: V1Table) =>
throw QueryCompilationErrors.replaceColumnsOnlySupportedWithV2TableError
case Some(table) =>
// REPLACE COLUMNS deletes all the existing columns and adds new columns specified.
val deleteChanges = table.schema.fieldNames.map { name =>
TableChange.deleteColumn(Array(name))
}
val addChanges = cols.map { col =>
TableChange.addColumn(
col.name.toArray,
col.dataType,
col.nullable,
col.comment.orNull,
col.position.orNull)
}
deleteChanges ++ addChanges
case None => Seq() // Unresolved table will be handled in CheckAnalysis.
}
createAlterTable(nameParts, catalog, tbl, changes)
case a @ AlterTableAlterColumn(ResolvedV1TableAndIdentifier(table, ident), _, _, _, _, _) => case a @ AlterTableAlterColumn(ResolvedV1TableAndIdentifier(table, ident), _, _, _, _, _) =>
if (a.column.name.length > 1) { if (a.column.name.length > 1) {

View file

@ -288,9 +288,6 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
case _: NoopCommand => case _: NoopCommand =>
LocalTableScanExec(Nil, Nil) :: Nil LocalTableScanExec(Nil, Nil) :: Nil
case AlterTable(catalog, ident, _, changes) =>
AlterTableExec(catalog, ident, changes) :: Nil
case RenameTable(r @ ResolvedTable(catalog, oldIdent, _, _), newIdent, isView) => case RenameTable(r @ ResolvedTable(catalog, oldIdent, _, _), newIdent, isView) =>
if (isView) { if (isView) {
throw QueryCompilationErrors.cannotRenameTableWithAlterViewError() throw QueryCompilationErrors.cannotRenameTableWithAlterViewError()
@ -445,7 +442,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
val changes = keys.map(key => TableChange.removeProperty(key)) val changes = keys.map(key => TableChange.removeProperty(key))
AlterTableExec(table.catalog, table.identifier, changes) :: Nil AlterTableExec(table.catalog, table.identifier, changes) :: Nil
case a: AlterTableCommand if a.table.resolved => case a: AlterTableColumnCommand if a.table.resolved =>
val table = a.table.asInstanceOf[ResolvedTable] val table = a.table.asInstanceOf[ResolvedTable]
AlterTableExec(table.catalog, table.identifier, a.changes) :: Nil AlterTableExec(table.catalog, table.identifier, a.changes) :: Nil

View file

@ -356,8 +356,7 @@ trait AlterTableTests extends SharedSparkSession {
sql(s"ALTER TABLE $t ADD COLUMN point.z double") sql(s"ALTER TABLE $t ADD COLUMN point.z double")
} }
assert(exc.getMessage.contains("point")) assert(exc.getMessage.contains("Missing field point"))
assert(exc.getMessage.contains("missing field"))
} }
} }

View file

@ -17,10 +17,10 @@
package org.apache.spark.sql.connector package org.apache.spark.sql.connector
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, CreateTablePartitioningValidationSuite, ResolvedTable, TestRelation2, UnresolvedFieldName} import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, CreateTablePartitioningValidationSuite, ResolvedTable, TestRelation2, TestTable2, UnresolvedFieldName, UnresolvedFieldPosition}
import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, AlterTableAlterColumn, AlterTableCommand, AlterTableDropColumns, AlterTableRenameColumn, CreateTableAsSelect, LogicalPlan, ReplaceTableAsSelect} import org.apache.spark.sql.catalyst.plans.logical.{AlterTableAddColumns, AlterTableAlterColumn, AlterTableColumnCommand, AlterTableDropColumns, AlterTableRenameColumn, CreateTableAsSelect, LogicalPlan, QualifiedColType, ReplaceTableAsSelect}
import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.{Identifier, TableChange} import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
import org.apache.spark.sql.connector.expressions.Expressions import org.apache.spark.sql.connector.expressions.Expressions
import org.apache.spark.sql.execution.datasources.PreprocessTableCreation import org.apache.spark.sql.execution.datasources.PreprocessTableCreation
@ -35,7 +35,7 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes
private val table = ResolvedTable( private val table = ResolvedTable(
catalog, catalog,
Identifier.of(Array(), "table_name"), Identifier.of(Array(), "table_name"),
null, TestTable2,
schema.toAttributes) schema.toAttributes)
override protected def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = { override protected def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = {
@ -140,8 +140,11 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes
Seq("POINT.Z", "poInt.z", "poInt.Z").foreach { ref => Seq("POINT.Z", "poInt.z", "poInt.Z").foreach { ref =>
val field = ref.split("\\.") val field = ref.split("\\.")
alterTableTest( alterTableTest(
TableChange.addColumn(field, LongType), AlterTableAddColumns(
Seq("add", field.head) table,
Seq(QualifiedColType(
Some(UnresolvedFieldName(field.init)), field.last, LongType, true, None, None))),
Seq("Missing field " + field.head)
) )
} }
} }
@ -149,8 +152,15 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes
test("AlterTable: add column resolution - positional") { test("AlterTable: add column resolution - positional") {
Seq("ID", "iD").foreach { ref => Seq("ID", "iD").foreach { ref =>
alterTableTest( alterTableTest(
TableChange.addColumn( AlterTableAddColumns(
Array("f"), LongType, true, null, ColumnPosition.after(ref)), table,
Seq(QualifiedColType(
None,
"f",
LongType,
true,
None,
Some(UnresolvedFieldPosition(ColumnPosition.after(ref)))))),
Seq("reference column", ref) Seq("reference column", ref)
) )
} }
@ -158,11 +168,22 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes
test("AlterTable: add column resolution - column position referencing new column") { test("AlterTable: add column resolution - column position referencing new column") {
alterTableTest( alterTableTest(
Seq( AlterTableAddColumns(
TableChange.addColumn( table,
Array("x"), LongType, true, null, ColumnPosition.after("id")), Seq(QualifiedColType(
TableChange.addColumn( None,
Array("y"), LongType, true, null, ColumnPosition.after("X"))), "x",
LongType,
true,
None,
Some(UnresolvedFieldPosition(ColumnPosition.after("id")))),
QualifiedColType(
None,
"x",
LongType,
true,
None,
Some(UnresolvedFieldPosition(ColumnPosition.after("X")))))),
Seq("Couldn't find the reference column for AFTER X at root") Seq("Couldn't find the reference column for AFTER X at root")
) )
} }
@ -170,8 +191,15 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes
test("AlterTable: add column resolution - nested positional") { test("AlterTable: add column resolution - nested positional") {
Seq("X", "Y").foreach { ref => Seq("X", "Y").foreach { ref =>
alterTableTest( alterTableTest(
TableChange.addColumn( AlterTableAddColumns(
Array("point", "z"), LongType, true, null, ColumnPosition.after(ref)), table,
Seq(QualifiedColType(
Some(UnresolvedFieldName(Seq("point"))),
"z",
LongType,
true,
None,
Some(UnresolvedFieldPosition(ColumnPosition.after(ref)))))),
Seq("reference column", ref) Seq("reference column", ref)
) )
} }
@ -179,11 +207,22 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes
test("AlterTable: add column resolution - column position referencing new nested column") { test("AlterTable: add column resolution - column position referencing new nested column") {
alterTableTest( alterTableTest(
Seq( AlterTableAddColumns(
TableChange.addColumn( table,
Array("point", "z"), LongType, true, null), Seq(QualifiedColType(
TableChange.addColumn( Some(UnresolvedFieldName(Seq("point"))),
Array("point", "zz"), LongType, true, null, ColumnPosition.after("Z"))), "z",
LongType,
true,
None,
None),
QualifiedColType(
Some(UnresolvedFieldName(Seq("point"))),
"zz",
LongType,
true,
None,
Some(UnresolvedFieldPosition(ColumnPosition.after("Z")))))),
Seq("Couldn't find the reference column for AFTER Z at point") Seq("Couldn't find the reference column for AFTER Z at point")
) )
} }
@ -233,30 +272,7 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes
} }
} }
private def alterTableTest(change: TableChange, error: Seq[String]): Unit = { private def alterTableTest(alter: AlterTableColumnCommand, error: Seq[String]): Unit = {
alterTableTest(Seq(change), error)
}
private def alterTableTest(changes: Seq[TableChange], error: Seq[String]): Unit = {
Seq(true, false).foreach { caseSensitive =>
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
val plan = AlterTable(
catalog,
Identifier.of(Array(), "table_name"),
TestRelation2,
changes
)
if (caseSensitive) {
assertAnalysisError(plan, error, caseSensitive)
} else {
assertAnalysisSuccess(plan, caseSensitive)
}
}
}
}
private def alterTableTest(alter: AlterTableCommand, error: Seq[String]): Unit = {
Seq(true, false).foreach { caseSensitive => Seq(true, false).foreach { caseSensitive =>
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
if (caseSensitive) { if (caseSensitive) {

View file

@ -2305,7 +2305,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
val e = intercept[AnalysisException] { val e = intercept[AnalysisException] {
sql("ALTER TABLE tmp_v ADD COLUMNS (c3 INT)") sql("ALTER TABLE tmp_v ADD COLUMNS (c3 INT)")
} }
assert(e.message.contains("'tmp_v' is a view not a table")) assert(e.message.contains(
"tmp_v is a temp view. 'ALTER TABLE ... ADD COLUMNS' expects a table."))
} }
} }
@ -2315,7 +2316,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
val e = intercept[AnalysisException] { val e = intercept[AnalysisException] {
sql("ALTER TABLE v1 ADD COLUMNS (c3 INT)") sql("ALTER TABLE v1 ADD COLUMNS (c3 INT)")
} }
assert(e.message.contains("ALTER ADD COLUMNS does not support views")) assert(e.message.contains(
"default.v1 is a view. 'ALTER TABLE ... ADD COLUMNS' expects a table."))
} }
} }