[SPARK-36323][SQL] Support ANSI interval literals for TimeWindow

### What changes were proposed in this pull request?

This PR proposes to support ANSI interval literals for `TimeWindow`.

### Why are the changes needed?

Watermark also supports ANSI interval literals so it's great to support for `TimeWindow`.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New test.

Closes #33551 from sarutak/window-interval.

Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
This commit is contained in:
Kousuke Saruta 2021-07-29 08:51:51 +03:00 committed by Max Gekk
parent 6a8dd3229a
commit db18866742
5 changed files with 82 additions and 32 deletions

View file

@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
case class TimeWindow(
timeColumn: Expression,
@ -110,12 +109,12 @@ object TimeWindow {
* precision.
*/
def getIntervalInMicroSeconds(interval: String): Long = {
val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval))
val cal = IntervalUtils.fromIntervalString(interval)
if (cal.months != 0) {
throw new IllegalArgumentException(
s"Intervals greater than a month is not supported ($interval).")
}
cal.days * MICROS_PER_DAY + cal.microseconds
Math.addExact(Math.multiplyExact(cal.days, MICROS_PER_DAY), cal.microseconds)
}
/**

View file

@ -25,11 +25,14 @@ import java.util.concurrent.TimeUnit
import scala.collection.mutable
import scala.util.control.NonFatal
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToMicros
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.{ANSI_STYLE, HIVE_STYLE, IntervalStyle}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.{DayTimeIntervalType => DT, Decimal, YearMonthIntervalType => YM}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@ -433,6 +436,24 @@ object IntervalUtils {
}
}
/**
* Parse all kinds of interval literals including unit-to-unit form and unit list form
*/
def fromIntervalString(input: String): CalendarInterval = try {
if (input.toLowerCase(Locale.ROOT).trim.startsWith("interval")) {
CatalystSqlParser.parseExpression(input) match {
case Literal(months: Int, _: YearMonthIntervalType) => new CalendarInterval(months, 0, 0)
case Literal(micros: Long, _: DayTimeIntervalType) => new CalendarInterval(0, 0, micros)
case Literal(cal: CalendarInterval, CalendarIntervalType) => cal
}
} else {
stringToInterval(UTF8String.fromString(input))
}
} catch {
case NonFatal(e) =>
throw QueryCompilationErrors.cannotParseIntervalError(input, e)
}
private val dayTimePatternLegacy =
"^([+|-])?((\\d+) )?((\\d+):)?(\\d+):(\\d+)(\\.(\\d+))?$".r

View file

@ -2261,8 +2261,8 @@ private[spark] object QueryCompilationErrors {
s"""Cannot resolve column name "$colName" among (${fieldsStr})${extraMsg}""")
}
def cannotParseTimeDelayError(delayThreshold: String, e: Throwable): Throwable = {
new AnalysisException(s"Unable to parse time delay '$delayThreshold'", cause = Some(e))
def cannotParseIntervalError(delayThreshold: String, e: Throwable): Throwable = {
new AnalysisException(s"Unable to parse '$delayThreshold'", cause = Some(e))
}
def invalidJoinTypeInJoinWithError(joinType: JoinType): Throwable = {

View file

@ -17,10 +17,14 @@
package org.apache.spark.sql.catalyst.expressions
import scala.reflect.ClassTag
import org.scalatest.PrivateMethodTester
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampNTZType, TimestampType}
class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with PrivateMethodTester {
@ -31,16 +35,16 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
}
}
private def checkErrorMessage(msg: String, value: String): Unit = {
private def checkErrorMessage[E <: Exception : ClassTag](msg: String, value: String): Unit = {
val validDuration = "10 second"
val validTime = "5 second"
val e1 = intercept[IllegalArgumentException] {
val e1 = intercept[E] {
TimeWindow(Literal(10L), value, validDuration, validTime).windowDuration
}
val e2 = intercept[IllegalArgumentException] {
val e2 = intercept[E] {
TimeWindow(Literal(10L), validDuration, value, validTime).slideDuration
}
val e3 = intercept[IllegalArgumentException] {
val e3 = intercept[E] {
TimeWindow(Literal(10L), validDuration, validDuration, value).startTime
}
Seq(e1, e2, e3).foreach { e =>
@ -50,18 +54,18 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
test("blank intervals throw exception") {
for (blank <- Seq(null, " ", "\n", "\t")) {
checkErrorMessage(
checkErrorMessage[AnalysisException](
"The window duration, slide duration and start time cannot be null or blank.", blank)
}
}
test("invalid intervals throw exception") {
checkErrorMessage(
checkErrorMessage[AnalysisException](
"did not correspond to a valid interval string.", "2 apples")
}
test("intervals greater than a month throws exception") {
checkErrorMessage(
checkErrorMessage[IllegalArgumentException](
"Intervals greater than or equal to a month is not supported (1 month).", "1 month")
}
@ -111,7 +115,7 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
}
test("parse sql expression for duration in microseconds - invalid interval") {
intercept[IllegalArgumentException] {
intercept[AnalysisException] {
TimeWindow.invokePrivate(parseExpression(Literal("2 apples")))
}
}
@ -147,4 +151,46 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
assert(timestampNTZWindow.dataType == StructType(
Seq(StructField("start", TimestampNTZType), StructField("end", TimestampNTZType))))
}
Seq("true", "false").foreach { legacyIntervalEnabled =>
test("SPARK-36323: Support ANSI interval literals for TimeWindow " +
s"(${SQLConf.LEGACY_INTERVAL_ENABLED.key}=$legacyIntervalEnabled)") {
withSQLConf(SQLConf.LEGACY_INTERVAL_ENABLED.key -> legacyIntervalEnabled) {
Seq(
// Conventional form and some variants
(Seq("3 days", "Interval 3 day", "inTerval '3' day"), 3 * MICROS_PER_DAY),
(Seq(" 5 hours", "INTERVAL 5 hour", "interval '5' hour"), 5 * MICROS_PER_HOUR),
(Seq("\t8 minutes", "interval 8 minute", "interval '8' minute"), 8 * MICROS_PER_MINUTE),
(Seq(
"10 seconds", "interval 10 second", "interval '10' second"), 10 * MICROS_PER_SECOND),
(Seq(
"1 day 2 hours 3 minutes 4 seconds",
" interval 1 day 2 hours 3 minutes 4 seconds",
"\tinterval '1' day '2' hours '3' minutes '4' seconds",
"interval '1 2:3:4' day to second"),
MICROS_PER_DAY + 2 * MICROS_PER_HOUR + 3 * MICROS_PER_MINUTE + 4 * MICROS_PER_SECOND)
).foreach { case (intervalVariants, expectedMs) =>
intervalVariants.foreach { case interval =>
val timeWindow = TimeWindow(Literal(10L, TimestampType), interval, interval, interval)
val expected =
TimeWindow(Literal(10L, TimestampType), expectedMs, expectedMs, expectedMs)
assert(timeWindow === expected)
}
}
// year-month interval literals are not supported for TimeWindow.
Seq(
"1 years", "interval 1 year", "interval '1' year",
"1 months", "interval 1 month", "interval '1' month",
" 1 year 2 months",
"interval 1 year 2 month",
"interval '1' year '2' month",
"\tinterval '1-2' year to month").foreach { interval =>
intercept[IllegalArgumentException] {
TimeWindow(Literal(10L, TimestampType), interval, interval, interval)
}
}
}
}
}
}

View file

@ -18,7 +18,6 @@
package org.apache.spark.sql
import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream}
import java.util.Locale
import scala.annotation.varargs
import scala.collection.JavaConverters._
@ -44,7 +43,7 @@ import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
@ -65,7 +64,6 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.storage.StorageLevel
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.util.Utils
private[sql] object Dataset {
@ -741,21 +739,7 @@ class Dataset[T] private[sql](
// We only accept an existing column name, not a derived column here as a watermark that is
// defined on a derived column cannot referenced elsewhere in the plan.
def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withTypedPlan {
val parsedDelay = try {
if (delayThreshold.toLowerCase(Locale.ROOT).trim.startsWith("interval")) {
CatalystSqlParser.parseExpression(delayThreshold) match {
case Literal(months: Int, _: YearMonthIntervalType) =>
new CalendarInterval(months, 0, 0)
case Literal(micros: Long, _: DayTimeIntervalType) =>
new CalendarInterval(0, 0, micros)
}
} else {
IntervalUtils.stringToInterval(UTF8String.fromString(delayThreshold))
}
} catch {
case NonFatal(e) =>
throw QueryCompilationErrors.cannotParseTimeDelayError(delayThreshold, e)
}
val parsedDelay = IntervalUtils.fromIntervalString(delayThreshold)
require(!IntervalUtils.isNegative(parsedDelay),
s"delay threshold ($delayThreshold) should not be negative.")
EliminateEventTimeWatermark(