[SPARK-24737][SQL] Type coercion between StructTypes.
## What changes were proposed in this pull request? We can support type coercion between `StructType`s where all the internal types are compatible. ## How was this patch tested? Added tests. Author: Takuya UESHIN <ueshin@databricks.com> Closes #21713 from ueshin/issues/SPARK-24737/structtypecoercion.
This commit is contained in:
parent
e71e93aaaa
commit
01fcba2c68
|
@ -102,25 +102,7 @@ object TypeCoercion {
|
|||
case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) =>
|
||||
Some(TimestampType)
|
||||
|
||||
case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) =>
|
||||
Some(StructType(fields1.zip(fields2).map { case (f1, f2) =>
|
||||
// Since `t1.sameType(t2)` is true, two StructTypes have the same DataType
|
||||
// except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`.
|
||||
// - Different names: use f1.name
|
||||
// - Different nullabilities: `nullable` is true iff one of them is nullable.
|
||||
val dataType = findTightestCommonType(f1.dataType, f2.dataType).get
|
||||
StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable)
|
||||
}))
|
||||
|
||||
case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) =>
|
||||
findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2))
|
||||
|
||||
case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) =>
|
||||
val keyType = findTightestCommonType(kt1, kt2)
|
||||
val valueType = findTightestCommonType(vt1, vt2)
|
||||
Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2))
|
||||
|
||||
case _ => None
|
||||
case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType)
|
||||
}
|
||||
|
||||
/** Promotes all the way to StringType. */
|
||||
|
@ -166,6 +148,30 @@ object TypeCoercion {
|
|||
case (l, r) => None
|
||||
}
|
||||
|
||||
private def findTypeForComplex(
|
||||
t1: DataType,
|
||||
t2: DataType,
|
||||
findTypeFunc: (DataType, DataType) => Option[DataType]): Option[DataType] = (t1, t2) match {
|
||||
case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
|
||||
findTypeFunc(et1, et2).map(ArrayType(_, containsNull1 || containsNull2))
|
||||
case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) =>
|
||||
findTypeFunc(kt1, kt2).flatMap { kt =>
|
||||
findTypeFunc(vt1, vt2).map { vt =>
|
||||
MapType(kt, vt, valueContainsNull1 || valueContainsNull2)
|
||||
}
|
||||
}
|
||||
case (StructType(fields1), StructType(fields2)) if fields1.length == fields2.length =>
|
||||
val resolver = SQLConf.get.resolver
|
||||
fields1.zip(fields2).foldLeft(Option(new StructType())) {
|
||||
case (Some(struct), (field1, field2)) if resolver(field1.name, field2.name) =>
|
||||
findTypeFunc(field1.dataType, field2.dataType).map {
|
||||
dt => struct.add(field1.name, dt, field1.nullable || field2.nullable)
|
||||
}
|
||||
case _ => None
|
||||
}
|
||||
case _ => None
|
||||
}
|
||||
|
||||
/**
|
||||
* Case 2 type widening (see the classdoc comment above for TypeCoercion).
|
||||
*
|
||||
|
@ -176,17 +182,7 @@ object TypeCoercion {
|
|||
findTightestCommonType(t1, t2)
|
||||
.orElse(findWiderTypeForDecimal(t1, t2))
|
||||
.orElse(stringPromotion(t1, t2))
|
||||
.orElse((t1, t2) match {
|
||||
case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
|
||||
findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2))
|
||||
case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) =>
|
||||
findWiderTypeForTwo(kt1, kt2).flatMap { kt =>
|
||||
findWiderTypeForTwo(vt1, vt2).map { vt =>
|
||||
MapType(kt, vt, valueContainsNull1 || valueContainsNull2)
|
||||
}
|
||||
}
|
||||
case _ => None
|
||||
})
|
||||
.orElse(findTypeForComplex(t1, t2, findWiderTypeForTwo))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -222,18 +218,7 @@ object TypeCoercion {
|
|||
t2: DataType): Option[DataType] = {
|
||||
findTightestCommonType(t1, t2)
|
||||
.orElse(findWiderTypeForDecimal(t1, t2))
|
||||
.orElse((t1, t2) match {
|
||||
case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
|
||||
findWiderTypeWithoutStringPromotionForTwo(et1, et2)
|
||||
.map(ArrayType(_, containsNull1 || containsNull2))
|
||||
case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) =>
|
||||
findWiderTypeWithoutStringPromotionForTwo(kt1, kt2).flatMap { kt =>
|
||||
findWiderTypeWithoutStringPromotionForTwo(vt1, vt2).map { vt =>
|
||||
MapType(kt, vt, valueContainsNull1 || valueContainsNull2)
|
||||
}
|
||||
}
|
||||
case _ => None
|
||||
})
|
||||
.orElse(findTypeForComplex(t1, t2, findWiderTypeWithoutStringPromotionForTwo))
|
||||
}
|
||||
|
||||
def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
|
||||
|
|
|
@ -54,7 +54,7 @@ class TypeCoercionSuite extends AnalysisTest {
|
|||
// | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType |
|
||||
// | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X |
|
||||
// +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+
|
||||
// Note: StructType* is castable only when the internal child types also match; otherwise, not castable.
|
||||
// Note: StructType* is castable when all the internal child types are castable according to the table.
|
||||
// Note: ArrayType* is castable when the element type is castable according to the table.
|
||||
// Note: MapType* is castable when both the key type and the value type are castable according to the table.
|
||||
// scalastyle:on line.size.limit
|
||||
|
@ -397,7 +397,7 @@ class TypeCoercionSuite extends AnalysisTest {
|
|||
widenTest(
|
||||
StructType(Seq(StructField("a", IntegerType, nullable = false))),
|
||||
StructType(Seq(StructField("a", DoubleType, nullable = false))),
|
||||
None)
|
||||
Some(StructType(Seq(StructField("a", DoubleType, nullable = false)))))
|
||||
|
||||
widenTest(
|
||||
StructType(Seq(StructField("a", IntegerType, nullable = false))),
|
||||
|
@ -454,15 +454,18 @@ class TypeCoercionSuite extends AnalysisTest {
|
|||
def widenTestWithStringPromotion(
|
||||
t1: DataType,
|
||||
t2: DataType,
|
||||
expected: Option[DataType]): Unit = {
|
||||
checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected)
|
||||
expected: Option[DataType],
|
||||
isSymmetric: Boolean = true): Unit = {
|
||||
checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected, isSymmetric)
|
||||
}
|
||||
|
||||
def widenTestWithoutStringPromotion(
|
||||
t1: DataType,
|
||||
t2: DataType,
|
||||
expected: Option[DataType]): Unit = {
|
||||
checkWidenType(TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected)
|
||||
expected: Option[DataType],
|
||||
isSymmetric: Boolean = true): Unit = {
|
||||
checkWidenType(
|
||||
TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected, isSymmetric)
|
||||
}
|
||||
|
||||
// Decimal
|
||||
|
@ -492,6 +495,10 @@ class TypeCoercionSuite extends AnalysisTest {
|
|||
ArrayType(MapType(IntegerType, FloatType), containsNull = false),
|
||||
ArrayType(MapType(LongType, DoubleType), containsNull = false),
|
||||
Some(ArrayType(MapType(LongType, DoubleType), containsNull = false)))
|
||||
widenTestWithStringPromotion(
|
||||
ArrayType(new StructType().add("num", ShortType), containsNull = false),
|
||||
ArrayType(new StructType().add("num", LongType), containsNull = false),
|
||||
Some(ArrayType(new StructType().add("num", LongType), containsNull = false)))
|
||||
|
||||
// MapType
|
||||
widenTestWithStringPromotion(
|
||||
|
@ -506,6 +513,64 @@ class TypeCoercionSuite extends AnalysisTest {
|
|||
MapType(IntegerType, MapType(ShortType, TimestampType), valueContainsNull = false),
|
||||
MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false),
|
||||
Some(MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false)))
|
||||
widenTestWithStringPromotion(
|
||||
MapType(IntegerType, new StructType().add("num", ShortType), valueContainsNull = false),
|
||||
MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false),
|
||||
Some(MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false)))
|
||||
|
||||
// StructType
|
||||
widenTestWithStringPromotion(
|
||||
new StructType()
|
||||
.add("num", ShortType, nullable = true).add("ts", StringType, nullable = false),
|
||||
new StructType()
|
||||
.add("num", DoubleType, nullable = false).add("ts", TimestampType, nullable = true),
|
||||
Some(new StructType()
|
||||
.add("num", DoubleType, nullable = true).add("ts", StringType, nullable = true)))
|
||||
widenTestWithStringPromotion(
|
||||
new StructType()
|
||||
.add("arr", ArrayType(ShortType, containsNull = false), nullable = false),
|
||||
new StructType()
|
||||
.add("arr", ArrayType(DoubleType, containsNull = true), nullable = false),
|
||||
Some(new StructType()
|
||||
.add("arr", ArrayType(DoubleType, containsNull = true), nullable = false)))
|
||||
widenTestWithStringPromotion(
|
||||
new StructType()
|
||||
.add("map", MapType(ShortType, TimestampType, valueContainsNull = true), nullable = false),
|
||||
new StructType()
|
||||
.add("map", MapType(DoubleType, StringType, valueContainsNull = false), nullable = false),
|
||||
Some(new StructType()
|
||||
.add("map", MapType(DoubleType, StringType, valueContainsNull = true), nullable = false)))
|
||||
|
||||
widenTestWithStringPromotion(
|
||||
new StructType().add("num", IntegerType),
|
||||
new StructType().add("num", LongType).add("str", StringType),
|
||||
None)
|
||||
widenTestWithoutStringPromotion(
|
||||
new StructType().add("num", IntegerType),
|
||||
new StructType().add("num", LongType).add("str", StringType),
|
||||
None)
|
||||
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
|
||||
widenTestWithStringPromotion(
|
||||
new StructType().add("a", IntegerType),
|
||||
new StructType().add("A", LongType),
|
||||
None)
|
||||
widenTestWithoutStringPromotion(
|
||||
new StructType().add("a", IntegerType),
|
||||
new StructType().add("A", LongType),
|
||||
None)
|
||||
}
|
||||
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
|
||||
widenTestWithStringPromotion(
|
||||
new StructType().add("a", IntegerType),
|
||||
new StructType().add("A", LongType),
|
||||
Some(new StructType().add("a", LongType)),
|
||||
isSymmetric = false)
|
||||
widenTestWithoutStringPromotion(
|
||||
new StructType().add("a", IntegerType),
|
||||
new StructType().add("A", LongType),
|
||||
Some(new StructType().add("a", LongType)),
|
||||
isSymmetric = false)
|
||||
}
|
||||
|
||||
// Without string promotion
|
||||
widenTestWithoutStringPromotion(IntegerType, StringType, None)
|
||||
|
@ -520,6 +585,14 @@ class TypeCoercionSuite extends AnalysisTest {
|
|||
MapType(StringType, IntegerType), MapType(TimestampType, IntegerType), None)
|
||||
widenTestWithoutStringPromotion(
|
||||
MapType(IntegerType, StringType), MapType(IntegerType, TimestampType), None)
|
||||
widenTestWithoutStringPromotion(
|
||||
new StructType().add("a", IntegerType),
|
||||
new StructType().add("a", StringType),
|
||||
None)
|
||||
widenTestWithoutStringPromotion(
|
||||
new StructType().add("a", StringType),
|
||||
new StructType().add("a", IntegerType),
|
||||
None)
|
||||
|
||||
// String promotion
|
||||
widenTestWithStringPromotion(IntegerType, StringType, Some(StringType))
|
||||
|
@ -544,6 +617,14 @@ class TypeCoercionSuite extends AnalysisTest {
|
|||
MapType(IntegerType, StringType),
|
||||
MapType(IntegerType, TimestampType),
|
||||
Some(MapType(IntegerType, StringType)))
|
||||
widenTestWithStringPromotion(
|
||||
new StructType().add("a", IntegerType),
|
||||
new StructType().add("a", StringType),
|
||||
Some(new StructType().add("a", StringType)))
|
||||
widenTestWithStringPromotion(
|
||||
new StructType().add("a", StringType),
|
||||
new StructType().add("a", IntegerType),
|
||||
Some(new StructType().add("a", StringType)))
|
||||
}
|
||||
|
||||
private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
|
||||
|
|
Loading…
Reference in a new issue