diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AnsiCastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AnsiCastSuiteBase.scala index 20f70b9b50..f943440323 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AnsiCastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AnsiCastSuiteBase.scala @@ -25,6 +25,14 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +/** + * Test suite base for + * 1. [[Cast]] with ANSI mode enabled + * 2. [[AnsiCast]] + * 3. [[TryCast]] + * Note: for new test cases that work for [[Cast]], [[AnsiCast]] and [[TryCast]], please add them + * in `CastSuiteBase` instead of this file to ensure the test coverage. + */ abstract class AnsiCastSuiteBase extends CastSuiteBase { private def testIntMaxAndMin(dt: DataType): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 96596601b2..6e08500286 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -18,8 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} -import java.time.{Duration, Period} -import java.time.temporal.ChronoUnit import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -30,12 +28,12 @@ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.DayTimeIntervalType._ -import org.apache.spark.sql.types.YearMonthIntervalType._ import org.apache.spark.unsafe.types.UTF8String /** - * Test suite for data type casting expression [[Cast]]. + * Test suite for data type casting expression [[Cast]] with ANSI mode disabled. + * Note: for new test cases that work for [[Cast]], [[AnsiCast]] and [[TryCast]], please add them + * in `CastSuiteBase` instead of this file to ensure the test coverage. */ class CastSuite extends CastSuiteBase { @@ -568,92 +566,6 @@ class CastSuite extends CastSuiteBase { } } - test("SPARK-35112: Cast string to day-time interval") { - checkEvaluation(cast(Literal.create("0 0:0:0"), DayTimeIntervalType()), 0L) - checkEvaluation(cast(Literal.create(" interval '0 0:0:0' Day TO second "), - DayTimeIntervalType()), 0L) - checkEvaluation(cast(Literal.create("INTERVAL '1 2:03:04' DAY TO SECOND"), - DayTimeIntervalType()), 93784000000L) - checkEvaluation(cast(Literal.create("INTERVAL '1 03:04:00' DAY TO SECOND"), - DayTimeIntervalType()), 97440000000L) - checkEvaluation(cast(Literal.create("INTERVAL '1 03:04:00.0000' DAY TO SECOND"), - DayTimeIntervalType()), 97440000000L) - checkEvaluation(cast(Literal.create("1 2:03:04"), DayTimeIntervalType()), 93784000000L) - checkEvaluation(cast(Literal.create("INTERVAL '-10 2:03:04' DAY TO SECOND"), - DayTimeIntervalType()), -871384000000L) - checkEvaluation(cast(Literal.create("-10 2:03:04"), DayTimeIntervalType()), -871384000000L) - checkEvaluation(cast(Literal.create("-106751991 04:00:54.775808"), DayTimeIntervalType()), - Long.MinValue) - checkEvaluation(cast(Literal.create("106751991 04:00:54.775807"), DayTimeIntervalType()), - Long.MaxValue) - - Seq("-106751991 04:00:54.775808", "106751991 04:00:54.775807").foreach { interval => - val ansiInterval = s"INTERVAL '$interval' DAY TO SECOND" - checkEvaluation( - cast(cast(Literal.create(interval), DayTimeIntervalType()), StringType), ansiInterval) - checkEvaluation(cast(cast(Literal.create(ansiInterval), - DayTimeIntervalType()), StringType), ansiInterval) - } - - Seq("INTERVAL '-106751991 04:00:54.775809' YEAR TO MONTH", - "INTERVAL '106751991 04:00:54.775808' YEAR TO MONTH").foreach { interval => - val e = intercept[IllegalArgumentException] { - cast(Literal.create(interval), DayTimeIntervalType()).eval() - }.getMessage - assert(e.contains("Interval string must match day-time format of")) - } - - Seq(Byte.MaxValue, Short.MaxValue, Int.MaxValue, Long.MaxValue, Long.MinValue + 1, - Long.MinValue).foreach { duration => - val interval = Literal.create( - Duration.of(duration, ChronoUnit.MICROS), - DayTimeIntervalType()) - checkEvaluation(cast(cast(interval, StringType), DayTimeIntervalType()), duration) - } - } - - test("SPARK-35111: Cast string to year-month interval") { - checkEvaluation(cast(Literal.create("INTERVAL '1-0' YEAR TO MONTH"), - YearMonthIntervalType()), 12) - checkEvaluation(cast(Literal.create("INTERVAL '-1-0' YEAR TO MONTH"), - YearMonthIntervalType()), -12) - checkEvaluation(cast(Literal.create("INTERVAL -'-1-0' YEAR TO MONTH"), - YearMonthIntervalType()), 12) - checkEvaluation(cast(Literal.create("INTERVAL +'-1-0' YEAR TO MONTH"), - YearMonthIntervalType()), -12) - checkEvaluation(cast(Literal.create("INTERVAL +'+1-0' YEAR TO MONTH"), - YearMonthIntervalType()), 12) - checkEvaluation(cast(Literal.create("INTERVAL +'1-0' YEAR TO MONTH"), - YearMonthIntervalType()), 12) - checkEvaluation(cast(Literal.create(" interval +'1-0' YEAR TO MONTH "), - YearMonthIntervalType()), 12) - checkEvaluation(cast(Literal.create(" -1-0 "), YearMonthIntervalType()), -12) - checkEvaluation(cast(Literal.create("-1-0"), YearMonthIntervalType()), -12) - checkEvaluation(cast(Literal.create(null, StringType), YearMonthIntervalType()), null) - - Seq("0-0", "10-1", "-178956970-7", "178956970-7", "-178956970-8").foreach { interval => - val ansiInterval = s"INTERVAL '$interval' YEAR TO MONTH" - checkEvaluation( - cast(cast(Literal.create(interval), YearMonthIntervalType()), StringType), ansiInterval) - checkEvaluation(cast(cast(Literal.create(ansiInterval), - YearMonthIntervalType()), StringType), ansiInterval) - } - - Seq("INTERVAL '-178956970-9' YEAR TO MONTH", "INTERVAL '178956970-8' YEAR TO MONTH") - .foreach { interval => - val e = intercept[IllegalArgumentException] { - cast(Literal.create(interval), YearMonthIntervalType()).eval() - }.getMessage - assert(e.contains("Error parsing interval year-month string: integer overflow")) - } - - Seq(Byte.MaxValue, Short.MaxValue, Int.MaxValue, Int.MinValue + 1, Int.MinValue) - .foreach { period => - val interval = Literal.create(Period.ofMonths(period), YearMonthIntervalType()) - checkEvaluation(cast(cast(interval, StringType), YearMonthIntervalType()), period) - } - } - test("SPARK-35720: cast invalid string input to timestamp without time zone") { Seq("00:00:00", "a", @@ -664,34 +576,4 @@ class CastSuite extends CastSuiteBase { checkEvaluation(cast(invalidInput, TimestampWithoutTZType), null) } } - - test("SPARK-35820: Support cast DayTimeIntervalType in different fields") { - val duration = Duration.ofSeconds(12345678L, 123456789) - Seq((DayTimeIntervalType(DAY, DAY), 12268800000000L, -12268800000000L), - (DayTimeIntervalType(DAY, HOUR), 12344400000000L, -12344400000000L), - (DayTimeIntervalType(DAY, MINUTE), 12345660000000L, -12345660000000L), - (DayTimeIntervalType(DAY, SECOND), 12345678123456L, -12345678123457L), - (DayTimeIntervalType(HOUR, HOUR), 12344400000000L, -12344400000000L), - (DayTimeIntervalType(HOUR, MINUTE), 12345660000000L, -12345660000000L), - (DayTimeIntervalType(HOUR, SECOND), 12345678123456L, -12345678123457L), - (DayTimeIntervalType(MINUTE, MINUTE), 12345660000000L, -12345660000000L), - (DayTimeIntervalType(MINUTE, SECOND), 12345678123456L, -12345678123457L), - (DayTimeIntervalType(SECOND, SECOND), 12345678123456L, -12345678123457L)) - .foreach { case (dt, positive, negative) => - checkEvaluation( - cast(Literal.create(duration, DayTimeIntervalType(DAY, SECOND)), dt), positive) - checkEvaluation( - cast(Literal.create(duration.negated(), DayTimeIntervalType(DAY, SECOND)), dt), negative) - } - } - - test("SPARK-35819: Support cast YearMonthIntervalType in different fields") { - val ym = cast(Literal.create("1-1"), YearMonthIntervalType(YEAR, MONTH)) - Seq(YearMonthIntervalType(YEAR) -> 12, - YearMonthIntervalType(YEAR, MONTH) -> 13, - YearMonthIntervalType(MONTH) -> 13) - .foreach { case (dt, value) => - checkEvaluation(cast(ym, dt), value) - } - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala index 2d484c73c0..0c74ca4fea 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala @@ -37,8 +37,13 @@ import org.apache.spark.sql.catalyst.util.IntervalUtils.microsToDuration import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearMonthIntervalTypes} +import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND} +import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR} import org.apache.spark.unsafe.types.UTF8String +/** + * Common test suite for [[Cast]], [[AnsiCast]] and [[TryCast]] expressions. + */ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { protected def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase @@ -68,7 +73,9 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { } } - protected def isAlwaysNullable: Boolean = false + // Whether the test suite is for TryCast. If yes, there is no exceptions and the result is + // always nullable. + protected def isTryCast: Boolean = false protected def setConfigurationHint: String = "" @@ -268,8 +275,8 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { } test("cast from string") { - assert(cast("abcdef", StringType).nullable === isAlwaysNullable) - assert(cast("abcdef", BinaryType).nullable === isAlwaysNullable) + assert(cast("abcdef", StringType).nullable === isTryCast) + assert(cast("abcdef", BinaryType).nullable === isTryCast) assert(cast("abcdef", BooleanType).nullable) assert(cast("abcdef", TimestampType).nullable) assert(cast("abcdef", LongType).nullable) @@ -949,4 +956,124 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast("2021-06-17", TimestampWithoutTZType), LocalDateTime.of(2021, 6, 17, 0, 0)) } + + test("SPARK-35112: Cast string to day-time interval") { + checkEvaluation(cast(Literal.create("0 0:0:0"), DayTimeIntervalType()), 0L) + checkEvaluation(cast(Literal.create(" interval '0 0:0:0' Day TO second "), + DayTimeIntervalType()), 0L) + checkEvaluation(cast(Literal.create("INTERVAL '1 2:03:04' DAY TO SECOND"), + DayTimeIntervalType()), 93784000000L) + checkEvaluation(cast(Literal.create("INTERVAL '1 03:04:00' DAY TO SECOND"), + DayTimeIntervalType()), 97440000000L) + checkEvaluation(cast(Literal.create("INTERVAL '1 03:04:00.0000' DAY TO SECOND"), + DayTimeIntervalType()), 97440000000L) + checkEvaluation(cast(Literal.create("1 2:03:04"), DayTimeIntervalType()), 93784000000L) + checkEvaluation(cast(Literal.create("INTERVAL '-10 2:03:04' DAY TO SECOND"), + DayTimeIntervalType()), -871384000000L) + checkEvaluation(cast(Literal.create("-10 2:03:04"), DayTimeIntervalType()), -871384000000L) + checkEvaluation(cast(Literal.create("-106751991 04:00:54.775808"), DayTimeIntervalType()), + Long.MinValue) + checkEvaluation(cast(Literal.create("106751991 04:00:54.775807"), DayTimeIntervalType()), + Long.MaxValue) + + Seq("-106751991 04:00:54.775808", "106751991 04:00:54.775807").foreach { interval => + val ansiInterval = s"INTERVAL '$interval' DAY TO SECOND" + checkEvaluation( + cast(cast(Literal.create(interval), DayTimeIntervalType()), StringType), ansiInterval) + checkEvaluation(cast(cast(Literal.create(ansiInterval), + DayTimeIntervalType()), StringType), ansiInterval) + } + + if (!isTryCast) { + Seq("INTERVAL '-106751991 04:00:54.775809' YEAR TO MONTH", + "INTERVAL '106751991 04:00:54.775808' YEAR TO MONTH").foreach { interval => + val e = intercept[IllegalArgumentException] { + cast(Literal.create(interval), DayTimeIntervalType()).eval() + }.getMessage + assert(e.contains("Interval string must match day-time format of")) + } + } + + Seq(Byte.MaxValue, Short.MaxValue, Int.MaxValue, Long.MaxValue, Long.MinValue + 1, + Long.MinValue).foreach { duration => + val interval = Literal.create( + Duration.of(duration, ChronoUnit.MICROS), + DayTimeIntervalType()) + checkEvaluation(cast(cast(interval, StringType), DayTimeIntervalType()), duration) + } + } + + test("SPARK-35111: Cast string to year-month interval") { + checkEvaluation(cast(Literal.create("INTERVAL '1-0' YEAR TO MONTH"), + YearMonthIntervalType()), 12) + checkEvaluation(cast(Literal.create("INTERVAL '-1-0' YEAR TO MONTH"), + YearMonthIntervalType()), -12) + checkEvaluation(cast(Literal.create("INTERVAL -'-1-0' YEAR TO MONTH"), + YearMonthIntervalType()), 12) + checkEvaluation(cast(Literal.create("INTERVAL +'-1-0' YEAR TO MONTH"), + YearMonthIntervalType()), -12) + checkEvaluation(cast(Literal.create("INTERVAL +'+1-0' YEAR TO MONTH"), + YearMonthIntervalType()), 12) + checkEvaluation(cast(Literal.create("INTERVAL +'1-0' YEAR TO MONTH"), + YearMonthIntervalType()), 12) + checkEvaluation(cast(Literal.create(" interval +'1-0' YEAR TO MONTH "), + YearMonthIntervalType()), 12) + checkEvaluation(cast(Literal.create(" -1-0 "), YearMonthIntervalType()), -12) + checkEvaluation(cast(Literal.create("-1-0"), YearMonthIntervalType()), -12) + checkEvaluation(cast(Literal.create(null, StringType), YearMonthIntervalType()), null) + + Seq("0-0", "10-1", "-178956970-7", "178956970-7", "-178956970-8").foreach { interval => + val ansiInterval = s"INTERVAL '$interval' YEAR TO MONTH" + checkEvaluation( + cast(cast(Literal.create(interval), YearMonthIntervalType()), StringType), ansiInterval) + checkEvaluation(cast(cast(Literal.create(ansiInterval), + YearMonthIntervalType()), StringType), ansiInterval) + } + + if (!isTryCast) { + Seq("INTERVAL '-178956970-9' YEAR TO MONTH", "INTERVAL '178956970-8' YEAR TO MONTH") + .foreach { interval => + val e = intercept[IllegalArgumentException] { + cast(Literal.create(interval), YearMonthIntervalType()).eval() + }.getMessage + assert(e.contains("Error parsing interval year-month string: integer overflow")) + } + } + + Seq(Byte.MaxValue, Short.MaxValue, Int.MaxValue, Int.MinValue + 1, Int.MinValue) + .foreach { period => + val interval = Literal.create(Period.ofMonths(period), YearMonthIntervalType()) + checkEvaluation(cast(cast(interval, StringType), YearMonthIntervalType()), period) + } + } + + test("SPARK-35820: Support cast DayTimeIntervalType in different fields") { + val duration = Duration.ofSeconds(12345678L, 123456789) + Seq((DayTimeIntervalType(DAY, DAY), 12268800000000L, -12268800000000L), + (DayTimeIntervalType(DAY, HOUR), 12344400000000L, -12344400000000L), + (DayTimeIntervalType(DAY, MINUTE), 12345660000000L, -12345660000000L), + (DayTimeIntervalType(DAY, SECOND), 12345678123456L, -12345678123457L), + (DayTimeIntervalType(HOUR, HOUR), 12344400000000L, -12344400000000L), + (DayTimeIntervalType(HOUR, MINUTE), 12345660000000L, -12345660000000L), + (DayTimeIntervalType(HOUR, SECOND), 12345678123456L, -12345678123457L), + (DayTimeIntervalType(MINUTE, MINUTE), 12345660000000L, -12345660000000L), + (DayTimeIntervalType(MINUTE, SECOND), 12345678123456L, -12345678123457L), + (DayTimeIntervalType(SECOND, SECOND), 12345678123456L, -12345678123457L)) + .foreach { case (dt, positive, negative) => + checkEvaluation( + cast(Literal.create(duration, DayTimeIntervalType(DAY, SECOND)), dt), positive) + checkEvaluation( + cast(Literal.create(duration.negated(), DayTimeIntervalType(DAY, SECOND)), dt), negative) + } + } + + test("SPARK-35819: Support cast YearMonthIntervalType in different fields") { + val ym = cast(Literal.create("1-1"), YearMonthIntervalType(YEAR, MONTH)) + Seq(YearMonthIntervalType(YEAR) -> 12, + YearMonthIntervalType(YEAR, MONTH) -> 13, + YearMonthIntervalType(MONTH) -> 13) + .foreach { case (dt, value) => + checkEvaluation(cast(ym, dt), value) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala index 76ce96705c..1394ec8c8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala @@ -30,7 +30,7 @@ class TryCastSuite extends AnsiCastSuiteBase { } } - override def isAlwaysNullable: Boolean = true + override def isTryCast: Boolean = true override protected def setConfigurationHint: String = ""