[SPARK-22815][SQL] Keep PromotePrecision in Optimized Plans
## What changes were proposed in this pull request? We could get incorrect results by running DecimalPrecision twice. This PR resolves the original found in https://github.com/apache/spark/pull/15048 and https://github.com/apache/spark/pull/14797. After this PR, it becomes easier to change it back using `children` instead of using `innerChildren`. ## How was this patch tested? The existing test. Author: gatorsmile <gatorsmile@gmail.com> Closes #20000 from gatorsmile/keepPromotePrecision.
This commit is contained in:
parent
28315714dd
commit
b779c93518
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
|
|||
import scala.util.control.NonFatal
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, AttributeSet, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._
|
||||
|
@ -238,6 +238,8 @@ object StreamingJoinHelper extends PredicateHelper with Logging {
|
|||
collect(child, !negate)
|
||||
case CheckOverflow(child, _) =>
|
||||
collect(child, negate)
|
||||
case PromotePrecision(child) =>
|
||||
collect(child, negate)
|
||||
case Cast(child, dataType, _) =>
|
||||
dataType match {
|
||||
case _: NumericType | _: TimestampType => collect(child, negate)
|
||||
|
|
|
@ -70,10 +70,12 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
|
|||
case class PromotePrecision(child: Expression) extends UnaryExpression {
|
||||
override def dataType: DataType = child.dataType
|
||||
override def eval(input: InternalRow): Any = child.eval(input)
|
||||
/** Just a simple pass-through for code generation. */
|
||||
override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx)
|
||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("")
|
||||
override def prettyName: String = "promote_precision"
|
||||
override def sql: String = child.sql
|
||||
override lazy val canonicalized: Expression = child.canonicalized
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -614,7 +614,6 @@ object SimplifyCasts extends Rule[LogicalPlan] {
|
|||
object RemoveDispensableExpressions extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
|
||||
case UnaryPositive(child) => child
|
||||
case PromotePrecision(child) => child
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue