[SPARK-5203][SQL] fix union with different decimal type
When union non-decimal types with decimals, we use the following rules: - FIRST `intTypeToFixed`, then fixed union decimals with precision/scale p1/s2 and p2/s2 will be promoted to DecimalType(max(p1, p2), max(s1, s2)) - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, but note that unlimited decimals are considered bigger than doubles in WidenTypes) Author: guowei2 <guowei2@asiainfo.com> Closes #4004 from guowei2/SPARK-5203 and squashes the following commits: ff50f5f [guowei2] fix code style 11df1bf [guowei2] fix decimal union with double, double->Decimal(15,15) 0f345f9 [guowei2] fix structType merge with decimal 101ed4d [guowei2] fix build error after rebase 0b196e4 [guowei2] code style fe2c2ca [guowei2] handle union decimal precision in 'DecimalPrecision' 421d840 [guowei2] fix union types for decimal precision ef2c661 [guowei2] fix union with different decimal type
This commit is contained in:
parent
dc6dff248d
commit
c23ba81b8c
|
@ -285,6 +285,7 @@ trait HiveTypeCoercion {
|
|||
* Calculates and propagates precision for fixed-precision decimals. Hive has a number of
|
||||
* rules for this based on the SQL standard and MS SQL:
|
||||
* https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
|
||||
* https://msdn.microsoft.com/en-us/library/ms190476.aspx
|
||||
*
|
||||
* In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2
|
||||
* respectively, then the following operations have the following precision / scale:
|
||||
|
@ -296,6 +297,7 @@ trait HiveTypeCoercion {
|
|||
* e1 * e2 p1 + p2 + 1 s1 + s2
|
||||
* e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1)
|
||||
* e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2)
|
||||
* e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2)
|
||||
* sum(e1) p1 + 10 s1
|
||||
* avg(e1) p1 + 4 s1 + 4
|
||||
*
|
||||
|
@ -311,7 +313,12 @@ trait HiveTypeCoercion {
|
|||
* - SHORT gets turned into DECIMAL(5, 0)
|
||||
* - INT gets turned into DECIMAL(10, 0)
|
||||
* - LONG gets turned into DECIMAL(20, 0)
|
||||
* - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive,
|
||||
* - FLOAT and DOUBLE
|
||||
* 1. Union operation:
|
||||
* FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this is the
|
||||
* same as Hive)
|
||||
* 2. Other operation:
|
||||
* FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive,
|
||||
* but note that unlimited decimals are considered bigger than doubles in WidenTypes)
|
||||
*/
|
||||
// scalastyle:on
|
||||
|
@ -328,76 +335,127 @@ trait HiveTypeCoercion {
|
|||
|
||||
def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
|
||||
// Skip nodes whose children have not been resolved yet
|
||||
case e if !e.childrenResolved => e
|
||||
// Conversion rules for float and double into fixed-precision decimals
|
||||
val floatTypeToFixed: Map[DataType, DecimalType] = Map(
|
||||
FloatType -> DecimalType(7, 7),
|
||||
DoubleType -> DecimalType(15, 15)
|
||||
)
|
||||
|
||||
case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
Cast(
|
||||
Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
|
||||
DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
|
||||
)
|
||||
|
||||
case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
Cast(
|
||||
Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
|
||||
DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
|
||||
)
|
||||
|
||||
case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
Cast(
|
||||
Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
|
||||
DecimalType(p1 + p2 + 1, s1 + s2)
|
||||
)
|
||||
|
||||
case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
Cast(
|
||||
Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
|
||||
DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1))
|
||||
)
|
||||
|
||||
case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
Cast(
|
||||
Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
|
||||
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
|
||||
)
|
||||
|
||||
case LessThan(e1 @ DecimalType.Expression(p1, s1),
|
||||
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
|
||||
LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
|
||||
|
||||
case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
|
||||
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
|
||||
LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
|
||||
|
||||
case GreaterThan(e1 @ DecimalType.Expression(p1, s1),
|
||||
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
|
||||
GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
|
||||
|
||||
case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
|
||||
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
|
||||
GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
|
||||
|
||||
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
|
||||
// and fixed-precision decimals in an expression with floats / doubles to doubles
|
||||
case b: BinaryExpression if b.left.dataType != b.right.dataType =>
|
||||
(b.left.dataType, b.right.dataType) match {
|
||||
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
|
||||
b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right))
|
||||
case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
|
||||
b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t))))
|
||||
case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
|
||||
b.makeCopy(Array(b.left, Cast(b.right, DoubleType)))
|
||||
case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
|
||||
b.makeCopy(Array(Cast(b.left, DoubleType), b.right))
|
||||
case _ =>
|
||||
b
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
// fix decimal precision for union
|
||||
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
|
||||
val castedInput = left.output.zip(right.output).map {
|
||||
case (l, r) if l.dataType != r.dataType =>
|
||||
(l.dataType, r.dataType) match {
|
||||
case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
|
||||
// Union decimals with precision/scale p1/s2 and p2/s2 will be promoted to
|
||||
// DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))
|
||||
val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2))
|
||||
(Alias(Cast(l, fixedType), l.name)(), Alias(Cast(r, fixedType), r.name)())
|
||||
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
|
||||
(Alias(Cast(l, intTypeToFixed(t)), l.name)(), r)
|
||||
case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
|
||||
(l, Alias(Cast(r, intTypeToFixed(t)), r.name)())
|
||||
case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) =>
|
||||
(Alias(Cast(l, floatTypeToFixed(t)), l.name)(), r)
|
||||
case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) =>
|
||||
(l, Alias(Cast(r, floatTypeToFixed(t)), r.name)())
|
||||
case _ => (l, r)
|
||||
}
|
||||
case other => other
|
||||
}
|
||||
|
||||
// TODO: MaxOf, MinOf, etc might want other rules
|
||||
val (castedLeft, castedRight) = castedInput.unzip
|
||||
|
||||
// SUM and AVERAGE are handled by the implementations of those expressions
|
||||
val newLeft =
|
||||
if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
|
||||
Project(castedLeft, left)
|
||||
} else {
|
||||
left
|
||||
}
|
||||
|
||||
val newRight =
|
||||
if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
|
||||
Project(castedRight, right)
|
||||
} else {
|
||||
right
|
||||
}
|
||||
|
||||
Union(newLeft, newRight)
|
||||
|
||||
// fix decimal precision for expressions
|
||||
case q => q.transformExpressions {
|
||||
// Skip nodes whose children have not been resolved yet
|
||||
case e if !e.childrenResolved => e
|
||||
|
||||
case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
Cast(
|
||||
Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
|
||||
DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
|
||||
)
|
||||
|
||||
case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
Cast(
|
||||
Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
|
||||
DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
|
||||
)
|
||||
|
||||
case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
Cast(
|
||||
Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
|
||||
DecimalType(p1 + p2 + 1, s1 + s2)
|
||||
)
|
||||
|
||||
case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
Cast(
|
||||
Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
|
||||
DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1))
|
||||
)
|
||||
|
||||
case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
Cast(
|
||||
Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
|
||||
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
|
||||
)
|
||||
|
||||
case LessThan(e1 @ DecimalType.Expression(p1, s1),
|
||||
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
|
||||
LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
|
||||
|
||||
case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
|
||||
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
|
||||
LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
|
||||
|
||||
case GreaterThan(e1 @ DecimalType.Expression(p1, s1),
|
||||
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
|
||||
GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
|
||||
|
||||
case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
|
||||
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
|
||||
GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
|
||||
|
||||
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
|
||||
// and fixed-precision decimals in an expression with floats / doubles to doubles
|
||||
case b: BinaryExpression if b.left.dataType != b.right.dataType =>
|
||||
(b.left.dataType, b.right.dataType) match {
|
||||
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
|
||||
b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right))
|
||||
case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
|
||||
b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t))))
|
||||
case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
|
||||
b.makeCopy(Array(b.left, Cast(b.right, DoubleType)))
|
||||
case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
|
||||
b.makeCopy(Array(Cast(b.left, DoubleType), b.right))
|
||||
case _ =>
|
||||
b
|
||||
}
|
||||
|
||||
// TODO: MaxOf, MinOf, etc might want other rules
|
||||
|
||||
// SUM and AVERAGE are handled by the implementations of those expressions
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql.types
|
|||
import java.sql.Timestamp
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.math._
|
||||
import scala.math.Numeric.{FloatAsIfIntegral, DoubleAsIfIntegral}
|
||||
import scala.reflect.ClassTag
|
||||
import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag}
|
||||
|
@ -934,7 +935,9 @@ object StructType {
|
|||
|
||||
case (DecimalType.Fixed(leftPrecision, leftScale),
|
||||
DecimalType.Fixed(rightPrecision, rightScale)) =>
|
||||
DecimalType(leftPrecision.max(rightPrecision), leftScale.max(rightScale))
|
||||
DecimalType(
|
||||
max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale),
|
||||
max(leftScale, rightScale))
|
||||
|
||||
case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_])
|
||||
if leftUdt.userClass == rightUdt.userClass => leftUdt
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.sql.catalyst.analysis
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.scalatest.{BeforeAndAfter, FunSuite}
|
||||
|
||||
|
@ -31,7 +31,8 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
|
|||
AttributeReference("d1", DecimalType(2, 1))(),
|
||||
AttributeReference("d2", DecimalType(5, 2))(),
|
||||
AttributeReference("u", DecimalType.Unlimited)(),
|
||||
AttributeReference("f", FloatType)()
|
||||
AttributeReference("f", FloatType)(),
|
||||
AttributeReference("b", DoubleType)()
|
||||
)
|
||||
|
||||
val i: Expression = UnresolvedAttribute("i")
|
||||
|
@ -39,6 +40,7 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
|
|||
val d2: Expression = UnresolvedAttribute("d2")
|
||||
val u: Expression = UnresolvedAttribute("u")
|
||||
val f: Expression = UnresolvedAttribute("f")
|
||||
val b: Expression = UnresolvedAttribute("b")
|
||||
|
||||
before {
|
||||
catalog.registerTable(Seq("table"), relation)
|
||||
|
@ -58,6 +60,17 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
|
|||
assert(comparison.right.dataType === expectedType)
|
||||
}
|
||||
|
||||
private def checkUnion(left: Expression, right: Expression, expectedType: DataType): Unit = {
|
||||
val plan =
|
||||
Union(Project(Seq(Alias(left, "l")()), relation),
|
||||
Project(Seq(Alias(right, "r")()), relation))
|
||||
val (l, r) = analyzer(plan).collect {
|
||||
case Union(left, right) => (left.output.head, right.output.head)
|
||||
}.head
|
||||
assert(l.dataType === expectedType)
|
||||
assert(r.dataType === expectedType)
|
||||
}
|
||||
|
||||
test("basic operations") {
|
||||
checkType(Add(d1, d2), DecimalType(6, 2))
|
||||
checkType(Subtract(d1, d2), DecimalType(6, 2))
|
||||
|
@ -82,6 +95,19 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
|
|||
checkComparison(GreaterThan(d2, d2), DecimalType(5, 2))
|
||||
}
|
||||
|
||||
test("decimal precision for union") {
|
||||
checkUnion(d1, i, DecimalType(11, 1))
|
||||
checkUnion(i, d2, DecimalType(12, 2))
|
||||
checkUnion(d1, d2, DecimalType(5, 2))
|
||||
checkUnion(d2, d1, DecimalType(5, 2))
|
||||
checkUnion(d1, f, DecimalType(8, 7))
|
||||
checkUnion(f, d2, DecimalType(10, 7))
|
||||
checkUnion(d1, b, DecimalType(16, 15))
|
||||
checkUnion(b, d2, DecimalType(18, 15))
|
||||
checkUnion(d1, u, DecimalType.Unlimited)
|
||||
checkUnion(u, d2, DecimalType.Unlimited)
|
||||
}
|
||||
|
||||
test("bringing in primitive types") {
|
||||
checkType(Add(d1, i), DecimalType(12, 1))
|
||||
checkType(Add(d1, f), DoubleType)
|
||||
|
|
|
@ -468,4 +468,15 @@ class SQLQuerySuite extends QueryTest {
|
|||
sql(s"DROP TABLE $tableName")
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-5203 union with different decimal precision") {
|
||||
Seq.empty[(Decimal, Decimal)]
|
||||
.toDF("d1", "d2")
|
||||
.select($"d1".cast(DecimalType(10, 15)).as("d"))
|
||||
.registerTempTable("dn")
|
||||
|
||||
sql("select d from dn union all select d * 2 from dn")
|
||||
.queryExecution.analyzed
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue