[SPARK-35720][SQL] Support casting of String to timestamp without time zone type

### What changes were proposed in this pull request?

Extend the Cast expression and support StringType in casting to TimestampWithoutTZType.

Closes #32898

### Why are the changes needed?

To conform the ANSI SQL standard which requires to support such casting.

### Does this PR introduce _any_ user-facing change?

No, the new timestamp type is not released yet.

### How was this patch tested?

Unit test

Closes #32936 from gengliangwang/castStringToTswtz.

Authored-by: Gengliang Wang <gengliang@apache.org>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
This commit is contained in:
Gengliang Wang 2021-06-18 02:02:10 +08:00
parent 79362c4efc
commit 05e2b76852
5 changed files with 180 additions and 15 deletions

View file

@ -70,6 +70,7 @@ object Cast {
case (_: NumericType, TimestampType) => true
case (TimestampWithoutTZType, TimestampType) => true
case (StringType, TimestampWithoutTZType) => true
case (DateType, TimestampWithoutTZType) => true
case (TimestampType, TimestampWithoutTZType) => true
@ -513,6 +514,14 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
}
private[this] def castToTimestampWithoutTZ(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, utfs => {
if (ansiEnabled) {
DateTimeUtils.stringToTimestampWithoutTimeZoneAnsi(utfs)
} else {
DateTimeUtils.stringToTimestampWithoutTimeZone(utfs).orNull
}
})
case DateType =>
buildCast[Int](_, d => daysToMicros(d, ZoneOffset.UTC))
case TimestampType =>
@ -1410,6 +1419,24 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
private[this] def castToTimestampWithoutTZCode(
from: DataType,
ctx: CodegenContext): CastFunction = from match {
case StringType =>
val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]])
(c, evPrim, evNull) =>
if (ansiEnabled) {
code"""
$evPrim =
$dateTimeUtilsCls.stringToTimestampWithoutTimeZoneAnsi($c);
"""
} else {
code"""
scala.Option<Long> $longOpt = $dateTimeUtilsCls.stringToTimestampWithoutTimeZone($c);
if ($longOpt.isDefined()) {
$evPrim = ((Long) $longOpt.get()).longValue();
} else {
$evNull = true;
}
"""
}
case DateType =>
(c, evPrim, evNull) =>
code"$evPrim = $dateTimeUtilsCls.daysToMicros($c, java.time.ZoneOffset.UTC);"
@ -2016,6 +2043,7 @@ object AnsiCast {
case (DateType, TimestampType) => true
case (TimestampWithoutTZType, TimestampType) => true
case (StringType, TimestampWithoutTZType) => true
case (DateType, TimestampWithoutTZType) => true
case (TimestampType, TimestampWithoutTZType) => true

View file

@ -30,7 +30,7 @@ import sun.util.calendar.ZoneInfo
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.RebaseDateTime._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.{DateType, Decimal, TimestampType}
import org.apache.spark.sql.types.{DateType, Decimal, TimestampType, TimestampWithoutTZType}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
@ -219,7 +219,8 @@ object DateTimeUtils {
def cleanLegacyTimestampStr(s: UTF8String): UTF8String = s.replace(gmtUtf8, UTF8String.EMPTY_UTF8)
/**
* Trims and parses a given UTF8 timestamp string to the corresponding a corresponding [[Long]]
* Trims and parses a given UTF8 timestamp string to the corresponding timestamp segments,
* time zone id and whether it is just time without a date.
* value. The return type is [[Option]] in order to distinguish between 0L and null. The following
* formats are allowed:
*
@ -243,10 +244,13 @@ object DateTimeUtils {
* - +|-hh:mm:ss
* - +|-hhmmss
* - Region-based zone IDs in the form `area/city`, such as `Europe/Paris`
*
* @return timestamp segments, time zone id and whether the input is just time without a date. If
* the input string can't be parsed as timestamp, the result timestamp segments are empty.
*/
def stringToTimestamp(s: UTF8String, timeZoneId: ZoneId): Option[Long] = {
private def parseTimestampString(s: UTF8String): (Array[Int], Option[ZoneId], Boolean) = {
if (s == null) {
return None
return (Array.empty, None, false)
}
var tz: Option[String] = None
val segments: Array[Int] = Array[Int](1, 1, 1, 0, 0, 0, 0, 0, 0)
@ -267,7 +271,7 @@ object DateTimeUtils {
if (b == '-') {
if (i == 0 && j != 4) {
// year should have exact four digits
return None
return (Array.empty, None, false)
}
segments(i) = currentSegmentValue
currentSegmentValue = 0
@ -278,7 +282,7 @@ object DateTimeUtils {
currentSegmentValue = 0
i = 4
} else {
return None
return (Array.empty, None, false)
}
} else if (i == 2) {
if (b == ' ' || b == 'T') {
@ -286,7 +290,7 @@ object DateTimeUtils {
currentSegmentValue = 0
i += 1
} else {
return None
return (Array.empty, None, false)
}
} else if (i == 3 || i == 4) {
if (b == ':') {
@ -294,7 +298,7 @@ object DateTimeUtils {
currentSegmentValue = 0
i += 1
} else {
return None
return (Array.empty, None, false)
}
} else if (i == 5 || i == 6) {
if (b == '-' || b == '+') {
@ -322,7 +326,7 @@ object DateTimeUtils {
currentSegmentValue = 0
i += 1
} else {
return None
return (Array.empty, None, false)
}
}
} else {
@ -337,7 +341,7 @@ object DateTimeUtils {
segments(i) = currentSegmentValue
if (!justTime && i == 0 && j != 4) {
// year should have exact four digits
return None
return (Array.empty, None, false)
}
while (digitsMilli < 6) {
@ -350,13 +354,48 @@ object DateTimeUtils {
segments(6) /= 10
digitsMilli -= 1
}
// This step also validates time zone part
val zoneId = tz.map {
case "+" => ZoneOffset.ofHoursMinutes(segments(7), segments(8))
case "-" => ZoneOffset.ofHoursMinutes(-segments(7), -segments(8))
case zoneName: String => getZoneId(zoneName.trim)
}
(segments, zoneId, justTime)
}
/**
* Trims and parses a given UTF8 timestamp string to the corresponding a corresponding [[Long]]
* value. The return type is [[Option]] in order to distinguish between 0L and null. The following
* formats are allowed:
*
* `yyyy`
* `yyyy-[m]m`
* `yyyy-[m]m-[d]d`
* `yyyy-[m]m-[d]d `
* `yyyy-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]`
* `yyyy-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]`
* `[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]`
* `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]`
*
* where `zone_id` should have one of the forms:
* - Z - Zulu time zone UTC+0
* - +|-[h]h:[m]m
* - A short id, see https://docs.oracle.com/javase/8/docs/api/java/time/ZoneId.html#SHORT_IDS
* - An id with one of the prefixes UTC+, UTC-, GMT+, GMT-, UT+ or UT-,
* and a suffix in the formats:
* - +|-h[h]
* - +|-hh[:]mm
* - +|-hh:mm:ss
* - +|-hhmmss
* - Region-based zone IDs in the form `area/city`, such as `Europe/Paris`
*/
def stringToTimestamp(s: UTF8String, timeZoneId: ZoneId): Option[Long] = {
try {
val zoneId = tz match {
case None => timeZoneId
case Some("+") => ZoneOffset.ofHoursMinutes(segments(7), segments(8))
case Some("-") => ZoneOffset.ofHoursMinutes(-segments(7), -segments(8))
case Some(zoneName: String) => getZoneId(zoneName.trim)
val (segments, parsedZoneId, justTime) = parseTimestampString(s)
if (segments.isEmpty) {
return None
}
val zoneId = parsedZoneId.getOrElse(timeZoneId)
val nanoseconds = MICROSECONDS.toNanos(segments(6))
val localTime = LocalTime.of(segments(3), segments(4), segments(5), nanoseconds.toInt)
val localDate = if (justTime) {
@ -378,6 +417,60 @@ object DateTimeUtils {
throw QueryExecutionErrors.cannotCastUTF8StringToDataTypeError(s, TimestampType)
}
}
/**
* Trims and parses a given UTF8 string to a corresponding [[Long]] value which representing the
* number of microseconds since the epoch. The result is independent of time zones,
* which means that zone ID in the input string will be ignored.
* The return type is [[Option]] in order to distinguish between 0L and null. The following
* formats are allowed:
*
* `yyyy`
* `yyyy-[m]m`
* `yyyy-[m]m-[d]d`
* `yyyy-[m]m-[d]d `
* `yyyy-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]`
* `yyyy-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]`
*
* where `zone_id` should have one of the forms:
* - Z - Zulu time zone UTC+0
* - +|-[h]h:[m]m
* - A short id, see https://docs.oracle.com/javase/8/docs/api/java/time/ZoneId.html#SHORT_IDS
* - An id with one of the prefixes UTC+, UTC-, GMT+, GMT-, UT+ or UT-,
* and a suffix in the formats:
* - +|-h[h]
* - +|-hh[:]mm
* - +|-hh:mm:ss
* - +|-hhmmss
* - Region-based zone IDs in the form `area/city`, such as `Europe/Paris`
*
* Note: The input string has to contains year/month/day fields, otherwise Spark can't determine
* the value of timestamp without time zone.
*/
def stringToTimestampWithoutTimeZone(s: UTF8String): Option[Long] = {
try {
val (segments, _, justTime) = parseTimestampString(s)
// If the input string can't be parsed as a timestamp, or it contains only the time part of a
// timestamp and we can't determine its date, return None.
if (segments.isEmpty || justTime) {
return None
}
val nanoseconds = MICROSECONDS.toNanos(segments(6))
val localTime = LocalTime.of(segments(3), segments(4), segments(5), nanoseconds.toInt)
val localDate = LocalDate.of(segments(0), segments(1), segments(2))
val localDateTime = LocalDateTime.of(localDate, localTime)
Some(localDateTimeToMicros(localDateTime))
} catch {
case NonFatal(_) => None
}
}
def stringToTimestampWithoutTimeZoneAnsi(s: UTF8String): Long = {
stringToTimestampWithoutTimeZone(s).getOrElse {
throw QueryExecutionErrors.cannotCastUTF8StringToDataTypeError(s, TimestampWithoutTZType)
}
}
// See issue SPARK-35679
// min second cause overflow in instant to micro
private val MIN_SECONDS = Math.floorDiv(Long.MinValue, MICROS_PER_SECOND)

View file

@ -401,6 +401,19 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
checkExceptionInExpression[ArithmeticException](
cast(cast(Literal("2147483648"), FloatType), IntegerType), "overflow")
}
test("SPARK-35720: cast invalid string input to timestamp without time zone") {
Seq("00:00:00",
"a",
"123",
"a2021-06-17",
"2021-06-17abc",
"2021-06-17 00:00:00ABC").foreach { invalidInput =>
checkExceptionInExpression[DateTimeException](
cast(invalidInput, TimestampWithoutTZType),
s"Cannot cast $invalidInput to TimestampWithoutTZType")
}
}
}
/**

View file

@ -651,4 +651,15 @@ class CastSuite extends CastSuiteBase {
checkEvaluation(cast(cast(interval, StringType), YearMonthIntervalType()), period)
}
}
test("SPARK-35720: cast invalid string input to timestamp without time zone") {
Seq("00:00:00",
"a",
"123",
"a2021-06-17",
"2021-06-17abc",
"2021-06-17 00:00:00ABC").foreach { invalidInput =>
checkEvaluation(cast(invalidInput, TimestampWithoutTZType), null)
}
}
}

View file

@ -929,4 +929,24 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
verifyCastFailure(cast(timestampWithoutTZLiteral, numericType), Some(errorMsg))
}
}
test("SPARK-35720: cast string to timestamp without timezone") {
specialTs.foreach { s =>
val expectedTs = LocalDateTime.parse(s)
checkEvaluation(cast(s, TimestampWithoutTZType), expectedTs)
// Trim spaces before casting
checkEvaluation(cast(" " + s + " ", TimestampWithoutTZType), expectedTs)
// The result is independent of timezone
outstandingZoneIds.foreach { zoneId =>
checkEvaluation(cast(s + zoneId.toString, TimestampWithoutTZType), expectedTs)
val tsWithMicros = s + ".123456"
val expectedTsWithNanoSeconds = LocalDateTime.parse(tsWithMicros)
checkEvaluation(cast(tsWithMicros + zoneId.toString, TimestampWithoutTZType),
expectedTsWithNanoSeconds)
}
}
// The input string can contain date only
checkEvaluation(cast("2021-06-17", TimestampWithoutTZType),
LocalDateTime.of(2021, 6, 17, 0, 0))
}
}