[SPARK-19727][SQL][FOLLOWUP] Fix for round function that modifies original column
## What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/17075 , to fix the bug in codegen path. ## How was this patch tested? new regression test Author: Wenchen Fan <wenchen@databricks.com> Closes #19576 from cloud-fan/bug.
This commit is contained in:
parent
e80da8129a
commit
7fdacbc77b
|
@ -310,7 +310,7 @@ object CatalystTypeConverters {
|
|||
case d: JavaBigInteger => Decimal(d)
|
||||
case d: Decimal => d
|
||||
}
|
||||
decimal.toPrecision(dataType.precision, dataType.scale).orNull
|
||||
decimal.toPrecision(dataType.precision, dataType.scale)
|
||||
}
|
||||
override def toScala(catalystValue: Decimal): JavaBigDecimal = {
|
||||
if (catalystValue == null) null
|
||||
|
|
|
@ -387,10 +387,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
|
|||
/**
|
||||
* Create new `Decimal` with precision and scale given in `decimalType` (if any),
|
||||
* returning null if it overflows or creating a new `value` and returning it if successful.
|
||||
*
|
||||
*/
|
||||
private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal =
|
||||
value.toPrecision(decimalType.precision, decimalType.scale).orNull
|
||||
value.toPrecision(decimalType.precision, decimalType.scale)
|
||||
|
||||
|
||||
private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
|
||||
|
|
|
@ -85,7 +85,7 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary
|
|||
override def nullable: Boolean = true
|
||||
|
||||
override def nullSafeEval(input: Any): Any =
|
||||
input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale).orNull
|
||||
input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale)
|
||||
|
||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
nullSafeCodeGen(ctx, ev, eval => {
|
||||
|
|
|
@ -1044,7 +1044,7 @@ abstract class RoundBase(child: Expression, scale: Expression,
|
|||
dataType match {
|
||||
case DecimalType.Fixed(_, s) =>
|
||||
val decimal = input1.asInstanceOf[Decimal]
|
||||
decimal.toPrecision(decimal.precision, s, mode).orNull
|
||||
decimal.toPrecision(decimal.precision, s, mode)
|
||||
case ByteType =>
|
||||
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
|
||||
case ShortType =>
|
||||
|
@ -1076,12 +1076,8 @@ abstract class RoundBase(child: Expression, scale: Expression,
|
|||
val evaluationCode = dataType match {
|
||||
case DecimalType.Fixed(_, s) =>
|
||||
s"""
|
||||
if (${ce.value}.changePrecision(${ce.value}.precision(), ${s},
|
||||
java.math.BigDecimal.${modeStr})) {
|
||||
${ev.value} = ${ce.value};
|
||||
} else {
|
||||
${ev.isNull} = true;
|
||||
}"""
|
||||
${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, Decimal.$modeStr());
|
||||
${ev.isNull} = ${ev.value} == null;"""
|
||||
case ByteType =>
|
||||
if (_scale < 0) {
|
||||
s"""
|
||||
|
|
|
@ -234,22 +234,17 @@ final class Decimal extends Ordered[Decimal] with Serializable {
|
|||
changePrecision(precision, scale, ROUND_HALF_UP)
|
||||
}
|
||||
|
||||
def changePrecision(precision: Int, scale: Int, mode: Int): Boolean = mode match {
|
||||
case java.math.BigDecimal.ROUND_HALF_UP => changePrecision(precision, scale, ROUND_HALF_UP)
|
||||
case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create new `Decimal` with given precision and scale.
|
||||
*
|
||||
* @return `Some(decimal)` if successful or `None` if overflow would occur
|
||||
* @return a non-null `Decimal` value if successful or `null` if overflow would occur.
|
||||
*/
|
||||
private[sql] def toPrecision(
|
||||
precision: Int,
|
||||
scale: Int,
|
||||
roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Option[Decimal] = {
|
||||
roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = {
|
||||
val copy = clone()
|
||||
if (copy.changePrecision(precision, scale, roundMode)) Some(copy) else None
|
||||
if (copy.changePrecision(precision, scale, roundMode)) copy else null
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -257,8 +252,10 @@ final class Decimal extends Ordered[Decimal] with Serializable {
|
|||
*
|
||||
* @return true if successful, false if overflow would occur
|
||||
*/
|
||||
private[sql] def changePrecision(precision: Int, scale: Int,
|
||||
roundMode: BigDecimal.RoundingMode.Value): Boolean = {
|
||||
private[sql] def changePrecision(
|
||||
precision: Int,
|
||||
scale: Int,
|
||||
roundMode: BigDecimal.RoundingMode.Value): Boolean = {
|
||||
// fast path for UnsafeProjection
|
||||
if (precision == this.precision && scale == this.scale) {
|
||||
return true
|
||||
|
@ -393,14 +390,20 @@ final class Decimal extends Ordered[Decimal] with Serializable {
|
|||
|
||||
def floor: Decimal = if (scale == 0) this else {
|
||||
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
|
||||
toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse(
|
||||
throw new AnalysisException(s"Overflow when setting precision to $newPrecision"))
|
||||
val res = toPrecision(newPrecision, 0, ROUND_FLOOR)
|
||||
if (res == null) {
|
||||
throw new AnalysisException(s"Overflow when setting precision to $newPrecision")
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
def ceil: Decimal = if (scale == 0) this else {
|
||||
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
|
||||
toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse(
|
||||
throw new AnalysisException(s"Overflow when setting precision to $newPrecision"))
|
||||
val res = toPrecision(newPrecision, 0, ROUND_CEILING)
|
||||
if (res == null) {
|
||||
throw new AnalysisException(s"Overflow when setting precision to $newPrecision")
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -213,7 +213,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
|
|||
assert(d.changePrecision(10, 0, mode))
|
||||
assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode")
|
||||
|
||||
val copy = d.toPrecision(10, 0, mode).orNull
|
||||
val copy = d.toPrecision(10, 0, mode)
|
||||
assert(copy !== null)
|
||||
assert(d.ne(copy))
|
||||
assert(d === copy)
|
||||
|
|
|
@ -258,6 +258,18 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext {
|
|||
)
|
||||
}
|
||||
|
||||
test("round/bround with table columns") {
|
||||
withTable("t") {
|
||||
Seq(BigDecimal("5.9")).toDF("i").write.saveAsTable("t")
|
||||
checkAnswer(
|
||||
sql("select i, round(i) from t"),
|
||||
Seq(Row(BigDecimal("5.9"), BigDecimal("6"))))
|
||||
checkAnswer(
|
||||
sql("select i, bround(i) from t"),
|
||||
Seq(Row(BigDecimal("5.9"), BigDecimal("6"))))
|
||||
}
|
||||
}
|
||||
|
||||
test("exp") {
|
||||
testOneToOneMathFunction(exp, math.exp)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue