[SPARK-30292][SQL] Throw Exception when invalid string is cast to numeric type in ANSI mode
### What changes were proposed in this pull request? If spark.sql.ansi.enabled is set, throw exception when cast to any numeric type do not follow the ANSI SQL standards. ### Why are the changes needed? ANSI SQL standards do not allow invalid strings to get casted into numeric types and throw exception for that. Currently spark sql gives NULL in such cases. Before: `select cast('str' as decimal) => NULL` After : `select cast('str' as decimal) => invalid input syntax for type numeric: str` These results are after setting `spark.sql.ansi.enabled=true` ### Does this PR introduce any user-facing change? Yes. Now when ansi mode is on users will get arithmetic exception for invalid strings. ### How was this patch tested? Unit Tests Added. Closes #26933 from iRakson/castDecimalANSI. Lead-authored-by: root1 <raksonrakesh@gmail.com> Co-authored-by: iRakson <raksonrakesh@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
88fc8dbc09
commit
e0efd213eb
|
@ -1294,6 +1294,52 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
|
|||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses UTF8String(trimmed if needed) to long. This method is used when ANSI is enabled.
|
||||
*
|
||||
* @return If string contains valid numeric value then it returns the long value otherwise a
|
||||
* NumberFormatException is thrown.
|
||||
*/
|
||||
public long toLongExact() {
|
||||
LongWrapper result = new LongWrapper();
|
||||
if (toLong(result)) {
|
||||
return result.value;
|
||||
}
|
||||
throw new NumberFormatException("invalid input syntax for type numeric: " + this);
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses UTF8String(trimmed if needed) to int. This method is used when ANSI is enabled.
|
||||
*
|
||||
* @return If string contains valid numeric value then it returns the int value otherwise a
|
||||
* NumberFormatException is thrown.
|
||||
*/
|
||||
public int toIntExact() {
|
||||
IntWrapper result = new IntWrapper();
|
||||
if (toInt(result)) {
|
||||
return result.value;
|
||||
}
|
||||
throw new NumberFormatException("invalid input syntax for type numeric: " + this);
|
||||
}
|
||||
|
||||
public short toShortExact() {
|
||||
int value = this.toIntExact();
|
||||
short result = (short) value;
|
||||
if (result == value) {
|
||||
return result;
|
||||
}
|
||||
throw new NumberFormatException("invalid input syntax for type numeric: " + this);
|
||||
}
|
||||
|
||||
public byte toByteExact() {
|
||||
int value = this.toIntExact();
|
||||
byte result = (byte) value;
|
||||
if (result == value) {
|
||||
return result;
|
||||
}
|
||||
throw new NumberFormatException("invalid input syntax for type numeric: " + this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return new String(getBytes(), StandardCharsets.UTF_8);
|
||||
|
|
|
@ -482,6 +482,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
|
|||
|
||||
// LongConverter
|
||||
private[this] def castToLong(from: DataType): Any => Any = from match {
|
||||
case StringType if ansiEnabled =>
|
||||
buildCast[UTF8String](_, _.toLongExact())
|
||||
case StringType =>
|
||||
val result = new LongWrapper()
|
||||
buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null)
|
||||
|
@ -499,6 +501,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
|
|||
|
||||
// IntConverter
|
||||
private[this] def castToInt(from: DataType): Any => Any = from match {
|
||||
case StringType if ansiEnabled =>
|
||||
buildCast[UTF8String](_, _.toIntExact())
|
||||
case StringType =>
|
||||
val result = new IntWrapper()
|
||||
buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null)
|
||||
|
@ -518,6 +522,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
|
|||
|
||||
// ShortConverter
|
||||
private[this] def castToShort(from: DataType): Any => Any = from match {
|
||||
case StringType if ansiEnabled =>
|
||||
buildCast[UTF8String](_, _.toShortExact())
|
||||
case StringType =>
|
||||
val result = new IntWrapper()
|
||||
buildCast[UTF8String](_, s => if (s.toShort(result)) {
|
||||
|
@ -559,6 +565,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
|
|||
|
||||
// ByteConverter
|
||||
private[this] def castToByte(from: DataType): Any => Any = from match {
|
||||
case StringType if ansiEnabled =>
|
||||
buildCast[UTF8String](_, _.toByteExact())
|
||||
case StringType =>
|
||||
val result = new IntWrapper()
|
||||
buildCast[UTF8String](_, s => if (s.toByte(result)) {
|
||||
|
@ -636,7 +644,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
|
|||
// Please refer to https://github.com/apache/spark/pull/26640
|
||||
changePrecision(Decimal(new JavaBigDecimal(s.toString.trim)), target)
|
||||
} catch {
|
||||
case _: NumberFormatException => null
|
||||
case _: NumberFormatException =>
|
||||
if (ansiEnabled) {
|
||||
throw new NumberFormatException(s"invalid input syntax for type numeric: $s")
|
||||
} else {
|
||||
null
|
||||
}
|
||||
})
|
||||
case BooleanType =>
|
||||
buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target))
|
||||
|
@ -664,7 +677,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
|
|||
val doubleStr = s.toString
|
||||
try doubleStr.toDouble catch {
|
||||
case _: NumberFormatException =>
|
||||
Cast.processFloatingPointSpecialLiterals(doubleStr, false)
|
||||
val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false)
|
||||
if(ansiEnabled && d == null) {
|
||||
throw new NumberFormatException(s"invalid input syntax for type numeric: $s")
|
||||
} else {
|
||||
d
|
||||
}
|
||||
}
|
||||
})
|
||||
case BooleanType =>
|
||||
|
@ -684,7 +702,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
|
|||
val floatStr = s.toString
|
||||
try floatStr.toFloat catch {
|
||||
case _: NumberFormatException =>
|
||||
Cast.processFloatingPointSpecialLiterals(floatStr, true)
|
||||
val f = Cast.processFloatingPointSpecialLiterals(floatStr, true)
|
||||
if (ansiEnabled && f == null) {
|
||||
throw new NumberFormatException(s"invalid input syntax for type numeric: $s")
|
||||
} else {
|
||||
f
|
||||
}
|
||||
}
|
||||
})
|
||||
case BooleanType =>
|
||||
|
@ -1128,12 +1151,17 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
|
|||
from match {
|
||||
case StringType =>
|
||||
(c, evPrim, evNull) =>
|
||||
val handleException = if (ansiEnabled) {
|
||||
s"""throw new NumberFormatException("invalid input syntax for type numeric: $c");"""
|
||||
} else {
|
||||
s"$evNull =true;"
|
||||
}
|
||||
code"""
|
||||
try {
|
||||
Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString().trim()));
|
||||
${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)}
|
||||
} catch (java.lang.NumberFormatException e) {
|
||||
$evNull = true;
|
||||
$handleException
|
||||
}
|
||||
"""
|
||||
case BooleanType =>
|
||||
|
@ -1355,6 +1383,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
|
|||
}
|
||||
|
||||
private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
|
||||
case StringType if ansiEnabled =>
|
||||
(c, evPrim, evNull) => code"$evPrim = $c.toByteExact();"
|
||||
case StringType =>
|
||||
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
|
||||
(c, evPrim, evNull) =>
|
||||
|
@ -1386,6 +1416,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
|
|||
private[this] def castToShortCode(
|
||||
from: DataType,
|
||||
ctx: CodegenContext): CastFunction = from match {
|
||||
case StringType if ansiEnabled =>
|
||||
(c, evPrim, evNull) => code"$evPrim = $c.toShortExact();"
|
||||
case StringType =>
|
||||
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
|
||||
(c, evPrim, evNull) =>
|
||||
|
@ -1415,6 +1447,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
|
|||
}
|
||||
|
||||
private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
|
||||
case StringType if ansiEnabled =>
|
||||
(c, evPrim, evNull) => code"$evPrim = $c.toIntExact();"
|
||||
case StringType =>
|
||||
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
|
||||
(c, evPrim, evNull) =>
|
||||
|
@ -1443,9 +1477,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
|
|||
}
|
||||
|
||||
private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
|
||||
case StringType if ansiEnabled =>
|
||||
(c, evPrim, evNull) => code"$evPrim = $c.toLongExact();"
|
||||
case StringType =>
|
||||
val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper])
|
||||
|
||||
(c, evPrim, evNull) =>
|
||||
code"""
|
||||
UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper();
|
||||
|
@ -1476,6 +1511,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
|
|||
case StringType =>
|
||||
val floatStr = ctx.freshVariable("floatStr", StringType)
|
||||
(c, evPrim, evNull) =>
|
||||
val handleNull = if (ansiEnabled) {
|
||||
s"""throw new NumberFormatException("invalid input syntax for type numeric: $c");"""
|
||||
} else {
|
||||
s"$evNull = true;"
|
||||
}
|
||||
code"""
|
||||
final String $floatStr = $c.toString();
|
||||
try {
|
||||
|
@ -1483,7 +1523,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
|
|||
} catch (java.lang.NumberFormatException e) {
|
||||
final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true);
|
||||
if (f == null) {
|
||||
$evNull = true;
|
||||
$handleNull
|
||||
} else {
|
||||
$evPrim = f.floatValue();
|
||||
}
|
||||
|
@ -1507,6 +1547,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
|
|||
case StringType =>
|
||||
val doubleStr = ctx.freshVariable("doubleStr", StringType)
|
||||
(c, evPrim, evNull) =>
|
||||
val handleNull = if (ansiEnabled) {
|
||||
s"""throw new NumberFormatException("invalid input syntax for type numeric: $c");"""
|
||||
} else {
|
||||
s"$evNull = true;"
|
||||
}
|
||||
code"""
|
||||
final String $doubleStr = $c.toString();
|
||||
try {
|
||||
|
@ -1514,7 +1559,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
|
|||
} catch (java.lang.NumberFormatException e) {
|
||||
final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false);
|
||||
if (d == null) {
|
||||
$evNull = true;
|
||||
$handleNull
|
||||
} else {
|
||||
$evPrim = d.doubleValue();
|
||||
}
|
||||
|
|
|
@ -284,7 +284,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
|
|||
val gmtId = Option("GMT")
|
||||
|
||||
checkEvaluation(cast("abdef", StringType), "abdef")
|
||||
checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null)
|
||||
checkEvaluation(cast("abdef", TimestampType, gmtId), null)
|
||||
checkEvaluation(cast("12.65", DecimalType.SYSTEM_DEFAULT), Decimal(12.65))
|
||||
|
||||
|
@ -324,7 +323,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(cast("23", DecimalType.USER_DEFAULT), Decimal(23))
|
||||
checkEvaluation(cast("23", ByteType), 23.toByte)
|
||||
checkEvaluation(cast("23", ShortType), 23.toShort)
|
||||
checkEvaluation(cast("2012-12-11", DoubleType), null)
|
||||
checkEvaluation(cast(123, IntegerType), 123)
|
||||
|
||||
checkEvaluation(cast(Literal.create(null, IntegerType), ShortType), null)
|
||||
|
@ -410,15 +408,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
|
|||
|
||||
checkNullCast(ArrayType(StringType), ArrayType(IntegerType))
|
||||
|
||||
{
|
||||
val ret = cast(array, ArrayType(IntegerType, containsNull = true))
|
||||
assert(ret.resolved)
|
||||
checkEvaluation(ret, Seq(123, null, null, null))
|
||||
}
|
||||
{
|
||||
val ret = cast(array, ArrayType(IntegerType, containsNull = false))
|
||||
assert(ret.resolved === false)
|
||||
}
|
||||
{
|
||||
val ret = cast(array, ArrayType(BooleanType, containsNull = true))
|
||||
assert(ret.resolved)
|
||||
|
@ -429,15 +418,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
|
|||
assert(ret.resolved === false)
|
||||
}
|
||||
|
||||
{
|
||||
val ret = cast(array_notNull, ArrayType(IntegerType, containsNull = true))
|
||||
assert(ret.resolved)
|
||||
checkEvaluation(ret, Seq(123, null, null))
|
||||
}
|
||||
{
|
||||
val ret = cast(array_notNull, ArrayType(IntegerType, containsNull = false))
|
||||
assert(ret.resolved === false)
|
||||
}
|
||||
{
|
||||
val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true))
|
||||
assert(ret.resolved)
|
||||
|
@ -464,15 +444,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
|
|||
|
||||
checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType))
|
||||
|
||||
{
|
||||
val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = true))
|
||||
assert(ret.resolved)
|
||||
checkEvaluation(ret, Map("a" -> 123, "b" -> null, "c" -> null, "d" -> null))
|
||||
}
|
||||
{
|
||||
val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = false))
|
||||
assert(ret.resolved === false)
|
||||
}
|
||||
{
|
||||
val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true))
|
||||
assert(ret.resolved)
|
||||
|
@ -486,16 +457,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
|
|||
val ret = cast(map, MapType(IntegerType, StringType, valueContainsNull = true))
|
||||
assert(ret.resolved === false)
|
||||
}
|
||||
|
||||
{
|
||||
val ret = cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = true))
|
||||
assert(ret.resolved)
|
||||
checkEvaluation(ret, Map("a" -> 123, "b" -> null, "c" -> null))
|
||||
}
|
||||
{
|
||||
val ret = cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = false))
|
||||
assert(ret.resolved === false)
|
||||
}
|
||||
{
|
||||
val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true))
|
||||
assert(ret.resolved)
|
||||
|
@ -546,23 +507,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
|
|||
StructField("b", StringType, nullable = false),
|
||||
StructField("c", StringType, nullable = false))))
|
||||
|
||||
{
|
||||
val ret = cast(struct, StructType(Seq(
|
||||
StructField("a", IntegerType, nullable = true),
|
||||
StructField("b", IntegerType, nullable = true),
|
||||
StructField("c", IntegerType, nullable = true),
|
||||
StructField("d", IntegerType, nullable = true))))
|
||||
assert(ret.resolved)
|
||||
checkEvaluation(ret, InternalRow(123, null, null, null))
|
||||
}
|
||||
{
|
||||
val ret = cast(struct, StructType(Seq(
|
||||
StructField("a", IntegerType, nullable = true),
|
||||
StructField("b", IntegerType, nullable = true),
|
||||
StructField("c", IntegerType, nullable = false),
|
||||
StructField("d", IntegerType, nullable = true))))
|
||||
assert(ret.resolved === false)
|
||||
}
|
||||
{
|
||||
val ret = cast(struct, StructType(Seq(
|
||||
StructField("a", BooleanType, nullable = true),
|
||||
|
@ -581,21 +525,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
|
|||
assert(ret.resolved === false)
|
||||
}
|
||||
|
||||
{
|
||||
val ret = cast(struct_notNull, StructType(Seq(
|
||||
StructField("a", IntegerType, nullable = true),
|
||||
StructField("b", IntegerType, nullable = true),
|
||||
StructField("c", IntegerType, nullable = true))))
|
||||
assert(ret.resolved)
|
||||
checkEvaluation(ret, InternalRow(123, null, null))
|
||||
}
|
||||
{
|
||||
val ret = cast(struct_notNull, StructType(Seq(
|
||||
StructField("a", IntegerType, nullable = true),
|
||||
StructField("b", IntegerType, nullable = true),
|
||||
StructField("c", IntegerType, nullable = false))))
|
||||
assert(ret.resolved === false)
|
||||
}
|
||||
{
|
||||
val ret = cast(struct_notNull, StructType(Seq(
|
||||
StructField("a", BooleanType, nullable = true),
|
||||
|
@ -921,11 +850,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
|
|||
Seq("nan", "nAn", " nan ").foreach { value =>
|
||||
checkEvaluation(cast(value, DoubleType), Double.NaN)
|
||||
}
|
||||
|
||||
// Invalid literals when casted to double and float results in null.
|
||||
Seq(DoubleType, FloatType).foreach { dataType =>
|
||||
checkEvaluation(cast("badvalue", dataType), null)
|
||||
}
|
||||
}
|
||||
|
||||
private def testIntMaxAndMin(dt: DataType): Unit = {
|
||||
|
@ -1054,7 +978,6 @@ class CastSuite extends CastSuiteBase {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
test("cast from int") {
|
||||
checkCast(0, false)
|
||||
checkCast(1, true)
|
||||
|
@ -1214,6 +1137,125 @@ class CastSuite extends CastSuiteBase {
|
|||
val set = CollectSet(Literal(1))
|
||||
assert(Cast.canCast(set.dataType, ArrayType(StringType, false)))
|
||||
}
|
||||
|
||||
test("Cast should output null for invalid strings when ANSI is not enabled.") {
|
||||
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
|
||||
checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null)
|
||||
checkEvaluation(cast("2012-12-11", DoubleType), null)
|
||||
|
||||
// cast to array
|
||||
val array = Literal.create(Seq("123", "true", "f", null),
|
||||
ArrayType(StringType, containsNull = true))
|
||||
val array_notNull = Literal.create(Seq("123", "true", "f"),
|
||||
ArrayType(StringType, containsNull = false))
|
||||
|
||||
{
|
||||
val ret = cast(array, ArrayType(IntegerType, containsNull = true))
|
||||
assert(ret.resolved)
|
||||
checkEvaluation(ret, Seq(123, null, null, null))
|
||||
}
|
||||
{
|
||||
val ret = cast(array, ArrayType(IntegerType, containsNull = false))
|
||||
assert(ret.resolved === false)
|
||||
}
|
||||
{
|
||||
val ret = cast(array_notNull, ArrayType(IntegerType, containsNull = true))
|
||||
assert(ret.resolved)
|
||||
checkEvaluation(ret, Seq(123, null, null))
|
||||
}
|
||||
{
|
||||
val ret = cast(array_notNull, ArrayType(IntegerType, containsNull = false))
|
||||
assert(ret.resolved === false)
|
||||
}
|
||||
|
||||
// cast from map
|
||||
val map = Literal.create(
|
||||
Map("a" -> "123", "b" -> "true", "c" -> "f", "d" -> null),
|
||||
MapType(StringType, StringType, valueContainsNull = true))
|
||||
val map_notNull = Literal.create(
|
||||
Map("a" -> "123", "b" -> "true", "c" -> "f"),
|
||||
MapType(StringType, StringType, valueContainsNull = false))
|
||||
|
||||
{
|
||||
val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = true))
|
||||
assert(ret.resolved)
|
||||
checkEvaluation(ret, Map("a" -> 123, "b" -> null, "c" -> null, "d" -> null))
|
||||
}
|
||||
{
|
||||
val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = false))
|
||||
assert(ret.resolved === false)
|
||||
}
|
||||
{
|
||||
val ret = cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = true))
|
||||
assert(ret.resolved)
|
||||
checkEvaluation(ret, Map("a" -> 123, "b" -> null, "c" -> null))
|
||||
}
|
||||
{
|
||||
val ret = cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = false))
|
||||
assert(ret.resolved === false)
|
||||
}
|
||||
|
||||
// cast from struct
|
||||
val struct = Literal.create(
|
||||
InternalRow(
|
||||
UTF8String.fromString("123"),
|
||||
UTF8String.fromString("true"),
|
||||
UTF8String.fromString("f"),
|
||||
null),
|
||||
StructType(Seq(
|
||||
StructField("a", StringType, nullable = true),
|
||||
StructField("b", StringType, nullable = true),
|
||||
StructField("c", StringType, nullable = true),
|
||||
StructField("d", StringType, nullable = true))))
|
||||
val struct_notNull = Literal.create(
|
||||
InternalRow(
|
||||
UTF8String.fromString("123"),
|
||||
UTF8String.fromString("true"),
|
||||
UTF8String.fromString("f")),
|
||||
StructType(Seq(
|
||||
StructField("a", StringType, nullable = false),
|
||||
StructField("b", StringType, nullable = false),
|
||||
StructField("c", StringType, nullable = false))))
|
||||
|
||||
{
|
||||
val ret = cast(struct, StructType(Seq(
|
||||
StructField("a", IntegerType, nullable = true),
|
||||
StructField("b", IntegerType, nullable = true),
|
||||
StructField("c", IntegerType, nullable = true),
|
||||
StructField("d", IntegerType, nullable = true))))
|
||||
assert(ret.resolved)
|
||||
checkEvaluation(ret, InternalRow(123, null, null, null))
|
||||
}
|
||||
{
|
||||
val ret = cast(struct, StructType(Seq(
|
||||
StructField("a", IntegerType, nullable = true),
|
||||
StructField("b", IntegerType, nullable = true),
|
||||
StructField("c", IntegerType, nullable = false),
|
||||
StructField("d", IntegerType, nullable = true))))
|
||||
assert(ret.resolved === false)
|
||||
}
|
||||
{
|
||||
val ret = cast(struct_notNull, StructType(Seq(
|
||||
StructField("a", IntegerType, nullable = true),
|
||||
StructField("b", IntegerType, nullable = true),
|
||||
StructField("c", IntegerType, nullable = true))))
|
||||
assert(ret.resolved)
|
||||
checkEvaluation(ret, InternalRow(123, null, null))
|
||||
}
|
||||
{
|
||||
val ret = cast(struct_notNull, StructType(Seq(
|
||||
StructField("a", IntegerType, nullable = true),
|
||||
StructField("b", IntegerType, nullable = true),
|
||||
StructField("c", IntegerType, nullable = false))))
|
||||
assert(ret.resolved === false)
|
||||
}
|
||||
|
||||
// Invalid literals when casted to double and float results in null.
|
||||
Seq(DoubleType, FloatType).foreach { dataType =>
|
||||
checkEvaluation(cast("badvalue", dataType), null)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1229,4 +1271,29 @@ class AnsiCastSuite extends CastSuiteBase {
|
|||
case _ => AnsiCast(Literal(v), targetType, timeZoneId)
|
||||
}
|
||||
}
|
||||
|
||||
test("cast from invalid string to numeric should throw NumberFormatException") {
|
||||
// cast to IntegerType
|
||||
Seq(IntegerType, ShortType, ByteType, LongType).foreach { dataType =>
|
||||
val array = Literal.create(Seq("123", "true", "f", null),
|
||||
ArrayType(StringType, containsNull = true))
|
||||
checkExceptionInExpression[NumberFormatException](
|
||||
cast(array, ArrayType(dataType, containsNull = true)), "invalid input")
|
||||
checkExceptionInExpression[NumberFormatException](
|
||||
cast("string", dataType), "invalid input")
|
||||
checkExceptionInExpression[NumberFormatException](
|
||||
cast("123-string", dataType), "invalid input")
|
||||
checkExceptionInExpression[NumberFormatException](
|
||||
cast("2020-07-19", dataType), "invalid input")
|
||||
}
|
||||
|
||||
Seq(DoubleType, FloatType, DecimalType.USER_DEFAULT).foreach { dataType =>
|
||||
checkExceptionInExpression[NumberFormatException](
|
||||
cast("string", dataType), "invalid input")
|
||||
checkExceptionInExpression[NumberFormatException](
|
||||
cast("123.000.00", dataType), "invalid input")
|
||||
checkExceptionInExpression[NumberFormatException](
|
||||
cast("abc.com", dataType), "invalid input")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -93,25 +93,28 @@ struct<CAST( -INFINiTY AS FLOAT):float>
|
|||
-- !query 11
|
||||
SELECT float('N A N')
|
||||
-- !query 11 schema
|
||||
struct<CAST(N A N AS FLOAT):float>
|
||||
struct<>
|
||||
-- !query 11 output
|
||||
NULL
|
||||
java.lang.NumberFormatException
|
||||
invalid input syntax for type numeric: N A N
|
||||
|
||||
|
||||
-- !query 12
|
||||
SELECT float('NaN x')
|
||||
-- !query 12 schema
|
||||
struct<CAST(NaN x AS FLOAT):float>
|
||||
struct<>
|
||||
-- !query 12 output
|
||||
NULL
|
||||
java.lang.NumberFormatException
|
||||
invalid input syntax for type numeric: NaN x
|
||||
|
||||
|
||||
-- !query 13
|
||||
SELECT float(' INFINITY x')
|
||||
-- !query 13 schema
|
||||
struct<CAST( INFINITY x AS FLOAT):float>
|
||||
struct<>
|
||||
-- !query 13 output
|
||||
NULL
|
||||
java.lang.NumberFormatException
|
||||
invalid input syntax for type numeric: INFINITY x
|
||||
|
||||
|
||||
-- !query 14
|
||||
|
@ -141,9 +144,10 @@ NaN
|
|||
-- !query 17
|
||||
SELECT float(decimal('nan'))
|
||||
-- !query 17 schema
|
||||
struct<CAST(CAST(nan AS DECIMAL(10,0)) AS FLOAT):float>
|
||||
struct<>
|
||||
-- !query 17 output
|
||||
NULL
|
||||
java.lang.NumberFormatException
|
||||
invalid input syntax for type numeric: nan
|
||||
|
||||
|
||||
-- !query 18
|
||||
|
|
|
@ -125,25 +125,28 @@ struct<CAST( -INFINiTY AS DOUBLE):double>
|
|||
-- !query 15
|
||||
SELECT double('N A N')
|
||||
-- !query 15 schema
|
||||
struct<CAST(N A N AS DOUBLE):double>
|
||||
struct<>
|
||||
-- !query 15 output
|
||||
NULL
|
||||
java.lang.NumberFormatException
|
||||
invalid input syntax for type numeric: N A N
|
||||
|
||||
|
||||
-- !query 16
|
||||
SELECT double('NaN x')
|
||||
-- !query 16 schema
|
||||
struct<CAST(NaN x AS DOUBLE):double>
|
||||
struct<>
|
||||
-- !query 16 output
|
||||
NULL
|
||||
java.lang.NumberFormatException
|
||||
invalid input syntax for type numeric: NaN x
|
||||
|
||||
|
||||
-- !query 17
|
||||
SELECT double(' INFINITY x')
|
||||
-- !query 17 schema
|
||||
struct<CAST( INFINITY x AS DOUBLE):double>
|
||||
struct<>
|
||||
-- !query 17 output
|
||||
NULL
|
||||
java.lang.NumberFormatException
|
||||
invalid input syntax for type numeric: INFINITY x
|
||||
|
||||
|
||||
-- !query 18
|
||||
|
@ -173,9 +176,10 @@ NaN
|
|||
-- !query 21
|
||||
SELECT double(decimal('nan'))
|
||||
-- !query 21 schema
|
||||
struct<CAST(CAST(nan AS DECIMAL(10,0)) AS DOUBLE):double>
|
||||
struct<>
|
||||
-- !query 21 output
|
||||
NULL
|
||||
java.lang.NumberFormatException
|
||||
invalid input syntax for type numeric: nan
|
||||
|
||||
|
||||
-- !query 22
|
||||
|
|
|
@ -62,17 +62,19 @@ struct<length(CAST(42 AS STRING)):int>
|
|||
-- !query 7
|
||||
select string('four: ') || 2+2
|
||||
-- !query 7 schema
|
||||
struct<(CAST(concat(CAST(four: AS STRING), CAST(2 AS STRING)) AS DOUBLE) + CAST(2 AS DOUBLE)):double>
|
||||
struct<>
|
||||
-- !query 7 output
|
||||
NULL
|
||||
java.lang.NumberFormatException
|
||||
invalid input syntax for type numeric: four: 2
|
||||
|
||||
|
||||
-- !query 8
|
||||
select 'four: ' || 2+2
|
||||
-- !query 8 schema
|
||||
struct<(CAST(concat(four: , CAST(2 AS STRING)) AS DOUBLE) + CAST(2 AS DOUBLE)):double>
|
||||
struct<>
|
||||
-- !query 8 output
|
||||
NULL
|
||||
java.lang.NumberFormatException
|
||||
invalid input syntax for type numeric: four: 2
|
||||
|
||||
|
||||
-- !query 9
|
||||
|
|
|
@ -452,15 +452,10 @@ from numerics
|
|||
window w as (order by f_numeric range between
|
||||
1.1 preceding and 'NaN' following)
|
||||
-- !query 28 schema
|
||||
struct<id:int,f_numeric:int,first(id, false) OVER (ORDER BY f_numeric ASC NULLS FIRST RANGE BETWEEN CAST((- 1.1) AS INT) FOLLOWING AND CAST(NaN AS INT) FOLLOWING):int,last(id, false) OVER (ORDER BY f_numeric ASC NULLS FIRST RANGE BETWEEN CAST((- 1.1) AS INT) FOLLOWING AND CAST(NaN AS INT) FOLLOWING):int>
|
||||
struct<>
|
||||
-- !query 28 output
|
||||
1 -3 NULL NULL
|
||||
2 -1 NULL NULL
|
||||
3 0 NULL NULL
|
||||
4 1 NULL NULL
|
||||
5 1 NULL NULL
|
||||
6 2 NULL NULL
|
||||
7 100 NULL NULL
|
||||
java.lang.NumberFormatException
|
||||
invalid input syntax for type numeric: NaN
|
||||
|
||||
|
||||
-- !query 29
|
||||
|
|
|
@ -498,10 +498,7 @@ SELECT a, b,
|
|||
SUM(b) OVER(ORDER BY A ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)
|
||||
FROM (VALUES(1,1),(2,2),(3,(cast('nan' as int))),(4,3),(5,4)) t(a,b)
|
||||
-- !query 38 schema
|
||||
struct<a:int,b:int,sum(b) OVER (ORDER BY A ASC NULLS FIRST ROWS BETWEEN 1 PRECEDING AND CURRENT ROW):bigint>
|
||||
struct<>
|
||||
-- !query 38 output
|
||||
1 1 1
|
||||
2 2 3
|
||||
3 NULL 2
|
||||
4 3 3
|
||||
5 4 7
|
||||
org.apache.spark.sql.AnalysisException
|
||||
failed to evaluate expression CAST('nan' AS INT): invalid input syntax for type numeric: nan; line 3 pos 6
|
||||
|
|
Loading…
Reference in a new issue