diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 41921265b7..fd7208615a 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -61,6 +61,27 @@ Spark SQL has three kinds of type conversions: explicit casting, type coercion, When `spark.sql.ansi.enabled` is set to `true`, explicit casting by `CAST` syntax throws a runtime exception for illegal cast patterns defined in the standard, e.g. casts from a string to an integer. On the other hand, `INSERT INTO` syntax throws an analysis exception when the ANSI mode enabled via `spark.sql.storeAssignmentPolicy=ANSI`. +The type conversion of Spark ANSI mode follows the syntax rules of section 6.13 "cast specification" in [ISO/IEC 9075-2:2011 Information technology — Database languages - SQL — Part 2: Foundation (SQL/Foundation)"](https://www.iso.org/standard/53682.html), except it specially allows the following + straightforward type conversions which are disallowed as per the ANSI standard: +* NumericType <=> BooleanType +* StringType <=> BinaryType + + The valid combinations of target data type and source data type in a `CAST` expression are given by the following table. +“Y” indicates that the combination is syntactically valid without restriction and “N” indicates that the combination is not valid. + +| From\To | NumericType | StringType | DateType | TimestampType | IntervalType | BooleanType | BinaryType | ArrayType | MapType | StructType | +|-----------|---------|--------|------|-----------|----------|---------|--------|-------|-----|--------| +| NumericType | Y | Y | N | N | N | Y | N | N | N | N | +| StringType | Y | Y | Y | Y | Y | Y | Y | N | N | N | +| DateType | N | Y | Y | Y | N | N | N | N | N | N | +| TimestampType | N | Y | Y | Y | N | N | N | N | N | N | +| IntervalType | N | Y | N | N | Y | N | N | N | N | N | +| BooleanType | Y | Y | N | N | N | Y | N | N | N | N | +| BinaryType | Y | N | N | N | N | N | Y | N | N | N | +| ArrayType | N | N | N | N | N | N | N | Y | N | N | +| MapType | N | N | N | N | N | N | N | N | Y | N | +| StructType | N | N | N | N | N | N | N | N | N | Y | + Currently, the ANSI mode affects explicit casting and assignment casting only. In future releases, the behaviour of type coercion might change along with the other two type conversion rules. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 4af12d61e8..1257cf6e78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -25,6 +25,7 @@ import java.util.concurrent.TimeUnit._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.expressions.Cast.{canCast, forceNullable, resolvableNullability} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ @@ -258,13 +259,18 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit def dataType: DataType + /** + * Returns true iff we can cast `from` type to `to` type. + */ + def canCast(from: DataType, to: DataType): Boolean + override def toString: String = { val ansi = if (ansiEnabled) "ansi_" else "" s"${ansi}cast($child as ${dataType.simpleString})" } override def checkInputDataTypes(): TypeCheckResult = { - if (Cast.canCast(child.dataType, dataType)) { + if (canCast(child.dataType, dataType)) { TypeCheckResult.TypeCheckSuccess } else { TypeCheckResult.TypeCheckFailure( @@ -1753,6 +1759,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String copy(timeZoneId = Option(timeZoneId)) override protected val ansiEnabled: Boolean = SQLConf.get.ansiEnabled + + override def canCast(from: DataType, to: DataType): Boolean = if (ansiEnabled) { + AnsiCast.canCast(from, to) + } else { + Cast.canCast(from, to) + } } /** @@ -1770,6 +1782,110 @@ case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[St copy(timeZoneId = Option(timeZoneId)) override protected val ansiEnabled: Boolean = true + + override def canCast(from: DataType, to: DataType): Boolean = AnsiCast.canCast(from, to) +} + +object AnsiCast { + /** + * As per section 6.13 "cast specification" in "Information technology — Database languages " + + * "- SQL — Part 2: Foundation (SQL/Foundation)": + * If the is a , then the valid combinations of TD and SD + * in a are given by the following table. “Y” indicates that the + * combination is syntactically valid without restriction; “M” indicates that the combination + * is valid subject to other Syntax Rules in this Sub- clause being satisfied; and “N” indicates + * that the combination is not valid: + * SD TD + * EN AN C D T TS YM DT BO UDT B RT CT RW + * EN Y Y Y N N N M M N M N M N N + * AN Y Y Y N N N N N N M N M N N + * C Y Y Y Y Y Y Y Y Y M N M N N + * D N N Y Y N Y N N N M N M N N + * T N N Y N Y Y N N N M N M N N + * TS N N Y Y Y Y N N N M N M N N + * YM M N Y N N N Y N N M N M N N + * DT M N Y N N N N Y N M N M N N + * BO N N Y N N N N N Y M N M N N + * UDT M M M M M M M M M M M M M N + * B N N N N N N N N N M Y M N N + * RT M M M M M M M M M M M M N N + * CT N N N N N N N N N M N N M N + * RW N N N N N N N N N N N N N M + * + * Where: + * EN = Exact Numeric + * AN = Approximate Numeric + * C = Character (Fixed- or Variable-Length, or Character Large Object) + * D = Date + * T = Time + * TS = Timestamp + * YM = Year-Month Interval + * DT = Day-Time Interval + * BO = Boolean + * UDT = User-Defined Type + * B = Binary (Fixed- or Variable-Length or Binary Large Object) + * RT = Reference type + * CT = Collection type + * RW = Row type + * + * Spark's ANSI mode follows the syntax rules, except it specially allow the following + * straightforward type conversions which are disallowed as per the SQL standard: + * - Numeric <=> Boolean + * - String <=> Binary + */ + def canCast(from: DataType, to: DataType): Boolean = (from, to) match { + case (fromType, toType) if fromType == toType => true + + case (NullType, _) => true + + case (StringType, _: BinaryType) => true + + case (StringType, BooleanType) => true + case (_: NumericType, BooleanType) => true + + case (StringType, TimestampType) => true + case (DateType, TimestampType) => true + + case (StringType, _: CalendarIntervalType) => true + + case (StringType, DateType) => true + case (TimestampType, DateType) => true + + case (_: NumericType, _: NumericType) => true + case (StringType, _: NumericType) => true + case (BooleanType, _: NumericType) => true + + case (_: NumericType, StringType) => true + case (_: DateType, StringType) => true + case (_: TimestampType, StringType) => true + case (_: CalendarIntervalType, StringType) => true + case (BooleanType, StringType) => true + case (BinaryType, StringType) => true + + case (ArrayType(fromType, fn), ArrayType(toType, tn)) => + canCast(fromType, toType) && + resolvableNullability(fn || forceNullable(fromType, toType), tn) + + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + canCast(fromKey, toKey) && + (!forceNullable(fromKey, toKey)) && + canCast(fromValue, toValue) && + resolvableNullability(fn || forceNullable(fromValue, toValue), tn) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { + case (fromField, toField) => + canCast(fromField.dataType, toField.dataType) && + resolvableNullability( + fromField.nullable || forceNullable(fromField.dataType, toField.dataType), + toField.nullable) + } + + case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt2.acceptsType(udt1) => true + + case _ => false + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 61133e2db5..afb76d8a5a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -38,9 +38,6 @@ import org.apache.spark.unsafe.types.UTF8String abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { - // Whether it is required to set SQLConf.ANSI_ENABLED as true for testing numeric overflow. - protected def requiredAnsiEnabledForOverflowTestCases: Boolean - protected def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase // expected cannot be null @@ -55,8 +52,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { test("null cast") { import DataTypeTestUtils._ - // follow [[org.apache.spark.sql.catalyst.expressions.Cast.canCast]] logic - // to ensure we test every possible cast situation here atomicTypes.zip(atomicTypes).foreach { case (from, to) => checkNullCast(from, to) } @@ -65,14 +60,10 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { atomicTypes.foreach(dt => checkNullCast(dt, StringType)) checkNullCast(StringType, BinaryType) checkNullCast(StringType, BooleanType) - checkNullCast(DateType, BooleanType) - checkNullCast(TimestampType, BooleanType) numericTypes.foreach(dt => checkNullCast(dt, BooleanType)) checkNullCast(StringType, TimestampType) - checkNullCast(BooleanType, TimestampType) checkNullCast(DateType, TimestampType) - numericTypes.foreach(dt => checkNullCast(dt, TimestampType)) checkNullCast(StringType, DateType) checkNullCast(TimestampType, DateType) @@ -80,8 +71,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkNullCast(StringType, CalendarIntervalType) numericTypes.foreach(dt => checkNullCast(StringType, dt)) numericTypes.foreach(dt => checkNullCast(BooleanType, dt)) - numericTypes.foreach(dt => checkNullCast(DateType, dt)) - numericTypes.foreach(dt => checkNullCast(TimestampType, dt)) for (from <- numericTypes; to <- numericTypes) checkNullCast(from, to) } @@ -215,6 +204,39 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(cast(0, BooleanType), IntegerType), 0) } + test("cast from int") { + checkCast(0, false) + checkCast(1, true) + checkCast(-5, true) + checkCast(1, 1.toByte) + checkCast(1, 1.toShort) + checkCast(1, 1) + checkCast(1, 1.toLong) + checkCast(1, 1.0f) + checkCast(1, 1.0) + checkCast(123, "123") + + checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) + checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) + checkEvaluation(cast(1, LongType), 1.toLong) + } + + test("cast from long") { + checkCast(0L, false) + checkCast(1L, true) + checkCast(-5L, true) + checkCast(1L, 1.toByte) + checkCast(1L, 1.toShort) + checkCast(1L, 1) + checkCast(1L, 1.toLong) + checkCast(1L, 1.0f) + checkCast(1L, 1.0) + checkCast(123L, "123") + + checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123)) + checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123)) + } + test("cast from float") { checkCast(0.0f, false) checkCast(0.5f, true) @@ -237,8 +259,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkCast(1.5, 1.toLong) checkCast(1.5, 1.5f) checkCast(1.5, "1.5") - - checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) } test("cast from string") { @@ -305,18 +325,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { cast(cast("5", ByteType), ShortType), IntegerType), FloatType), DoubleType), LongType), 5.toLong) - checkEvaluation( - cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), - DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), - 5.toShort) - checkEvaluation( - cast(cast(cast(cast(cast(cast("5", TimestampType, UTC_OPT), ByteType), - DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), - null) - checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT), - ByteType), TimestampType), LongType), StringType), ShortType), - 5.toShort) - checkEvaluation(cast("23", DoubleType), 23d) checkEvaluation(cast("23", IntegerType), 23) checkEvaluation(cast("23", FloatType), 23f) @@ -350,58 +358,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkCast(Decimal(1.5), "1.5") } - test("cast from date") { - val d = Date.valueOf("1970-01-01") - checkEvaluation(cast(d, ShortType), null) - checkEvaluation(cast(d, IntegerType), null) - checkEvaluation(cast(d, LongType), null) - checkEvaluation(cast(d, FloatType), null) - checkEvaluation(cast(d, DoubleType), null) - checkEvaluation(cast(d, DecimalType.SYSTEM_DEFAULT), null) - checkEvaluation(cast(d, DecimalType(10, 2)), null) - checkEvaluation(cast(d, StringType), "1970-01-01") - - checkEvaluation( - cast(cast(d, TimestampType, UTC_OPT), StringType, UTC_OPT), - "1970-01-01 00:00:00") - } - - test("cast from timestamp") { - val millis = 15 * 1000 + 3 - val seconds = millis * 1000 + 3 - val ts = new Timestamp(millis) - val tss = new Timestamp(seconds) - checkEvaluation(cast(ts, ShortType), 15.toShort) - checkEvaluation(cast(ts, IntegerType), 15) - checkEvaluation(cast(ts, LongType), 15.toLong) - checkEvaluation(cast(ts, FloatType), 15.003f) - checkEvaluation(cast(ts, DoubleType), 15.003) - - checkEvaluation(cast(cast(tss, ShortType), TimestampType), - fromJavaTimestamp(ts) * MILLIS_PER_SECOND) - checkEvaluation(cast(cast(tss, IntegerType), TimestampType), - fromJavaTimestamp(ts) * MILLIS_PER_SECOND) - checkEvaluation(cast(cast(tss, LongType), TimestampType), - fromJavaTimestamp(ts) * MILLIS_PER_SECOND) - checkEvaluation( - cast(cast(millis.toFloat / MILLIS_PER_SECOND, TimestampType), FloatType), - millis.toFloat / MILLIS_PER_SECOND) - checkEvaluation( - cast(cast(millis.toDouble / MILLIS_PER_SECOND, TimestampType), DoubleType), - millis.toDouble / MILLIS_PER_SECOND) - checkEvaluation( - cast(cast(Decimal(1), TimestampType), DecimalType.SYSTEM_DEFAULT), - Decimal(1)) - - // A test for higher precision than millis - checkEvaluation(cast(cast(0.000001, TimestampType), DoubleType), 0.000001) - - checkEvaluation(cast(Double.NaN, TimestampType), null) - checkEvaluation(cast(1.0 / 0.0, TimestampType), null) - checkEvaluation(cast(Float.NaN, TimestampType), null) - checkEvaluation(cast(1.0f / 0.0f, TimestampType), null) - } - test("cast from array") { val array = Literal.create(Seq("123", "true", "f", null), ArrayType(StringType, containsNull = true)) @@ -635,16 +591,20 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast("", BooleanType), null) } + protected def checkInvalidCastFromNumericType(to: DataType): Unit = { + assert(cast(1.toByte, to).checkInputDataTypes().isFailure) + assert(cast(1.toShort, to).checkInputDataTypes().isFailure) + assert(cast(1, to).checkInputDataTypes().isFailure) + assert(cast(1L, to).checkInputDataTypes().isFailure) + assert(cast(1.0.toFloat, to).checkInputDataTypes().isFailure) + assert(cast(1.0, to).checkInputDataTypes().isFailure) + } + test("SPARK-16729 type checking for casting to date type") { assert(cast("1234", DateType).checkInputDataTypes().isSuccess) assert(cast(new Timestamp(1), DateType).checkInputDataTypes().isSuccess) assert(cast(false, DateType).checkInputDataTypes().isFailure) - assert(cast(1.toByte, DateType).checkInputDataTypes().isFailure) - assert(cast(1.toShort, DateType).checkInputDataTypes().isFailure) - assert(cast(1, DateType).checkInputDataTypes().isFailure) - assert(cast(1L, DateType).checkInputDataTypes().isFailure) - assert(cast(1.0.toFloat, DateType).checkInputDataTypes().isFailure) - assert(cast(1.0, DateType).checkInputDataTypes().isFailure) + checkInvalidCastFromNumericType(DateType) } test("SPARK-20302 cast with same structure") { @@ -686,117 +646,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { assert(ctx.inlinedMutableStates.length == 0) } - test("SPARK-22825 Cast array to string") { - val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType) - checkEvaluation(ret1, "[1, 2, 3, 4, 5]") - val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType) - checkEvaluation(ret2, "[ab, cde, f]") - Seq(false, true).foreach { omitNull => - withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> omitNull.toString) { - val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType) - checkEvaluation(ret3, s"[ab,${if (omitNull) "" else " null"}, c]") - } - } - val ret4 = - cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType) - checkEvaluation(ret4, "[ab, cde, f]") - val ret5 = cast( - Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)), - StringType) - checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]") - val ret6 = cast( - Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00") - .map(Timestamp.valueOf)), - StringType) - checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]") - val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType) - checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]") - val ret8 = cast( - Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))), - StringType) - checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]") - } - - test("SPARK-33291: Cast array with null elements to string") { - Seq(false, true).foreach { omitNull => - withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> omitNull.toString) { - val ret1 = cast(Literal.create(Array(null, null)), StringType) - checkEvaluation( - ret1, - s"[${if (omitNull) "" else "null"},${if (omitNull) "" else " null"}]") - } - } - } - - test("SPARK-22973 Cast map to string") { - Seq( - false -> ("{", "}"), - true -> ("[", "]")).foreach { case (legacyCast, (lb, rb)) => - withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> legacyCast.toString) { - val ret1 = cast(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c")), StringType) - checkEvaluation(ret1, s"${lb}1 -> a, 2 -> b, 3 -> c$rb") - val ret2 = cast( - Literal.create(Map("1" -> "a".getBytes, "2" -> null, "3" -> "c".getBytes)), - StringType) - checkEvaluation(ret2, s"${lb}1 -> a, 2 ->${if (legacyCast) "" else " null"}, 3 -> c$rb") - val ret3 = cast( - Literal.create(Map( - 1 -> Date.valueOf("2014-12-03"), - 2 -> Date.valueOf("2014-12-04"), - 3 -> Date.valueOf("2014-12-05"))), - StringType) - checkEvaluation(ret3, s"${lb}1 -> 2014-12-03, 2 -> 2014-12-04, 3 -> 2014-12-05$rb") - val ret4 = cast( - Literal.create(Map( - 1 -> Timestamp.valueOf("2014-12-03 13:01:00"), - 2 -> Timestamp.valueOf("2014-12-04 15:05:00"))), - StringType) - checkEvaluation(ret4, s"${lb}1 -> 2014-12-03 13:01:00, 2 -> 2014-12-04 15:05:00$rb") - val ret5 = cast( - Literal.create(Map( - 1 -> Array(1, 2, 3), - 2 -> Array(4, 5, 6))), - StringType) - checkEvaluation(ret5, s"${lb}1 -> [1, 2, 3], 2 -> [4, 5, 6]$rb") - } - } - } - - test("SPARK-22981 Cast struct to string") { - Seq( - false -> ("{", "}"), - true -> ("[", "]")).foreach { case (legacyCast, (lb, rb)) => - withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> legacyCast.toString) { - val ret1 = cast(Literal.create((1, "a", 0.1)), StringType) - checkEvaluation(ret1, s"${lb}1, a, 0.1$rb") - val ret2 = cast(Literal.create(Tuple3[Int, String, String](1, null, "a")), StringType) - checkEvaluation(ret2, s"${lb}1,${if (legacyCast) "" else " null"}, a$rb") - val ret3 = cast(Literal.create( - (Date.valueOf("2014-12-03"), Timestamp.valueOf("2014-12-03 15:05:00"))), StringType) - checkEvaluation(ret3, s"${lb}2014-12-03, 2014-12-03 15:05:00$rb") - val ret4 = cast(Literal.create(((1, "a"), 5, 0.1)), StringType) - checkEvaluation(ret4, s"$lb${lb}1, a$rb, 5, 0.1$rb") - val ret5 = cast(Literal.create((Seq(1, 2, 3), "a", 0.1)), StringType) - checkEvaluation(ret5, s"$lb[1, 2, 3], a, 0.1$rb") - val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType) - checkEvaluation(ret6, s"${lb}1, ${lb}1 -> a, 2 -> b, 3 -> c$rb$rb") - } - } - } - - test("SPARK-33291: Cast struct with null elements to string") { - Seq( - false -> ("{", "}"), - true -> ("[", "]")).foreach { case (legacyCast, (lb, rb)) => - withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> legacyCast.toString) { - val ret1 = cast(Literal.create(Tuple2[String, String](null, null)), StringType) - checkEvaluation( - ret1, - s"$lb${if (legacyCast) "" else "null"},${if (legacyCast) "" else " null"}$rb") - } - } - } - test("up-cast") { def isCastSafe(from: NumericType, to: NumericType): Boolean = (from, to) match { case (_, dt: DecimalType) => dt.isWiderThan(from) @@ -869,20 +718,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { } } - test("Throw exception on casting out-of-range value to decimal type") { - withSQLConf(SQLConf.ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) { - checkExceptionInExpression[ArithmeticException]( - cast(Literal("134.12"), DecimalType(3, 2)), "cannot be represented") - checkExceptionInExpression[ArithmeticException]( - cast(Literal(Timestamp.valueOf("2019-07-25 22:04:36")), DecimalType(3, 2)), - "cannot be represented") - checkExceptionInExpression[ArithmeticException]( - cast(Literal(BigDecimal(134.12)), DecimalType(3, 2)), "cannot be represented") - checkExceptionInExpression[ArithmeticException]( - cast(Literal(134.12), DecimalType(3, 2)), "cannot be represented") - } - } - test("Process Infinity, -Infinity, NaN in case insensitive manner") { Seq("inf", "+inf", "infinity", "+infiNity", " infinity ").foreach { value => checkEvaluation(cast(value, FloatType), Float.PositiveInfinity) @@ -903,14 +738,15 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(value, DoubleType), Double.NaN) } } +} + +abstract class AnsiCastSuiteBase extends CastSuiteBase { private def testIntMaxAndMin(dt: DataType): Unit = { assert(Seq(IntegerType, ShortType, ByteType).contains(dt)) Seq(Int.MaxValue + 1L, Int.MinValue - 1L).foreach { value => checkExceptionInExpression[ArithmeticException](cast(value, dt), "overflow") checkExceptionInExpression[ArithmeticException](cast(Decimal(value.toString), dt), "overflow") - checkExceptionInExpression[ArithmeticException]( - cast(Literal(value * MICROS_PER_SECOND, TimestampType), dt), "overflow") checkExceptionInExpression[ArithmeticException]( cast(Literal(value * 1.5f, FloatType), dt), "overflow") checkExceptionInExpression[ArithmeticException]( @@ -930,98 +766,191 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { } } - test("Throw exception on casting out-of-range value to byte type") { - withSQLConf(SQLConf.ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) { - testIntMaxAndMin(ByteType) - Seq(Byte.MaxValue + 1, Byte.MinValue - 1).foreach { value => - checkExceptionInExpression[ArithmeticException](cast(value, ByteType), "overflow") - checkExceptionInExpression[ArithmeticException]( - cast(Literal(value * MICROS_PER_SECOND, TimestampType), ByteType), "overflow") - checkExceptionInExpression[ArithmeticException]( - cast(Literal(value.toFloat, FloatType), ByteType), "overflow") - checkExceptionInExpression[ArithmeticException]( - cast(Literal(value.toDouble, DoubleType), ByteType), "overflow") - } + test("ANSI mode: Throw exception on casting out-of-range value to byte type") { + testIntMaxAndMin(ByteType) + Seq(Byte.MaxValue + 1, Byte.MinValue - 1).foreach { value => + checkExceptionInExpression[ArithmeticException](cast(value, ByteType), "overflow") + checkExceptionInExpression[ArithmeticException]( + cast(Literal(value.toFloat, FloatType), ByteType), "overflow") + checkExceptionInExpression[ArithmeticException]( + cast(Literal(value.toDouble, DoubleType), ByteType), "overflow") + } - Seq(Byte.MaxValue, 0.toByte, Byte.MinValue).foreach { value => - checkEvaluation(cast(value, ByteType), value) - checkEvaluation(cast(value.toString, ByteType), value) - checkEvaluation(cast(Decimal(value.toString), ByteType), value) - checkEvaluation(cast(Literal(value * MICROS_PER_SECOND, TimestampType), ByteType), value) - checkEvaluation(cast(Literal(value.toInt, DateType), ByteType), null) - checkEvaluation(cast(Literal(value.toFloat, FloatType), ByteType), value) - checkEvaluation(cast(Literal(value.toDouble, DoubleType), ByteType), value) - } + Seq(Byte.MaxValue, 0.toByte, Byte.MinValue).foreach { value => + checkEvaluation(cast(value, ByteType), value) + checkEvaluation(cast(value.toString, ByteType), value) + checkEvaluation(cast(Decimal(value.toString), ByteType), value) + checkEvaluation(cast(Literal(value.toFloat, FloatType), ByteType), value) + checkEvaluation(cast(Literal(value.toDouble, DoubleType), ByteType), value) } } - test("Throw exception on casting out-of-range value to short type") { - withSQLConf(SQLConf.ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) { - testIntMaxAndMin(ShortType) - Seq(Short.MaxValue + 1, Short.MinValue - 1).foreach { value => - checkExceptionInExpression[ArithmeticException](cast(value, ShortType), "overflow") - checkExceptionInExpression[ArithmeticException]( - cast(Literal(value * MICROS_PER_SECOND, TimestampType), ShortType), "overflow") - checkExceptionInExpression[ArithmeticException]( - cast(Literal(value.toFloat, FloatType), ShortType), "overflow") - checkExceptionInExpression[ArithmeticException]( - cast(Literal(value.toDouble, DoubleType), ShortType), "overflow") - } + test("ANSI mode: Throw exception on casting out-of-range value to short type") { + testIntMaxAndMin(ShortType) + Seq(Short.MaxValue + 1, Short.MinValue - 1).foreach { value => + checkExceptionInExpression[ArithmeticException](cast(value, ShortType), "overflow") + checkExceptionInExpression[ArithmeticException]( + cast(Literal(value.toFloat, FloatType), ShortType), "overflow") + checkExceptionInExpression[ArithmeticException]( + cast(Literal(value.toDouble, DoubleType), ShortType), "overflow") + } - Seq(Short.MaxValue, 0.toShort, Short.MinValue).foreach { value => - checkEvaluation(cast(value, ShortType), value) - checkEvaluation(cast(value.toString, ShortType), value) - checkEvaluation(cast(Decimal(value.toString), ShortType), value) - checkEvaluation(cast(Literal(value * MICROS_PER_SECOND, TimestampType), ShortType), value) - checkEvaluation(cast(Literal(value.toInt, DateType), ShortType), null) - checkEvaluation(cast(Literal(value.toFloat, FloatType), ShortType), value) - checkEvaluation(cast(Literal(value.toDouble, DoubleType), ShortType), value) - } + Seq(Short.MaxValue, 0.toShort, Short.MinValue).foreach { value => + checkEvaluation(cast(value, ShortType), value) + checkEvaluation(cast(value.toString, ShortType), value) + checkEvaluation(cast(Decimal(value.toString), ShortType), value) + checkEvaluation(cast(Literal(value.toFloat, FloatType), ShortType), value) + checkEvaluation(cast(Literal(value.toDouble, DoubleType), ShortType), value) } } - test("Throw exception on casting out-of-range value to int type") { - withSQLConf(SQLConf.ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) { - testIntMaxAndMin(IntegerType) - testLongMaxAndMin(IntegerType) + test("ANSI mode: Throw exception on casting out-of-range value to int type") { + testIntMaxAndMin(IntegerType) + testLongMaxAndMin(IntegerType) - Seq(Int.MaxValue, 0, Int.MinValue).foreach { value => - checkEvaluation(cast(value, IntegerType), value) - checkEvaluation(cast(value.toString, IntegerType), value) - checkEvaluation(cast(Decimal(value.toString), IntegerType), value) - checkEvaluation(cast(Literal(value * MICROS_PER_SECOND, TimestampType), IntegerType), value) - checkEvaluation(cast(Literal(value * 1.0, DoubleType), IntegerType), value) - } - checkEvaluation(cast(Int.MaxValue + 0.9D, IntegerType), Int.MaxValue) - checkEvaluation(cast(Int.MinValue - 0.9D, IntegerType), Int.MinValue) + Seq(Int.MaxValue, 0, Int.MinValue).foreach { value => + checkEvaluation(cast(value, IntegerType), value) + checkEvaluation(cast(value.toString, IntegerType), value) + checkEvaluation(cast(Decimal(value.toString), IntegerType), value) + checkEvaluation(cast(Literal(value * 1.0, DoubleType), IntegerType), value) + } + checkEvaluation(cast(Int.MaxValue + 0.9D, IntegerType), Int.MaxValue) + checkEvaluation(cast(Int.MinValue - 0.9D, IntegerType), Int.MinValue) + } + + test("ANSI mode: Throw exception on casting out-of-range value to long type") { + testLongMaxAndMin(LongType) + + Seq(Long.MaxValue, 0, Long.MinValue).foreach { value => + checkEvaluation(cast(value, LongType), value) + checkEvaluation(cast(value.toString, LongType), value) + checkEvaluation(cast(Decimal(value.toString), LongType), value) + } + checkEvaluation(cast(Long.MaxValue + 0.9F, LongType), Long.MaxValue) + checkEvaluation(cast(Long.MinValue - 0.9F, LongType), Long.MinValue) + checkEvaluation(cast(Long.MaxValue + 0.9D, LongType), Long.MaxValue) + checkEvaluation(cast(Long.MinValue - 0.9D, LongType), Long.MinValue) + } + + test("ANSI mode: Throw exception on casting out-of-range value to decimal type") { + checkExceptionInExpression[ArithmeticException]( + cast(Literal("134.12"), DecimalType(3, 2)), "cannot be represented") + checkExceptionInExpression[ArithmeticException]( + cast(Literal(BigDecimal(134.12)), DecimalType(3, 2)), "cannot be represented") + checkExceptionInExpression[ArithmeticException]( + cast(Literal(134.12), DecimalType(3, 2)), "cannot be represented") + } + + test("ANSI mode: disallow type conversions between Numeric types and Timestamp type") { + import DataTypeTestUtils.numericTypes + checkInvalidCastFromNumericType(TimestampType) + val timestampLiteral = Literal(1L, TimestampType) + numericTypes.foreach { numericType => + assert(cast(timestampLiteral, numericType).checkInputDataTypes().isFailure) } } - test("Throw exception on casting out-of-range value to long type") { - withSQLConf(SQLConf.ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) { - testLongMaxAndMin(LongType) - - Seq(Long.MaxValue, 0, Long.MinValue).foreach { value => - checkEvaluation(cast(value, LongType), value) - checkEvaluation(cast(value.toString, LongType), value) - checkEvaluation(cast(Decimal(value.toString), LongType), value) - checkEvaluation(cast(Literal(value, TimestampType), LongType), - Math.floorDiv(value, MICROS_PER_SECOND)) - } - checkEvaluation(cast(Long.MaxValue + 0.9F, LongType), Long.MaxValue) - checkEvaluation(cast(Long.MinValue - 0.9F, LongType), Long.MinValue) - checkEvaluation(cast(Long.MaxValue + 0.9D, LongType), Long.MaxValue) - checkEvaluation(cast(Long.MinValue - 0.9D, LongType), Long.MinValue) + test("ANSI mode: disallow type conversions between Numeric types and Date type") { + import DataTypeTestUtils.numericTypes + checkInvalidCastFromNumericType(DateType) + val dateLiteral = Literal(1, DateType) + numericTypes.foreach { numericType => + assert(cast(dateLiteral, numericType).checkInputDataTypes().isFailure) } } + + test("ANSI mode: disallow type conversions between Numeric types and Binary type") { + import DataTypeTestUtils.numericTypes + checkInvalidCastFromNumericType(BinaryType) + val binaryLiteral = Literal(new Array[Byte](1.toByte), BinaryType) + numericTypes.foreach { numericType => + assert(cast(binaryLiteral, numericType).checkInputDataTypes().isFailure) + } + } + + test("ANSI mode: disallow type conversions between Datatime types and Boolean types") { + val timestampLiteral = Literal(1L, TimestampType) + assert(cast(timestampLiteral, BooleanType).checkInputDataTypes().isFailure) + val dateLiteral = Literal(1, DateType) + assert(cast(dateLiteral, BooleanType).checkInputDataTypes().isFailure) + + val booleanLiteral = Literal(true, BooleanType) + assert(cast(booleanLiteral, TimestampType).checkInputDataTypes().isFailure) + assert(cast(booleanLiteral, DateType).checkInputDataTypes().isFailure) + } + + test("ANSI mode: disallow casting complex types as String type") { + assert(cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType).checkInputDataTypes().isFailure) + assert(cast(Literal.create(Map(1 -> "a")), StringType).checkInputDataTypes().isFailure) + assert(cast(Literal.create((1, "a", 0.1)), StringType).checkInputDataTypes().isFailure) + } + + test("cast from invalid string to numeric should throw NumberFormatException") { + // cast to IntegerType + Seq(IntegerType, ShortType, ByteType, LongType).foreach { dataType => + val array = Literal.create(Seq("123", "true", "f", null), + ArrayType(StringType, containsNull = true)) + checkExceptionInExpression[NumberFormatException]( + cast(array, ArrayType(dataType, containsNull = true)), + "invalid input syntax for type numeric: true") + checkExceptionInExpression[NumberFormatException]( + cast("string", dataType), "invalid input syntax for type numeric: string") + checkExceptionInExpression[NumberFormatException]( + cast("123-string", dataType), "invalid input syntax for type numeric: 123-string") + checkExceptionInExpression[NumberFormatException]( + cast("2020-07-19", dataType), "invalid input syntax for type numeric: 2020-07-19") + checkExceptionInExpression[NumberFormatException]( + cast("1.23", dataType), "invalid input syntax for type numeric: 1.23") + } + + Seq(DoubleType, FloatType, DecimalType.USER_DEFAULT).foreach { dataType => + checkExceptionInExpression[NumberFormatException]( + cast("string", dataType), "invalid input syntax for type numeric: string") + checkExceptionInExpression[NumberFormatException]( + cast("123.000.00", dataType), "invalid input syntax for type numeric: 123.000.00") + checkExceptionInExpression[NumberFormatException]( + cast("abc.com", dataType), "invalid input syntax for type numeric: abc.com") + } + } + + test("Fast fail for cast string type to decimal type in ansi mode") { + checkEvaluation(cast("12345678901234567890123456789012345678", DecimalType(38, 0)), + Decimal("12345678901234567890123456789012345678")) + checkExceptionInExpression[ArithmeticException]( + cast("123456789012345678901234567890123456789", DecimalType(38, 0)), + "out of decimal type range") + checkExceptionInExpression[ArithmeticException]( + cast("12345678901234567890123456789012345678", DecimalType(38, 1)), + "cannot be represented as Decimal(38, 1)") + + checkEvaluation(cast("0.00000000000000000000000000000000000001", DecimalType(38, 0)), + Decimal("0")) + checkEvaluation(cast("0.00000000000000000000000000000000000000000001", DecimalType(38, 0)), + Decimal("0")) + checkEvaluation(cast("0.00000000000000000000000000000000000001", DecimalType(38, 18)), + Decimal("0E-18")) + checkEvaluation(cast("6E-120", DecimalType(38, 0)), + Decimal("0")) + + checkEvaluation(cast("6E+37", DecimalType(38, 0)), + Decimal("60000000000000000000000000000000000000")) + checkExceptionInExpression[ArithmeticException]( + cast("6E+38", DecimalType(38, 0)), + "out of decimal type range") + checkExceptionInExpression[ArithmeticException]( + cast("6E+37", DecimalType(38, 1)), + "cannot be represented as Decimal(38, 1)") + + checkExceptionInExpression[NumberFormatException]( + cast("abcd", DecimalType(38, 1)), + "invalid input syntax for type numeric") + } } /** * Test suite for data type casting expression [[Cast]]. */ class CastSuite extends CastSuiteBase { - // It is required to set SQLConf.ANSI_ENABLED as true for testing numeric overflow. - override protected def requiredAnsiEnabledForOverflowTestCases: Boolean = true override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = { v match { @@ -1030,51 +959,26 @@ class CastSuite extends CastSuiteBase { } } - test("cast from int") { - checkCast(0, false) - checkCast(1, true) - checkCast(-5, true) - checkCast(1, 1.toByte) - checkCast(1, 1.toShort) - checkCast(1, 1) - checkCast(1, 1.toLong) - checkCast(1, 1.0f) - checkCast(1, 1.0) - checkCast(123, "123") + test("null cast #2") { + import DataTypeTestUtils._ - checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) - checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) - checkEvaluation(cast(123, DecimalType(3, 1)), null) - checkEvaluation(cast(123, DecimalType(2, 0)), null) + checkNullCast(DateType, BooleanType) + checkNullCast(TimestampType, BooleanType) + checkNullCast(BooleanType, TimestampType) + numericTypes.foreach(dt => checkNullCast(dt, TimestampType)) + numericTypes.foreach(dt => checkNullCast(TimestampType, dt)) + numericTypes.foreach(dt => checkNullCast(DateType, dt)) } - test("cast from long") { - checkCast(0L, false) - checkCast(1L, true) - checkCast(-5L, true) - checkCast(1L, 1.toByte) - checkCast(1L, 1.toShort) - checkCast(1L, 1) - checkCast(1L, 1.toLong) - checkCast(1L, 1.0f) - checkCast(1L, 1.0) - checkCast(123L, "123") - - checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123)) - checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123)) + test("cast from long #2") { checkEvaluation(cast(123L, DecimalType(3, 1)), null) - checkEvaluation(cast(123L, DecimalType(2, 0)), null) } - test("cast from int 2") { - checkEvaluation(cast(1, LongType), 1.toLong) - + test("cast from int #2") { checkEvaluation(cast(cast(1000, TimestampType), LongType), 1000.toLong) checkEvaluation(cast(cast(-1200, TimestampType), LongType), -1200.toLong) - checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) - checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 1)), null) checkEvaluation(cast(123, DecimalType(2, 0)), null) } @@ -1343,6 +1247,58 @@ class CastSuite extends CastSuiteBase { } } + test("cast from date") { + val d = Date.valueOf("1970-01-01") + checkEvaluation(cast(d, ShortType), null) + checkEvaluation(cast(d, IntegerType), null) + checkEvaluation(cast(d, LongType), null) + checkEvaluation(cast(d, FloatType), null) + checkEvaluation(cast(d, DoubleType), null) + checkEvaluation(cast(d, DecimalType.SYSTEM_DEFAULT), null) + checkEvaluation(cast(d, DecimalType(10, 2)), null) + checkEvaluation(cast(d, StringType), "1970-01-01") + + checkEvaluation( + cast(cast(d, TimestampType, UTC_OPT), StringType, UTC_OPT), + "1970-01-01 00:00:00") + } + + test("cast from timestamp") { + val millis = 15 * 1000 + 3 + val seconds = millis * 1000 + 3 + val ts = new Timestamp(millis) + val tss = new Timestamp(seconds) + checkEvaluation(cast(ts, ShortType), 15.toShort) + checkEvaluation(cast(ts, IntegerType), 15) + checkEvaluation(cast(ts, LongType), 15.toLong) + checkEvaluation(cast(ts, FloatType), 15.003f) + checkEvaluation(cast(ts, DoubleType), 15.003) + + checkEvaluation(cast(cast(tss, ShortType), TimestampType), + fromJavaTimestamp(ts) * MILLIS_PER_SECOND) + checkEvaluation(cast(cast(tss, IntegerType), TimestampType), + fromJavaTimestamp(ts) * MILLIS_PER_SECOND) + checkEvaluation(cast(cast(tss, LongType), TimestampType), + fromJavaTimestamp(ts) * MILLIS_PER_SECOND) + checkEvaluation( + cast(cast(millis.toFloat / MILLIS_PER_SECOND, TimestampType), FloatType), + millis.toFloat / MILLIS_PER_SECOND) + checkEvaluation( + cast(cast(millis.toDouble / MILLIS_PER_SECOND, TimestampType), DoubleType), + millis.toDouble / MILLIS_PER_SECOND) + checkEvaluation( + cast(cast(Decimal(1), TimestampType), DecimalType.SYSTEM_DEFAULT), + Decimal(1)) + + // A test for higher precision than millis + checkEvaluation(cast(cast(0.000001, TimestampType), DoubleType), 0.000001) + + checkEvaluation(cast(Double.NaN, TimestampType), null) + checkEvaluation(cast(1.0 / 0.0, TimestampType), null) + checkEvaluation(cast(Float.NaN, TimestampType), null) + checkEvaluation(cast(1.0f / 0.0f, TimestampType), null) + } + test("cast a timestamp before the epoch 1970-01-01 00:00:00Z") { withDefaultTimeZone(UTC) { val negativeTs = Timestamp.valueOf("1900-05-05 18:34:56.1") @@ -1396,14 +1352,172 @@ class CastSuite extends CastSuiteBase { checkEvaluation(cast("abcd", DecimalType(38, 1)), null) } + + test("SPARK-22825 Cast array to string") { + val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType) + checkEvaluation(ret1, "[1, 2, 3, 4, 5]") + val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType) + checkEvaluation(ret2, "[ab, cde, f]") + Seq(false, true).foreach { omitNull => + withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> omitNull.toString) { + val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType) + checkEvaluation(ret3, s"[ab,${if (omitNull) "" else " null"}, c]") + } + } + val ret4 = + cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType) + checkEvaluation(ret4, "[ab, cde, f]") + val ret5 = cast( + Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)), + StringType) + checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]") + val ret6 = cast( + Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00") + .map(Timestamp.valueOf)), + StringType) + checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]") + val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType) + checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]") + val ret8 = cast( + Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))), + StringType) + checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]") + } + + test("SPARK-33291: Cast array with null elements to string") { + Seq(false, true).foreach { omitNull => + withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> omitNull.toString) { + val ret1 = cast(Literal.create(Array(null, null)), StringType) + checkEvaluation( + ret1, + s"[${if (omitNull) "" else "null"},${if (omitNull) "" else " null"}]") + } + } + } + + test("SPARK-22973 Cast map to string") { + Seq( + false -> ("{", "}"), + true -> ("[", "]")).foreach { case (legacyCast, (lb, rb)) => + withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> legacyCast.toString) { + val ret1 = cast(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c")), StringType) + checkEvaluation(ret1, s"${lb}1 -> a, 2 -> b, 3 -> c$rb") + val ret2 = cast( + Literal.create(Map("1" -> "a".getBytes, "2" -> null, "3" -> "c".getBytes)), + StringType) + checkEvaluation(ret2, s"${lb}1 -> a, 2 ->${if (legacyCast) "" else " null"}, 3 -> c$rb") + val ret3 = cast( + Literal.create(Map( + 1 -> Date.valueOf("2014-12-03"), + 2 -> Date.valueOf("2014-12-04"), + 3 -> Date.valueOf("2014-12-05"))), + StringType) + checkEvaluation(ret3, s"${lb}1 -> 2014-12-03, 2 -> 2014-12-04, 3 -> 2014-12-05$rb") + val ret4 = cast( + Literal.create(Map( + 1 -> Timestamp.valueOf("2014-12-03 13:01:00"), + 2 -> Timestamp.valueOf("2014-12-04 15:05:00"))), + StringType) + checkEvaluation(ret4, s"${lb}1 -> 2014-12-03 13:01:00, 2 -> 2014-12-04 15:05:00$rb") + val ret5 = cast( + Literal.create(Map( + 1 -> Array(1, 2, 3), + 2 -> Array(4, 5, 6))), + StringType) + checkEvaluation(ret5, s"${lb}1 -> [1, 2, 3], 2 -> [4, 5, 6]$rb") + } + } + } + + test("SPARK-22981 Cast struct to string") { + Seq( + false -> ("{", "}"), + true -> ("[", "]")).foreach { case (legacyCast, (lb, rb)) => + withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> legacyCast.toString) { + val ret1 = cast(Literal.create((1, "a", 0.1)), StringType) + checkEvaluation(ret1, s"${lb}1, a, 0.1$rb") + val ret2 = cast(Literal.create(Tuple3[Int, String, String](1, null, "a")), StringType) + checkEvaluation(ret2, s"${lb}1,${if (legacyCast) "" else " null"}, a$rb") + val ret3 = cast(Literal.create( + (Date.valueOf("2014-12-03"), Timestamp.valueOf("2014-12-03 15:05:00"))), StringType) + checkEvaluation(ret3, s"${lb}2014-12-03, 2014-12-03 15:05:00$rb") + val ret4 = cast(Literal.create(((1, "a"), 5, 0.1)), StringType) + checkEvaluation(ret4, s"$lb${lb}1, a$rb, 5, 0.1$rb") + val ret5 = cast(Literal.create((Seq(1, 2, 3), "a", 0.1)), StringType) + checkEvaluation(ret5, s"$lb[1, 2, 3], a, 0.1$rb") + val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType) + checkEvaluation(ret6, s"${lb}1, ${lb}1 -> a, 2 -> b, 3 -> c$rb$rb") + } + } + } + + test("SPARK-33291: Cast struct with null elements to string") { + Seq( + false -> ("{", "}"), + true -> ("[", "]")).foreach { case (legacyCast, (lb, rb)) => + withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> legacyCast.toString) { + val ret1 = cast(Literal.create(Tuple2[String, String](null, null)), StringType) + checkEvaluation( + ret1, + s"$lb${if (legacyCast) "" else "null"},${if (legacyCast) "" else " null"}$rb") + } + } + } + + test("data type casting II") { + checkEvaluation( + cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), + DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), + 5.toShort) + checkEvaluation( + cast(cast(cast(cast(cast(cast("5", TimestampType, UTC_OPT), ByteType), + DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), + null) + checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT), + ByteType), TimestampType), LongType), StringType), ShortType), + 5.toShort) + } + + test("Cast from double II") { + checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) + } } /** - * Test suite for data type casting expression [[AnsiCast]]. + * Test suite for data type casting expression [[Cast]] with ANSI mode disabled. */ -class AnsiCastSuite extends CastSuiteBase { - // It is not required to set SQLConf.ANSI_ENABLED as true for testing numeric overflow. - override protected def requiredAnsiEnabledForOverflowTestCases: Boolean = false +class CastSuiteWithAnsiModeOn extends AnsiCastSuiteBase { + override def beforeAll(): Unit = { + super.beforeAll() + SQLConf.get.setConf(SQLConf.ANSI_ENABLED, true) + } + + override def afterAll(): Unit = { + super.afterAll() + SQLConf.get.unsetConf(SQLConf.ANSI_ENABLED) + } + + override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = { + v match { + case lit: Expression => Cast(lit, targetType, timeZoneId) + case _ => Cast(Literal(v), targetType, timeZoneId) + } + } +} + +/** + * Test suite for data type casting expression [[AnsiCast]] with ANSI mode enabled. + */ +class AnsiCastSuiteWithAnsiModeOn extends AnsiCastSuiteBase { + override def beforeAll(): Unit = { + super.beforeAll() + SQLConf.get.setConf(SQLConf.ANSI_ENABLED, true) + } + + override def afterAll(): Unit = { + super.afterAll() + SQLConf.get.unsetConf(SQLConf.ANSI_ENABLED) + } override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = { v match { @@ -1411,78 +1525,26 @@ class AnsiCastSuite extends CastSuiteBase { case _ => AnsiCast(Literal(v), targetType, timeZoneId) } } +} - test("cast from invalid string to numeric should throw NumberFormatException") { - // cast to IntegerType - Seq(IntegerType, ShortType, ByteType, LongType).foreach { dataType => - val array = Literal.create(Seq("123", "true", "f", null), - ArrayType(StringType, containsNull = true)) - checkExceptionInExpression[NumberFormatException]( - cast(array, ArrayType(dataType, containsNull = true)), - "invalid input syntax for type numeric: true") - checkExceptionInExpression[NumberFormatException]( - cast("string", dataType), "invalid input syntax for type numeric: string") - checkExceptionInExpression[NumberFormatException]( - cast("123-string", dataType), "invalid input syntax for type numeric: 123-string") - checkExceptionInExpression[NumberFormatException]( - cast("2020-07-19", dataType), "invalid input syntax for type numeric: 2020-07-19") - checkExceptionInExpression[NumberFormatException]( - cast("1.23", dataType), "invalid input syntax for type numeric: 1.23") - } - - Seq(DoubleType, FloatType, DecimalType.USER_DEFAULT).foreach { dataType => - checkExceptionInExpression[NumberFormatException]( - cast("string", dataType), "invalid input syntax for type numeric: string") - checkExceptionInExpression[NumberFormatException]( - cast("123.000.00", dataType), "invalid input syntax for type numeric: 123.000.00") - checkExceptionInExpression[NumberFormatException]( - cast("abc.com", dataType), "invalid input syntax for type numeric: abc.com") - } +/** + * Test suite for data type casting expression [[AnsiCast]] with ANSI mode disabled. + */ +class AnsiCastSuiteWithAnsiModeOff extends AnsiCastSuiteBase { + override def beforeAll(): Unit = { + super.beforeAll() + SQLConf.get.setConf(SQLConf.ANSI_ENABLED, false) } - test("cast a timestamp before the epoch 1970-01-01 00:00:00Z") { - def errMsg(t: String): String = s"Casting -2198208303900000 to $t causes overflow" - withDefaultTimeZone(UTC) { - val negativeTs = Timestamp.valueOf("1900-05-05 18:34:56.1") - assert(negativeTs.getTime < 0) - val expectedSecs = Math.floorDiv(negativeTs.getTime, MILLIS_PER_SECOND) - checkExceptionInExpression[ArithmeticException](cast(negativeTs, ByteType), errMsg("byte")) - checkExceptionInExpression[ArithmeticException](cast(negativeTs, ShortType), errMsg("short")) - checkExceptionInExpression[ArithmeticException](cast(negativeTs, IntegerType), errMsg("int")) - checkEvaluation(cast(negativeTs, LongType), expectedSecs) - } + override def afterAll(): Unit = { + super.afterAll() + SQLConf.get.unsetConf(SQLConf.ANSI_ENABLED) } - test("Fast fail for cast string type to decimal type in ansi mode") { - checkEvaluation(cast("12345678901234567890123456789012345678", DecimalType(38, 0)), - Decimal("12345678901234567890123456789012345678")) - checkExceptionInExpression[ArithmeticException]( - cast("123456789012345678901234567890123456789", DecimalType(38, 0)), - "out of decimal type range") - checkExceptionInExpression[ArithmeticException]( - cast("12345678901234567890123456789012345678", DecimalType(38, 1)), - "cannot be represented as Decimal(38, 1)") - - checkEvaluation(cast("0.00000000000000000000000000000000000001", DecimalType(38, 0)), - Decimal("0")) - checkEvaluation(cast("0.00000000000000000000000000000000000000000001", DecimalType(38, 0)), - Decimal("0")) - checkEvaluation(cast("0.00000000000000000000000000000000000001", DecimalType(38, 18)), - Decimal("0E-18")) - checkEvaluation(cast("6E-120", DecimalType(38, 0)), - Decimal("0")) - - checkEvaluation(cast("6E+37", DecimalType(38, 0)), - Decimal("60000000000000000000000000000000000000")) - checkExceptionInExpression[ArithmeticException]( - cast("6E+38", DecimalType(38, 0)), - "out of decimal type range") - checkExceptionInExpression[ArithmeticException]( - cast("6E+37", DecimalType(38, 1)), - "cannot be represented as Decimal(38, 1)") - - checkExceptionInExpression[NumberFormatException]( - cast("abcd", DecimalType(38, 1)), - "invalid input syntax for type numeric") + override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = { + v match { + case lit: Expression => AnsiCast(lit, targetType, timeZoneId) + case _ => AnsiCast(Literal(v), targetType, timeZoneId) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 4686a0c69d..aaf8765c04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -756,6 +756,47 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { } } + test("SPARK-33354: Throw exceptions on inserting invalid cast with ANSI casting policy") { + withSQLConf( + SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.ANSI.toString) { + withTable("t") { + sql("CREATE TABLE t(i int, t timestamp) USING parquet") + val msg = intercept[AnalysisException] { + sql("INSERT INTO t VALUES (TIMESTAMP('2010-09-02 14:10:10'), 1)") + }.getMessage + assert(msg.contains("Cannot safely cast 'i': timestamp to int")) + assert(msg.contains("Cannot safely cast 't': int to timestamp")) + } + + withTable("t") { + sql("CREATE TABLE t(i int, d date) USING parquet") + val msg = intercept[AnalysisException] { + sql("INSERT INTO t VALUES (date('2010-09-02'), 1)") + }.getMessage + assert(msg.contains("Cannot safely cast 'i': date to int")) + assert(msg.contains("Cannot safely cast 'd': int to date")) + } + + withTable("t") { + sql("CREATE TABLE t(b boolean, t timestamp) USING parquet") + val msg = intercept[AnalysisException] { + sql("INSERT INTO t VALUES (TIMESTAMP('2010-09-02 14:10:10'), true)") + }.getMessage + assert(msg.contains("Cannot safely cast 'b': timestamp to boolean")) + assert(msg.contains("Cannot safely cast 't': boolean to timestamp")) + } + + withTable("t") { + sql("CREATE TABLE t(b boolean, d date) USING parquet") + val msg = intercept[AnalysisException] { + sql("INSERT INTO t VALUES (date('2010-09-02'), true)") + }.getMessage + assert(msg.contains("Cannot safely cast 'b': date to boolean")) + assert(msg.contains("Cannot safely cast 'd': boolean to date")) + } + } + } + test("SPARK-30844: static partition should also follow StoreAssignmentPolicy") { SQLConf.StoreAssignmentPolicy.values.foreach { policy => withSQLConf(