[SPARK-34761][SQL] Support add/subtract of a day-time interval to/from a timestamp
### What changes were proposed in this pull request? Support `timestamp +/- day-time interval`. In the PR, I propose to extend the `TimeAdd` expression and support `DayTimeIntervalType` as the `interval` parameter. The expression invokes the new method `DateTimeUtils.timestampAddDayTime()` which splits the input day-time interval to `days` and `microsecond adjustment` of a day, and adds `days` (and the microseconds) to a local timestamp derived from the given timestamp at the given time zone. The resulted local timestamp is converted back to the offset in microseconds since the epoch. Also I updated the rules that handle `CalendarIntervalType` and produce `TimeAdd` to take into account new type `DateTimeIntervalType` for the `interval` parameter of `TimeAdd`. ### Why are the changes needed? To conform the ANSI SQL standard which requires to support such operation over timestamps and intervals: <img width="811" alt="Screenshot 2021-03-12 at 11 36 14" src="https://user-images.githubusercontent.com/1580697/111081674-865d4900-8515-11eb-86c8-3538ecaf4804.png"> ### Does this PR introduce _any_ user-facing change? Should not since new intervals have not been released yet. ### How was this patch tested? By running new tests: ``` $ build/sbt "test:testOnly *DateTimeUtilsSuite" $ build/sbt "test:testOnly *DateExpressionsSuite" $ build/sbt "test:testOnly *ColumnExpressionSuite" ``` Closes #31855 from MaxGekk/timestamp-add-day-time-interval. Authored-by: Max Gekk <max.gekk@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
0a58029d52
commit
a48b2086dd
|
@ -342,11 +342,12 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
case (YearMonthIntervalType, DateType) => DateAddYMInterval(r, l)
|
||||
case (TimestampType, YearMonthIntervalType) => TimestampAddYMInterval(l, r)
|
||||
case (YearMonthIntervalType, TimestampType) => TimestampAddYMInterval(r, l)
|
||||
case (CalendarIntervalType, CalendarIntervalType) => a
|
||||
case (CalendarIntervalType, CalendarIntervalType) |
|
||||
(DayTimeIntervalType, DayTimeIntervalType) => a
|
||||
case (DateType, CalendarIntervalType) => DateAddInterval(l, r, ansiEnabled = f)
|
||||
case (_, CalendarIntervalType) => Cast(TimeAdd(l, r), l.dataType)
|
||||
case (_, CalendarIntervalType | DayTimeIntervalType) => Cast(TimeAdd(l, r), l.dataType)
|
||||
case (CalendarIntervalType, DateType) => DateAddInterval(r, l, ansiEnabled = f)
|
||||
case (CalendarIntervalType, _) => Cast(TimeAdd(r, l), r.dataType)
|
||||
case (CalendarIntervalType | DayTimeIntervalType, _) => Cast(TimeAdd(r, l), r.dataType)
|
||||
case (DateType, dt) if dt != StringType => DateAdd(l, r)
|
||||
case (dt, DateType) if dt != StringType => DateAdd(r, l)
|
||||
case _ => a
|
||||
|
@ -356,10 +357,11 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, f)))
|
||||
case (TimestampType, YearMonthIntervalType) =>
|
||||
DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, f)))
|
||||
case (CalendarIntervalType, CalendarIntervalType) => s
|
||||
case (CalendarIntervalType, CalendarIntervalType) |
|
||||
(DayTimeIntervalType, DayTimeIntervalType) => s
|
||||
case (DateType, CalendarIntervalType) =>
|
||||
DatetimeSub(l, r, DateAddInterval(l, UnaryMinus(r, f), ansiEnabled = f))
|
||||
case (_, CalendarIntervalType) =>
|
||||
case (_, CalendarIntervalType | DayTimeIntervalType) =>
|
||||
Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r, f))), l.dataType)
|
||||
case (TimestampType, _) => SubtractTimestamps(l, r)
|
||||
case (_, TimestampType) => SubtractTimestamps(l, r)
|
||||
|
|
|
@ -1264,25 +1264,33 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S
|
|||
|
||||
override def toString: String = s"$left + $right"
|
||||
override def sql: String = s"${left.sql} + ${right.sql}"
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType)
|
||||
override def inputTypes: Seq[AbstractDataType] =
|
||||
Seq(TimestampType, TypeCollection(CalendarIntervalType, DayTimeIntervalType))
|
||||
|
||||
override def dataType: DataType = TimestampType
|
||||
|
||||
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
|
||||
copy(timeZoneId = Option(timeZoneId))
|
||||
|
||||
override def nullSafeEval(start: Any, interval: Any): Any = {
|
||||
val itvl = interval.asInstanceOf[CalendarInterval]
|
||||
DateTimeUtils.timestampAddInterval(
|
||||
start.asInstanceOf[Long], itvl.months, itvl.days, itvl.microseconds, zoneId)
|
||||
override def nullSafeEval(start: Any, interval: Any): Any = right.dataType match {
|
||||
case DayTimeIntervalType =>
|
||||
timestampAddDayTime(start.asInstanceOf[Long], interval.asInstanceOf[Long], zoneId)
|
||||
case CalendarIntervalType =>
|
||||
val i = interval.asInstanceOf[CalendarInterval]
|
||||
timestampAddInterval(start.asInstanceOf[Long], i.months, i.days, i.microseconds, zoneId)
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
|
||||
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
|
||||
defineCodeGen(ctx, ev, (sd, i) => {
|
||||
s"""$dtu.timestampAddInterval($sd, $i.months, $i.days, $i.microseconds, $zid)"""
|
||||
})
|
||||
interval.dataType match {
|
||||
case DayTimeIntervalType =>
|
||||
defineCodeGen(ctx, ev, (sd, dt) => s"""$dtu.timestampAddDayTime($sd, $dt, $zid)""")
|
||||
case CalendarIntervalType =>
|
||||
defineCodeGen(ctx, ev, (sd, i) => {
|
||||
s"""$dtu.timestampAddInterval($sd, $i.months, $i.days, $i.microseconds, $zid)"""
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -591,6 +591,28 @@ object DateTimeUtils {
|
|||
instantToMicros(microsToInstant(micros).atZone(zoneId).plusMonths(months).toInstant)
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a day-time interval expressed in microseconds to a timestamp at the given time zone.
|
||||
* It converts the input timestamp to a local timestamp, and adds the interval by:
|
||||
* - Splitting the interval to days and microsecond adjustment in a day, and
|
||||
* - First of all, it adds days and then the time part.
|
||||
* The resulted local timestamp is converted back to an instant at the given time zone.
|
||||
*
|
||||
* @param micros The input timestamp value, expressed in microseconds since 1970-01-01 00:00:00Z.
|
||||
* @param dayTime The amount of microseconds to add. It can be positive or negative.
|
||||
* @param zoneId The time zone ID at which the operation is performed.
|
||||
* @return A timestamp value, expressed in microseconds since 1970-01-01 00:00:00Z.
|
||||
*/
|
||||
def timestampAddDayTime(micros: Long, dayTime: Long, zoneId: ZoneId): Long = {
|
||||
val days = dayTime / MICROS_PER_DAY
|
||||
val microseconds = dayTime - days * MICROS_PER_DAY
|
||||
val resultTimestamp = microsToInstant(micros)
|
||||
.atZone(zoneId)
|
||||
.plusDays(days)
|
||||
.plus(microseconds, ChronoUnit.MICROS)
|
||||
instantToMicros(resultTimestamp.toInstant)
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a full interval (months, days, microseconds) a timestamp represented as the number of
|
||||
* microseconds since 1970-01-01 00:00:00Z.
|
||||
|
|
|
@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
|
||||
import java.sql.{Date, Timestamp}
|
||||
import java.text.{ParseException, SimpleDateFormat}
|
||||
import java.time.{DateTimeException, Instant, LocalDate, Period, ZoneId}
|
||||
import java.time.{DateTimeException, Duration, Instant, LocalDate, Period, ZoneId}
|
||||
import java.time.format.DateTimeParseException
|
||||
import java.time.temporal.ChronoUnit
|
||||
import java.util.{Calendar, Locale, TimeZone}
|
||||
import java.util.concurrent.TimeUnit._
|
||||
|
||||
|
@ -1538,4 +1539,59 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
TimestampType, YearMonthIntervalType)
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-34761: add a day-time interval to a timestamp") {
|
||||
val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS", Locale.US)
|
||||
for (zid <- outstandingZoneIds) {
|
||||
val timeZoneId = Option(zid.getId)
|
||||
sdf.setTimeZone(TimeZone.getTimeZone(zid))
|
||||
checkEvaluation(
|
||||
TimeAdd(
|
||||
Literal(new Timestamp(sdf.parse("2021-01-01 00:00:00.123").getTime)),
|
||||
Literal(Duration.ofDays(10).plusMinutes(10).plusMillis(321)),
|
||||
timeZoneId),
|
||||
DateTimeUtils.fromJavaTimestamp(
|
||||
new Timestamp(sdf.parse("2021-01-11 00:10:00.444").getTime)))
|
||||
checkEvaluation(
|
||||
TimeAdd(
|
||||
Literal(new Timestamp(sdf.parse("2021-01-01 00:10:00.123").getTime)),
|
||||
Literal(Duration.ofDays(-10).minusMinutes(9).minusMillis(120)),
|
||||
timeZoneId),
|
||||
DateTimeUtils.fromJavaTimestamp(
|
||||
new Timestamp(sdf.parse("2020-12-22 00:01:00.003").getTime)))
|
||||
|
||||
val e = intercept[Exception] {
|
||||
checkEvaluation(
|
||||
TimeAdd(
|
||||
Literal(new Timestamp(sdf.parse("2021-01-01 00:00:00.123").getTime)),
|
||||
Literal(Duration.of(Long.MaxValue, ChronoUnit.MICROS)),
|
||||
timeZoneId),
|
||||
null)
|
||||
}.getCause
|
||||
assert(e.isInstanceOf[ArithmeticException])
|
||||
assert(e.getMessage.contains("long overflow"))
|
||||
|
||||
checkEvaluation(
|
||||
TimeAdd(
|
||||
Literal.create(null, TimestampType),
|
||||
Literal(Duration.ofDays(1)),
|
||||
timeZoneId),
|
||||
null)
|
||||
checkEvaluation(
|
||||
TimeAdd(
|
||||
Literal(new Timestamp(sdf.parse("2021-01-01 00:00:00.123").getTime)),
|
||||
Literal.create(null, DayTimeIntervalType),
|
||||
timeZoneId),
|
||||
null)
|
||||
checkEvaluation(
|
||||
TimeAdd(
|
||||
Literal.create(null, TimestampType),
|
||||
Literal.create(null, DayTimeIntervalType),
|
||||
timeZoneId),
|
||||
null)
|
||||
checkConsistencyBetweenInterpretedAndCodegen(
|
||||
(ts: Expression, interval: Expression) => TimeAdd(ts, interval, timeZoneId),
|
||||
TimestampType, DayTimeIntervalType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -164,13 +164,10 @@ object LiteralGenerator {
|
|||
for { i <- Gen.choose(-100, 100) } yield Literal.create(i, IntegerType)
|
||||
|
||||
lazy val dayTimeIntervalLiteralGen: Gen[Literal] = {
|
||||
for {
|
||||
seconds <- Gen.choose(
|
||||
Duration.ofDays(-106751990).getSeconds,
|
||||
Duration.ofDays(106751990).getSeconds)
|
||||
nanoAdjustment <- Gen.choose(-999999000, 999999000)
|
||||
} yield {
|
||||
Literal.create(Duration.ofSeconds(seconds, nanoAdjustment), DayTimeIntervalType)
|
||||
calendarIntervalLiterGen.map { calendarIntervalLiteral =>
|
||||
Literal.create(
|
||||
calendarIntervalLiteral.value.asInstanceOf[CalendarInterval].extractAsDuration(),
|
||||
DayTimeIntervalType)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -711,4 +711,39 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
|
|||
intercept[IllegalArgumentException](getDayOfWeekFromString(UTF8String.fromString("xx")))
|
||||
intercept[IllegalArgumentException](getDayOfWeekFromString(UTF8String.fromString("\"quote")))
|
||||
}
|
||||
|
||||
test("SPARK-34761: timestamp add day-time interval") {
|
||||
// transit from Pacific Standard Time to Pacific Daylight Time
|
||||
assert(timestampAddDayTime(
|
||||
// 2019-3-9 is the end of Pacific Standard Time
|
||||
date(2019, 3, 9, 12, 0, 0, 123000, LA),
|
||||
MICROS_PER_DAY, LA) ===
|
||||
// 2019-3-10 is the start of Pacific Daylight Time
|
||||
date(2019, 3, 10, 12, 0, 0, 123000, LA))
|
||||
// just normal days
|
||||
outstandingZoneIds.foreach { zid =>
|
||||
assert(timestampAddDayTime(
|
||||
date(2021, 3, 18, 19, 44, 1, 100000, zid), 0, zid) ===
|
||||
date(2021, 3, 18, 19, 44, 1, 100000, zid))
|
||||
assert(timestampAddDayTime(
|
||||
date(2021, 1, 19, 0, 0, 0, 0, zid), -18 * MICROS_PER_DAY, zid) ===
|
||||
date(2021, 1, 1, 0, 0, 0, 0, zid))
|
||||
assert(timestampAddDayTime(
|
||||
date(2021, 3, 18, 19, 44, 1, 999999, zid), 10 * MICROS_PER_MINUTE, zid) ===
|
||||
date(2021, 3, 18, 19, 54, 1, 999999, zid))
|
||||
assert(timestampAddDayTime(
|
||||
date(2021, 3, 18, 19, 44, 1, 1, zid), -MICROS_PER_DAY - 1, zid) ===
|
||||
date(2021, 3, 17, 19, 44, 1, 0, zid))
|
||||
assert(timestampAddDayTime(
|
||||
date(2019, 5, 9, 12, 0, 0, 123456, zid), 2 * MICROS_PER_DAY + 1, zid) ===
|
||||
date(2019, 5, 11, 12, 0, 0, 123457, zid))
|
||||
}
|
||||
// transit from Pacific Daylight Time to Pacific Standard Time
|
||||
assert(timestampAddDayTime(
|
||||
// 2019-11-2 is the end of Pacific Daylight Time
|
||||
date(2019, 11, 2, 12, 0, 0, 123000, LA),
|
||||
MICROS_PER_DAY, LA) ===
|
||||
// 2019-11-3 is the start of Pacific Standard Time
|
||||
date(2019, 11, 3, 12, 0, 0, 123000, LA))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.sql
|
|||
|
||||
import java.sql.{Date, Timestamp}
|
||||
import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
|
||||
import java.time.temporal.ChronoUnit
|
||||
import java.util.Locale
|
||||
|
||||
import org.apache.hadoop.io.{LongWritable, Text}
|
||||
|
@ -2526,4 +2527,51 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
|
|||
assert(e.getMessage.contains("long overflow"))
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-34761: add/subtract a day-time interval to/from a timestamp") {
|
||||
withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") {
|
||||
outstandingZoneIds.foreach { zid =>
|
||||
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> zid.getId) {
|
||||
Seq(
|
||||
(LocalDateTime.of(1900, 1, 1, 0, 0, 0, 123456000), Duration.ofDays(0)) ->
|
||||
LocalDateTime.of(1900, 1, 1, 0, 0, 0, 123456000),
|
||||
(LocalDateTime.of(1970, 1, 1, 0, 0, 0, 100000000), Duration.ofDays(-1)) ->
|
||||
LocalDateTime.of(1969, 12, 31, 0, 0, 0, 100000000),
|
||||
(LocalDateTime.of(2021, 3, 14, 1, 2, 3), Duration.ofDays(1)) ->
|
||||
LocalDateTime.of(2021, 3, 15, 1, 2, 3),
|
||||
(LocalDateTime.of(2020, 12, 31, 23, 59, 59, 999000000),
|
||||
Duration.ofDays(2 * 30).plusMillis(1)) -> LocalDateTime.of(2021, 3, 2, 0, 0, 0),
|
||||
(LocalDateTime.of(2020, 3, 16, 0, 0, 0, 1000), Duration.of(-1, ChronoUnit.MICROS)) ->
|
||||
LocalDateTime.of(2020, 3, 16, 0, 0, 0),
|
||||
(LocalDateTime.of(2020, 2, 29, 12, 13, 14), Duration.ofDays(365)) ->
|
||||
LocalDateTime.of(2021, 2, 28, 12, 13, 14),
|
||||
(LocalDateTime.of(1582, 10, 4, 1, 2, 3, 40000000),
|
||||
Duration.ofDays(10).plusMillis(60)) ->
|
||||
LocalDateTime.of(1582, 10, 14, 1, 2, 3, 100000000)
|
||||
).foreach { case ((ldt, duration), expected) =>
|
||||
val ts = ldt.atZone(zid).toInstant
|
||||
val result = expected.atZone(zid).toInstant
|
||||
val df = Seq((ts, duration, result)).toDF("ts", "interval", "result")
|
||||
checkAnswer(
|
||||
df.select($"ts" + $"interval", $"interval" + $"ts", $"result" - $"interval"),
|
||||
Row(result, result, ts))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Seq(
|
||||
"2021-03-16T18:56:00Z" -> "ts + i",
|
||||
"1900-03-16T18:56:00Z" -> "ts - i").foreach { case (instant, op) =>
|
||||
val e = intercept[SparkException] {
|
||||
Seq(
|
||||
(Instant.parse(instant), Duration.of(Long.MaxValue, ChronoUnit.MICROS)))
|
||||
.toDF("ts", "i")
|
||||
.selectExpr(op)
|
||||
.collect()
|
||||
}.getCause
|
||||
assert(e.isInstanceOf[ArithmeticException])
|
||||
assert(e.getMessage.contains("long overflow"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue