diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index b679901531..4c31b52d62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -158,7 +158,10 @@ object AnsiTypeCoercion extends TypeCoercionBase { case _ if expectedType.acceptsType(inType) => Some(inType) // Cast null type (usually from null literals) into target types - case (NullType, target) => Some(target.defaultConcreteType) + // By default, the result type is `target.defaultConcreteType`. When the target type is + // `TypeCollection`, there is another branch to find the "closet convertible data type" below. + case (NullType, target) if !target.isInstanceOf[TypeCollection] => + Some(target.defaultConcreteType) // This type coercion system will allow implicit converting String type literals as other // primitive types, in case of breaking too many existing Spark SQL queries. @@ -191,9 +194,35 @@ object AnsiTypeCoercion extends TypeCoercionBase { case (DateType, TimestampType) => Some(TimestampType) // When we reach here, input type is not acceptable for any types in this type collection, - // try to find the first one we can implicitly cast. + // first try to find the all the expected types we can implicitly cast: + // 1. if there is no convertible data types, return None; + // 2. if there is only one convertible data type, cast input as it; + // 3. otherwise if there are multiple convertible data types, find the closet convertible + // data type among them. If there is no such a data type, return None. case (_, TypeCollection(types)) => - types.flatMap(implicitCast(inType, _, isInputFoldable)).headOption + // Since Spark contains special objects like `NumericType` and `DecimalType`, which accepts + // multiple types and they are `AbstractDataType` instead of `DataType`, here we use the + // conversion result their representation. + val convertibleTypes = types.flatMap(implicitCast(inType, _, isInputFoldable)) + if (convertibleTypes.isEmpty) { + None + } else { + // find the closet convertible data type, which can be implicit cast to all other + // convertible types. + val closestConvertibleType = convertibleTypes.find { dt => + convertibleTypes.forall { target => + implicitCast(dt, target, isInputFoldable = false).isDefined + } + } + // If the closet convertible type is Float type and the convertible types contains Double + // type, simply return Double type as the closet convertible type to avoid potential + // precision loss on converting the Integral type as Float type. + if (closestConvertibleType.contains(FloatType) && convertibleTypes.contains(DoubleType)) { + Some(DoubleType) + } else { + closestConvertibleType + } + } // Implicit cast between array types. // diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala index 88e082f158..e3e61f022c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala @@ -345,8 +345,6 @@ class AnsiTypeCoercionSuite extends AnalysisTest { } test("eligible implicit type cast - TypeCollection") { - shouldCast(NullType, TypeCollection(StringType, BinaryType), StringType) - shouldCast(StringType, TypeCollection(StringType, BinaryType), StringType) shouldCast(BinaryType, TypeCollection(StringType, BinaryType), BinaryType) shouldCast(StringType, TypeCollection(BinaryType, StringType), StringType) @@ -356,17 +354,10 @@ class AnsiTypeCoercionSuite extends AnalysisTest { shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType) shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType) - shouldNotCast(IntegerType, TypeCollection(StringType, BinaryType)) - shouldNotCast(IntegerType, TypeCollection(BinaryType, StringType)) - shouldCast(DecimalType.SYSTEM_DEFAULT, TypeCollection(IntegerType, DecimalType), DecimalType.SYSTEM_DEFAULT) shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2)) shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2)) - shouldNotCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType)) - - shouldNotCastStringInput(TypeCollection(NumericType, BinaryType)) - shouldCastStringLiteral(TypeCollection(NumericType, BinaryType), DoubleType) shouldCast( ArrayType(StringType, false), @@ -377,10 +368,32 @@ class AnsiTypeCoercionSuite extends AnalysisTest { ArrayType(StringType, true), TypeCollection(ArrayType(StringType), StringType), ArrayType(StringType, true)) + + // When there are multiple convertible types in the `TypeCollection`, use the closest + // convertible data type among convertible types. + shouldCast(IntegerType, TypeCollection(BinaryType, FloatType, LongType), LongType) + shouldCast(ShortType, TypeCollection(BinaryType, LongType, IntegerType), IntegerType) + shouldCast(ShortType, TypeCollection(DateType, LongType, IntegerType, DoubleType), IntegerType) + // If the result is Float type and Double type is also among the convertible target types, + // use Double Type instead of Float type. + shouldCast(LongType, TypeCollection(FloatType, DoubleType, StringType), DoubleType) } test("ineligible implicit type cast - TypeCollection") { + shouldNotCast(IntegerType, TypeCollection(StringType, BinaryType)) + shouldNotCast(IntegerType, TypeCollection(BinaryType, StringType)) shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType)) + shouldNotCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType)) + shouldNotCastStringInput(TypeCollection(NumericType, BinaryType)) + // When there are multiple convertible types in the `TypeCollection` and there is no such + // a data type that can be implicit cast to all the other convertible types in the collection. + Seq(TypeCollection(NumericType, BinaryType), + TypeCollection(NumericType, DecimalType, BinaryType), + TypeCollection(IntegerType, LongType, BooleanType), + TypeCollection(DateType, TimestampType, BooleanType)).foreach { typeCollection => + shouldNotCastStringLiteral(typeCollection) + shouldNotCast(NullType, typeCollection) + } } test("tightest common bound for types") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 47c1a6debb..d44055d72e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -17,9 +17,11 @@ select position('bar' in 'foobarbar'), position(null, 'foobarbar'), position('aa -- left && right select left("abcd", 2), left("abcd", 5), left("abcd", '2'), left("abcd", null); -select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a'); +select left(null, -2); +select left("abcd", -2), left("abcd", 0), left("abcd", 'a'); select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null); -select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a'); +select right(null, -2); +select right("abcd", -2), right("abcd", 0), right("abcd", 'a'); -- split function SELECT split('aa1cc2ee3', '[1-9]+'); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index a1f1d87f5a..3f4399fe08 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 48 +-- Number of queries: 50 -- !query @@ -69,7 +69,16 @@ ab abcd ab NULL -- !query -select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a') +select left(null, -2) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'substring(NULL, 1, -2)' due to data type mismatch: argument 1 requires (string or binary) type, however, 'NULL' is of null type.; line 1 pos 7 + + +-- !query +select left("abcd", -2), left("abcd", 0), left("abcd", 'a') -- !query schema struct<> -- !query output @@ -87,12 +96,21 @@ cannot resolve 'substring('abcd', (- CAST('2' AS DOUBLE)), 2147483647)' due to d -- !query -select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a') +select right(null, -2) -- !query schema struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'substring('abcd', (- CAST('a' AS DOUBLE)), 2147483647)' due to data type mismatch: argument 2 requires int type, however, '(- CAST('a' AS DOUBLE))' is of double type.; line 1 pos 61 +cannot resolve 'substring(NULL, (- -2), 2147483647)' due to data type mismatch: argument 1 requires (string or binary) type, however, 'NULL' is of null type.; line 1 pos 7 + + +-- !query +select right("abcd", -2), right("abcd", 0), right("abcd", 'a') +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'substring('abcd', (- CAST('a' AS DOUBLE)), 2147483647)' due to data type mismatch: argument 2 requires int type, however, '(- CAST('a' AS DOUBLE))' is of double type.; line 1 pos 44 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out index 0ecba2d11a..2387dd2441 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out @@ -130,7 +130,7 @@ select concat_ws(',',10,20,null,30) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'concat_ws(',', 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array or string) type, however, '10' is of int type. argument 3 requires (array or string) type, however, '20' is of int type. argument 5 requires (array or string) type, however, '30' is of int type.; line 1 pos 7 +cannot resolve 'concat_ws(',', 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array or string) type, however, '10' is of int type. argument 3 requires (array or string) type, however, '20' is of int type. argument 4 requires (array or string) type, however, 'NULL' is of null type. argument 5 requires (array or string) type, however, '30' is of int type.; line 1 pos 7 -- !query @@ -139,7 +139,7 @@ select concat_ws('',10,20,null,30) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'concat_ws('', 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array or string) type, however, '10' is of int type. argument 3 requires (array or string) type, however, '20' is of int type. argument 5 requires (array or string) type, however, '30' is of int type.; line 1 pos 7 +cannot resolve 'concat_ws('', 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array or string) type, however, '10' is of int type. argument 3 requires (array or string) type, however, '20' is of int type. argument 4 requires (array or string) type, however, 'NULL' is of null type. argument 5 requires (array or string) type, however, '30' is of int type.; line 1 pos 7 -- !query @@ -148,7 +148,7 @@ select concat_ws(NULL,10,20,null,30) is null struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'concat_ws(CAST(NULL AS STRING), 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array or string) type, however, '10' is of int type. argument 3 requires (array or string) type, however, '20' is of int type. argument 5 requires (array or string) type, however, '30' is of int type.; line 1 pos 7 +cannot resolve 'concat_ws(CAST(NULL AS STRING), 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array or string) type, however, '10' is of int type. argument 3 requires (array or string) type, however, '20' is of int type. argument 4 requires (array or string) type, however, 'NULL' is of null type. argument 5 requires (array or string) type, however, '30' is of int type.; line 1 pos 7 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 03a45a287d..80e88d0566 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 48 +-- Number of queries: 50 -- !query @@ -69,11 +69,19 @@ ab abcd ab NULL -- !query -select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a') +select left(null, -2) -- !query schema -struct +struct -- !query output -NULL NULL +NULL + + +-- !query +select left("abcd", -2), left("abcd", 0), left("abcd", 'a') +-- !query schema +struct +-- !query output + NULL -- !query @@ -85,11 +93,19 @@ cd abcd cd NULL -- !query -select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a') +select right(null, -2) -- !query schema -struct +struct -- !query output -NULL NULL +NULL + + +-- !query +select right("abcd", -2), right("abcd", 0), right("abcd", 'a') +-- !query schema +struct +-- !query output + NULL -- !query