[SPARK-28224][SQL] Check overflow in decimal Sum aggregate
## What changes were proposed in this pull request? - Currently `sum` in aggregates for decimal type can overflow and return null. - `Sum` expression codegens arithmetic on `sql.Decimal` and the output which preserves scale and precision goes into `UnsafeRowWriter`. Here overflowing will be converted to null when writing out. - It also does not go through this branch in `DecimalAggregates` because it's expecting precision of the sum (not the elements to be summed) to be less than 5.4ebff5b6d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala (L1400-L1403)
- This PR adds the check at the final result of the sum operator itself.4ebff5b6d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala (L372-L376)
https://issues.apache.org/jira/browse/SPARK-28224 ## How was this patch tested? - Added an integration test on dataframe suite cc mgaido91 JoshRosen Closes #25033 from mickjermsurawong-stripe/SPARK-28224. Authored-by: Mick Jermsurawong <mickjermsurawong@stripe.com> Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
parent
26f344354b
commit
b79cf0d143
|
@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
|||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.util.TypeUtils
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
@ExpressionDescription(
|
||||
|
@ -89,5 +90,9 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
|
|||
)
|
||||
}
|
||||
|
||||
override lazy val evaluateExpression: Expression = sum
|
||||
override lazy val evaluateExpression: Expression = resultType match {
|
||||
case d: DecimalType => CheckOverflow(sum, d, SQLConf.get.decimalOperationsNullOnOverflow)
|
||||
case _ => sum
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExc
|
|||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession}
|
||||
import org.apache.spark.sql.test.SQLTestData.{NullStrings, TestData2}
|
||||
import org.apache.spark.sql.test.SQLTestData.{DecimalData, NullStrings, TestData2}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.Utils
|
||||
import org.apache.spark.util.random.XORShiftRandom
|
||||
|
@ -156,6 +156,27 @@ class DataFrameSuite extends QueryTest with SharedSparkSession {
|
|||
structDf.select(xxhash64($"a", $"record.*")))
|
||||
}
|
||||
|
||||
test("SPARK-28224: Aggregate sum big decimal overflow") {
|
||||
val largeDecimals = spark.sparkContext.parallelize(
|
||||
DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) ::
|
||||
DecimalData(BigDecimal("9"* 20 + ".123"), BigDecimal("9"* 20 + ".123")) :: Nil).toDF()
|
||||
|
||||
Seq(true, false).foreach { nullOnOverflow =>
|
||||
withSQLConf((SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key, nullOnOverflow.toString)) {
|
||||
val structDf = largeDecimals.select("a").agg(sum("a"))
|
||||
if (nullOnOverflow) {
|
||||
checkAnswer(structDf, Row(null))
|
||||
} else {
|
||||
val e = intercept[SparkException] {
|
||||
structDf.collect
|
||||
}
|
||||
assert(e.getCause.getClass.equals(classOf[ArithmeticException]))
|
||||
assert(e.getCause.getMessage.contains("cannot be represented as Decimal"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("Star Expansion - explode should fail with a meaningful message if it takes a star") {
|
||||
val df = Seq(("1,2"), ("4"), ("7,8,9")).toDF("csv")
|
||||
val e = intercept[AnalysisException] {
|
||||
|
|
Loading…
Reference in a new issue