[SPARK-35664][SQL] Support java.time.LocalDateTime as an external type of TimestampWithoutTZ type

### What changes were proposed in this pull request?

In the PR, I propose to extend Spark SQL API to accept `java.time.LocalDateTime` as an external type of recently added new Catalyst type - `TimestampWithoutTZ`. The Java class `java.time.LocalDateTime` has a similar semantic to ANSI SQL timestamp without timezone type, and it is the most suitable to be an external type for `TimestampWithoutTZType`. In more details:

* Added `TimestampWithoutTZConverter` which converts java.time.LocalDateTime instances to/from internal representation of the Catalyst type `TimestampWithoutTZType` (to Long type). The `TimestampWithoutTZConverter` object uses new methods of DateTimeUtils:
  * localDateTimeToMicros() converts the input date time to the total length in microseconds.
  * microsToLocalDateTime() obtains a java.time.LocalDateTime
* Support new type `TimestampWithoutTZType` in RowEncoder via the methods createDeserializerForLocalDateTime() and createSerializerForLocalDateTime().
* Extended the Literal API to construct literals from `java.time.LocalDateTime` instances.

### Why are the changes needed?

To allow users parallelization of `java.time.LocalDateTime` collections, and construct timestamp without time zone columns. Also to collect such columns back to the driver side.

### Does this PR introduce _any_ user-facing change?

The PR extends existing functionality. So, users can parallelize instances of the java.time.LocalDateTime class and collect them back.
```
scala> val ds = Seq(java.time.LocalDateTime.parse("1970-01-01T00:00:00")).toDS
ds: org.apache.spark.sql.Dataset[java.time.LocalDateTime] = [value: timestampwithouttz]

scala> ds.collect()
res0: Array[java.time.LocalDateTime] = Array(1970-01-01T00:00)
```
### How was this patch tested?

New unit tests

Closes #32814 from gengliangwang/LocalDateTime.

Authored-by: Gengliang Wang <gengliang@apache.org>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
This commit is contained in:
Gengliang Wang 2021-06-09 14:59:46 +08:00
parent 825b620862
commit 84c5ca33f9
23 changed files with 206 additions and 19 deletions

View file

@ -65,6 +65,9 @@ public final class SpecializedGettersReader {
if (dataType instanceof TimestampType) { if (dataType instanceof TimestampType) {
return obj.getLong(ordinal); return obj.getLong(ordinal);
} }
if (dataType instanceof TimestampWithoutTZType) {
return obj.getLong(ordinal);
}
if (dataType instanceof CalendarIntervalType) { if (dataType instanceof CalendarIntervalType) {
return obj.getInterval(ordinal); return obj.getInterval(ordinal);
} }

View file

@ -113,6 +113,14 @@ object Encoders {
*/ */
def LOCALDATE: Encoder[java.time.LocalDate] = ExpressionEncoder() def LOCALDATE: Encoder[java.time.LocalDate] = ExpressionEncoder()
/**
* Creates an encoder that serializes instances of the `java.time.LocalDateTime` class
* to the internal representation of nullable Catalyst's DateType.
*
* @since 3.2.0
*/
def LOCALDATETIME: Encoder[java.time.LocalDateTime] = ExpressionEncoder()
/** /**
* An encoder for nullable timestamp type. * An encoder for nullable timestamp type.
* *

View file

@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable}
import java.math.{BigDecimal => JavaBigDecimal} import java.math.{BigDecimal => JavaBigDecimal}
import java.math.{BigInteger => JavaBigInteger} import java.math.{BigInteger => JavaBigInteger}
import java.sql.{Date, Timestamp} import java.sql.{Date, Timestamp}
import java.time.{Duration, Instant, LocalDate, Period} import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
import java.util.{Map => JavaMap} import java.util.{Map => JavaMap}
import javax.annotation.Nullable import javax.annotation.Nullable
@ -66,6 +66,7 @@ object CatalystTypeConverters {
case DateType => DateConverter case DateType => DateConverter
case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => InstantConverter case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => InstantConverter
case TimestampType => TimestampConverter case TimestampType => TimestampConverter
case TimestampWithoutTZType => TimestampWithoutTZConverter
case dt: DecimalType => new DecimalConverter(dt) case dt: DecimalType => new DecimalConverter(dt)
case BooleanType => BooleanConverter case BooleanType => BooleanConverter
case ByteType => ByteConverter case ByteType => ByteConverter
@ -354,6 +355,23 @@ object CatalystTypeConverters {
DateTimeUtils.microsToInstant(row.getLong(column)) DateTimeUtils.microsToInstant(row.getLong(column))
} }
private object TimestampWithoutTZConverter
extends CatalystTypeConverter[Any, LocalDateTime, Any] {
override def toCatalystImpl(scalaValue: Any): Any = scalaValue match {
case l: LocalDateTime => DateTimeUtils.localDateTimeToMicros(l)
case other => throw new IllegalArgumentException(
s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ s"cannot be converted to the ${TimestampWithoutTZType.sql} type")
}
override def toScala(catalystValue: Any): LocalDateTime =
if (catalystValue == null) null
else DateTimeUtils.microsToLocalDateTime(catalystValue.asInstanceOf[Long])
override def toScalaImpl(row: InternalRow, column: Int): LocalDateTime =
DateTimeUtils.microsToLocalDateTime(row.getLong(column))
}
private class DecimalConverter(dataType: DecimalType) private class DecimalConverter(dataType: DecimalType)
extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
@ -489,6 +507,7 @@ object CatalystTypeConverters {
case ld: LocalDate => LocalDateConverter.toCatalyst(ld) case ld: LocalDate => LocalDateConverter.toCatalyst(ld)
case t: Timestamp => TimestampConverter.toCatalyst(t) case t: Timestamp => TimestampConverter.toCatalyst(t)
case i: Instant => InstantConverter.toCatalyst(i) case i: Instant => InstantConverter.toCatalyst(i)
case l: LocalDateTime => TimestampWithoutTZConverter.toCatalyst(l)
case d: BigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d) case d: BigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d)
case d: JavaBigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d) case d: JavaBigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d)
case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray) case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray)

View file

@ -118,6 +118,15 @@ object DeserializerBuildHelper {
returnNullable = false) returnNullable = false)
} }
def createDeserializerForLocalDateTime(path: Expression): Expression = {
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.time.LocalDateTime]),
"microsToLocalDateTime",
path :: Nil,
returnNullable = false)
}
def createDeserializerForJavaBigDecimal( def createDeserializerForJavaBigDecimal(
path: Expression, path: Expression,
returnNullable: Boolean): Expression = { returnNullable: Boolean): Expression = {

View file

@ -134,7 +134,7 @@ object InternalRow {
case ShortType => (input, ordinal) => input.getShort(ordinal) case ShortType => (input, ordinal) => input.getShort(ordinal)
case IntegerType | DateType | YearMonthIntervalType => case IntegerType | DateType | YearMonthIntervalType =>
(input, ordinal) => input.getInt(ordinal) (input, ordinal) => input.getInt(ordinal)
case LongType | TimestampType | DayTimeIntervalType => case LongType | TimestampType | TimestampWithoutTZType | DayTimeIntervalType =>
(input, ordinal) => input.getLong(ordinal) (input, ordinal) => input.getLong(ordinal)
case FloatType => (input, ordinal) => input.getFloat(ordinal) case FloatType => (input, ordinal) => input.getFloat(ordinal)
case DoubleType => (input, ordinal) => input.getDouble(ordinal) case DoubleType => (input, ordinal) => input.getDouble(ordinal)
@ -171,7 +171,7 @@ object InternalRow {
case ShortType => (input, v) => input.setShort(ordinal, v.asInstanceOf[Short]) case ShortType => (input, v) => input.setShort(ordinal, v.asInstanceOf[Short])
case IntegerType | DateType | YearMonthIntervalType => case IntegerType | DateType | YearMonthIntervalType =>
(input, v) => input.setInt(ordinal, v.asInstanceOf[Int]) (input, v) => input.setInt(ordinal, v.asInstanceOf[Int])
case LongType | TimestampType | DayTimeIntervalType => case LongType | TimestampType | TimestampWithoutTZType | DayTimeIntervalType =>
(input, v) => input.setLong(ordinal, v.asInstanceOf[Long]) (input, v) => input.setLong(ordinal, v.asInstanceOf[Long])
case FloatType => (input, v) => input.setFloat(ordinal, v.asInstanceOf[Float]) case FloatType => (input, v) => input.setFloat(ordinal, v.asInstanceOf[Float])
case DoubleType => (input, v) => input.setDouble(ordinal, v.asInstanceOf[Double]) case DoubleType => (input, v) => input.setDouble(ordinal, v.asInstanceOf[Double])

View file

@ -119,6 +119,7 @@ object JavaTypeInference {
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true) case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
case c: Class[_] if c == classOf[java.time.LocalDateTime] => (TimestampWithoutTZType, true)
case c: Class[_] if c == classOf[java.time.Duration] => (DayTimeIntervalType, true) case c: Class[_] if c == classOf[java.time.Duration] => (DayTimeIntervalType, true)
case c: Class[_] if c == classOf[java.time.Period] => (YearMonthIntervalType, true) case c: Class[_] if c == classOf[java.time.Period] => (YearMonthIntervalType, true)
@ -250,6 +251,9 @@ object JavaTypeInference {
case c if c == classOf[java.sql.Timestamp] => case c if c == classOf[java.sql.Timestamp] =>
createDeserializerForSqlTimestamp(path) createDeserializerForSqlTimestamp(path)
case c if c == classOf[java.time.LocalDateTime] =>
createDeserializerForLocalDateTime(path)
case c if c == classOf[java.time.Duration] => case c if c == classOf[java.time.Duration] =>
createDeserializerForDuration(path) createDeserializerForDuration(path)
@ -409,6 +413,9 @@ object JavaTypeInference {
case c if c == classOf[java.sql.Timestamp] => createSerializerForSqlTimestamp(inputObject) case c if c == classOf[java.sql.Timestamp] => createSerializerForSqlTimestamp(inputObject)
case c if c == classOf[java.time.LocalDateTime] =>
createSerializerForLocalDateTime(inputObject)
case c if c == classOf[java.time.LocalDate] => createSerializerForJavaLocalDate(inputObject) case c if c == classOf[java.time.LocalDate] => createSerializerForJavaLocalDate(inputObject)
case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject) case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject)

View file

@ -241,6 +241,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
createDeserializerForSqlTimestamp(path) createDeserializerForSqlTimestamp(path)
case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) =>
createDeserializerForLocalDateTime(path)
case t if isSubtype(t, localTypeOf[java.time.Duration]) => case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
createDeserializerForDuration(path) createDeserializerForDuration(path)
@ -524,6 +527,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
createSerializerForSqlTimestamp(inputObject) createSerializerForSqlTimestamp(inputObject)
case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) =>
createSerializerForLocalDateTime(inputObject)
case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => case t if isSubtype(t, localTypeOf[java.time.LocalDate]) =>
createSerializerForJavaLocalDate(inputObject) createSerializerForJavaLocalDate(inputObject)
@ -746,6 +752,8 @@ object ScalaReflection extends ScalaReflection {
Schema(TimestampType, nullable = true) Schema(TimestampType, nullable = true)
case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
Schema(TimestampType, nullable = true) Schema(TimestampType, nullable = true)
case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) =>
Schema(TimestampWithoutTZType, nullable = true)
case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => Schema(DateType, nullable = true) case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => Schema(DateType, nullable = true)
case t if isSubtype(t, localTypeOf[java.sql.Date]) => Schema(DateType, nullable = true) case t if isSubtype(t, localTypeOf[java.sql.Date]) => Schema(DateType, nullable = true)
case t if isSubtype(t, localTypeOf[CalendarInterval]) => case t if isSubtype(t, localTypeOf[CalendarInterval]) =>
@ -850,6 +858,7 @@ object ScalaReflection extends ScalaReflection {
StringType -> classOf[UTF8String], StringType -> classOf[UTF8String],
DateType -> classOf[DateType.InternalType], DateType -> classOf[DateType.InternalType],
TimestampType -> classOf[TimestampType.InternalType], TimestampType -> classOf[TimestampType.InternalType],
TimestampWithoutTZType -> classOf[TimestampWithoutTZType.InternalType],
BinaryType -> classOf[BinaryType.InternalType], BinaryType -> classOf[BinaryType.InternalType],
CalendarIntervalType -> classOf[CalendarInterval], CalendarIntervalType -> classOf[CalendarInterval],
DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType], DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType],
@ -866,6 +875,7 @@ object ScalaReflection extends ScalaReflection {
DoubleType -> classOf[java.lang.Double], DoubleType -> classOf[java.lang.Double],
DateType -> classOf[java.lang.Integer], DateType -> classOf[java.lang.Integer],
TimestampType -> classOf[java.lang.Long], TimestampType -> classOf[java.lang.Long],
TimestampWithoutTZType -> classOf[java.lang.Long],
DayTimeIntervalType -> classOf[java.lang.Long], DayTimeIntervalType -> classOf[java.lang.Long],
YearMonthIntervalType -> classOf[java.lang.Integer] YearMonthIntervalType -> classOf[java.lang.Integer]
) )

View file

@ -86,6 +86,15 @@ object SerializerBuildHelper {
returnNullable = false) returnNullable = false)
} }
def createSerializerForLocalDateTime(inputObject: Expression): Expression = {
StaticInvoke(
DateTimeUtils.getClass,
TimestampWithoutTZType,
"localDateTimeToMicros",
inputObject :: Nil,
returnNullable = false)
}
def createSerializerForJavaLocalDate(inputObject: Expression): Expression = { def createSerializerForJavaLocalDate(inputObject: Expression): Expression = {
StaticInvoke( StaticInvoke(
DateTimeUtils.getClass, DateTimeUtils.getClass,

View file

@ -297,6 +297,10 @@ package object dsl {
/** Creates a new AttributeReference of type timestamp */ /** Creates a new AttributeReference of type timestamp */
def timestamp: AttributeReference = AttributeReference(s, TimestampType, nullable = true)() def timestamp: AttributeReference = AttributeReference(s, TimestampType, nullable = true)()
/** Creates a new AttributeReference of type timestamp without time zone */
def timestampWithoutTZ: AttributeReference =
AttributeReference(s, TimestampWithoutTZType, nullable = true)()
/** Creates a new AttributeReference of the day-time interval type */ /** Creates a new AttributeReference of the day-time interval type */
def dayTimeInterval: AttributeReference = { def dayTimeInterval: AttributeReference = {
AttributeReference(s, DayTimeIntervalType, nullable = true)() AttributeReference(s, DayTimeIntervalType, nullable = true)()

View file

@ -53,6 +53,8 @@ import org.apache.spark.sql.types._
* TimestampType -> java.sql.Timestamp if spark.sql.datetime.java8API.enabled is false * TimestampType -> java.sql.Timestamp if spark.sql.datetime.java8API.enabled is false
* TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled is true * TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled is true
* *
* TimestampWithoutTZType -> java.time.LocalDateTime
*
* DayTimeIntervalType -> java.time.Duration * DayTimeIntervalType -> java.time.Duration
* YearMonthIntervalType -> java.time.Period * YearMonthIntervalType -> java.time.Period
* *
@ -103,6 +105,8 @@ object RowEncoder {
createSerializerForSqlTimestamp(inputObject) createSerializerForSqlTimestamp(inputObject)
} }
case TimestampWithoutTZType => createSerializerForLocalDateTime(inputObject)
case DateType => case DateType =>
if (SQLConf.get.datetimeJava8ApiEnabled) { if (SQLConf.get.datetimeJava8ApiEnabled) {
createSerializerForJavaLocalDate(inputObject) createSerializerForJavaLocalDate(inputObject)
@ -226,6 +230,8 @@ object RowEncoder {
} else { } else {
ObjectType(classOf[java.sql.Timestamp]) ObjectType(classOf[java.sql.Timestamp])
} }
case TimestampWithoutTZType =>
ObjectType(classOf[java.time.LocalDateTime])
case DateType => case DateType =>
if (SQLConf.get.datetimeJava8ApiEnabled) { if (SQLConf.get.datetimeJava8ApiEnabled) {
ObjectType(classOf[java.time.LocalDate]) ObjectType(classOf[java.time.LocalDate])
@ -281,6 +287,9 @@ object RowEncoder {
createDeserializerForSqlTimestamp(input) createDeserializerForSqlTimestamp(input)
} }
case TimestampWithoutTZType =>
createDeserializerForLocalDateTime(input)
case DateType => case DateType =>
if (SQLConf.get.datetimeJava8ApiEnabled) { if (SQLConf.get.datetimeJava8ApiEnabled) {
createDeserializerForLocalDate(input) createDeserializerForLocalDate(input)

View file

@ -160,7 +160,7 @@ object InterpretedUnsafeProjection {
case IntegerType | DateType | YearMonthIntervalType => case IntegerType | DateType | YearMonthIntervalType =>
(v, i) => writer.write(i, v.getInt(i)) (v, i) => writer.write(i, v.getInt(i))
case LongType | TimestampType | DayTimeIntervalType => case LongType | TimestampType | TimestampWithoutTZType | DayTimeIntervalType =>
(v, i) => writer.write(i, v.getLong(i)) (v, i) => writer.write(i, v.getLong(i))
case FloatType => case FloatType =>

View file

@ -195,8 +195,8 @@ final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGen
private[this] def dataTypeToMutableValue(dataType: DataType): MutableValue = dataType match { private[this] def dataTypeToMutableValue(dataType: DataType): MutableValue = dataType match {
// We use INT for DATE and YearMonthIntervalType internally // We use INT for DATE and YearMonthIntervalType internally
case IntegerType | DateType | YearMonthIntervalType => new MutableInt case IntegerType | DateType | YearMonthIntervalType => new MutableInt
// We use Long for Timestamp and DayTimeInterval internally // We use Long for Timestamp, Timestamp without time zone and DayTimeInterval internally
case LongType | TimestampType | DayTimeIntervalType => new MutableLong case LongType | TimestampType | TimestampWithoutTZType | DayTimeIntervalType => new MutableLong
case FloatType => new MutableFloat case FloatType => new MutableFloat
case DoubleType => new MutableDouble case DoubleType => new MutableDouble
case BooleanType => new MutableBoolean case BooleanType => new MutableBoolean

View file

@ -1817,7 +1817,7 @@ object CodeGenerator extends Logging {
case ByteType => JAVA_BYTE case ByteType => JAVA_BYTE
case ShortType => JAVA_SHORT case ShortType => JAVA_SHORT
case IntegerType | DateType | YearMonthIntervalType => JAVA_INT case IntegerType | DateType | YearMonthIntervalType => JAVA_INT
case LongType | TimestampType | DayTimeIntervalType => JAVA_LONG case LongType | TimestampType | TimestampWithoutTZType | DayTimeIntervalType => JAVA_LONG
case FloatType => JAVA_FLOAT case FloatType => JAVA_FLOAT
case DoubleType => JAVA_DOUBLE case DoubleType => JAVA_DOUBLE
case _: DecimalType => "Decimal" case _: DecimalType => "Decimal"
@ -1838,7 +1838,8 @@ object CodeGenerator extends Logging {
case ByteType => java.lang.Byte.TYPE case ByteType => java.lang.Byte.TYPE
case ShortType => java.lang.Short.TYPE case ShortType => java.lang.Short.TYPE
case IntegerType | DateType | YearMonthIntervalType => java.lang.Integer.TYPE case IntegerType | DateType | YearMonthIntervalType => java.lang.Integer.TYPE
case LongType | TimestampType | DayTimeIntervalType => java.lang.Long.TYPE case LongType | TimestampType | TimestampWithoutTZType | DayTimeIntervalType =>
java.lang.Long.TYPE
case FloatType => java.lang.Float.TYPE case FloatType => java.lang.Float.TYPE
case DoubleType => java.lang.Double.TYPE case DoubleType => java.lang.Double.TYPE
case _: DecimalType => classOf[Decimal] case _: DecimalType => classOf[Decimal]

View file

@ -28,7 +28,7 @@ import java.lang.{Short => JavaShort}
import java.math.{BigDecimal => JavaBigDecimal} import java.math.{BigDecimal => JavaBigDecimal}
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp} import java.sql.{Date, Timestamp}
import java.time.{Duration, Instant, LocalDate, Period} import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
import java.util import java.util
import java.util.Objects import java.util.Objects
import javax.xml.bind.DatatypeConverter import javax.xml.bind.DatatypeConverter
@ -80,6 +80,7 @@ object Literal {
case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale)) case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale))
case i: Instant => Literal(instantToMicros(i), TimestampType) case i: Instant => Literal(instantToMicros(i), TimestampType)
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
case l: LocalDateTime => Literal(DateTimeUtils.localDateTimeToMicros(l), TimestampWithoutTZType)
case ld: LocalDate => Literal(ld.toEpochDay.toInt, DateType) case ld: LocalDate => Literal(ld.toEpochDay.toInt, DateType)
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
case d: Duration => Literal(durationToMicros(d), DayTimeIntervalType) case d: Duration => Literal(durationToMicros(d), DayTimeIntervalType)
@ -119,6 +120,7 @@ object Literal {
case _ if clz == classOf[Date] => DateType case _ if clz == classOf[Date] => DateType
case _ if clz == classOf[Instant] => TimestampType case _ if clz == classOf[Instant] => TimestampType
case _ if clz == classOf[Timestamp] => TimestampType case _ if clz == classOf[Timestamp] => TimestampType
case _ if clz == classOf[LocalDateTime] => TimestampWithoutTZType
case _ if clz == classOf[Duration] => DayTimeIntervalType case _ if clz == classOf[Duration] => DayTimeIntervalType
case _ if clz == classOf[Period] => YearMonthIntervalType case _ if clz == classOf[Period] => YearMonthIntervalType
case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT
@ -177,6 +179,7 @@ object Literal {
case dt: DecimalType => Literal(Decimal(0, dt.precision, dt.scale)) case dt: DecimalType => Literal(Decimal(0, dt.precision, dt.scale))
case DateType => create(0, DateType) case DateType => create(0, DateType)
case TimestampType => create(0L, TimestampType) case TimestampType => create(0L, TimestampType)
case TimestampWithoutTZType => create(0L, TimestampWithoutTZType)
case DayTimeIntervalType => create(0L, DayTimeIntervalType) case DayTimeIntervalType => create(0L, DayTimeIntervalType)
case YearMonthIntervalType => create(0, YearMonthIntervalType) case YearMonthIntervalType => create(0, YearMonthIntervalType)
case StringType => Literal("") case StringType => Literal("")
@ -198,7 +201,8 @@ object Literal {
case ByteType => v.isInstanceOf[Byte] case ByteType => v.isInstanceOf[Byte]
case ShortType => v.isInstanceOf[Short] case ShortType => v.isInstanceOf[Short]
case IntegerType | DateType | YearMonthIntervalType => v.isInstanceOf[Int] case IntegerType | DateType | YearMonthIntervalType => v.isInstanceOf[Int]
case LongType | TimestampType | DayTimeIntervalType => v.isInstanceOf[Long] case LongType | TimestampType | TimestampWithoutTZType | DayTimeIntervalType =>
v.isInstanceOf[Long]
case FloatType => v.isInstanceOf[Float] case FloatType => v.isInstanceOf[Float]
case DoubleType => v.isInstanceOf[Double] case DoubleType => v.isInstanceOf[Double]
case _: DecimalType => v.isInstanceOf[Decimal] case _: DecimalType => v.isInstanceOf[Decimal]
@ -422,7 +426,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
} }
case ByteType | ShortType => case ByteType | ShortType =>
ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType)) ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType))
case TimestampType | LongType | DayTimeIntervalType => case TimestampType | TimestampWithoutTZType | LongType | DayTimeIntervalType =>
toExprCode(s"${value}L") toExprCode(s"${value}L")
case _ => case _ =>
val constRef = ctx.addReferenceObj("literal", value, javaType) val constRef = ctx.addReferenceObj("literal", value, javaType)

View file

@ -71,6 +71,14 @@ object DateTimeUtils {
instantToMicros(instant) instantToMicros(instant)
} }
def microsToLocalDateTime(micros: Long): LocalDateTime = {
getLocalDateTime(micros, ZoneOffset.UTC)
}
def localDateTimeToMicros(localDateTime: LocalDateTime): Long = {
instantToMicros(localDateTime.toInstant(ZoneOffset.UTC))
}
/** /**
* Converts a local date at the default JVM time zone to the number of days since 1970-01-01 * Converts a local date at the default JVM time zone to the number of days since 1970-01-01
* in the hybrid calendar (Julian + Gregorian) by discarding the time part. The resulted days are * in the hybrid calendar (Julian + Gregorian) by discarding the time part. The resulted days are

View file

@ -171,7 +171,7 @@ object DataType {
private val otherTypes = { private val otherTypes = {
Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType, Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType,
DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType, DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType,
DayTimeIntervalType, YearMonthIntervalType) DayTimeIntervalType, YearMonthIntervalType, TimestampWithoutTZType)
.map(t => t.typeName -> t).toMap .map(t => t.typeName -> t).toMap
} }

View file

@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst package org.apache.spark.sql.catalyst
import java.time.{Duration, Instant, LocalDate, Period} import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row import org.apache.spark.sql.Row
@ -189,6 +189,35 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper {
} }
} }
test("SPARK-35664: converting java.time.LocalDateTime to TimestampWithoutTZType") {
Seq(
"0101-02-16T10:11:32",
"1582-10-02T01:02:03.04",
"1582-12-31T23:59:59.999999",
"1970-01-01T00:00:01.123",
"1972-12-31T23:59:59.123456",
"2019-02-16T18:12:30",
"2119-03-16T19:13:31").foreach { text =>
val input = LocalDateTime.parse(text)
val result = CatalystTypeConverters.convertToCatalyst(input)
val expected = DateTimeUtils.localDateTimeToMicros(input)
assert(result === expected)
}
}
test("SPARK-35664: converting TimestampWithoutTZType to java.time.LocalDateTime") {
Seq(
-9463427405253013L,
-244000001L,
0L,
99628200102030L,
1543749753123456L).foreach { us =>
val localDateTime = DateTimeUtils.microsToLocalDateTime(us)
assert(CatalystTypeConverters.createToScalaConverter(TimestampWithoutTZType)(us) ===
localDateTime)
}
}
test("converting java.time.LocalDate to DateType") { test("converting java.time.LocalDate to DateType") {
Seq( Seq(
"0101-02-16", "0101-02-16",

View file

@ -330,6 +330,16 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
} }
} }
test("SPARK-35664: encoding/decoding TimestampWithoutTZType to/from java.time.LocalDateTime") {
val schema = new StructType().add("t", TimestampWithoutTZType)
val encoder = RowEncoder(schema).resolveAndBind()
val localDateTime = java.time.LocalDateTime.parse("2019-02-26T16:56:00")
val row = toRow(encoder, Row(localDateTime))
assert(row.getLong(0) === DateTimeUtils.localDateTimeToMicros(localDateTime))
val readback = fromRow(encoder, row)
assert(readback.get(0) === localDateTime)
}
test("encoding/decoding DateType to/from java.time.LocalDate") { test("encoding/decoding DateType to/from java.time.LocalDate") {
withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") { withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") {
val schema = new StructType().add("d", DateType) val schema = new StructType().add("d", DateType)

View file

@ -360,6 +360,17 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
} }
} }
test("SPARK-35664: construct literals from java.time.LocalDateTime") {
Seq(
LocalDateTime.of(1, 1, 1, 0, 0, 0, 0),
LocalDateTime.of(2021, 5, 31, 23, 59, 59, 100),
LocalDateTime.of(2020, 2, 29, 23, 50, 57, 9999),
LocalDateTime.parse("9999-12-31T23:59:59.999999")
).foreach { dateTime =>
checkEvaluation(Literal(dateTime), dateTime)
}
}
test("SPARK-34605: construct literals from java.time.Duration") { test("SPARK-34605: construct literals from java.time.Duration") {
Seq( Seq(
Duration.ofNanos(0), Duration.ofNanos(0),

View file

@ -634,6 +634,39 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
} }
} }
test("SPARK-35664: microseconds to LocalDateTime") {
assert(microsToLocalDateTime(0) == LocalDateTime.parse("1970-01-01T00:00:00"))
assert(microsToLocalDateTime(100) == LocalDateTime.parse("1970-01-01T00:00:00.0001"))
assert(microsToLocalDateTime(100000000) == LocalDateTime.parse("1970-01-01T00:01:40"))
assert(microsToLocalDateTime(100000000000L) == LocalDateTime.parse("1970-01-02T03:46:40"))
assert(microsToLocalDateTime(253402300799999999L) ==
LocalDateTime.parse("9999-12-31T23:59:59.999999"))
assert(microsToLocalDateTime(Long.MinValue) ==
LocalDateTime.parse("-290308-12-21T19:59:05.224192"))
assert(microsToLocalDateTime(Long.MaxValue) ==
LocalDateTime.parse("+294247-01-10T04:00:54.775807"))
}
test("SPARK-35664: LocalDateTime to microseconds") {
assert(DateTimeUtils.localDateTimeToMicros(LocalDateTime.parse("1970-01-01T00:00:00")) == 0)
assert(
DateTimeUtils.localDateTimeToMicros(LocalDateTime.parse("1970-01-01T00:00:00.0001")) == 100)
assert(
DateTimeUtils.localDateTimeToMicros(LocalDateTime.parse("1970-01-01T00:01:40")) == 100000000)
assert(DateTimeUtils.localDateTimeToMicros(LocalDateTime.parse("1970-01-02T03:46:40")) ==
100000000000L)
assert(DateTimeUtils.localDateTimeToMicros(LocalDateTime.parse("9999-12-31T23:59:59.999999"))
== 253402300799999999L)
assert(DateTimeUtils.localDateTimeToMicros(LocalDateTime.parse("-1000-12-31T23:59:59.999999"))
== -93692592000000001L)
Seq(LocalDateTime.MIN, LocalDateTime.MAX).foreach { dt =>
val msg = intercept[ArithmeticException] {
DateTimeUtils.localDateTimeToMicros(dt)
}.getMessage
assert(msg == "long overflow")
}
}
test("daysToMicros and microsToDays") { test("daysToMicros and microsToDays") {
val input = date(2015, 12, 31, 16, zid = LA) val input = date(2015, 12, 31, 16, zid = LA)
assert(microsToDays(input, LA) === 16800) assert(microsToDays(input, LA) === 16800)
@ -780,8 +813,8 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
(LocalDateTime.of(2021, 3, 14, 1, 0, 0), LocalDateTime.of(2021, 3, 14, 3, 0, 0)) -> (LocalDateTime.of(2021, 3, 14, 1, 0, 0), LocalDateTime.of(2021, 3, 14, 3, 0, 0)) ->
TimeUnit.HOURS.toMicros(2) TimeUnit.HOURS.toMicros(2)
).foreach { case ((start, end), expected) => ).foreach { case ((start, end), expected) =>
val startMicros = localDateTimeToMicros(start, zid) val startMicros = DateTimeTestUtils.localDateTimeToMicros(start, zid)
val endMicros = localDateTimeToMicros(end, zid) val endMicros = DateTimeTestUtils.localDateTimeToMicros(end, zid)
val result = subtractTimestamps(endMicros, startMicros, zid) val result = subtractTimestamps(endMicros, startMicros, zid)
assert(result === expected) assert(result === expected)
} }

View file

@ -82,6 +82,9 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits {
/** @since 3.0.0 */ /** @since 3.0.0 */
implicit def newLocalDateEncoder: Encoder[java.time.LocalDate] = Encoders.LOCALDATE implicit def newLocalDateEncoder: Encoder[java.time.LocalDate] = Encoders.LOCALDATE
/** @since 3.2.0 */
implicit def newLocalDateTimeEncoder: Encoder[java.time.LocalDateTime] = Encoders.LOCALDATETIME
/** @since 2.2.0 */ /** @since 2.2.0 */
implicit def newTimeStampEncoder: Encoder[java.sql.Timestamp] = Encoders.TIMESTAMP implicit def newTimeStampEncoder: Encoder[java.sql.Timestamp] = Encoders.TIMESTAMP

View file

@ -21,10 +21,7 @@ import java.io.Serializable;
import java.math.BigDecimal; import java.math.BigDecimal;
import java.sql.Date; import java.sql.Date;
import java.sql.Timestamp; import java.sql.Timestamp;
import java.time.Duration; import java.time.*;
import java.time.Instant;
import java.time.LocalDate;
import java.time.Period;
import java.util.*; import java.util.*;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
@ -413,6 +410,14 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(data, ds.collectAsList()); Assert.assertEquals(data, ds.collectAsList());
} }
@Test
public void testLocalDateTimeEncoder() {
Encoder<LocalDateTime> encoder = Encoders.LOCALDATETIME();
List<LocalDateTime> data = Arrays.asList(LocalDateTime.of(1, 1, 1, 1, 1));
Dataset<LocalDateTime> ds = spark.createDataset(data, encoder);
Assert.assertEquals(data, ds.collectAsList());
}
@Test @Test
public void testDurationEncoder() { public void testDurationEncoder() {
Encoder<Duration> encoder = Encoders.DURATION(); Encoder<Duration> encoder = Encoders.DURATION();

View file

@ -2008,6 +2008,11 @@ class DatasetSuite extends QueryTest
checkAnswer(withUDF, Row(Row(1), null, null) :: Row(Row(1), null, null) :: Nil) checkAnswer(withUDF, Row(Row(1), null, null) :: Row(Row(1), null, null) :: Nil)
} }
test("SPARK-35664: implicit encoder for java.time.LocalDateTime") {
val localDateTime = java.time.LocalDateTime.parse("2021-06-08T12:31:58.999999")
assert(Seq(localDateTime).toDS().head() === localDateTime)
}
test("SPARK-34605: implicit encoder for java.time.Duration") { test("SPARK-34605: implicit encoder for java.time.Duration") {
val duration = java.time.Duration.ofMinutes(10) val duration = java.time.Duration.ofMinutes(10)
assert(spark.range(1).map { _ => duration }.head === duration) assert(spark.range(1).map { _ => duration }.head === duration)