[SPARK-24737][SQL] Type coercion between StructTypes.

## What changes were proposed in this pull request?

We can support type coercion between `StructType`s where all the internal types are compatible.

## How was this patch tested?

Added tests.

Author: Takuya UESHIN <ueshin@databricks.com>

Closes #21713 from ueshin/issues/SPARK-24737/structtypecoercion.
This commit is contained in:
Takuya UESHIN 2018-07-06 11:10:50 +08:00 committed by hyukjinkwon
parent e71e93aaaa
commit 01fcba2c68
2 changed files with 114 additions and 48 deletions

View file

@ -102,25 +102,7 @@ object TypeCoercion {
case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) =>
Some(TimestampType)
case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) =>
Some(StructType(fields1.zip(fields2).map { case (f1, f2) =>
// Since `t1.sameType(t2)` is true, two StructTypes have the same DataType
// except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`.
// - Different names: use f1.name
// - Different nullabilities: `nullable` is true iff one of them is nullable.
val dataType = findTightestCommonType(f1.dataType, f2.dataType).get
StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable)
}))
case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) =>
findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2))
case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) =>
val keyType = findTightestCommonType(kt1, kt2)
val valueType = findTightestCommonType(vt1, vt2)
Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2))
case _ => None
case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType)
}
/** Promotes all the way to StringType. */
@ -166,6 +148,30 @@ object TypeCoercion {
case (l, r) => None
}
private def findTypeForComplex(
t1: DataType,
t2: DataType,
findTypeFunc: (DataType, DataType) => Option[DataType]): Option[DataType] = (t1, t2) match {
case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
findTypeFunc(et1, et2).map(ArrayType(_, containsNull1 || containsNull2))
case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) =>
findTypeFunc(kt1, kt2).flatMap { kt =>
findTypeFunc(vt1, vt2).map { vt =>
MapType(kt, vt, valueContainsNull1 || valueContainsNull2)
}
}
case (StructType(fields1), StructType(fields2)) if fields1.length == fields2.length =>
val resolver = SQLConf.get.resolver
fields1.zip(fields2).foldLeft(Option(new StructType())) {
case (Some(struct), (field1, field2)) if resolver(field1.name, field2.name) =>
findTypeFunc(field1.dataType, field2.dataType).map {
dt => struct.add(field1.name, dt, field1.nullable || field2.nullable)
}
case _ => None
}
case _ => None
}
/**
* Case 2 type widening (see the classdoc comment above for TypeCoercion).
*
@ -176,17 +182,7 @@ object TypeCoercion {
findTightestCommonType(t1, t2)
.orElse(findWiderTypeForDecimal(t1, t2))
.orElse(stringPromotion(t1, t2))
.orElse((t1, t2) match {
case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2))
case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) =>
findWiderTypeForTwo(kt1, kt2).flatMap { kt =>
findWiderTypeForTwo(vt1, vt2).map { vt =>
MapType(kt, vt, valueContainsNull1 || valueContainsNull2)
}
}
case _ => None
})
.orElse(findTypeForComplex(t1, t2, findWiderTypeForTwo))
}
/**
@ -222,18 +218,7 @@ object TypeCoercion {
t2: DataType): Option[DataType] = {
findTightestCommonType(t1, t2)
.orElse(findWiderTypeForDecimal(t1, t2))
.orElse((t1, t2) match {
case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
findWiderTypeWithoutStringPromotionForTwo(et1, et2)
.map(ArrayType(_, containsNull1 || containsNull2))
case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) =>
findWiderTypeWithoutStringPromotionForTwo(kt1, kt2).flatMap { kt =>
findWiderTypeWithoutStringPromotionForTwo(vt1, vt2).map { vt =>
MapType(kt, vt, valueContainsNull1 || valueContainsNull2)
}
}
case _ => None
})
.orElse(findTypeForComplex(t1, t2, findWiderTypeWithoutStringPromotionForTwo))
}
def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {

View file

@ -54,7 +54,7 @@ class TypeCoercionSuite extends AnalysisTest {
// | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType |
// | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X |
// +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+
// Note: StructType* is castable only when the internal child types also match; otherwise, not castable.
// Note: StructType* is castable when all the internal child types are castable according to the table.
// Note: ArrayType* is castable when the element type is castable according to the table.
// Note: MapType* is castable when both the key type and the value type are castable according to the table.
// scalastyle:on line.size.limit
@ -397,7 +397,7 @@ class TypeCoercionSuite extends AnalysisTest {
widenTest(
StructType(Seq(StructField("a", IntegerType, nullable = false))),
StructType(Seq(StructField("a", DoubleType, nullable = false))),
None)
Some(StructType(Seq(StructField("a", DoubleType, nullable = false)))))
widenTest(
StructType(Seq(StructField("a", IntegerType, nullable = false))),
@ -454,15 +454,18 @@ class TypeCoercionSuite extends AnalysisTest {
def widenTestWithStringPromotion(
t1: DataType,
t2: DataType,
expected: Option[DataType]): Unit = {
checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected)
expected: Option[DataType],
isSymmetric: Boolean = true): Unit = {
checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected, isSymmetric)
}
def widenTestWithoutStringPromotion(
t1: DataType,
t2: DataType,
expected: Option[DataType]): Unit = {
checkWidenType(TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected)
expected: Option[DataType],
isSymmetric: Boolean = true): Unit = {
checkWidenType(
TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected, isSymmetric)
}
// Decimal
@ -492,6 +495,10 @@ class TypeCoercionSuite extends AnalysisTest {
ArrayType(MapType(IntegerType, FloatType), containsNull = false),
ArrayType(MapType(LongType, DoubleType), containsNull = false),
Some(ArrayType(MapType(LongType, DoubleType), containsNull = false)))
widenTestWithStringPromotion(
ArrayType(new StructType().add("num", ShortType), containsNull = false),
ArrayType(new StructType().add("num", LongType), containsNull = false),
Some(ArrayType(new StructType().add("num", LongType), containsNull = false)))
// MapType
widenTestWithStringPromotion(
@ -506,6 +513,64 @@ class TypeCoercionSuite extends AnalysisTest {
MapType(IntegerType, MapType(ShortType, TimestampType), valueContainsNull = false),
MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false),
Some(MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false)))
widenTestWithStringPromotion(
MapType(IntegerType, new StructType().add("num", ShortType), valueContainsNull = false),
MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false),
Some(MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false)))
// StructType
widenTestWithStringPromotion(
new StructType()
.add("num", ShortType, nullable = true).add("ts", StringType, nullable = false),
new StructType()
.add("num", DoubleType, nullable = false).add("ts", TimestampType, nullable = true),
Some(new StructType()
.add("num", DoubleType, nullable = true).add("ts", StringType, nullable = true)))
widenTestWithStringPromotion(
new StructType()
.add("arr", ArrayType(ShortType, containsNull = false), nullable = false),
new StructType()
.add("arr", ArrayType(DoubleType, containsNull = true), nullable = false),
Some(new StructType()
.add("arr", ArrayType(DoubleType, containsNull = true), nullable = false)))
widenTestWithStringPromotion(
new StructType()
.add("map", MapType(ShortType, TimestampType, valueContainsNull = true), nullable = false),
new StructType()
.add("map", MapType(DoubleType, StringType, valueContainsNull = false), nullable = false),
Some(new StructType()
.add("map", MapType(DoubleType, StringType, valueContainsNull = true), nullable = false)))
widenTestWithStringPromotion(
new StructType().add("num", IntegerType),
new StructType().add("num", LongType).add("str", StringType),
None)
widenTestWithoutStringPromotion(
new StructType().add("num", IntegerType),
new StructType().add("num", LongType).add("str", StringType),
None)
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
widenTestWithStringPromotion(
new StructType().add("a", IntegerType),
new StructType().add("A", LongType),
None)
widenTestWithoutStringPromotion(
new StructType().add("a", IntegerType),
new StructType().add("A", LongType),
None)
}
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
widenTestWithStringPromotion(
new StructType().add("a", IntegerType),
new StructType().add("A", LongType),
Some(new StructType().add("a", LongType)),
isSymmetric = false)
widenTestWithoutStringPromotion(
new StructType().add("a", IntegerType),
new StructType().add("A", LongType),
Some(new StructType().add("a", LongType)),
isSymmetric = false)
}
// Without string promotion
widenTestWithoutStringPromotion(IntegerType, StringType, None)
@ -520,6 +585,14 @@ class TypeCoercionSuite extends AnalysisTest {
MapType(StringType, IntegerType), MapType(TimestampType, IntegerType), None)
widenTestWithoutStringPromotion(
MapType(IntegerType, StringType), MapType(IntegerType, TimestampType), None)
widenTestWithoutStringPromotion(
new StructType().add("a", IntegerType),
new StructType().add("a", StringType),
None)
widenTestWithoutStringPromotion(
new StructType().add("a", StringType),
new StructType().add("a", IntegerType),
None)
// String promotion
widenTestWithStringPromotion(IntegerType, StringType, Some(StringType))
@ -544,6 +617,14 @@ class TypeCoercionSuite extends AnalysisTest {
MapType(IntegerType, StringType),
MapType(IntegerType, TimestampType),
Some(MapType(IntegerType, StringType)))
widenTestWithStringPromotion(
new StructType().add("a", IntegerType),
new StructType().add("a", StringType),
Some(new StructType().add("a", StringType)))
widenTestWithStringPromotion(
new StructType().add("a", StringType),
new StructType().add("a", IntegerType),
Some(new StructType().add("a", StringType)))
}
private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {