[SPARK-19435][SQL] Type coercion between ArrayTypes

## What changes were proposed in this pull request?

This PR proposes to support type coercion between `ArrayType`s where the element types are compatible.

**Before**

```
Seq(Array(1)).toDF("a").selectExpr("greatest(a, array(1D))")
org.apache.spark.sql.AnalysisException: cannot resolve 'greatest(`a`, array(1.0D))' due to data type mismatch: The expressions should all have the same type, got GREATEST(array<int>, array<double>).; line 1 pos 0;

Seq(Array(1)).toDF("a").selectExpr("least(a, array(1D))")
org.apache.spark.sql.AnalysisException: cannot resolve 'least(`a`, array(1.0D))' due to data type mismatch: The expressions should all have the same type, got LEAST(array<int>, array<double>).; line 1 pos 0;

sql("SELECT * FROM values (array(0)), (array(1D)) as data(a)")
org.apache.spark.sql.AnalysisException: incompatible types found in column a for inline table; line 1 pos 14

Seq(Array(1)).toDF("a").union(Seq(Array(1D)).toDF("b"))
org.apache.spark.sql.AnalysisException: Union can only be performed on tables with the compatible column types. ArrayType(DoubleType,false) <> ArrayType(IntegerType,false) at the first column of the second table;;

sql("SELECT IF(1=1, array(1), array(1D))")
org.apache.spark.sql.AnalysisException: cannot resolve '(IF((1 = 1), array(1), array(1.0D)))' due to data type mismatch: differing types in '(IF((1 = 1), array(1), array(1.0D)))' (array<int> and array<double>).; line 1 pos 7;
```

**After**

```scala
Seq(Array(1)).toDF("a").selectExpr("greatest(a, array(1D))")
res5: org.apache.spark.sql.DataFrame = [greatest(a, array(1.0)): array<double>]

Seq(Array(1)).toDF("a").selectExpr("least(a, array(1D))")
res6: org.apache.spark.sql.DataFrame = [least(a, array(1.0)): array<double>]

sql("SELECT * FROM values (array(0)), (array(1D)) as data(a)")
res8: org.apache.spark.sql.DataFrame = [a: array<double>]

Seq(Array(1)).toDF("a").union(Seq(Array(1D)).toDF("b"))
res10: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [a: array<double>]

sql("SELECT IF(1=1, array(1), array(1D))")
res15: org.apache.spark.sql.DataFrame = [(IF((1 = 1), array(1), array(1.0))): array<double>]
```

## How was this patch tested?

Unit tests in `TypeCoercion` and Jenkins tests and

building with scala 2.10

```scala
./dev/change-scala-version.sh 2.10
./build/mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package
```

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #16777 from HyukjinKwon/SPARK-19435.
This commit is contained in:
hyukjinkwon 2017-02-13 13:10:57 -08:00 committed by Xiao Li
parent 905fdf0c24
commit 9af8f743b0
2 changed files with 127 additions and 50 deletions

View file

@ -101,13 +101,11 @@ object TypeCoercion {
case _ => None
}
/** Similar to [[findTightestCommonType]], but can promote all the way to StringType. */
def findTightestCommonTypeToString(left: DataType, right: DataType): Option[DataType] = {
findTightestCommonType(left, right).orElse((left, right) match {
case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType)
case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType)
case _ => None
})
/** Promotes all the way to StringType. */
private def stringPromotion(dt1: DataType, dt2: DataType): Option[DataType] = (dt1, dt2) match {
case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType)
case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType)
case _ => None
}
/**
@ -117,7 +115,55 @@ object TypeCoercion {
* loss of precision when widening decimal and double, and promotion to string.
*/
private[analysis] def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = {
(t1, t2) match {
findTightestCommonType(t1, t2)
.orElse(findWiderTypeForDecimal(t1, t2))
.orElse(stringPromotion(t1, t2))
.orElse((t1, t2) match {
case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2))
case _ => None
})
}
private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case Some(d) => findWiderTypeForTwo(d, c)
case None => None
})
}
/**
* Similar to [[findWiderTypeForTwo]] that can handle decimal types, but can't promote to
* string. If the wider decimal type exceeds system limitation, this rule will truncate
* the decimal type before return it.
*/
private[analysis] def findWiderTypeWithoutStringPromotionForTwo(
t1: DataType,
t2: DataType): Option[DataType] = {
findTightestCommonType(t1, t2)
.orElse(findWiderTypeForDecimal(t1, t2))
.orElse((t1, t2) match {
case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
findWiderTypeWithoutStringPromotionForTwo(et1, et2)
.map(ArrayType(_, containsNull1 || containsNull2))
case _ => None
})
}
def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c)
case None => None
})
}
/**
* Finds a wider type when one or both types are decimals. If the wider decimal type exceeds
* system limitation, this rule will truncate the decimal type. If a decimal and other fractional
* types are compared, returns a double type.
*/
private def findWiderTypeForDecimal(dt1: DataType, dt2: DataType): Option[DataType] = {
(dt1, dt2) match {
case (t1: DecimalType, t2: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(t1, t2))
case (t: IntegralType, d: DecimalType) =>
@ -126,40 +172,10 @@ object TypeCoercion {
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) =>
Some(DoubleType)
case _ =>
findTightestCommonTypeToString(t1, t2)
case _ => None
}
}
private def findWiderCommonType(types: Seq[DataType]) = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case Some(d) => findWiderTypeForTwo(d, c)
case None => None
})
}
/**
* Similar to [[findWiderCommonType]] that can handle decimal types, but can't promote to
* string. If the wider decimal type exceeds system limitation, this rule will truncate
* the decimal type before return it.
*/
def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case Some(d) => findTightestCommonType(d, c).orElse((d, c) match {
case (t1: DecimalType, t2: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(t1, t2))
case (t: IntegralType, d: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
case (d: DecimalType, t: IntegralType) =>
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) =>
Some(DoubleType)
case _ => None
})
case None => None
})
}
private def haveSameType(exprs: Seq[Expression]): Boolean =
exprs.map(_.dataType).distinct.length == 1

View file

@ -53,7 +53,8 @@ class TypeCoercionSuite extends PlanTest {
// | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType |
// | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X |
// +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+
// Note: ArrayType*, MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable
// Note: MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable.
// Note: ArrayType* is castable when the element type is castable according to the table.
// scalastyle:on line.size.limit
private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = {
@ -125,6 +126,20 @@ class TypeCoercionSuite extends PlanTest {
}
}
private def checkWidenType(
widenFunc: (DataType, DataType) => Option[DataType],
t1: DataType,
t2: DataType,
expected: Option[DataType]): Unit = {
var found = widenFunc(t1, t2)
assert(found == expected,
s"Expected $expected as wider common type for $t1 and $t2, found $found")
// Test both directions to make sure the widening is symmetric.
found = widenFunc(t2, t1)
assert(found == expected,
s"Expected $expected as wider common type for $t2 and $t1, found $found")
}
test("implicit type cast - ByteType") {
val checkedType = ByteType
checkTypeCasting(checkedType, castableTypes = numericTypes ++ Seq(StringType))
@ -308,15 +323,8 @@ class TypeCoercionSuite extends PlanTest {
}
test("tightest common bound for types") {
def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
var found = TypeCoercion.findTightestCommonType(t1, t2)
assert(found == tightestCommon,
s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found")
// Test both directions to make sure the widening is symmetric.
found = TypeCoercion.findTightestCommonType(t2, t1)
assert(found == tightestCommon,
s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found")
}
def widenTest(t1: DataType, t2: DataType, expected: Option[DataType]): Unit =
checkWidenType(TypeCoercion.findTightestCommonType, t1, t2, expected)
// Null
widenTest(NullType, NullType, Some(NullType))
@ -355,7 +363,6 @@ class TypeCoercionSuite extends PlanTest {
widenTest(DecimalType(2, 1), DoubleType, None)
widenTest(DecimalType(2, 1), IntegerType, None)
widenTest(DoubleType, DecimalType(2, 1), None)
widenTest(IntegerType, DecimalType(2, 1), None)
// StringType
widenTest(NullType, StringType, Some(StringType))
@ -379,6 +386,60 @@ class TypeCoercionSuite extends PlanTest {
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
}
test("wider common type for decimal and array") {
def widenTestWithStringPromotion(
t1: DataType,
t2: DataType,
expected: Option[DataType]): Unit = {
checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected)
}
def widenTestWithoutStringPromotion(
t1: DataType,
t2: DataType,
expected: Option[DataType]): Unit = {
checkWidenType(TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected)
}
// Decimal
widenTestWithStringPromotion(
DecimalType(2, 1), DecimalType(3, 2), Some(DecimalType(3, 2)))
widenTestWithStringPromotion(
DecimalType(2, 1), DoubleType, Some(DoubleType))
widenTestWithStringPromotion(
DecimalType(2, 1), IntegerType, Some(DecimalType(11, 1)))
widenTestWithStringPromotion(
DecimalType(2, 1), LongType, Some(DecimalType(21, 1)))
// ArrayType
widenTestWithStringPromotion(
ArrayType(ShortType, containsNull = true),
ArrayType(DoubleType, containsNull = false),
Some(ArrayType(DoubleType, containsNull = true)))
widenTestWithStringPromotion(
ArrayType(TimestampType, containsNull = false),
ArrayType(StringType, containsNull = true),
Some(ArrayType(StringType, containsNull = true)))
widenTestWithStringPromotion(
ArrayType(ArrayType(IntegerType), containsNull = false),
ArrayType(ArrayType(LongType), containsNull = false),
Some(ArrayType(ArrayType(LongType), containsNull = false)))
// Without string promotion
widenTestWithoutStringPromotion(IntegerType, StringType, None)
widenTestWithoutStringPromotion(StringType, TimestampType, None)
widenTestWithoutStringPromotion(ArrayType(LongType), ArrayType(StringType), None)
widenTestWithoutStringPromotion(ArrayType(StringType), ArrayType(TimestampType), None)
// String promotion
widenTestWithStringPromotion(IntegerType, StringType, Some(StringType))
widenTestWithStringPromotion(StringType, TimestampType, Some(StringType))
widenTestWithStringPromotion(
ArrayType(LongType), ArrayType(StringType), Some(ArrayType(StringType)))
widenTestWithStringPromotion(
ArrayType(StringType), ArrayType(TimestampType), Some(ArrayType(StringType)))
}
private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
ruleTest(Seq(rule), initial, transformed)
}