[SPARK-35857][SQL] The ANSI flag of Cast should be kept after being copied

### What changes were proposed in this pull request?

Make the ANSI flag part of expression `Cast`'s  parameter list, instead of fetching it from the sessional SQLConf.

### Why are the changes needed?

For Views, it is important to show consistent results even the ANSI configuration is different in the running session. This is why many expressions like 'Add'/'Divide' making the ANSI flag part of its case class parameter list.

We should make it consistent for the expression `Cast`

### Does this PR introduce _any_ user-facing change?

Yes, the `Cast` inside a View always behaves the same, independent of the ANSI model SQL configuration in the current session.

### How was this patch tested?

Existing UT

Closes #33027 from gengliangwang/ansiFlagInCast.

Authored-by: Gengliang Wang <gengliang@apache.org>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
This commit is contained in:
Gengliang Wang 2021-06-23 16:52:33 +08:00
parent 758b423a31
commit 6f51e37eb5
10 changed files with 25 additions and 20 deletions

View file

@ -431,7 +431,7 @@ class Analyzer(override val catalogManager: CatalogManager)
case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil)
case e if !e.resolved => u
case g: Generator => MultiAlias(g, Nil)
case c @ Cast(ne: NamedExpression, _, _) => Alias(c, ne.name)()
case c @ Cast(ne: NamedExpression, _, _, _) => Alias(c, ne.name)()
case e: ExtractValue => Alias(e, toPrettySQL(e))()
case e if optGenAliasFunc.isDefined =>
Alias(child, optGenAliasFunc.get.apply(e))()

View file

@ -239,7 +239,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging {
collect(child, negate)
case PromotePrecision(child) =>
collect(child, negate)
case Cast(child, dataType, _) =>
case Cast(child, dataType, _, _) =>
dataType match {
case _: NumericType | _: TimestampType => collect(child, negate)
case _ =>

View file

@ -1945,16 +1945,21 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
""",
since = "1.0.0",
group = "conversion_funcs")
case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String] = None)
case class Cast(
child: Expression,
dataType: DataType,
timeZoneId: Option[String] = None,
override val ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
extends CastBase {
def this(child: Expression, dataType: DataType, timeZoneId: Option[String]) =
this(child, dataType, timeZoneId, ansiEnabled = SQLConf.get.ansiEnabled)
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST)
override protected val ansiEnabled: Boolean = SQLConf.get.ansiEnabled
override def canCast(from: DataType, to: DataType): Boolean = if (ansiEnabled) {
AnsiCast.canCast(from, to)
} else {

View file

@ -113,7 +113,7 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
// Not a canonical form. In this case we first canonicalize the expression by swapping the
// literal and cast side, then process the result and swap the literal and cast again to
// restore the original order.
case BinaryComparison(Literal(_, literalType), Cast(fromExp, toType, _))
case BinaryComparison(Literal(_, literalType), Cast(fromExp, toType, _, _))
if canImplicitlyCast(fromExp, toType, literalType) =>
def swap(e: Expression): Expression = e match {
case GreaterThan(left, right) => LessThan(right, left)
@ -130,7 +130,7 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
// In case both sides have numeric type, optimize the comparison by removing casts or
// moving cast to the literal side.
case be @ BinaryComparison(
Cast(fromExp, toType: NumericType, _), Literal(value, literalType))
Cast(fromExp, toType: NumericType, _, _), Literal(value, literalType))
if canImplicitlyCast(fromExp, toType, literalType) =>
simplifyNumericComparison(be, fromExp, toType, value)
@ -141,7 +141,7 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
// values.
// 2. this rule only handles the case when both `fromExp` and value in `in.list` are of numeric
// type.
case in @ In(Cast(fromExp, toType: NumericType, _), list @ Seq(firstLit, _*))
case in @ In(Cast(fromExp, toType: NumericType, _, _), list @ Seq(firstLit, _*))
if canImplicitlyCast(fromExp, toType, firstLit.dataType) =>
// There are 3 kinds of literals in the list:
@ -184,7 +184,7 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
// The same with `In` expression, the analyzer makes sure that the hset of InSet is already of
// the same data type, so simply check `fromExp.dataType` can implicitly cast to `toType` and
// both `fromExp.dataType` and `toType` is numeric type or not.
case inSet @ InSet(Cast(fromExp, toType: NumericType, _), hset)
case inSet @ InSet(Cast(fromExp, toType: NumericType, _, _), hset)
if hset.nonEmpty && canImplicitlyCast(fromExp, toType, toType) =>
// The same with `In`, there are 3 kinds of literals in the hset:

View file

@ -757,7 +757,7 @@ object NullPropagation extends Rule[LogicalPlan] {
case q: LogicalPlan => q.transformExpressionsUpWithPruning(
t => t.containsAnyPattern(NULL_CHECK, NULL_LITERAL, COUNT)
|| t.containsAllPatterns(WINDOW_EXPRESSION, CAST, LITERAL), ruleId) {
case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) =>
case e @ WindowExpression(Cast(Literal(0L, _), _, _, _), _) =>
Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone))
case e @ AggregateExpression(Count(exprs), _, _, _, _) if exprs.forall(isNullLiteral) =>
Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone))
@ -934,8 +934,8 @@ object FoldablePropagation extends Rule[LogicalPlan] {
object SimplifyCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning(
_.containsPattern(CAST), ruleId) {
case Cast(e, dataType, _) if e.dataType == dataType => e
case c @ Cast(e, dataType, _) => (e.dataType, dataType) match {
case Cast(e, dataType, _, _) if e.dataType == dataType => e
case c @ Cast(e, dataType, _, _) => (e.dataType, dataType) match {
case (ArrayType(from, false), ArrayType(to, true)) if from == to => e
case (MapType(fromKey, fromValue, false), MapType(toKey, toValue, true))
if fromKey == toKey && fromValue == toValue => e
@ -989,7 +989,7 @@ object CombineConcats extends Rule[LogicalPlan] {
// If `spark.sql.function.concatBinaryAsString` is false, nested `Concat` exprs possibly
// have `Concat`s with binary output. Since `TypeCoercion` casts them into strings,
// we need to handle the case to combine all nested `Concat`s.
case c @ Cast(Concat(children), StringType, _) =>
case c @ Cast(Concat(children), StringType, _, _) =>
val newChildren = children.map { e => c.copy(child = e) }
stack.pushAll(newChildren.reverse)
case child =>
@ -1001,7 +1001,7 @@ object CombineConcats extends Rule[LogicalPlan] {
private def hasNestedConcats(concat: Concat): Boolean = concat.children.exists {
case c: Concat => true
case c @ Cast(Concat(children), StringType, _) => true
case c @ Cast(Concat(children), StringType, _, _) => true
case _ => false
}

View file

@ -67,9 +67,9 @@ trait ConstraintHelper {
val candidateConstraints = predicates - eq
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
case eq @ EqualTo(l @ Cast(_: Attribute, _, _), r: Attribute) =>
case eq @ EqualTo(l @ Cast(_: Attribute, _, _, _), r: Attribute) =>
inferredConstraints ++= replaceConstraints(predicates - eq, r, l)
case eq @ EqualTo(l: Attribute, r @ Cast(_: Attribute, _, _)) =>
case eq @ EqualTo(l: Attribute, r @ Cast(_: Attribute, _, _, _)) =>
inferredConstraints ++= replaceConstraints(predicates - eq, l, r)
case _ => // No inference
}

View file

@ -179,7 +179,7 @@ class Column(val expr: Expression) extends Logging {
// NamedExpression under this Cast.
case c: Cast =>
c.transformUp {
case c @ Cast(_: NamedExpression, _, _) => UnresolvedAlias(c)
case c @ Cast(_: NamedExpression, _, _, _) => UnresolvedAlias(c)
} match {
case ne: NamedExpression => ne
case _ => UnresolvedAlias(expr, Some(Column.generateAlias))

View file

@ -52,7 +52,7 @@ case class SubqueryBroadcastExec(
val key = buildKeys(index)
val name = key match {
case n: NamedExpression => n.name
case Cast(n: NamedExpression, _, _) => n.name
case Cast(n: NamedExpression, _, _, _) => n.name
case _ => "key"
}
Seq(AttributeReference(name, key.dataType, key.nullable)())

View file

@ -65,7 +65,7 @@ object DetectAmbiguousSelfJoin extends Rule[LogicalPlan] {
object AttrWithCast {
def unapply(expr: Expression): Option[AttributeReference] = expr match {
case Cast(child, _, _) => unapply(child)
case Cast(child, _, _, _) => unapply(child)
case a: AttributeReference => Some(a)
case _ => None
}

View file

@ -760,7 +760,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
def unapply(expr: Expression): Option[Attribute] = {
expr match {
case attr: Attribute => Some(attr)
case Cast(child @ IntegralType(), dt: IntegralType, _)
case Cast(child @ IntegralType(), dt: IntegralType, _, _)
if Cast.canUpCast(child.dataType.asInstanceOf[AtomicType], dt) => unapply(child)
case _ => None
}