[SPARK-5203][SQL] fix union with different decimal type

When union non-decimal types with decimals, we use the following rules:
      - FIRST `intTypeToFixed`, then fixed union decimals with precision/scale p1/s2 and p2/s2  will be promoted to
      DecimalType(max(p1, p2), max(s1, s2))
      - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive,
      but note that unlimited decimals are considered bigger than doubles in WidenTypes)

Author: guowei2 <guowei2@asiainfo.com>

Closes #4004 from guowei2/SPARK-5203 and squashes the following commits:

ff50f5f [guowei2] fix code style
11df1bf [guowei2] fix decimal union with double, double->Decimal(15,15)
0f345f9 [guowei2] fix structType merge with decimal
101ed4d [guowei2] fix build error after rebase
0b196e4 [guowei2] code style
fe2c2ca [guowei2] handle union decimal precision in 'DecimalPrecision'
421d840 [guowei2] fix union types for decimal precision
ef2c661 [guowei2] fix union with different decimal type
This commit is contained in:
guowei2 2015-04-04 02:02:30 +08:00 committed by Cheng Lian
parent dc6dff248d
commit c23ba81b8c
4 changed files with 167 additions and 69 deletions

View file

@ -285,6 +285,7 @@ trait HiveTypeCoercion {
* Calculates and propagates precision for fixed-precision decimals. Hive has a number of
* rules for this based on the SQL standard and MS SQL:
* https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
* https://msdn.microsoft.com/en-us/library/ms190476.aspx
*
* In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2
* respectively, then the following operations have the following precision / scale:
@ -296,6 +297,7 @@ trait HiveTypeCoercion {
* e1 * e2 p1 + p2 + 1 s1 + s2
* e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1)
* e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2)
* e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2)
* sum(e1) p1 + 10 s1
* avg(e1) p1 + 4 s1 + 4
*
@ -311,7 +313,12 @@ trait HiveTypeCoercion {
* - SHORT gets turned into DECIMAL(5, 0)
* - INT gets turned into DECIMAL(10, 0)
* - LONG gets turned into DECIMAL(20, 0)
* - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive,
* - FLOAT and DOUBLE
* 1. Union operation:
* FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this is the
* same as Hive)
* 2. Other operation:
* FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive,
* but note that unlimited decimals are considered bigger than doubles in WidenTypes)
*/
// scalastyle:on
@ -328,76 +335,127 @@ trait HiveTypeCoercion {
def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes whose children have not been resolved yet
case e if !e.childrenResolved => e
// Conversion rules for float and double into fixed-precision decimals
val floatTypeToFixed: Map[DataType, DecimalType] = Map(
FloatType -> DecimalType(7, 7),
DoubleType -> DecimalType(15, 15)
)
case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
)
case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
)
case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(p1 + p2 + 1, s1 + s2)
)
case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1))
)
case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
)
case LessThan(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
case GreaterThan(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
case b: BinaryExpression if b.left.dataType != b.right.dataType =>
(b.left.dataType, b.right.dataType) match {
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right))
case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t))))
case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
b.makeCopy(Array(b.left, Cast(b.right, DoubleType)))
case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
b.makeCopy(Array(Cast(b.left, DoubleType), b.right))
case _ =>
b
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// fix decimal precision for union
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
val castedInput = left.output.zip(right.output).map {
case (l, r) if l.dataType != r.dataType =>
(l.dataType, r.dataType) match {
case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
// Union decimals with precision/scale p1/s2 and p2/s2 will be promoted to
// DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))
val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2))
(Alias(Cast(l, fixedType), l.name)(), Alias(Cast(r, fixedType), r.name)())
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
(Alias(Cast(l, intTypeToFixed(t)), l.name)(), r)
case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
(l, Alias(Cast(r, intTypeToFixed(t)), r.name)())
case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) =>
(Alias(Cast(l, floatTypeToFixed(t)), l.name)(), r)
case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) =>
(l, Alias(Cast(r, floatTypeToFixed(t)), r.name)())
case _ => (l, r)
}
case other => other
}
// TODO: MaxOf, MinOf, etc might want other rules
val (castedLeft, castedRight) = castedInput.unzip
// SUM and AVERAGE are handled by the implementations of those expressions
val newLeft =
if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
Project(castedLeft, left)
} else {
left
}
val newRight =
if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
Project(castedRight, right)
} else {
right
}
Union(newLeft, newRight)
// fix decimal precision for expressions
case q => q.transformExpressions {
// Skip nodes whose children have not been resolved yet
case e if !e.childrenResolved => e
case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
)
case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
)
case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(p1 + p2 + 1, s1 + s2)
)
case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1))
)
case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
)
case LessThan(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
case GreaterThan(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
case b: BinaryExpression if b.left.dataType != b.right.dataType =>
(b.left.dataType, b.right.dataType) match {
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right))
case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t))))
case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
b.makeCopy(Array(b.left, Cast(b.right, DoubleType)))
case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
b.makeCopy(Array(Cast(b.left, DoubleType), b.right))
case _ =>
b
}
// TODO: MaxOf, MinOf, etc might want other rules
// SUM and AVERAGE are handled by the implementations of those expressions
}
}
}
/**

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.types
import java.sql.Timestamp
import scala.collection.mutable.ArrayBuffer
import scala.math._
import scala.math.Numeric.{FloatAsIfIntegral, DoubleAsIfIntegral}
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag}
@ -934,7 +935,9 @@ object StructType {
case (DecimalType.Fixed(leftPrecision, leftScale),
DecimalType.Fixed(rightPrecision, rightScale)) =>
DecimalType(leftPrecision.max(rightPrecision), leftScale.max(rightScale))
DecimalType(
max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale),
max(leftScale, rightScale))
case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_])
if leftUdt.userClass == rightUdt.userClass => leftUdt

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation}
import org.apache.spark.sql.types._
import org.scalatest.{BeforeAndAfter, FunSuite}
@ -31,7 +31,8 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
AttributeReference("d1", DecimalType(2, 1))(),
AttributeReference("d2", DecimalType(5, 2))(),
AttributeReference("u", DecimalType.Unlimited)(),
AttributeReference("f", FloatType)()
AttributeReference("f", FloatType)(),
AttributeReference("b", DoubleType)()
)
val i: Expression = UnresolvedAttribute("i")
@ -39,6 +40,7 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
val d2: Expression = UnresolvedAttribute("d2")
val u: Expression = UnresolvedAttribute("u")
val f: Expression = UnresolvedAttribute("f")
val b: Expression = UnresolvedAttribute("b")
before {
catalog.registerTable(Seq("table"), relation)
@ -58,6 +60,17 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
assert(comparison.right.dataType === expectedType)
}
private def checkUnion(left: Expression, right: Expression, expectedType: DataType): Unit = {
val plan =
Union(Project(Seq(Alias(left, "l")()), relation),
Project(Seq(Alias(right, "r")()), relation))
val (l, r) = analyzer(plan).collect {
case Union(left, right) => (left.output.head, right.output.head)
}.head
assert(l.dataType === expectedType)
assert(r.dataType === expectedType)
}
test("basic operations") {
checkType(Add(d1, d2), DecimalType(6, 2))
checkType(Subtract(d1, d2), DecimalType(6, 2))
@ -82,6 +95,19 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
checkComparison(GreaterThan(d2, d2), DecimalType(5, 2))
}
test("decimal precision for union") {
checkUnion(d1, i, DecimalType(11, 1))
checkUnion(i, d2, DecimalType(12, 2))
checkUnion(d1, d2, DecimalType(5, 2))
checkUnion(d2, d1, DecimalType(5, 2))
checkUnion(d1, f, DecimalType(8, 7))
checkUnion(f, d2, DecimalType(10, 7))
checkUnion(d1, b, DecimalType(16, 15))
checkUnion(b, d2, DecimalType(18, 15))
checkUnion(d1, u, DecimalType.Unlimited)
checkUnion(u, d2, DecimalType.Unlimited)
}
test("bringing in primitive types") {
checkType(Add(d1, i), DecimalType(12, 1))
checkType(Add(d1, f), DoubleType)

View file

@ -468,4 +468,15 @@ class SQLQuerySuite extends QueryTest {
sql(s"DROP TABLE $tableName")
}
}
test("SPARK-5203 union with different decimal precision") {
Seq.empty[(Decimal, Decimal)]
.toDF("d1", "d2")
.select($"d1".cast(DecimalType(10, 15)).as("d"))
.registerTempTable("dn")
sql("select d from dn union all select d * 2 from dn")
.queryExecution.analyzed
}
}