From 9d3d25bca4d087d4e509c4fc83cf960397aebe9b Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Thu, 11 Mar 2021 10:08:43 +0000 Subject: [PATCH] [SPARK-34677][SQL] Support the `+`/`-` operators over ANSI SQL intervals ### What changes were proposed in this pull request? Extend the `Add`, `Subtract` and `UnaryMinus` expression to support `DayTimeIntervalType` and `YearMonthIntervalType` added by #31614. Note: the expressions can throw the `overflow` exception independently from the SQL config `spark.sql.ansi.enabled`. In this way, the modified expressions always behave in the ANSI mode for the intervals. ### Why are the changes needed? To conform to the ANSI SQL standard which defines `-/+` over intervals: Screenshot 2021-03-09 at 21 59 22 ### Does this PR introduce _any_ user-facing change? Should not since new types have not been released yet. ### How was this patch tested? By running new tests in the test suites: ``` $ build/sbt "test:testOnly *ArithmeticExpressionSuite" $ build/sbt "test:testOnly *ColumnExpressionSuite" ``` Closes #31789 from MaxGekk/add-subtruct-intervals. Authored-by: Max Gekk Signed-off-by: Wenchen Fan --- .../sql/catalyst/expressions/arithmetic.scala | 21 ++++++++++ .../spark/sql/types/AbstractDataType.scala | 6 ++- .../ExpressionTypeCheckingSuite.scala | 4 +- .../ArithmeticExpressionSuite.scala | 40 +++++++++++++++++++ .../spark/sql/types/DataTypeTestUtils.scala | 5 ++- .../sql-tests/results/ansi/literals.sql.out | 18 ++++----- .../sql-tests/results/literals.sql.out | 18 ++++----- .../native/windowFrameCoercion.sql.out | 6 +-- .../spark/sql/ColumnExpressionSuite.scala | 13 ++++++ 9 files changed, 106 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 2ee68e62ab..59831dae21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -83,12 +83,19 @@ case class UnaryMinus( val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") val method = if (failOnError) "negateExact" else "negate" defineCodeGen(ctx, ev, c => s"$iu.$method($c)") + case DayTimeIntervalType | YearMonthIntervalType => + nullSafeCodeGen(ctx, ev, eval => { + val mathClass = classOf[Math].getName + s"${ev.value} = $mathClass.negateExact($eval);" + }) } protected override def nullSafeEval(input: Any): Any = dataType match { case CalendarIntervalType if failOnError => IntervalUtils.negateExact(input.asInstanceOf[CalendarInterval]) case CalendarIntervalType => IntervalUtils.negate(input.asInstanceOf[CalendarInterval]) + case DayTimeIntervalType => Math.negateExact(input.asInstanceOf[Long]) + case YearMonthIntervalType => Math.negateExact(input.asInstanceOf[Int]) case _ => numeric.negate(input) } @@ -185,6 +192,12 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { case CalendarIntervalType => val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") defineCodeGen(ctx, ev, (eval1, eval2) => s"$iu.$calendarIntervalMethod($eval1, $eval2)") + case DayTimeIntervalType | YearMonthIntervalType => + assert(exactMathMethod.isDefined, + s"The expression '$nodeName' must override the exactMathMethod() method " + + "if it is supposed to operate over interval types.") + val mathClass = classOf[Math].getName + defineCodeGen(ctx, ev, (eval1, eval2) => s"$mathClass.${exactMathMethod.get}($eval1, $eval2)") // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { @@ -267,6 +280,10 @@ case class Add( case CalendarIntervalType => IntervalUtils.add( input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) + case DayTimeIntervalType => + Math.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long]) + case YearMonthIntervalType => + Math.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int]) case _ => numeric.plus(input1, input2) } @@ -306,6 +323,10 @@ case class Subtract( case CalendarIntervalType => IntervalUtils.subtract( input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) + case DayTimeIntervalType => + Math.subtractExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long]) + case YearMonthIntervalType => + Math.subtractExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int]) case _ => numeric.minus(input1, input2) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 21ac32adca..02c95b286a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -82,7 +82,11 @@ private[sql] object TypeCollection { * Types that include numeric types and interval type. They are only used in unary_minus, * unary_positive, add and subtract operations. */ - val NumericAndInterval = TypeCollection(NumericType, CalendarIntervalType) + val NumericAndInterval = TypeCollection( + NumericType, + CalendarIntervalType, + DayTimeIntervalType, + YearMonthIntervalType) def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 46634c9314..ee560ea4be 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -78,9 +78,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(BitwiseXor(Symbol("intField"), Symbol("booleanField"))) assertError(Add(Symbol("booleanField"), Symbol("booleanField")), - "requires (numeric or interval) type") + "requires (numeric or interval or daytimeinterval or yearmonthinterval) type") assertError(Subtract(Symbol("booleanField"), Symbol("booleanField")), - "requires (numeric or interval) type") + "requires (numeric or interval or daytimeinterval or yearmonthinterval) type") assertError(Multiply(Symbol("booleanField"), Symbol("booleanField")), "requires numeric type") assertError(Divide(Symbol("booleanField"), Symbol("booleanField")), "requires (double or decimal) type") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 14dd04afeb..ca97418e0d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} +import java.time.{Duration, Period} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow @@ -576,4 +577,43 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } } + + test("SPARK-34677: exact add and subtract of day-time and year-month intervals") { + Seq(true, false).foreach { failOnError => + checkExceptionInExpression[ArithmeticException]( + UnaryMinus( + Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType), + failOnError), + "overflow") + checkExceptionInExpression[ArithmeticException]( + Subtract( + Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType), + Literal.create(Period.ofMonths(10), YearMonthIntervalType), + failOnError + ), + "overflow") + checkExceptionInExpression[ArithmeticException]( + Add( + Literal.create(Period.ofMonths(Int.MaxValue), YearMonthIntervalType), + Literal.create(Period.ofMonths(10), YearMonthIntervalType), + failOnError + ), + "overflow") + + checkExceptionInExpression[ArithmeticException]( + Subtract( + Literal.create(Duration.ofDays(-106751991), DayTimeIntervalType), + Literal.create(Duration.ofDays(10), DayTimeIntervalType), + failOnError + ), + "overflow") + checkExceptionInExpression[ArithmeticException]( + Add( + Literal.create(Duration.ofDays(106751991), DayTimeIntervalType), + Literal.create(Duration.ofDays(10), DayTimeIntervalType), + failOnError + ), + "overflow") + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala index 07552a510b..769de33528 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala @@ -52,7 +52,10 @@ object DataTypeTestUtils { /** * Instances of all [[NumericType]]s and [[CalendarIntervalType]] */ - val numericAndInterval: Set[DataType] = numericTypeWithoutDecimal + CalendarIntervalType + val numericAndInterval: Set[DataType] = numericTypeWithoutDecimal ++ Set( + CalendarIntervalType, + DayTimeIntervalType, + YearMonthIntervalType) /** * All the types that support ordering diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out index ea74bb7175..1c290a0f3d 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out @@ -436,7 +436,7 @@ select +date '1999-01-01' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7 +cannot resolve '(+ DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7 -- !query @@ -445,7 +445,7 @@ select +timestamp '1999-01-01' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7 +cannot resolve '(+ TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7 -- !query @@ -462,7 +462,7 @@ select +map(1, 2) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ map(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'map(1, 2)' is of map type.; line 1 pos 7 +cannot resolve '(+ map(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'map(1, 2)' is of map type.; line 1 pos 7 -- !query @@ -471,7 +471,7 @@ select +array(1,2) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ array(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'array(1, 2)' is of array type.; line 1 pos 7 +cannot resolve '(+ array(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'array(1, 2)' is of array type.; line 1 pos 7 -- !query @@ -480,7 +480,7 @@ select +named_struct('a', 1, 'b', 'spark') struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ named_struct('a', 1, 'b', 'spark'))' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'named_struct('a', 1, 'b', 'spark')' is of struct type.; line 1 pos 7 +cannot resolve '(+ named_struct('a', 1, 'b', 'spark'))' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'named_struct('a', 1, 'b', 'spark')' is of struct type.; line 1 pos 7 -- !query @@ -489,7 +489,7 @@ select +X'1' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ X'01')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'X'01'' is of binary type.; line 1 pos 7 +cannot resolve '(+ X'01')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'X'01'' is of binary type.; line 1 pos 7 -- !query @@ -498,7 +498,7 @@ select -date '1999-01-01' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(- DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7 +cannot resolve '(- DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7 -- !query @@ -507,7 +507,7 @@ select -timestamp '1999-01-01' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(- TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7 +cannot resolve '(- TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7 -- !query @@ -516,4 +516,4 @@ select -x'2379ACFe' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(- X'2379ACFE')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'X'2379ACFE'' is of binary type.; line 1 pos 7 +cannot resolve '(- X'2379ACFE')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'X'2379ACFE'' is of binary type.; line 1 pos 7 diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index ea74bb7175..1c290a0f3d 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -436,7 +436,7 @@ select +date '1999-01-01' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7 +cannot resolve '(+ DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7 -- !query @@ -445,7 +445,7 @@ select +timestamp '1999-01-01' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7 +cannot resolve '(+ TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7 -- !query @@ -462,7 +462,7 @@ select +map(1, 2) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ map(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'map(1, 2)' is of map type.; line 1 pos 7 +cannot resolve '(+ map(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'map(1, 2)' is of map type.; line 1 pos 7 -- !query @@ -471,7 +471,7 @@ select +array(1,2) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ array(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'array(1, 2)' is of array type.; line 1 pos 7 +cannot resolve '(+ array(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'array(1, 2)' is of array type.; line 1 pos 7 -- !query @@ -480,7 +480,7 @@ select +named_struct('a', 1, 'b', 'spark') struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ named_struct('a', 1, 'b', 'spark'))' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'named_struct('a', 1, 'b', 'spark')' is of struct type.; line 1 pos 7 +cannot resolve '(+ named_struct('a', 1, 'b', 'spark'))' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'named_struct('a', 1, 'b', 'spark')' is of struct type.; line 1 pos 7 -- !query @@ -489,7 +489,7 @@ select +X'1' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ X'01')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'X'01'' is of binary type.; line 1 pos 7 +cannot resolve '(+ X'01')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'X'01'' is of binary type.; line 1 pos 7 -- !query @@ -498,7 +498,7 @@ select -date '1999-01-01' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(- DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7 +cannot resolve '(- DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7 -- !query @@ -507,7 +507,7 @@ select -timestamp '1999-01-01' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(- TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7 +cannot resolve '(- TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7 -- !query @@ -516,4 +516,4 @@ select -x'2379ACFe' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(- X'2379ACFE')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'X'2379ACFE'' is of binary type.; line 1 pos 7 +cannot resolve '(- X'2379ACFE')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'X'2379ACFE'' is of binary type.; line 1 pos 7 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out index 71ef82d48b..1520d807a1 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out @@ -168,7 +168,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as string) DESC RANGE BETWE struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS STRING) FOLLOWING' due to data type mismatch: The data type of the upper bound 'string' does not match the expected data type '(numeric or interval)'.; line 1 pos 21 +cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS STRING) FOLLOWING' due to data type mismatch: The data type of the upper bound 'string' does not match the expected data type '(numeric or interval or daytimeinterval or yearmonthinterval)'.; line 1 pos 21 -- !query @@ -177,7 +177,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('1' as binary) DESC RANGE BET struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BINARY) FOLLOWING' due to data type mismatch: The data type of the upper bound 'binary' does not match the expected data type '(numeric or interval)'.; line 1 pos 21 +cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BINARY) FOLLOWING' due to data type mismatch: The data type of the upper bound 'binary' does not match the expected data type '(numeric or interval or daytimeinterval or yearmonthinterval)'.; line 1 pos 21 -- !query @@ -186,7 +186,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as boolean) DESC RANGE BETW struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BOOLEAN) FOLLOWING' due to data type mismatch: The data type of the upper bound 'boolean' does not match the expected data type '(numeric or interval)'.; line 1 pos 21 +cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BOOLEAN) FOLLOWING' due to data type mismatch: The data type of the upper bound 'boolean' does not match the expected data type '(numeric or interval or daytimeinterval or yearmonthinterval)'.; line 1 pos 21 -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 4f64de4ae8..fac510502c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import java.time.{Duration, Period} import java.util.Locale import org.apache.hadoop.io.{LongWritable, Text} @@ -2375,4 +2376,16 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { assert(e2.getCause.isInstanceOf[RuntimeException]) assert(e2.getCause.getMessage == "hello") } + + test("SPARK-34677: negate/add/subtract year-month and day-time intervals") { + import testImplicits._ + val df = Seq((Period.ofMonths(10), Duration.ofDays(10), Period.ofMonths(1), Duration.ofDays(1))) + .toDF("year-month-A", "day-time-A", "year-month-B", "day-time-B") + val negatedDF = df.select(-$"year-month-A", -$"day-time-A") + checkAnswer(negatedDF, Row(Period.ofMonths(-10), Duration.ofDays(-10))) + val addDF = df.select($"year-month-A" + $"year-month-B", $"day-time-A" + $"day-time-B") + checkAnswer(addDF, Row(Period.ofMonths(11), Duration.ofDays(11))) + val subDF = df.select($"year-month-A" - $"year-month-B", $"day-time-A" - $"day-time-B") + checkAnswer(subDF, Row(Period.ofMonths(9), Duration.ofDays(9))) + } }