From 42275bb20d6849ee9df488d9ec1fa402f394ac89 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Sun, 18 Jul 2021 20:46:23 +0300 Subject: [PATCH] [SPARK-36090][SQL] Support TimestampNTZType in expression Sequence ### What changes were proposed in this pull request? The current implement of `Sequence` accept `TimestampType`, `DateType` and `IntegralType`. This PR will let `Sequence` accepts `TimestampNTZType`. ### Why are the changes needed? We can generate sequence for timestamp without time zone. ### Does this PR introduce _any_ user-facing change? 'Yes'. This PR will let `Sequence` accepts `TimestampNTZType`. ### How was this patch tested? New tests. Closes #33360 from beliefer/SPARK-36090. Lead-authored-by: gengjiaan Co-authored-by: Jiaan Geng Signed-off-by: Max Gekk --- .../expressions/collectionOperations.scala | 48 ++++--- .../sql/catalyst/util/DateTimeUtils.scala | 21 ++- .../CollectionExpressionsSuite.scala | 122 +++++++++++++++++- 3 files changed, 172 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2883d8dded..730b8d0f34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2568,7 +2568,7 @@ case class Sequence( val typesCorrect = startType.sameType(stop.dataType) && (startType match { - case TimestampType => + case TimestampType | TimestampNTZType => stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType) || YearMonthIntervalType.acceptsType(stepType) || DayTimeIntervalType.acceptsType(stepType) @@ -2614,20 +2614,20 @@ case class Sequence( val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe)) new IntegralSequenceImpl(iType)(ct, iType.integral) - case TimestampType => + case TimestampType | TimestampNTZType => if (stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepOpt.get.dataType)) { - new TemporalSequenceImpl[Long](LongType, 1, identity, zoneId) + new TemporalSequenceImpl[Long](LongType, start.dataType, 1, identity, zoneId) } else if (YearMonthIntervalType.acceptsType(stepOpt.get.dataType)) { - new PeriodSequenceImpl[Long](LongType, 1, identity, zoneId) + new PeriodSequenceImpl[Long](LongType, start.dataType, 1, identity, zoneId) } else { - new DurationSequenceImpl[Long](LongType, 1, identity, zoneId) + new DurationSequenceImpl[Long](LongType, start.dataType, 1, identity, zoneId) } case DateType => if (stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepOpt.get.dataType)) { - new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, zoneId) + new TemporalSequenceImpl[Int](IntegerType, start.dataType, MICROS_PER_DAY, _.toInt, zoneId) } else { - new PeriodSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, zoneId) + new PeriodSequenceImpl[Int](IntegerType, start.dataType, MICROS_PER_DAY, _.toInt, zoneId) } } @@ -2769,8 +2769,9 @@ object Sequence { } private class PeriodSequenceImpl[T: ClassTag] - (dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId) - (implicit num: Integral[T]) extends InternalSequenceBase(dt, scale, fromLong, zoneId) { + (dt: IntegralType, outerDataType: DataType, scale: Long, fromLong: Long => T, zoneId: ZoneId) + (implicit num: Integral[T]) + extends InternalSequenceBase(dt, outerDataType, scale, fromLong, zoneId) { override val defaultStep: DefaultStep = new DefaultStep( (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], @@ -2794,8 +2795,9 @@ object Sequence { } private class DurationSequenceImpl[T: ClassTag] - (dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId) - (implicit num: Integral[T]) extends InternalSequenceBase(dt, scale, fromLong, zoneId) { + (dt: IntegralType, outerDataType: DataType, scale: Long, fromLong: Long => T, zoneId: ZoneId) + (implicit num: Integral[T]) + extends InternalSequenceBase(dt, outerDataType, scale, fromLong, zoneId) { override val defaultStep: DefaultStep = new DefaultStep( (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], @@ -2819,8 +2821,9 @@ object Sequence { } private class TemporalSequenceImpl[T: ClassTag] - (dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId) - (implicit num: Integral[T]) extends InternalSequenceBase(dt, scale, fromLong, zoneId) { + (dt: IntegralType, outerDataType: DataType, scale: Long, fromLong: Long => T, zoneId: ZoneId) + (implicit num: Integral[T]) + extends InternalSequenceBase(dt, outerDataType, scale, fromLong, zoneId) { override val defaultStep: DefaultStep = new DefaultStep( (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], @@ -2845,7 +2848,7 @@ object Sequence { } private abstract class InternalSequenceBase[T: ClassTag] - (dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId) + (dt: IntegralType, outerDataType: DataType, scale: Long, fromLong: Long => T, zoneId: ZoneId) (implicit num: Integral[T]) extends InternalSequence { val defaultStep: DefaultStep @@ -2859,6 +2862,11 @@ object Sequence { protected def splitStep(input: Any): (Int, Int, Long) + private val addInterval: (Long, Int, Int, Long, ZoneId) => Long = outerDataType match { + case TimestampType | DateType => timestampAddInterval + case TimestampNTZType => timestampNTZAddInterval + } + override def eval(input1: Any, input2: Any, input3: Any): Array[T] = { val start = input1.asInstanceOf[T] val stop = input2.asInstanceOf[T] @@ -2897,8 +2905,7 @@ object Sequence { while (t < exclusiveItem ^ stepSign < 0) { arr(i) = fromLong(t / scale) i += 1 - t = timestampAddInterval( - startMicros, i * stepMonths, i * stepDays, i * stepMicros, zoneId) + t = addInterval(startMicros, i * stepMonths, i * stepDays, i * stepMicros, zoneId) } // truncate array to the correct length @@ -2909,6 +2916,13 @@ object Sequence { protected def stepSplitCode( stepMonths: String, stepDays: String, stepMicros: String, step: String): String + private val addIntervalCode = outerDataType match { + case TimestampType | DateType => + "org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval" + case TimestampNTZType => + "org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampNTZAddInterval" + } + override def genCode( ctx: CodegenContext, start: String, @@ -2978,7 +2992,7 @@ object Sequence { | while ($t < $exclusiveItem ^ $stepSign < 0) { | $arr[$i] = ($elemType) ($t / ${scale}L); | $i += 1; - | $t = org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval( + | $t = $addIntervalCode( | $startMicros, $i * $stepMonths, $i * $stepDays, $i * $stepMicros, $zid); | } | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index ae444ebf3d..0825a115e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -728,7 +728,7 @@ object DateTimeUtils { } /** - * Adds a full interval (months, days, microseconds) a timestamp represented as the number of + * Adds a full interval (months, days, microseconds) to a timestamp represented as the number of * microseconds since 1970-01-01 00:00:00Z. * @return A timestamp value, expressed in microseconds since 1970-01-01 00:00:00Z. */ @@ -746,6 +746,25 @@ object DateTimeUtils { instantToMicros(resultTimestamp.toInstant) } + /** + * Adds a full interval (months, days, microseconds) to a timestamp without time zone + * represented as a local time in microsecond precision, which is independent of time zone. + * @return A timestamp without time zone value, expressed in range + * [0001-01-01T00:00:00.000000, 9999-12-31T23:59:59.999999]. + */ + def timestampNTZAddInterval( + start: Long, + months: Int, + days: Int, + microseconds: Long, + zoneId: ZoneId): Long = { + val localDateTime = microsToLocalDateTime(start) + .plusMonths(months) + .plusDays(days) + .plus(microseconds, ChronoUnit.MICROS) + localDateTimeToMicros(localDateTime) + } + /** * Adds the interval's months and days to a date expressed as days since the epoch. * @return A date value, expressed in days since 1970-01-01. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 456ccafa57..bfecbf5766 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} -import java.time.{Duration, Period} +import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone import scala.language.implicitConversions @@ -1116,6 +1116,126 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } } + test("SPARK-36090: Support TimestampNTZType in expression Sequence") { + checkEvaluation(new Sequence( + Literal(LocalDateTime.parse("2018-01-01T00:00:00")), + Literal(LocalDateTime.parse("2018-01-02T00:00:00")), + Literal(Duration.ofHours(12))), + Seq( + LocalDateTime.parse("2018-01-01T00:00:00"), + LocalDateTime.parse("2018-01-01T12:00:00"), + LocalDateTime.parse("2018-01-02T00:00:00"))) + + checkEvaluation(new Sequence( + Literal(LocalDateTime.parse("2018-01-01T00:00:00")), + Literal(LocalDateTime.parse("2018-01-02T00:00:01")), + Literal(Duration.ofHours(12))), + Seq( + LocalDateTime.parse("2018-01-01T00:00:00"), + LocalDateTime.parse("2018-01-01T12:00:00"), + LocalDateTime.parse("2018-01-02T00:00:00"))) + + checkEvaluation(new Sequence( + Literal(LocalDateTime.parse("2018-01-02T00:00:00")), + Literal(LocalDateTime.parse("2018-01-01T00:00:00")), + Literal(Duration.ofHours(-12))), + Seq( + LocalDateTime.parse("2018-01-02T00:00:00"), + LocalDateTime.parse("2018-01-01T12:00:00"), + LocalDateTime.parse("2018-01-01T00:00:00"))) + + checkEvaluation(new Sequence( + Literal(LocalDateTime.parse("2018-01-02T00:00:00")), + Literal(LocalDateTime.parse("2017-12-31T23:59:59")), + Literal(Duration.ofHours(-12))), + Seq( + LocalDateTime.parse("2018-01-02T00:00:00"), + LocalDateTime.parse("2018-01-01T12:00:00"), + LocalDateTime.parse("2018-01-01T00:00:00"))) + + checkEvaluation(new Sequence( + Literal(LocalDateTime.parse("2018-01-01T00:00:00")), + Literal(LocalDateTime.parse("2018-03-01T00:00:00")), + Literal(Period.ofMonths(1))), + Seq( + LocalDateTime.parse("2018-01-01T00:00:00"), + LocalDateTime.parse("2018-02-01T00:00:00"), + LocalDateTime.parse("2018-03-01T00:00:00"))) + + checkEvaluation(new Sequence( + Literal(LocalDateTime.parse("2018-03-01T00:00:00")), + Literal(LocalDateTime.parse("2018-01-01T00:00:00")), + Literal(Period.ofMonths(-1))), + Seq( + LocalDateTime.parse("2018-03-01T00:00:00"), + LocalDateTime.parse("2018-02-01T00:00:00"), + LocalDateTime.parse("2018-01-01T00:00:00"))) + + checkEvaluation(new Sequence( + Literal(LocalDateTime.parse("2018-01-31T00:00:00")), + Literal(LocalDateTime.parse("2018-04-30T00:00:00")), + Literal(Period.ofMonths(1))), + Seq( + LocalDateTime.parse("2018-01-31T00:00:00"), + LocalDateTime.parse("2018-02-28T00:00:00"), + LocalDateTime.parse("2018-03-31T00:00:00"), + LocalDateTime.parse("2018-04-30T00:00:00"))) + + checkEvaluation(new Sequence( + Literal(LocalDateTime.parse("2018-01-01T00:00:00")), + Literal(LocalDateTime.parse("2023-01-01T00:00:00")), + Literal(Period.of(1, 5, 0))), + Seq( + LocalDateTime.parse("2018-01-01T00:00:00.000"), + LocalDateTime.parse("2019-06-01T00:00:00.000"), + LocalDateTime.parse("2020-11-01T00:00:00.000"), + LocalDateTime.parse("2022-04-01T00:00:00.000"))) + + checkEvaluation(new Sequence( + Literal(LocalDateTime.parse("2022-04-01T00:00:00")), + Literal(LocalDateTime.parse("2017-01-01T00:00:00")), + Literal(Period.of(-1, -5, 0))), + Seq( + LocalDateTime.parse("2022-04-01T00:00:00.000"), + LocalDateTime.parse("2020-11-01T00:00:00.000"), + LocalDateTime.parse("2019-06-01T00:00:00.000"), + LocalDateTime.parse("2018-01-01T00:00:00.000"))) + + checkEvaluation(new Sequence( + Literal(LocalDateTime.parse("2018-01-01T00:00:00")), + Literal(LocalDateTime.parse("2018-01-04T00:00:00")), + Literal(Duration.ofDays(1))), + Seq( + LocalDateTime.parse("2018-01-01T00:00:00.000"), + LocalDateTime.parse("2018-01-02T00:00:00.000"), + LocalDateTime.parse("2018-01-03T00:00:00.000"), + LocalDateTime.parse("2018-01-04T00:00:00.000"))) + + checkEvaluation(new Sequence( + Literal(LocalDateTime.parse("2018-01-04T00:00:00")), + Literal(LocalDateTime.parse("2018-01-01T00:00:00")), + Literal(Duration.ofDays(-1))), + Seq( + LocalDateTime.parse("2018-01-04T00:00:00.000"), + LocalDateTime.parse("2018-01-03T00:00:00.000"), + LocalDateTime.parse("2018-01-02T00:00:00.000"), + LocalDateTime.parse("2018-01-01T00:00:00.000"))) + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(LocalDateTime.parse("2018-01-01T00:00:00")), + Literal(LocalDateTime.parse("2018-01-04T00:00:00")), + Literal(Period.ofDays(1))), + EmptyRow, s"sequence boundaries: 1514764800000000 to 1515024000000000 by 0") + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(LocalDateTime.parse("2018-01-04T00:00:00")), + Literal(LocalDateTime.parse("2018-01-01T00:00:00")), + Literal(Period.ofDays(-1))), + EmptyRow, s"sequence boundaries: 1515024000000000 to 1514764800000000 by 0") + } + test("Sequence with default step") { // +/- 1 for integral type checkEvaluation(new Sequence(Literal(1), Literal(3)), Seq(1, 2, 3))