[SPARK-34769][SQL] AnsiTypeCoercion: return closest convertible type among TypeCollection

### What changes were proposed in this pull request?

Currently, when implicit casting a data type to a `TypeCollection`, Spark returns the first convertible data type among `TypeCollection`.
In ANSI mode, we can make the behavior more reasonable by returning the closet convertible data type in `TypeCollection`.

In details, we first try to find the all the expected types we can implicitly cast:
1. if there is no convertible data types, return None;
2. if there is only one convertible data type, cast input as it;
3. otherwise if there are multiple convertible data types, find the closet data
type among them. If there is no such closet data type, return None.

Note that if the closet type is Float type and the convertible types contains Double type, simply return Double type as the closet type to avoid potential
precision loss on converting the Integral type as Float type.

### Why are the changes needed?

Make the type coercion rule for TypeCollection more reasonable and ANSI compatible.
E.g. returning Long instead of Double for`implicast(int, TypeCollect(Double, Long))`.

From ANSI SQL Spec section 4.33 "SQL-invoked routines"
![Screen Shot 2021-03-17 at 4 05 06 PM](https://user-images.githubusercontent.com/1097932/111434916-5e104e80-86bd-11eb-8b3b-33090a68067d.png)

Section 9.6 "Subject routine determination"
![Screen Shot 2021-03-17 at 1 36 55 PM](https://user-images.githubusercontent.com/1097932/111420336-48445e80-86a8-11eb-9d50-34b325043bdb.png)

Section 10.4 "routine invocation"
![Screen Shot 2021-03-17 at 4 08 41 PM](https://user-images.githubusercontent.com/1097932/111434926-610b3f00-86bd-11eb-8c32-8c7935e055da.png)

### Does this PR introduce _any_ user-facing change?

Yes, in ANSI mode, implicit casting to a `TypeCollection` returns the narrowest convertible data type instead of the first convertible one.

### How was this patch tested?

Unit tests.

Closes #31859 from gengliangwang/implicitCastTypeCollection.

Lead-authored-by: Gengliang Wang <gengliang.wang@databricks.com>
Co-authored-by: Gengliang Wang <ltnwgl@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Gengliang Wang 2021-03-24 15:04:03 +00:00 committed by Wenchen Fan
parent 84df54b495
commit abfd9b23cd
6 changed files with 106 additions and 28 deletions

View file

@ -158,7 +158,10 @@ object AnsiTypeCoercion extends TypeCoercionBase {
case _ if expectedType.acceptsType(inType) => Some(inType) case _ if expectedType.acceptsType(inType) => Some(inType)
// Cast null type (usually from null literals) into target types // Cast null type (usually from null literals) into target types
case (NullType, target) => Some(target.defaultConcreteType) // By default, the result type is `target.defaultConcreteType`. When the target type is
// `TypeCollection`, there is another branch to find the "closet convertible data type" below.
case (NullType, target) if !target.isInstanceOf[TypeCollection] =>
Some(target.defaultConcreteType)
// This type coercion system will allow implicit converting String type literals as other // This type coercion system will allow implicit converting String type literals as other
// primitive types, in case of breaking too many existing Spark SQL queries. // primitive types, in case of breaking too many existing Spark SQL queries.
@ -191,9 +194,35 @@ object AnsiTypeCoercion extends TypeCoercionBase {
case (DateType, TimestampType) => Some(TimestampType) case (DateType, TimestampType) => Some(TimestampType)
// When we reach here, input type is not acceptable for any types in this type collection, // When we reach here, input type is not acceptable for any types in this type collection,
// try to find the first one we can implicitly cast. // first try to find the all the expected types we can implicitly cast:
// 1. if there is no convertible data types, return None;
// 2. if there is only one convertible data type, cast input as it;
// 3. otherwise if there are multiple convertible data types, find the closet convertible
// data type among them. If there is no such a data type, return None.
case (_, TypeCollection(types)) => case (_, TypeCollection(types)) =>
types.flatMap(implicitCast(inType, _, isInputFoldable)).headOption // Since Spark contains special objects like `NumericType` and `DecimalType`, which accepts
// multiple types and they are `AbstractDataType` instead of `DataType`, here we use the
// conversion result their representation.
val convertibleTypes = types.flatMap(implicitCast(inType, _, isInputFoldable))
if (convertibleTypes.isEmpty) {
None
} else {
// find the closet convertible data type, which can be implicit cast to all other
// convertible types.
val closestConvertibleType = convertibleTypes.find { dt =>
convertibleTypes.forall { target =>
implicitCast(dt, target, isInputFoldable = false).isDefined
}
}
// If the closet convertible type is Float type and the convertible types contains Double
// type, simply return Double type as the closet convertible type to avoid potential
// precision loss on converting the Integral type as Float type.
if (closestConvertibleType.contains(FloatType) && convertibleTypes.contains(DoubleType)) {
Some(DoubleType)
} else {
closestConvertibleType
}
}
// Implicit cast between array types. // Implicit cast between array types.
// //

View file

@ -345,8 +345,6 @@ class AnsiTypeCoercionSuite extends AnalysisTest {
} }
test("eligible implicit type cast - TypeCollection") { test("eligible implicit type cast - TypeCollection") {
shouldCast(NullType, TypeCollection(StringType, BinaryType), StringType)
shouldCast(StringType, TypeCollection(StringType, BinaryType), StringType) shouldCast(StringType, TypeCollection(StringType, BinaryType), StringType)
shouldCast(BinaryType, TypeCollection(StringType, BinaryType), BinaryType) shouldCast(BinaryType, TypeCollection(StringType, BinaryType), BinaryType)
shouldCast(StringType, TypeCollection(BinaryType, StringType), StringType) shouldCast(StringType, TypeCollection(BinaryType, StringType), StringType)
@ -356,17 +354,10 @@ class AnsiTypeCoercionSuite extends AnalysisTest {
shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType) shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType)
shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType) shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType)
shouldNotCast(IntegerType, TypeCollection(StringType, BinaryType))
shouldNotCast(IntegerType, TypeCollection(BinaryType, StringType))
shouldCast(DecimalType.SYSTEM_DEFAULT, shouldCast(DecimalType.SYSTEM_DEFAULT,
TypeCollection(IntegerType, DecimalType), DecimalType.SYSTEM_DEFAULT) TypeCollection(IntegerType, DecimalType), DecimalType.SYSTEM_DEFAULT)
shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2)) shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2))
shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2)) shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2))
shouldNotCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType))
shouldNotCastStringInput(TypeCollection(NumericType, BinaryType))
shouldCastStringLiteral(TypeCollection(NumericType, BinaryType), DoubleType)
shouldCast( shouldCast(
ArrayType(StringType, false), ArrayType(StringType, false),
@ -377,10 +368,32 @@ class AnsiTypeCoercionSuite extends AnalysisTest {
ArrayType(StringType, true), ArrayType(StringType, true),
TypeCollection(ArrayType(StringType), StringType), TypeCollection(ArrayType(StringType), StringType),
ArrayType(StringType, true)) ArrayType(StringType, true))
// When there are multiple convertible types in the `TypeCollection`, use the closest
// convertible data type among convertible types.
shouldCast(IntegerType, TypeCollection(BinaryType, FloatType, LongType), LongType)
shouldCast(ShortType, TypeCollection(BinaryType, LongType, IntegerType), IntegerType)
shouldCast(ShortType, TypeCollection(DateType, LongType, IntegerType, DoubleType), IntegerType)
// If the result is Float type and Double type is also among the convertible target types,
// use Double Type instead of Float type.
shouldCast(LongType, TypeCollection(FloatType, DoubleType, StringType), DoubleType)
} }
test("ineligible implicit type cast - TypeCollection") { test("ineligible implicit type cast - TypeCollection") {
shouldNotCast(IntegerType, TypeCollection(StringType, BinaryType))
shouldNotCast(IntegerType, TypeCollection(BinaryType, StringType))
shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType)) shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType))
shouldNotCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType))
shouldNotCastStringInput(TypeCollection(NumericType, BinaryType))
// When there are multiple convertible types in the `TypeCollection` and there is no such
// a data type that can be implicit cast to all the other convertible types in the collection.
Seq(TypeCollection(NumericType, BinaryType),
TypeCollection(NumericType, DecimalType, BinaryType),
TypeCollection(IntegerType, LongType, BooleanType),
TypeCollection(DateType, TimestampType, BooleanType)).foreach { typeCollection =>
shouldNotCastStringLiteral(typeCollection)
shouldNotCast(NullType, typeCollection)
}
} }
test("tightest common bound for types") { test("tightest common bound for types") {

View file

@ -17,9 +17,11 @@ select position('bar' in 'foobarbar'), position(null, 'foobarbar'), position('aa
-- left && right -- left && right
select left("abcd", 2), left("abcd", 5), left("abcd", '2'), left("abcd", null); select left("abcd", 2), left("abcd", 5), left("abcd", '2'), left("abcd", null);
select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a'); select left(null, -2);
select left("abcd", -2), left("abcd", 0), left("abcd", 'a');
select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null); select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null);
select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a'); select right(null, -2);
select right("abcd", -2), right("abcd", 0), right("abcd", 'a');
-- split function -- split function
SELECT split('aa1cc2ee3', '[1-9]+'); SELECT split('aa1cc2ee3', '[1-9]+');

View file

@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite -- Automatically generated by SQLQueryTestSuite
-- Number of queries: 48 -- Number of queries: 50
-- !query -- !query
@ -69,7 +69,16 @@ ab abcd ab NULL
-- !query -- !query
select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a') select left(null, -2)
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
cannot resolve 'substring(NULL, 1, -2)' due to data type mismatch: argument 1 requires (string or binary) type, however, 'NULL' is of null type.; line 1 pos 7
-- !query
select left("abcd", -2), left("abcd", 0), left("abcd", 'a')
-- !query schema -- !query schema
struct<> struct<>
-- !query output -- !query output
@ -87,12 +96,21 @@ cannot resolve 'substring('abcd', (- CAST('2' AS DOUBLE)), 2147483647)' due to d
-- !query -- !query
select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a') select right(null, -2)
-- !query schema -- !query schema
struct<> struct<>
-- !query output -- !query output
org.apache.spark.sql.AnalysisException org.apache.spark.sql.AnalysisException
cannot resolve 'substring('abcd', (- CAST('a' AS DOUBLE)), 2147483647)' due to data type mismatch: argument 2 requires int type, however, '(- CAST('a' AS DOUBLE))' is of double type.; line 1 pos 61 cannot resolve 'substring(NULL, (- -2), 2147483647)' due to data type mismatch: argument 1 requires (string or binary) type, however, 'NULL' is of null type.; line 1 pos 7
-- !query
select right("abcd", -2), right("abcd", 0), right("abcd", 'a')
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
cannot resolve 'substring('abcd', (- CAST('a' AS DOUBLE)), 2147483647)' due to data type mismatch: argument 2 requires int type, however, '(- CAST('a' AS DOUBLE))' is of double type.; line 1 pos 44
-- !query -- !query

View file

@ -130,7 +130,7 @@ select concat_ws(',',10,20,null,30)
struct<> struct<>
-- !query output -- !query output
org.apache.spark.sql.AnalysisException org.apache.spark.sql.AnalysisException
cannot resolve 'concat_ws(',', 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array<string> or string) type, however, '10' is of int type. argument 3 requires (array<string> or string) type, however, '20' is of int type. argument 5 requires (array<string> or string) type, however, '30' is of int type.; line 1 pos 7 cannot resolve 'concat_ws(',', 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array<string> or string) type, however, '10' is of int type. argument 3 requires (array<string> or string) type, however, '20' is of int type. argument 4 requires (array<string> or string) type, however, 'NULL' is of null type. argument 5 requires (array<string> or string) type, however, '30' is of int type.; line 1 pos 7
-- !query -- !query
@ -139,7 +139,7 @@ select concat_ws('',10,20,null,30)
struct<> struct<>
-- !query output -- !query output
org.apache.spark.sql.AnalysisException org.apache.spark.sql.AnalysisException
cannot resolve 'concat_ws('', 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array<string> or string) type, however, '10' is of int type. argument 3 requires (array<string> or string) type, however, '20' is of int type. argument 5 requires (array<string> or string) type, however, '30' is of int type.; line 1 pos 7 cannot resolve 'concat_ws('', 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array<string> or string) type, however, '10' is of int type. argument 3 requires (array<string> or string) type, however, '20' is of int type. argument 4 requires (array<string> or string) type, however, 'NULL' is of null type. argument 5 requires (array<string> or string) type, however, '30' is of int type.; line 1 pos 7
-- !query -- !query
@ -148,7 +148,7 @@ select concat_ws(NULL,10,20,null,30) is null
struct<> struct<>
-- !query output -- !query output
org.apache.spark.sql.AnalysisException org.apache.spark.sql.AnalysisException
cannot resolve 'concat_ws(CAST(NULL AS STRING), 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array<string> or string) type, however, '10' is of int type. argument 3 requires (array<string> or string) type, however, '20' is of int type. argument 5 requires (array<string> or string) type, however, '30' is of int type.; line 1 pos 7 cannot resolve 'concat_ws(CAST(NULL AS STRING), 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array<string> or string) type, however, '10' is of int type. argument 3 requires (array<string> or string) type, however, '20' is of int type. argument 4 requires (array<string> or string) type, however, 'NULL' is of null type. argument 5 requires (array<string> or string) type, however, '30' is of int type.; line 1 pos 7
-- !query -- !query

View file

@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite -- Automatically generated by SQLQueryTestSuite
-- Number of queries: 48 -- Number of queries: 50
-- !query -- !query
@ -69,11 +69,19 @@ ab abcd ab NULL
-- !query -- !query
select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a') select left(null, -2)
-- !query schema -- !query schema
struct<left(NULL, -2):string,left(abcd, -2):string,left(abcd, 0):string,left(abcd, a):string> struct<left(NULL, -2):string>
-- !query output -- !query output
NULL NULL NULL
-- !query
select left("abcd", -2), left("abcd", 0), left("abcd", 'a')
-- !query schema
struct<left(abcd, -2):string,left(abcd, 0):string,left(abcd, a):string>
-- !query output
NULL
-- !query -- !query
@ -85,11 +93,19 @@ cd abcd cd NULL
-- !query -- !query
select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a') select right(null, -2)
-- !query schema -- !query schema
struct<right(NULL, -2):string,right(abcd, -2):string,right(abcd, 0):string,right(abcd, a):string> struct<right(NULL, -2):string>
-- !query output -- !query output
NULL NULL NULL
-- !query
select right("abcd", -2), right("abcd", 0), right("abcd", 'a')
-- !query schema
struct<right(abcd, -2):string,right(abcd, 0):string,right(abcd, a):string>
-- !query output
NULL
-- !query -- !query