[SPARK-36323][SQL] Support ANSI interval literals for TimeWindow
### What changes were proposed in this pull request? This PR proposes to support ANSI interval literals for `TimeWindow`. ### Why are the changes needed? Watermark also supports ANSI interval literals so it's great to support for `TimeWindow`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New test. Closes #33551 from sarutak/window-interval. Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com> Signed-off-by: Max Gekk <max.gekk@gmail.com>
This commit is contained in:
parent
6a8dd3229a
commit
db18866742
|
@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY
|
|||
import org.apache.spark.sql.catalyst.util.IntervalUtils
|
||||
import org.apache.spark.sql.errors.QueryCompilationErrors
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
case class TimeWindow(
|
||||
timeColumn: Expression,
|
||||
|
@ -110,12 +109,12 @@ object TimeWindow {
|
|||
* precision.
|
||||
*/
|
||||
def getIntervalInMicroSeconds(interval: String): Long = {
|
||||
val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval))
|
||||
val cal = IntervalUtils.fromIntervalString(interval)
|
||||
if (cal.months != 0) {
|
||||
throw new IllegalArgumentException(
|
||||
s"Intervals greater than a month is not supported ($interval).")
|
||||
}
|
||||
cal.days * MICROS_PER_DAY + cal.microseconds
|
||||
Math.addExact(Math.multiplyExact(cal.days, MICROS_PER_DAY), cal.microseconds)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -25,11 +25,14 @@ import java.util.concurrent.TimeUnit
|
|||
import scala.collection.mutable
|
||||
import scala.util.control.NonFatal
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.Literal
|
||||
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToMicros
|
||||
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.{ANSI_STYLE, HIVE_STYLE, IntervalStyle}
|
||||
import org.apache.spark.sql.errors.QueryExecutionErrors
|
||||
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.types.{DayTimeIntervalType => DT, Decimal, YearMonthIntervalType => YM}
|
||||
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
||||
|
||||
|
@ -433,6 +436,24 @@ object IntervalUtils {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse all kinds of interval literals including unit-to-unit form and unit list form
|
||||
*/
|
||||
def fromIntervalString(input: String): CalendarInterval = try {
|
||||
if (input.toLowerCase(Locale.ROOT).trim.startsWith("interval")) {
|
||||
CatalystSqlParser.parseExpression(input) match {
|
||||
case Literal(months: Int, _: YearMonthIntervalType) => new CalendarInterval(months, 0, 0)
|
||||
case Literal(micros: Long, _: DayTimeIntervalType) => new CalendarInterval(0, 0, micros)
|
||||
case Literal(cal: CalendarInterval, CalendarIntervalType) => cal
|
||||
}
|
||||
} else {
|
||||
stringToInterval(UTF8String.fromString(input))
|
||||
}
|
||||
} catch {
|
||||
case NonFatal(e) =>
|
||||
throw QueryCompilationErrors.cannotParseIntervalError(input, e)
|
||||
}
|
||||
|
||||
private val dayTimePatternLegacy =
|
||||
"^([+|-])?((\\d+) )?((\\d+):)?(\\d+):(\\d+)(\\.(\\d+))?$".r
|
||||
|
||||
|
|
|
@ -2261,8 +2261,8 @@ private[spark] object QueryCompilationErrors {
|
|||
s"""Cannot resolve column name "$colName" among (${fieldsStr})${extraMsg}""")
|
||||
}
|
||||
|
||||
def cannotParseTimeDelayError(delayThreshold: String, e: Throwable): Throwable = {
|
||||
new AnalysisException(s"Unable to parse time delay '$delayThreshold'", cause = Some(e))
|
||||
def cannotParseIntervalError(delayThreshold: String, e: Throwable): Throwable = {
|
||||
new AnalysisException(s"Unable to parse '$delayThreshold'", cause = Some(e))
|
||||
}
|
||||
|
||||
def invalidJoinTypeInJoinWithError(joinType: JoinType): Throwable = {
|
||||
|
|
|
@ -17,10 +17,14 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
import org.scalatest.PrivateMethodTester
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampNTZType, TimestampType}
|
||||
|
||||
class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with PrivateMethodTester {
|
||||
|
@ -31,16 +35,16 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
|
|||
}
|
||||
}
|
||||
|
||||
private def checkErrorMessage(msg: String, value: String): Unit = {
|
||||
private def checkErrorMessage[E <: Exception : ClassTag](msg: String, value: String): Unit = {
|
||||
val validDuration = "10 second"
|
||||
val validTime = "5 second"
|
||||
val e1 = intercept[IllegalArgumentException] {
|
||||
val e1 = intercept[E] {
|
||||
TimeWindow(Literal(10L), value, validDuration, validTime).windowDuration
|
||||
}
|
||||
val e2 = intercept[IllegalArgumentException] {
|
||||
val e2 = intercept[E] {
|
||||
TimeWindow(Literal(10L), validDuration, value, validTime).slideDuration
|
||||
}
|
||||
val e3 = intercept[IllegalArgumentException] {
|
||||
val e3 = intercept[E] {
|
||||
TimeWindow(Literal(10L), validDuration, validDuration, value).startTime
|
||||
}
|
||||
Seq(e1, e2, e3).foreach { e =>
|
||||
|
@ -50,18 +54,18 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
|
|||
|
||||
test("blank intervals throw exception") {
|
||||
for (blank <- Seq(null, " ", "\n", "\t")) {
|
||||
checkErrorMessage(
|
||||
checkErrorMessage[AnalysisException](
|
||||
"The window duration, slide duration and start time cannot be null or blank.", blank)
|
||||
}
|
||||
}
|
||||
|
||||
test("invalid intervals throw exception") {
|
||||
checkErrorMessage(
|
||||
checkErrorMessage[AnalysisException](
|
||||
"did not correspond to a valid interval string.", "2 apples")
|
||||
}
|
||||
|
||||
test("intervals greater than a month throws exception") {
|
||||
checkErrorMessage(
|
||||
checkErrorMessage[IllegalArgumentException](
|
||||
"Intervals greater than or equal to a month is not supported (1 month).", "1 month")
|
||||
}
|
||||
|
||||
|
@ -111,7 +115,7 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
|
|||
}
|
||||
|
||||
test("parse sql expression for duration in microseconds - invalid interval") {
|
||||
intercept[IllegalArgumentException] {
|
||||
intercept[AnalysisException] {
|
||||
TimeWindow.invokePrivate(parseExpression(Literal("2 apples")))
|
||||
}
|
||||
}
|
||||
|
@ -147,4 +151,46 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
|
|||
assert(timestampNTZWindow.dataType == StructType(
|
||||
Seq(StructField("start", TimestampNTZType), StructField("end", TimestampNTZType))))
|
||||
}
|
||||
|
||||
Seq("true", "false").foreach { legacyIntervalEnabled =>
|
||||
test("SPARK-36323: Support ANSI interval literals for TimeWindow " +
|
||||
s"(${SQLConf.LEGACY_INTERVAL_ENABLED.key}=$legacyIntervalEnabled)") {
|
||||
withSQLConf(SQLConf.LEGACY_INTERVAL_ENABLED.key -> legacyIntervalEnabled) {
|
||||
Seq(
|
||||
// Conventional form and some variants
|
||||
(Seq("3 days", "Interval 3 day", "inTerval '3' day"), 3 * MICROS_PER_DAY),
|
||||
(Seq(" 5 hours", "INTERVAL 5 hour", "interval '5' hour"), 5 * MICROS_PER_HOUR),
|
||||
(Seq("\t8 minutes", "interval 8 minute", "interval '8' minute"), 8 * MICROS_PER_MINUTE),
|
||||
(Seq(
|
||||
"10 seconds", "interval 10 second", "interval '10' second"), 10 * MICROS_PER_SECOND),
|
||||
(Seq(
|
||||
"1 day 2 hours 3 minutes 4 seconds",
|
||||
" interval 1 day 2 hours 3 minutes 4 seconds",
|
||||
"\tinterval '1' day '2' hours '3' minutes '4' seconds",
|
||||
"interval '1 2:3:4' day to second"),
|
||||
MICROS_PER_DAY + 2 * MICROS_PER_HOUR + 3 * MICROS_PER_MINUTE + 4 * MICROS_PER_SECOND)
|
||||
).foreach { case (intervalVariants, expectedMs) =>
|
||||
intervalVariants.foreach { case interval =>
|
||||
val timeWindow = TimeWindow(Literal(10L, TimestampType), interval, interval, interval)
|
||||
val expected =
|
||||
TimeWindow(Literal(10L, TimestampType), expectedMs, expectedMs, expectedMs)
|
||||
assert(timeWindow === expected)
|
||||
}
|
||||
}
|
||||
|
||||
// year-month interval literals are not supported for TimeWindow.
|
||||
Seq(
|
||||
"1 years", "interval 1 year", "interval '1' year",
|
||||
"1 months", "interval 1 month", "interval '1' month",
|
||||
" 1 year 2 months",
|
||||
"interval 1 year 2 month",
|
||||
"interval '1' year '2' month",
|
||||
"\tinterval '1-2' year to month").foreach { interval =>
|
||||
intercept[IllegalArgumentException] {
|
||||
TimeWindow(Literal(10L, TimestampType), interval, interval, interval)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
package org.apache.spark.sql
|
||||
|
||||
import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream}
|
||||
import java.util.Locale
|
||||
|
||||
import scala.annotation.varargs
|
||||
import scala.collection.JavaConverters._
|
||||
|
@ -44,7 +43,7 @@ import org.apache.spark.sql.catalyst.encoders._
|
|||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
|
||||
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
|
||||
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
|
||||
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
|
||||
|
@ -65,7 +64,6 @@ import org.apache.spark.sql.types._
|
|||
import org.apache.spark.sql.util.SchemaUtils
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
import org.apache.spark.unsafe.array.ByteArrayMethods
|
||||
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
private[sql] object Dataset {
|
||||
|
@ -741,21 +739,7 @@ class Dataset[T] private[sql](
|
|||
// We only accept an existing column name, not a derived column here as a watermark that is
|
||||
// defined on a derived column cannot referenced elsewhere in the plan.
|
||||
def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withTypedPlan {
|
||||
val parsedDelay = try {
|
||||
if (delayThreshold.toLowerCase(Locale.ROOT).trim.startsWith("interval")) {
|
||||
CatalystSqlParser.parseExpression(delayThreshold) match {
|
||||
case Literal(months: Int, _: YearMonthIntervalType) =>
|
||||
new CalendarInterval(months, 0, 0)
|
||||
case Literal(micros: Long, _: DayTimeIntervalType) =>
|
||||
new CalendarInterval(0, 0, micros)
|
||||
}
|
||||
} else {
|
||||
IntervalUtils.stringToInterval(UTF8String.fromString(delayThreshold))
|
||||
}
|
||||
} catch {
|
||||
case NonFatal(e) =>
|
||||
throw QueryCompilationErrors.cannotParseTimeDelayError(delayThreshold, e)
|
||||
}
|
||||
val parsedDelay = IntervalUtils.fromIntervalString(delayThreshold)
|
||||
require(!IntervalUtils.isNegative(parsedDelay),
|
||||
s"delay threshold ($delayThreshold) should not be negative.")
|
||||
EliminateEventTimeWatermark(
|
||||
|
|
Loading…
Reference in a new issue