diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 0fd2201213..c52578d913 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -73,6 +73,7 @@ object Cast { case (TimestampType, DateType) => true case (StringType, CalendarIntervalType) => true + case (StringType, DayTimeIntervalType) => true case (StringType, YearMonthIntervalType) => true case (StringType, _: NumericType) => true @@ -535,9 +536,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[UTF8String](_, s => IntervalUtils.safeStringToInterval(s)) } + private[this] def castToDayTimeInterval(from: DataType): Any => Any = from match { + case StringType => buildCast[UTF8String](_, s => IntervalUtils.castStringToDTInterval(s)) + } + private[this] def castToYearMonthInterval(from: DataType): Any => Any = from match { - case StringType => - buildCast[UTF8String](_, s => IntervalUtils.castStringToYMInterval(s)) + case StringType => buildCast[UTF8String](_, s => IntervalUtils.castStringToYMInterval(s)) } // LongConverter @@ -844,6 +848,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case decimal: DecimalType => castToDecimal(from, decimal) case TimestampType => castToTimestamp(from) case CalendarIntervalType => castToInterval(from) + case DayTimeIntervalType => castToDayTimeInterval(from) case YearMonthIntervalType => castToYearMonthInterval(from) case BooleanType => castToBoolean(from) case ByteType => castToByte(from) @@ -903,6 +908,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case decimal: DecimalType => castToDecimalCode(from, decimal, ctx) case TimestampType => castToTimestampCode(from, ctx) case CalendarIntervalType => castToIntervalCode(from) + case DayTimeIntervalType => castToDayTimeIntervalCode(from) case YearMonthIntervalType => castToYearMonthIntervalCode(from) case BooleanType => castToBooleanCode(from) case ByteType => castToByteCode(from, ctx) @@ -1362,6 +1368,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } + private[this] def castToDayTimeIntervalCode(from: DataType): CastFunction = from match { + case StringType => + val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") + (c, evPrim, _) => code"$evPrim = $util.castStringToDTInterval($c);" + } + private[this] def castToYearMonthIntervalCode(from: DataType): CastFunction = from match { case StringType => val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") @@ -1929,6 +1941,7 @@ object AnsiCast { case (DateType, TimestampType) => true case (StringType, _: CalendarIntervalType) => true + case (StringType, DayTimeIntervalType) => true case (StringType, YearMonthIntervalType) => true case (StringType, DateType) => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 461cd347af..f08f77ac28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -150,6 +150,58 @@ object IntervalUtils { } } + private val unquotedDaySecondPattern = + "([+|-])?(\\d+) (\\d{1,2}):(\\d{1,2}):(\\d{1,2})(\\.\\d{1,9})?" + private val quotedDaySecondPattern = (s"^$unquotedDaySecondPattern$$").r + private val daySecondLiteralPattern = + (s"(?i)^INTERVAL\\s+([+|-])?\\'$unquotedDaySecondPattern\\'\\s+DAY\\s+TO\\s+SECOND$$").r + + def castStringToDTInterval(input: UTF8String): Long = { + def secondAndMicro(second: String, micro: String): String = { + if (micro != null) { + s"$second$micro" + } else { + second + } + } + + input.trimAll().toString match { + case quotedDaySecondPattern("-", day, hour, minute, second, micro) => + toDTInterval(day, hour, minute, secondAndMicro(second, micro), -1) + case quotedDaySecondPattern(_, day, hour, minute, second, micro) => + toDTInterval(day, hour, minute, secondAndMicro(second, micro), 1) + case daySecondLiteralPattern(firstSign, secondSign, day, hour, minute, second, micro) => + (firstSign, secondSign) match { + case ("-", "-") => toDTInterval(day, hour, minute, secondAndMicro(second, micro), 1) + case ("-", _) => toDTInterval(day, hour, minute, secondAndMicro(second, micro), -1) + case (_, "-") => toDTInterval(day, hour, minute, secondAndMicro(second, micro), -1) + case (_, _) => toDTInterval(day, hour, minute, secondAndMicro(second, micro), 1) + } + case _ => + throw new IllegalArgumentException( + s"Interval string must match day-time format of `d h:m:s.n` " + + s"or `INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND`: ${input.toString}, " + + s"$fallbackNotice") + } + } + + def toDTInterval( + dayStr: String, + hourStr: String, + minuteStr: String, + secondStr: String, + sign: Int): Long = { + var micros = 0L + val days = toLongWithRange(DAY, dayStr, 0, Int.MaxValue).toInt + micros = Math.addExact(micros, sign * days * MICROS_PER_DAY) + val hours = toLongWithRange(HOUR, hourStr, 0, 23) + micros = Math.addExact(micros, sign * hours * MICROS_PER_HOUR) + val minutes = toLongWithRange(MINUTE, minuteStr, 0, 59) + micros = Math.addExact(micros, sign * minutes * MICROS_PER_MINUTE) + micros = Math.addExact(micros, sign * parseSecondNano(secondStr)) + micros + } + /** * Parse dayTime string in form: [-]d HH:mm:ss.nnnnnnnnn and [-]HH:mm:ss.nnnnnnnnn * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index e6874c6180..cf7be47026 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -1775,6 +1775,48 @@ class CastSuite extends CastSuiteBase { } } + test("SPARK-35112: Cast string to day-time interval") { + checkEvaluation(cast(Literal.create("0 0:0:0"), DayTimeIntervalType), 0L) + checkEvaluation(cast(Literal.create(" interval '0 0:0:0' Day TO second "), + DayTimeIntervalType), 0L) + checkEvaluation(cast(Literal.create("INTERVAL '1 2:03:04' DAY TO SECOND"), + DayTimeIntervalType), 93784000000L) + checkEvaluation(cast(Literal.create("INTERVAL '1 03:04:00' DAY TO SECOND"), + DayTimeIntervalType), 97440000000L) + checkEvaluation(cast(Literal.create("INTERVAL '1 03:04:00.0000' DAY TO SECOND"), + DayTimeIntervalType), 97440000000L) + checkEvaluation(cast(Literal.create("1 2:03:04"), DayTimeIntervalType), 93784000000L) + checkEvaluation(cast(Literal.create("INTERVAL '-10 2:03:04' DAY TO SECOND"), + DayTimeIntervalType), -871384000000L) + checkEvaluation(cast(Literal.create("-10 2:03:04"), DayTimeIntervalType), -871384000000L) + checkEvaluation(cast(Literal.create("-106751991 04:00:54.775808"), DayTimeIntervalType), + Long.MinValue) + checkEvaluation(cast(Literal.create("106751991 04:00:54.775807"), DayTimeIntervalType), + Long.MaxValue) + + Seq("-106751991 04:00:54.775808", "106751991 04:00:54.775807").foreach { interval => + val ansiInterval = s"INTERVAL '$interval' DAY TO SECOND" + checkEvaluation( + cast(cast(Literal.create(interval), DayTimeIntervalType), StringType), ansiInterval) + checkEvaluation(cast(cast(Literal.create(ansiInterval), + DayTimeIntervalType), StringType), ansiInterval) + } + + Seq("INTERVAL '-106751991 04:00:54.775809' YEAR TO MONTH", + "INTERVAL '106751991 04:00:54.775808' YEAR TO MONTH").foreach { interval => + val e = intercept[IllegalArgumentException] { + cast(Literal.create(interval), DayTimeIntervalType).eval() + }.getMessage + assert(e.contains("Interval string must match day-time format of")) + } + + Seq(Byte.MaxValue, Short.MaxValue, Int.MaxValue, Long.MaxValue, Long.MinValue + 1, + Long.MinValue).foreach { duration => + val interval = Literal.create(Duration.of(duration, ChronoUnit.MICROS), DayTimeIntervalType) + checkEvaluation(cast(cast(interval, StringType), DayTimeIntervalType), duration) + } + } + test("SPARK-35111: Cast string to year-month interval") { checkEvaluation(cast(Literal.create("INTERVAL '1-0' YEAR TO MONTH"), YearMonthIntervalType), 12)