[SPARK-28389][SQL] Use Java 8 API in add_months

## What changes were proposed in this pull request?

In the PR, I propose to use the `plusMonths()` method of `LocalDate` to add months to a date. This method adds the specified amount to the months field of `LocalDate` in three steps:
1. Add the input months to the month-of-year field
2. Check if the resulting date would be invalid
3. Adjust the day-of-month to the last valid day if necessary

The difference between current behavior and propose one is in handling the last day of month in the original date. For example, adding 1 month to `2019-02-28` will produce `2019-03-28` comparing to the current implementation where the result is `2019-03-31`.

The proposed behavior is implemented in MySQL and PostgreSQL.

## How was this patch tested?

By existing test suites `DateExpressionsSuite`, `DateFunctionsSuite` and `DateTimeUtilsSuite`.

Closes #25153 from MaxGekk/add-months.

Authored-by: Maxim Gekk <max.gekk@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Maxim Gekk 2019-07-15 20:49:39 +08:00 committed by Wenchen Fan
parent a7a02a86ad
commit f241fc7776
6 changed files with 15 additions and 66 deletions

View file

@ -151,6 +151,8 @@ license: |
- Since Spark 3.0, substitution order of nested WITH clauses is changed and an inner CTE definition takes precedence over an outer. In version 2.4 and earlier, `WITH t AS (SELECT 1), t2 AS (WITH t AS (SELECT 2) SELECT * FROM t) SELECT * FROM t2` returns `1` while in version 3.0 it returns `2`. The previous behaviour can be restored by setting `spark.sql.legacy.ctePrecedence.enabled` to `true`.
- Since Spark 3.0, the `add_months` function adjusts the resulting date to a last day of month only if it is invalid. For example, `select add_months(DATE'2019-01-31', 1)` results `2019-02-28`. In Spark version 2.4 and earlier, the resulting date is adjusted when it is invalid, or the original date is a last day of months. For example, adding a month to `2019-02-28` resultes in `2019-03-31`.
## Upgrading from Spark SQL 2.4 to 2.4.1
- The value of `spark.executor.heartbeatInterval`, when specified without units like "30" rather than "30s", was

View file

@ -2641,8 +2641,8 @@ object Sequence {
while (t < exclusiveItem ^ stepSign < 0) {
arr(i) = fromLong(t / scale)
t = timestampAddInterval(t, stepMonths, stepMicros, timeZone)
i += 1
t = timestampAddInterval(startMicros, i * stepMonths, i * stepMicros, timeZone)
}
// truncate array to the correct length
@ -2676,12 +2676,6 @@ object Sequence {
|${genSequenceLengthCode(ctx, startMicros, stopMicros, intervalInMicros, arrLength)}
""".stripMargin
val timestampAddIntervalCode =
s"""
|$t = org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval(
| $t, $stepMonths, $stepMicros, $genTimeZone);
""".stripMargin
s"""
|final int $stepMonths = $step.months;
|final long $stepMicros = $step.microseconds;
@ -2705,8 +2699,9 @@ object Sequence {
|
| while ($t < $exclusiveItem ^ $stepSign < 0) {
| $arr[$i] = ($elemType) ($t / ${scale}L);
| $timestampAddIntervalCode
| $i += 1;
| $t = org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval(
| $startMicros, $i * $stepMonths, $i * $stepMicros, $genTimeZone);
| }
|
| if ($arr.length > $i) {

View file

@ -505,60 +505,12 @@ object DateTimeUtils {
LocalDate.ofEpochDay(date).getDayOfMonth
}
/**
* The number of days for each month (not leap year)
*/
private val monthDays = Array(31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31)
/**
* Returns the date value for the first day of the given month.
* The month is expressed in months since year zero (17999 BC), starting from 0.
*/
private def firstDayOfMonth(absoluteMonth: Int): SQLDate = {
val absoluteYear = absoluteMonth / 12
var monthInYear = absoluteMonth - absoluteYear * 12
var date = getDateFromYear(absoluteYear)
if (monthInYear >= 2 && isLeap(absoluteYear + YearZero)) {
date += 1
}
while (monthInYear > 0) {
date += monthDays(monthInYear - 1)
monthInYear -= 1
}
date
}
/**
* Returns the date value for January 1 of the given year.
* The year is expressed in years since year zero (17999 BC), starting from 0.
*/
private def getDateFromYear(absoluteYear: Int): SQLDate = {
val absoluteDays = (absoluteYear * 365 + absoluteYear / 400 - absoluteYear / 100
+ absoluteYear / 4)
absoluteDays - toYearZero
}
/**
* Add date and year-month interval.
* Returns a date value, expressed in days since 1.1.1970.
*/
def dateAddMonths(days: SQLDate, months: Int): SQLDate = {
val (year, monthInYear, dayOfMonth, daysToMonthEnd) = splitDate(days)
val absoluteMonth = (year - YearZero) * 12 + monthInYear - 1 + months
val nonNegativeMonth = if (absoluteMonth >= 0) absoluteMonth else 0
val currentMonthInYear = nonNegativeMonth % 12
val currentYear = nonNegativeMonth / 12
val leapDay = if (currentMonthInYear == 1 && isLeap(currentYear + YearZero)) 1 else 0
val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay
val currentDayInMonth = if (daysToMonthEnd == 0 || dayOfMonth >= lastDayOfMonth) {
// last day of the month
lastDayOfMonth
} else {
dayOfMonth
}
firstDayOfMonth(nonNegativeMonth) + currentDayInMonth - 1
LocalDate.ofEpochDay(days).plusMonths(months).toEpochDay.toInt
}
/**

View file

@ -463,7 +463,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(AddMonths(Literal.create(null, DateType), Literal.create(null, IntegerType)),
null)
checkEvaluation(
AddMonths(Literal(Date.valueOf("2015-01-30")), Literal(Int.MinValue)), -7293498)
AddMonths(Literal(Date.valueOf("2015-01-30")), Literal(Int.MinValue)), -938165455)
checkEvaluation(
AddMonths(Literal(Date.valueOf("2016-02-28")), positiveIntLit), 1014213)
checkEvaluation(

View file

@ -359,18 +359,18 @@ class DateTimeUtilsSuite extends SparkFunSuite {
test("date add months") {
val input = days(1997, 2, 28, 10, 30)
assert(dateAddMonths(input, 36) === days(2000, 2, 29))
assert(dateAddMonths(input, -13) === days(1996, 1, 31))
assert(dateAddMonths(input, 36) === days(2000, 2, 28))
assert(dateAddMonths(input, -13) === days(1996, 1, 28))
}
test("timestamp add months") {
val ts1 = date(1997, 2, 28, 10, 30, 0)
val ts2 = date(2000, 2, 29, 10, 30, 0, 123000)
val ts2 = date(2000, 2, 28, 10, 30, 0, 123000)
assert(timestampAddInterval(ts1, 36, 123000, defaultTz) === ts2)
val ts3 = date(1997, 2, 27, 16, 0, 0, 0, TimeZonePST)
val ts4 = date(2000, 2, 27, 16, 0, 0, 123000, TimeZonePST)
val ts5 = date(2000, 2, 29, 0, 0, 0, 123000, TimeZoneGMT)
val ts5 = date(2000, 2, 28, 0, 0, 0, 123000, TimeZoneGMT)
assert(timestampAddInterval(ts3, 36, 123000, TimeZonePST) === ts4)
assert(timestampAddInterval(ts3, 36, 123000, TimeZoneGMT) === ts5)
}

View file

@ -301,11 +301,11 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext {
val df = Seq((1, t1, d1), (3, t2, d2)).toDF("n", "t", "d")
checkAnswer(
df.selectExpr(s"d - $i"),
Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-12-30"))))
Seq(Row(Date.valueOf("2015-07-29")), Row(Date.valueOf("2015-12-28"))))
checkAnswer(
df.selectExpr(s"t - $i"),
Seq(Row(Timestamp.valueOf("2015-07-31 23:59:59")),
Row(Timestamp.valueOf("2015-12-31 00:00:00"))))
Row(Timestamp.valueOf("2015-12-29 00:00:00"))))
}
test("function add_months") {
@ -314,10 +314,10 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext {
val df = Seq((1, d1), (2, d2)).toDF("n", "d")
checkAnswer(
df.select(add_months(col("d"), 1)),
Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2015-03-31"))))
Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2015-03-28"))))
checkAnswer(
df.selectExpr("add_months(d, -1)"),
Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-01-31"))))
Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-01-28"))))
}
test("function months_between") {