[SPARK-36247][SQL] Check string length for char/varchar and apply type coercion in UPDATE/MERGE command

### What changes were proposed in this pull request?

We added the char/varchar support in 3.1, but the string length check is only applied to INSERT, not UPDATE/MERGE. This PR fixes it. This PR also adds the missing type coercion for UPDATE/MERGE.

### Why are the changes needed?

complete the char/varchar support and make UPDATE/MERGE easier to use by doing type coercion.

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

new UT. No built-in source support UPDATE/MERGE so end-to-end test is not applicable here.

Closes #33468 from cloud-fan/char.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Wenchen Fan 2021-07-27 13:57:05 +08:00
parent 9a47483f74
commit 068f8d434a
6 changed files with 130 additions and 16 deletions

View file

@ -3300,6 +3300,41 @@ class Analyzer(override val catalogManager: CatalogManager)
} else {
v2Write
}
case u: UpdateTable if !u.skipSchemaResolution && u.resolved =>
resolveAssignments(u)
case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved =>
resolveAssignments(m)
}
private def resolveAssignments(p: LogicalPlan): LogicalPlan = {
p.transformExpressions {
case assignment: Assignment =>
val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) {
AssertNotNull(assignment.value)
} else {
assignment.value
}
val casted = if (assignment.key.dataType != nullHandled.dataType) {
AnsiCast(nullHandled, assignment.key.dataType)
} else {
nullHandled
}
val rawKeyType = assignment.key.transform {
case a: AttributeReference =>
CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a)
}.dataType
val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) {
CharVarcharUtils.stringLengthCheck(casted, rawKeyType)
} else {
casted
}
val cleanedKey = assignment.key.transform {
case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a)
}
Assignment(cleanedKey, finalValue)
}
}
}
@ -4218,14 +4253,6 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
}
}
private def padOuterRefAttrCmp(outerAttr: Attribute, attr: Attribute): Seq[Expression] = {
val Seq(r, newAttr) = CharVarcharUtils.addPaddingInStringComparison(Seq(outerAttr, attr))
val newOuterRef = r.transform {
case ar: Attribute if ar.semanticEquals(outerAttr) => OuterReference(ar)
}
Seq(newOuterRef, newAttr)
}
private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = {
if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr
}

View file

@ -168,6 +168,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)
override def withMetadata(newMetadata: Metadata): Attribute = this
override def withExprId(newExprId: ExprId): UnresolvedAttribute = this
override def withDataType(newType: DataType): Attribute = this
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_ATTRIBUTE)
override def toString: String = s"'$name"

View file

@ -123,6 +123,7 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn
def withName(newName: String): Attribute
def withMetadata(newMetadata: Metadata): Attribute
def withExprId(newExprId: ExprId): Attribute
def withDataType(newType: DataType): Attribute
override def toAttribute: Attribute = this
def newInstance(): Attribute
@ -339,6 +340,10 @@ case class AttributeReference(
AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier)
}
override def withDataType(newType: DataType): Attribute = {
AttributeReference(name, newType, nullable, metadata)(exprId, qualifier)
}
override protected final def otherCopyArgs: Seq[AnyRef] = {
exprId :: qualifier :: Nil
}
@ -395,6 +400,8 @@ case class PrettyAttribute(
override def exprId: ExprId = throw new UnsupportedOperationException
override def withExprId(newExprId: ExprId): Attribute =
throw new UnsupportedOperationException
override def withDataType(newType: DataType): Attribute =
throw new UnsupportedOperationException
override def nullable: Boolean = true
}

View file

@ -425,6 +425,12 @@ case class UpdateTable(
override def child: LogicalPlan = table
override protected def withNewChildInternal(newChild: LogicalPlan): UpdateTable =
copy(table = newChild)
def skipSchemaResolution: Boolean = table match {
case r: NamedRelation => r.skipSchemaResolution
case SubqueryAlias(_, r: NamedRelation) => r.skipSchemaResolution
case _ => false
}
}
/**
@ -437,6 +443,13 @@ case class MergeIntoTable(
matchedActions: Seq[MergeAction],
notMatchedActions: Seq[MergeAction]) extends BinaryCommand with SupportsSubquery {
def duplicateResolved: Boolean = targetTable.outputSet.intersect(sourceTable.outputSet).isEmpty
def skipSchemaResolution: Boolean = targetTable match {
case r: NamedRelation => r.skipSchemaResolution
case SubqueryAlias(_, r: NamedRelation) => r.skipSchemaResolution
case _ => false
}
override def left: LogicalPlan = targetTable
override def right: LogicalPlan = sourceTable
override protected def withNewChildrenInternal(
@ -466,7 +479,7 @@ case class UpdateAction(
newChildren: IndexedSeq[Expression]): UpdateAction =
copy(
condition = if (condition.isDefined) Some(newChildren.head) else None,
assignments = newChildren.tail.asInstanceOf[Seq[Assignment]])
assignments = newChildren.takeRight(assignments.length).asInstanceOf[Seq[Assignment]])
}
case class UpdateStarAction(condition: Option[Expression]) extends MergeAction {
@ -485,7 +498,7 @@ case class InsertAction(
newChildren: IndexedSeq[Expression]): InsertAction =
copy(
condition = if (condition.isDefined) Some(newChildren.head) else None,
assignments = newChildren.tail.asInstanceOf[Seq[Assignment]])
assignments = newChildren.takeRight(assignments.length).asInstanceOf[Seq[Assignment]])
}
case class InsertStarAction(condition: Option[Expression]) extends MergeAction {

View file

@ -149,7 +149,7 @@ object CharVarcharUtils extends Logging {
}.getOrElse(expr)
}
private def stringLengthCheck(expr: Expression, dt: DataType): Expression = {
def stringLengthCheck(expr: Expression, dt: DataType): Expression = {
dt match {
case CharType(length) =>
StaticInvoke(

View file

@ -28,7 +28,8 @@ import org.apache.spark.sql.{AnalysisException, SaveMode}
import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolvedFieldName, ResolvedTable, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTable}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral}
import org.apache.spark.sql.catalyst.expressions.{AnsiCast, AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.catalyst.plans.logical.{AlterTableAlterColumn, AnalysisOnlyCommand, AppendData, Assignment, CreateTableAsSelect, CreateTableStatement, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction, UpdateTable}
import org.apache.spark.sql.catalyst.rules.Rule
@ -75,6 +76,13 @@ class PlanResolutionSuite extends AnalysisTest {
t
}
private val charVarcharTable: Table = {
val t = mock(classOf[Table])
when(t.schema()).thenReturn(new StructType().add("c1", "char(5)").add("c2", "varchar(5)"))
when(t.partitioning()).thenReturn(Array.empty[Transform])
t
}
private val v1Table: V1Table = {
val t = mock(classOf[CatalogTable])
when(t.schema).thenReturn(new StructType()
@ -109,6 +117,7 @@ class PlanResolutionSuite extends AnalysisTest {
case "tab" => table
case "tab1" => table1
case "tab2" => table2
case "charvarchar" => charVarcharTable
case name => throw new NoSuchTableException(name)
}
})
@ -1058,12 +1067,33 @@ class PlanResolutionSuite extends AnalysisTest {
}
}
val sql = "UPDATE non_existing SET id=1"
val parsed = parseAndResolve(sql)
parsed match {
val sql1 = "UPDATE non_existing SET id=1"
val parsed1 = parseAndResolve(sql1)
parsed1 match {
case u: UpdateTable =>
assert(u.table.isInstanceOf[UnresolvedRelation])
case _ => fail("Expect UpdateTable, but got:\n" + parsed.treeString)
case _ => fail("Expect UpdateTable, but got:\n" + parsed1.treeString)
}
val sql2 = "UPDATE testcat.charvarchar SET c1='a', c2=1"
val parsed2 = parseAndResolve(sql2)
parsed2 match {
case u: UpdateTable =>
assert(u.assignments.length == 2)
u.assignments(0).value match {
case s: StaticInvoke =>
assert(s.arguments.length == 2)
assert(s.functionName == "charTypeWriteSideCheck")
case other => fail("Expect StaticInvoke, but got: " + other)
}
u.assignments(1).value match {
case s: StaticInvoke =>
assert(s.arguments.length == 2)
assert(s.arguments.head.isInstanceOf[AnsiCast])
assert(s.functionName == "varcharTypeWriteSideCheck")
case other => fail("Expect StaticInvoke, but got: " + other)
}
case _ => fail("Expect UpdateTable, but got:\n" + parsed2.treeString)
}
}
@ -1568,6 +1598,42 @@ class PlanResolutionSuite extends AnalysisTest {
val e3 = intercept[AnalysisException](parseAndResolve(sql3))
assert(e3.message.contains(
"cannot resolve s in MERGE command given columns [testcat.tab2.i, testcat.tab2.x]"))
val sql4 =
"""
|MERGE INTO testcat.charvarchar
|USING testcat.tab2
|ON 1 = 1
|WHEN MATCHED THEN UPDATE SET c1='a', c2=1
|WHEN NOT MATCHED THEN INSERT (c1, c2) VALUES ('b', 2)
|""".stripMargin
val parsed4 = parseAndResolve(sql4)
parsed4 match {
case m: MergeIntoTable =>
assert(m.matchedActions.length == 1)
m.matchedActions.head match {
case UpdateAction(_, Seq(
Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke))) =>
assert(s1.arguments.length == 2)
assert(s1.functionName == "charTypeWriteSideCheck")
assert(s2.arguments.length == 2)
assert(s2.arguments.head.isInstanceOf[AnsiCast])
assert(s2.functionName == "varcharTypeWriteSideCheck")
case other => fail("Expect UpdateAction, but got: " + other)
}
assert(m.notMatchedActions.length == 1)
m.notMatchedActions.head match {
case InsertAction(_, Seq(
Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke))) =>
assert(s1.arguments.length == 2)
assert(s1.functionName == "charTypeWriteSideCheck")
assert(s2.arguments.length == 2)
assert(s2.arguments.head.isInstanceOf[AnsiCast])
assert(s2.functionName == "varcharTypeWriteSideCheck")
case other => fail("Expect UpdateAction, but got: " + other)
}
case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString)
}
}
test("MERGE INTO TABLE - skip resolution on v2 tables that accept any schema") {