[SPARK-36247][SQL] Check string length for char/varchar and apply type coercion in UPDATE/MERGE command
### What changes were proposed in this pull request?
We added the char/varchar support in 3.1, but the string length check is only applied to INSERT, not UPDATE/MERGE. This PR fixes it. This PR also adds the missing type coercion for UPDATE/MERGE.
### Why are the changes needed?
complete the char/varchar support and make UPDATE/MERGE easier to use by doing type coercion.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
new UT. No built-in source support UPDATE/MERGE so end-to-end test is not applicable here.
Closes #33468 from cloud-fan/char.
Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 068f8d434a
)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
6027137928
commit
14328e043d
|
@ -3300,6 +3300,41 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
} else {
|
||||
v2Write
|
||||
}
|
||||
|
||||
case u: UpdateTable if !u.skipSchemaResolution && u.resolved =>
|
||||
resolveAssignments(u)
|
||||
|
||||
case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved =>
|
||||
resolveAssignments(m)
|
||||
}
|
||||
|
||||
private def resolveAssignments(p: LogicalPlan): LogicalPlan = {
|
||||
p.transformExpressions {
|
||||
case assignment: Assignment =>
|
||||
val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) {
|
||||
AssertNotNull(assignment.value)
|
||||
} else {
|
||||
assignment.value
|
||||
}
|
||||
val casted = if (assignment.key.dataType != nullHandled.dataType) {
|
||||
AnsiCast(nullHandled, assignment.key.dataType)
|
||||
} else {
|
||||
nullHandled
|
||||
}
|
||||
val rawKeyType = assignment.key.transform {
|
||||
case a: AttributeReference =>
|
||||
CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a)
|
||||
}.dataType
|
||||
val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) {
|
||||
CharVarcharUtils.stringLengthCheck(casted, rawKeyType)
|
||||
} else {
|
||||
casted
|
||||
}
|
||||
val cleanedKey = assignment.key.transform {
|
||||
case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a)
|
||||
}
|
||||
Assignment(cleanedKey, finalValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4218,14 +4253,6 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
|
|||
}
|
||||
}
|
||||
|
||||
private def padOuterRefAttrCmp(outerAttr: Attribute, attr: Attribute): Seq[Expression] = {
|
||||
val Seq(r, newAttr) = CharVarcharUtils.addPaddingInStringComparison(Seq(outerAttr, attr))
|
||||
val newOuterRef = r.transform {
|
||||
case ar: Attribute if ar.semanticEquals(outerAttr) => OuterReference(ar)
|
||||
}
|
||||
Seq(newOuterRef, newAttr)
|
||||
}
|
||||
|
||||
private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = {
|
||||
if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr
|
||||
}
|
||||
|
|
|
@ -168,6 +168,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un
|
|||
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)
|
||||
override def withMetadata(newMetadata: Metadata): Attribute = this
|
||||
override def withExprId(newExprId: ExprId): UnresolvedAttribute = this
|
||||
override def withDataType(newType: DataType): Attribute = this
|
||||
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_ATTRIBUTE)
|
||||
|
||||
override def toString: String = s"'$name"
|
||||
|
|
|
@ -123,6 +123,7 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn
|
|||
def withName(newName: String): Attribute
|
||||
def withMetadata(newMetadata: Metadata): Attribute
|
||||
def withExprId(newExprId: ExprId): Attribute
|
||||
def withDataType(newType: DataType): Attribute
|
||||
|
||||
override def toAttribute: Attribute = this
|
||||
def newInstance(): Attribute
|
||||
|
@ -339,6 +340,10 @@ case class AttributeReference(
|
|||
AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier)
|
||||
}
|
||||
|
||||
override def withDataType(newType: DataType): Attribute = {
|
||||
AttributeReference(name, newType, nullable, metadata)(exprId, qualifier)
|
||||
}
|
||||
|
||||
override protected final def otherCopyArgs: Seq[AnyRef] = {
|
||||
exprId :: qualifier :: Nil
|
||||
}
|
||||
|
@ -395,6 +400,8 @@ case class PrettyAttribute(
|
|||
override def exprId: ExprId = throw new UnsupportedOperationException
|
||||
override def withExprId(newExprId: ExprId): Attribute =
|
||||
throw new UnsupportedOperationException
|
||||
override def withDataType(newType: DataType): Attribute =
|
||||
throw new UnsupportedOperationException
|
||||
override def nullable: Boolean = true
|
||||
}
|
||||
|
||||
|
|
|
@ -425,6 +425,12 @@ case class UpdateTable(
|
|||
override def child: LogicalPlan = table
|
||||
override protected def withNewChildInternal(newChild: LogicalPlan): UpdateTable =
|
||||
copy(table = newChild)
|
||||
|
||||
def skipSchemaResolution: Boolean = table match {
|
||||
case r: NamedRelation => r.skipSchemaResolution
|
||||
case SubqueryAlias(_, r: NamedRelation) => r.skipSchemaResolution
|
||||
case _ => false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -437,6 +443,13 @@ case class MergeIntoTable(
|
|||
matchedActions: Seq[MergeAction],
|
||||
notMatchedActions: Seq[MergeAction]) extends BinaryCommand with SupportsSubquery {
|
||||
def duplicateResolved: Boolean = targetTable.outputSet.intersect(sourceTable.outputSet).isEmpty
|
||||
|
||||
def skipSchemaResolution: Boolean = targetTable match {
|
||||
case r: NamedRelation => r.skipSchemaResolution
|
||||
case SubqueryAlias(_, r: NamedRelation) => r.skipSchemaResolution
|
||||
case _ => false
|
||||
}
|
||||
|
||||
override def left: LogicalPlan = targetTable
|
||||
override def right: LogicalPlan = sourceTable
|
||||
override protected def withNewChildrenInternal(
|
||||
|
@ -466,7 +479,7 @@ case class UpdateAction(
|
|||
newChildren: IndexedSeq[Expression]): UpdateAction =
|
||||
copy(
|
||||
condition = if (condition.isDefined) Some(newChildren.head) else None,
|
||||
assignments = newChildren.tail.asInstanceOf[Seq[Assignment]])
|
||||
assignments = newChildren.takeRight(assignments.length).asInstanceOf[Seq[Assignment]])
|
||||
}
|
||||
|
||||
case class UpdateStarAction(condition: Option[Expression]) extends MergeAction {
|
||||
|
@ -485,7 +498,7 @@ case class InsertAction(
|
|||
newChildren: IndexedSeq[Expression]): InsertAction =
|
||||
copy(
|
||||
condition = if (condition.isDefined) Some(newChildren.head) else None,
|
||||
assignments = newChildren.tail.asInstanceOf[Seq[Assignment]])
|
||||
assignments = newChildren.takeRight(assignments.length).asInstanceOf[Seq[Assignment]])
|
||||
}
|
||||
|
||||
case class InsertStarAction(condition: Option[Expression]) extends MergeAction {
|
||||
|
|
|
@ -149,7 +149,7 @@ object CharVarcharUtils extends Logging {
|
|||
}.getOrElse(expr)
|
||||
}
|
||||
|
||||
private def stringLengthCheck(expr: Expression, dt: DataType): Expression = {
|
||||
def stringLengthCheck(expr: Expression, dt: DataType): Expression = {
|
||||
dt match {
|
||||
case CharType(length) =>
|
||||
StaticInvoke(
|
||||
|
|
|
@ -28,7 +28,8 @@ import org.apache.spark.sql.{AnalysisException, SaveMode}
|
|||
import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
|
||||
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolvedFieldName, ResolvedTable, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTable}
|
||||
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog}
|
||||
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral}
|
||||
import org.apache.spark.sql.catalyst.expressions.{AnsiCast, AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral}
|
||||
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
|
||||
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{AlterTableAlterColumn, AnalysisOnlyCommand, AppendData, Assignment, CreateTableAsSelect, CreateTableStatement, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction, UpdateTable}
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
|
@ -75,6 +76,13 @@ class PlanResolutionSuite extends AnalysisTest {
|
|||
t
|
||||
}
|
||||
|
||||
private val charVarcharTable: Table = {
|
||||
val t = mock(classOf[Table])
|
||||
when(t.schema()).thenReturn(new StructType().add("c1", "char(5)").add("c2", "varchar(5)"))
|
||||
when(t.partitioning()).thenReturn(Array.empty[Transform])
|
||||
t
|
||||
}
|
||||
|
||||
private val v1Table: V1Table = {
|
||||
val t = mock(classOf[CatalogTable])
|
||||
when(t.schema).thenReturn(new StructType()
|
||||
|
@ -109,6 +117,7 @@ class PlanResolutionSuite extends AnalysisTest {
|
|||
case "tab" => table
|
||||
case "tab1" => table1
|
||||
case "tab2" => table2
|
||||
case "charvarchar" => charVarcharTable
|
||||
case name => throw new NoSuchTableException(name)
|
||||
}
|
||||
})
|
||||
|
@ -1058,12 +1067,33 @@ class PlanResolutionSuite extends AnalysisTest {
|
|||
}
|
||||
}
|
||||
|
||||
val sql = "UPDATE non_existing SET id=1"
|
||||
val parsed = parseAndResolve(sql)
|
||||
parsed match {
|
||||
val sql1 = "UPDATE non_existing SET id=1"
|
||||
val parsed1 = parseAndResolve(sql1)
|
||||
parsed1 match {
|
||||
case u: UpdateTable =>
|
||||
assert(u.table.isInstanceOf[UnresolvedRelation])
|
||||
case _ => fail("Expect UpdateTable, but got:\n" + parsed.treeString)
|
||||
case _ => fail("Expect UpdateTable, but got:\n" + parsed1.treeString)
|
||||
}
|
||||
|
||||
val sql2 = "UPDATE testcat.charvarchar SET c1='a', c2=1"
|
||||
val parsed2 = parseAndResolve(sql2)
|
||||
parsed2 match {
|
||||
case u: UpdateTable =>
|
||||
assert(u.assignments.length == 2)
|
||||
u.assignments(0).value match {
|
||||
case s: StaticInvoke =>
|
||||
assert(s.arguments.length == 2)
|
||||
assert(s.functionName == "charTypeWriteSideCheck")
|
||||
case other => fail("Expect StaticInvoke, but got: " + other)
|
||||
}
|
||||
u.assignments(1).value match {
|
||||
case s: StaticInvoke =>
|
||||
assert(s.arguments.length == 2)
|
||||
assert(s.arguments.head.isInstanceOf[AnsiCast])
|
||||
assert(s.functionName == "varcharTypeWriteSideCheck")
|
||||
case other => fail("Expect StaticInvoke, but got: " + other)
|
||||
}
|
||||
case _ => fail("Expect UpdateTable, but got:\n" + parsed2.treeString)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1568,6 +1598,42 @@ class PlanResolutionSuite extends AnalysisTest {
|
|||
val e3 = intercept[AnalysisException](parseAndResolve(sql3))
|
||||
assert(e3.message.contains(
|
||||
"cannot resolve s in MERGE command given columns [testcat.tab2.i, testcat.tab2.x]"))
|
||||
|
||||
val sql4 =
|
||||
"""
|
||||
|MERGE INTO testcat.charvarchar
|
||||
|USING testcat.tab2
|
||||
|ON 1 = 1
|
||||
|WHEN MATCHED THEN UPDATE SET c1='a', c2=1
|
||||
|WHEN NOT MATCHED THEN INSERT (c1, c2) VALUES ('b', 2)
|
||||
|""".stripMargin
|
||||
val parsed4 = parseAndResolve(sql4)
|
||||
parsed4 match {
|
||||
case m: MergeIntoTable =>
|
||||
assert(m.matchedActions.length == 1)
|
||||
m.matchedActions.head match {
|
||||
case UpdateAction(_, Seq(
|
||||
Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke))) =>
|
||||
assert(s1.arguments.length == 2)
|
||||
assert(s1.functionName == "charTypeWriteSideCheck")
|
||||
assert(s2.arguments.length == 2)
|
||||
assert(s2.arguments.head.isInstanceOf[AnsiCast])
|
||||
assert(s2.functionName == "varcharTypeWriteSideCheck")
|
||||
case other => fail("Expect UpdateAction, but got: " + other)
|
||||
}
|
||||
assert(m.notMatchedActions.length == 1)
|
||||
m.notMatchedActions.head match {
|
||||
case InsertAction(_, Seq(
|
||||
Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke))) =>
|
||||
assert(s1.arguments.length == 2)
|
||||
assert(s1.functionName == "charTypeWriteSideCheck")
|
||||
assert(s2.arguments.length == 2)
|
||||
assert(s2.arguments.head.isInstanceOf[AnsiCast])
|
||||
assert(s2.functionName == "varcharTypeWriteSideCheck")
|
||||
case other => fail("Expect UpdateAction, but got: " + other)
|
||||
}
|
||||
case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString)
|
||||
}
|
||||
}
|
||||
|
||||
test("MERGE INTO TABLE - skip resolution on v2 tables that accept any schema") {
|
||||
|
|
Loading…
Reference in a new issue