diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala index 19e212d1f9..a7e9a22b13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala @@ -61,7 +61,7 @@ case class ApproxCountDistinctForIntervals( } override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(NumericType, TimestampType, DateType), ArrayType) + Seq(TypeCollection(NumericType, TimestampType, DateType, TimestampNTZType), ArrayType) } // Mark as lazy so that endpointsExpression is not evaluated during tree transformation. @@ -79,7 +79,7 @@ case class ApproxCountDistinctForIntervals( TypeCheckFailure("The endpoints provided must be constant literals") } else { endpointsExpression.dataType match { - case ArrayType(_: NumericType | DateType | TimestampType, _) => + case ArrayType(_: NumericType | DateType | TimestampType | TimestampNTZType, _) => if (endpoints.length < 2) { TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals") } else { @@ -122,7 +122,7 @@ case class ApproxCountDistinctForIntervals( n.numeric.toDouble(value.asInstanceOf[n.InternalType]) case _: DateType => value.asInstanceOf[Int].toDouble - case _: TimestampType => + case TimestampType | TimestampNTZType => value.asInstanceOf[Long].toDouble } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala index 73f18d4fee..9d53673296 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import java.sql.{Date, Timestamp} +import java.time.LocalDateTime import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow @@ -38,7 +39,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { assert( wrongColumn.checkInputDataTypes() match { case TypeCheckFailure(msg) - if msg.contains("requires (numeric or timestamp or date) type") => true + if msg.contains("requires (numeric or timestamp or date or timestamp_ntz) type") => true case _ => false }) } @@ -199,7 +200,9 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { (intRecords.map(DateTimeUtils.toJavaDate), intEndpoints.map(DateTimeUtils.toJavaDate), DateType), (intRecords.map(DateTimeUtils.toJavaTimestamp(_)), - intEndpoints.map(DateTimeUtils.toJavaTimestamp(_)), TimestampType) + intEndpoints.map(DateTimeUtils.toJavaTimestamp(_)), TimestampType), + (intRecords.map(DateTimeUtils.microsToLocalDateTime(_)), + intEndpoints.map(DateTimeUtils.microsToLocalDateTime(_)), TimestampNTZType) ) inputs.foreach { case (records, endpoints, dataType) => @@ -209,6 +212,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { val value = r match { case d: Date => DateTimeUtils.fromJavaDate(d) case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) + case ldt: LocalDateTime => DateTimeUtils.localDateTimeToMicros(ldt) case _ => r } input.update(0, value)