[SPARK-35177][SQL] Fix arithmetic overflow in parsing the minimal interval by IntervalUtils.fromYearMonthString

### What changes were proposed in this pull request?
IntervalUtils.fromYearMonthString should handle Int.MinValue months correctly.
In current logic, just use `Math.addExact(Math.multiplyExact(years, 12), months)` to calculate  negative total months will overflow when actual total months is Int.MinValue, this pr fixes this bug.

### Why are the changes needed?
IntervalUtils.fromYearMonthString should handle Int.MinValue months correctly

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Added UT

Closes #32281 from AngersZhuuuu/SPARK-35177.

Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
This commit is contained in:
Angerszhuuuu 2021-04-22 11:59:38 +03:00 committed by Max Gekk
parent b17a0e6931
commit bb5459fb26
2 changed files with 19 additions and 7 deletions

View file

@ -100,12 +100,11 @@ object IntervalUtils {
*/ */
def fromYearMonthString(input: String): CalendarInterval = { def fromYearMonthString(input: String): CalendarInterval = {
require(input != null, "Interval year-month string must be not null") require(input != null, "Interval year-month string must be not null")
def toInterval(yearStr: String, monthStr: String): CalendarInterval = { def toInterval(yearStr: String, monthStr: String, sign: Int): CalendarInterval = {
try { try {
val years = toLongWithRange(YEAR, yearStr, 0, Integer.MAX_VALUE).toInt val years = toLongWithRange(YEAR, yearStr, 0, Integer.MAX_VALUE / MONTHS_PER_YEAR)
val months = toLongWithRange(MONTH, monthStr, 0, 11).toInt val totalMonths = sign * (years * MONTHS_PER_YEAR + toLongWithRange(MONTH, monthStr, 0, 11))
val totalMonths = Math.addExact(Math.multiplyExact(years, 12), months) new CalendarInterval(Math.toIntExact(totalMonths), 0, 0)
new CalendarInterval(totalMonths, 0, 0)
} catch { } catch {
case NonFatal(e) => case NonFatal(e) =>
throw new IllegalArgumentException( throw new IllegalArgumentException(
@ -114,9 +113,9 @@ object IntervalUtils {
} }
input.trim match { input.trim match {
case yearMonthPattern("-", yearStr, monthStr) => case yearMonthPattern("-", yearStr, monthStr) =>
negateExact(toInterval(yearStr, monthStr)) toInterval(yearStr, monthStr, -1)
case yearMonthPattern(_, yearStr, monthStr) => case yearMonthPattern(_, yearStr, monthStr) =>
toInterval(yearStr, monthStr) toInterval(yearStr, monthStr, 1)
case _ => case _ =>
throw new IllegalArgumentException( throw new IllegalArgumentException(
s"Interval string does not match year-month format of 'y-m': $input") s"Interval string does not match year-month format of 'y-m': $input")

View file

@ -169,6 +169,19 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper {
fromYearMonthString) fromYearMonthString)
failFuncWithInvalidInput("-\t99-15", "Interval string does not match year-month format", failFuncWithInvalidInput("-\t99-15", "Interval string does not match year-month format",
fromYearMonthString) fromYearMonthString)
assert(fromYearMonthString("178956970-6") == new CalendarInterval(Int.MaxValue - 1, 0, 0))
assert(fromYearMonthString("178956970-7") == new CalendarInterval(Int.MaxValue, 0, 0))
val e1 = intercept[IllegalArgumentException]{
assert(fromYearMonthString("178956970-8") == new CalendarInterval(Int.MinValue, 0, 0))
}.getMessage
assert(e1.contains("integer overflow"))
assert(fromYearMonthString("-178956970-8") == new CalendarInterval(Int.MinValue, 0, 0))
val e2 = intercept[IllegalArgumentException]{
assert(fromYearMonthString("-178956970-9") == new CalendarInterval(Int.MinValue, 0, 0))
}.getMessage
assert(e2.contains("integer overflow"))
} }
test("from day-time string - legacy") { test("from day-time string - legacy") {