[SPARK-35720][SQL] Support casting of String to timestamp without time zone type
### What changes were proposed in this pull request? Extend the Cast expression and support StringType in casting to TimestampWithoutTZType. Closes #32898 ### Why are the changes needed? To conform the ANSI SQL standard which requires to support such casting. ### Does this PR introduce _any_ user-facing change? No, the new timestamp type is not released yet. ### How was this patch tested? Unit test Closes #32936 from gengliangwang/castStringToTswtz. Authored-by: Gengliang Wang <gengliang@apache.org> Signed-off-by: Gengliang Wang <gengliang@apache.org>
This commit is contained in:
parent
79362c4efc
commit
05e2b76852
|
@ -70,6 +70,7 @@ object Cast {
|
||||||
case (_: NumericType, TimestampType) => true
|
case (_: NumericType, TimestampType) => true
|
||||||
case (TimestampWithoutTZType, TimestampType) => true
|
case (TimestampWithoutTZType, TimestampType) => true
|
||||||
|
|
||||||
|
case (StringType, TimestampWithoutTZType) => true
|
||||||
case (DateType, TimestampWithoutTZType) => true
|
case (DateType, TimestampWithoutTZType) => true
|
||||||
case (TimestampType, TimestampWithoutTZType) => true
|
case (TimestampType, TimestampWithoutTZType) => true
|
||||||
|
|
||||||
|
@ -513,6 +514,14 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
|
||||||
}
|
}
|
||||||
|
|
||||||
private[this] def castToTimestampWithoutTZ(from: DataType): Any => Any = from match {
|
private[this] def castToTimestampWithoutTZ(from: DataType): Any => Any = from match {
|
||||||
|
case StringType =>
|
||||||
|
buildCast[UTF8String](_, utfs => {
|
||||||
|
if (ansiEnabled) {
|
||||||
|
DateTimeUtils.stringToTimestampWithoutTimeZoneAnsi(utfs)
|
||||||
|
} else {
|
||||||
|
DateTimeUtils.stringToTimestampWithoutTimeZone(utfs).orNull
|
||||||
|
}
|
||||||
|
})
|
||||||
case DateType =>
|
case DateType =>
|
||||||
buildCast[Int](_, d => daysToMicros(d, ZoneOffset.UTC))
|
buildCast[Int](_, d => daysToMicros(d, ZoneOffset.UTC))
|
||||||
case TimestampType =>
|
case TimestampType =>
|
||||||
|
@ -1410,6 +1419,24 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
|
||||||
private[this] def castToTimestampWithoutTZCode(
|
private[this] def castToTimestampWithoutTZCode(
|
||||||
from: DataType,
|
from: DataType,
|
||||||
ctx: CodegenContext): CastFunction = from match {
|
ctx: CodegenContext): CastFunction = from match {
|
||||||
|
case StringType =>
|
||||||
|
val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]])
|
||||||
|
(c, evPrim, evNull) =>
|
||||||
|
if (ansiEnabled) {
|
||||||
|
code"""
|
||||||
|
$evPrim =
|
||||||
|
$dateTimeUtilsCls.stringToTimestampWithoutTimeZoneAnsi($c);
|
||||||
|
"""
|
||||||
|
} else {
|
||||||
|
code"""
|
||||||
|
scala.Option<Long> $longOpt = $dateTimeUtilsCls.stringToTimestampWithoutTimeZone($c);
|
||||||
|
if ($longOpt.isDefined()) {
|
||||||
|
$evPrim = ((Long) $longOpt.get()).longValue();
|
||||||
|
} else {
|
||||||
|
$evNull = true;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
}
|
||||||
case DateType =>
|
case DateType =>
|
||||||
(c, evPrim, evNull) =>
|
(c, evPrim, evNull) =>
|
||||||
code"$evPrim = $dateTimeUtilsCls.daysToMicros($c, java.time.ZoneOffset.UTC);"
|
code"$evPrim = $dateTimeUtilsCls.daysToMicros($c, java.time.ZoneOffset.UTC);"
|
||||||
|
@ -2016,6 +2043,7 @@ object AnsiCast {
|
||||||
case (DateType, TimestampType) => true
|
case (DateType, TimestampType) => true
|
||||||
case (TimestampWithoutTZType, TimestampType) => true
|
case (TimestampWithoutTZType, TimestampType) => true
|
||||||
|
|
||||||
|
case (StringType, TimestampWithoutTZType) => true
|
||||||
case (DateType, TimestampWithoutTZType) => true
|
case (DateType, TimestampWithoutTZType) => true
|
||||||
case (TimestampType, TimestampWithoutTZType) => true
|
case (TimestampType, TimestampWithoutTZType) => true
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ import sun.util.calendar.ZoneInfo
|
||||||
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
|
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
|
||||||
import org.apache.spark.sql.catalyst.util.RebaseDateTime._
|
import org.apache.spark.sql.catalyst.util.RebaseDateTime._
|
||||||
import org.apache.spark.sql.errors.QueryExecutionErrors
|
import org.apache.spark.sql.errors.QueryExecutionErrors
|
||||||
import org.apache.spark.sql.types.{DateType, Decimal, TimestampType}
|
import org.apache.spark.sql.types.{DateType, Decimal, TimestampType, TimestampWithoutTZType}
|
||||||
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -219,7 +219,8 @@ object DateTimeUtils {
|
||||||
def cleanLegacyTimestampStr(s: UTF8String): UTF8String = s.replace(gmtUtf8, UTF8String.EMPTY_UTF8)
|
def cleanLegacyTimestampStr(s: UTF8String): UTF8String = s.replace(gmtUtf8, UTF8String.EMPTY_UTF8)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Trims and parses a given UTF8 timestamp string to the corresponding a corresponding [[Long]]
|
* Trims and parses a given UTF8 timestamp string to the corresponding timestamp segments,
|
||||||
|
* time zone id and whether it is just time without a date.
|
||||||
* value. The return type is [[Option]] in order to distinguish between 0L and null. The following
|
* value. The return type is [[Option]] in order to distinguish between 0L and null. The following
|
||||||
* formats are allowed:
|
* formats are allowed:
|
||||||
*
|
*
|
||||||
|
@ -243,10 +244,13 @@ object DateTimeUtils {
|
||||||
* - +|-hh:mm:ss
|
* - +|-hh:mm:ss
|
||||||
* - +|-hhmmss
|
* - +|-hhmmss
|
||||||
* - Region-based zone IDs in the form `area/city`, such as `Europe/Paris`
|
* - Region-based zone IDs in the form `area/city`, such as `Europe/Paris`
|
||||||
|
*
|
||||||
|
* @return timestamp segments, time zone id and whether the input is just time without a date. If
|
||||||
|
* the input string can't be parsed as timestamp, the result timestamp segments are empty.
|
||||||
*/
|
*/
|
||||||
def stringToTimestamp(s: UTF8String, timeZoneId: ZoneId): Option[Long] = {
|
private def parseTimestampString(s: UTF8String): (Array[Int], Option[ZoneId], Boolean) = {
|
||||||
if (s == null) {
|
if (s == null) {
|
||||||
return None
|
return (Array.empty, None, false)
|
||||||
}
|
}
|
||||||
var tz: Option[String] = None
|
var tz: Option[String] = None
|
||||||
val segments: Array[Int] = Array[Int](1, 1, 1, 0, 0, 0, 0, 0, 0)
|
val segments: Array[Int] = Array[Int](1, 1, 1, 0, 0, 0, 0, 0, 0)
|
||||||
|
@ -267,7 +271,7 @@ object DateTimeUtils {
|
||||||
if (b == '-') {
|
if (b == '-') {
|
||||||
if (i == 0 && j != 4) {
|
if (i == 0 && j != 4) {
|
||||||
// year should have exact four digits
|
// year should have exact four digits
|
||||||
return None
|
return (Array.empty, None, false)
|
||||||
}
|
}
|
||||||
segments(i) = currentSegmentValue
|
segments(i) = currentSegmentValue
|
||||||
currentSegmentValue = 0
|
currentSegmentValue = 0
|
||||||
|
@ -278,7 +282,7 @@ object DateTimeUtils {
|
||||||
currentSegmentValue = 0
|
currentSegmentValue = 0
|
||||||
i = 4
|
i = 4
|
||||||
} else {
|
} else {
|
||||||
return None
|
return (Array.empty, None, false)
|
||||||
}
|
}
|
||||||
} else if (i == 2) {
|
} else if (i == 2) {
|
||||||
if (b == ' ' || b == 'T') {
|
if (b == ' ' || b == 'T') {
|
||||||
|
@ -286,7 +290,7 @@ object DateTimeUtils {
|
||||||
currentSegmentValue = 0
|
currentSegmentValue = 0
|
||||||
i += 1
|
i += 1
|
||||||
} else {
|
} else {
|
||||||
return None
|
return (Array.empty, None, false)
|
||||||
}
|
}
|
||||||
} else if (i == 3 || i == 4) {
|
} else if (i == 3 || i == 4) {
|
||||||
if (b == ':') {
|
if (b == ':') {
|
||||||
|
@ -294,7 +298,7 @@ object DateTimeUtils {
|
||||||
currentSegmentValue = 0
|
currentSegmentValue = 0
|
||||||
i += 1
|
i += 1
|
||||||
} else {
|
} else {
|
||||||
return None
|
return (Array.empty, None, false)
|
||||||
}
|
}
|
||||||
} else if (i == 5 || i == 6) {
|
} else if (i == 5 || i == 6) {
|
||||||
if (b == '-' || b == '+') {
|
if (b == '-' || b == '+') {
|
||||||
|
@ -322,7 +326,7 @@ object DateTimeUtils {
|
||||||
currentSegmentValue = 0
|
currentSegmentValue = 0
|
||||||
i += 1
|
i += 1
|
||||||
} else {
|
} else {
|
||||||
return None
|
return (Array.empty, None, false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -337,7 +341,7 @@ object DateTimeUtils {
|
||||||
segments(i) = currentSegmentValue
|
segments(i) = currentSegmentValue
|
||||||
if (!justTime && i == 0 && j != 4) {
|
if (!justTime && i == 0 && j != 4) {
|
||||||
// year should have exact four digits
|
// year should have exact four digits
|
||||||
return None
|
return (Array.empty, None, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
while (digitsMilli < 6) {
|
while (digitsMilli < 6) {
|
||||||
|
@ -350,13 +354,48 @@ object DateTimeUtils {
|
||||||
segments(6) /= 10
|
segments(6) /= 10
|
||||||
digitsMilli -= 1
|
digitsMilli -= 1
|
||||||
}
|
}
|
||||||
|
// This step also validates time zone part
|
||||||
|
val zoneId = tz.map {
|
||||||
|
case "+" => ZoneOffset.ofHoursMinutes(segments(7), segments(8))
|
||||||
|
case "-" => ZoneOffset.ofHoursMinutes(-segments(7), -segments(8))
|
||||||
|
case zoneName: String => getZoneId(zoneName.trim)
|
||||||
|
}
|
||||||
|
(segments, zoneId, justTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Trims and parses a given UTF8 timestamp string to the corresponding a corresponding [[Long]]
|
||||||
|
* value. The return type is [[Option]] in order to distinguish between 0L and null. The following
|
||||||
|
* formats are allowed:
|
||||||
|
*
|
||||||
|
* `yyyy`
|
||||||
|
* `yyyy-[m]m`
|
||||||
|
* `yyyy-[m]m-[d]d`
|
||||||
|
* `yyyy-[m]m-[d]d `
|
||||||
|
* `yyyy-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]`
|
||||||
|
* `yyyy-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]`
|
||||||
|
* `[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]`
|
||||||
|
* `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]`
|
||||||
|
*
|
||||||
|
* where `zone_id` should have one of the forms:
|
||||||
|
* - Z - Zulu time zone UTC+0
|
||||||
|
* - +|-[h]h:[m]m
|
||||||
|
* - A short id, see https://docs.oracle.com/javase/8/docs/api/java/time/ZoneId.html#SHORT_IDS
|
||||||
|
* - An id with one of the prefixes UTC+, UTC-, GMT+, GMT-, UT+ or UT-,
|
||||||
|
* and a suffix in the formats:
|
||||||
|
* - +|-h[h]
|
||||||
|
* - +|-hh[:]mm
|
||||||
|
* - +|-hh:mm:ss
|
||||||
|
* - +|-hhmmss
|
||||||
|
* - Region-based zone IDs in the form `area/city`, such as `Europe/Paris`
|
||||||
|
*/
|
||||||
|
def stringToTimestamp(s: UTF8String, timeZoneId: ZoneId): Option[Long] = {
|
||||||
try {
|
try {
|
||||||
val zoneId = tz match {
|
val (segments, parsedZoneId, justTime) = parseTimestampString(s)
|
||||||
case None => timeZoneId
|
if (segments.isEmpty) {
|
||||||
case Some("+") => ZoneOffset.ofHoursMinutes(segments(7), segments(8))
|
return None
|
||||||
case Some("-") => ZoneOffset.ofHoursMinutes(-segments(7), -segments(8))
|
|
||||||
case Some(zoneName: String) => getZoneId(zoneName.trim)
|
|
||||||
}
|
}
|
||||||
|
val zoneId = parsedZoneId.getOrElse(timeZoneId)
|
||||||
val nanoseconds = MICROSECONDS.toNanos(segments(6))
|
val nanoseconds = MICROSECONDS.toNanos(segments(6))
|
||||||
val localTime = LocalTime.of(segments(3), segments(4), segments(5), nanoseconds.toInt)
|
val localTime = LocalTime.of(segments(3), segments(4), segments(5), nanoseconds.toInt)
|
||||||
val localDate = if (justTime) {
|
val localDate = if (justTime) {
|
||||||
|
@ -378,6 +417,60 @@ object DateTimeUtils {
|
||||||
throw QueryExecutionErrors.cannotCastUTF8StringToDataTypeError(s, TimestampType)
|
throw QueryExecutionErrors.cannotCastUTF8StringToDataTypeError(s, TimestampType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Trims and parses a given UTF8 string to a corresponding [[Long]] value which representing the
|
||||||
|
* number of microseconds since the epoch. The result is independent of time zones,
|
||||||
|
* which means that zone ID in the input string will be ignored.
|
||||||
|
* The return type is [[Option]] in order to distinguish between 0L and null. The following
|
||||||
|
* formats are allowed:
|
||||||
|
*
|
||||||
|
* `yyyy`
|
||||||
|
* `yyyy-[m]m`
|
||||||
|
* `yyyy-[m]m-[d]d`
|
||||||
|
* `yyyy-[m]m-[d]d `
|
||||||
|
* `yyyy-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]`
|
||||||
|
* `yyyy-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]`
|
||||||
|
*
|
||||||
|
* where `zone_id` should have one of the forms:
|
||||||
|
* - Z - Zulu time zone UTC+0
|
||||||
|
* - +|-[h]h:[m]m
|
||||||
|
* - A short id, see https://docs.oracle.com/javase/8/docs/api/java/time/ZoneId.html#SHORT_IDS
|
||||||
|
* - An id with one of the prefixes UTC+, UTC-, GMT+, GMT-, UT+ or UT-,
|
||||||
|
* and a suffix in the formats:
|
||||||
|
* - +|-h[h]
|
||||||
|
* - +|-hh[:]mm
|
||||||
|
* - +|-hh:mm:ss
|
||||||
|
* - +|-hhmmss
|
||||||
|
* - Region-based zone IDs in the form `area/city`, such as `Europe/Paris`
|
||||||
|
*
|
||||||
|
* Note: The input string has to contains year/month/day fields, otherwise Spark can't determine
|
||||||
|
* the value of timestamp without time zone.
|
||||||
|
*/
|
||||||
|
def stringToTimestampWithoutTimeZone(s: UTF8String): Option[Long] = {
|
||||||
|
try {
|
||||||
|
val (segments, _, justTime) = parseTimestampString(s)
|
||||||
|
// If the input string can't be parsed as a timestamp, or it contains only the time part of a
|
||||||
|
// timestamp and we can't determine its date, return None.
|
||||||
|
if (segments.isEmpty || justTime) {
|
||||||
|
return None
|
||||||
|
}
|
||||||
|
val nanoseconds = MICROSECONDS.toNanos(segments(6))
|
||||||
|
val localTime = LocalTime.of(segments(3), segments(4), segments(5), nanoseconds.toInt)
|
||||||
|
val localDate = LocalDate.of(segments(0), segments(1), segments(2))
|
||||||
|
val localDateTime = LocalDateTime.of(localDate, localTime)
|
||||||
|
Some(localDateTimeToMicros(localDateTime))
|
||||||
|
} catch {
|
||||||
|
case NonFatal(_) => None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def stringToTimestampWithoutTimeZoneAnsi(s: UTF8String): Long = {
|
||||||
|
stringToTimestampWithoutTimeZone(s).getOrElse {
|
||||||
|
throw QueryExecutionErrors.cannotCastUTF8StringToDataTypeError(s, TimestampWithoutTZType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// See issue SPARK-35679
|
// See issue SPARK-35679
|
||||||
// min second cause overflow in instant to micro
|
// min second cause overflow in instant to micro
|
||||||
private val MIN_SECONDS = Math.floorDiv(Long.MinValue, MICROS_PER_SECOND)
|
private val MIN_SECONDS = Math.floorDiv(Long.MinValue, MICROS_PER_SECOND)
|
||||||
|
|
|
@ -401,6 +401,19 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
|
||||||
checkExceptionInExpression[ArithmeticException](
|
checkExceptionInExpression[ArithmeticException](
|
||||||
cast(cast(Literal("2147483648"), FloatType), IntegerType), "overflow")
|
cast(cast(Literal("2147483648"), FloatType), IntegerType), "overflow")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-35720: cast invalid string input to timestamp without time zone") {
|
||||||
|
Seq("00:00:00",
|
||||||
|
"a",
|
||||||
|
"123",
|
||||||
|
"a2021-06-17",
|
||||||
|
"2021-06-17abc",
|
||||||
|
"2021-06-17 00:00:00ABC").foreach { invalidInput =>
|
||||||
|
checkExceptionInExpression[DateTimeException](
|
||||||
|
cast(invalidInput, TimestampWithoutTZType),
|
||||||
|
s"Cannot cast $invalidInput to TimestampWithoutTZType")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -651,4 +651,15 @@ class CastSuite extends CastSuiteBase {
|
||||||
checkEvaluation(cast(cast(interval, StringType), YearMonthIntervalType()), period)
|
checkEvaluation(cast(cast(interval, StringType), YearMonthIntervalType()), period)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-35720: cast invalid string input to timestamp without time zone") {
|
||||||
|
Seq("00:00:00",
|
||||||
|
"a",
|
||||||
|
"123",
|
||||||
|
"a2021-06-17",
|
||||||
|
"2021-06-17abc",
|
||||||
|
"2021-06-17 00:00:00ABC").foreach { invalidInput =>
|
||||||
|
checkEvaluation(cast(invalidInput, TimestampWithoutTZType), null)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -929,4 +929,24 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
|
||||||
verifyCastFailure(cast(timestampWithoutTZLiteral, numericType), Some(errorMsg))
|
verifyCastFailure(cast(timestampWithoutTZLiteral, numericType), Some(errorMsg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-35720: cast string to timestamp without timezone") {
|
||||||
|
specialTs.foreach { s =>
|
||||||
|
val expectedTs = LocalDateTime.parse(s)
|
||||||
|
checkEvaluation(cast(s, TimestampWithoutTZType), expectedTs)
|
||||||
|
// Trim spaces before casting
|
||||||
|
checkEvaluation(cast(" " + s + " ", TimestampWithoutTZType), expectedTs)
|
||||||
|
// The result is independent of timezone
|
||||||
|
outstandingZoneIds.foreach { zoneId =>
|
||||||
|
checkEvaluation(cast(s + zoneId.toString, TimestampWithoutTZType), expectedTs)
|
||||||
|
val tsWithMicros = s + ".123456"
|
||||||
|
val expectedTsWithNanoSeconds = LocalDateTime.parse(tsWithMicros)
|
||||||
|
checkEvaluation(cast(tsWithMicros + zoneId.toString, TimestampWithoutTZType),
|
||||||
|
expectedTsWithNanoSeconds)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// The input string can contain date only
|
||||||
|
checkEvaluation(cast("2021-06-17", TimestampWithoutTZType),
|
||||||
|
LocalDateTime.of(2021, 6, 17, 0, 0))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue