From 84c5ca33f95e982a15efd514f103e4b85c273567 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 9 Jun 2021 14:59:46 +0800 Subject: [PATCH] [SPARK-35664][SQL] Support java.time.LocalDateTime as an external type of TimestampWithoutTZ type ### What changes were proposed in this pull request? In the PR, I propose to extend Spark SQL API to accept `java.time.LocalDateTime` as an external type of recently added new Catalyst type - `TimestampWithoutTZ`. The Java class `java.time.LocalDateTime` has a similar semantic to ANSI SQL timestamp without timezone type, and it is the most suitable to be an external type for `TimestampWithoutTZType`. In more details: * Added `TimestampWithoutTZConverter` which converts java.time.LocalDateTime instances to/from internal representation of the Catalyst type `TimestampWithoutTZType` (to Long type). The `TimestampWithoutTZConverter` object uses new methods of DateTimeUtils: * localDateTimeToMicros() converts the input date time to the total length in microseconds. * microsToLocalDateTime() obtains a java.time.LocalDateTime * Support new type `TimestampWithoutTZType` in RowEncoder via the methods createDeserializerForLocalDateTime() and createSerializerForLocalDateTime(). * Extended the Literal API to construct literals from `java.time.LocalDateTime` instances. ### Why are the changes needed? To allow users parallelization of `java.time.LocalDateTime` collections, and construct timestamp without time zone columns. Also to collect such columns back to the driver side. ### Does this PR introduce _any_ user-facing change? The PR extends existing functionality. So, users can parallelize instances of the java.time.LocalDateTime class and collect them back. ``` scala> val ds = Seq(java.time.LocalDateTime.parse("1970-01-01T00:00:00")).toDS ds: org.apache.spark.sql.Dataset[java.time.LocalDateTime] = [value: timestampwithouttz] scala> ds.collect() res0: Array[java.time.LocalDateTime] = Array(1970-01-01T00:00) ``` ### How was this patch tested? New unit tests Closes #32814 from gengliangwang/LocalDateTime. Authored-by: Gengliang Wang Signed-off-by: Gengliang Wang --- .../expressions/SpecializedGettersReader.java | 3 ++ .../scala/org/apache/spark/sql/Encoders.scala | 8 ++++ .../sql/catalyst/CatalystTypeConverters.scala | 21 ++++++++++- .../catalyst/DeserializerBuildHelper.scala | 9 +++++ .../spark/sql/catalyst/InternalRow.scala | 4 +- .../sql/catalyst/JavaTypeInference.scala | 7 ++++ .../spark/sql/catalyst/ScalaReflection.scala | 10 +++++ .../sql/catalyst/SerializerBuildHelper.scala | 9 +++++ .../spark/sql/catalyst/dsl/package.scala | 4 ++ .../sql/catalyst/encoders/RowEncoder.scala | 9 +++++ .../InterpretedUnsafeProjection.scala | 2 +- .../expressions/SpecificInternalRow.scala | 4 +- .../expressions/codegen/CodeGenerator.scala | 5 ++- .../sql/catalyst/expressions/literals.scala | 10 +++-- .../sql/catalyst/util/DateTimeUtils.scala | 8 ++++ .../org/apache/spark/sql/types/DataType.scala | 2 +- .../CatalystTypeConvertersSuite.scala | 31 +++++++++++++++- .../catalyst/encoders/RowEncoderSuite.scala | 10 +++++ .../expressions/LiteralExpressionSuite.scala | 11 ++++++ .../catalyst/util/DateTimeUtilsSuite.scala | 37 ++++++++++++++++++- .../org/apache/spark/sql/SQLImplicits.scala | 3 ++ .../apache/spark/sql/JavaDatasetSuite.java | 13 +++++-- .../org/apache/spark/sql/DatasetSuite.scala | 5 +++ 23 files changed, 206 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java index 90f340b51c..9bce7e8945 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java @@ -65,6 +65,9 @@ public final class SpecializedGettersReader { if (dataType instanceof TimestampType) { return obj.getLong(ordinal); } + if (dataType instanceof TimestampWithoutTZType) { + return obj.getLong(ordinal); + } if (dataType instanceof CalendarIntervalType) { return obj.getInterval(ordinal); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index d50829578e..4ead9505bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -113,6 +113,14 @@ object Encoders { */ def LOCALDATE: Encoder[java.time.LocalDate] = ExpressionEncoder() + /** + * Creates an encoder that serializes instances of the `java.time.LocalDateTime` class + * to the internal representation of nullable Catalyst's DateType. + * + * @since 3.2.0 + */ + def LOCALDATETIME: Encoder[java.time.LocalDateTime] = ExpressionEncoder() + /** * An encoder for nullable timestamp type. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index ccf0a50b73..abd3bf4dda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable} import java.math.{BigDecimal => JavaBigDecimal} import java.math.{BigInteger => JavaBigInteger} import java.sql.{Date, Timestamp} -import java.time.{Duration, Instant, LocalDate, Period} +import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} import java.util.{Map => JavaMap} import javax.annotation.Nullable @@ -66,6 +66,7 @@ object CatalystTypeConverters { case DateType => DateConverter case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => InstantConverter case TimestampType => TimestampConverter + case TimestampWithoutTZType => TimestampWithoutTZConverter case dt: DecimalType => new DecimalConverter(dt) case BooleanType => BooleanConverter case ByteType => ByteConverter @@ -354,6 +355,23 @@ object CatalystTypeConverters { DateTimeUtils.microsToInstant(row.getLong(column)) } + private object TimestampWithoutTZConverter + extends CatalystTypeConverter[Any, LocalDateTime, Any] { + override def toCatalystImpl(scalaValue: Any): Any = scalaValue match { + case l: LocalDateTime => DateTimeUtils.localDateTimeToMicros(l) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to the ${TimestampWithoutTZType.sql} type") + } + + override def toScala(catalystValue: Any): LocalDateTime = + if (catalystValue == null) null + else DateTimeUtils.microsToLocalDateTime(catalystValue.asInstanceOf[Long]) + + override def toScalaImpl(row: InternalRow, column: Int): LocalDateTime = + DateTimeUtils.microsToLocalDateTime(row.getLong(column)) + } + private class DecimalConverter(dataType: DecimalType) extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { @@ -489,6 +507,7 @@ object CatalystTypeConverters { case ld: LocalDate => LocalDateConverter.toCatalyst(ld) case t: Timestamp => TimestampConverter.toCatalyst(t) case i: Instant => InstantConverter.toCatalyst(i) + case l: LocalDateTime => TimestampWithoutTZConverter.toCatalyst(l) case d: BigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d) case d: JavaBigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d) case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index eaa7c17bfd..0d3b9977e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -118,6 +118,15 @@ object DeserializerBuildHelper { returnNullable = false) } + def createDeserializerForLocalDateTime(path: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.time.LocalDateTime]), + "microsToLocalDateTime", + path :: Nil, + returnNullable = false) + } + def createDeserializerForJavaBigDecimal( path: Expression, returnNullable: Boolean): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index fd74f60c0c..202c718f63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -134,7 +134,7 @@ object InternalRow { case ShortType => (input, ordinal) => input.getShort(ordinal) case IntegerType | DateType | YearMonthIntervalType => (input, ordinal) => input.getInt(ordinal) - case LongType | TimestampType | DayTimeIntervalType => + case LongType | TimestampType | TimestampWithoutTZType | DayTimeIntervalType => (input, ordinal) => input.getLong(ordinal) case FloatType => (input, ordinal) => input.getFloat(ordinal) case DoubleType => (input, ordinal) => input.getDouble(ordinal) @@ -171,7 +171,7 @@ object InternalRow { case ShortType => (input, v) => input.setShort(ordinal, v.asInstanceOf[Short]) case IntegerType | DateType | YearMonthIntervalType => (input, v) => input.setInt(ordinal, v.asInstanceOf[Int]) - case LongType | TimestampType | DayTimeIntervalType => + case LongType | TimestampType | TimestampWithoutTZType | DayTimeIntervalType => (input, v) => input.setLong(ordinal, v.asInstanceOf[Long]) case FloatType => (input, v) => input.setFloat(ordinal, v.asInstanceOf[Float]) case DoubleType => (input, v) => input.setDouble(ordinal, v.asInstanceOf[Double]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 73f809fd4e..807eb8cfd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -119,6 +119,7 @@ object JavaTypeInference { case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true) case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) + case c: Class[_] if c == classOf[java.time.LocalDateTime] => (TimestampWithoutTZType, true) case c: Class[_] if c == classOf[java.time.Duration] => (DayTimeIntervalType, true) case c: Class[_] if c == classOf[java.time.Period] => (YearMonthIntervalType, true) @@ -250,6 +251,9 @@ object JavaTypeInference { case c if c == classOf[java.sql.Timestamp] => createDeserializerForSqlTimestamp(path) + case c if c == classOf[java.time.LocalDateTime] => + createDeserializerForLocalDateTime(path) + case c if c == classOf[java.time.Duration] => createDeserializerForDuration(path) @@ -409,6 +413,9 @@ object JavaTypeInference { case c if c == classOf[java.sql.Timestamp] => createSerializerForSqlTimestamp(inputObject) + case c if c == classOf[java.time.LocalDateTime] => + createSerializerForLocalDateTime(inputObject) + case c if c == classOf[java.time.LocalDate] => createSerializerForJavaLocalDate(inputObject) case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d60b1b719f..7ecf32da1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -241,6 +241,9 @@ object ScalaReflection extends ScalaReflection { case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => createDeserializerForSqlTimestamp(path) + case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) => + createDeserializerForLocalDateTime(path) + case t if isSubtype(t, localTypeOf[java.time.Duration]) => createDeserializerForDuration(path) @@ -524,6 +527,9 @@ object ScalaReflection extends ScalaReflection { case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => createSerializerForSqlTimestamp(inputObject) + case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) => + createSerializerForLocalDateTime(inputObject) + case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => createSerializerForJavaLocalDate(inputObject) @@ -746,6 +752,8 @@ object ScalaReflection extends ScalaReflection { Schema(TimestampType, nullable = true) case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => Schema(TimestampType, nullable = true) + case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) => + Schema(TimestampWithoutTZType, nullable = true) case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => Schema(DateType, nullable = true) case t if isSubtype(t, localTypeOf[java.sql.Date]) => Schema(DateType, nullable = true) case t if isSubtype(t, localTypeOf[CalendarInterval]) => @@ -850,6 +858,7 @@ object ScalaReflection extends ScalaReflection { StringType -> classOf[UTF8String], DateType -> classOf[DateType.InternalType], TimestampType -> classOf[TimestampType.InternalType], + TimestampWithoutTZType -> classOf[TimestampWithoutTZType.InternalType], BinaryType -> classOf[BinaryType.InternalType], CalendarIntervalType -> classOf[CalendarInterval], DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType], @@ -866,6 +875,7 @@ object ScalaReflection extends ScalaReflection { DoubleType -> classOf[java.lang.Double], DateType -> classOf[java.lang.Integer], TimestampType -> classOf[java.lang.Long], + TimestampWithoutTZType -> classOf[java.lang.Long], DayTimeIntervalType -> classOf[java.lang.Long], YearMonthIntervalType -> classOf[java.lang.Integer] ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index f80fab573c..0624698485 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -86,6 +86,15 @@ object SerializerBuildHelper { returnNullable = false) } + def createSerializerForLocalDateTime(inputObject: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + TimestampWithoutTZType, + "localDateTimeToMicros", + inputObject :: Nil, + returnNullable = false) + } + def createSerializerForJavaLocalDate(inputObject: Expression): Expression = { StaticInvoke( DateTimeUtils.getClass, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 626ece33f1..86998a7154 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -297,6 +297,10 @@ package object dsl { /** Creates a new AttributeReference of type timestamp */ def timestamp: AttributeReference = AttributeReference(s, TimestampType, nullable = true)() + /** Creates a new AttributeReference of type timestamp without time zone */ + def timestampWithoutTZ: AttributeReference = + AttributeReference(s, TimestampWithoutTZType, nullable = true)() + /** Creates a new AttributeReference of the day-time interval type */ def dayTimeInterval: AttributeReference = { AttributeReference(s, DayTimeIntervalType, nullable = true)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 417ff7a439..83b91972c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -53,6 +53,8 @@ import org.apache.spark.sql.types._ * TimestampType -> java.sql.Timestamp if spark.sql.datetime.java8API.enabled is false * TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled is true * + * TimestampWithoutTZType -> java.time.LocalDateTime + * * DayTimeIntervalType -> java.time.Duration * YearMonthIntervalType -> java.time.Period * @@ -103,6 +105,8 @@ object RowEncoder { createSerializerForSqlTimestamp(inputObject) } + case TimestampWithoutTZType => createSerializerForLocalDateTime(inputObject) + case DateType => if (SQLConf.get.datetimeJava8ApiEnabled) { createSerializerForJavaLocalDate(inputObject) @@ -226,6 +230,8 @@ object RowEncoder { } else { ObjectType(classOf[java.sql.Timestamp]) } + case TimestampWithoutTZType => + ObjectType(classOf[java.time.LocalDateTime]) case DateType => if (SQLConf.get.datetimeJava8ApiEnabled) { ObjectType(classOf[java.time.LocalDate]) @@ -281,6 +287,9 @@ object RowEncoder { createDeserializerForSqlTimestamp(input) } + case TimestampWithoutTZType => + createDeserializerForLocalDateTime(input) + case DateType => if (SQLConf.get.datetimeJava8ApiEnabled) { createDeserializerForLocalDate(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 908b73abad..e072c9a793 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -160,7 +160,7 @@ object InterpretedUnsafeProjection { case IntegerType | DateType | YearMonthIntervalType => (v, i) => writer.write(i, v.getInt(i)) - case LongType | TimestampType | DayTimeIntervalType => + case LongType | TimestampType | TimestampWithoutTZType | DayTimeIntervalType => (v, i) => writer.write(i, v.getLong(i)) case FloatType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala index 0f26192468..849870f18c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala @@ -195,8 +195,8 @@ final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGen private[this] def dataTypeToMutableValue(dataType: DataType): MutableValue = dataType match { // We use INT for DATE and YearMonthIntervalType internally case IntegerType | DateType | YearMonthIntervalType => new MutableInt - // We use Long for Timestamp and DayTimeInterval internally - case LongType | TimestampType | DayTimeIntervalType => new MutableLong + // We use Long for Timestamp, Timestamp without time zone and DayTimeInterval internally + case LongType | TimestampType | TimestampWithoutTZType | DayTimeIntervalType => new MutableLong case FloatType => new MutableFloat case DoubleType => new MutableDouble case BooleanType => new MutableBoolean diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 81ed646757..eec04d1ba5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1817,7 +1817,7 @@ object CodeGenerator extends Logging { case ByteType => JAVA_BYTE case ShortType => JAVA_SHORT case IntegerType | DateType | YearMonthIntervalType => JAVA_INT - case LongType | TimestampType | DayTimeIntervalType => JAVA_LONG + case LongType | TimestampType | TimestampWithoutTZType | DayTimeIntervalType => JAVA_LONG case FloatType => JAVA_FLOAT case DoubleType => JAVA_DOUBLE case _: DecimalType => "Decimal" @@ -1838,7 +1838,8 @@ object CodeGenerator extends Logging { case ByteType => java.lang.Byte.TYPE case ShortType => java.lang.Short.TYPE case IntegerType | DateType | YearMonthIntervalType => java.lang.Integer.TYPE - case LongType | TimestampType | DayTimeIntervalType => java.lang.Long.TYPE + case LongType | TimestampType | TimestampWithoutTZType | DayTimeIntervalType => + java.lang.Long.TYPE case FloatType => java.lang.Float.TYPE case DoubleType => java.lang.Double.TYPE case _: DecimalType => classOf[Decimal] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 9ffa58b99d..27259992c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -28,7 +28,7 @@ import java.lang.{Short => JavaShort} import java.math.{BigDecimal => JavaBigDecimal} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} -import java.time.{Duration, Instant, LocalDate, Period} +import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} import java.util import java.util.Objects import javax.xml.bind.DatatypeConverter @@ -80,6 +80,7 @@ object Literal { case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale)) case i: Instant => Literal(instantToMicros(i), TimestampType) case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) + case l: LocalDateTime => Literal(DateTimeUtils.localDateTimeToMicros(l), TimestampWithoutTZType) case ld: LocalDate => Literal(ld.toEpochDay.toInt, DateType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case d: Duration => Literal(durationToMicros(d), DayTimeIntervalType) @@ -119,6 +120,7 @@ object Literal { case _ if clz == classOf[Date] => DateType case _ if clz == classOf[Instant] => TimestampType case _ if clz == classOf[Timestamp] => TimestampType + case _ if clz == classOf[LocalDateTime] => TimestampWithoutTZType case _ if clz == classOf[Duration] => DayTimeIntervalType case _ if clz == classOf[Period] => YearMonthIntervalType case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT @@ -177,6 +179,7 @@ object Literal { case dt: DecimalType => Literal(Decimal(0, dt.precision, dt.scale)) case DateType => create(0, DateType) case TimestampType => create(0L, TimestampType) + case TimestampWithoutTZType => create(0L, TimestampWithoutTZType) case DayTimeIntervalType => create(0L, DayTimeIntervalType) case YearMonthIntervalType => create(0, YearMonthIntervalType) case StringType => Literal("") @@ -198,7 +201,8 @@ object Literal { case ByteType => v.isInstanceOf[Byte] case ShortType => v.isInstanceOf[Short] case IntegerType | DateType | YearMonthIntervalType => v.isInstanceOf[Int] - case LongType | TimestampType | DayTimeIntervalType => v.isInstanceOf[Long] + case LongType | TimestampType | TimestampWithoutTZType | DayTimeIntervalType => + v.isInstanceOf[Long] case FloatType => v.isInstanceOf[Float] case DoubleType => v.isInstanceOf[Double] case _: DecimalType => v.isInstanceOf[Decimal] @@ -422,7 +426,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { } case ByteType | ShortType => ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType)) - case TimestampType | LongType | DayTimeIntervalType => + case TimestampType | TimestampWithoutTZType | LongType | DayTimeIntervalType => toExprCode(s"${value}L") case _ => val constRef = ctx.addReferenceObj("literal", value, javaType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 9f4abde281..f2cc08d89d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -71,6 +71,14 @@ object DateTimeUtils { instantToMicros(instant) } + def microsToLocalDateTime(micros: Long): LocalDateTime = { + getLocalDateTime(micros, ZoneOffset.UTC) + } + + def localDateTimeToMicros(localDateTime: LocalDateTime): Long = { + instantToMicros(localDateTime.toInstant(ZoneOffset.UTC)) + } + /** * Converts a local date at the default JVM time zone to the number of days since 1970-01-01 * in the hybrid calendar (Julian + Gregorian) by discarding the time part. The resulted days are diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 5c5742c812..a8618565cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -171,7 +171,7 @@ object DataType { private val otherTypes = { Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType, - DayTimeIntervalType, YearMonthIntervalType) + DayTimeIntervalType, YearMonthIntervalType, TimestampWithoutTZType) .map(t => t.typeName -> t).toMap } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index 169c5d6a31..c116daba49 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst -import java.time.{Duration, Instant, LocalDate, Period} +import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row @@ -189,6 +189,35 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper { } } + test("SPARK-35664: converting java.time.LocalDateTime to TimestampWithoutTZType") { + Seq( + "0101-02-16T10:11:32", + "1582-10-02T01:02:03.04", + "1582-12-31T23:59:59.999999", + "1970-01-01T00:00:01.123", + "1972-12-31T23:59:59.123456", + "2019-02-16T18:12:30", + "2119-03-16T19:13:31").foreach { text => + val input = LocalDateTime.parse(text) + val result = CatalystTypeConverters.convertToCatalyst(input) + val expected = DateTimeUtils.localDateTimeToMicros(input) + assert(result === expected) + } + } + + test("SPARK-35664: converting TimestampWithoutTZType to java.time.LocalDateTime") { + Seq( + -9463427405253013L, + -244000001L, + 0L, + 99628200102030L, + 1543749753123456L).foreach { us => + val localDateTime = DateTimeUtils.microsToLocalDateTime(us) + assert(CatalystTypeConverters.createToScalaConverter(TimestampWithoutTZType)(us) === + localDateTime) + } + } + test("converting java.time.LocalDate to DateType") { Seq( "0101-02-16", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 6c22c14870..b333f12b93 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -330,6 +330,16 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { } } + test("SPARK-35664: encoding/decoding TimestampWithoutTZType to/from java.time.LocalDateTime") { + val schema = new StructType().add("t", TimestampWithoutTZType) + val encoder = RowEncoder(schema).resolveAndBind() + val localDateTime = java.time.LocalDateTime.parse("2019-02-26T16:56:00") + val row = toRow(encoder, Row(localDateTime)) + assert(row.getLong(0) === DateTimeUtils.localDateTimeToMicros(localDateTime)) + val readback = fromRow(encoder, row) + assert(readback.get(0) === localDateTime) + } + test("encoding/decoding DateType to/from java.time.LocalDate") { withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") { val schema = new StructType().add("d", DateType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index bda43aac97..f25652870d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -360,6 +360,17 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-35664: construct literals from java.time.LocalDateTime") { + Seq( + LocalDateTime.of(1, 1, 1, 0, 0, 0, 0), + LocalDateTime.of(2021, 5, 31, 23, 59, 59, 100), + LocalDateTime.of(2020, 2, 29, 23, 50, 57, 9999), + LocalDateTime.parse("9999-12-31T23:59:59.999999") + ).foreach { dateTime => + checkEvaluation(Literal(dateTime), dateTime) + } + } + test("SPARK-34605: construct literals from java.time.Duration") { Seq( Duration.ofNanos(0), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 5afd13ab9c..58a8bff62c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -634,6 +634,39 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { } } + test("SPARK-35664: microseconds to LocalDateTime") { + assert(microsToLocalDateTime(0) == LocalDateTime.parse("1970-01-01T00:00:00")) + assert(microsToLocalDateTime(100) == LocalDateTime.parse("1970-01-01T00:00:00.0001")) + assert(microsToLocalDateTime(100000000) == LocalDateTime.parse("1970-01-01T00:01:40")) + assert(microsToLocalDateTime(100000000000L) == LocalDateTime.parse("1970-01-02T03:46:40")) + assert(microsToLocalDateTime(253402300799999999L) == + LocalDateTime.parse("9999-12-31T23:59:59.999999")) + assert(microsToLocalDateTime(Long.MinValue) == + LocalDateTime.parse("-290308-12-21T19:59:05.224192")) + assert(microsToLocalDateTime(Long.MaxValue) == + LocalDateTime.parse("+294247-01-10T04:00:54.775807")) + } + + test("SPARK-35664: LocalDateTime to microseconds") { + assert(DateTimeUtils.localDateTimeToMicros(LocalDateTime.parse("1970-01-01T00:00:00")) == 0) + assert( + DateTimeUtils.localDateTimeToMicros(LocalDateTime.parse("1970-01-01T00:00:00.0001")) == 100) + assert( + DateTimeUtils.localDateTimeToMicros(LocalDateTime.parse("1970-01-01T00:01:40")) == 100000000) + assert(DateTimeUtils.localDateTimeToMicros(LocalDateTime.parse("1970-01-02T03:46:40")) == + 100000000000L) + assert(DateTimeUtils.localDateTimeToMicros(LocalDateTime.parse("9999-12-31T23:59:59.999999")) + == 253402300799999999L) + assert(DateTimeUtils.localDateTimeToMicros(LocalDateTime.parse("-1000-12-31T23:59:59.999999")) + == -93692592000000001L) + Seq(LocalDateTime.MIN, LocalDateTime.MAX).foreach { dt => + val msg = intercept[ArithmeticException] { + DateTimeUtils.localDateTimeToMicros(dt) + }.getMessage + assert(msg == "long overflow") + } + } + test("daysToMicros and microsToDays") { val input = date(2015, 12, 31, 16, zid = LA) assert(microsToDays(input, LA) === 16800) @@ -780,8 +813,8 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { (LocalDateTime.of(2021, 3, 14, 1, 0, 0), LocalDateTime.of(2021, 3, 14, 3, 0, 0)) -> TimeUnit.HOURS.toMicros(2) ).foreach { case ((start, end), expected) => - val startMicros = localDateTimeToMicros(start, zid) - val endMicros = localDateTimeToMicros(end, zid) + val startMicros = DateTimeTestUtils.localDateTimeToMicros(start, zid) + val endMicros = DateTimeTestUtils.localDateTimeToMicros(end, zid) val result = subtractTimestamps(endMicros, startMicros, zid) assert(result === expected) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 90188cadfd..a3004ca2f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -82,6 +82,9 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { /** @since 3.0.0 */ implicit def newLocalDateEncoder: Encoder[java.time.LocalDate] = Encoders.LOCALDATE + /** @since 3.2.0 */ + implicit def newLocalDateTimeEncoder: Encoder[java.time.LocalDateTime] = Encoders.LOCALDATETIME + /** @since 2.2.0 */ implicit def newTimeStampEncoder: Encoder[java.sql.Timestamp] = Encoders.TIMESTAMP diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 3e988c2a23..645c9e9bd2 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -21,10 +21,7 @@ import java.io.Serializable; import java.math.BigDecimal; import java.sql.Date; import java.sql.Timestamp; -import java.time.Duration; -import java.time.Instant; -import java.time.LocalDate; -import java.time.Period; +import java.time.*; import java.util.*; import javax.annotation.Nonnull; @@ -413,6 +410,14 @@ public class JavaDatasetSuite implements Serializable { Assert.assertEquals(data, ds.collectAsList()); } + @Test + public void testLocalDateTimeEncoder() { + Encoder encoder = Encoders.LOCALDATETIME(); + List data = Arrays.asList(LocalDateTime.of(1, 1, 1, 1, 1)); + Dataset ds = spark.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + @Test public void testDurationEncoder() { Encoder encoder = Encoders.DURATION(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index e933b4488b..a2b3f66948 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -2008,6 +2008,11 @@ class DatasetSuite extends QueryTest checkAnswer(withUDF, Row(Row(1), null, null) :: Row(Row(1), null, null) :: Nil) } + test("SPARK-35664: implicit encoder for java.time.LocalDateTime") { + val localDateTime = java.time.LocalDateTime.parse("2021-06-08T12:31:58.999999") + assert(Seq(localDateTime).toDS().head() === localDateTime) + } + test("SPARK-34605: implicit encoder for java.time.Duration") { val duration = java.time.Duration.ofMinutes(10) assert(spark.range(1).map { _ => duration }.head === duration)