[SPARK-28224][SQL] Check overflow in decimal Sum aggregate

## What changes were proposed in this pull request?
- Currently `sum` in aggregates for decimal type can overflow and return null.
  - `Sum` expression codegens arithmetic on `sql.Decimal` and the output which preserves scale and precision goes into `UnsafeRowWriter`. Here overflowing will be converted to null when writing out.
  - It also does not go through this branch in `DecimalAggregates` because it's expecting precision of the sum (not the elements to be summed) to be less than 5.
4ebff5b6d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala (L1400-L1403)

- This PR adds the check at the final result of the sum operator itself.
4ebff5b6d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala (L372-L376)

https://issues.apache.org/jira/browse/SPARK-28224

## How was this patch tested?

- Added an integration test on dataframe suite

cc mgaido91 JoshRosen

Closes #25033 from mickjermsurawong-stripe/SPARK-28224.

Authored-by: Mick Jermsurawong <mickjermsurawong@stripe.com>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
Mick Jermsurawong 2019-08-20 09:47:04 +09:00 committed by Takeshi Yamamuro
parent 26f344354b
commit b79cf0d143
2 changed files with 28 additions and 2 deletions

View file

@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@ExpressionDescription(
@ -89,5 +90,9 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
)
}
override lazy val evaluateExpression: Expression = sum
override lazy val evaluateExpression: Expression = resultType match {
case d: DecimalType => CheckOverflow(sum, d, SQLConf.get.decimalOperationsNullOnOverflow)
case _ => sum
}
}

View file

@ -38,7 +38,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExc
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession}
import org.apache.spark.sql.test.SQLTestData.{NullStrings, TestData2}
import org.apache.spark.sql.test.SQLTestData.{DecimalData, NullStrings, TestData2}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
@ -156,6 +156,27 @@ class DataFrameSuite extends QueryTest with SharedSparkSession {
structDf.select(xxhash64($"a", $"record.*")))
}
test("SPARK-28224: Aggregate sum big decimal overflow") {
val largeDecimals = spark.sparkContext.parallelize(
DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) ::
DecimalData(BigDecimal("9"* 20 + ".123"), BigDecimal("9"* 20 + ".123")) :: Nil).toDF()
Seq(true, false).foreach { nullOnOverflow =>
withSQLConf((SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key, nullOnOverflow.toString)) {
val structDf = largeDecimals.select("a").agg(sum("a"))
if (nullOnOverflow) {
checkAnswer(structDf, Row(null))
} else {
val e = intercept[SparkException] {
structDf.collect
}
assert(e.getCause.getClass.equals(classOf[ArithmeticException]))
assert(e.getCause.getMessage.contains("cannot be represented as Decimal"))
}
}
}
}
test("Star Expansion - explode should fail with a meaningful message if it takes a star") {
val df = Seq(("1,2"), ("4"), ("7,8,9")).toDF("csv")
val e = intercept[AnalysisException] {