[SPARK-12201][SQL] add type coercion rule for greatest/least
checked with hive, greatest/least should cast their children to a tightest common type, i.e. `(int, long) => long`, `(int, string) => error`, `(decimal(10,5), decimal(5, 10)) => error` Author: Wenchen Fan <wenchen@databricks.com> Closes #10196 from cloud-fan/type-coercion.
This commit is contained in:
parent
75c60bf4ba
commit
381f17b540
|
@ -594,6 +594,20 @@ object HiveTypeCoercion {
|
|||
case None => c
|
||||
}
|
||||
|
||||
case g @ Greatest(children) if children.map(_.dataType).distinct.size > 1 =>
|
||||
val types = children.map(_.dataType)
|
||||
findTightestCommonType(types) match {
|
||||
case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType)))
|
||||
case None => g
|
||||
}
|
||||
|
||||
case l @ Least(children) if children.map(_.dataType).distinct.size > 1 =>
|
||||
val types = children.map(_.dataType)
|
||||
findTightestCommonType(types) match {
|
||||
case Some(finalDataType) => Least(children.map(Cast(_, finalDataType)))
|
||||
case None => l
|
||||
}
|
||||
|
||||
case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType =>
|
||||
NaNvl(l, Cast(r, DoubleType))
|
||||
case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType =>
|
||||
|
|
|
@ -32,6 +32,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
|
|||
'intField.int,
|
||||
'stringField.string,
|
||||
'booleanField.boolean,
|
||||
'decimalField.decimal(8, 0),
|
||||
'arrayField.array(StringType),
|
||||
'mapField.map(StringType, LongType))
|
||||
|
||||
|
@ -189,4 +190,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
|
|||
assertError(Round('intField, 'mapField), "requires int type")
|
||||
assertError(Round('booleanField, 'intField), "requires numeric type")
|
||||
}
|
||||
|
||||
test("check types for Greatest/Least") {
|
||||
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
|
||||
assertError(operator(Seq('booleanField)), "requires at least 2 arguments")
|
||||
assertError(operator(Seq('intField, 'stringField)), "should all have the same type")
|
||||
assertError(operator(Seq('intField, 'decimalField)), "should all have the same type")
|
||||
assertError(operator(Seq('mapField, 'mapField)), "does not support ordering")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -251,6 +251,29 @@ class HiveTypeCoercionSuite extends PlanTest {
|
|||
:: Nil))
|
||||
}
|
||||
|
||||
test("greatest/least cast") {
|
||||
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
|
||||
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
|
||||
operator(Literal(1.0)
|
||||
:: Literal(1)
|
||||
:: Literal.create(1.0, FloatType)
|
||||
:: Nil),
|
||||
operator(Cast(Literal(1.0), DoubleType)
|
||||
:: Cast(Literal(1), DoubleType)
|
||||
:: Cast(Literal.create(1.0, FloatType), DoubleType)
|
||||
:: Nil))
|
||||
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
|
||||
operator(Literal(1L)
|
||||
:: Literal(1)
|
||||
:: Literal(new java.math.BigDecimal("1000000000000000000000"))
|
||||
:: Nil),
|
||||
operator(Cast(Literal(1L), DecimalType(22, 0))
|
||||
:: Cast(Literal(1), DecimalType(22, 0))
|
||||
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
|
||||
:: Nil))
|
||||
}
|
||||
}
|
||||
|
||||
test("nanvl casts") {
|
||||
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
|
||||
NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)),
|
||||
|
|
Loading…
Reference in a new issue