From 14328e043d0233800869d5435291b3c0d4a65aa1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 27 Jul 2021 13:57:05 +0800 Subject: [PATCH] [SPARK-36247][SQL] Check string length for char/varchar and apply type coercion in UPDATE/MERGE command ### What changes were proposed in this pull request? We added the char/varchar support in 3.1, but the string length check is only applied to INSERT, not UPDATE/MERGE. This PR fixes it. This PR also adds the missing type coercion for UPDATE/MERGE. ### Why are the changes needed? complete the char/varchar support and make UPDATE/MERGE easier to use by doing type coercion. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new UT. No built-in source support UPDATE/MERGE so end-to-end test is not applicable here. Closes #33468 from cloud-fan/char. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan (cherry picked from commit 068f8d434ad9a6651006151de521d0799db8af52) Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 43 +++++++++-- .../sql/catalyst/analysis/unresolved.scala | 1 + .../expressions/namedExpressions.scala | 7 ++ .../catalyst/plans/logical/v2Commands.scala | 17 ++++- .../sql/catalyst/util/CharVarcharUtils.scala | 2 +- .../command/PlanResolutionSuite.scala | 76 +++++++++++++++++-- 6 files changed, 130 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ed7ad7f6d8..ee7b342c53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3300,6 +3300,41 @@ class Analyzer(override val catalogManager: CatalogManager) } else { v2Write } + + case u: UpdateTable if !u.skipSchemaResolution && u.resolved => + resolveAssignments(u) + + case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved => + resolveAssignments(m) + } + + private def resolveAssignments(p: LogicalPlan): LogicalPlan = { + p.transformExpressions { + case assignment: Assignment => + val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) { + AssertNotNull(assignment.value) + } else { + assignment.value + } + val casted = if (assignment.key.dataType != nullHandled.dataType) { + AnsiCast(nullHandled, assignment.key.dataType) + } else { + nullHandled + } + val rawKeyType = assignment.key.transform { + case a: AttributeReference => + CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a) + }.dataType + val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) { + CharVarcharUtils.stringLengthCheck(casted, rawKeyType) + } else { + casted + } + val cleanedKey = assignment.key.transform { + case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a) + } + Assignment(cleanedKey, finalValue) + } } } @@ -4218,14 +4253,6 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { } } - private def padOuterRefAttrCmp(outerAttr: Attribute, attr: Attribute): Seq[Expression] = { - val Seq(r, newAttr) = CharVarcharUtils.addPaddingInStringComparison(Seq(outerAttr, attr)) - val newOuterRef = r.transform { - case ar: Attribute if ar.semanticEquals(outerAttr) => OuterReference(ar) - } - Seq(newOuterRef, newAttr) - } - private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = { if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 29d5410906..9f05367035 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -168,6 +168,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) override def withMetadata(newMetadata: Metadata): Attribute = this override def withExprId(newExprId: ExprId): UnresolvedAttribute = this + override def withDataType(newType: DataType): Attribute = this final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_ATTRIBUTE) override def toString: String = s"'$name" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 2b8265f57b..ae2c66c86e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -123,6 +123,7 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn def withName(newName: String): Attribute def withMetadata(newMetadata: Metadata): Attribute def withExprId(newExprId: ExprId): Attribute + def withDataType(newType: DataType): Attribute override def toAttribute: Attribute = this def newInstance(): Attribute @@ -339,6 +340,10 @@ case class AttributeReference( AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier) } + override def withDataType(newType: DataType): Attribute = { + AttributeReference(name, newType, nullable, metadata)(exprId, qualifier) + } + override protected final def otherCopyArgs: Seq[AnyRef] = { exprId :: qualifier :: Nil } @@ -395,6 +400,8 @@ case class PrettyAttribute( override def exprId: ExprId = throw new UnsupportedOperationException override def withExprId(newExprId: ExprId): Attribute = throw new UnsupportedOperationException + override def withDataType(newType: DataType): Attribute = + throw new UnsupportedOperationException override def nullable: Boolean = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 3d88d6232c..fa897a84e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -425,6 +425,12 @@ case class UpdateTable( override def child: LogicalPlan = table override protected def withNewChildInternal(newChild: LogicalPlan): UpdateTable = copy(table = newChild) + + def skipSchemaResolution: Boolean = table match { + case r: NamedRelation => r.skipSchemaResolution + case SubqueryAlias(_, r: NamedRelation) => r.skipSchemaResolution + case _ => false + } } /** @@ -437,6 +443,13 @@ case class MergeIntoTable( matchedActions: Seq[MergeAction], notMatchedActions: Seq[MergeAction]) extends BinaryCommand with SupportsSubquery { def duplicateResolved: Boolean = targetTable.outputSet.intersect(sourceTable.outputSet).isEmpty + + def skipSchemaResolution: Boolean = targetTable match { + case r: NamedRelation => r.skipSchemaResolution + case SubqueryAlias(_, r: NamedRelation) => r.skipSchemaResolution + case _ => false + } + override def left: LogicalPlan = targetTable override def right: LogicalPlan = sourceTable override protected def withNewChildrenInternal( @@ -466,7 +479,7 @@ case class UpdateAction( newChildren: IndexedSeq[Expression]): UpdateAction = copy( condition = if (condition.isDefined) Some(newChildren.head) else None, - assignments = newChildren.tail.asInstanceOf[Seq[Assignment]]) + assignments = newChildren.takeRight(assignments.length).asInstanceOf[Seq[Assignment]]) } case class UpdateStarAction(condition: Option[Expression]) extends MergeAction { @@ -485,7 +498,7 @@ case class InsertAction( newChildren: IndexedSeq[Expression]): InsertAction = copy( condition = if (condition.isDefined) Some(newChildren.head) else None, - assignments = newChildren.tail.asInstanceOf[Seq[Assignment]]) + assignments = newChildren.takeRight(assignments.length).asInstanceOf[Seq[Assignment]]) } case class InsertStarAction(condition: Option[Expression]) extends MergeAction { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala index a5667756cd..3094b5ff81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala @@ -149,7 +149,7 @@ object CharVarcharUtils extends Logging { }.getOrElse(expr) } - private def stringLengthCheck(expr: Expression, dt: DataType): Expression = { + def stringLengthCheck(expr: Expression, dt: DataType): Expression = { dt match { case CharType(length) => StaticInvoke( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index e714fb3b13..25a8c4ea81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -28,7 +28,8 @@ import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolvedFieldName, ResolvedTable, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTable} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral} +import org.apache.spark.sql.catalyst.expressions.{AnsiCast, AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral} +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.catalyst.plans.logical.{AlterTableAlterColumn, AnalysisOnlyCommand, AppendData, Assignment, CreateTableAsSelect, CreateTableStatement, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction, UpdateTable} import org.apache.spark.sql.catalyst.rules.Rule @@ -75,6 +76,13 @@ class PlanResolutionSuite extends AnalysisTest { t } + private val charVarcharTable: Table = { + val t = mock(classOf[Table]) + when(t.schema()).thenReturn(new StructType().add("c1", "char(5)").add("c2", "varchar(5)")) + when(t.partitioning()).thenReturn(Array.empty[Transform]) + t + } + private val v1Table: V1Table = { val t = mock(classOf[CatalogTable]) when(t.schema).thenReturn(new StructType() @@ -109,6 +117,7 @@ class PlanResolutionSuite extends AnalysisTest { case "tab" => table case "tab1" => table1 case "tab2" => table2 + case "charvarchar" => charVarcharTable case name => throw new NoSuchTableException(name) } }) @@ -1058,12 +1067,33 @@ class PlanResolutionSuite extends AnalysisTest { } } - val sql = "UPDATE non_existing SET id=1" - val parsed = parseAndResolve(sql) - parsed match { + val sql1 = "UPDATE non_existing SET id=1" + val parsed1 = parseAndResolve(sql1) + parsed1 match { case u: UpdateTable => assert(u.table.isInstanceOf[UnresolvedRelation]) - case _ => fail("Expect UpdateTable, but got:\n" + parsed.treeString) + case _ => fail("Expect UpdateTable, but got:\n" + parsed1.treeString) + } + + val sql2 = "UPDATE testcat.charvarchar SET c1='a', c2=1" + val parsed2 = parseAndResolve(sql2) + parsed2 match { + case u: UpdateTable => + assert(u.assignments.length == 2) + u.assignments(0).value match { + case s: StaticInvoke => + assert(s.arguments.length == 2) + assert(s.functionName == "charTypeWriteSideCheck") + case other => fail("Expect StaticInvoke, but got: " + other) + } + u.assignments(1).value match { + case s: StaticInvoke => + assert(s.arguments.length == 2) + assert(s.arguments.head.isInstanceOf[AnsiCast]) + assert(s.functionName == "varcharTypeWriteSideCheck") + case other => fail("Expect StaticInvoke, but got: " + other) + } + case _ => fail("Expect UpdateTable, but got:\n" + parsed2.treeString) } } @@ -1568,6 +1598,42 @@ class PlanResolutionSuite extends AnalysisTest { val e3 = intercept[AnalysisException](parseAndResolve(sql3)) assert(e3.message.contains( "cannot resolve s in MERGE command given columns [testcat.tab2.i, testcat.tab2.x]")) + + val sql4 = + """ + |MERGE INTO testcat.charvarchar + |USING testcat.tab2 + |ON 1 = 1 + |WHEN MATCHED THEN UPDATE SET c1='a', c2=1 + |WHEN NOT MATCHED THEN INSERT (c1, c2) VALUES ('b', 2) + |""".stripMargin + val parsed4 = parseAndResolve(sql4) + parsed4 match { + case m: MergeIntoTable => + assert(m.matchedActions.length == 1) + m.matchedActions.head match { + case UpdateAction(_, Seq( + Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke))) => + assert(s1.arguments.length == 2) + assert(s1.functionName == "charTypeWriteSideCheck") + assert(s2.arguments.length == 2) + assert(s2.arguments.head.isInstanceOf[AnsiCast]) + assert(s2.functionName == "varcharTypeWriteSideCheck") + case other => fail("Expect UpdateAction, but got: " + other) + } + assert(m.notMatchedActions.length == 1) + m.notMatchedActions.head match { + case InsertAction(_, Seq( + Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke))) => + assert(s1.arguments.length == 2) + assert(s1.functionName == "charTypeWriteSideCheck") + assert(s2.arguments.length == 2) + assert(s2.arguments.head.isInstanceOf[AnsiCast]) + assert(s2.functionName == "varcharTypeWriteSideCheck") + case other => fail("Expect UpdateAction, but got: " + other) + } + case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString) + } } test("MERGE INTO TABLE - skip resolution on v2 tables that accept any schema") {