[SPARK-36017][SQL] Support TimestampNTZType in expression ApproximatePercentile

### What changes were proposed in this pull request?
The current `ApproximatePercentile` supports `TimestampType`, but not supports timestamp without time zone yet.
This PR will add the function.

### Why are the changes needed?
`ApproximatePercentile` need supports `TimestampNTZType`.

### Does this PR introduce _any_ user-facing change?
'Yes'. `ApproximatePercentile` accepts `TimestampNTZType`.

### How was this patch tested?
New tests.

Closes #33241 from beliefer/SPARK-36017.

Authored-by: gengjiaan <gengjiaan@360.cn>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
(cherry picked from commit cc4463e818)
Signed-off-by: Max Gekk <max.gekk@gmail.com>
This commit is contained in:
gengjiaan 2021-07-07 12:41:11 +03:00 committed by Max Gekk
parent dd038aacd4
commit 25ea296c3c
2 changed files with 14 additions and 10 deletions

View file

@ -92,9 +92,9 @@ case class ApproximatePercentile(
private lazy val accuracy: Long = accuracyExpression.eval().asInstanceOf[Number].longValue
override def inputTypes: Seq[AbstractDataType] = {
// Support NumericType, DateType and TimestampType since their internal types are all numeric,
// and can be easily cast to double for processing.
Seq(TypeCollection(NumericType, DateType, TimestampType),
// Support NumericType, DateType, TimestampType and TimestampNTZType since their internal types
// are all numeric, and can be easily cast to double for processing.
Seq(TypeCollection(NumericType, DateType, TimestampType, TimestampNTZType),
TypeCollection(DoubleType, ArrayType(DoubleType, containsNull = false)), IntegralType)
}
@ -139,7 +139,7 @@ case class ApproximatePercentile(
// Convert the value to a double value
val doubleValue = child.dataType match {
case DateType => value.asInstanceOf[Int].toDouble
case TimestampType => value.asInstanceOf[Long].toDouble
case TimestampType | TimestampNTZType => value.asInstanceOf[Long].toDouble
case n: NumericType => n.numeric.toDouble(value.asInstanceOf[n.InternalType])
case other: DataType =>
throw QueryExecutionErrors.dataTypeUnexpectedError(other)
@ -158,7 +158,7 @@ case class ApproximatePercentile(
val doubleResult = buffer.getPercentiles(percentages)
val result = child.dataType match {
case DateType => doubleResult.map(_.toInt)
case TimestampType => doubleResult.map(_.toLong)
case TimestampType | TimestampNTZType => doubleResult.map(_.toLong)
case ByteType => doubleResult.map(_.toByte)
case ShortType => doubleResult.map(_.toShort)
case IntegerType => doubleResult.map(_.toInt)

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql
import java.sql.{Date, Timestamp}
import java.time.LocalDateTime
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY
@ -89,23 +90,26 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession
test("percentile_approx, different column types") {
withTempView(table) {
val intSeq = 1 to 1000
val data: Seq[(java.math.BigDecimal, Date, Timestamp)] = intSeq.map { i =>
(new java.math.BigDecimal(i), DateTimeUtils.toJavaDate(i), DateTimeUtils.toJavaTimestamp(i))
val data: Seq[(java.math.BigDecimal, Date, Timestamp, LocalDateTime)] = intSeq.map { i =>
(new java.math.BigDecimal(i), DateTimeUtils.toJavaDate(i),
DateTimeUtils.toJavaTimestamp(i), DateTimeUtils.microsToLocalDateTime(i))
}
data.toDF("cdecimal", "cdate", "ctimestamp").createOrReplaceTempView(table)
data.toDF("cdecimal", "cdate", "ctimestamp", "ctimestampntz").createOrReplaceTempView(table)
checkAnswer(
spark.sql(
s"""SELECT
| percentile_approx(cdecimal, array(0.25, 0.5, 0.75D)),
| percentile_approx(cdate, array(0.25, 0.5, 0.75D)),
| percentile_approx(ctimestamp, array(0.25, 0.5, 0.75D))
| percentile_approx(ctimestamp, array(0.25, 0.5, 0.75D)),
| percentile_approx(ctimestampntz, array(0.25, 0.5, 0.75D))
|FROM $table
""".stripMargin),
Row(
Seq("250.000000000000000000", "500.000000000000000000", "750.000000000000000000")
.map(i => new java.math.BigDecimal(i)),
Seq(250, 500, 750).map(DateTimeUtils.toJavaDate),
Seq(250, 500, 750).map(i => DateTimeUtils.toJavaTimestamp(i.toLong)))
Seq(250, 500, 750).map(i => DateTimeUtils.toJavaTimestamp(i.toLong)),
Seq(250, 500, 750).map(i => DateTimeUtils.microsToLocalDateTime(i.toLong)))
)
}
}