[SPARK-9069] [SPARK-9264] [SQL] remove unlimited precision support for DecimalType

Romove Decimal.Unlimited (change to support precision up to 38, to match with Hive and other databases).

In order to keep backward source compatibility, Decimal.Unlimited is still there, but change to Decimal(38, 18).

If no precision and scale is provide, it's Decimal(10, 0) as before.

Author: Davies Liu <davies@databricks.com>

Closes #7605 from davies/decimal_unlimited and squashes the following commits:

aa3f115 [Davies Liu] fix tests and style
fb0d20d [Davies Liu] address comments
bfaae35 [Davies Liu] fix style
df93657 [Davies Liu] address comments and clean up
06727fd [Davies Liu] Merge branch 'master' of github.com:apache/spark into decimal_unlimited
4c28969 [Davies Liu] fix tests
8d783cc [Davies Liu] fix tests
788631c [Davies Liu] fix double with decimal in Union/except
1779bde [Davies Liu] fix scala style
c9c7c78 [Davies Liu] remove Decimal.Unlimited
This commit is contained in:
Davies Liu 2015-07-23 18:31:13 -07:00 committed by Reynold Xin
parent bebe3f7b45
commit 8a94eb23d5
53 changed files with 459 additions and 473 deletions

View file

@ -218,7 +218,7 @@ class AttributeSuite extends SparkFunSuite {
// Attribute.fromStructField should accept any NumericType, not just DoubleType
val longFldWithMeta = new StructField("x", LongType, false, metadata)
assert(Attribute.fromStructField(longFldWithMeta).isNumeric)
val decimalFldWithMeta = new StructField("x", DecimalType(None), false, metadata)
val decimalFldWithMeta = new StructField("x", DecimalType(38, 18), false, metadata)
assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric)
}
}

View file

@ -194,30 +194,33 @@ class TimestampType(AtomicType):
class DecimalType(FractionalType):
"""Decimal (decimal.Decimal) data type.
The DecimalType must have fixed precision (the maximum total number of digits)
and scale (the number of digits on the right of dot). For example, (5, 2) can
support the value from [-999.99 to 999.99].
The precision can be up to 38, the scale must less or equal to precision.
When create a DecimalType, the default precision and scale is (10, 0). When infer
schema from decimal.Decimal objects, it will be DecimalType(38, 18).
:param precision: the maximum total number of digits (default: 10)
:param scale: the number of digits on right side of dot. (default: 0)
"""
def __init__(self, precision=None, scale=None):
def __init__(self, precision=10, scale=0):
self.precision = precision
self.scale = scale
self.hasPrecisionInfo = precision is not None
self.hasPrecisionInfo = True # this is public API
def simpleString(self):
if self.hasPrecisionInfo:
return "decimal(%d,%d)" % (self.precision, self.scale)
else:
return "decimal(10,0)"
return "decimal(%d,%d)" % (self.precision, self.scale)
def jsonValue(self):
if self.hasPrecisionInfo:
return "decimal(%d,%d)" % (self.precision, self.scale)
else:
return "decimal"
return "decimal(%d,%d)" % (self.precision, self.scale)
def __repr__(self):
if self.hasPrecisionInfo:
return "DecimalType(%d,%d)" % (self.precision, self.scale)
else:
return "DecimalType()"
return "DecimalType(%d,%d)" % (self.precision, self.scale)
class DoubleType(FractionalType):
@ -761,7 +764,10 @@ def _infer_type(obj):
return obj.__UDT__
dataType = _type_mappings.get(type(obj))
if dataType is not None:
if dataType is DecimalType:
# the precision and scale of `obj` may be different from row to row.
return DecimalType(38, 18)
elif dataType is not None:
return dataType()
if isinstance(obj, dict):

View file

@ -111,12 +111,18 @@ public class DataTypes {
return new ArrayType(elementType, containsNull);
}
/**
* Creates a DecimalType by specifying the precision and scale.
*/
public static DecimalType createDecimalType(int precision, int scale) {
return DecimalType$.MODULE$.apply(precision, scale);
}
/**
* Creates a DecimalType with default precision and scale, which are 10 and 0.
*/
public static DecimalType createDecimalType() {
return DecimalType$.MODULE$.Unlimited();
return DecimalType$.MODULE$.USER_DEFAULT();
}
/**

View file

@ -75,7 +75,7 @@ private [sql] object JavaTypeInference {
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true)
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)

View file

@ -131,10 +131,10 @@ trait ScalaReflection {
case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true)
case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if t <:< localTypeOf[java.math.BigDecimal] =>
Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< localTypeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
@ -167,8 +167,8 @@ trait ScalaReflection {
case obj: Float => FloatType
case obj: Double => DoubleType
case obj: java.sql.Date => DateType
case obj: java.math.BigDecimal => DecimalType.Unlimited
case obj: Decimal => DecimalType.Unlimited
case obj: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT
case obj: Decimal => DecimalType.SYSTEM_DEFAULT
case obj: java.sql.Timestamp => TimestampType
case null => NullType
// For other cases, there is no obvious mapping from the type of the given object to a

View file

@ -322,7 +322,10 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected lazy val numericLiteral: Parser[Literal] =
( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) }
| sign.? ~ unsignedFloat ^^ { case s ~ f => Literal((s.getOrElse("") + f).toDouble) }
| sign.? ~ unsignedFloat ^^ {
// TODO(davies): some precisions may loss, we should create decimal literal
case s ~ f => Literal(BigDecimal(s.getOrElse("") + f).doubleValue())
}
)
protected lazy val unsignedFloat: Parser[String] =

View file

@ -19,7 +19,9 @@ package org.apache.spark.sql.catalyst.analysis
import javax.annotation.Nullable
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._
@ -58,8 +60,7 @@ object HiveTypeCoercion {
IntegerType,
LongType,
FloatType,
DoubleType,
DecimalType.Unlimited)
DoubleType)
/**
* Find the tightest common type of two types that might be used in a binary expression.
@ -72,15 +73,16 @@ object HiveTypeCoercion {
case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1)
// Promote numeric types to the highest of the two and all numeric types to unlimited decimal
case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) =>
Some(t2)
case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) =>
Some(t1)
// Promote numeric types to the highest of the two
case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2)
Some(numericPrecedence(index))
// Fixed-precision decimals can up-cast into unlimited
case (DecimalType.Unlimited, _: DecimalType) => Some(DecimalType.Unlimited)
case (_: DecimalType, DecimalType.Unlimited) => Some(DecimalType.Unlimited)
case _ => None
}
@ -101,7 +103,7 @@ object HiveTypeCoercion {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case None => None
case Some(d) =>
findTightestCommonTypeOfTwo(d, c).orElse(findTightestCommonTypeToString(d, c))
findTightestCommonTypeToString(d, c)
})
}
@ -158,6 +160,9 @@ object HiveTypeCoercion {
* converted to DOUBLE.
* - TINYINT, SMALLINT, and INT can all be converted to FLOAT.
* - BOOLEAN types cannot be converted to any other type.
* - Any integral numeric type can be implicitly converted to decimal type.
* - two different decimal types will be converted into a wider decimal type for both of them.
* - decimal type will be converted into double if there float or double together with it.
*
* Additionally, all types when UNION-ed with strings will be promoted to strings.
* Other string conversions are handled by PromoteStrings.
@ -166,55 +171,50 @@ object HiveTypeCoercion {
* - IntegerType to FloatType
* - LongType to FloatType
* - LongType to DoubleType
* - DecimalType to Double
*
* This rule is only applied to Union/Except/Intersect
*/
object WidenTypes extends Rule[LogicalPlan] {
private[this] def widenOutputTypes(planName: String, left: LogicalPlan, right: LogicalPlan):
(LogicalPlan, LogicalPlan) = {
// TODO: with fixed-precision decimals
val castedInput = left.output.zip(right.output).map {
// When a string is found on one side, make the other side a string too.
case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType =>
(lhs, Alias(Cast(rhs, StringType), rhs.name)())
case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType =>
(Alias(Cast(lhs, StringType), lhs.name)(), rhs)
private[this] def widenOutputTypes(
planName: String,
left: LogicalPlan,
right: LogicalPlan): (LogicalPlan, LogicalPlan) = {
val castedTypes = left.output.zip(right.output).map {
case (lhs, rhs) if lhs.dataType != rhs.dataType =>
logDebug(s"Resolving mismatched $planName input ${lhs.dataType}, ${rhs.dataType}")
findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType =>
val newLeft =
if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)()
val newRight =
if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)()
(newLeft, newRight)
}.getOrElse {
// If there is no applicable conversion, leave expression unchanged.
(lhs, rhs)
(lhs.dataType, rhs.dataType) match {
case (t1: DecimalType, t2: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(t1, t2))
case (t: IntegralType, d: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
case (d: DecimalType, t: IntegralType) =>
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
case (t: FractionalType, d: DecimalType) =>
Some(DoubleType)
case (d: DecimalType, t: FractionalType) =>
Some(DoubleType)
case _ =>
findTightestCommonTypeToString(lhs.dataType, rhs.dataType)
}
case other => other
case other => None
}
val (castedLeft, castedRight) = castedInput.unzip
val newLeft =
if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
logDebug(s"Widening numeric types in $planName $castedLeft ${left.output}")
Project(castedLeft, left)
} else {
left
def castOutput(plan: LogicalPlan): LogicalPlan = {
val casted = plan.output.zip(castedTypes).map {
case (hs, Some(dt)) if dt != hs.dataType =>
Alias(Cast(hs, dt), hs.name)()
case (hs, _) => hs
}
Project(casted, plan)
}
val newRight =
if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
logDebug(s"Widening numeric types in $planName $castedRight ${right.output}")
Project(castedRight, right)
} else {
right
}
(newLeft, newRight)
if (castedTypes.exists(_.isDefined)) {
(castOutput(left), castOutput(right))
} else {
(left, right)
}
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@ -334,144 +334,94 @@ object HiveTypeCoercion {
* - SHORT gets turned into DECIMAL(5, 0)
* - INT gets turned into DECIMAL(10, 0)
* - LONG gets turned into DECIMAL(20, 0)
* - FLOAT and DOUBLE
* 1. Union, Intersect and Except operations:
* FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this is the
* same as Hive)
* 2. Other operation:
* FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive,
* but note that unlimited decimals are considered bigger than doubles in WidenTypes)
* - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE
*
* Note: Union/Except/Interact is handled by WidenTypes
*/
// scalastyle:on
object DecimalPrecision extends Rule[LogicalPlan] {
import scala.math.{max, min}
// Conversion rules for integer types into fixed-precision decimals
private val intTypeToFixed: Map[DataType, DecimalType] = Map(
ByteType -> DecimalType(3, 0),
ShortType -> DecimalType(5, 0),
IntegerType -> DecimalType(10, 0),
LongType -> DecimalType(20, 0)
)
private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
// Conversion rules for float and double into fixed-precision decimals
private val floatTypeToFixed: Map[DataType, DecimalType] = Map(
FloatType -> DecimalType(7, 7),
DoubleType -> DecimalType(15, 15)
)
// Returns the wider decimal type that's wider than both of them
def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = {
widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale)
}
// max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
val scale = max(s1, s2)
val range = max(p1 - s1, p2 - s2)
DecimalType.bounded(range + scale, scale)
}
private def castDecimalPrecision(
left: LogicalPlan,
right: LogicalPlan): (LogicalPlan, LogicalPlan) = {
val castedInput = left.output.zip(right.output).map {
case (lhs, rhs) if lhs.dataType != rhs.dataType =>
(lhs.dataType, rhs.dataType) match {
case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
// Decimals with precision/scale p1/s2 and p2/s2 will be promoted to
// DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))
val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2))
(Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)())
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
(Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs)
case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
(lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)())
case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) =>
(Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs)
case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) =>
(lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)())
case _ => (lhs, rhs)
}
case other => other
}
/**
* An expression used to wrap the children when promote the precision of DecimalType to avoid
* promote multiple times.
*/
case class ChangePrecision(child: Expression) extends UnaryExpression {
override def dataType: DataType = child.dataType
override def eval(input: InternalRow): Any = child.eval(input)
override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = ""
override def prettyName: String = "change_precision"
}
val (castedLeft, castedRight) = castedInput.unzip
val newLeft =
if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
Project(castedLeft, left)
} else {
left
}
val newRight =
if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
Project(castedRight, right)
} else {
right
}
(newLeft, newRight)
def changePrecision(e: Expression, dataType: DataType): Expression = {
ChangePrecision(Cast(e, dataType))
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// fix decimal precision for union, intersect and except
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
val (newLeft, newRight) = castDecimalPrecision(left, right)
Union(newLeft, newRight)
case i @ Intersect(left, right) if i.childrenResolved && !i.resolved =>
val (newLeft, newRight) = castDecimalPrecision(left, right)
Intersect(newLeft, newRight)
case e @ Except(left, right) if e.childrenResolved && !e.resolved =>
val (newLeft, newRight) = castDecimalPrecision(left, right)
Except(newLeft, newRight)
// fix decimal precision for expressions
case q => q.transformExpressions {
// Skip nodes whose children have not been resolved yet
case e if !e.childrenResolved => e
// Skip nodes who is already promoted
case e: BinaryArithmetic if e.left.isInstanceOf[ChangePrecision] => e
case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
)
val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
Add(changePrecision(e1, dt), changePrecision(e2, dt))
case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
)
val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
Subtract(changePrecision(e1, dt), changePrecision(e2, dt))
case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(p1 + p2 + 1, s1 + s2)
)
val dt = DecimalType.bounded(p1 + p2 + 1, s1 + s2)
Multiply(changePrecision(e1, dt), changePrecision(e2, dt))
case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1))
)
val dt = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1))
Divide(changePrecision(e1, dt), changePrecision(e2, dt))
case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
)
val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
Cast(Remainder(changePrecision(e1, widerType), changePrecision(e2, widerType)),
resultType)
case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
Cast(
Pmod(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
)
val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
Cast(Pmod(changePrecision(e1, widerType), changePrecision(e2, widerType)), resultType)
// When we compare 2 decimal types with different precisions, cast them to the smallest
// common precision.
case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
val resultType = DecimalType(max(p1, p2), max(s1, s2))
val resultType = widerDecimalType(p1, s1, p2, s2)
b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType)))
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
(left.dataType, right.dataType) match {
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right))
case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
b.makeCopy(Array(left, Cast(right, intTypeToFixed(t))))
case (t: IntegralType, DecimalType.Fixed(p, s)) =>
b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right))
case (DecimalType.Fixed(p, s), t: IntegralType) =>
b.makeCopy(Array(left, Cast(right, DecimalType.forType(t))))
case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
b.makeCopy(Array(left, Cast(right, DoubleType)))
case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
@ -485,7 +435,6 @@ object HiveTypeCoercion {
// SUM and AVERAGE are handled by the implementations of those expressions
}
}
}
/**
@ -563,7 +512,7 @@ object HiveTypeCoercion {
case e if !e.childrenResolved => e
case Cast(e @ StringType(), t: IntegralType) =>
Cast(Cast(e, DecimalType.Unlimited), t)
Cast(Cast(e, DecimalType.forType(LongType)), t)
}
}
@ -756,8 +705,8 @@ object HiveTypeCoercion {
// Implicit cast among numeric types. When we reach here, input type is not acceptable.
// If input is a numeric type but not decimal, and we expect a decimal type,
// cast the input to unlimited precision decimal.
case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited)
// cast the input to decimal.
case (d: NumericType, DecimalType) => Cast(e, DecimalType.forType(d))
// For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long
case (_: NumericType, target: NumericType) => Cast(e, target)
@ -766,7 +715,7 @@ object HiveTypeCoercion {
case (TimestampType, DateType) => Cast(e, DateType)
// Implicit cast from/to string
case (StringType, DecimalType) => Cast(e, DecimalType.Unlimited)
case (StringType, DecimalType) => Cast(e, DecimalType.SYSTEM_DEFAULT)
case (StringType, target: NumericType) => Cast(e, target)
case (StringType, DateType) => Cast(e, DateType)
case (StringType, TimestampType) => Cast(e, TimestampType)

View file

@ -201,7 +201,7 @@ package object dsl {
/** Creates a new AttributeReference of type decimal */
def decimal: AttributeReference =
AttributeReference(s, DecimalType.Unlimited, nullable = true)()
AttributeReference(s, DecimalType.SYSTEM_DEFAULT, nullable = true)()
/** Creates a new AttributeReference of type decimal */
def decimal(precision: Int, scale: Int): AttributeReference =

View file

@ -300,12 +300,7 @@ case class Cast(child: Expression, dataType: DataType)
* NOTE: this modifies `value` in-place, so don't call it on external data.
*/
private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = {
decimalType match {
case DecimalType.Unlimited =>
value
case DecimalType.Fixed(precision, scale) =>
if (value.changePrecision(precision, scale)) value else null
}
if (value.changePrecision(decimalType.precision, decimalType.scale)) value else null
}
private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {

View file

@ -36,14 +36,13 @@ case class Average(child: Expression) extends AlgebraicAggregate {
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
private val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 4, scale + 4)
case DecimalType.Unlimited => DecimalType.Unlimited
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 4, s + 4)
case _ => DoubleType
}
private val sumDataType = child.dataType match {
case _ @ DecimalType() => DecimalType.Unlimited
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
case _ => DoubleType
}
@ -71,7 +70,14 @@ case class Average(child: Expression) extends AlgebraicAggregate {
)
// If all input are nulls, currentCount will be 0 and we will get null after the division.
override val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType)
override val evaluateExpression = child.dataType match {
case DecimalType.Fixed(p, s) =>
// increase the precision and scale to prevent precision loss
val dt = DecimalType.bounded(p + 14, s + 4)
Cast(Cast(currentSum, dt) / Cast(currentCount, dt), resultType)
case _ =>
Cast(currentSum, resultType) / Cast(currentCount, resultType)
}
}
case class Count(child: Expression) extends AlgebraicAggregate {
@ -255,15 +261,11 @@ case class Sum(child: Expression) extends AlgebraicAggregate {
private val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 4, scale + 4)
case DecimalType.Unlimited => DecimalType.Unlimited
DecimalType.bounded(precision + 10, scale)
case _ => child.dataType
}
private val sumDataType = child.dataType match {
case _ @ DecimalType() => DecimalType.Unlimited
case _ => child.dataType
}
private val sumDataType = resultType
private val currentSum = AttributeReference("currentSum", sumDataType)()

View file

@ -390,22 +390,21 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg
override def dataType: DataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive
case DecimalType.Unlimited =>
DecimalType.Unlimited
// Add 4 digits after decimal point, like Hive
DecimalType.bounded(precision + 4, scale + 4)
case _ =>
DoubleType
}
override def asPartial: SplitEvaluation = {
child.dataType match {
case DecimalType.Fixed(_, _) | DecimalType.Unlimited =>
// Turn the child to unlimited decimals for calculation, before going back to fixed
val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
case DecimalType.Fixed(precision, scale) =>
val partialSum = Alias(Sum(child), "PartialSum")()
val partialCount = Alias(Count(child), "PartialCount")()
val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited)
val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited)
// partialSum already increase the precision by 10
val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType)
val castedCount = Sum(partialCount.toAttribute)
SplitEvaluation(
Cast(Divide(castedSum, castedCount), dataType),
partialCount :: partialSum :: Nil)
@ -435,8 +434,8 @@ case class AverageFunction(expr: Expression, base: AggregateExpression1)
private val calcType =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision + 10, scale)
case _ =>
expr.dataType
}
@ -454,10 +453,9 @@ case class AverageFunction(expr: Expression, base: AggregateExpression1)
null
} else {
expr.dataType match {
case DecimalType.Fixed(_, _) =>
Cast(Divide(
Cast(sum, DecimalType.Unlimited),
Cast(Literal(count), DecimalType.Unlimited)), dataType).eval(null)
case DecimalType.Fixed(precision, scale) =>
val dt = DecimalType.bounded(precision + 14, scale + 4)
Cast(Divide(Cast(sum, dt), Cast(Literal(count), dt)), dataType).eval(null)
case _ =>
Divide(
Cast(sum, dataType),
@ -481,9 +479,8 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1
override def dataType: DataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
case DecimalType.Unlimited =>
DecimalType.Unlimited
// Add 10 digits left of decimal point, like Hive
DecimalType.bounded(precision + 10, scale)
case _ =>
child.dataType
}
@ -491,7 +488,7 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1
override def asPartial: SplitEvaluation = {
child.dataType match {
case DecimalType.Fixed(_, _) =>
val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
val partialSum = Alias(Sum(child), "PartialSum")()
SplitEvaluation(
Cast(CombineSum(partialSum.toAttribute), dataType),
partialSum :: Nil)
@ -515,8 +512,8 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg
private val calcType =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision + 10, scale)
case _ =>
expr.dataType
}
@ -572,8 +569,8 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression1)
private val calcType =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision + 10, scale)
case _ =>
expr.dataType
}
@ -608,9 +605,8 @@ case class SumDistinct(child: Expression) extends UnaryExpression with PartialAg
override def nullable: Boolean = true
override def dataType: DataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
case DecimalType.Unlimited =>
DecimalType.Unlimited
// Add 10 digits left of decimal point, like Hive
DecimalType.bounded(precision + 10, scale)
case _ =>
child.dataType
}

View file

@ -88,6 +88,8 @@ abstract class BinaryArithmetic extends BinaryOperator {
override def dataType: DataType = left.dataType
override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess
/** Name of the function for this expression on a [[Decimal]] type. */
def decimalMethod: String =
sys.error("BinaryArithmetics must override either decimalMethod or genCode")
@ -114,9 +116,6 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "+"
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
private lazy val numeric = TypeUtils.getNumeric(dataType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
@ -146,9 +145,6 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
override def symbol: String = "-"
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
private lazy val numeric = TypeUtils.getNumeric(dataType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
@ -179,9 +175,6 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
override def symbol: String = "*"
override def decimalMethod: String = "$times"
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
private lazy val numeric = TypeUtils.getNumeric(dataType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
@ -195,9 +188,6 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
override def decimalMethod: String = "$div"
override def nullable: Boolean = true
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
private lazy val div: (Any, Any) => Any = dataType match {
case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot
@ -260,9 +250,6 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
override def decimalMethod: String = "remainder"
override def nullable: Boolean = true
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
private lazy val integral = dataType match {
case i: IntegralType => i.integral.asInstanceOf[Integral[Any]]
case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]]

View file

@ -36,9 +36,9 @@ object Literal {
case s: Short => Literal(s, ShortType)
case s: String => Literal(UTF8String.fromString(s), StringType)
case b: Boolean => Literal(b, BooleanType)
case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
case d: Decimal => Literal(d, DecimalType.Unlimited)
case d: BigDecimal => Literal(Decimal(d), DecimalType(d.precision, d.scale))
case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType(d.precision(), d.scale()))
case d: Decimal => Literal(d, DecimalType(d.precision, d.scale))
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
case a: Array[Byte] => Literal(a, BinaryType)

View file

@ -17,9 +17,9 @@
package org.apache.spark.sql.catalyst.plans
import org.apache.spark.sql.catalyst.expressions.{VirtualColumn, Attribute, AttributeSet, Expression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn}
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType}
import org.apache.spark.sql.types.{DataType, StructType}
abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] {
self: PlanType =>

View file

@ -106,7 +106,7 @@ object DataType {
private def nameToType(name: String): DataType = {
val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r
name match {
case "decimal" => DecimalType.Unlimited
case "decimal" => DecimalType.USER_DEFAULT
case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt)
case other => nonDecimalNameToType(other)
}
@ -177,7 +177,7 @@ object DataType {
| "BinaryType" ^^^ BinaryType
| "BooleanType" ^^^ BooleanType
| "DateType" ^^^ DateType
| "DecimalType()" ^^^ DecimalType.Unlimited
| "DecimalType()" ^^^ DecimalType.USER_DEFAULT
| fixedDecimalType
| "TimestampType" ^^^ TimestampType
)

View file

@ -48,7 +48,7 @@ private[sql] trait DataTypeParser extends StandardTokenParsers {
"(?i)binary".r ^^^ BinaryType |
"(?i)boolean".r ^^^ BooleanType |
fixedDecimalType |
"(?i)decimal".r ^^^ DecimalType.Unlimited |
"(?i)decimal".r ^^^ DecimalType.USER_DEFAULT |
"(?i)date".r ^^^ DateType |
"(?i)timestamp".r ^^^ TimestampType |
varchar

View file

@ -26,25 +26,46 @@ import org.apache.spark.sql.catalyst.expressions.Expression
/** Precision parameters for a Decimal */
@deprecated("Use DecimalType(precision, scale) directly", "1.5")
case class PrecisionInfo(precision: Int, scale: Int) {
if (scale > precision) {
throw new AnalysisException(
s"Decimal scale ($scale) cannot be greater than precision ($precision).")
}
if (precision > DecimalType.MAX_PRECISION) {
throw new AnalysisException(
s"DecimalType can only support precision up to 38"
)
}
}
/**
* :: DeveloperApi ::
* The data type representing `java.math.BigDecimal` values.
* A Decimal that might have fixed precision and scale, or unlimited values for these.
* A Decimal that must have fixed precision (the maximum number of digits) and scale (the number
* of digits on right side of dot).
*
* The precision can be up to 38, scale can also be up to 38 (less or equal to precision).
*
* The default precision and scale is (10, 0).
*
* Please use [[DataTypes.createDecimalType()]] to create a specific instance.
*/
@DeveloperApi
case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType {
case class DecimalType(precision: Int, scale: Int) extends FractionalType {
/** No-arg constructor for kryo. */
protected def this() = this(null)
// default constructor for Java
def this(precision: Int) = this(precision, 0)
def this() = this(10)
@deprecated("Use DecimalType(precision, scale) instead", "1.5")
def this(precisionInfo: Option[PrecisionInfo]) {
this(precisionInfo.getOrElse(PrecisionInfo(10, 0)).precision,
precisionInfo.getOrElse(PrecisionInfo(10, 0)).scale)
}
@deprecated("Use DecimalType.precision and DecimalType.scale instead", "1.5")
val precisionInfo = Some(PrecisionInfo(precision, scale))
private[sql] type InternalType = Decimal
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
@ -53,18 +74,16 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
private[sql] val ordering = Decimal.DecimalIsFractional
private[sql] val asIntegral = Decimal.DecimalAsIfIntegral
def precision: Int = precisionInfo.map(_.precision).getOrElse(-1)
override def typeName: String = s"decimal($precision,$scale)"
def scale: Int = precisionInfo.map(_.scale).getOrElse(-1)
override def toString: String = s"DecimalType($precision,$scale)"
override def typeName: String = precisionInfo match {
case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
case None => "decimal"
}
override def toString: String = precisionInfo match {
case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)"
case None => "DecimalType()"
private[sql] def isWiderThan(other: DataType): Boolean = other match {
case dt: DecimalType =>
(precision - scale) >= (dt.precision - dt.scale) && scale >= dt.scale
case dt: IntegralType =>
isWiderThan(DecimalType.forType(dt))
case _ => false
}
/**
@ -72,10 +91,7 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
*/
override def defaultSize: Int = 4096
override def simpleString: String = precisionInfo match {
case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
case None => "decimal(10,0)"
}
override def simpleString: String = s"decimal($precision,$scale)"
private[spark] override def asNullable: DecimalType = this
}
@ -83,8 +99,47 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
/** Extra factory methods and pattern matchers for Decimals */
object DecimalType extends AbstractDataType {
import scala.math.min
override private[sql] def defaultConcreteType: DataType = Unlimited
val MAX_PRECISION = 38
val MAX_SCALE = 38
val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18)
val USER_DEFAULT: DecimalType = DecimalType(10, 0)
@deprecated("Does not support unlimited precision, please specify the precision and scale", "1.5")
val Unlimited: DecimalType = SYSTEM_DEFAULT
// The decimal types compatible with other numberic types
private[sql] val ByteDecimal = DecimalType(3, 0)
private[sql] val ShortDecimal = DecimalType(5, 0)
private[sql] val IntDecimal = DecimalType(10, 0)
private[sql] val LongDecimal = DecimalType(20, 0)
private[sql] val FloatDecimal = DecimalType(14, 7)
private[sql] val DoubleDecimal = DecimalType(30, 15)
private[sql] def forType(dataType: DataType): DecimalType = dataType match {
case ByteType => ByteDecimal
case ShortType => ShortDecimal
case IntegerType => IntDecimal
case LongType => LongDecimal
case FloatType => FloatDecimal
case DoubleType => DoubleDecimal
}
@deprecated("please specify precision and scale", "1.5")
def apply(): DecimalType = USER_DEFAULT
@deprecated("Use DecimalType(precision, scale) instead", "1.5")
def apply(precisionInfo: Option[PrecisionInfo]) {
this(precisionInfo.getOrElse(PrecisionInfo(10, 0)).precision,
precisionInfo.getOrElse(PrecisionInfo(10, 0)).scale)
}
private[sql] def bounded(precision: Int, scale: Int): DecimalType = {
DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE))
}
override private[sql] def defaultConcreteType: DataType = SYSTEM_DEFAULT
override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[DecimalType]
@ -92,31 +147,18 @@ object DecimalType extends AbstractDataType {
override private[sql] def simpleString: String = "decimal"
val Unlimited: DecimalType = DecimalType(None)
private[sql] object Fixed {
def unapply(t: DecimalType): Option[(Int, Int)] =
t.precisionInfo.map(p => (p.precision, p.scale))
def unapply(t: DecimalType): Option[(Int, Int)] = Some((t.precision, t.scale))
}
private[sql] object Expression {
def unapply(e: Expression): Option[(Int, Int)] = e.dataType match {
case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale))
case t: DecimalType => Some((t.precision, t.scale))
case _ => None
}
}
def apply(): DecimalType = Unlimited
def apply(precision: Int, scale: Int): DecimalType =
DecimalType(Some(PrecisionInfo(precision, scale)))
def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType]
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType]
def isFixed(dataType: DataType): Boolean = dataType match {
case DecimalType.Fixed(_, _) => true
case _ => false
}
}

View file

@ -94,8 +94,8 @@ object RandomDataGenerator {
case BooleanType => Some(() => rand.nextBoolean())
case DateType => Some(() => new java.sql.Date(rand.nextInt()))
case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong()))
case DecimalType.Unlimited => Some(
() => BigDecimal.apply(rand.nextLong, rand.nextInt, MathContext.UNLIMITED))
case DecimalType.Fixed(precision, scale) => Some(
() => BigDecimal.apply(rand.nextLong, rand.nextInt, new MathContext(precision)))
case DoubleType => randomNumeric[Double](
rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue,
Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0))

View file

@ -50,9 +50,7 @@ class RandomDataGeneratorSuite extends SparkFunSuite {
for (
dataType <- DataTypeTestUtils.atomicTypes;
nullable <- Seq(true, false)
if !dataType.isInstanceOf[DecimalType] ||
dataType.asInstanceOf[DecimalType].precisionInfo.isEmpty
) {
if !dataType.isInstanceOf[DecimalType]) {
test(s"$dataType (nullable=$nullable)") {
testRandomDataGeneration(dataType)
}

View file

@ -102,7 +102,7 @@ class ScalaReflectionSuite extends SparkFunSuite {
StructField("byteField", ByteType, nullable = true),
StructField("booleanField", BooleanType, nullable = true),
StructField("stringField", StringType, nullable = true),
StructField("decimalField", DecimalType.Unlimited, nullable = true),
StructField("decimalField", DecimalType.SYSTEM_DEFAULT, nullable = true),
StructField("dateField", DateType, nullable = true),
StructField("timestampField", TimestampType, nullable = true),
StructField("binaryField", BinaryType, nullable = true))),
@ -216,7 +216,7 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(DoubleType === typeOfObject(1.7976931348623157E308))
// DecimalType
assert(DecimalType.Unlimited ===
assert(DecimalType.SYSTEM_DEFAULT ===
typeOfObject(new java.math.BigDecimal("1.7976931348623157E318")))
// DateType
@ -229,19 +229,19 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(NullType === typeOfObject(null))
def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse {
case value: java.math.BigInteger => DecimalType.Unlimited
case value: java.math.BigDecimal => DecimalType.Unlimited
case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT
case value: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT
case _ => StringType
}
assert(DecimalType.Unlimited === typeOfObject1(
assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1(
new BigInteger("92233720368547758070")))
assert(DecimalType.Unlimited === typeOfObject1(
assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1(
new java.math.BigDecimal("1.7976931348623157E318")))
assert(StringType === typeOfObject1(BigInt("92233720368547758070")))
def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse {
case value: java.math.BigInteger => DecimalType.Unlimited
case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT
}
intercept[MatchError](typeOfObject2(BigInt("92233720368547758070")))

View file

@ -55,7 +55,7 @@ object AnalysisSuite {
AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", DoubleType)(),
AttributeReference("d", DecimalType.Unlimited)(),
AttributeReference("d", DecimalType.SYSTEM_DEFAULT)(),
AttributeReference("e", ShortType)())
val nestedRelation = LocalRelation(
@ -158,7 +158,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", DoubleType)(),
AttributeReference("d", DecimalType.Unlimited)(),
AttributeReference("d", DecimalType(10, 2))(),
AttributeReference("e", ShortType)())
val plan = caseInsensitiveAnalyzer.execute(
@ -173,7 +173,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
assert(pl(0).dataType == DoubleType)
assert(pl(1).dataType == DoubleType)
assert(pl(2).dataType == DoubleType)
assert(pl(3).dataType == DecimalType.Unlimited)
assert(pl(3).dataType == DoubleType) // StringType will be promoted into Double
assert(pl(4).dataType == DoubleType)
}
}

View file

@ -34,7 +34,7 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
AttributeReference("i", IntegerType)(),
AttributeReference("d1", DecimalType(2, 1))(),
AttributeReference("d2", DecimalType(5, 2))(),
AttributeReference("u", DecimalType.Unlimited)(),
AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
AttributeReference("f", FloatType)(),
AttributeReference("b", DoubleType)()
)
@ -92,11 +92,11 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
}
test("Comparison operations") {
checkComparison(EqualTo(i, d1), DecimalType(10, 1))
checkComparison(EqualTo(i, d1), DecimalType(11, 1))
checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2))
checkComparison(LessThan(i, d1), DecimalType(10, 1))
checkComparison(LessThan(i, d1), DecimalType(11, 1))
checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2))
checkComparison(GreaterThan(d2, u), DecimalType.Unlimited)
checkComparison(GreaterThan(d2, u), DecimalType.SYSTEM_DEFAULT)
checkComparison(GreaterThanOrEqual(d1, f), DoubleType)
checkComparison(GreaterThan(d2, d2), DecimalType(5, 2))
}
@ -106,12 +106,12 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
checkUnion(i, d2, DecimalType(12, 2))
checkUnion(d1, d2, DecimalType(5, 2))
checkUnion(d2, d1, DecimalType(5, 2))
checkUnion(d1, f, DecimalType(8, 7))
checkUnion(f, d2, DecimalType(10, 7))
checkUnion(d1, b, DecimalType(16, 15))
checkUnion(b, d2, DecimalType(18, 15))
checkUnion(d1, u, DecimalType.Unlimited)
checkUnion(u, d2, DecimalType.Unlimited)
checkUnion(d1, f, DoubleType)
checkUnion(f, d2, DoubleType)
checkUnion(d1, b, DoubleType)
checkUnion(b, d2, DoubleType)
checkUnion(d1, u, DecimalType.SYSTEM_DEFAULT)
checkUnion(u, d2, DecimalType.SYSTEM_DEFAULT)
}
test("bringing in primitive types") {
@ -125,13 +125,33 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
checkType(Add(d1, Cast(i, DoubleType)), DoubleType)
}
test("unlimited decimals make everything else cast up") {
for (expr <- Seq(d1, d2, i, f, u)) {
checkType(Add(expr, u), DecimalType.Unlimited)
checkType(Subtract(expr, u), DecimalType.Unlimited)
checkType(Multiply(expr, u), DecimalType.Unlimited)
checkType(Divide(expr, u), DecimalType.Unlimited)
checkType(Remainder(expr, u), DecimalType.Unlimited)
test("maximum decimals") {
for (expr <- Seq(d1, d2, i, u)) {
checkType(Add(expr, u), DecimalType.SYSTEM_DEFAULT)
checkType(Subtract(expr, u), DecimalType.SYSTEM_DEFAULT)
}
checkType(Multiply(d1, u), DecimalType(38, 19))
checkType(Multiply(d2, u), DecimalType(38, 20))
checkType(Multiply(i, u), DecimalType(38, 18))
checkType(Multiply(u, u), DecimalType(38, 36))
checkType(Divide(u, d1), DecimalType(38, 21))
checkType(Divide(u, d2), DecimalType(38, 24))
checkType(Divide(u, i), DecimalType(38, 29))
checkType(Divide(u, u), DecimalType(38, 38))
checkType(Remainder(d1, u), DecimalType(19, 18))
checkType(Remainder(d2, u), DecimalType(21, 18))
checkType(Remainder(i, u), DecimalType(28, 18))
checkType(Remainder(u, u), DecimalType.SYSTEM_DEFAULT)
for (expr <- Seq(f, b)) {
checkType(Add(expr, u), DoubleType)
checkType(Subtract(expr, u), DoubleType)
checkType(Multiply(expr, u), DoubleType)
checkType(Divide(expr, u), DoubleType)
checkType(Remainder(expr, u), DoubleType)
}
}
}

View file

@ -35,14 +35,14 @@ class HiveTypeCoercionSuite extends PlanTest {
shouldCast(NullType, NullType, NullType)
shouldCast(NullType, IntegerType, IntegerType)
shouldCast(NullType, DecimalType, DecimalType.Unlimited)
shouldCast(NullType, DecimalType, DecimalType.SYSTEM_DEFAULT)
shouldCast(ByteType, IntegerType, IntegerType)
shouldCast(IntegerType, IntegerType, IntegerType)
shouldCast(IntegerType, LongType, LongType)
shouldCast(IntegerType, DecimalType, DecimalType.Unlimited)
shouldCast(IntegerType, DecimalType, DecimalType(10, 0))
shouldCast(LongType, IntegerType, IntegerType)
shouldCast(LongType, DecimalType, DecimalType.Unlimited)
shouldCast(LongType, DecimalType, DecimalType(20, 0))
shouldCast(DateType, TimestampType, TimestampType)
shouldCast(TimestampType, DateType, DateType)
@ -71,8 +71,8 @@ class HiveTypeCoercionSuite extends PlanTest {
shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType)
shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType)
shouldCast(
DecimalType.Unlimited, TypeCollection(IntegerType, DecimalType), DecimalType.Unlimited)
shouldCast(DecimalType.SYSTEM_DEFAULT,
TypeCollection(IntegerType, DecimalType), DecimalType.SYSTEM_DEFAULT)
shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2))
shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2))
shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2))
@ -82,7 +82,7 @@ class HiveTypeCoercionSuite extends PlanTest {
// NumericType should not be changed when function accepts any of them.
Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType,
DecimalType.Unlimited, DecimalType(10, 2)).foreach { tpe =>
DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2)).foreach { tpe =>
shouldCast(tpe, NumericType, tpe)
}
@ -107,8 +107,8 @@ class HiveTypeCoercionSuite extends PlanTest {
shouldNotCast(IntegerType, TimestampType)
shouldNotCast(LongType, DateType)
shouldNotCast(LongType, TimestampType)
shouldNotCast(DecimalType.Unlimited, DateType)
shouldNotCast(DecimalType.Unlimited, TimestampType)
shouldNotCast(DecimalType.SYSTEM_DEFAULT, DateType)
shouldNotCast(DecimalType.SYSTEM_DEFAULT, TimestampType)
shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType))
@ -160,14 +160,6 @@ class HiveTypeCoercionSuite extends PlanTest {
widenTest(LongType, FloatType, Some(FloatType))
widenTest(LongType, DoubleType, Some(DoubleType))
// Casting up to unlimited-precision decimal
widenTest(IntegerType, DecimalType.Unlimited, Some(DecimalType.Unlimited))
widenTest(DoubleType, DecimalType.Unlimited, Some(DecimalType.Unlimited))
widenTest(DecimalType(3, 2), DecimalType.Unlimited, Some(DecimalType.Unlimited))
widenTest(DecimalType.Unlimited, IntegerType, Some(DecimalType.Unlimited))
widenTest(DecimalType.Unlimited, DoubleType, Some(DecimalType.Unlimited))
widenTest(DecimalType.Unlimited, DecimalType(3, 2), Some(DecimalType.Unlimited))
// No up-casting for fixed-precision decimal (this is handled by arithmetic rules)
widenTest(DecimalType(2, 1), DecimalType(3, 2), None)
widenTest(DecimalType(2, 1), DoubleType, None)
@ -242,9 +234,9 @@ class HiveTypeCoercionSuite extends PlanTest {
:: Literal(1)
:: Literal(new java.math.BigDecimal("1000000000000000000000"))
:: Nil),
Coalesce(Cast(Literal(1L), DecimalType())
:: Cast(Literal(1), DecimalType())
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType())
Coalesce(Cast(Literal(1L), DecimalType(22, 0))
:: Cast(Literal(1), DecimalType(22, 0))
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
:: Nil))
}
@ -323,7 +315,7 @@ class HiveTypeCoercionSuite extends PlanTest {
val left = LocalRelation(
AttributeReference("i", IntegerType)(),
AttributeReference("u", DecimalType.Unlimited)(),
AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
AttributeReference("b", ByteType)(),
AttributeReference("d", DoubleType)())
val right = LocalRelation(
@ -333,7 +325,7 @@ class HiveTypeCoercionSuite extends PlanTest {
AttributeReference("l", LongType)())
val wt = HiveTypeCoercion.WidenTypes
val expectedTypes = Seq(StringType, DecimalType.Unlimited, FloatType, DoubleType)
val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
val r1 = wt(Union(left, right)).asInstanceOf[Union]
val r2 = wt(Except(left, right)).asInstanceOf[Except]
@ -353,13 +345,13 @@ class HiveTypeCoercionSuite extends PlanTest {
}
}
val dp = HiveTypeCoercion.DecimalPrecision
val dp = HiveTypeCoercion.WidenTypes
val left1 = LocalRelation(
AttributeReference("l", DecimalType(10, 8))())
val right1 = LocalRelation(
AttributeReference("r", DecimalType(5, 5))())
val expectedType1 = Seq(DecimalType(math.max(8, 5) + math.max(10 - 8, 5 - 5), math.max(8, 5)))
val expectedType1 = Seq(DecimalType(10, 8))
val r1 = dp(Union(left1, right1)).asInstanceOf[Union]
val r2 = dp(Except(left1, right1)).asInstanceOf[Except]
@ -372,12 +364,11 @@ class HiveTypeCoercionSuite extends PlanTest {
checkOutput(r3.left, expectedType1)
checkOutput(r3.right, expectedType1)
val plan1 = LocalRelation(
AttributeReference("l", DecimalType(10, 10))())
val plan1 = LocalRelation(AttributeReference("l", DecimalType(10, 5))())
val rightTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType)
val expectedTypes = Seq(DecimalType(3, 0), DecimalType(5, 0), DecimalType(10, 0),
DecimalType(20, 0), DecimalType(7, 7), DecimalType(15, 15))
val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5),
DecimalType(25, 5), DoubleType, DoubleType)
rightTypes.zip(expectedTypes).map { case (rType, expectedType) =>
val plan2 = LocalRelation(

View file

@ -185,7 +185,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkCast(1, 1.0)
checkCast(123, "123")
checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123))
checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123))
checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123))
checkEvaluation(cast(123, DecimalType(3, 1)), null)
checkEvaluation(cast(123, DecimalType(2, 0)), null)
@ -203,7 +203,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkCast(1L, 1.0)
checkCast(123L, "123")
checkEvaluation(cast(123L, DecimalType.Unlimited), Decimal(123))
checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123))
checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123))
checkEvaluation(cast(123L, DecimalType(3, 1)), Decimal(123.0))
@ -225,7 +225,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(cast(1000, TimestampType), LongType), 1.toLong)
checkEvaluation(cast(cast(-1200, TimestampType), LongType), -2.toLong)
checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123))
checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123))
checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123))
checkEvaluation(cast(123, DecimalType(3, 1)), null)
checkEvaluation(cast(123, DecimalType(2, 0)), null)
@ -267,7 +267,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(cast("abcdef", IntegerType).nullable === true)
assert(cast("abcdef", ShortType).nullable === true)
assert(cast("abcdef", ByteType).nullable === true)
assert(cast("abcdef", DecimalType.Unlimited).nullable === true)
assert(cast("abcdef", DecimalType.USER_DEFAULT).nullable === true)
assert(cast("abcdef", DecimalType(4, 2)).nullable === true)
assert(cast("abcdef", DoubleType).nullable === true)
assert(cast("abcdef", FloatType).nullable === true)
@ -291,9 +291,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
c.getTimeInMillis * 1000)
checkEvaluation(cast("abdef", StringType), "abdef")
checkEvaluation(cast("abdef", DecimalType.Unlimited), null)
checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null)
checkEvaluation(cast("abdef", TimestampType), null)
checkEvaluation(cast("12.65", DecimalType.Unlimited), Decimal(12.65))
checkEvaluation(cast("12.65", DecimalType.SYSTEM_DEFAULT), Decimal(12.65))
checkEvaluation(cast(cast(sd, DateType), StringType), sd)
checkEvaluation(cast(cast(d, StringType), DateType), 0)
@ -311,20 +311,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
5.toLong)
checkEvaluation(
cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType),
DecimalType.Unlimited), LongType), StringType), ShortType),
DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType),
0.toShort)
checkEvaluation(
cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType),
DecimalType.Unlimited), LongType), StringType), ShortType),
DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType),
null)
checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.Unlimited),
checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT),
ByteType), TimestampType), LongType), StringType), ShortType),
0.toShort)
checkEvaluation(cast("23", DoubleType), 23d)
checkEvaluation(cast("23", IntegerType), 23)
checkEvaluation(cast("23", FloatType), 23f)
checkEvaluation(cast("23", DecimalType.Unlimited), Decimal(23))
checkEvaluation(cast("23", DecimalType.USER_DEFAULT), Decimal(23))
checkEvaluation(cast("23", ByteType), 23.toByte)
checkEvaluation(cast("23", ShortType), 23.toShort)
checkEvaluation(cast("2012-12-11", DoubleType), null)
@ -338,7 +338,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Add(Literal(23d), cast(true, DoubleType)), 24d)
checkEvaluation(Add(Literal(23), cast(true, IntegerType)), 24)
checkEvaluation(Add(Literal(23f), cast(true, FloatType)), 24f)
checkEvaluation(Add(Literal(Decimal(23)), cast(true, DecimalType.Unlimited)), Decimal(24))
checkEvaluation(Add(Literal(Decimal(23)), cast(true, DecimalType.USER_DEFAULT)), Decimal(24))
checkEvaluation(Add(Literal(23.toByte), cast(true, ByteType)), 24.toByte)
checkEvaluation(Add(Literal(23.toShort), cast(true, ShortType)), 24.toShort)
}
@ -362,10 +362,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
// - Values that would overflow the target precision should turn into null
// - Because of this, casts to fixed-precision decimals should be nullable
assert(cast(123, DecimalType.Unlimited).nullable === false)
assert(cast(10.03f, DecimalType.Unlimited).nullable === true)
assert(cast(10.03, DecimalType.Unlimited).nullable === true)
assert(cast(Decimal(10.03), DecimalType.Unlimited).nullable === false)
assert(cast(123, DecimalType.USER_DEFAULT).nullable === true)
assert(cast(10.03f, DecimalType.SYSTEM_DEFAULT).nullable === true)
assert(cast(10.03, DecimalType.SYSTEM_DEFAULT).nullable === true)
assert(cast(Decimal(10.03), DecimalType.SYSTEM_DEFAULT).nullable === true)
assert(cast(123, DecimalType(2, 1)).nullable === true)
assert(cast(10.03f, DecimalType(2, 1)).nullable === true)
@ -373,7 +373,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(cast(Decimal(10.03), DecimalType(2, 1)).nullable === true)
checkEvaluation(cast(10.03, DecimalType.Unlimited), Decimal(10.03))
checkEvaluation(cast(10.03, DecimalType.SYSTEM_DEFAULT), Decimal(10.03))
checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03))
checkEvaluation(cast(10.03, DecimalType(3, 1)), Decimal(10.0))
checkEvaluation(cast(10.03, DecimalType(2, 0)), Decimal(10))
@ -383,7 +383,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(Decimal(10.03), DecimalType(3, 1)), Decimal(10.0))
checkEvaluation(cast(Decimal(10.03), DecimalType(3, 2)), null)
checkEvaluation(cast(10.05, DecimalType.Unlimited), Decimal(10.05))
checkEvaluation(cast(10.05, DecimalType.SYSTEM_DEFAULT), Decimal(10.05))
checkEvaluation(cast(10.05, DecimalType(4, 2)), Decimal(10.05))
checkEvaluation(cast(10.05, DecimalType(3, 1)), Decimal(10.1))
checkEvaluation(cast(10.05, DecimalType(2, 0)), Decimal(10))
@ -409,10 +409,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(Decimal(-9.95), DecimalType(3, 1)), Decimal(-10.0))
checkEvaluation(cast(Decimal(-9.95), DecimalType(1, 0)), null)
checkEvaluation(cast(Double.NaN, DecimalType.Unlimited), null)
checkEvaluation(cast(1.0 / 0.0, DecimalType.Unlimited), null)
checkEvaluation(cast(Float.NaN, DecimalType.Unlimited), null)
checkEvaluation(cast(1.0f / 0.0f, DecimalType.Unlimited), null)
checkEvaluation(cast(Double.NaN, DecimalType.SYSTEM_DEFAULT), null)
checkEvaluation(cast(1.0 / 0.0, DecimalType.SYSTEM_DEFAULT), null)
checkEvaluation(cast(Float.NaN, DecimalType.SYSTEM_DEFAULT), null)
checkEvaluation(cast(1.0f / 0.0f, DecimalType.SYSTEM_DEFAULT), null)
checkEvaluation(cast(Double.NaN, DecimalType(2, 1)), null)
checkEvaluation(cast(1.0 / 0.0, DecimalType(2, 1)), null)
@ -427,7 +427,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(d, LongType), null)
checkEvaluation(cast(d, FloatType), null)
checkEvaluation(cast(d, DoubleType), null)
checkEvaluation(cast(d, DecimalType.Unlimited), null)
checkEvaluation(cast(d, DecimalType.SYSTEM_DEFAULT), null)
checkEvaluation(cast(d, DecimalType(10, 2)), null)
checkEvaluation(cast(d, StringType), "1970-01-01")
checkEvaluation(cast(cast(d, TimestampType), StringType), "1970-01-01 00:00:00")
@ -454,7 +454,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
cast(cast(millis.toDouble / 1000, TimestampType), DoubleType),
millis.toDouble / 1000)
checkEvaluation(
cast(cast(Decimal(1), TimestampType), DecimalType.Unlimited),
cast(cast(Decimal(1), TimestampType), DecimalType.SYSTEM_DEFAULT),
Decimal(1))
// A test for higher precision than millis

View file

@ -60,7 +60,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
testIf(_.toFloat, FloatType)
testIf(_.toDouble, DoubleType)
testIf(Decimal(_), DecimalType.Unlimited)
testIf(Decimal(_), DecimalType.USER_DEFAULT)
testIf(identity, DateType)
testIf(_.toLong, TimestampType)

View file

@ -33,7 +33,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal.create(null, LongType), null)
checkEvaluation(Literal.create(null, StringType), null)
checkEvaluation(Literal.create(null, BinaryType), null)
checkEvaluation(Literal.create(null, DecimalType()), null)
checkEvaluation(Literal.create(null, DecimalType.USER_DEFAULT), null)
checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null)
checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null)
checkEvaluation(Literal.create(null, StructType(Seq.empty)), null)

View file

@ -30,7 +30,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testFunc(1L, LongType)
testFunc(1.0F, FloatType)
testFunc(1.0, DoubleType)
testFunc(Decimal(1.5), DecimalType.Unlimited)
testFunc(Decimal(1.5), DecimalType(2, 1))
testFunc(new java.sql.Date(10), DateType)
testFunc(new java.sql.Timestamp(10), TimestampType)
testFunc("abcd", StringType)

View file

@ -121,6 +121,8 @@ class UnsafeFixedWidthAggregationMapSuite
}.toSet
seenKeys.size should be (groupKeys.size)
seenKeys should be (groupKeys)
map.free()
}
}

View file

@ -145,7 +145,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
DoubleType,
StringType,
BinaryType
// DecimalType.Unlimited,
// DecimalType.Default,
// ArrayType(IntegerType)
)
val converter = new UnsafeRowConverter(fieldTypes)

View file

@ -44,7 +44,7 @@ class DataTypeParserSuite extends SparkFunSuite {
checkDataType("float", FloatType)
checkDataType("dOUBle", DoubleType)
checkDataType("decimal(10, 5)", DecimalType(10, 5))
checkDataType("decimal", DecimalType.Unlimited)
checkDataType("decimal", DecimalType.USER_DEFAULT)
checkDataType("DATE", DateType)
checkDataType("timestamp", TimestampType)
checkDataType("string", StringType)
@ -87,7 +87,7 @@ class DataTypeParserSuite extends SparkFunSuite {
StructType(
StructField("struct",
StructType(
StructField("deciMal", DecimalType.Unlimited, true) ::
StructField("deciMal", DecimalType.USER_DEFAULT, true) ::
StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) ::
StructField("MAP", MapType(TimestampType, StringType), true) ::
StructField("arrAy", ArrayType(DoubleType, true), true) :: Nil)

View file

@ -185,7 +185,7 @@ class DataTypeSuite extends SparkFunSuite {
checkDataTypeJsonRepr(FloatType)
checkDataTypeJsonRepr(DoubleType)
checkDataTypeJsonRepr(DecimalType(10, 5))
checkDataTypeJsonRepr(DecimalType.Unlimited)
checkDataTypeJsonRepr(DecimalType.SYSTEM_DEFAULT)
checkDataTypeJsonRepr(DateType)
checkDataTypeJsonRepr(TimestampType)
checkDataTypeJsonRepr(StringType)
@ -219,7 +219,7 @@ class DataTypeSuite extends SparkFunSuite {
checkDefaultSize(FloatType, 4)
checkDefaultSize(DoubleType, 8)
checkDefaultSize(DecimalType(10, 5), 4096)
checkDefaultSize(DecimalType.Unlimited, 4096)
checkDefaultSize(DecimalType.SYSTEM_DEFAULT, 4096)
checkDefaultSize(DateType, 4)
checkDefaultSize(TimestampType, 8)
checkDefaultSize(StringType, 4096)

View file

@ -34,7 +34,7 @@ object DataTypeTestUtils {
* decimal types.
*/
val fractionalTypes: Set[FractionalType] = Set(
DecimalType(precisionInfo = None),
DecimalType.SYSTEM_DEFAULT,
DecimalType(2, 1),
DoubleType,
FloatType

View file

@ -996,7 +996,7 @@ class ColumnName(name: String) extends Column(name) {
* Creates a new [[StructField]] of type decimal.
* @since 1.3.0
*/
def decimal: StructField = StructField(name, DecimalType.Unlimited)
def decimal: StructField = StructField(name, DecimalType.USER_DEFAULT)
/**
* Creates a new [[StructField]] of type decimal.

View file

@ -375,7 +375,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 8) {
private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
extends NativeColumnType(
DecimalType(Some(PrecisionInfo(precision, scale))),
DecimalType(precision, scale),
10,
FIXED_DECIMAL.defaultSize) {

View file

@ -21,9 +21,9 @@ import org.apache.spark.TaskContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.trees._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.trees._
import org.apache.spark.sql.types._
case class AggregateEvaluation(
@ -92,8 +92,8 @@ case class GeneratedAggregate(
case s @ Sum(expr) =>
val calcType =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 10, s)
case _ =>
expr.dataType
}
@ -121,8 +121,8 @@ case class GeneratedAggregate(
case cs @ CombineSum(expr) =>
val calcType =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 10, s)
case _ =>
expr.dataType
}

View file

@ -25,6 +25,7 @@ import scala.util.Try
import org.apache.hadoop.fs.Path
import org.apache.hadoop.util.Shell
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
import org.apache.spark.sql.types._
@ -236,7 +237,7 @@ private[sql] object PartitioningUtils {
/**
* Converts a string to a [[Literal]] with automatic type inference. Currently only supports
* [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.Unlimited]], and
* [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.SYSTEM_DEFAULT]], and
* [[StringType]].
*/
private[sql] def inferPartitionColumnValue(
@ -249,7 +250,7 @@ private[sql] object PartitioningUtils {
.orElse(Try(Literal.create(JLong.parseLong(raw), LongType)))
// Then falls back to fractional types
.orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType)))
.orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited)))
.orElse(Try(Literal(new JBigDecimal(raw))))
// Then falls back to string
.getOrElse {
if (raw == defaultPartitionName) {
@ -268,7 +269,7 @@ private[sql] object PartitioningUtils {
}
private val upCastingOrder: Seq[DataType] =
Seq(NullType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited, StringType)
Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType)
/**
* Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower"

View file

@ -66,8 +66,8 @@ private[sql] object JDBCRDD extends Logging {
case java.sql.Types.DATALINK => null
case java.sql.Types.DATE => DateType
case java.sql.Types.DECIMAL
if precision != 0 || scale != 0 => DecimalType(precision, scale)
case java.sql.Types.DECIMAL => DecimalType.Unlimited
if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT
case java.sql.Types.DISTINCT => null
case java.sql.Types.DOUBLE => DoubleType
case java.sql.Types.FLOAT => FloatType
@ -80,8 +80,8 @@ private[sql] object JDBCRDD extends Logging {
case java.sql.Types.NCLOB => StringType
case java.sql.Types.NULL => null
case java.sql.Types.NUMERIC
if precision != 0 || scale != 0 => DecimalType(precision, scale)
case java.sql.Types.NUMERIC => DecimalType.Unlimited
if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT
case java.sql.Types.NVARCHAR => StringType
case java.sql.Types.OTHER => null
case java.sql.Types.REAL => DoubleType
@ -314,7 +314,7 @@ private[sql] class JDBCRDD(
abstract class JDBCConversion
case object BooleanConversion extends JDBCConversion
case object DateConversion extends JDBCConversion
case class DecimalConversion(precisionInfo: Option[(Int, Int)]) extends JDBCConversion
case class DecimalConversion(precision: Int, scale: Int) extends JDBCConversion
case object DoubleConversion extends JDBCConversion
case object FloatConversion extends JDBCConversion
case object IntegerConversion extends JDBCConversion
@ -331,8 +331,7 @@ private[sql] class JDBCRDD(
schema.fields.map(sf => sf.dataType match {
case BooleanType => BooleanConversion
case DateType => DateConversion
case DecimalType.Unlimited => DecimalConversion(None)
case DecimalType.Fixed(d) => DecimalConversion(Some(d))
case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
case DoubleType => DoubleConversion
case FloatType => FloatConversion
case IntegerType => IntegerConversion
@ -399,20 +398,13 @@ private[sql] class JDBCRDD(
// DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
// retrieve it, you will get wrong result 199.99.
// So it is needed to set precision and scale for Decimal based on JDBC metadata.
case DecimalConversion(Some((p, s))) =>
case DecimalConversion(p, s) =>
val decimalVal = rs.getBigDecimal(pos)
if (decimalVal == null) {
mutableRow.update(i, null)
} else {
mutableRow.update(i, Decimal(decimalVal, p, s))
}
case DecimalConversion(None) =>
val decimalVal = rs.getBigDecimal(pos)
if (decimalVal == null) {
mutableRow.update(i, null)
} else {
mutableRow.update(i, Decimal(decimalVal))
}
case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos))

View file

@ -89,8 +89,7 @@ package object jdbc {
case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i))
case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
case DecimalType.Unlimited => stmt.setBigDecimal(i + 1,
row.getAs[java.math.BigDecimal](i))
case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
case _ => throw new IllegalArgumentException(
s"Can't translate non-null value for field $i")
}
@ -145,7 +144,7 @@ package object jdbc {
case BinaryType => "BLOB"
case TimestampType => "TIMESTAMP"
case DateType => "DATE"
case DecimalType.Unlimited => "DECIMAL(40,20)"
case t: DecimalType => s"DECIMAL(${t.precision}},${t.scale}})"
case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
})
val nullable = if (field.nullable) "" else "NOT NULL"
@ -177,7 +176,7 @@ package object jdbc {
case BinaryType => java.sql.Types.BLOB
case TimestampType => java.sql.Types.TIMESTAMP
case DateType => java.sql.Types.DATE
case DecimalType.Unlimited => java.sql.Types.DECIMAL
case t: DecimalType => java.sql.Types.DECIMAL
case _ => throw new IllegalArgumentException(
s"Can't translate null value for field $field")
})

View file

@ -113,7 +113,7 @@ private[sql] object InferSchema {
case INT | LONG => LongType
// Since we do not have a data type backed by BigInteger,
// when we see a Java BigInteger, we use DecimalType.
case BIG_INTEGER | BIG_DECIMAL => DecimalType.Unlimited
case BIG_INTEGER | BIG_DECIMAL => DecimalType.SYSTEM_DEFAULT
case FLOAT | DOUBLE => DoubleType
}
@ -168,8 +168,13 @@ private[sql] object InferSchema {
HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse {
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
(t1, t2) match {
case (other: DataType, NullType) => other
case (NullType, other: DataType) => other
// Double support larger range than fixed decimal, DecimalType.Maximum should be enough
// in most case, also have better precision.
case (DoubleType, t: DecimalType) =>
if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType
case (t: DecimalType, DoubleType) =>
if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType
case (StructType(fields1), StructType(fields2)) =>
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
case (name, fieldTypes) =>

View file

@ -439,10 +439,6 @@ private[parquet] class CatalystSchemaConverter(
.length(minBytesForPrecision(precision))
.named(field.name)
case dec @ DecimalType.Unlimited if followParquetFormatSpec =>
throw new AnalysisException(
s"Data type $dec is not supported. Decimal precision and scale must be specified.")
// ===================================================
// ArrayType and MapType (for Spark versions <= 1.4.x)
// ===================================================

View file

@ -261,10 +261,10 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
case BinaryType => writer.addBinary(
Binary.fromByteArray(value.asInstanceOf[Array[Byte]]))
case d: DecimalType =>
if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
if (d.precision > 18) {
sys.error(s"Unsupported datatype $d, cannot write to consumer")
}
writeDecimal(value.asInstanceOf[Decimal], d.precisionInfo.get.precision)
writeDecimal(value.asInstanceOf[Decimal], d.precision)
case _ => sys.error(s"Do not know how to writer $schema to consumer")
}
}
@ -415,10 +415,10 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
case BinaryType => writer.addBinary(
Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]]))
case d: DecimalType =>
if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
if (d.precision > 18) {
sys.error(s"Unsupported datatype $d, cannot write to consumer")
}
writeDecimal(record(index).asInstanceOf[Decimal], d.precisionInfo.get.precision)
writeDecimal(record(index).asInstanceOf[Decimal], d.precision)
case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer")
}
}

View file

@ -22,7 +22,6 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.sql.test.TestSQLContext$;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
@ -31,8 +30,14 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.*;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
// The test suite itself is Serializable so that anonymous Function implementations can be
// serialized, as an alternative to converting these anonymous classes to static inner classes;
@ -159,7 +164,8 @@ public class JavaApplySchemaSuite implements Serializable {
"\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " +
"\"boolean\":false, \"null\":null}"));
List<StructField> fields = new ArrayList<StructField>(7);
fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(), true));
fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(38, 18),
true));
fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true));
fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true));
fields.add(DataTypes.createStructField("integer", DataTypes.LongType, true));

View file

@ -148,7 +148,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
val dataTypes =
Seq(StringType, BinaryType, NullType, BooleanType,
ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5),
FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5),
DateType, TimestampType,
ArrayType(IntegerType), MapType(StringType, LongType), struct)
val fields = dataTypes.zipWithIndex.map { case (dataType, index) =>

View file

@ -109,7 +109,7 @@ class PlannerSuite extends SparkFunSuite {
FloatType ::
DoubleType ::
DecimalType(10, 5) ::
DecimalType.Unlimited ::
DecimalType.SYSTEM_DEFAULT ::
DateType ::
TimestampType ::
StringType ::

View file

@ -54,7 +54,7 @@ class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite {
checkSupported(StringType, isSupported = true)
checkSupported(BinaryType, isSupported = true)
checkSupported(DecimalType(10, 5), isSupported = true)
checkSupported(DecimalType.Unlimited, isSupported = true)
checkSupported(DecimalType.SYSTEM_DEFAULT, isSupported = true)
// If NullType is the only data type in the schema, we do not support it.
checkSupported(NullType, isSupported = false)
@ -86,7 +86,7 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
val supportedTypes =
Seq(StringType, BinaryType, NullType, BooleanType,
ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5),
FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5),
DateType, TimestampType)
val fields = supportedTypes.zipWithIndex.map { case (dataType, index) =>

View file

@ -134,7 +134,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
""".stripMargin.replaceAll("\n", " "))
conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(40, 20))"
conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(38, 18))"
).executeUpdate()
conn.prepareStatement("insert into test.flttypes values ("
+ "1.0000000000000002220446049250313080847263336181640625, "
@ -152,7 +152,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
s"""
|create table test.nulltypes (a INT, b BOOLEAN, c TINYINT, d BINARY(20), e VARCHAR(20),
|f VARCHAR_IGNORECASE(20), g CHAR(20), h BLOB, i CLOB, j TIME, k DATE, l TIMESTAMP,
|m DOUBLE, n REAL, o DECIMAL(40, 20))
|m DOUBLE, n REAL, o DECIMAL(38, 18))
""".stripMargin.replaceAll("\n", " ")).executeUpdate()
conn.prepareStatement("insert into test.nulltypes values ("
+ "null, null, null, null, null, null, null, null, null, "
@ -357,14 +357,14 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
test("H2 floating-point types") {
val rows = sql("SELECT * FROM flttypes").collect()
assert(rows(0).getDouble(0) === 1.00000000000000022) // Yes, I meant ==.
assert(rows(0).getDouble(1) === 1.00000011920928955) // Yes, I meant ==.
assert(rows(0).getAs[BigDecimal](2)
.equals(new BigDecimal("123456789012345.54321543215432100000")))
assert(rows(0).schema.fields(2).dataType === DecimalType(40, 20))
val compareDecimal = sql("SELECT C FROM flttypes where C > C - 1").collect()
assert(compareDecimal(0).getAs[BigDecimal](0)
.equals(new BigDecimal("123456789012345.54321543215432100000")))
assert(rows(0).getDouble(0) === 1.00000000000000022)
assert(rows(0).getDouble(1) === 1.00000011920928955)
assert(rows(0).getAs[BigDecimal](2) ===
new BigDecimal("123456789012345.543215432154321000"))
assert(rows(0).schema.fields(2).dataType === DecimalType(38, 18))
val result = sql("SELECT C FROM flttypes where C > C - 1").collect()
assert(result(0).getAs[BigDecimal](0) ===
new BigDecimal("123456789012345.543215432154321000"))
}
test("SQL query as table name") {

View file

@ -63,18 +63,18 @@ class JsonSuite extends QueryTest with TestJsonData {
checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType))
checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType))
checkTypePromotion(
Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.Unlimited))
Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.SYSTEM_DEFAULT))
val longNumber: Long = 9223372036854775807L
checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType))
checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType))
checkTypePromotion(
Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.Unlimited))
Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.SYSTEM_DEFAULT))
val doubleNumber: Double = 1.7976931348623157E308d
checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType))
checkTypePromotion(
Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited))
Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.SYSTEM_DEFAULT))
checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber)),
enforceCorrectType(intNumber, TimestampType))
@ -115,7 +115,7 @@ class JsonSuite extends QueryTest with TestJsonData {
checkDataType(NullType, IntegerType, IntegerType)
checkDataType(NullType, LongType, LongType)
checkDataType(NullType, DoubleType, DoubleType)
checkDataType(NullType, DecimalType.Unlimited, DecimalType.Unlimited)
checkDataType(NullType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT)
checkDataType(NullType, StringType, StringType)
checkDataType(NullType, ArrayType(IntegerType), ArrayType(IntegerType))
checkDataType(NullType, StructType(Nil), StructType(Nil))
@ -126,7 +126,7 @@ class JsonSuite extends QueryTest with TestJsonData {
checkDataType(BooleanType, IntegerType, StringType)
checkDataType(BooleanType, LongType, StringType)
checkDataType(BooleanType, DoubleType, StringType)
checkDataType(BooleanType, DecimalType.Unlimited, StringType)
checkDataType(BooleanType, DecimalType.SYSTEM_DEFAULT, StringType)
checkDataType(BooleanType, StringType, StringType)
checkDataType(BooleanType, ArrayType(IntegerType), StringType)
checkDataType(BooleanType, StructType(Nil), StringType)
@ -135,7 +135,7 @@ class JsonSuite extends QueryTest with TestJsonData {
checkDataType(IntegerType, IntegerType, IntegerType)
checkDataType(IntegerType, LongType, LongType)
checkDataType(IntegerType, DoubleType, DoubleType)
checkDataType(IntegerType, DecimalType.Unlimited, DecimalType.Unlimited)
checkDataType(IntegerType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT)
checkDataType(IntegerType, StringType, StringType)
checkDataType(IntegerType, ArrayType(IntegerType), StringType)
checkDataType(IntegerType, StructType(Nil), StringType)
@ -143,23 +143,24 @@ class JsonSuite extends QueryTest with TestJsonData {
// LongType
checkDataType(LongType, LongType, LongType)
checkDataType(LongType, DoubleType, DoubleType)
checkDataType(LongType, DecimalType.Unlimited, DecimalType.Unlimited)
checkDataType(LongType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT)
checkDataType(LongType, StringType, StringType)
checkDataType(LongType, ArrayType(IntegerType), StringType)
checkDataType(LongType, StructType(Nil), StringType)
// DoubleType
checkDataType(DoubleType, DoubleType, DoubleType)
checkDataType(DoubleType, DecimalType.Unlimited, DecimalType.Unlimited)
checkDataType(DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT)
checkDataType(DoubleType, StringType, StringType)
checkDataType(DoubleType, ArrayType(IntegerType), StringType)
checkDataType(DoubleType, StructType(Nil), StringType)
// DoubleType
checkDataType(DecimalType.Unlimited, DecimalType.Unlimited, DecimalType.Unlimited)
checkDataType(DecimalType.Unlimited, StringType, StringType)
checkDataType(DecimalType.Unlimited, ArrayType(IntegerType), StringType)
checkDataType(DecimalType.Unlimited, StructType(Nil), StringType)
// DecimalType
checkDataType(DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT,
DecimalType.SYSTEM_DEFAULT)
checkDataType(DecimalType.SYSTEM_DEFAULT, StringType, StringType)
checkDataType(DecimalType.SYSTEM_DEFAULT, ArrayType(IntegerType), StringType)
checkDataType(DecimalType.SYSTEM_DEFAULT, StructType(Nil), StringType)
// StringType
checkDataType(StringType, StringType, StringType)
@ -213,7 +214,7 @@ class JsonSuite extends QueryTest with TestJsonData {
checkDataType(
StructType(
StructField("f1", IntegerType, true) :: Nil),
DecimalType.Unlimited,
DecimalType.SYSTEM_DEFAULT,
StringType)
}
@ -240,7 +241,7 @@ class JsonSuite extends QueryTest with TestJsonData {
val jsonDF = ctx.read.json(primitiveFieldAndType)
val expectedSchema = StructType(
StructField("bigInteger", DecimalType.Unlimited, true) ::
StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
StructField("integer", LongType, true) ::
@ -270,7 +271,7 @@ class JsonSuite extends QueryTest with TestJsonData {
val expectedSchema = StructType(
StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) ::
StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, true), true), true) ::
StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, true), true) ::
StructField("arrayOfBigInteger", ArrayType(DecimalType.SYSTEM_DEFAULT, true), true) ::
StructField("arrayOfBoolean", ArrayType(BooleanType, true), true) ::
StructField("arrayOfDouble", ArrayType(DoubleType, true), true) ::
StructField("arrayOfInteger", ArrayType(LongType, true), true) ::
@ -284,7 +285,7 @@ class JsonSuite extends QueryTest with TestJsonData {
StructField("field3", StringType, true) :: Nil), true), true) ::
StructField("struct", StructType(
StructField("field1", BooleanType, true) ::
StructField("field2", DecimalType.Unlimited, true) :: Nil), true) ::
StructField("field2", DecimalType.SYSTEM_DEFAULT, true) :: Nil), true) ::
StructField("structWithArrayFields", StructType(
StructField("field1", ArrayType(LongType, true), true) ::
StructField("field2", ArrayType(StringType, true), true) :: Nil), true) :: Nil)
@ -385,7 +386,7 @@ class JsonSuite extends QueryTest with TestJsonData {
val expectedSchema = StructType(
StructField("num_bool", StringType, true) ::
StructField("num_num_1", LongType, true) ::
StructField("num_num_2", DecimalType.Unlimited, true) ::
StructField("num_num_2", DecimalType.SYSTEM_DEFAULT, true) ::
StructField("num_num_3", DoubleType, true) ::
StructField("num_str", StringType, true) ::
StructField("str_bool", StringType, true) :: Nil)
@ -421,11 +422,11 @@ class JsonSuite extends QueryTest with TestJsonData {
Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil
)
// Widening to DecimalType
// Widening to DoubleType
checkAnswer(
sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"),
Row(new java.math.BigDecimal("21474836472.1")) ::
Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil
sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"),
Row(21474836472.2) ::
Row(92233720368547758071.3) :: Nil
)
// Widening to DoubleType
@ -442,8 +443,8 @@ class JsonSuite extends QueryTest with TestJsonData {
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"),
Row(new java.math.BigDecimal("92233720368547758061.2").doubleValue)
sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"),
Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue)
)
// String and Boolean conflict: resolve the type as string.
@ -489,9 +490,9 @@ class JsonSuite extends QueryTest with TestJsonData {
// in the Project.
checkAnswer(
jsonDF.
where('num_str > BigDecimal("92233720368547758060")).
where('num_str >= BigDecimal("92233720368547758060")).
select(('num_str + 1.2).as("num")),
Row(new java.math.BigDecimal("92233720368547758061.2"))
Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue())
)
// The following test will fail. The type of num_str is StringType.
@ -610,7 +611,7 @@ class JsonSuite extends QueryTest with TestJsonData {
val jsonDF = ctx.read.json(path)
val expectedSchema = StructType(
StructField("bigInteger", DecimalType.Unlimited, true) ::
StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
StructField("integer", LongType, true) ::
@ -668,7 +669,7 @@ class JsonSuite extends QueryTest with TestJsonData {
primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
val schema = StructType(
StructField("bigInteger", DecimalType.Unlimited, true) ::
StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
StructField("integer", IntegerType, true) ::

View file

@ -122,14 +122,6 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
sqlContext.read.parquet(dir.getCanonicalPath).collect()
}
}
// Unlimited-length decimals are not yet supported
intercept[Throwable] {
withTempPath { dir =>
makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath)
sqlContext.read.parquet(dir.getCanonicalPath).collect()
}
}
}
test("date type") {

View file

@ -509,7 +509,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
FloatType,
DoubleType,
DecimalType(10, 5),
DecimalType.Unlimited,
DecimalType.SYSTEM_DEFAULT,
DateType,
TimestampType,
StringType)

View file

@ -44,7 +44,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo
StructField("doubleType", DoubleType, nullable = false),
StructField("bigintType", LongType, nullable = false),
StructField("tinyintType", ByteType, nullable = false),
StructField("decimalType", DecimalType.Unlimited, nullable = false),
StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false),
StructField("fixedDecimalType", DecimalType(5, 1), nullable = false),
StructField("binaryType", BinaryType, nullable = false),
StructField("booleanType", BooleanType, nullable = false),

View file

@ -202,7 +202,7 @@ class TableScanSuite extends DataSourceTest {
StructField("longField_:,<>=+/~^", LongType, true) ::
StructField("floatField", FloatType, true) ::
StructField("doubleField", DoubleType, true) ::
StructField("decimalField1", DecimalType.Unlimited, true) ::
StructField("decimalField1", DecimalType.USER_DEFAULT, true) ::
StructField("decimalField2", DecimalType(9, 2), true) ::
StructField("dateField", DateType, true) ::
StructField("timestampField", TimestampType, true) ::

View file

@ -179,7 +179,7 @@ private[hive] trait HiveInspectors {
// writable
case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType
case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType
case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.Unlimited
case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.SYSTEM_DEFAULT
case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType
case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType
case c: Class[_] if c == classOf[hiveIo.DateWritable] => DateType
@ -195,8 +195,8 @@ private[hive] trait HiveInspectors {
case c: Class[_] if c == classOf[java.lang.String] => StringType
case c: Class[_] if c == classOf[java.sql.Date] => DateType
case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType
case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.Unlimited
case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.Unlimited
case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.SYSTEM_DEFAULT
case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.SYSTEM_DEFAULT
case c: Class[_] if c == classOf[Array[Byte]] => BinaryType
case c: Class[_] if c == classOf[java.lang.Short] => ShortType
case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
@ -813,9 +813,6 @@ private[hive] trait HiveInspectors {
private def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match {
case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale)
case _ => new DecimalTypeInfo(
HiveShim.UNLIMITED_DECIMAL_PRECISION,
HiveShim.UNLIMITED_DECIMAL_SCALE)
}
def toTypeInfo: TypeInfo = dt match {

View file

@ -377,7 +377,7 @@ private[hive] object HiveQl extends Logging {
DecimalType(precision.getText.toInt, scale.getText.toInt)
case Token("TOK_DECIMAL", precision :: Nil) =>
DecimalType(precision.getText.toInt, 0)
case Token("TOK_DECIMAL", Nil) => DecimalType.Unlimited
case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT
case Token("TOK_BIGINT", Nil) => LongType
case Token("TOK_INT", Nil) => IntegerType
case Token("TOK_TINYINT", Nil) => ByteType
@ -1369,7 +1369,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, 0))
case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), DecimalType.Unlimited)
Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT)
case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), TimestampType)
case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) =>