[SPARK-19435][SQL] Type coercion between ArrayTypes
## What changes were proposed in this pull request? This PR proposes to support type coercion between `ArrayType`s where the element types are compatible. **Before** ``` Seq(Array(1)).toDF("a").selectExpr("greatest(a, array(1D))") org.apache.spark.sql.AnalysisException: cannot resolve 'greatest(`a`, array(1.0D))' due to data type mismatch: The expressions should all have the same type, got GREATEST(array<int>, array<double>).; line 1 pos 0; Seq(Array(1)).toDF("a").selectExpr("least(a, array(1D))") org.apache.spark.sql.AnalysisException: cannot resolve 'least(`a`, array(1.0D))' due to data type mismatch: The expressions should all have the same type, got LEAST(array<int>, array<double>).; line 1 pos 0; sql("SELECT * FROM values (array(0)), (array(1D)) as data(a)") org.apache.spark.sql.AnalysisException: incompatible types found in column a for inline table; line 1 pos 14 Seq(Array(1)).toDF("a").union(Seq(Array(1D)).toDF("b")) org.apache.spark.sql.AnalysisException: Union can only be performed on tables with the compatible column types. ArrayType(DoubleType,false) <> ArrayType(IntegerType,false) at the first column of the second table;; sql("SELECT IF(1=1, array(1), array(1D))") org.apache.spark.sql.AnalysisException: cannot resolve '(IF((1 = 1), array(1), array(1.0D)))' due to data type mismatch: differing types in '(IF((1 = 1), array(1), array(1.0D)))' (array<int> and array<double>).; line 1 pos 7; ``` **After** ```scala Seq(Array(1)).toDF("a").selectExpr("greatest(a, array(1D))") res5: org.apache.spark.sql.DataFrame = [greatest(a, array(1.0)): array<double>] Seq(Array(1)).toDF("a").selectExpr("least(a, array(1D))") res6: org.apache.spark.sql.DataFrame = [least(a, array(1.0)): array<double>] sql("SELECT * FROM values (array(0)), (array(1D)) as data(a)") res8: org.apache.spark.sql.DataFrame = [a: array<double>] Seq(Array(1)).toDF("a").union(Seq(Array(1D)).toDF("b")) res10: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [a: array<double>] sql("SELECT IF(1=1, array(1), array(1D))") res15: org.apache.spark.sql.DataFrame = [(IF((1 = 1), array(1), array(1.0))): array<double>] ``` ## How was this patch tested? Unit tests in `TypeCoercion` and Jenkins tests and building with scala 2.10 ```scala ./dev/change-scala-version.sh 2.10 ./build/mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package ``` Author: hyukjinkwon <gurwls223@gmail.com> Closes #16777 from HyukjinKwon/SPARK-19435.
This commit is contained in:
parent
905fdf0c24
commit
9af8f743b0
|
@ -101,13 +101,11 @@ object TypeCoercion {
|
|||
case _ => None
|
||||
}
|
||||
|
||||
/** Similar to [[findTightestCommonType]], but can promote all the way to StringType. */
|
||||
def findTightestCommonTypeToString(left: DataType, right: DataType): Option[DataType] = {
|
||||
findTightestCommonType(left, right).orElse((left, right) match {
|
||||
case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType)
|
||||
case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType)
|
||||
case _ => None
|
||||
})
|
||||
/** Promotes all the way to StringType. */
|
||||
private def stringPromotion(dt1: DataType, dt2: DataType): Option[DataType] = (dt1, dt2) match {
|
||||
case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType)
|
||||
case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType)
|
||||
case _ => None
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -117,7 +115,55 @@ object TypeCoercion {
|
|||
* loss of precision when widening decimal and double, and promotion to string.
|
||||
*/
|
||||
private[analysis] def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = {
|
||||
(t1, t2) match {
|
||||
findTightestCommonType(t1, t2)
|
||||
.orElse(findWiderTypeForDecimal(t1, t2))
|
||||
.orElse(stringPromotion(t1, t2))
|
||||
.orElse((t1, t2) match {
|
||||
case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
|
||||
findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2))
|
||||
case _ => None
|
||||
})
|
||||
}
|
||||
|
||||
private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = {
|
||||
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
|
||||
case Some(d) => findWiderTypeForTwo(d, c)
|
||||
case None => None
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Similar to [[findWiderTypeForTwo]] that can handle decimal types, but can't promote to
|
||||
* string. If the wider decimal type exceeds system limitation, this rule will truncate
|
||||
* the decimal type before return it.
|
||||
*/
|
||||
private[analysis] def findWiderTypeWithoutStringPromotionForTwo(
|
||||
t1: DataType,
|
||||
t2: DataType): Option[DataType] = {
|
||||
findTightestCommonType(t1, t2)
|
||||
.orElse(findWiderTypeForDecimal(t1, t2))
|
||||
.orElse((t1, t2) match {
|
||||
case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
|
||||
findWiderTypeWithoutStringPromotionForTwo(et1, et2)
|
||||
.map(ArrayType(_, containsNull1 || containsNull2))
|
||||
case _ => None
|
||||
})
|
||||
}
|
||||
|
||||
def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
|
||||
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
|
||||
case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c)
|
||||
case None => None
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds a wider type when one or both types are decimals. If the wider decimal type exceeds
|
||||
* system limitation, this rule will truncate the decimal type. If a decimal and other fractional
|
||||
* types are compared, returns a double type.
|
||||
*/
|
||||
private def findWiderTypeForDecimal(dt1: DataType, dt2: DataType): Option[DataType] = {
|
||||
(dt1, dt2) match {
|
||||
case (t1: DecimalType, t2: DecimalType) =>
|
||||
Some(DecimalPrecision.widerDecimalType(t1, t2))
|
||||
case (t: IntegralType, d: DecimalType) =>
|
||||
|
@ -126,40 +172,10 @@ object TypeCoercion {
|
|||
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
|
||||
case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) =>
|
||||
Some(DoubleType)
|
||||
case _ =>
|
||||
findTightestCommonTypeToString(t1, t2)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
private def findWiderCommonType(types: Seq[DataType]) = {
|
||||
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
|
||||
case Some(d) => findWiderTypeForTwo(d, c)
|
||||
case None => None
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Similar to [[findWiderCommonType]] that can handle decimal types, but can't promote to
|
||||
* string. If the wider decimal type exceeds system limitation, this rule will truncate
|
||||
* the decimal type before return it.
|
||||
*/
|
||||
def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
|
||||
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
|
||||
case Some(d) => findTightestCommonType(d, c).orElse((d, c) match {
|
||||
case (t1: DecimalType, t2: DecimalType) =>
|
||||
Some(DecimalPrecision.widerDecimalType(t1, t2))
|
||||
case (t: IntegralType, d: DecimalType) =>
|
||||
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
|
||||
case (d: DecimalType, t: IntegralType) =>
|
||||
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
|
||||
case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) =>
|
||||
Some(DoubleType)
|
||||
case _ => None
|
||||
})
|
||||
case None => None
|
||||
})
|
||||
}
|
||||
|
||||
private def haveSameType(exprs: Seq[Expression]): Boolean =
|
||||
exprs.map(_.dataType).distinct.length == 1
|
||||
|
||||
|
|
|
@ -53,7 +53,8 @@ class TypeCoercionSuite extends PlanTest {
|
|||
// | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType |
|
||||
// | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X |
|
||||
// +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+
|
||||
// Note: ArrayType*, MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable
|
||||
// Note: MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable.
|
||||
// Note: ArrayType* is castable when the element type is castable according to the table.
|
||||
// scalastyle:on line.size.limit
|
||||
|
||||
private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = {
|
||||
|
@ -125,6 +126,20 @@ class TypeCoercionSuite extends PlanTest {
|
|||
}
|
||||
}
|
||||
|
||||
private def checkWidenType(
|
||||
widenFunc: (DataType, DataType) => Option[DataType],
|
||||
t1: DataType,
|
||||
t2: DataType,
|
||||
expected: Option[DataType]): Unit = {
|
||||
var found = widenFunc(t1, t2)
|
||||
assert(found == expected,
|
||||
s"Expected $expected as wider common type for $t1 and $t2, found $found")
|
||||
// Test both directions to make sure the widening is symmetric.
|
||||
found = widenFunc(t2, t1)
|
||||
assert(found == expected,
|
||||
s"Expected $expected as wider common type for $t2 and $t1, found $found")
|
||||
}
|
||||
|
||||
test("implicit type cast - ByteType") {
|
||||
val checkedType = ByteType
|
||||
checkTypeCasting(checkedType, castableTypes = numericTypes ++ Seq(StringType))
|
||||
|
@ -308,15 +323,8 @@ class TypeCoercionSuite extends PlanTest {
|
|||
}
|
||||
|
||||
test("tightest common bound for types") {
|
||||
def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
|
||||
var found = TypeCoercion.findTightestCommonType(t1, t2)
|
||||
assert(found == tightestCommon,
|
||||
s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found")
|
||||
// Test both directions to make sure the widening is symmetric.
|
||||
found = TypeCoercion.findTightestCommonType(t2, t1)
|
||||
assert(found == tightestCommon,
|
||||
s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found")
|
||||
}
|
||||
def widenTest(t1: DataType, t2: DataType, expected: Option[DataType]): Unit =
|
||||
checkWidenType(TypeCoercion.findTightestCommonType, t1, t2, expected)
|
||||
|
||||
// Null
|
||||
widenTest(NullType, NullType, Some(NullType))
|
||||
|
@ -355,7 +363,6 @@ class TypeCoercionSuite extends PlanTest {
|
|||
widenTest(DecimalType(2, 1), DoubleType, None)
|
||||
widenTest(DecimalType(2, 1), IntegerType, None)
|
||||
widenTest(DoubleType, DecimalType(2, 1), None)
|
||||
widenTest(IntegerType, DecimalType(2, 1), None)
|
||||
|
||||
// StringType
|
||||
widenTest(NullType, StringType, Some(StringType))
|
||||
|
@ -379,6 +386,60 @@ class TypeCoercionSuite extends PlanTest {
|
|||
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
|
||||
}
|
||||
|
||||
test("wider common type for decimal and array") {
|
||||
def widenTestWithStringPromotion(
|
||||
t1: DataType,
|
||||
t2: DataType,
|
||||
expected: Option[DataType]): Unit = {
|
||||
checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected)
|
||||
}
|
||||
|
||||
def widenTestWithoutStringPromotion(
|
||||
t1: DataType,
|
||||
t2: DataType,
|
||||
expected: Option[DataType]): Unit = {
|
||||
checkWidenType(TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected)
|
||||
}
|
||||
|
||||
// Decimal
|
||||
widenTestWithStringPromotion(
|
||||
DecimalType(2, 1), DecimalType(3, 2), Some(DecimalType(3, 2)))
|
||||
widenTestWithStringPromotion(
|
||||
DecimalType(2, 1), DoubleType, Some(DoubleType))
|
||||
widenTestWithStringPromotion(
|
||||
DecimalType(2, 1), IntegerType, Some(DecimalType(11, 1)))
|
||||
widenTestWithStringPromotion(
|
||||
DecimalType(2, 1), LongType, Some(DecimalType(21, 1)))
|
||||
|
||||
// ArrayType
|
||||
widenTestWithStringPromotion(
|
||||
ArrayType(ShortType, containsNull = true),
|
||||
ArrayType(DoubleType, containsNull = false),
|
||||
Some(ArrayType(DoubleType, containsNull = true)))
|
||||
widenTestWithStringPromotion(
|
||||
ArrayType(TimestampType, containsNull = false),
|
||||
ArrayType(StringType, containsNull = true),
|
||||
Some(ArrayType(StringType, containsNull = true)))
|
||||
widenTestWithStringPromotion(
|
||||
ArrayType(ArrayType(IntegerType), containsNull = false),
|
||||
ArrayType(ArrayType(LongType), containsNull = false),
|
||||
Some(ArrayType(ArrayType(LongType), containsNull = false)))
|
||||
|
||||
// Without string promotion
|
||||
widenTestWithoutStringPromotion(IntegerType, StringType, None)
|
||||
widenTestWithoutStringPromotion(StringType, TimestampType, None)
|
||||
widenTestWithoutStringPromotion(ArrayType(LongType), ArrayType(StringType), None)
|
||||
widenTestWithoutStringPromotion(ArrayType(StringType), ArrayType(TimestampType), None)
|
||||
|
||||
// String promotion
|
||||
widenTestWithStringPromotion(IntegerType, StringType, Some(StringType))
|
||||
widenTestWithStringPromotion(StringType, TimestampType, Some(StringType))
|
||||
widenTestWithStringPromotion(
|
||||
ArrayType(LongType), ArrayType(StringType), Some(ArrayType(StringType)))
|
||||
widenTestWithStringPromotion(
|
||||
ArrayType(StringType), ArrayType(TimestampType), Some(ArrayType(StringType)))
|
||||
}
|
||||
|
||||
private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
|
||||
ruleTest(Seq(rule), initial, transformed)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue