[SPARK-24123][SQL] Fix precision issues in monthsBetween with more than 8 digits

## What changes were proposed in this pull request?

SPARK-23902 introduced the ability to retrieve more than 8 digits in `monthsBetween`. Unfortunately, current implementation can cause precision loss in such a case. This was causing also a flaky UT.

This PR mirrors Hive's implementation in order to avoid precision loss also when more than 8 digits are returned.

## How was this patch tested?

running 10000000 times the flaky UT

Author: Marco Gaido <marcogaido91@gmail.com>

Closes #21196 from mgaido91/SPARK-24123.
This commit is contained in:
Marco Gaido 2018-05-02 13:49:15 -07:00 committed by gatorsmile
parent 8bd27025b7
commit 504c9cfd21

View file

@ -888,14 +888,19 @@ object DateTimeUtils {
val months1 = year1 * 12 + monthInYear1
val months2 = year2 * 12 + monthInYear2
val monthDiff = (months1 - months2).toDouble
if (dayInMonth1 == dayInMonth2 || ((daysToMonthEnd1 == 0) && (daysToMonthEnd2 == 0))) {
return (months1 - months2).toDouble
return monthDiff
}
// milliseconds is enough for 8 digits precision on the right side
val timeInDay1 = millis1 - daysToMillis(date1, timeZone)
val timeInDay2 = millis2 - daysToMillis(date2, timeZone)
val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY
val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0
// using milliseconds can cause precision loss with more than 8 digits
// we follow Hive's implementation which uses seconds
val secondsInDay1 = (millis1 - daysToMillis(date1, timeZone)) / 1000L
val secondsInDay2 = (millis2 - daysToMillis(date2, timeZone)) / 1000L
val secondsDiff = (dayInMonth1 - dayInMonth2) * SECONDS_PER_DAY + secondsInDay1 - secondsInDay2
// 2678400D is the number of seconds in 31 days
// every month is considered to be 31 days long in this function
val diff = monthDiff + secondsDiff / 2678400D
if (roundOff) {
// rounding to 8 digits
math.round(diff * 1e8) / 1e8