[SPARK-16714][SPARK-16735][SPARK-16646] array, map, greatest, least's type coercion should handle decimal type
## What changes were proposed in this pull request? Here is a table about the behaviours of `array`/`map` and `greatest`/`least` in Hive, MySQL and Postgres: | |Hive|MySQL|Postgres| |---|---|---|---|---| |`array`/`map`|can find a wider type with decimal type arguments, and will truncate the wider decimal type if necessary|can find a wider type with decimal type arguments, no truncation problem|can find a wider type with decimal type arguments, no truncation problem| |`greatest`/`least`|can find a wider type with decimal type arguments, and truncate if necessary, but can't do string promotion|can find a wider type with decimal type arguments, no truncation problem, but can't do string promotion|can find a wider type with decimal type arguments, no truncation problem, but can't do string promotion| I think these behaviours makes sense and Spark SQL should follow them. This PR fixes `array` and `map` by using `findWiderCommonType` to get the wider type. This PR fixes `greatest` and `least` by add a `findWiderTypeWithoutStringPromotion`, which provides similar semantic of `findWiderCommonType`, but without string promotion. ## How was this patch tested? new tests in `TypeCoersionSuite` Author: Wenchen Fan <wenchen@databricks.com> Author: Yin Huai <yhuai@databricks.com> Closes #14439 from cloud-fan/bug.
This commit is contained in:
parent
639df046a2
commit
b55f34370f
|
@ -108,18 +108,6 @@ object TypeCoercion {
|
|||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Similar to [[findTightestCommonType]], if can not find the TightestCommonType, try to use
|
||||
* [[findTightestCommonTypeToString]] to find the TightestCommonType.
|
||||
*/
|
||||
private def findTightestCommonTypeAndPromoteToString(types: Seq[DataType]): Option[DataType] = {
|
||||
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
|
||||
case None => None
|
||||
case Some(d) =>
|
||||
findTightestCommonTypeToString(d, c)
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the tightest common type of a set of types by continuously applying
|
||||
* `findTightestCommonTypeOfTwo` on these types.
|
||||
|
@ -157,6 +145,28 @@ object TypeCoercion {
|
|||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Similar to [[findWiderCommonType]], but can't promote to string. This is also similar to
|
||||
* [[findTightestCommonType]], but can handle decimal types. If the wider decimal type exceeds
|
||||
* system limitation, this rule will truncate the decimal type before return it.
|
||||
*/
|
||||
private def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
|
||||
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
|
||||
case Some(d) => findTightestCommonTypeOfTwo(d, c).orElse((d, c) match {
|
||||
case (t1: DecimalType, t2: DecimalType) =>
|
||||
Some(DecimalPrecision.widerDecimalType(t1, t2))
|
||||
case (t: IntegralType, d: DecimalType) =>
|
||||
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
|
||||
case (d: DecimalType, t: IntegralType) =>
|
||||
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
|
||||
case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) =>
|
||||
Some(DoubleType)
|
||||
case _ => None
|
||||
})
|
||||
case None => None
|
||||
})
|
||||
}
|
||||
|
||||
private def haveSameType(exprs: Seq[Expression]): Boolean =
|
||||
exprs.map(_.dataType).distinct.length == 1
|
||||
|
||||
|
@ -440,7 +450,7 @@ object TypeCoercion {
|
|||
|
||||
case a @ CreateArray(children) if !haveSameType(children) =>
|
||||
val types = children.map(_.dataType)
|
||||
findTightestCommonTypeAndPromoteToString(types) match {
|
||||
findWiderCommonType(types) match {
|
||||
case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
|
||||
case None => a
|
||||
}
|
||||
|
@ -451,7 +461,7 @@ object TypeCoercion {
|
|||
m.keys
|
||||
} else {
|
||||
val types = m.keys.map(_.dataType)
|
||||
findTightestCommonTypeAndPromoteToString(types) match {
|
||||
findWiderCommonType(types) match {
|
||||
case Some(finalDataType) => m.keys.map(Cast(_, finalDataType))
|
||||
case None => m.keys
|
||||
}
|
||||
|
@ -461,7 +471,7 @@ object TypeCoercion {
|
|||
m.values
|
||||
} else {
|
||||
val types = m.values.map(_.dataType)
|
||||
findTightestCommonTypeAndPromoteToString(types) match {
|
||||
findWiderCommonType(types) match {
|
||||
case Some(finalDataType) => m.values.map(Cast(_, finalDataType))
|
||||
case None => m.values
|
||||
}
|
||||
|
@ -494,16 +504,19 @@ object TypeCoercion {
|
|||
case None => c
|
||||
}
|
||||
|
||||
// When finding wider type for `Greatest` and `Least`, we should handle decimal types even if
|
||||
// we need to truncate, but we should not promote one side to string if the other side is
|
||||
// string.g
|
||||
case g @ Greatest(children) if !haveSameType(children) =>
|
||||
val types = children.map(_.dataType)
|
||||
findTightestCommonType(types) match {
|
||||
findWiderTypeWithoutStringPromotion(types) match {
|
||||
case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType)))
|
||||
case None => g
|
||||
}
|
||||
|
||||
case l @ Least(children) if !haveSameType(children) =>
|
||||
val types = children.map(_.dataType)
|
||||
findTightestCommonType(types) match {
|
||||
findWiderTypeWithoutStringPromotion(types) match {
|
||||
case Some(finalDataType) => Least(children.map(Cast(_, finalDataType)))
|
||||
case None => l
|
||||
}
|
||||
|
|
|
@ -209,7 +209,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
|
|||
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
|
||||
assertError(operator(Seq('booleanField)), "requires at least 2 arguments")
|
||||
assertError(operator(Seq('intField, 'stringField)), "should all have the same type")
|
||||
assertError(operator(Seq('intField, 'decimalField)), "should all have the same type")
|
||||
assertError(operator(Seq('mapField, 'mapField)), "does not support ordering")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -283,6 +283,24 @@ class TypeCoercionSuite extends PlanTest {
|
|||
:: Cast(Literal(1), StringType)
|
||||
:: Cast(Literal("a"), StringType)
|
||||
:: Nil))
|
||||
|
||||
ruleTest(TypeCoercion.FunctionArgumentConversion,
|
||||
CreateArray(Literal.create(null, DecimalType(5, 3))
|
||||
:: Literal(1)
|
||||
:: Nil),
|
||||
CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(13, 3))
|
||||
:: Literal(1).cast(DecimalType(13, 3))
|
||||
:: Nil))
|
||||
|
||||
ruleTest(TypeCoercion.FunctionArgumentConversion,
|
||||
CreateArray(Literal.create(null, DecimalType(5, 3))
|
||||
:: Literal.create(null, DecimalType(22, 10))
|
||||
:: Literal.create(null, DecimalType(38, 38))
|
||||
:: Nil),
|
||||
CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(38, 38))
|
||||
:: Literal.create(null, DecimalType(22, 10)).cast(DecimalType(38, 38))
|
||||
:: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38))
|
||||
:: Nil))
|
||||
}
|
||||
|
||||
test("CreateMap casts") {
|
||||
|
@ -298,6 +316,17 @@ class TypeCoercionSuite extends PlanTest {
|
|||
:: Cast(Literal.create(2.0, FloatType), FloatType)
|
||||
:: Literal("b")
|
||||
:: Nil))
|
||||
ruleTest(TypeCoercion.FunctionArgumentConversion,
|
||||
CreateMap(Literal.create(null, DecimalType(5, 3))
|
||||
:: Literal("a")
|
||||
:: Literal.create(2.0, FloatType)
|
||||
:: Literal("b")
|
||||
:: Nil),
|
||||
CreateMap(Literal.create(null, DecimalType(5, 3)).cast(DoubleType)
|
||||
:: Literal("a")
|
||||
:: Literal.create(2.0, FloatType).cast(DoubleType)
|
||||
:: Literal("b")
|
||||
:: Nil))
|
||||
// type coercion for map values
|
||||
ruleTest(TypeCoercion.FunctionArgumentConversion,
|
||||
CreateMap(Literal(1)
|
||||
|
@ -310,6 +339,17 @@ class TypeCoercionSuite extends PlanTest {
|
|||
:: Literal(2)
|
||||
:: Cast(Literal(3.0), StringType)
|
||||
:: Nil))
|
||||
ruleTest(TypeCoercion.FunctionArgumentConversion,
|
||||
CreateMap(Literal(1)
|
||||
:: Literal.create(null, DecimalType(38, 0))
|
||||
:: Literal(2)
|
||||
:: Literal.create(null, DecimalType(38, 38))
|
||||
:: Nil),
|
||||
CreateMap(Literal(1)
|
||||
:: Literal.create(null, DecimalType(38, 0)).cast(DecimalType(38, 38))
|
||||
:: Literal(2)
|
||||
:: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38))
|
||||
:: Nil))
|
||||
// type coercion for both map keys and values
|
||||
ruleTest(TypeCoercion.FunctionArgumentConversion,
|
||||
CreateMap(Literal(1)
|
||||
|
@ -344,6 +384,33 @@ class TypeCoercionSuite extends PlanTest {
|
|||
:: Cast(Literal(1), DecimalType(22, 0))
|
||||
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
|
||||
:: Nil))
|
||||
ruleTest(TypeCoercion.FunctionArgumentConversion,
|
||||
operator(Literal(1.0)
|
||||
:: Literal.create(null, DecimalType(10, 5))
|
||||
:: Literal(1)
|
||||
:: Nil),
|
||||
operator(Literal(1.0).cast(DoubleType)
|
||||
:: Literal.create(null, DecimalType(10, 5)).cast(DoubleType)
|
||||
:: Literal(1).cast(DoubleType)
|
||||
:: Nil))
|
||||
ruleTest(TypeCoercion.FunctionArgumentConversion,
|
||||
operator(Literal.create(null, DecimalType(15, 0))
|
||||
:: Literal.create(null, DecimalType(10, 5))
|
||||
:: Literal(1)
|
||||
:: Nil),
|
||||
operator(Literal.create(null, DecimalType(15, 0)).cast(DecimalType(20, 5))
|
||||
:: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5))
|
||||
:: Literal(1).cast(DecimalType(20, 5))
|
||||
:: Nil))
|
||||
ruleTest(TypeCoercion.FunctionArgumentConversion,
|
||||
operator(Literal.create(2L, LongType)
|
||||
:: Literal(1)
|
||||
:: Literal.create(null, DecimalType(10, 5))
|
||||
:: Nil),
|
||||
operator(Literal.create(2L, LongType).cast(DecimalType(25, 5))
|
||||
:: Literal(1).cast(DecimalType(25, 5))
|
||||
:: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(25, 5))
|
||||
:: Nil))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue