[SPARK-23179][SQL] Support option to throw exception if overflow occurs during Decimal arithmetic
## What changes were proposed in this pull request? SQL ANSI 2011 states that in case of overflow during arithmetic operations, an exception should be thrown. This is what most of the SQL DBs do (eg. SQLServer, DB2). Hive currently returns NULL (as Spark does) but HIVE-18291 is open to be SQL compliant. The PR introduce an option to decide which behavior Spark should follow, ie. returning NULL on overflow or throwing an exception. ## How was this patch tested? added UTs Closes #20350 from mgaido91/SPARK-23179. Authored-by: Marco Gaido <marcogaido91@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
7cbe01e8ef
commit
3139d642fa
|
@ -82,6 +82,8 @@ object DecimalPrecision extends TypeCoercionRule {
|
|||
PromotePrecision(Cast(e, dataType))
|
||||
}
|
||||
|
||||
private def nullOnOverflow: Boolean = SQLConf.get.decimalOperationsNullOnOverflow
|
||||
|
||||
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
// fix decimal precision for expressions
|
||||
case q => q.transformExpressionsUp(
|
||||
|
@ -105,7 +107,7 @@ object DecimalPrecision extends TypeCoercionRule {
|
|||
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
|
||||
}
|
||||
CheckOverflow(Add(promotePrecision(e1, resultType), promotePrecision(e2, resultType)),
|
||||
resultType)
|
||||
resultType, nullOnOverflow)
|
||||
|
||||
case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
val resultScale = max(s1, s2)
|
||||
|
@ -116,7 +118,7 @@ object DecimalPrecision extends TypeCoercionRule {
|
|||
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
|
||||
}
|
||||
CheckOverflow(Subtract(promotePrecision(e1, resultType), promotePrecision(e2, resultType)),
|
||||
resultType)
|
||||
resultType, nullOnOverflow)
|
||||
|
||||
case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
|
||||
|
@ -126,7 +128,7 @@ object DecimalPrecision extends TypeCoercionRule {
|
|||
}
|
||||
val widerType = widerDecimalType(p1, s1, p2, s2)
|
||||
CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
|
||||
resultType)
|
||||
resultType, nullOnOverflow)
|
||||
|
||||
case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
|
||||
|
@ -148,7 +150,7 @@ object DecimalPrecision extends TypeCoercionRule {
|
|||
}
|
||||
val widerType = widerDecimalType(p1, s1, p2, s2)
|
||||
CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
|
||||
resultType)
|
||||
resultType, nullOnOverflow)
|
||||
|
||||
case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
|
||||
|
@ -159,7 +161,7 @@ object DecimalPrecision extends TypeCoercionRule {
|
|||
// resultType may have lower precision, so we cast them into wider type first.
|
||||
val widerType = widerDecimalType(p1, s1, p2, s2)
|
||||
CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
|
||||
resultType)
|
||||
resultType, nullOnOverflow)
|
||||
|
||||
case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
|
||||
|
@ -170,7 +172,7 @@ object DecimalPrecision extends TypeCoercionRule {
|
|||
// resultType may have lower precision, so we cast them into wider type first.
|
||||
val widerType = widerDecimalType(p1, s1, p2, s2)
|
||||
CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
|
||||
resultType)
|
||||
resultType, nullOnOverflow)
|
||||
|
||||
case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
|
||||
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
|
||||
|
|
|
@ -236,7 +236,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging {
|
|||
collect(left, negate) ++ collect(right, !negate)
|
||||
case UnaryMinus(child) =>
|
||||
collect(child, !negate)
|
||||
case CheckOverflow(child, _) =>
|
||||
case CheckOverflow(child, _, _) =>
|
||||
collect(child, negate)
|
||||
case PromotePrecision(child) =>
|
||||
collect(child, negate)
|
||||
|
|
|
@ -114,7 +114,7 @@ object RowEncoder {
|
|||
d,
|
||||
"fromDecimal",
|
||||
inputObject :: Nil,
|
||||
returnNullable = false), d)
|
||||
returnNullable = false), d, SQLConf.get.decimalOperationsNullOnOverflow)
|
||||
|
||||
case StringType => createSerializerForString(inputObject)
|
||||
|
||||
|
|
|
@ -81,30 +81,34 @@ case class PromotePrecision(child: Expression) extends UnaryExpression {
|
|||
|
||||
/**
|
||||
* Rounds the decimal to given scale and check whether the decimal can fit in provided precision
|
||||
* or not, returns null if not.
|
||||
* or not. If not, if `nullOnOverflow` is `true`, it returns `null`; otherwise an
|
||||
* `ArithmeticException` is thrown.
|
||||
*/
|
||||
case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression {
|
||||
case class CheckOverflow(
|
||||
child: Expression,
|
||||
dataType: DecimalType,
|
||||
nullOnOverflow: Boolean) extends UnaryExpression {
|
||||
|
||||
override def nullable: Boolean = true
|
||||
|
||||
override def nullSafeEval(input: Any): Any =
|
||||
input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale)
|
||||
input.asInstanceOf[Decimal].toPrecision(
|
||||
dataType.precision,
|
||||
dataType.scale,
|
||||
Decimal.ROUND_HALF_UP,
|
||||
nullOnOverflow)
|
||||
|
||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
nullSafeCodeGen(ctx, ev, eval => {
|
||||
val tmp = ctx.freshName("tmp")
|
||||
s"""
|
||||
| Decimal $tmp = $eval.clone();
|
||||
| if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) {
|
||||
| ${ev.value} = $tmp;
|
||||
| } else {
|
||||
| ${ev.isNull} = true;
|
||||
| }
|
||||
|${ev.value} = $eval.toPrecision(
|
||||
| ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow);
|
||||
|${ev.isNull} = ${ev.value} == null;
|
||||
""".stripMargin
|
||||
})
|
||||
}
|
||||
|
||||
override def toString: String = s"CheckOverflow($child, $dataType)"
|
||||
override def toString: String = s"CheckOverflow($child, $dataType, $nullOnOverflow)"
|
||||
|
||||
override def sql: String = child.sql
|
||||
}
|
||||
|
|
|
@ -1138,8 +1138,10 @@ abstract class RoundBase(child: Expression, scale: Expression,
|
|||
val evaluationCode = dataType match {
|
||||
case DecimalType.Fixed(_, s) =>
|
||||
s"""
|
||||
${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, Decimal.$modeStr());
|
||||
${ev.isNull} = ${ev.value} == null;"""
|
||||
|${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s,
|
||||
| Decimal.$modeStr(), true);
|
||||
|${ev.isNull} = ${ev.value} == null;
|
||||
""".stripMargin
|
||||
case ByteType =>
|
||||
if (_scale < 0) {
|
||||
s"""
|
||||
|
|
|
@ -1441,6 +1441,16 @@ object SQLConf {
|
|||
.booleanConf
|
||||
.createWithDefault(true)
|
||||
|
||||
val DECIMAL_OPERATIONS_NULL_ON_OVERFLOW =
|
||||
buildConf("spark.sql.decimalOperations.nullOnOverflow")
|
||||
.internal()
|
||||
.doc("When true (default), if an overflow on a decimal occurs, then NULL is returned. " +
|
||||
"Spark's older versions and Hive behave in this way. If turned to false, SQL ANSI 2011 " +
|
||||
"specification will be followed instead: an arithmetic exception is thrown, as most " +
|
||||
"of the SQL databases do.")
|
||||
.booleanConf
|
||||
.createWithDefault(true)
|
||||
|
||||
val LITERAL_PICK_MINIMUM_PRECISION =
|
||||
buildConf("spark.sql.legacy.literal.pickMinimumPrecision")
|
||||
.internal()
|
||||
|
@ -2205,6 +2215,8 @@ class SQLConf extends Serializable with Logging {
|
|||
|
||||
def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)
|
||||
|
||||
def decimalOperationsNullOnOverflow: Boolean = getConf(DECIMAL_OPERATIONS_NULL_ON_OVERFLOW)
|
||||
|
||||
def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION)
|
||||
|
||||
def continuousStreamingEpochBacklogQueueSize: Int =
|
||||
|
|
|
@ -249,14 +249,25 @@ final class Decimal extends Ordered[Decimal] with Serializable {
|
|||
/**
|
||||
* Create new `Decimal` with given precision and scale.
|
||||
*
|
||||
* @return a non-null `Decimal` value if successful or `null` if overflow would occur.
|
||||
* @return a non-null `Decimal` value if successful. Otherwise, if `nullOnOverflow` is true, null
|
||||
* is returned; if `nullOnOverflow` is false, an `ArithmeticException` is thrown.
|
||||
*/
|
||||
private[sql] def toPrecision(
|
||||
precision: Int,
|
||||
scale: Int,
|
||||
roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = {
|
||||
roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP,
|
||||
nullOnOverflow: Boolean = true): Decimal = {
|
||||
val copy = clone()
|
||||
if (copy.changePrecision(precision, scale, roundMode)) copy else null
|
||||
if (copy.changePrecision(precision, scale, roundMode)) {
|
||||
copy
|
||||
} else {
|
||||
if (nullOnOverflow) {
|
||||
null
|
||||
} else {
|
||||
throw new ArithmeticException(
|
||||
s"$toDebugString cannot be represented as Decimal($precision, $scale).")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -45,18 +45,26 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
|
||||
test("CheckOverflow") {
|
||||
val d1 = Decimal("10.1")
|
||||
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal("10"))
|
||||
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1)
|
||||
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1)
|
||||
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null)
|
||||
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0), true), Decimal("10"))
|
||||
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1), true), d1)
|
||||
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2), true), d1)
|
||||
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3), true), null)
|
||||
intercept[ArithmeticException](CheckOverflow(Literal(d1), DecimalType(4, 3), false).eval())
|
||||
intercept[ArithmeticException](checkEvaluationWithMutableProjection(
|
||||
CheckOverflow(Literal(d1), DecimalType(4, 3), false), null))
|
||||
|
||||
val d2 = Decimal(101, 3, 1)
|
||||
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal("10"))
|
||||
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2)
|
||||
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2)
|
||||
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null)
|
||||
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0), true), Decimal("10"))
|
||||
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1), true), d2)
|
||||
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2), true), d2)
|
||||
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3), true), null)
|
||||
intercept[ArithmeticException](CheckOverflow(Literal(d2), DecimalType(4, 3), false).eval())
|
||||
intercept[ArithmeticException](checkEvaluationWithMutableProjection(
|
||||
CheckOverflow(Literal(d2), DecimalType(4, 3), false), null))
|
||||
|
||||
checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2)), null)
|
||||
checkEvaluation(CheckOverflow(
|
||||
Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2), true), null)
|
||||
checkEvaluation(CheckOverflow(
|
||||
Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2), false), null)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -83,4 +83,28 @@ select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.1
|
|||
select 123456789123456789.1234567890 * 1.123456789123456789;
|
||||
select 12345678912345.123456789123 / 0.000000012345678;
|
||||
|
||||
-- throw an exception instead of returning NULL, according to SQL ANSI 2011
|
||||
set spark.sql.decimalOperations.nullOnOverflow=false;
|
||||
|
||||
-- test operations between decimals and constants
|
||||
select id, a*10, b/10 from decimals_test order by id;
|
||||
|
||||
-- test operations on constants
|
||||
select 10.3 * 3.0;
|
||||
select 10.3000 * 3.0;
|
||||
select 10.30000 * 30.0;
|
||||
select 10.300000000000000000 * 3.000000000000000000;
|
||||
select 10.300000000000000000 * 3.0000000000000000000;
|
||||
|
||||
-- arithmetic operations causing an overflow throw exception
|
||||
select (5e36 + 0.1) + 5e36;
|
||||
select (-4e36 - 0.1) - 7e36;
|
||||
select 12345678901234567890.0 * 12345678901234567890.0;
|
||||
select 1e35 / 0.1;
|
||||
|
||||
-- arithmetic operations causing a precision loss throw exception
|
||||
select 123456789123456789.1234567890 * 1.123456789123456789;
|
||||
select 123456789123456789.1234567890 * 1.123456789123456789;
|
||||
select 12345678912345.123456789123 / 0.000000012345678;
|
||||
|
||||
drop table decimals_test;
|
|
@ -1,5 +1,5 @@
|
|||
-- Automatically generated by SQLQueryTestSuite
|
||||
-- Number of queries: 40
|
||||
-- Number of queries: 54
|
||||
|
||||
|
||||
-- !query 0
|
||||
|
@ -328,8 +328,131 @@ NULL
|
|||
|
||||
|
||||
-- !query 39
|
||||
drop table decimals_test
|
||||
set spark.sql.decimalOperations.nullOnOverflow=false
|
||||
-- !query 39 schema
|
||||
struct<>
|
||||
struct<key:string,value:string>
|
||||
-- !query 39 output
|
||||
spark.sql.decimalOperations.nullOnOverflow false
|
||||
|
||||
|
||||
-- !query 40
|
||||
select id, a*10, b/10 from decimals_test order by id
|
||||
-- !query 40 schema
|
||||
struct<id:int,(CAST(a AS DECIMAL(38,18)) * CAST(CAST(10 AS DECIMAL(2,0)) AS DECIMAL(38,18))):decimal(38,18),(CAST(b AS DECIMAL(38,18)) / CAST(CAST(10 AS DECIMAL(2,0)) AS DECIMAL(38,18))):decimal(38,19)>
|
||||
-- !query 40 output
|
||||
1 1000 99.9
|
||||
2 123451.23 1234.5123
|
||||
3 1.234567891011 123.41
|
||||
4 1234567891234567890 0.1123456789123456789
|
||||
|
||||
|
||||
-- !query 41
|
||||
select 10.3 * 3.0
|
||||
-- !query 41 schema
|
||||
struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)>
|
||||
-- !query 41 output
|
||||
30.9
|
||||
|
||||
|
||||
-- !query 42
|
||||
select 10.3000 * 3.0
|
||||
-- !query 42 schema
|
||||
struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)>
|
||||
-- !query 42 output
|
||||
30.9
|
||||
|
||||
|
||||
-- !query 43
|
||||
select 10.30000 * 30.0
|
||||
-- !query 43 schema
|
||||
struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)>
|
||||
-- !query 43 output
|
||||
309
|
||||
|
||||
|
||||
-- !query 44
|
||||
select 10.300000000000000000 * 3.000000000000000000
|
||||
-- !query 44 schema
|
||||
struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)>
|
||||
-- !query 44 output
|
||||
30.9
|
||||
|
||||
|
||||
-- !query 45
|
||||
select 10.300000000000000000 * 3.0000000000000000000
|
||||
-- !query 45 schema
|
||||
struct<>
|
||||
-- !query 45 output
|
||||
java.lang.ArithmeticException
|
||||
Decimal(expanded,30.900000000000000000000000000000000000,38,36}) cannot be represented as Decimal(38, 37).
|
||||
|
||||
|
||||
-- !query 46
|
||||
select (5e36 + 0.1) + 5e36
|
||||
-- !query 46 schema
|
||||
struct<>
|
||||
-- !query 46 output
|
||||
java.lang.ArithmeticException
|
||||
Decimal(expanded,10000000000000000000000000000000000000.1,39,1}) cannot be represented as Decimal(38, 1).
|
||||
|
||||
|
||||
-- !query 47
|
||||
select (-4e36 - 0.1) - 7e36
|
||||
-- !query 47 schema
|
||||
struct<>
|
||||
-- !query 47 output
|
||||
java.lang.ArithmeticException
|
||||
Decimal(expanded,-11000000000000000000000000000000000000.1,39,1}) cannot be represented as Decimal(38, 1).
|
||||
|
||||
|
||||
-- !query 48
|
||||
select 12345678901234567890.0 * 12345678901234567890.0
|
||||
-- !query 48 schema
|
||||
struct<>
|
||||
-- !query 48 output
|
||||
java.lang.ArithmeticException
|
||||
Decimal(expanded,1.5241578753238836750190519987501905210E+38,38,-1}) cannot be represented as Decimal(38, 2).
|
||||
|
||||
|
||||
-- !query 49
|
||||
select 1e35 / 0.1
|
||||
-- !query 49 schema
|
||||
struct<>
|
||||
-- !query 49 output
|
||||
java.lang.ArithmeticException
|
||||
Decimal(expanded,1000000000000000000000000000000000000,37,0}) cannot be represented as Decimal(38, 3).
|
||||
|
||||
|
||||
-- !query 50
|
||||
select 123456789123456789.1234567890 * 1.123456789123456789
|
||||
-- !query 50 schema
|
||||
struct<>
|
||||
-- !query 50 output
|
||||
java.lang.ArithmeticException
|
||||
Decimal(expanded,138698367904130467.65432098851562262075,38,20}) cannot be represented as Decimal(38, 28).
|
||||
|
||||
|
||||
-- !query 51
|
||||
select 123456789123456789.1234567890 * 1.123456789123456789
|
||||
-- !query 51 schema
|
||||
struct<>
|
||||
-- !query 51 output
|
||||
java.lang.ArithmeticException
|
||||
Decimal(expanded,138698367904130467.65432098851562262075,38,20}) cannot be represented as Decimal(38, 28).
|
||||
|
||||
|
||||
-- !query 52
|
||||
select 12345678912345.123456789123 / 0.000000012345678
|
||||
-- !query 52 schema
|
||||
struct<>
|
||||
-- !query 52 output
|
||||
java.lang.ArithmeticException
|
||||
Decimal(expanded,1000000073899961059796.7258663315210392,38,16}) cannot be represented as Decimal(38, 18).
|
||||
|
||||
|
||||
-- !query 53
|
||||
drop table decimals_test
|
||||
-- !query 53 schema
|
||||
struct<>
|
||||
-- !query 53 output
|
||||
|
Loading…
Reference in a new issue