[SPARK-36247][SQL] Check string length for char/varchar and apply type coercion in UPDATE/MERGE command
### What changes were proposed in this pull request?
We added the char/varchar support in 3.1, but the string length check is only applied to INSERT, not UPDATE/MERGE. This PR fixes it. This PR also adds the missing type coercion for UPDATE/MERGE.
### Why are the changes needed?
complete the char/varchar support and make UPDATE/MERGE easier to use by doing type coercion.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
new UT. No built-in source support UPDATE/MERGE so end-to-end test is not applicable here.
Closes #33468 from cloud-fan/char.
Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 068f8d434a
)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
6027137928
commit
14328e043d
|
@ -3300,6 +3300,41 @@ class Analyzer(override val catalogManager: CatalogManager)
|
||||||
} else {
|
} else {
|
||||||
v2Write
|
v2Write
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case u: UpdateTable if !u.skipSchemaResolution && u.resolved =>
|
||||||
|
resolveAssignments(u)
|
||||||
|
|
||||||
|
case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved =>
|
||||||
|
resolveAssignments(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def resolveAssignments(p: LogicalPlan): LogicalPlan = {
|
||||||
|
p.transformExpressions {
|
||||||
|
case assignment: Assignment =>
|
||||||
|
val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) {
|
||||||
|
AssertNotNull(assignment.value)
|
||||||
|
} else {
|
||||||
|
assignment.value
|
||||||
|
}
|
||||||
|
val casted = if (assignment.key.dataType != nullHandled.dataType) {
|
||||||
|
AnsiCast(nullHandled, assignment.key.dataType)
|
||||||
|
} else {
|
||||||
|
nullHandled
|
||||||
|
}
|
||||||
|
val rawKeyType = assignment.key.transform {
|
||||||
|
case a: AttributeReference =>
|
||||||
|
CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a)
|
||||||
|
}.dataType
|
||||||
|
val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) {
|
||||||
|
CharVarcharUtils.stringLengthCheck(casted, rawKeyType)
|
||||||
|
} else {
|
||||||
|
casted
|
||||||
|
}
|
||||||
|
val cleanedKey = assignment.key.transform {
|
||||||
|
case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a)
|
||||||
|
}
|
||||||
|
Assignment(cleanedKey, finalValue)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4218,14 +4253,6 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private def padOuterRefAttrCmp(outerAttr: Attribute, attr: Attribute): Seq[Expression] = {
|
|
||||||
val Seq(r, newAttr) = CharVarcharUtils.addPaddingInStringComparison(Seq(outerAttr, attr))
|
|
||||||
val newOuterRef = r.transform {
|
|
||||||
case ar: Attribute if ar.semanticEquals(outerAttr) => OuterReference(ar)
|
|
||||||
}
|
|
||||||
Seq(newOuterRef, newAttr)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = {
|
private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = {
|
||||||
if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr
|
if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr
|
||||||
}
|
}
|
||||||
|
|
|
@ -168,6 +168,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un
|
||||||
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)
|
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)
|
||||||
override def withMetadata(newMetadata: Metadata): Attribute = this
|
override def withMetadata(newMetadata: Metadata): Attribute = this
|
||||||
override def withExprId(newExprId: ExprId): UnresolvedAttribute = this
|
override def withExprId(newExprId: ExprId): UnresolvedAttribute = this
|
||||||
|
override def withDataType(newType: DataType): Attribute = this
|
||||||
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_ATTRIBUTE)
|
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_ATTRIBUTE)
|
||||||
|
|
||||||
override def toString: String = s"'$name"
|
override def toString: String = s"'$name"
|
||||||
|
|
|
@ -123,6 +123,7 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn
|
||||||
def withName(newName: String): Attribute
|
def withName(newName: String): Attribute
|
||||||
def withMetadata(newMetadata: Metadata): Attribute
|
def withMetadata(newMetadata: Metadata): Attribute
|
||||||
def withExprId(newExprId: ExprId): Attribute
|
def withExprId(newExprId: ExprId): Attribute
|
||||||
|
def withDataType(newType: DataType): Attribute
|
||||||
|
|
||||||
override def toAttribute: Attribute = this
|
override def toAttribute: Attribute = this
|
||||||
def newInstance(): Attribute
|
def newInstance(): Attribute
|
||||||
|
@ -339,6 +340,10 @@ case class AttributeReference(
|
||||||
AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier)
|
AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def withDataType(newType: DataType): Attribute = {
|
||||||
|
AttributeReference(name, newType, nullable, metadata)(exprId, qualifier)
|
||||||
|
}
|
||||||
|
|
||||||
override protected final def otherCopyArgs: Seq[AnyRef] = {
|
override protected final def otherCopyArgs: Seq[AnyRef] = {
|
||||||
exprId :: qualifier :: Nil
|
exprId :: qualifier :: Nil
|
||||||
}
|
}
|
||||||
|
@ -395,6 +400,8 @@ case class PrettyAttribute(
|
||||||
override def exprId: ExprId = throw new UnsupportedOperationException
|
override def exprId: ExprId = throw new UnsupportedOperationException
|
||||||
override def withExprId(newExprId: ExprId): Attribute =
|
override def withExprId(newExprId: ExprId): Attribute =
|
||||||
throw new UnsupportedOperationException
|
throw new UnsupportedOperationException
|
||||||
|
override def withDataType(newType: DataType): Attribute =
|
||||||
|
throw new UnsupportedOperationException
|
||||||
override def nullable: Boolean = true
|
override def nullable: Boolean = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -425,6 +425,12 @@ case class UpdateTable(
|
||||||
override def child: LogicalPlan = table
|
override def child: LogicalPlan = table
|
||||||
override protected def withNewChildInternal(newChild: LogicalPlan): UpdateTable =
|
override protected def withNewChildInternal(newChild: LogicalPlan): UpdateTable =
|
||||||
copy(table = newChild)
|
copy(table = newChild)
|
||||||
|
|
||||||
|
def skipSchemaResolution: Boolean = table match {
|
||||||
|
case r: NamedRelation => r.skipSchemaResolution
|
||||||
|
case SubqueryAlias(_, r: NamedRelation) => r.skipSchemaResolution
|
||||||
|
case _ => false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -437,6 +443,13 @@ case class MergeIntoTable(
|
||||||
matchedActions: Seq[MergeAction],
|
matchedActions: Seq[MergeAction],
|
||||||
notMatchedActions: Seq[MergeAction]) extends BinaryCommand with SupportsSubquery {
|
notMatchedActions: Seq[MergeAction]) extends BinaryCommand with SupportsSubquery {
|
||||||
def duplicateResolved: Boolean = targetTable.outputSet.intersect(sourceTable.outputSet).isEmpty
|
def duplicateResolved: Boolean = targetTable.outputSet.intersect(sourceTable.outputSet).isEmpty
|
||||||
|
|
||||||
|
def skipSchemaResolution: Boolean = targetTable match {
|
||||||
|
case r: NamedRelation => r.skipSchemaResolution
|
||||||
|
case SubqueryAlias(_, r: NamedRelation) => r.skipSchemaResolution
|
||||||
|
case _ => false
|
||||||
|
}
|
||||||
|
|
||||||
override def left: LogicalPlan = targetTable
|
override def left: LogicalPlan = targetTable
|
||||||
override def right: LogicalPlan = sourceTable
|
override def right: LogicalPlan = sourceTable
|
||||||
override protected def withNewChildrenInternal(
|
override protected def withNewChildrenInternal(
|
||||||
|
@ -466,7 +479,7 @@ case class UpdateAction(
|
||||||
newChildren: IndexedSeq[Expression]): UpdateAction =
|
newChildren: IndexedSeq[Expression]): UpdateAction =
|
||||||
copy(
|
copy(
|
||||||
condition = if (condition.isDefined) Some(newChildren.head) else None,
|
condition = if (condition.isDefined) Some(newChildren.head) else None,
|
||||||
assignments = newChildren.tail.asInstanceOf[Seq[Assignment]])
|
assignments = newChildren.takeRight(assignments.length).asInstanceOf[Seq[Assignment]])
|
||||||
}
|
}
|
||||||
|
|
||||||
case class UpdateStarAction(condition: Option[Expression]) extends MergeAction {
|
case class UpdateStarAction(condition: Option[Expression]) extends MergeAction {
|
||||||
|
@ -485,7 +498,7 @@ case class InsertAction(
|
||||||
newChildren: IndexedSeq[Expression]): InsertAction =
|
newChildren: IndexedSeq[Expression]): InsertAction =
|
||||||
copy(
|
copy(
|
||||||
condition = if (condition.isDefined) Some(newChildren.head) else None,
|
condition = if (condition.isDefined) Some(newChildren.head) else None,
|
||||||
assignments = newChildren.tail.asInstanceOf[Seq[Assignment]])
|
assignments = newChildren.takeRight(assignments.length).asInstanceOf[Seq[Assignment]])
|
||||||
}
|
}
|
||||||
|
|
||||||
case class InsertStarAction(condition: Option[Expression]) extends MergeAction {
|
case class InsertStarAction(condition: Option[Expression]) extends MergeAction {
|
||||||
|
|
|
@ -149,7 +149,7 @@ object CharVarcharUtils extends Logging {
|
||||||
}.getOrElse(expr)
|
}.getOrElse(expr)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def stringLengthCheck(expr: Expression, dt: DataType): Expression = {
|
def stringLengthCheck(expr: Expression, dt: DataType): Expression = {
|
||||||
dt match {
|
dt match {
|
||||||
case CharType(length) =>
|
case CharType(length) =>
|
||||||
StaticInvoke(
|
StaticInvoke(
|
||||||
|
|
|
@ -28,7 +28,8 @@ import org.apache.spark.sql.{AnalysisException, SaveMode}
|
||||||
import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
|
import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
|
||||||
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolvedFieldName, ResolvedTable, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTable}
|
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolvedFieldName, ResolvedTable, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTable}
|
||||||
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog}
|
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog}
|
||||||
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral}
|
import org.apache.spark.sql.catalyst.expressions.{AnsiCast, AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral}
|
||||||
|
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
|
||||||
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
|
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
|
||||||
import org.apache.spark.sql.catalyst.plans.logical.{AlterTableAlterColumn, AnalysisOnlyCommand, AppendData, Assignment, CreateTableAsSelect, CreateTableStatement, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction, UpdateTable}
|
import org.apache.spark.sql.catalyst.plans.logical.{AlterTableAlterColumn, AnalysisOnlyCommand, AppendData, Assignment, CreateTableAsSelect, CreateTableStatement, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction, UpdateTable}
|
||||||
import org.apache.spark.sql.catalyst.rules.Rule
|
import org.apache.spark.sql.catalyst.rules.Rule
|
||||||
|
@ -75,6 +76,13 @@ class PlanResolutionSuite extends AnalysisTest {
|
||||||
t
|
t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private val charVarcharTable: Table = {
|
||||||
|
val t = mock(classOf[Table])
|
||||||
|
when(t.schema()).thenReturn(new StructType().add("c1", "char(5)").add("c2", "varchar(5)"))
|
||||||
|
when(t.partitioning()).thenReturn(Array.empty[Transform])
|
||||||
|
t
|
||||||
|
}
|
||||||
|
|
||||||
private val v1Table: V1Table = {
|
private val v1Table: V1Table = {
|
||||||
val t = mock(classOf[CatalogTable])
|
val t = mock(classOf[CatalogTable])
|
||||||
when(t.schema).thenReturn(new StructType()
|
when(t.schema).thenReturn(new StructType()
|
||||||
|
@ -109,6 +117,7 @@ class PlanResolutionSuite extends AnalysisTest {
|
||||||
case "tab" => table
|
case "tab" => table
|
||||||
case "tab1" => table1
|
case "tab1" => table1
|
||||||
case "tab2" => table2
|
case "tab2" => table2
|
||||||
|
case "charvarchar" => charVarcharTable
|
||||||
case name => throw new NoSuchTableException(name)
|
case name => throw new NoSuchTableException(name)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -1058,12 +1067,33 @@ class PlanResolutionSuite extends AnalysisTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val sql = "UPDATE non_existing SET id=1"
|
val sql1 = "UPDATE non_existing SET id=1"
|
||||||
val parsed = parseAndResolve(sql)
|
val parsed1 = parseAndResolve(sql1)
|
||||||
parsed match {
|
parsed1 match {
|
||||||
case u: UpdateTable =>
|
case u: UpdateTable =>
|
||||||
assert(u.table.isInstanceOf[UnresolvedRelation])
|
assert(u.table.isInstanceOf[UnresolvedRelation])
|
||||||
case _ => fail("Expect UpdateTable, but got:\n" + parsed.treeString)
|
case _ => fail("Expect UpdateTable, but got:\n" + parsed1.treeString)
|
||||||
|
}
|
||||||
|
|
||||||
|
val sql2 = "UPDATE testcat.charvarchar SET c1='a', c2=1"
|
||||||
|
val parsed2 = parseAndResolve(sql2)
|
||||||
|
parsed2 match {
|
||||||
|
case u: UpdateTable =>
|
||||||
|
assert(u.assignments.length == 2)
|
||||||
|
u.assignments(0).value match {
|
||||||
|
case s: StaticInvoke =>
|
||||||
|
assert(s.arguments.length == 2)
|
||||||
|
assert(s.functionName == "charTypeWriteSideCheck")
|
||||||
|
case other => fail("Expect StaticInvoke, but got: " + other)
|
||||||
|
}
|
||||||
|
u.assignments(1).value match {
|
||||||
|
case s: StaticInvoke =>
|
||||||
|
assert(s.arguments.length == 2)
|
||||||
|
assert(s.arguments.head.isInstanceOf[AnsiCast])
|
||||||
|
assert(s.functionName == "varcharTypeWriteSideCheck")
|
||||||
|
case other => fail("Expect StaticInvoke, but got: " + other)
|
||||||
|
}
|
||||||
|
case _ => fail("Expect UpdateTable, but got:\n" + parsed2.treeString)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1568,6 +1598,42 @@ class PlanResolutionSuite extends AnalysisTest {
|
||||||
val e3 = intercept[AnalysisException](parseAndResolve(sql3))
|
val e3 = intercept[AnalysisException](parseAndResolve(sql3))
|
||||||
assert(e3.message.contains(
|
assert(e3.message.contains(
|
||||||
"cannot resolve s in MERGE command given columns [testcat.tab2.i, testcat.tab2.x]"))
|
"cannot resolve s in MERGE command given columns [testcat.tab2.i, testcat.tab2.x]"))
|
||||||
|
|
||||||
|
val sql4 =
|
||||||
|
"""
|
||||||
|
|MERGE INTO testcat.charvarchar
|
||||||
|
|USING testcat.tab2
|
||||||
|
|ON 1 = 1
|
||||||
|
|WHEN MATCHED THEN UPDATE SET c1='a', c2=1
|
||||||
|
|WHEN NOT MATCHED THEN INSERT (c1, c2) VALUES ('b', 2)
|
||||||
|
|""".stripMargin
|
||||||
|
val parsed4 = parseAndResolve(sql4)
|
||||||
|
parsed4 match {
|
||||||
|
case m: MergeIntoTable =>
|
||||||
|
assert(m.matchedActions.length == 1)
|
||||||
|
m.matchedActions.head match {
|
||||||
|
case UpdateAction(_, Seq(
|
||||||
|
Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke))) =>
|
||||||
|
assert(s1.arguments.length == 2)
|
||||||
|
assert(s1.functionName == "charTypeWriteSideCheck")
|
||||||
|
assert(s2.arguments.length == 2)
|
||||||
|
assert(s2.arguments.head.isInstanceOf[AnsiCast])
|
||||||
|
assert(s2.functionName == "varcharTypeWriteSideCheck")
|
||||||
|
case other => fail("Expect UpdateAction, but got: " + other)
|
||||||
|
}
|
||||||
|
assert(m.notMatchedActions.length == 1)
|
||||||
|
m.notMatchedActions.head match {
|
||||||
|
case InsertAction(_, Seq(
|
||||||
|
Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke))) =>
|
||||||
|
assert(s1.arguments.length == 2)
|
||||||
|
assert(s1.functionName == "charTypeWriteSideCheck")
|
||||||
|
assert(s2.arguments.length == 2)
|
||||||
|
assert(s2.arguments.head.isInstanceOf[AnsiCast])
|
||||||
|
assert(s2.functionName == "varcharTypeWriteSideCheck")
|
||||||
|
case other => fail("Expect UpdateAction, but got: " + other)
|
||||||
|
}
|
||||||
|
case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
test("MERGE INTO TABLE - skip resolution on v2 tables that accept any schema") {
|
test("MERGE INTO TABLE - skip resolution on v2 tables that accept any schema") {
|
||||||
|
|
Loading…
Reference in a new issue