[SPARK-33354][SQL] New explicit cast syntax rules in ANSI mode

### What changes were proposed in this pull request?

In section 6.13 of the ANSI SQL standard, there are syntax rules for valid combinations of the source and target data types.
![image](https://user-images.githubusercontent.com/1097932/98212874-17356f80-1ef9-11eb-8f2b-385f32db404a.png)

Comparing the ANSI CAST syntax rules with the current default behavior of Spark:
![image](https://user-images.githubusercontent.com/1097932/98789831-b7870a80-23b7-11eb-9b5f-469a42e0ee4a.png)

To make Spark's ANSI mode more ANSI SQL Compatible,I propose to disallow the following casting in ANSI mode:
```
TimeStamp <=> Boolean
Date <=> Boolean
Numeric <=> Timestamp
Numeric <=> Date
Numeric <=> Binary
String <=> Array
String <=> Map
String <=> Struct
```
The following castings are considered invalid in ANSI SQL standard, but they are quite straight forward. Let's Allow them for now
```
Numeric <=> Boolean
String <=> Binary
```
### Why are the changes needed?

Better ANSI SQL compliance

### Does this PR introduce _any_ user-facing change?

Yes, the following castings will not be allowed in ANSI mode:
```
TimeStamp <=> Boolean
Date <=> Boolean
Numeric <=> Timestamp
Numeric <=> Date
Numeric <=> Binary
String <=> Array
String <=> Map
String <=> Struct
```

### How was this patch tested?

Unit test

The ANSI Compliance doc preview:
![image](https://user-images.githubusercontent.com/1097932/98946017-2cd20880-24a8-11eb-8161-65749bfdd03a.png)

Closes #30260 from gengliangwang/ansiCanCast.

Authored-by: Gengliang Wang <gengliang.wang@databricks.com>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
Gengliang Wang 2020-11-19 09:23:36 +09:00 committed by Takeshi Yamamuro
parent fbfc0bf628
commit 9a4c79073b
4 changed files with 630 additions and 390 deletions

View file

@ -61,6 +61,27 @@ Spark SQL has three kinds of type conversions: explicit casting, type coercion,
When `spark.sql.ansi.enabled` is set to `true`, explicit casting by `CAST` syntax throws a runtime exception for illegal cast patterns defined in the standard, e.g. casts from a string to an integer.
On the other hand, `INSERT INTO` syntax throws an analysis exception when the ANSI mode enabled via `spark.sql.storeAssignmentPolicy=ANSI`.
The type conversion of Spark ANSI mode follows the syntax rules of section 6.13 "cast specification" in [ISO/IEC 9075-2:2011 Information technology — Database languages - SQL — Part 2: Foundation (SQL/Foundation)"](https://www.iso.org/standard/53682.html), except it specially allows the following
straightforward type conversions which are disallowed as per the ANSI standard:
* NumericType <=> BooleanType
* StringType <=> BinaryType
The valid combinations of target data type and source data type in a `CAST` expression are given by the following table.
“Y” indicates that the combination is syntactically valid without restriction and “N” indicates that the combination is not valid.
| From\To | NumericType | StringType | DateType | TimestampType | IntervalType | BooleanType | BinaryType | ArrayType | MapType | StructType |
|-----------|---------|--------|------|-----------|----------|---------|--------|-------|-----|--------|
| NumericType | Y | Y | N | N | N | Y | N | N | N | N |
| StringType | Y | Y | Y | Y | Y | Y | Y | N | N | N |
| DateType | N | Y | Y | Y | N | N | N | N | N | N |
| TimestampType | N | Y | Y | Y | N | N | N | N | N | N |
| IntervalType | N | Y | N | N | Y | N | N | N | N | N |
| BooleanType | Y | Y | N | N | N | Y | N | N | N | N |
| BinaryType | Y | N | N | N | N | N | Y | N | N | N |
| ArrayType | N | N | N | N | N | N | N | Y | N | N |
| MapType | N | N | N | N | N | N | N | N | Y | N |
| StructType | N | N | N | N | N | N | N | N | N | Y |
Currently, the ANSI mode affects explicit casting and assignment casting only.
In future releases, the behaviour of type coercion might change along with the other two type conversion rules.

View file

@ -25,6 +25,7 @@ import java.util.concurrent.TimeUnit._
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.Cast.{canCast, forceNullable, resolvableNullability}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util._
@ -258,13 +259,18 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
def dataType: DataType
/**
* Returns true iff we can cast `from` type to `to` type.
*/
def canCast(from: DataType, to: DataType): Boolean
override def toString: String = {
val ansi = if (ansiEnabled) "ansi_" else ""
s"${ansi}cast($child as ${dataType.simpleString})"
}
override def checkInputDataTypes(): TypeCheckResult = {
if (Cast.canCast(child.dataType, dataType)) {
if (canCast(child.dataType, dataType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
@ -1753,6 +1759,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
copy(timeZoneId = Option(timeZoneId))
override protected val ansiEnabled: Boolean = SQLConf.get.ansiEnabled
override def canCast(from: DataType, to: DataType): Boolean = if (ansiEnabled) {
AnsiCast.canCast(from, to)
} else {
Cast.canCast(from, to)
}
}
/**
@ -1770,6 +1782,110 @@ case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[St
copy(timeZoneId = Option(timeZoneId))
override protected val ansiEnabled: Boolean = true
override def canCast(from: DataType, to: DataType): Boolean = AnsiCast.canCast(from, to)
}
object AnsiCast {
/**
* As per section 6.13 "cast specification" in "Information technology — Database languages " +
* "- SQL — Part 2: Foundation (SQL/Foundation)":
* If the <cast operand> is a <value expression>, then the valid combinations of TD and SD
* in a <cast specification> are given by the following table. Y indicates that the
* combination is syntactically valid without restriction; M indicates that the combination
* is valid subject to other Syntax Rules in this Sub- clause being satisfied; and N indicates
* that the combination is not valid:
* SD TD
* EN AN C D T TS YM DT BO UDT B RT CT RW
* EN Y Y Y N N N M M N M N M N N
* AN Y Y Y N N N N N N M N M N N
* C Y Y Y Y Y Y Y Y Y M N M N N
* D N N Y Y N Y N N N M N M N N
* T N N Y N Y Y N N N M N M N N
* TS N N Y Y Y Y N N N M N M N N
* YM M N Y N N N Y N N M N M N N
* DT M N Y N N N N Y N M N M N N
* BO N N Y N N N N N Y M N M N N
* UDT M M M M M M M M M M M M M N
* B N N N N N N N N N M Y M N N
* RT M M M M M M M M M M M M N N
* CT N N N N N N N N N M N N M N
* RW N N N N N N N N N N N N N M
*
* Where:
* EN = Exact Numeric
* AN = Approximate Numeric
* C = Character (Fixed- or Variable-Length, or Character Large Object)
* D = Date
* T = Time
* TS = Timestamp
* YM = Year-Month Interval
* DT = Day-Time Interval
* BO = Boolean
* UDT = User-Defined Type
* B = Binary (Fixed- or Variable-Length or Binary Large Object)
* RT = Reference type
* CT = Collection type
* RW = Row type
*
* Spark's ANSI mode follows the syntax rules, except it specially allow the following
* straightforward type conversions which are disallowed as per the SQL standard:
* - Numeric <=> Boolean
* - String <=> Binary
*/
def canCast(from: DataType, to: DataType): Boolean = (from, to) match {
case (fromType, toType) if fromType == toType => true
case (NullType, _) => true
case (StringType, _: BinaryType) => true
case (StringType, BooleanType) => true
case (_: NumericType, BooleanType) => true
case (StringType, TimestampType) => true
case (DateType, TimestampType) => true
case (StringType, _: CalendarIntervalType) => true
case (StringType, DateType) => true
case (TimestampType, DateType) => true
case (_: NumericType, _: NumericType) => true
case (StringType, _: NumericType) => true
case (BooleanType, _: NumericType) => true
case (_: NumericType, StringType) => true
case (_: DateType, StringType) => true
case (_: TimestampType, StringType) => true
case (_: CalendarIntervalType, StringType) => true
case (BooleanType, StringType) => true
case (BinaryType, StringType) => true
case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
canCast(fromType, toType) &&
resolvableNullability(fn || forceNullable(fromType, toType), tn)
case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
canCast(fromKey, toKey) &&
(!forceNullable(fromKey, toKey)) &&
canCast(fromValue, toValue) &&
resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields).forall {
case (fromField, toField) =>
canCast(fromField.dataType, toField.dataType) &&
resolvableNullability(
fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
toField.nullable)
}
case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt2.acceptsType(udt1) => true
case _ => false
}
}
/**

View file

@ -38,9 +38,6 @@ import org.apache.spark.unsafe.types.UTF8String
abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
// Whether it is required to set SQLConf.ANSI_ENABLED as true for testing numeric overflow.
protected def requiredAnsiEnabledForOverflowTestCases: Boolean
protected def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase
// expected cannot be null
@ -55,8 +52,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
test("null cast") {
import DataTypeTestUtils._
// follow [[org.apache.spark.sql.catalyst.expressions.Cast.canCast]] logic
// to ensure we test every possible cast situation here
atomicTypes.zip(atomicTypes).foreach { case (from, to) =>
checkNullCast(from, to)
}
@ -65,14 +60,10 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
atomicTypes.foreach(dt => checkNullCast(dt, StringType))
checkNullCast(StringType, BinaryType)
checkNullCast(StringType, BooleanType)
checkNullCast(DateType, BooleanType)
checkNullCast(TimestampType, BooleanType)
numericTypes.foreach(dt => checkNullCast(dt, BooleanType))
checkNullCast(StringType, TimestampType)
checkNullCast(BooleanType, TimestampType)
checkNullCast(DateType, TimestampType)
numericTypes.foreach(dt => checkNullCast(dt, TimestampType))
checkNullCast(StringType, DateType)
checkNullCast(TimestampType, DateType)
@ -80,8 +71,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
checkNullCast(StringType, CalendarIntervalType)
numericTypes.foreach(dt => checkNullCast(StringType, dt))
numericTypes.foreach(dt => checkNullCast(BooleanType, dt))
numericTypes.foreach(dt => checkNullCast(DateType, dt))
numericTypes.foreach(dt => checkNullCast(TimestampType, dt))
for (from <- numericTypes; to <- numericTypes) checkNullCast(from, to)
}
@ -215,6 +204,39 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(cast(0, BooleanType), IntegerType), 0)
}
test("cast from int") {
checkCast(0, false)
checkCast(1, true)
checkCast(-5, true)
checkCast(1, 1.toByte)
checkCast(1, 1.toShort)
checkCast(1, 1)
checkCast(1, 1.toLong)
checkCast(1, 1.0f)
checkCast(1, 1.0)
checkCast(123, "123")
checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123))
checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123))
checkEvaluation(cast(1, LongType), 1.toLong)
}
test("cast from long") {
checkCast(0L, false)
checkCast(1L, true)
checkCast(-5L, true)
checkCast(1L, 1.toByte)
checkCast(1L, 1.toShort)
checkCast(1L, 1)
checkCast(1L, 1.toLong)
checkCast(1L, 1.0f)
checkCast(1L, 1.0)
checkCast(123L, "123")
checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123))
checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123))
}
test("cast from float") {
checkCast(0.0f, false)
checkCast(0.5f, true)
@ -237,8 +259,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
checkCast(1.5, 1.toLong)
checkCast(1.5, 1.5f)
checkCast(1.5, "1.5")
checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble)
}
test("cast from string") {
@ -305,18 +325,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
cast(cast("5", ByteType), ShortType), IntegerType), FloatType), DoubleType), LongType),
5.toLong)
checkEvaluation(
cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType),
DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType),
5.toShort)
checkEvaluation(
cast(cast(cast(cast(cast(cast("5", TimestampType, UTC_OPT), ByteType),
DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType),
null)
checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT),
ByteType), TimestampType), LongType), StringType), ShortType),
5.toShort)
checkEvaluation(cast("23", DoubleType), 23d)
checkEvaluation(cast("23", IntegerType), 23)
checkEvaluation(cast("23", FloatType), 23f)
@ -350,58 +358,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
checkCast(Decimal(1.5), "1.5")
}
test("cast from date") {
val d = Date.valueOf("1970-01-01")
checkEvaluation(cast(d, ShortType), null)
checkEvaluation(cast(d, IntegerType), null)
checkEvaluation(cast(d, LongType), null)
checkEvaluation(cast(d, FloatType), null)
checkEvaluation(cast(d, DoubleType), null)
checkEvaluation(cast(d, DecimalType.SYSTEM_DEFAULT), null)
checkEvaluation(cast(d, DecimalType(10, 2)), null)
checkEvaluation(cast(d, StringType), "1970-01-01")
checkEvaluation(
cast(cast(d, TimestampType, UTC_OPT), StringType, UTC_OPT),
"1970-01-01 00:00:00")
}
test("cast from timestamp") {
val millis = 15 * 1000 + 3
val seconds = millis * 1000 + 3
val ts = new Timestamp(millis)
val tss = new Timestamp(seconds)
checkEvaluation(cast(ts, ShortType), 15.toShort)
checkEvaluation(cast(ts, IntegerType), 15)
checkEvaluation(cast(ts, LongType), 15.toLong)
checkEvaluation(cast(ts, FloatType), 15.003f)
checkEvaluation(cast(ts, DoubleType), 15.003)
checkEvaluation(cast(cast(tss, ShortType), TimestampType),
fromJavaTimestamp(ts) * MILLIS_PER_SECOND)
checkEvaluation(cast(cast(tss, IntegerType), TimestampType),
fromJavaTimestamp(ts) * MILLIS_PER_SECOND)
checkEvaluation(cast(cast(tss, LongType), TimestampType),
fromJavaTimestamp(ts) * MILLIS_PER_SECOND)
checkEvaluation(
cast(cast(millis.toFloat / MILLIS_PER_SECOND, TimestampType), FloatType),
millis.toFloat / MILLIS_PER_SECOND)
checkEvaluation(
cast(cast(millis.toDouble / MILLIS_PER_SECOND, TimestampType), DoubleType),
millis.toDouble / MILLIS_PER_SECOND)
checkEvaluation(
cast(cast(Decimal(1), TimestampType), DecimalType.SYSTEM_DEFAULT),
Decimal(1))
// A test for higher precision than millis
checkEvaluation(cast(cast(0.000001, TimestampType), DoubleType), 0.000001)
checkEvaluation(cast(Double.NaN, TimestampType), null)
checkEvaluation(cast(1.0 / 0.0, TimestampType), null)
checkEvaluation(cast(Float.NaN, TimestampType), null)
checkEvaluation(cast(1.0f / 0.0f, TimestampType), null)
}
test("cast from array") {
val array = Literal.create(Seq("123", "true", "f", null),
ArrayType(StringType, containsNull = true))
@ -635,16 +591,20 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast("", BooleanType), null)
}
protected def checkInvalidCastFromNumericType(to: DataType): Unit = {
assert(cast(1.toByte, to).checkInputDataTypes().isFailure)
assert(cast(1.toShort, to).checkInputDataTypes().isFailure)
assert(cast(1, to).checkInputDataTypes().isFailure)
assert(cast(1L, to).checkInputDataTypes().isFailure)
assert(cast(1.0.toFloat, to).checkInputDataTypes().isFailure)
assert(cast(1.0, to).checkInputDataTypes().isFailure)
}
test("SPARK-16729 type checking for casting to date type") {
assert(cast("1234", DateType).checkInputDataTypes().isSuccess)
assert(cast(new Timestamp(1), DateType).checkInputDataTypes().isSuccess)
assert(cast(false, DateType).checkInputDataTypes().isFailure)
assert(cast(1.toByte, DateType).checkInputDataTypes().isFailure)
assert(cast(1.toShort, DateType).checkInputDataTypes().isFailure)
assert(cast(1, DateType).checkInputDataTypes().isFailure)
assert(cast(1L, DateType).checkInputDataTypes().isFailure)
assert(cast(1.0.toFloat, DateType).checkInputDataTypes().isFailure)
assert(cast(1.0, DateType).checkInputDataTypes().isFailure)
checkInvalidCastFromNumericType(DateType)
}
test("SPARK-20302 cast with same structure") {
@ -686,117 +646,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
assert(ctx.inlinedMutableStates.length == 0)
}
test("SPARK-22825 Cast array to string") {
val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType)
checkEvaluation(ret1, "[1, 2, 3, 4, 5]")
val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType)
checkEvaluation(ret2, "[ab, cde, f]")
Seq(false, true).foreach { omitNull =>
withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> omitNull.toString) {
val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType)
checkEvaluation(ret3, s"[ab,${if (omitNull) "" else " null"}, c]")
}
}
val ret4 =
cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType)
checkEvaluation(ret4, "[ab, cde, f]")
val ret5 = cast(
Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)),
StringType)
checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]")
val ret6 = cast(
Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00")
.map(Timestamp.valueOf)),
StringType)
checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]")
val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType)
checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]")
val ret8 = cast(
Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))),
StringType)
checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]")
}
test("SPARK-33291: Cast array with null elements to string") {
Seq(false, true).foreach { omitNull =>
withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> omitNull.toString) {
val ret1 = cast(Literal.create(Array(null, null)), StringType)
checkEvaluation(
ret1,
s"[${if (omitNull) "" else "null"},${if (omitNull) "" else " null"}]")
}
}
}
test("SPARK-22973 Cast map to string") {
Seq(
false -> ("{", "}"),
true -> ("[", "]")).foreach { case (legacyCast, (lb, rb)) =>
withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> legacyCast.toString) {
val ret1 = cast(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c")), StringType)
checkEvaluation(ret1, s"${lb}1 -> a, 2 -> b, 3 -> c$rb")
val ret2 = cast(
Literal.create(Map("1" -> "a".getBytes, "2" -> null, "3" -> "c".getBytes)),
StringType)
checkEvaluation(ret2, s"${lb}1 -> a, 2 ->${if (legacyCast) "" else " null"}, 3 -> c$rb")
val ret3 = cast(
Literal.create(Map(
1 -> Date.valueOf("2014-12-03"),
2 -> Date.valueOf("2014-12-04"),
3 -> Date.valueOf("2014-12-05"))),
StringType)
checkEvaluation(ret3, s"${lb}1 -> 2014-12-03, 2 -> 2014-12-04, 3 -> 2014-12-05$rb")
val ret4 = cast(
Literal.create(Map(
1 -> Timestamp.valueOf("2014-12-03 13:01:00"),
2 -> Timestamp.valueOf("2014-12-04 15:05:00"))),
StringType)
checkEvaluation(ret4, s"${lb}1 -> 2014-12-03 13:01:00, 2 -> 2014-12-04 15:05:00$rb")
val ret5 = cast(
Literal.create(Map(
1 -> Array(1, 2, 3),
2 -> Array(4, 5, 6))),
StringType)
checkEvaluation(ret5, s"${lb}1 -> [1, 2, 3], 2 -> [4, 5, 6]$rb")
}
}
}
test("SPARK-22981 Cast struct to string") {
Seq(
false -> ("{", "}"),
true -> ("[", "]")).foreach { case (legacyCast, (lb, rb)) =>
withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> legacyCast.toString) {
val ret1 = cast(Literal.create((1, "a", 0.1)), StringType)
checkEvaluation(ret1, s"${lb}1, a, 0.1$rb")
val ret2 = cast(Literal.create(Tuple3[Int, String, String](1, null, "a")), StringType)
checkEvaluation(ret2, s"${lb}1,${if (legacyCast) "" else " null"}, a$rb")
val ret3 = cast(Literal.create(
(Date.valueOf("2014-12-03"), Timestamp.valueOf("2014-12-03 15:05:00"))), StringType)
checkEvaluation(ret3, s"${lb}2014-12-03, 2014-12-03 15:05:00$rb")
val ret4 = cast(Literal.create(((1, "a"), 5, 0.1)), StringType)
checkEvaluation(ret4, s"$lb${lb}1, a$rb, 5, 0.1$rb")
val ret5 = cast(Literal.create((Seq(1, 2, 3), "a", 0.1)), StringType)
checkEvaluation(ret5, s"$lb[1, 2, 3], a, 0.1$rb")
val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType)
checkEvaluation(ret6, s"${lb}1, ${lb}1 -> a, 2 -> b, 3 -> c$rb$rb")
}
}
}
test("SPARK-33291: Cast struct with null elements to string") {
Seq(
false -> ("{", "}"),
true -> ("[", "]")).foreach { case (legacyCast, (lb, rb)) =>
withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> legacyCast.toString) {
val ret1 = cast(Literal.create(Tuple2[String, String](null, null)), StringType)
checkEvaluation(
ret1,
s"$lb${if (legacyCast) "" else "null"},${if (legacyCast) "" else " null"}$rb")
}
}
}
test("up-cast") {
def isCastSafe(from: NumericType, to: NumericType): Boolean = (from, to) match {
case (_, dt: DecimalType) => dt.isWiderThan(from)
@ -869,20 +718,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
}
}
test("Throw exception on casting out-of-range value to decimal type") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) {
checkExceptionInExpression[ArithmeticException](
cast(Literal("134.12"), DecimalType(3, 2)), "cannot be represented")
checkExceptionInExpression[ArithmeticException](
cast(Literal(Timestamp.valueOf("2019-07-25 22:04:36")), DecimalType(3, 2)),
"cannot be represented")
checkExceptionInExpression[ArithmeticException](
cast(Literal(BigDecimal(134.12)), DecimalType(3, 2)), "cannot be represented")
checkExceptionInExpression[ArithmeticException](
cast(Literal(134.12), DecimalType(3, 2)), "cannot be represented")
}
}
test("Process Infinity, -Infinity, NaN in case insensitive manner") {
Seq("inf", "+inf", "infinity", "+infiNity", " infinity ").foreach { value =>
checkEvaluation(cast(value, FloatType), Float.PositiveInfinity)
@ -903,14 +738,15 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(value, DoubleType), Double.NaN)
}
}
}
abstract class AnsiCastSuiteBase extends CastSuiteBase {
private def testIntMaxAndMin(dt: DataType): Unit = {
assert(Seq(IntegerType, ShortType, ByteType).contains(dt))
Seq(Int.MaxValue + 1L, Int.MinValue - 1L).foreach { value =>
checkExceptionInExpression[ArithmeticException](cast(value, dt), "overflow")
checkExceptionInExpression[ArithmeticException](cast(Decimal(value.toString), dt), "overflow")
checkExceptionInExpression[ArithmeticException](
cast(Literal(value * MICROS_PER_SECOND, TimestampType), dt), "overflow")
checkExceptionInExpression[ArithmeticException](
cast(Literal(value * 1.5f, FloatType), dt), "overflow")
checkExceptionInExpression[ArithmeticException](
@ -930,98 +766,191 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
}
}
test("Throw exception on casting out-of-range value to byte type") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) {
testIntMaxAndMin(ByteType)
Seq(Byte.MaxValue + 1, Byte.MinValue - 1).foreach { value =>
checkExceptionInExpression[ArithmeticException](cast(value, ByteType), "overflow")
checkExceptionInExpression[ArithmeticException](
cast(Literal(value * MICROS_PER_SECOND, TimestampType), ByteType), "overflow")
checkExceptionInExpression[ArithmeticException](
cast(Literal(value.toFloat, FloatType), ByteType), "overflow")
checkExceptionInExpression[ArithmeticException](
cast(Literal(value.toDouble, DoubleType), ByteType), "overflow")
}
test("ANSI mode: Throw exception on casting out-of-range value to byte type") {
testIntMaxAndMin(ByteType)
Seq(Byte.MaxValue + 1, Byte.MinValue - 1).foreach { value =>
checkExceptionInExpression[ArithmeticException](cast(value, ByteType), "overflow")
checkExceptionInExpression[ArithmeticException](
cast(Literal(value.toFloat, FloatType), ByteType), "overflow")
checkExceptionInExpression[ArithmeticException](
cast(Literal(value.toDouble, DoubleType), ByteType), "overflow")
}
Seq(Byte.MaxValue, 0.toByte, Byte.MinValue).foreach { value =>
checkEvaluation(cast(value, ByteType), value)
checkEvaluation(cast(value.toString, ByteType), value)
checkEvaluation(cast(Decimal(value.toString), ByteType), value)
checkEvaluation(cast(Literal(value * MICROS_PER_SECOND, TimestampType), ByteType), value)
checkEvaluation(cast(Literal(value.toInt, DateType), ByteType), null)
checkEvaluation(cast(Literal(value.toFloat, FloatType), ByteType), value)
checkEvaluation(cast(Literal(value.toDouble, DoubleType), ByteType), value)
}
Seq(Byte.MaxValue, 0.toByte, Byte.MinValue).foreach { value =>
checkEvaluation(cast(value, ByteType), value)
checkEvaluation(cast(value.toString, ByteType), value)
checkEvaluation(cast(Decimal(value.toString), ByteType), value)
checkEvaluation(cast(Literal(value.toFloat, FloatType), ByteType), value)
checkEvaluation(cast(Literal(value.toDouble, DoubleType), ByteType), value)
}
}
test("Throw exception on casting out-of-range value to short type") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) {
testIntMaxAndMin(ShortType)
Seq(Short.MaxValue + 1, Short.MinValue - 1).foreach { value =>
checkExceptionInExpression[ArithmeticException](cast(value, ShortType), "overflow")
checkExceptionInExpression[ArithmeticException](
cast(Literal(value * MICROS_PER_SECOND, TimestampType), ShortType), "overflow")
checkExceptionInExpression[ArithmeticException](
cast(Literal(value.toFloat, FloatType), ShortType), "overflow")
checkExceptionInExpression[ArithmeticException](
cast(Literal(value.toDouble, DoubleType), ShortType), "overflow")
}
test("ANSI mode: Throw exception on casting out-of-range value to short type") {
testIntMaxAndMin(ShortType)
Seq(Short.MaxValue + 1, Short.MinValue - 1).foreach { value =>
checkExceptionInExpression[ArithmeticException](cast(value, ShortType), "overflow")
checkExceptionInExpression[ArithmeticException](
cast(Literal(value.toFloat, FloatType), ShortType), "overflow")
checkExceptionInExpression[ArithmeticException](
cast(Literal(value.toDouble, DoubleType), ShortType), "overflow")
}
Seq(Short.MaxValue, 0.toShort, Short.MinValue).foreach { value =>
checkEvaluation(cast(value, ShortType), value)
checkEvaluation(cast(value.toString, ShortType), value)
checkEvaluation(cast(Decimal(value.toString), ShortType), value)
checkEvaluation(cast(Literal(value * MICROS_PER_SECOND, TimestampType), ShortType), value)
checkEvaluation(cast(Literal(value.toInt, DateType), ShortType), null)
checkEvaluation(cast(Literal(value.toFloat, FloatType), ShortType), value)
checkEvaluation(cast(Literal(value.toDouble, DoubleType), ShortType), value)
}
Seq(Short.MaxValue, 0.toShort, Short.MinValue).foreach { value =>
checkEvaluation(cast(value, ShortType), value)
checkEvaluation(cast(value.toString, ShortType), value)
checkEvaluation(cast(Decimal(value.toString), ShortType), value)
checkEvaluation(cast(Literal(value.toFloat, FloatType), ShortType), value)
checkEvaluation(cast(Literal(value.toDouble, DoubleType), ShortType), value)
}
}
test("Throw exception on casting out-of-range value to int type") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) {
testIntMaxAndMin(IntegerType)
testLongMaxAndMin(IntegerType)
test("ANSI mode: Throw exception on casting out-of-range value to int type") {
testIntMaxAndMin(IntegerType)
testLongMaxAndMin(IntegerType)
Seq(Int.MaxValue, 0, Int.MinValue).foreach { value =>
checkEvaluation(cast(value, IntegerType), value)
checkEvaluation(cast(value.toString, IntegerType), value)
checkEvaluation(cast(Decimal(value.toString), IntegerType), value)
checkEvaluation(cast(Literal(value * MICROS_PER_SECOND, TimestampType), IntegerType), value)
checkEvaluation(cast(Literal(value * 1.0, DoubleType), IntegerType), value)
}
checkEvaluation(cast(Int.MaxValue + 0.9D, IntegerType), Int.MaxValue)
checkEvaluation(cast(Int.MinValue - 0.9D, IntegerType), Int.MinValue)
Seq(Int.MaxValue, 0, Int.MinValue).foreach { value =>
checkEvaluation(cast(value, IntegerType), value)
checkEvaluation(cast(value.toString, IntegerType), value)
checkEvaluation(cast(Decimal(value.toString), IntegerType), value)
checkEvaluation(cast(Literal(value * 1.0, DoubleType), IntegerType), value)
}
checkEvaluation(cast(Int.MaxValue + 0.9D, IntegerType), Int.MaxValue)
checkEvaluation(cast(Int.MinValue - 0.9D, IntegerType), Int.MinValue)
}
test("ANSI mode: Throw exception on casting out-of-range value to long type") {
testLongMaxAndMin(LongType)
Seq(Long.MaxValue, 0, Long.MinValue).foreach { value =>
checkEvaluation(cast(value, LongType), value)
checkEvaluation(cast(value.toString, LongType), value)
checkEvaluation(cast(Decimal(value.toString), LongType), value)
}
checkEvaluation(cast(Long.MaxValue + 0.9F, LongType), Long.MaxValue)
checkEvaluation(cast(Long.MinValue - 0.9F, LongType), Long.MinValue)
checkEvaluation(cast(Long.MaxValue + 0.9D, LongType), Long.MaxValue)
checkEvaluation(cast(Long.MinValue - 0.9D, LongType), Long.MinValue)
}
test("ANSI mode: Throw exception on casting out-of-range value to decimal type") {
checkExceptionInExpression[ArithmeticException](
cast(Literal("134.12"), DecimalType(3, 2)), "cannot be represented")
checkExceptionInExpression[ArithmeticException](
cast(Literal(BigDecimal(134.12)), DecimalType(3, 2)), "cannot be represented")
checkExceptionInExpression[ArithmeticException](
cast(Literal(134.12), DecimalType(3, 2)), "cannot be represented")
}
test("ANSI mode: disallow type conversions between Numeric types and Timestamp type") {
import DataTypeTestUtils.numericTypes
checkInvalidCastFromNumericType(TimestampType)
val timestampLiteral = Literal(1L, TimestampType)
numericTypes.foreach { numericType =>
assert(cast(timestampLiteral, numericType).checkInputDataTypes().isFailure)
}
}
test("Throw exception on casting out-of-range value to long type") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) {
testLongMaxAndMin(LongType)
Seq(Long.MaxValue, 0, Long.MinValue).foreach { value =>
checkEvaluation(cast(value, LongType), value)
checkEvaluation(cast(value.toString, LongType), value)
checkEvaluation(cast(Decimal(value.toString), LongType), value)
checkEvaluation(cast(Literal(value, TimestampType), LongType),
Math.floorDiv(value, MICROS_PER_SECOND))
}
checkEvaluation(cast(Long.MaxValue + 0.9F, LongType), Long.MaxValue)
checkEvaluation(cast(Long.MinValue - 0.9F, LongType), Long.MinValue)
checkEvaluation(cast(Long.MaxValue + 0.9D, LongType), Long.MaxValue)
checkEvaluation(cast(Long.MinValue - 0.9D, LongType), Long.MinValue)
test("ANSI mode: disallow type conversions between Numeric types and Date type") {
import DataTypeTestUtils.numericTypes
checkInvalidCastFromNumericType(DateType)
val dateLiteral = Literal(1, DateType)
numericTypes.foreach { numericType =>
assert(cast(dateLiteral, numericType).checkInputDataTypes().isFailure)
}
}
test("ANSI mode: disallow type conversions between Numeric types and Binary type") {
import DataTypeTestUtils.numericTypes
checkInvalidCastFromNumericType(BinaryType)
val binaryLiteral = Literal(new Array[Byte](1.toByte), BinaryType)
numericTypes.foreach { numericType =>
assert(cast(binaryLiteral, numericType).checkInputDataTypes().isFailure)
}
}
test("ANSI mode: disallow type conversions between Datatime types and Boolean types") {
val timestampLiteral = Literal(1L, TimestampType)
assert(cast(timestampLiteral, BooleanType).checkInputDataTypes().isFailure)
val dateLiteral = Literal(1, DateType)
assert(cast(dateLiteral, BooleanType).checkInputDataTypes().isFailure)
val booleanLiteral = Literal(true, BooleanType)
assert(cast(booleanLiteral, TimestampType).checkInputDataTypes().isFailure)
assert(cast(booleanLiteral, DateType).checkInputDataTypes().isFailure)
}
test("ANSI mode: disallow casting complex types as String type") {
assert(cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType).checkInputDataTypes().isFailure)
assert(cast(Literal.create(Map(1 -> "a")), StringType).checkInputDataTypes().isFailure)
assert(cast(Literal.create((1, "a", 0.1)), StringType).checkInputDataTypes().isFailure)
}
test("cast from invalid string to numeric should throw NumberFormatException") {
// cast to IntegerType
Seq(IntegerType, ShortType, ByteType, LongType).foreach { dataType =>
val array = Literal.create(Seq("123", "true", "f", null),
ArrayType(StringType, containsNull = true))
checkExceptionInExpression[NumberFormatException](
cast(array, ArrayType(dataType, containsNull = true)),
"invalid input syntax for type numeric: true")
checkExceptionInExpression[NumberFormatException](
cast("string", dataType), "invalid input syntax for type numeric: string")
checkExceptionInExpression[NumberFormatException](
cast("123-string", dataType), "invalid input syntax for type numeric: 123-string")
checkExceptionInExpression[NumberFormatException](
cast("2020-07-19", dataType), "invalid input syntax for type numeric: 2020-07-19")
checkExceptionInExpression[NumberFormatException](
cast("1.23", dataType), "invalid input syntax for type numeric: 1.23")
}
Seq(DoubleType, FloatType, DecimalType.USER_DEFAULT).foreach { dataType =>
checkExceptionInExpression[NumberFormatException](
cast("string", dataType), "invalid input syntax for type numeric: string")
checkExceptionInExpression[NumberFormatException](
cast("123.000.00", dataType), "invalid input syntax for type numeric: 123.000.00")
checkExceptionInExpression[NumberFormatException](
cast("abc.com", dataType), "invalid input syntax for type numeric: abc.com")
}
}
test("Fast fail for cast string type to decimal type in ansi mode") {
checkEvaluation(cast("12345678901234567890123456789012345678", DecimalType(38, 0)),
Decimal("12345678901234567890123456789012345678"))
checkExceptionInExpression[ArithmeticException](
cast("123456789012345678901234567890123456789", DecimalType(38, 0)),
"out of decimal type range")
checkExceptionInExpression[ArithmeticException](
cast("12345678901234567890123456789012345678", DecimalType(38, 1)),
"cannot be represented as Decimal(38, 1)")
checkEvaluation(cast("0.00000000000000000000000000000000000001", DecimalType(38, 0)),
Decimal("0"))
checkEvaluation(cast("0.00000000000000000000000000000000000000000001", DecimalType(38, 0)),
Decimal("0"))
checkEvaluation(cast("0.00000000000000000000000000000000000001", DecimalType(38, 18)),
Decimal("0E-18"))
checkEvaluation(cast("6E-120", DecimalType(38, 0)),
Decimal("0"))
checkEvaluation(cast("6E+37", DecimalType(38, 0)),
Decimal("60000000000000000000000000000000000000"))
checkExceptionInExpression[ArithmeticException](
cast("6E+38", DecimalType(38, 0)),
"out of decimal type range")
checkExceptionInExpression[ArithmeticException](
cast("6E+37", DecimalType(38, 1)),
"cannot be represented as Decimal(38, 1)")
checkExceptionInExpression[NumberFormatException](
cast("abcd", DecimalType(38, 1)),
"invalid input syntax for type numeric")
}
}
/**
* Test suite for data type casting expression [[Cast]].
*/
class CastSuite extends CastSuiteBase {
// It is required to set SQLConf.ANSI_ENABLED as true for testing numeric overflow.
override protected def requiredAnsiEnabledForOverflowTestCases: Boolean = true
override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = {
v match {
@ -1030,51 +959,26 @@ class CastSuite extends CastSuiteBase {
}
}
test("cast from int") {
checkCast(0, false)
checkCast(1, true)
checkCast(-5, true)
checkCast(1, 1.toByte)
checkCast(1, 1.toShort)
checkCast(1, 1)
checkCast(1, 1.toLong)
checkCast(1, 1.0f)
checkCast(1, 1.0)
checkCast(123, "123")
test("null cast #2") {
import DataTypeTestUtils._
checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123))
checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123))
checkEvaluation(cast(123, DecimalType(3, 1)), null)
checkEvaluation(cast(123, DecimalType(2, 0)), null)
checkNullCast(DateType, BooleanType)
checkNullCast(TimestampType, BooleanType)
checkNullCast(BooleanType, TimestampType)
numericTypes.foreach(dt => checkNullCast(dt, TimestampType))
numericTypes.foreach(dt => checkNullCast(TimestampType, dt))
numericTypes.foreach(dt => checkNullCast(DateType, dt))
}
test("cast from long") {
checkCast(0L, false)
checkCast(1L, true)
checkCast(-5L, true)
checkCast(1L, 1.toByte)
checkCast(1L, 1.toShort)
checkCast(1L, 1)
checkCast(1L, 1.toLong)
checkCast(1L, 1.0f)
checkCast(1L, 1.0)
checkCast(123L, "123")
checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123))
checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123))
test("cast from long #2") {
checkEvaluation(cast(123L, DecimalType(3, 1)), null)
checkEvaluation(cast(123L, DecimalType(2, 0)), null)
}
test("cast from int 2") {
checkEvaluation(cast(1, LongType), 1.toLong)
test("cast from int #2") {
checkEvaluation(cast(cast(1000, TimestampType), LongType), 1000.toLong)
checkEvaluation(cast(cast(-1200, TimestampType), LongType), -1200.toLong)
checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123))
checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123))
checkEvaluation(cast(123, DecimalType(3, 1)), null)
checkEvaluation(cast(123, DecimalType(2, 0)), null)
}
@ -1343,6 +1247,58 @@ class CastSuite extends CastSuiteBase {
}
}
test("cast from date") {
val d = Date.valueOf("1970-01-01")
checkEvaluation(cast(d, ShortType), null)
checkEvaluation(cast(d, IntegerType), null)
checkEvaluation(cast(d, LongType), null)
checkEvaluation(cast(d, FloatType), null)
checkEvaluation(cast(d, DoubleType), null)
checkEvaluation(cast(d, DecimalType.SYSTEM_DEFAULT), null)
checkEvaluation(cast(d, DecimalType(10, 2)), null)
checkEvaluation(cast(d, StringType), "1970-01-01")
checkEvaluation(
cast(cast(d, TimestampType, UTC_OPT), StringType, UTC_OPT),
"1970-01-01 00:00:00")
}
test("cast from timestamp") {
val millis = 15 * 1000 + 3
val seconds = millis * 1000 + 3
val ts = new Timestamp(millis)
val tss = new Timestamp(seconds)
checkEvaluation(cast(ts, ShortType), 15.toShort)
checkEvaluation(cast(ts, IntegerType), 15)
checkEvaluation(cast(ts, LongType), 15.toLong)
checkEvaluation(cast(ts, FloatType), 15.003f)
checkEvaluation(cast(ts, DoubleType), 15.003)
checkEvaluation(cast(cast(tss, ShortType), TimestampType),
fromJavaTimestamp(ts) * MILLIS_PER_SECOND)
checkEvaluation(cast(cast(tss, IntegerType), TimestampType),
fromJavaTimestamp(ts) * MILLIS_PER_SECOND)
checkEvaluation(cast(cast(tss, LongType), TimestampType),
fromJavaTimestamp(ts) * MILLIS_PER_SECOND)
checkEvaluation(
cast(cast(millis.toFloat / MILLIS_PER_SECOND, TimestampType), FloatType),
millis.toFloat / MILLIS_PER_SECOND)
checkEvaluation(
cast(cast(millis.toDouble / MILLIS_PER_SECOND, TimestampType), DoubleType),
millis.toDouble / MILLIS_PER_SECOND)
checkEvaluation(
cast(cast(Decimal(1), TimestampType), DecimalType.SYSTEM_DEFAULT),
Decimal(1))
// A test for higher precision than millis
checkEvaluation(cast(cast(0.000001, TimestampType), DoubleType), 0.000001)
checkEvaluation(cast(Double.NaN, TimestampType), null)
checkEvaluation(cast(1.0 / 0.0, TimestampType), null)
checkEvaluation(cast(Float.NaN, TimestampType), null)
checkEvaluation(cast(1.0f / 0.0f, TimestampType), null)
}
test("cast a timestamp before the epoch 1970-01-01 00:00:00Z") {
withDefaultTimeZone(UTC) {
val negativeTs = Timestamp.valueOf("1900-05-05 18:34:56.1")
@ -1396,14 +1352,172 @@ class CastSuite extends CastSuiteBase {
checkEvaluation(cast("abcd", DecimalType(38, 1)), null)
}
test("SPARK-22825 Cast array to string") {
val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType)
checkEvaluation(ret1, "[1, 2, 3, 4, 5]")
val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType)
checkEvaluation(ret2, "[ab, cde, f]")
Seq(false, true).foreach { omitNull =>
withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> omitNull.toString) {
val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType)
checkEvaluation(ret3, s"[ab,${if (omitNull) "" else " null"}, c]")
}
}
val ret4 =
cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType)
checkEvaluation(ret4, "[ab, cde, f]")
val ret5 = cast(
Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)),
StringType)
checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]")
val ret6 = cast(
Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00")
.map(Timestamp.valueOf)),
StringType)
checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]")
val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType)
checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]")
val ret8 = cast(
Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))),
StringType)
checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]")
}
test("SPARK-33291: Cast array with null elements to string") {
Seq(false, true).foreach { omitNull =>
withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> omitNull.toString) {
val ret1 = cast(Literal.create(Array(null, null)), StringType)
checkEvaluation(
ret1,
s"[${if (omitNull) "" else "null"},${if (omitNull) "" else " null"}]")
}
}
}
test("SPARK-22973 Cast map to string") {
Seq(
false -> ("{", "}"),
true -> ("[", "]")).foreach { case (legacyCast, (lb, rb)) =>
withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> legacyCast.toString) {
val ret1 = cast(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c")), StringType)
checkEvaluation(ret1, s"${lb}1 -> a, 2 -> b, 3 -> c$rb")
val ret2 = cast(
Literal.create(Map("1" -> "a".getBytes, "2" -> null, "3" -> "c".getBytes)),
StringType)
checkEvaluation(ret2, s"${lb}1 -> a, 2 ->${if (legacyCast) "" else " null"}, 3 -> c$rb")
val ret3 = cast(
Literal.create(Map(
1 -> Date.valueOf("2014-12-03"),
2 -> Date.valueOf("2014-12-04"),
3 -> Date.valueOf("2014-12-05"))),
StringType)
checkEvaluation(ret3, s"${lb}1 -> 2014-12-03, 2 -> 2014-12-04, 3 -> 2014-12-05$rb")
val ret4 = cast(
Literal.create(Map(
1 -> Timestamp.valueOf("2014-12-03 13:01:00"),
2 -> Timestamp.valueOf("2014-12-04 15:05:00"))),
StringType)
checkEvaluation(ret4, s"${lb}1 -> 2014-12-03 13:01:00, 2 -> 2014-12-04 15:05:00$rb")
val ret5 = cast(
Literal.create(Map(
1 -> Array(1, 2, 3),
2 -> Array(4, 5, 6))),
StringType)
checkEvaluation(ret5, s"${lb}1 -> [1, 2, 3], 2 -> [4, 5, 6]$rb")
}
}
}
test("SPARK-22981 Cast struct to string") {
Seq(
false -> ("{", "}"),
true -> ("[", "]")).foreach { case (legacyCast, (lb, rb)) =>
withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> legacyCast.toString) {
val ret1 = cast(Literal.create((1, "a", 0.1)), StringType)
checkEvaluation(ret1, s"${lb}1, a, 0.1$rb")
val ret2 = cast(Literal.create(Tuple3[Int, String, String](1, null, "a")), StringType)
checkEvaluation(ret2, s"${lb}1,${if (legacyCast) "" else " null"}, a$rb")
val ret3 = cast(Literal.create(
(Date.valueOf("2014-12-03"), Timestamp.valueOf("2014-12-03 15:05:00"))), StringType)
checkEvaluation(ret3, s"${lb}2014-12-03, 2014-12-03 15:05:00$rb")
val ret4 = cast(Literal.create(((1, "a"), 5, 0.1)), StringType)
checkEvaluation(ret4, s"$lb${lb}1, a$rb, 5, 0.1$rb")
val ret5 = cast(Literal.create((Seq(1, 2, 3), "a", 0.1)), StringType)
checkEvaluation(ret5, s"$lb[1, 2, 3], a, 0.1$rb")
val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType)
checkEvaluation(ret6, s"${lb}1, ${lb}1 -> a, 2 -> b, 3 -> c$rb$rb")
}
}
}
test("SPARK-33291: Cast struct with null elements to string") {
Seq(
false -> ("{", "}"),
true -> ("[", "]")).foreach { case (legacyCast, (lb, rb)) =>
withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> legacyCast.toString) {
val ret1 = cast(Literal.create(Tuple2[String, String](null, null)), StringType)
checkEvaluation(
ret1,
s"$lb${if (legacyCast) "" else "null"},${if (legacyCast) "" else " null"}$rb")
}
}
}
test("data type casting II") {
checkEvaluation(
cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType),
DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType),
5.toShort)
checkEvaluation(
cast(cast(cast(cast(cast(cast("5", TimestampType, UTC_OPT), ByteType),
DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType),
null)
checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT),
ByteType), TimestampType), LongType), StringType), ShortType),
5.toShort)
}
test("Cast from double II") {
checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble)
}
}
/**
* Test suite for data type casting expression [[AnsiCast]].
* Test suite for data type casting expression [[Cast]] with ANSI mode disabled.
*/
class AnsiCastSuite extends CastSuiteBase {
// It is not required to set SQLConf.ANSI_ENABLED as true for testing numeric overflow.
override protected def requiredAnsiEnabledForOverflowTestCases: Boolean = false
class CastSuiteWithAnsiModeOn extends AnsiCastSuiteBase {
override def beforeAll(): Unit = {
super.beforeAll()
SQLConf.get.setConf(SQLConf.ANSI_ENABLED, true)
}
override def afterAll(): Unit = {
super.afterAll()
SQLConf.get.unsetConf(SQLConf.ANSI_ENABLED)
}
override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = {
v match {
case lit: Expression => Cast(lit, targetType, timeZoneId)
case _ => Cast(Literal(v), targetType, timeZoneId)
}
}
}
/**
* Test suite for data type casting expression [[AnsiCast]] with ANSI mode enabled.
*/
class AnsiCastSuiteWithAnsiModeOn extends AnsiCastSuiteBase {
override def beforeAll(): Unit = {
super.beforeAll()
SQLConf.get.setConf(SQLConf.ANSI_ENABLED, true)
}
override def afterAll(): Unit = {
super.afterAll()
SQLConf.get.unsetConf(SQLConf.ANSI_ENABLED)
}
override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = {
v match {
@ -1411,78 +1525,26 @@ class AnsiCastSuite extends CastSuiteBase {
case _ => AnsiCast(Literal(v), targetType, timeZoneId)
}
}
}
test("cast from invalid string to numeric should throw NumberFormatException") {
// cast to IntegerType
Seq(IntegerType, ShortType, ByteType, LongType).foreach { dataType =>
val array = Literal.create(Seq("123", "true", "f", null),
ArrayType(StringType, containsNull = true))
checkExceptionInExpression[NumberFormatException](
cast(array, ArrayType(dataType, containsNull = true)),
"invalid input syntax for type numeric: true")
checkExceptionInExpression[NumberFormatException](
cast("string", dataType), "invalid input syntax for type numeric: string")
checkExceptionInExpression[NumberFormatException](
cast("123-string", dataType), "invalid input syntax for type numeric: 123-string")
checkExceptionInExpression[NumberFormatException](
cast("2020-07-19", dataType), "invalid input syntax for type numeric: 2020-07-19")
checkExceptionInExpression[NumberFormatException](
cast("1.23", dataType), "invalid input syntax for type numeric: 1.23")
}
Seq(DoubleType, FloatType, DecimalType.USER_DEFAULT).foreach { dataType =>
checkExceptionInExpression[NumberFormatException](
cast("string", dataType), "invalid input syntax for type numeric: string")
checkExceptionInExpression[NumberFormatException](
cast("123.000.00", dataType), "invalid input syntax for type numeric: 123.000.00")
checkExceptionInExpression[NumberFormatException](
cast("abc.com", dataType), "invalid input syntax for type numeric: abc.com")
}
/**
* Test suite for data type casting expression [[AnsiCast]] with ANSI mode disabled.
*/
class AnsiCastSuiteWithAnsiModeOff extends AnsiCastSuiteBase {
override def beforeAll(): Unit = {
super.beforeAll()
SQLConf.get.setConf(SQLConf.ANSI_ENABLED, false)
}
test("cast a timestamp before the epoch 1970-01-01 00:00:00Z") {
def errMsg(t: String): String = s"Casting -2198208303900000 to $t causes overflow"
withDefaultTimeZone(UTC) {
val negativeTs = Timestamp.valueOf("1900-05-05 18:34:56.1")
assert(negativeTs.getTime < 0)
val expectedSecs = Math.floorDiv(negativeTs.getTime, MILLIS_PER_SECOND)
checkExceptionInExpression[ArithmeticException](cast(negativeTs, ByteType), errMsg("byte"))
checkExceptionInExpression[ArithmeticException](cast(negativeTs, ShortType), errMsg("short"))
checkExceptionInExpression[ArithmeticException](cast(negativeTs, IntegerType), errMsg("int"))
checkEvaluation(cast(negativeTs, LongType), expectedSecs)
}
override def afterAll(): Unit = {
super.afterAll()
SQLConf.get.unsetConf(SQLConf.ANSI_ENABLED)
}
test("Fast fail for cast string type to decimal type in ansi mode") {
checkEvaluation(cast("12345678901234567890123456789012345678", DecimalType(38, 0)),
Decimal("12345678901234567890123456789012345678"))
checkExceptionInExpression[ArithmeticException](
cast("123456789012345678901234567890123456789", DecimalType(38, 0)),
"out of decimal type range")
checkExceptionInExpression[ArithmeticException](
cast("12345678901234567890123456789012345678", DecimalType(38, 1)),
"cannot be represented as Decimal(38, 1)")
checkEvaluation(cast("0.00000000000000000000000000000000000001", DecimalType(38, 0)),
Decimal("0"))
checkEvaluation(cast("0.00000000000000000000000000000000000000000001", DecimalType(38, 0)),
Decimal("0"))
checkEvaluation(cast("0.00000000000000000000000000000000000001", DecimalType(38, 18)),
Decimal("0E-18"))
checkEvaluation(cast("6E-120", DecimalType(38, 0)),
Decimal("0"))
checkEvaluation(cast("6E+37", DecimalType(38, 0)),
Decimal("60000000000000000000000000000000000000"))
checkExceptionInExpression[ArithmeticException](
cast("6E+38", DecimalType(38, 0)),
"out of decimal type range")
checkExceptionInExpression[ArithmeticException](
cast("6E+37", DecimalType(38, 1)),
"cannot be represented as Decimal(38, 1)")
checkExceptionInExpression[NumberFormatException](
cast("abcd", DecimalType(38, 1)),
"invalid input syntax for type numeric")
override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = {
v match {
case lit: Expression => AnsiCast(lit, targetType, timeZoneId)
case _ => AnsiCast(Literal(v), targetType, timeZoneId)
}
}
}

View file

@ -756,6 +756,47 @@ class InsertSuite extends DataSourceTest with SharedSparkSession {
}
}
test("SPARK-33354: Throw exceptions on inserting invalid cast with ANSI casting policy") {
withSQLConf(
SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.ANSI.toString) {
withTable("t") {
sql("CREATE TABLE t(i int, t timestamp) USING parquet")
val msg = intercept[AnalysisException] {
sql("INSERT INTO t VALUES (TIMESTAMP('2010-09-02 14:10:10'), 1)")
}.getMessage
assert(msg.contains("Cannot safely cast 'i': timestamp to int"))
assert(msg.contains("Cannot safely cast 't': int to timestamp"))
}
withTable("t") {
sql("CREATE TABLE t(i int, d date) USING parquet")
val msg = intercept[AnalysisException] {
sql("INSERT INTO t VALUES (date('2010-09-02'), 1)")
}.getMessage
assert(msg.contains("Cannot safely cast 'i': date to int"))
assert(msg.contains("Cannot safely cast 'd': int to date"))
}
withTable("t") {
sql("CREATE TABLE t(b boolean, t timestamp) USING parquet")
val msg = intercept[AnalysisException] {
sql("INSERT INTO t VALUES (TIMESTAMP('2010-09-02 14:10:10'), true)")
}.getMessage
assert(msg.contains("Cannot safely cast 'b': timestamp to boolean"))
assert(msg.contains("Cannot safely cast 't': boolean to timestamp"))
}
withTable("t") {
sql("CREATE TABLE t(b boolean, d date) USING parquet")
val msg = intercept[AnalysisException] {
sql("INSERT INTO t VALUES (date('2010-09-02'), true)")
}.getMessage
assert(msg.contains("Cannot safely cast 'b': date to boolean"))
assert(msg.contains("Cannot safely cast 'd': boolean to date"))
}
}
}
test("SPARK-30844: static partition should also follow StoreAssignmentPolicy") {
SQLConf.StoreAssignmentPolicy.values.foreach { policy =>
withSQLConf(