diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index fba17d339e..2f63159592 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -70,6 +70,8 @@ object Cast { case (_: NumericType, TimestampType) => true case (TimestampWithoutTZType, TimestampType) => true + case (DateType, TimestampWithoutTZType) => true + case (StringType, DateType) => true case (TimestampType, DateType) => true case (TimestampWithoutTZType, DateType) => true @@ -315,6 +317,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit // The brackets that are used in casting structs and maps to strings private val (leftBracket, rightBracket) = if (legacyCastToStr) ("[", "]") else ("{", "}") + // The class name of `DateTimeUtils` + protected def dateTimeUtilsCls: String = DateTimeUtils.getClass.getName.stripSuffix("$") + // UDFToString private[this] def castToString(from: DataType): Any => Any = from match { case CalendarIntervalType => @@ -505,6 +510,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[Float](_, f => doubleToTimestamp(f.toDouble)) } + private[this] def castToTimestampWithoutTZ(from: DataType): Any => Any = from match { + case DateType => + buildCast[Int](_, d => daysToMicros(d, ZoneOffset.UTC)) + } + private[this] def decimalToTimestamp(d: Decimal): Long = { (d.toBigDecimal * MICROS_PER_SECOND).longValue } @@ -856,6 +866,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case DateType => castToDate(from) case decimal: DecimalType => castToDecimal(from, decimal) case TimestampType => castToTimestamp(from) + case TimestampWithoutTZType => castToTimestampWithoutTZ(from) case CalendarIntervalType => castToInterval(from) case DayTimeIntervalType => castToDayTimeInterval(from) case YearMonthIntervalType => castToYearMonthInterval(from) @@ -916,6 +927,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case DateType => castToDateCode(from, ctx) case decimal: DecimalType => castToDecimalCode(from, decimal, ctx) case TimestampType => castToTimestampCode(from, ctx) + case TimestampWithoutTZType => castToTimestampWithoutTZCode(from) case CalendarIntervalType => castToIntervalCode(from) case DayTimeIntervalType => castToDayTimeIntervalCode(from) case YearMonthIntervalType => castToYearMonthIntervalCode(from) @@ -1209,9 +1221,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit org.apache.spark.sql.catalyst.util.DateTimeUtils.microsToDays($c, $zid);""" case TimestampWithoutTZType => (c, evPrim, evNull) => - // scalastyle:off line.size.limit - code"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.microsToDays($c, java.time.ZoneOffset.UTC);" - // scalastyle:on line.size.limit + code"$evPrim = $dateTimeUtilsCls.microsToDays($c, java.time.ZoneOffset.UTC);" case _ => (c, evPrim, evNull) => code"$evNull = true;" } @@ -1371,6 +1381,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit """ } + private[this] def castToTimestampWithoutTZCode(from: DataType): CastFunction = from match { + case DateType => + (c, evPrim, evNull) => + code"$evPrim = $dateTimeUtilsCls.daysToMicros($c, java.time.ZoneOffset.UTC);" + } + private[this] def castToIntervalCode(from: DataType): CastFunction = from match { case StringType => val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") @@ -1955,6 +1971,8 @@ object AnsiCast { case (DateType, TimestampType) => true case (TimestampWithoutTZType, TimestampType) => true + case (DateType, TimestampWithoutTZType) => true + case (StringType, _: CalendarIntervalType) => true case (StringType, DayTimeIntervalType) => true case (StringType, YearMonthIntervalType) => true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 51a77405d4..a8b9a263d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -1267,6 +1267,15 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase { checkEvaluation(cast(dt, DateType), LocalDate.parse(s.split("T")(0))) } } + + test("SPARK-35718: cast date type to timestamp without timezone") { + specialTs.foreach { s => + val inputDate = LocalDate.parse(s.split("T")(0)) + // The hour/minute/second of the expect result should be 0 + val expectedTs = LocalDateTime.parse(s.split("T")(0) + "T00:00:00") + checkEvaluation(cast(inputDate, TimestampWithoutTZType), expectedTs) + } + } } /**