[SPARK-9069] [SPARK-9264] [SQL] remove unlimited precision support for DecimalType
Romove Decimal.Unlimited (change to support precision up to 38, to match with Hive and other databases). In order to keep backward source compatibility, Decimal.Unlimited is still there, but change to Decimal(38, 18). If no precision and scale is provide, it's Decimal(10, 0) as before. Author: Davies Liu <davies@databricks.com> Closes #7605 from davies/decimal_unlimited and squashes the following commits: aa3f115 [Davies Liu] fix tests and style fb0d20d [Davies Liu] address comments bfaae35 [Davies Liu] fix style df93657 [Davies Liu] address comments and clean up 06727fd [Davies Liu] Merge branch 'master' of github.com:apache/spark into decimal_unlimited 4c28969 [Davies Liu] fix tests 8d783cc [Davies Liu] fix tests 788631c [Davies Liu] fix double with decimal in Union/except 1779bde [Davies Liu] fix scala style c9c7c78 [Davies Liu] remove Decimal.Unlimited
This commit is contained in:
parent
bebe3f7b45
commit
8a94eb23d5
|
@ -218,7 +218,7 @@ class AttributeSuite extends SparkFunSuite {
|
|||
// Attribute.fromStructField should accept any NumericType, not just DoubleType
|
||||
val longFldWithMeta = new StructField("x", LongType, false, metadata)
|
||||
assert(Attribute.fromStructField(longFldWithMeta).isNumeric)
|
||||
val decimalFldWithMeta = new StructField("x", DecimalType(None), false, metadata)
|
||||
val decimalFldWithMeta = new StructField("x", DecimalType(38, 18), false, metadata)
|
||||
assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -194,30 +194,33 @@ class TimestampType(AtomicType):
|
|||
|
||||
class DecimalType(FractionalType):
|
||||
"""Decimal (decimal.Decimal) data type.
|
||||
|
||||
The DecimalType must have fixed precision (the maximum total number of digits)
|
||||
and scale (the number of digits on the right of dot). For example, (5, 2) can
|
||||
support the value from [-999.99 to 999.99].
|
||||
|
||||
The precision can be up to 38, the scale must less or equal to precision.
|
||||
|
||||
When create a DecimalType, the default precision and scale is (10, 0). When infer
|
||||
schema from decimal.Decimal objects, it will be DecimalType(38, 18).
|
||||
|
||||
:param precision: the maximum total number of digits (default: 10)
|
||||
:param scale: the number of digits on right side of dot. (default: 0)
|
||||
"""
|
||||
|
||||
def __init__(self, precision=None, scale=None):
|
||||
def __init__(self, precision=10, scale=0):
|
||||
self.precision = precision
|
||||
self.scale = scale
|
||||
self.hasPrecisionInfo = precision is not None
|
||||
self.hasPrecisionInfo = True # this is public API
|
||||
|
||||
def simpleString(self):
|
||||
if self.hasPrecisionInfo:
|
||||
return "decimal(%d,%d)" % (self.precision, self.scale)
|
||||
else:
|
||||
return "decimal(10,0)"
|
||||
return "decimal(%d,%d)" % (self.precision, self.scale)
|
||||
|
||||
def jsonValue(self):
|
||||
if self.hasPrecisionInfo:
|
||||
return "decimal(%d,%d)" % (self.precision, self.scale)
|
||||
else:
|
||||
return "decimal"
|
||||
return "decimal(%d,%d)" % (self.precision, self.scale)
|
||||
|
||||
def __repr__(self):
|
||||
if self.hasPrecisionInfo:
|
||||
return "DecimalType(%d,%d)" % (self.precision, self.scale)
|
||||
else:
|
||||
return "DecimalType()"
|
||||
return "DecimalType(%d,%d)" % (self.precision, self.scale)
|
||||
|
||||
|
||||
class DoubleType(FractionalType):
|
||||
|
@ -761,7 +764,10 @@ def _infer_type(obj):
|
|||
return obj.__UDT__
|
||||
|
||||
dataType = _type_mappings.get(type(obj))
|
||||
if dataType is not None:
|
||||
if dataType is DecimalType:
|
||||
# the precision and scale of `obj` may be different from row to row.
|
||||
return DecimalType(38, 18)
|
||||
elif dataType is not None:
|
||||
return dataType()
|
||||
|
||||
if isinstance(obj, dict):
|
||||
|
|
|
@ -111,12 +111,18 @@ public class DataTypes {
|
|||
return new ArrayType(elementType, containsNull);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a DecimalType by specifying the precision and scale.
|
||||
*/
|
||||
public static DecimalType createDecimalType(int precision, int scale) {
|
||||
return DecimalType$.MODULE$.apply(precision, scale);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a DecimalType with default precision and scale, which are 10 and 0.
|
||||
*/
|
||||
public static DecimalType createDecimalType() {
|
||||
return DecimalType$.MODULE$.Unlimited();
|
||||
return DecimalType$.MODULE$.USER_DEFAULT();
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -75,7 +75,7 @@ private [sql] object JavaTypeInference {
|
|||
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
|
||||
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
|
||||
|
||||
case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
|
||||
case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true)
|
||||
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
|
||||
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
|
||||
|
||||
|
|
|
@ -131,10 +131,10 @@ trait ScalaReflection {
|
|||
case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
|
||||
case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
|
||||
case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true)
|
||||
case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
|
||||
case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
|
||||
case t if t <:< localTypeOf[java.math.BigDecimal] =>
|
||||
Schema(DecimalType.Unlimited, nullable = true)
|
||||
case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
|
||||
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
|
||||
case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
|
||||
case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
|
||||
case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true)
|
||||
case t if t <:< localTypeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
|
||||
|
@ -167,8 +167,8 @@ trait ScalaReflection {
|
|||
case obj: Float => FloatType
|
||||
case obj: Double => DoubleType
|
||||
case obj: java.sql.Date => DateType
|
||||
case obj: java.math.BigDecimal => DecimalType.Unlimited
|
||||
case obj: Decimal => DecimalType.Unlimited
|
||||
case obj: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT
|
||||
case obj: Decimal => DecimalType.SYSTEM_DEFAULT
|
||||
case obj: java.sql.Timestamp => TimestampType
|
||||
case null => NullType
|
||||
// For other cases, there is no obvious mapping from the type of the given object to a
|
||||
|
|
|
@ -322,7 +322,10 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
|
|||
|
||||
protected lazy val numericLiteral: Parser[Literal] =
|
||||
( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) }
|
||||
| sign.? ~ unsignedFloat ^^ { case s ~ f => Literal((s.getOrElse("") + f).toDouble) }
|
||||
| sign.? ~ unsignedFloat ^^ {
|
||||
// TODO(davies): some precisions may loss, we should create decimal literal
|
||||
case s ~ f => Literal(BigDecimal(s.getOrElse("") + f).doubleValue())
|
||||
}
|
||||
)
|
||||
|
||||
protected lazy val unsignedFloat: Parser[String] =
|
||||
|
|
|
@ -19,7 +19,9 @@ package org.apache.spark.sql.catalyst.analysis
|
|||
|
||||
import javax.annotation.Nullable
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -58,8 +60,7 @@ object HiveTypeCoercion {
|
|||
IntegerType,
|
||||
LongType,
|
||||
FloatType,
|
||||
DoubleType,
|
||||
DecimalType.Unlimited)
|
||||
DoubleType)
|
||||
|
||||
/**
|
||||
* Find the tightest common type of two types that might be used in a binary expression.
|
||||
|
@ -72,15 +73,16 @@ object HiveTypeCoercion {
|
|||
case (NullType, t1) => Some(t1)
|
||||
case (t1, NullType) => Some(t1)
|
||||
|
||||
// Promote numeric types to the highest of the two and all numeric types to unlimited decimal
|
||||
case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) =>
|
||||
Some(t2)
|
||||
case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) =>
|
||||
Some(t1)
|
||||
|
||||
// Promote numeric types to the highest of the two
|
||||
case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
|
||||
val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2)
|
||||
Some(numericPrecedence(index))
|
||||
|
||||
// Fixed-precision decimals can up-cast into unlimited
|
||||
case (DecimalType.Unlimited, _: DecimalType) => Some(DecimalType.Unlimited)
|
||||
case (_: DecimalType, DecimalType.Unlimited) => Some(DecimalType.Unlimited)
|
||||
|
||||
case _ => None
|
||||
}
|
||||
|
||||
|
@ -101,7 +103,7 @@ object HiveTypeCoercion {
|
|||
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
|
||||
case None => None
|
||||
case Some(d) =>
|
||||
findTightestCommonTypeOfTwo(d, c).orElse(findTightestCommonTypeToString(d, c))
|
||||
findTightestCommonTypeToString(d, c)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -158,6 +160,9 @@ object HiveTypeCoercion {
|
|||
* converted to DOUBLE.
|
||||
* - TINYINT, SMALLINT, and INT can all be converted to FLOAT.
|
||||
* - BOOLEAN types cannot be converted to any other type.
|
||||
* - Any integral numeric type can be implicitly converted to decimal type.
|
||||
* - two different decimal types will be converted into a wider decimal type for both of them.
|
||||
* - decimal type will be converted into double if there float or double together with it.
|
||||
*
|
||||
* Additionally, all types when UNION-ed with strings will be promoted to strings.
|
||||
* Other string conversions are handled by PromoteStrings.
|
||||
|
@ -166,55 +171,50 @@ object HiveTypeCoercion {
|
|||
* - IntegerType to FloatType
|
||||
* - LongType to FloatType
|
||||
* - LongType to DoubleType
|
||||
* - DecimalType to Double
|
||||
*
|
||||
* This rule is only applied to Union/Except/Intersect
|
||||
*/
|
||||
object WidenTypes extends Rule[LogicalPlan] {
|
||||
|
||||
private[this] def widenOutputTypes(planName: String, left: LogicalPlan, right: LogicalPlan):
|
||||
(LogicalPlan, LogicalPlan) = {
|
||||
|
||||
// TODO: with fixed-precision decimals
|
||||
val castedInput = left.output.zip(right.output).map {
|
||||
// When a string is found on one side, make the other side a string too.
|
||||
case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType =>
|
||||
(lhs, Alias(Cast(rhs, StringType), rhs.name)())
|
||||
case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType =>
|
||||
(Alias(Cast(lhs, StringType), lhs.name)(), rhs)
|
||||
private[this] def widenOutputTypes(
|
||||
planName: String,
|
||||
left: LogicalPlan,
|
||||
right: LogicalPlan): (LogicalPlan, LogicalPlan) = {
|
||||
|
||||
val castedTypes = left.output.zip(right.output).map {
|
||||
case (lhs, rhs) if lhs.dataType != rhs.dataType =>
|
||||
logDebug(s"Resolving mismatched $planName input ${lhs.dataType}, ${rhs.dataType}")
|
||||
findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType =>
|
||||
val newLeft =
|
||||
if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)()
|
||||
val newRight =
|
||||
if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)()
|
||||
|
||||
(newLeft, newRight)
|
||||
}.getOrElse {
|
||||
// If there is no applicable conversion, leave expression unchanged.
|
||||
(lhs, rhs)
|
||||
(lhs.dataType, rhs.dataType) match {
|
||||
case (t1: DecimalType, t2: DecimalType) =>
|
||||
Some(DecimalPrecision.widerDecimalType(t1, t2))
|
||||
case (t: IntegralType, d: DecimalType) =>
|
||||
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
|
||||
case (d: DecimalType, t: IntegralType) =>
|
||||
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
|
||||
case (t: FractionalType, d: DecimalType) =>
|
||||
Some(DoubleType)
|
||||
case (d: DecimalType, t: FractionalType) =>
|
||||
Some(DoubleType)
|
||||
case _ =>
|
||||
findTightestCommonTypeToString(lhs.dataType, rhs.dataType)
|
||||
}
|
||||
|
||||
case other => other
|
||||
case other => None
|
||||
}
|
||||
|
||||
val (castedLeft, castedRight) = castedInput.unzip
|
||||
|
||||
val newLeft =
|
||||
if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
|
||||
logDebug(s"Widening numeric types in $planName $castedLeft ${left.output}")
|
||||
Project(castedLeft, left)
|
||||
} else {
|
||||
left
|
||||
def castOutput(plan: LogicalPlan): LogicalPlan = {
|
||||
val casted = plan.output.zip(castedTypes).map {
|
||||
case (hs, Some(dt)) if dt != hs.dataType =>
|
||||
Alias(Cast(hs, dt), hs.name)()
|
||||
case (hs, _) => hs
|
||||
}
|
||||
Project(casted, plan)
|
||||
}
|
||||
|
||||
val newRight =
|
||||
if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
|
||||
logDebug(s"Widening numeric types in $planName $castedRight ${right.output}")
|
||||
Project(castedRight, right)
|
||||
} else {
|
||||
right
|
||||
}
|
||||
(newLeft, newRight)
|
||||
if (castedTypes.exists(_.isDefined)) {
|
||||
(castOutput(left), castOutput(right))
|
||||
} else {
|
||||
(left, right)
|
||||
}
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
|
@ -334,144 +334,94 @@ object HiveTypeCoercion {
|
|||
* - SHORT gets turned into DECIMAL(5, 0)
|
||||
* - INT gets turned into DECIMAL(10, 0)
|
||||
* - LONG gets turned into DECIMAL(20, 0)
|
||||
* - FLOAT and DOUBLE
|
||||
* 1. Union, Intersect and Except operations:
|
||||
* FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this is the
|
||||
* same as Hive)
|
||||
* 2. Other operation:
|
||||
* FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive,
|
||||
* but note that unlimited decimals are considered bigger than doubles in WidenTypes)
|
||||
* - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE
|
||||
*
|
||||
* Note: Union/Except/Interact is handled by WidenTypes
|
||||
*/
|
||||
// scalastyle:on
|
||||
object DecimalPrecision extends Rule[LogicalPlan] {
|
||||
import scala.math.{max, min}
|
||||
|
||||
// Conversion rules for integer types into fixed-precision decimals
|
||||
private val intTypeToFixed: Map[DataType, DecimalType] = Map(
|
||||
ByteType -> DecimalType(3, 0),
|
||||
ShortType -> DecimalType(5, 0),
|
||||
IntegerType -> DecimalType(10, 0),
|
||||
LongType -> DecimalType(20, 0)
|
||||
)
|
||||
|
||||
private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
|
||||
|
||||
// Conversion rules for float and double into fixed-precision decimals
|
||||
private val floatTypeToFixed: Map[DataType, DecimalType] = Map(
|
||||
FloatType -> DecimalType(7, 7),
|
||||
DoubleType -> DecimalType(15, 15)
|
||||
)
|
||||
// Returns the wider decimal type that's wider than both of them
|
||||
def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = {
|
||||
widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale)
|
||||
}
|
||||
// max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
|
||||
def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
|
||||
val scale = max(s1, s2)
|
||||
val range = max(p1 - s1, p2 - s2)
|
||||
DecimalType.bounded(range + scale, scale)
|
||||
}
|
||||
|
||||
private def castDecimalPrecision(
|
||||
left: LogicalPlan,
|
||||
right: LogicalPlan): (LogicalPlan, LogicalPlan) = {
|
||||
val castedInput = left.output.zip(right.output).map {
|
||||
case (lhs, rhs) if lhs.dataType != rhs.dataType =>
|
||||
(lhs.dataType, rhs.dataType) match {
|
||||
case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
|
||||
// Decimals with precision/scale p1/s2 and p2/s2 will be promoted to
|
||||
// DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))
|
||||
val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2))
|
||||
(Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)())
|
||||
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
|
||||
(Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs)
|
||||
case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
|
||||
(lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)())
|
||||
case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) =>
|
||||
(Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs)
|
||||
case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) =>
|
||||
(lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)())
|
||||
case _ => (lhs, rhs)
|
||||
}
|
||||
case other => other
|
||||
}
|
||||
/**
|
||||
* An expression used to wrap the children when promote the precision of DecimalType to avoid
|
||||
* promote multiple times.
|
||||
*/
|
||||
case class ChangePrecision(child: Expression) extends UnaryExpression {
|
||||
override def dataType: DataType = child.dataType
|
||||
override def eval(input: InternalRow): Any = child.eval(input)
|
||||
override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
|
||||
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = ""
|
||||
override def prettyName: String = "change_precision"
|
||||
}
|
||||
|
||||
val (castedLeft, castedRight) = castedInput.unzip
|
||||
|
||||
val newLeft =
|
||||
if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
|
||||
Project(castedLeft, left)
|
||||
} else {
|
||||
left
|
||||
}
|
||||
|
||||
val newRight =
|
||||
if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
|
||||
Project(castedRight, right)
|
||||
} else {
|
||||
right
|
||||
}
|
||||
(newLeft, newRight)
|
||||
def changePrecision(e: Expression, dataType: DataType): Expression = {
|
||||
ChangePrecision(Cast(e, dataType))
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
// fix decimal precision for union, intersect and except
|
||||
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
|
||||
val (newLeft, newRight) = castDecimalPrecision(left, right)
|
||||
Union(newLeft, newRight)
|
||||
case i @ Intersect(left, right) if i.childrenResolved && !i.resolved =>
|
||||
val (newLeft, newRight) = castDecimalPrecision(left, right)
|
||||
Intersect(newLeft, newRight)
|
||||
case e @ Except(left, right) if e.childrenResolved && !e.resolved =>
|
||||
val (newLeft, newRight) = castDecimalPrecision(left, right)
|
||||
Except(newLeft, newRight)
|
||||
|
||||
// fix decimal precision for expressions
|
||||
case q => q.transformExpressions {
|
||||
// Skip nodes whose children have not been resolved yet
|
||||
case e if !e.childrenResolved => e
|
||||
|
||||
// Skip nodes who is already promoted
|
||||
case e: BinaryArithmetic if e.left.isInstanceOf[ChangePrecision] => e
|
||||
|
||||
case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
Cast(
|
||||
Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
|
||||
DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
|
||||
)
|
||||
val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
|
||||
Add(changePrecision(e1, dt), changePrecision(e2, dt))
|
||||
|
||||
case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
Cast(
|
||||
Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
|
||||
DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
|
||||
)
|
||||
val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
|
||||
Subtract(changePrecision(e1, dt), changePrecision(e2, dt))
|
||||
|
||||
case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
Cast(
|
||||
Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
|
||||
DecimalType(p1 + p2 + 1, s1 + s2)
|
||||
)
|
||||
val dt = DecimalType.bounded(p1 + p2 + 1, s1 + s2)
|
||||
Multiply(changePrecision(e1, dt), changePrecision(e2, dt))
|
||||
|
||||
case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
Cast(
|
||||
Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
|
||||
DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1))
|
||||
)
|
||||
val dt = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1))
|
||||
Divide(changePrecision(e1, dt), changePrecision(e2, dt))
|
||||
|
||||
case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
Cast(
|
||||
Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
|
||||
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
|
||||
)
|
||||
val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
|
||||
// resultType may have lower precision, so we cast them into wider type first.
|
||||
val widerType = widerDecimalType(p1, s1, p2, s2)
|
||||
Cast(Remainder(changePrecision(e1, widerType), changePrecision(e2, widerType)),
|
||||
resultType)
|
||||
|
||||
case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
|
||||
Cast(
|
||||
Pmod(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
|
||||
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
|
||||
)
|
||||
val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
|
||||
// resultType may have lower precision, so we cast them into wider type first.
|
||||
val widerType = widerDecimalType(p1, s1, p2, s2)
|
||||
Cast(Pmod(changePrecision(e1, widerType), changePrecision(e2, widerType)), resultType)
|
||||
|
||||
// When we compare 2 decimal types with different precisions, cast them to the smallest
|
||||
// common precision.
|
||||
case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
|
||||
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
|
||||
val resultType = DecimalType(max(p1, p2), max(s1, s2))
|
||||
val resultType = widerDecimalType(p1, s1, p2, s2)
|
||||
b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType)))
|
||||
|
||||
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
|
||||
// and fixed-precision decimals in an expression with floats / doubles to doubles
|
||||
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
|
||||
(left.dataType, right.dataType) match {
|
||||
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
|
||||
b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right))
|
||||
case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
|
||||
b.makeCopy(Array(left, Cast(right, intTypeToFixed(t))))
|
||||
case (t: IntegralType, DecimalType.Fixed(p, s)) =>
|
||||
b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right))
|
||||
case (DecimalType.Fixed(p, s), t: IntegralType) =>
|
||||
b.makeCopy(Array(left, Cast(right, DecimalType.forType(t))))
|
||||
case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
|
||||
b.makeCopy(Array(left, Cast(right, DoubleType)))
|
||||
case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
|
||||
|
@ -485,7 +435,6 @@ object HiveTypeCoercion {
|
|||
// SUM and AVERAGE are handled by the implementations of those expressions
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -563,7 +512,7 @@ object HiveTypeCoercion {
|
|||
case e if !e.childrenResolved => e
|
||||
|
||||
case Cast(e @ StringType(), t: IntegralType) =>
|
||||
Cast(Cast(e, DecimalType.Unlimited), t)
|
||||
Cast(Cast(e, DecimalType.forType(LongType)), t)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -756,8 +705,8 @@ object HiveTypeCoercion {
|
|||
// Implicit cast among numeric types. When we reach here, input type is not acceptable.
|
||||
|
||||
// If input is a numeric type but not decimal, and we expect a decimal type,
|
||||
// cast the input to unlimited precision decimal.
|
||||
case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited)
|
||||
// cast the input to decimal.
|
||||
case (d: NumericType, DecimalType) => Cast(e, DecimalType.forType(d))
|
||||
// For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long
|
||||
case (_: NumericType, target: NumericType) => Cast(e, target)
|
||||
|
||||
|
@ -766,7 +715,7 @@ object HiveTypeCoercion {
|
|||
case (TimestampType, DateType) => Cast(e, DateType)
|
||||
|
||||
// Implicit cast from/to string
|
||||
case (StringType, DecimalType) => Cast(e, DecimalType.Unlimited)
|
||||
case (StringType, DecimalType) => Cast(e, DecimalType.SYSTEM_DEFAULT)
|
||||
case (StringType, target: NumericType) => Cast(e, target)
|
||||
case (StringType, DateType) => Cast(e, DateType)
|
||||
case (StringType, TimestampType) => Cast(e, TimestampType)
|
||||
|
|
|
@ -201,7 +201,7 @@ package object dsl {
|
|||
|
||||
/** Creates a new AttributeReference of type decimal */
|
||||
def decimal: AttributeReference =
|
||||
AttributeReference(s, DecimalType.Unlimited, nullable = true)()
|
||||
AttributeReference(s, DecimalType.SYSTEM_DEFAULT, nullable = true)()
|
||||
|
||||
/** Creates a new AttributeReference of type decimal */
|
||||
def decimal(precision: Int, scale: Int): AttributeReference =
|
||||
|
|
|
@ -300,12 +300,7 @@ case class Cast(child: Expression, dataType: DataType)
|
|||
* NOTE: this modifies `value` in-place, so don't call it on external data.
|
||||
*/
|
||||
private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = {
|
||||
decimalType match {
|
||||
case DecimalType.Unlimited =>
|
||||
value
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
if (value.changePrecision(precision, scale)) value else null
|
||||
}
|
||||
if (value.changePrecision(decimalType.precision, decimalType.scale)) value else null
|
||||
}
|
||||
|
||||
private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
|
||||
|
|
|
@ -36,14 +36,13 @@ case class Average(child: Expression) extends AlgebraicAggregate {
|
|||
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
|
||||
|
||||
private val resultType = child.dataType match {
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
DecimalType(precision + 4, scale + 4)
|
||||
case DecimalType.Unlimited => DecimalType.Unlimited
|
||||
case DecimalType.Fixed(p, s) =>
|
||||
DecimalType.bounded(p + 4, s + 4)
|
||||
case _ => DoubleType
|
||||
}
|
||||
|
||||
private val sumDataType = child.dataType match {
|
||||
case _ @ DecimalType() => DecimalType.Unlimited
|
||||
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
|
||||
case _ => DoubleType
|
||||
}
|
||||
|
||||
|
@ -71,7 +70,14 @@ case class Average(child: Expression) extends AlgebraicAggregate {
|
|||
)
|
||||
|
||||
// If all input are nulls, currentCount will be 0 and we will get null after the division.
|
||||
override val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType)
|
||||
override val evaluateExpression = child.dataType match {
|
||||
case DecimalType.Fixed(p, s) =>
|
||||
// increase the precision and scale to prevent precision loss
|
||||
val dt = DecimalType.bounded(p + 14, s + 4)
|
||||
Cast(Cast(currentSum, dt) / Cast(currentCount, dt), resultType)
|
||||
case _ =>
|
||||
Cast(currentSum, resultType) / Cast(currentCount, resultType)
|
||||
}
|
||||
}
|
||||
|
||||
case class Count(child: Expression) extends AlgebraicAggregate {
|
||||
|
@ -255,15 +261,11 @@ case class Sum(child: Expression) extends AlgebraicAggregate {
|
|||
|
||||
private val resultType = child.dataType match {
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
DecimalType(precision + 4, scale + 4)
|
||||
case DecimalType.Unlimited => DecimalType.Unlimited
|
||||
DecimalType.bounded(precision + 10, scale)
|
||||
case _ => child.dataType
|
||||
}
|
||||
|
||||
private val sumDataType = child.dataType match {
|
||||
case _ @ DecimalType() => DecimalType.Unlimited
|
||||
case _ => child.dataType
|
||||
}
|
||||
private val sumDataType = resultType
|
||||
|
||||
private val currentSum = AttributeReference("currentSum", sumDataType)()
|
||||
|
||||
|
|
|
@ -390,22 +390,21 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg
|
|||
|
||||
override def dataType: DataType = child.dataType match {
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive
|
||||
case DecimalType.Unlimited =>
|
||||
DecimalType.Unlimited
|
||||
// Add 4 digits after decimal point, like Hive
|
||||
DecimalType.bounded(precision + 4, scale + 4)
|
||||
case _ =>
|
||||
DoubleType
|
||||
}
|
||||
|
||||
override def asPartial: SplitEvaluation = {
|
||||
child.dataType match {
|
||||
case DecimalType.Fixed(_, _) | DecimalType.Unlimited =>
|
||||
// Turn the child to unlimited decimals for calculation, before going back to fixed
|
||||
val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
val partialSum = Alias(Sum(child), "PartialSum")()
|
||||
val partialCount = Alias(Count(child), "PartialCount")()
|
||||
|
||||
val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited)
|
||||
val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited)
|
||||
// partialSum already increase the precision by 10
|
||||
val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType)
|
||||
val castedCount = Sum(partialCount.toAttribute)
|
||||
SplitEvaluation(
|
||||
Cast(Divide(castedSum, castedCount), dataType),
|
||||
partialCount :: partialSum :: Nil)
|
||||
|
@ -435,8 +434,8 @@ case class AverageFunction(expr: Expression, base: AggregateExpression1)
|
|||
|
||||
private val calcType =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
DecimalType.Unlimited
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
DecimalType.bounded(precision + 10, scale)
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
|
@ -454,10 +453,9 @@ case class AverageFunction(expr: Expression, base: AggregateExpression1)
|
|||
null
|
||||
} else {
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
Cast(Divide(
|
||||
Cast(sum, DecimalType.Unlimited),
|
||||
Cast(Literal(count), DecimalType.Unlimited)), dataType).eval(null)
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
val dt = DecimalType.bounded(precision + 14, scale + 4)
|
||||
Cast(Divide(Cast(sum, dt), Cast(Literal(count), dt)), dataType).eval(null)
|
||||
case _ =>
|
||||
Divide(
|
||||
Cast(sum, dataType),
|
||||
|
@ -481,9 +479,8 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1
|
|||
|
||||
override def dataType: DataType = child.dataType match {
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
|
||||
case DecimalType.Unlimited =>
|
||||
DecimalType.Unlimited
|
||||
// Add 10 digits left of decimal point, like Hive
|
||||
DecimalType.bounded(precision + 10, scale)
|
||||
case _ =>
|
||||
child.dataType
|
||||
}
|
||||
|
@ -491,7 +488,7 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1
|
|||
override def asPartial: SplitEvaluation = {
|
||||
child.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
|
||||
val partialSum = Alias(Sum(child), "PartialSum")()
|
||||
SplitEvaluation(
|
||||
Cast(CombineSum(partialSum.toAttribute), dataType),
|
||||
partialSum :: Nil)
|
||||
|
@ -515,8 +512,8 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg
|
|||
|
||||
private val calcType =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
DecimalType.Unlimited
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
DecimalType.bounded(precision + 10, scale)
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
|
@ -572,8 +569,8 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression1)
|
|||
|
||||
private val calcType =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
DecimalType.Unlimited
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
DecimalType.bounded(precision + 10, scale)
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
|
@ -608,9 +605,8 @@ case class SumDistinct(child: Expression) extends UnaryExpression with PartialAg
|
|||
override def nullable: Boolean = true
|
||||
override def dataType: DataType = child.dataType match {
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
|
||||
case DecimalType.Unlimited =>
|
||||
DecimalType.Unlimited
|
||||
// Add 10 digits left of decimal point, like Hive
|
||||
DecimalType.bounded(precision + 10, scale)
|
||||
case _ =>
|
||||
child.dataType
|
||||
}
|
||||
|
|
|
@ -88,6 +88,8 @@ abstract class BinaryArithmetic extends BinaryOperator {
|
|||
|
||||
override def dataType: DataType = left.dataType
|
||||
|
||||
override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess
|
||||
|
||||
/** Name of the function for this expression on a [[Decimal]] type. */
|
||||
def decimalMethod: String =
|
||||
sys.error("BinaryArithmetics must override either decimalMethod or genCode")
|
||||
|
@ -114,9 +116,6 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
|
|||
|
||||
override def symbol: String = "+"
|
||||
|
||||
override lazy val resolved =
|
||||
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
|
||||
|
||||
private lazy val numeric = TypeUtils.getNumeric(dataType)
|
||||
|
||||
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
|
||||
|
@ -146,9 +145,6 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
|
|||
|
||||
override def symbol: String = "-"
|
||||
|
||||
override lazy val resolved =
|
||||
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
|
||||
|
||||
private lazy val numeric = TypeUtils.getNumeric(dataType)
|
||||
|
||||
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
|
||||
|
@ -179,9 +175,6 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
|
|||
override def symbol: String = "*"
|
||||
override def decimalMethod: String = "$times"
|
||||
|
||||
override lazy val resolved =
|
||||
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
|
||||
|
||||
private lazy val numeric = TypeUtils.getNumeric(dataType)
|
||||
|
||||
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
|
||||
|
@ -195,9 +188,6 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
|
|||
override def decimalMethod: String = "$div"
|
||||
override def nullable: Boolean = true
|
||||
|
||||
override lazy val resolved =
|
||||
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
|
||||
|
||||
private lazy val div: (Any, Any) => Any = dataType match {
|
||||
case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
|
||||
case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot
|
||||
|
@ -260,9 +250,6 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
|
|||
override def decimalMethod: String = "remainder"
|
||||
override def nullable: Boolean = true
|
||||
|
||||
override lazy val resolved =
|
||||
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
|
||||
|
||||
private lazy val integral = dataType match {
|
||||
case i: IntegralType => i.integral.asInstanceOf[Integral[Any]]
|
||||
case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]]
|
||||
|
|
|
@ -36,9 +36,9 @@ object Literal {
|
|||
case s: Short => Literal(s, ShortType)
|
||||
case s: String => Literal(UTF8String.fromString(s), StringType)
|
||||
case b: Boolean => Literal(b, BooleanType)
|
||||
case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
|
||||
case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
|
||||
case d: Decimal => Literal(d, DecimalType.Unlimited)
|
||||
case d: BigDecimal => Literal(Decimal(d), DecimalType(d.precision, d.scale))
|
||||
case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType(d.precision(), d.scale()))
|
||||
case d: Decimal => Literal(d, DecimalType(d.precision, d.scale))
|
||||
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
|
||||
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
|
||||
case a: Array[Byte] => Literal(a, BinaryType)
|
||||
|
|
|
@ -17,9 +17,9 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.plans
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.{VirtualColumn, Attribute, AttributeSet, Expression}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn}
|
||||
import org.apache.spark.sql.catalyst.trees.TreeNode
|
||||
import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType}
|
||||
import org.apache.spark.sql.types.{DataType, StructType}
|
||||
|
||||
abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] {
|
||||
self: PlanType =>
|
||||
|
|
|
@ -106,7 +106,7 @@ object DataType {
|
|||
private def nameToType(name: String): DataType = {
|
||||
val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r
|
||||
name match {
|
||||
case "decimal" => DecimalType.Unlimited
|
||||
case "decimal" => DecimalType.USER_DEFAULT
|
||||
case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt)
|
||||
case other => nonDecimalNameToType(other)
|
||||
}
|
||||
|
@ -177,7 +177,7 @@ object DataType {
|
|||
| "BinaryType" ^^^ BinaryType
|
||||
| "BooleanType" ^^^ BooleanType
|
||||
| "DateType" ^^^ DateType
|
||||
| "DecimalType()" ^^^ DecimalType.Unlimited
|
||||
| "DecimalType()" ^^^ DecimalType.USER_DEFAULT
|
||||
| fixedDecimalType
|
||||
| "TimestampType" ^^^ TimestampType
|
||||
)
|
||||
|
|
|
@ -48,7 +48,7 @@ private[sql] trait DataTypeParser extends StandardTokenParsers {
|
|||
"(?i)binary".r ^^^ BinaryType |
|
||||
"(?i)boolean".r ^^^ BooleanType |
|
||||
fixedDecimalType |
|
||||
"(?i)decimal".r ^^^ DecimalType.Unlimited |
|
||||
"(?i)decimal".r ^^^ DecimalType.USER_DEFAULT |
|
||||
"(?i)date".r ^^^ DateType |
|
||||
"(?i)timestamp".r ^^^ TimestampType |
|
||||
varchar
|
||||
|
|
|
@ -26,25 +26,46 @@ import org.apache.spark.sql.catalyst.expressions.Expression
|
|||
|
||||
|
||||
/** Precision parameters for a Decimal */
|
||||
@deprecated("Use DecimalType(precision, scale) directly", "1.5")
|
||||
case class PrecisionInfo(precision: Int, scale: Int) {
|
||||
if (scale > precision) {
|
||||
throw new AnalysisException(
|
||||
s"Decimal scale ($scale) cannot be greater than precision ($precision).")
|
||||
}
|
||||
if (precision > DecimalType.MAX_PRECISION) {
|
||||
throw new AnalysisException(
|
||||
s"DecimalType can only support precision up to 38"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
* The data type representing `java.math.BigDecimal` values.
|
||||
* A Decimal that might have fixed precision and scale, or unlimited values for these.
|
||||
* A Decimal that must have fixed precision (the maximum number of digits) and scale (the number
|
||||
* of digits on right side of dot).
|
||||
*
|
||||
* The precision can be up to 38, scale can also be up to 38 (less or equal to precision).
|
||||
*
|
||||
* The default precision and scale is (10, 0).
|
||||
*
|
||||
* Please use [[DataTypes.createDecimalType()]] to create a specific instance.
|
||||
*/
|
||||
@DeveloperApi
|
||||
case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType {
|
||||
case class DecimalType(precision: Int, scale: Int) extends FractionalType {
|
||||
|
||||
/** No-arg constructor for kryo. */
|
||||
protected def this() = this(null)
|
||||
// default constructor for Java
|
||||
def this(precision: Int) = this(precision, 0)
|
||||
def this() = this(10)
|
||||
|
||||
@deprecated("Use DecimalType(precision, scale) instead", "1.5")
|
||||
def this(precisionInfo: Option[PrecisionInfo]) {
|
||||
this(precisionInfo.getOrElse(PrecisionInfo(10, 0)).precision,
|
||||
precisionInfo.getOrElse(PrecisionInfo(10, 0)).scale)
|
||||
}
|
||||
|
||||
@deprecated("Use DecimalType.precision and DecimalType.scale instead", "1.5")
|
||||
val precisionInfo = Some(PrecisionInfo(precision, scale))
|
||||
|
||||
private[sql] type InternalType = Decimal
|
||||
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
|
||||
|
@ -53,18 +74,16 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
|
|||
private[sql] val ordering = Decimal.DecimalIsFractional
|
||||
private[sql] val asIntegral = Decimal.DecimalAsIfIntegral
|
||||
|
||||
def precision: Int = precisionInfo.map(_.precision).getOrElse(-1)
|
||||
override def typeName: String = s"decimal($precision,$scale)"
|
||||
|
||||
def scale: Int = precisionInfo.map(_.scale).getOrElse(-1)
|
||||
override def toString: String = s"DecimalType($precision,$scale)"
|
||||
|
||||
override def typeName: String = precisionInfo match {
|
||||
case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
|
||||
case None => "decimal"
|
||||
}
|
||||
|
||||
override def toString: String = precisionInfo match {
|
||||
case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)"
|
||||
case None => "DecimalType()"
|
||||
private[sql] def isWiderThan(other: DataType): Boolean = other match {
|
||||
case dt: DecimalType =>
|
||||
(precision - scale) >= (dt.precision - dt.scale) && scale >= dt.scale
|
||||
case dt: IntegralType =>
|
||||
isWiderThan(DecimalType.forType(dt))
|
||||
case _ => false
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -72,10 +91,7 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
|
|||
*/
|
||||
override def defaultSize: Int = 4096
|
||||
|
||||
override def simpleString: String = precisionInfo match {
|
||||
case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
|
||||
case None => "decimal(10,0)"
|
||||
}
|
||||
override def simpleString: String = s"decimal($precision,$scale)"
|
||||
|
||||
private[spark] override def asNullable: DecimalType = this
|
||||
}
|
||||
|
@ -83,8 +99,47 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
|
|||
|
||||
/** Extra factory methods and pattern matchers for Decimals */
|
||||
object DecimalType extends AbstractDataType {
|
||||
import scala.math.min
|
||||
|
||||
override private[sql] def defaultConcreteType: DataType = Unlimited
|
||||
val MAX_PRECISION = 38
|
||||
val MAX_SCALE = 38
|
||||
val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18)
|
||||
val USER_DEFAULT: DecimalType = DecimalType(10, 0)
|
||||
|
||||
@deprecated("Does not support unlimited precision, please specify the precision and scale", "1.5")
|
||||
val Unlimited: DecimalType = SYSTEM_DEFAULT
|
||||
|
||||
// The decimal types compatible with other numberic types
|
||||
private[sql] val ByteDecimal = DecimalType(3, 0)
|
||||
private[sql] val ShortDecimal = DecimalType(5, 0)
|
||||
private[sql] val IntDecimal = DecimalType(10, 0)
|
||||
private[sql] val LongDecimal = DecimalType(20, 0)
|
||||
private[sql] val FloatDecimal = DecimalType(14, 7)
|
||||
private[sql] val DoubleDecimal = DecimalType(30, 15)
|
||||
|
||||
private[sql] def forType(dataType: DataType): DecimalType = dataType match {
|
||||
case ByteType => ByteDecimal
|
||||
case ShortType => ShortDecimal
|
||||
case IntegerType => IntDecimal
|
||||
case LongType => LongDecimal
|
||||
case FloatType => FloatDecimal
|
||||
case DoubleType => DoubleDecimal
|
||||
}
|
||||
|
||||
@deprecated("please specify precision and scale", "1.5")
|
||||
def apply(): DecimalType = USER_DEFAULT
|
||||
|
||||
@deprecated("Use DecimalType(precision, scale) instead", "1.5")
|
||||
def apply(precisionInfo: Option[PrecisionInfo]) {
|
||||
this(precisionInfo.getOrElse(PrecisionInfo(10, 0)).precision,
|
||||
precisionInfo.getOrElse(PrecisionInfo(10, 0)).scale)
|
||||
}
|
||||
|
||||
private[sql] def bounded(precision: Int, scale: Int): DecimalType = {
|
||||
DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE))
|
||||
}
|
||||
|
||||
override private[sql] def defaultConcreteType: DataType = SYSTEM_DEFAULT
|
||||
|
||||
override private[sql] def acceptsType(other: DataType): Boolean = {
|
||||
other.isInstanceOf[DecimalType]
|
||||
|
@ -92,31 +147,18 @@ object DecimalType extends AbstractDataType {
|
|||
|
||||
override private[sql] def simpleString: String = "decimal"
|
||||
|
||||
val Unlimited: DecimalType = DecimalType(None)
|
||||
|
||||
private[sql] object Fixed {
|
||||
def unapply(t: DecimalType): Option[(Int, Int)] =
|
||||
t.precisionInfo.map(p => (p.precision, p.scale))
|
||||
def unapply(t: DecimalType): Option[(Int, Int)] = Some((t.precision, t.scale))
|
||||
}
|
||||
|
||||
private[sql] object Expression {
|
||||
def unapply(e: Expression): Option[(Int, Int)] = e.dataType match {
|
||||
case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale))
|
||||
case t: DecimalType => Some((t.precision, t.scale))
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
def apply(): DecimalType = Unlimited
|
||||
|
||||
def apply(precision: Int, scale: Int): DecimalType =
|
||||
DecimalType(Some(PrecisionInfo(precision, scale)))
|
||||
|
||||
def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType]
|
||||
|
||||
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType]
|
||||
|
||||
def isFixed(dataType: DataType): Boolean = dataType match {
|
||||
case DecimalType.Fixed(_, _) => true
|
||||
case _ => false
|
||||
}
|
||||
}
|
||||
|
|
|
@ -94,8 +94,8 @@ object RandomDataGenerator {
|
|||
case BooleanType => Some(() => rand.nextBoolean())
|
||||
case DateType => Some(() => new java.sql.Date(rand.nextInt()))
|
||||
case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong()))
|
||||
case DecimalType.Unlimited => Some(
|
||||
() => BigDecimal.apply(rand.nextLong, rand.nextInt, MathContext.UNLIMITED))
|
||||
case DecimalType.Fixed(precision, scale) => Some(
|
||||
() => BigDecimal.apply(rand.nextLong, rand.nextInt, new MathContext(precision)))
|
||||
case DoubleType => randomNumeric[Double](
|
||||
rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue,
|
||||
Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0))
|
||||
|
|
|
@ -50,9 +50,7 @@ class RandomDataGeneratorSuite extends SparkFunSuite {
|
|||
for (
|
||||
dataType <- DataTypeTestUtils.atomicTypes;
|
||||
nullable <- Seq(true, false)
|
||||
if !dataType.isInstanceOf[DecimalType] ||
|
||||
dataType.asInstanceOf[DecimalType].precisionInfo.isEmpty
|
||||
) {
|
||||
if !dataType.isInstanceOf[DecimalType]) {
|
||||
test(s"$dataType (nullable=$nullable)") {
|
||||
testRandomDataGeneration(dataType)
|
||||
}
|
||||
|
|
|
@ -102,7 +102,7 @@ class ScalaReflectionSuite extends SparkFunSuite {
|
|||
StructField("byteField", ByteType, nullable = true),
|
||||
StructField("booleanField", BooleanType, nullable = true),
|
||||
StructField("stringField", StringType, nullable = true),
|
||||
StructField("decimalField", DecimalType.Unlimited, nullable = true),
|
||||
StructField("decimalField", DecimalType.SYSTEM_DEFAULT, nullable = true),
|
||||
StructField("dateField", DateType, nullable = true),
|
||||
StructField("timestampField", TimestampType, nullable = true),
|
||||
StructField("binaryField", BinaryType, nullable = true))),
|
||||
|
@ -216,7 +216,7 @@ class ScalaReflectionSuite extends SparkFunSuite {
|
|||
assert(DoubleType === typeOfObject(1.7976931348623157E308))
|
||||
|
||||
// DecimalType
|
||||
assert(DecimalType.Unlimited ===
|
||||
assert(DecimalType.SYSTEM_DEFAULT ===
|
||||
typeOfObject(new java.math.BigDecimal("1.7976931348623157E318")))
|
||||
|
||||
// DateType
|
||||
|
@ -229,19 +229,19 @@ class ScalaReflectionSuite extends SparkFunSuite {
|
|||
assert(NullType === typeOfObject(null))
|
||||
|
||||
def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse {
|
||||
case value: java.math.BigInteger => DecimalType.Unlimited
|
||||
case value: java.math.BigDecimal => DecimalType.Unlimited
|
||||
case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT
|
||||
case value: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT
|
||||
case _ => StringType
|
||||
}
|
||||
|
||||
assert(DecimalType.Unlimited === typeOfObject1(
|
||||
assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1(
|
||||
new BigInteger("92233720368547758070")))
|
||||
assert(DecimalType.Unlimited === typeOfObject1(
|
||||
assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1(
|
||||
new java.math.BigDecimal("1.7976931348623157E318")))
|
||||
assert(StringType === typeOfObject1(BigInt("92233720368547758070")))
|
||||
|
||||
def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse {
|
||||
case value: java.math.BigInteger => DecimalType.Unlimited
|
||||
case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT
|
||||
}
|
||||
|
||||
intercept[MatchError](typeOfObject2(BigInt("92233720368547758070")))
|
||||
|
|
|
@ -55,7 +55,7 @@ object AnalysisSuite {
|
|||
AttributeReference("a", StringType)(),
|
||||
AttributeReference("b", StringType)(),
|
||||
AttributeReference("c", DoubleType)(),
|
||||
AttributeReference("d", DecimalType.Unlimited)(),
|
||||
AttributeReference("d", DecimalType.SYSTEM_DEFAULT)(),
|
||||
AttributeReference("e", ShortType)())
|
||||
|
||||
val nestedRelation = LocalRelation(
|
||||
|
@ -158,7 +158,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
|
|||
AttributeReference("a", StringType)(),
|
||||
AttributeReference("b", StringType)(),
|
||||
AttributeReference("c", DoubleType)(),
|
||||
AttributeReference("d", DecimalType.Unlimited)(),
|
||||
AttributeReference("d", DecimalType(10, 2))(),
|
||||
AttributeReference("e", ShortType)())
|
||||
|
||||
val plan = caseInsensitiveAnalyzer.execute(
|
||||
|
@ -173,7 +173,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
|
|||
assert(pl(0).dataType == DoubleType)
|
||||
assert(pl(1).dataType == DoubleType)
|
||||
assert(pl(2).dataType == DoubleType)
|
||||
assert(pl(3).dataType == DecimalType.Unlimited)
|
||||
assert(pl(3).dataType == DoubleType) // StringType will be promoted into Double
|
||||
assert(pl(4).dataType == DoubleType)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
|
|||
AttributeReference("i", IntegerType)(),
|
||||
AttributeReference("d1", DecimalType(2, 1))(),
|
||||
AttributeReference("d2", DecimalType(5, 2))(),
|
||||
AttributeReference("u", DecimalType.Unlimited)(),
|
||||
AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
|
||||
AttributeReference("f", FloatType)(),
|
||||
AttributeReference("b", DoubleType)()
|
||||
)
|
||||
|
@ -92,11 +92,11 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
|
|||
}
|
||||
|
||||
test("Comparison operations") {
|
||||
checkComparison(EqualTo(i, d1), DecimalType(10, 1))
|
||||
checkComparison(EqualTo(i, d1), DecimalType(11, 1))
|
||||
checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2))
|
||||
checkComparison(LessThan(i, d1), DecimalType(10, 1))
|
||||
checkComparison(LessThan(i, d1), DecimalType(11, 1))
|
||||
checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2))
|
||||
checkComparison(GreaterThan(d2, u), DecimalType.Unlimited)
|
||||
checkComparison(GreaterThan(d2, u), DecimalType.SYSTEM_DEFAULT)
|
||||
checkComparison(GreaterThanOrEqual(d1, f), DoubleType)
|
||||
checkComparison(GreaterThan(d2, d2), DecimalType(5, 2))
|
||||
}
|
||||
|
@ -106,12 +106,12 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
|
|||
checkUnion(i, d2, DecimalType(12, 2))
|
||||
checkUnion(d1, d2, DecimalType(5, 2))
|
||||
checkUnion(d2, d1, DecimalType(5, 2))
|
||||
checkUnion(d1, f, DecimalType(8, 7))
|
||||
checkUnion(f, d2, DecimalType(10, 7))
|
||||
checkUnion(d1, b, DecimalType(16, 15))
|
||||
checkUnion(b, d2, DecimalType(18, 15))
|
||||
checkUnion(d1, u, DecimalType.Unlimited)
|
||||
checkUnion(u, d2, DecimalType.Unlimited)
|
||||
checkUnion(d1, f, DoubleType)
|
||||
checkUnion(f, d2, DoubleType)
|
||||
checkUnion(d1, b, DoubleType)
|
||||
checkUnion(b, d2, DoubleType)
|
||||
checkUnion(d1, u, DecimalType.SYSTEM_DEFAULT)
|
||||
checkUnion(u, d2, DecimalType.SYSTEM_DEFAULT)
|
||||
}
|
||||
|
||||
test("bringing in primitive types") {
|
||||
|
@ -125,13 +125,33 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
|
|||
checkType(Add(d1, Cast(i, DoubleType)), DoubleType)
|
||||
}
|
||||
|
||||
test("unlimited decimals make everything else cast up") {
|
||||
for (expr <- Seq(d1, d2, i, f, u)) {
|
||||
checkType(Add(expr, u), DecimalType.Unlimited)
|
||||
checkType(Subtract(expr, u), DecimalType.Unlimited)
|
||||
checkType(Multiply(expr, u), DecimalType.Unlimited)
|
||||
checkType(Divide(expr, u), DecimalType.Unlimited)
|
||||
checkType(Remainder(expr, u), DecimalType.Unlimited)
|
||||
test("maximum decimals") {
|
||||
for (expr <- Seq(d1, d2, i, u)) {
|
||||
checkType(Add(expr, u), DecimalType.SYSTEM_DEFAULT)
|
||||
checkType(Subtract(expr, u), DecimalType.SYSTEM_DEFAULT)
|
||||
}
|
||||
|
||||
checkType(Multiply(d1, u), DecimalType(38, 19))
|
||||
checkType(Multiply(d2, u), DecimalType(38, 20))
|
||||
checkType(Multiply(i, u), DecimalType(38, 18))
|
||||
checkType(Multiply(u, u), DecimalType(38, 36))
|
||||
|
||||
checkType(Divide(u, d1), DecimalType(38, 21))
|
||||
checkType(Divide(u, d2), DecimalType(38, 24))
|
||||
checkType(Divide(u, i), DecimalType(38, 29))
|
||||
checkType(Divide(u, u), DecimalType(38, 38))
|
||||
|
||||
checkType(Remainder(d1, u), DecimalType(19, 18))
|
||||
checkType(Remainder(d2, u), DecimalType(21, 18))
|
||||
checkType(Remainder(i, u), DecimalType(28, 18))
|
||||
checkType(Remainder(u, u), DecimalType.SYSTEM_DEFAULT)
|
||||
|
||||
for (expr <- Seq(f, b)) {
|
||||
checkType(Add(expr, u), DoubleType)
|
||||
checkType(Subtract(expr, u), DoubleType)
|
||||
checkType(Multiply(expr, u), DoubleType)
|
||||
checkType(Divide(expr, u), DoubleType)
|
||||
checkType(Remainder(expr, u), DoubleType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,14 +35,14 @@ class HiveTypeCoercionSuite extends PlanTest {
|
|||
|
||||
shouldCast(NullType, NullType, NullType)
|
||||
shouldCast(NullType, IntegerType, IntegerType)
|
||||
shouldCast(NullType, DecimalType, DecimalType.Unlimited)
|
||||
shouldCast(NullType, DecimalType, DecimalType.SYSTEM_DEFAULT)
|
||||
|
||||
shouldCast(ByteType, IntegerType, IntegerType)
|
||||
shouldCast(IntegerType, IntegerType, IntegerType)
|
||||
shouldCast(IntegerType, LongType, LongType)
|
||||
shouldCast(IntegerType, DecimalType, DecimalType.Unlimited)
|
||||
shouldCast(IntegerType, DecimalType, DecimalType(10, 0))
|
||||
shouldCast(LongType, IntegerType, IntegerType)
|
||||
shouldCast(LongType, DecimalType, DecimalType.Unlimited)
|
||||
shouldCast(LongType, DecimalType, DecimalType(20, 0))
|
||||
|
||||
shouldCast(DateType, TimestampType, TimestampType)
|
||||
shouldCast(TimestampType, DateType, DateType)
|
||||
|
@ -71,8 +71,8 @@ class HiveTypeCoercionSuite extends PlanTest {
|
|||
shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType)
|
||||
shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType)
|
||||
|
||||
shouldCast(
|
||||
DecimalType.Unlimited, TypeCollection(IntegerType, DecimalType), DecimalType.Unlimited)
|
||||
shouldCast(DecimalType.SYSTEM_DEFAULT,
|
||||
TypeCollection(IntegerType, DecimalType), DecimalType.SYSTEM_DEFAULT)
|
||||
shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2))
|
||||
shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2))
|
||||
shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2))
|
||||
|
@ -82,7 +82,7 @@ class HiveTypeCoercionSuite extends PlanTest {
|
|||
|
||||
// NumericType should not be changed when function accepts any of them.
|
||||
Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType,
|
||||
DecimalType.Unlimited, DecimalType(10, 2)).foreach { tpe =>
|
||||
DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2)).foreach { tpe =>
|
||||
shouldCast(tpe, NumericType, tpe)
|
||||
}
|
||||
|
||||
|
@ -107,8 +107,8 @@ class HiveTypeCoercionSuite extends PlanTest {
|
|||
shouldNotCast(IntegerType, TimestampType)
|
||||
shouldNotCast(LongType, DateType)
|
||||
shouldNotCast(LongType, TimestampType)
|
||||
shouldNotCast(DecimalType.Unlimited, DateType)
|
||||
shouldNotCast(DecimalType.Unlimited, TimestampType)
|
||||
shouldNotCast(DecimalType.SYSTEM_DEFAULT, DateType)
|
||||
shouldNotCast(DecimalType.SYSTEM_DEFAULT, TimestampType)
|
||||
|
||||
shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType))
|
||||
|
||||
|
@ -160,14 +160,6 @@ class HiveTypeCoercionSuite extends PlanTest {
|
|||
widenTest(LongType, FloatType, Some(FloatType))
|
||||
widenTest(LongType, DoubleType, Some(DoubleType))
|
||||
|
||||
// Casting up to unlimited-precision decimal
|
||||
widenTest(IntegerType, DecimalType.Unlimited, Some(DecimalType.Unlimited))
|
||||
widenTest(DoubleType, DecimalType.Unlimited, Some(DecimalType.Unlimited))
|
||||
widenTest(DecimalType(3, 2), DecimalType.Unlimited, Some(DecimalType.Unlimited))
|
||||
widenTest(DecimalType.Unlimited, IntegerType, Some(DecimalType.Unlimited))
|
||||
widenTest(DecimalType.Unlimited, DoubleType, Some(DecimalType.Unlimited))
|
||||
widenTest(DecimalType.Unlimited, DecimalType(3, 2), Some(DecimalType.Unlimited))
|
||||
|
||||
// No up-casting for fixed-precision decimal (this is handled by arithmetic rules)
|
||||
widenTest(DecimalType(2, 1), DecimalType(3, 2), None)
|
||||
widenTest(DecimalType(2, 1), DoubleType, None)
|
||||
|
@ -242,9 +234,9 @@ class HiveTypeCoercionSuite extends PlanTest {
|
|||
:: Literal(1)
|
||||
:: Literal(new java.math.BigDecimal("1000000000000000000000"))
|
||||
:: Nil),
|
||||
Coalesce(Cast(Literal(1L), DecimalType())
|
||||
:: Cast(Literal(1), DecimalType())
|
||||
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType())
|
||||
Coalesce(Cast(Literal(1L), DecimalType(22, 0))
|
||||
:: Cast(Literal(1), DecimalType(22, 0))
|
||||
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
|
||||
:: Nil))
|
||||
}
|
||||
|
||||
|
@ -323,7 +315,7 @@ class HiveTypeCoercionSuite extends PlanTest {
|
|||
|
||||
val left = LocalRelation(
|
||||
AttributeReference("i", IntegerType)(),
|
||||
AttributeReference("u", DecimalType.Unlimited)(),
|
||||
AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
|
||||
AttributeReference("b", ByteType)(),
|
||||
AttributeReference("d", DoubleType)())
|
||||
val right = LocalRelation(
|
||||
|
@ -333,7 +325,7 @@ class HiveTypeCoercionSuite extends PlanTest {
|
|||
AttributeReference("l", LongType)())
|
||||
|
||||
val wt = HiveTypeCoercion.WidenTypes
|
||||
val expectedTypes = Seq(StringType, DecimalType.Unlimited, FloatType, DoubleType)
|
||||
val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
|
||||
|
||||
val r1 = wt(Union(left, right)).asInstanceOf[Union]
|
||||
val r2 = wt(Except(left, right)).asInstanceOf[Except]
|
||||
|
@ -353,13 +345,13 @@ class HiveTypeCoercionSuite extends PlanTest {
|
|||
}
|
||||
}
|
||||
|
||||
val dp = HiveTypeCoercion.DecimalPrecision
|
||||
val dp = HiveTypeCoercion.WidenTypes
|
||||
|
||||
val left1 = LocalRelation(
|
||||
AttributeReference("l", DecimalType(10, 8))())
|
||||
val right1 = LocalRelation(
|
||||
AttributeReference("r", DecimalType(5, 5))())
|
||||
val expectedType1 = Seq(DecimalType(math.max(8, 5) + math.max(10 - 8, 5 - 5), math.max(8, 5)))
|
||||
val expectedType1 = Seq(DecimalType(10, 8))
|
||||
|
||||
val r1 = dp(Union(left1, right1)).asInstanceOf[Union]
|
||||
val r2 = dp(Except(left1, right1)).asInstanceOf[Except]
|
||||
|
@ -372,12 +364,11 @@ class HiveTypeCoercionSuite extends PlanTest {
|
|||
checkOutput(r3.left, expectedType1)
|
||||
checkOutput(r3.right, expectedType1)
|
||||
|
||||
val plan1 = LocalRelation(
|
||||
AttributeReference("l", DecimalType(10, 10))())
|
||||
val plan1 = LocalRelation(AttributeReference("l", DecimalType(10, 5))())
|
||||
|
||||
val rightTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType)
|
||||
val expectedTypes = Seq(DecimalType(3, 0), DecimalType(5, 0), DecimalType(10, 0),
|
||||
DecimalType(20, 0), DecimalType(7, 7), DecimalType(15, 15))
|
||||
val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5),
|
||||
DecimalType(25, 5), DoubleType, DoubleType)
|
||||
|
||||
rightTypes.zip(expectedTypes).map { case (rType, expectedType) =>
|
||||
val plan2 = LocalRelation(
|
||||
|
|
|
@ -185,7 +185,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkCast(1, 1.0)
|
||||
checkCast(123, "123")
|
||||
|
||||
checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123))
|
||||
checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123))
|
||||
checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123))
|
||||
checkEvaluation(cast(123, DecimalType(3, 1)), null)
|
||||
checkEvaluation(cast(123, DecimalType(2, 0)), null)
|
||||
|
@ -203,7 +203,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkCast(1L, 1.0)
|
||||
checkCast(123L, "123")
|
||||
|
||||
checkEvaluation(cast(123L, DecimalType.Unlimited), Decimal(123))
|
||||
checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123))
|
||||
checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123))
|
||||
checkEvaluation(cast(123L, DecimalType(3, 1)), Decimal(123.0))
|
||||
|
||||
|
@ -225,7 +225,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(cast(cast(1000, TimestampType), LongType), 1.toLong)
|
||||
checkEvaluation(cast(cast(-1200, TimestampType), LongType), -2.toLong)
|
||||
|
||||
checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123))
|
||||
checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123))
|
||||
checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123))
|
||||
checkEvaluation(cast(123, DecimalType(3, 1)), null)
|
||||
checkEvaluation(cast(123, DecimalType(2, 0)), null)
|
||||
|
@ -267,7 +267,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
assert(cast("abcdef", IntegerType).nullable === true)
|
||||
assert(cast("abcdef", ShortType).nullable === true)
|
||||
assert(cast("abcdef", ByteType).nullable === true)
|
||||
assert(cast("abcdef", DecimalType.Unlimited).nullable === true)
|
||||
assert(cast("abcdef", DecimalType.USER_DEFAULT).nullable === true)
|
||||
assert(cast("abcdef", DecimalType(4, 2)).nullable === true)
|
||||
assert(cast("abcdef", DoubleType).nullable === true)
|
||||
assert(cast("abcdef", FloatType).nullable === true)
|
||||
|
@ -291,9 +291,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
c.getTimeInMillis * 1000)
|
||||
|
||||
checkEvaluation(cast("abdef", StringType), "abdef")
|
||||
checkEvaluation(cast("abdef", DecimalType.Unlimited), null)
|
||||
checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null)
|
||||
checkEvaluation(cast("abdef", TimestampType), null)
|
||||
checkEvaluation(cast("12.65", DecimalType.Unlimited), Decimal(12.65))
|
||||
checkEvaluation(cast("12.65", DecimalType.SYSTEM_DEFAULT), Decimal(12.65))
|
||||
|
||||
checkEvaluation(cast(cast(sd, DateType), StringType), sd)
|
||||
checkEvaluation(cast(cast(d, StringType), DateType), 0)
|
||||
|
@ -311,20 +311,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
5.toLong)
|
||||
checkEvaluation(
|
||||
cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType),
|
||||
DecimalType.Unlimited), LongType), StringType), ShortType),
|
||||
DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType),
|
||||
0.toShort)
|
||||
checkEvaluation(
|
||||
cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType),
|
||||
DecimalType.Unlimited), LongType), StringType), ShortType),
|
||||
DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType),
|
||||
null)
|
||||
checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.Unlimited),
|
||||
checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT),
|
||||
ByteType), TimestampType), LongType), StringType), ShortType),
|
||||
0.toShort)
|
||||
|
||||
checkEvaluation(cast("23", DoubleType), 23d)
|
||||
checkEvaluation(cast("23", IntegerType), 23)
|
||||
checkEvaluation(cast("23", FloatType), 23f)
|
||||
checkEvaluation(cast("23", DecimalType.Unlimited), Decimal(23))
|
||||
checkEvaluation(cast("23", DecimalType.USER_DEFAULT), Decimal(23))
|
||||
checkEvaluation(cast("23", ByteType), 23.toByte)
|
||||
checkEvaluation(cast("23", ShortType), 23.toShort)
|
||||
checkEvaluation(cast("2012-12-11", DoubleType), null)
|
||||
|
@ -338,7 +338,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(Add(Literal(23d), cast(true, DoubleType)), 24d)
|
||||
checkEvaluation(Add(Literal(23), cast(true, IntegerType)), 24)
|
||||
checkEvaluation(Add(Literal(23f), cast(true, FloatType)), 24f)
|
||||
checkEvaluation(Add(Literal(Decimal(23)), cast(true, DecimalType.Unlimited)), Decimal(24))
|
||||
checkEvaluation(Add(Literal(Decimal(23)), cast(true, DecimalType.USER_DEFAULT)), Decimal(24))
|
||||
checkEvaluation(Add(Literal(23.toByte), cast(true, ByteType)), 24.toByte)
|
||||
checkEvaluation(Add(Literal(23.toShort), cast(true, ShortType)), 24.toShort)
|
||||
}
|
||||
|
@ -362,10 +362,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
// - Values that would overflow the target precision should turn into null
|
||||
// - Because of this, casts to fixed-precision decimals should be nullable
|
||||
|
||||
assert(cast(123, DecimalType.Unlimited).nullable === false)
|
||||
assert(cast(10.03f, DecimalType.Unlimited).nullable === true)
|
||||
assert(cast(10.03, DecimalType.Unlimited).nullable === true)
|
||||
assert(cast(Decimal(10.03), DecimalType.Unlimited).nullable === false)
|
||||
assert(cast(123, DecimalType.USER_DEFAULT).nullable === true)
|
||||
assert(cast(10.03f, DecimalType.SYSTEM_DEFAULT).nullable === true)
|
||||
assert(cast(10.03, DecimalType.SYSTEM_DEFAULT).nullable === true)
|
||||
assert(cast(Decimal(10.03), DecimalType.SYSTEM_DEFAULT).nullable === true)
|
||||
|
||||
assert(cast(123, DecimalType(2, 1)).nullable === true)
|
||||
assert(cast(10.03f, DecimalType(2, 1)).nullable === true)
|
||||
|
@ -373,7 +373,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
assert(cast(Decimal(10.03), DecimalType(2, 1)).nullable === true)
|
||||
|
||||
|
||||
checkEvaluation(cast(10.03, DecimalType.Unlimited), Decimal(10.03))
|
||||
checkEvaluation(cast(10.03, DecimalType.SYSTEM_DEFAULT), Decimal(10.03))
|
||||
checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03))
|
||||
checkEvaluation(cast(10.03, DecimalType(3, 1)), Decimal(10.0))
|
||||
checkEvaluation(cast(10.03, DecimalType(2, 0)), Decimal(10))
|
||||
|
@ -383,7 +383,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(cast(Decimal(10.03), DecimalType(3, 1)), Decimal(10.0))
|
||||
checkEvaluation(cast(Decimal(10.03), DecimalType(3, 2)), null)
|
||||
|
||||
checkEvaluation(cast(10.05, DecimalType.Unlimited), Decimal(10.05))
|
||||
checkEvaluation(cast(10.05, DecimalType.SYSTEM_DEFAULT), Decimal(10.05))
|
||||
checkEvaluation(cast(10.05, DecimalType(4, 2)), Decimal(10.05))
|
||||
checkEvaluation(cast(10.05, DecimalType(3, 1)), Decimal(10.1))
|
||||
checkEvaluation(cast(10.05, DecimalType(2, 0)), Decimal(10))
|
||||
|
@ -409,10 +409,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(cast(Decimal(-9.95), DecimalType(3, 1)), Decimal(-10.0))
|
||||
checkEvaluation(cast(Decimal(-9.95), DecimalType(1, 0)), null)
|
||||
|
||||
checkEvaluation(cast(Double.NaN, DecimalType.Unlimited), null)
|
||||
checkEvaluation(cast(1.0 / 0.0, DecimalType.Unlimited), null)
|
||||
checkEvaluation(cast(Float.NaN, DecimalType.Unlimited), null)
|
||||
checkEvaluation(cast(1.0f / 0.0f, DecimalType.Unlimited), null)
|
||||
checkEvaluation(cast(Double.NaN, DecimalType.SYSTEM_DEFAULT), null)
|
||||
checkEvaluation(cast(1.0 / 0.0, DecimalType.SYSTEM_DEFAULT), null)
|
||||
checkEvaluation(cast(Float.NaN, DecimalType.SYSTEM_DEFAULT), null)
|
||||
checkEvaluation(cast(1.0f / 0.0f, DecimalType.SYSTEM_DEFAULT), null)
|
||||
|
||||
checkEvaluation(cast(Double.NaN, DecimalType(2, 1)), null)
|
||||
checkEvaluation(cast(1.0 / 0.0, DecimalType(2, 1)), null)
|
||||
|
@ -427,7 +427,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(cast(d, LongType), null)
|
||||
checkEvaluation(cast(d, FloatType), null)
|
||||
checkEvaluation(cast(d, DoubleType), null)
|
||||
checkEvaluation(cast(d, DecimalType.Unlimited), null)
|
||||
checkEvaluation(cast(d, DecimalType.SYSTEM_DEFAULT), null)
|
||||
checkEvaluation(cast(d, DecimalType(10, 2)), null)
|
||||
checkEvaluation(cast(d, StringType), "1970-01-01")
|
||||
checkEvaluation(cast(cast(d, TimestampType), StringType), "1970-01-01 00:00:00")
|
||||
|
@ -454,7 +454,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
cast(cast(millis.toDouble / 1000, TimestampType), DoubleType),
|
||||
millis.toDouble / 1000)
|
||||
checkEvaluation(
|
||||
cast(cast(Decimal(1), TimestampType), DecimalType.Unlimited),
|
||||
cast(cast(Decimal(1), TimestampType), DecimalType.SYSTEM_DEFAULT),
|
||||
Decimal(1))
|
||||
|
||||
// A test for higher precision than millis
|
||||
|
|
|
@ -60,7 +60,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
|
||||
testIf(_.toFloat, FloatType)
|
||||
testIf(_.toDouble, DoubleType)
|
||||
testIf(Decimal(_), DecimalType.Unlimited)
|
||||
testIf(Decimal(_), DecimalType.USER_DEFAULT)
|
||||
|
||||
testIf(identity, DateType)
|
||||
testIf(_.toLong, TimestampType)
|
||||
|
|
|
@ -33,7 +33,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(Literal.create(null, LongType), null)
|
||||
checkEvaluation(Literal.create(null, StringType), null)
|
||||
checkEvaluation(Literal.create(null, BinaryType), null)
|
||||
checkEvaluation(Literal.create(null, DecimalType()), null)
|
||||
checkEvaluation(Literal.create(null, DecimalType.USER_DEFAULT), null)
|
||||
checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null)
|
||||
checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null)
|
||||
checkEvaluation(Literal.create(null, StructType(Seq.empty)), null)
|
||||
|
|
|
@ -30,7 +30,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
testFunc(1L, LongType)
|
||||
testFunc(1.0F, FloatType)
|
||||
testFunc(1.0, DoubleType)
|
||||
testFunc(Decimal(1.5), DecimalType.Unlimited)
|
||||
testFunc(Decimal(1.5), DecimalType(2, 1))
|
||||
testFunc(new java.sql.Date(10), DateType)
|
||||
testFunc(new java.sql.Timestamp(10), TimestampType)
|
||||
testFunc("abcd", StringType)
|
||||
|
|
|
@ -121,6 +121,8 @@ class UnsafeFixedWidthAggregationMapSuite
|
|||
}.toSet
|
||||
seenKeys.size should be (groupKeys.size)
|
||||
seenKeys should be (groupKeys)
|
||||
|
||||
map.free()
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -145,7 +145,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
|
|||
DoubleType,
|
||||
StringType,
|
||||
BinaryType
|
||||
// DecimalType.Unlimited,
|
||||
// DecimalType.Default,
|
||||
// ArrayType(IntegerType)
|
||||
)
|
||||
val converter = new UnsafeRowConverter(fieldTypes)
|
||||
|
|
|
@ -44,7 +44,7 @@ class DataTypeParserSuite extends SparkFunSuite {
|
|||
checkDataType("float", FloatType)
|
||||
checkDataType("dOUBle", DoubleType)
|
||||
checkDataType("decimal(10, 5)", DecimalType(10, 5))
|
||||
checkDataType("decimal", DecimalType.Unlimited)
|
||||
checkDataType("decimal", DecimalType.USER_DEFAULT)
|
||||
checkDataType("DATE", DateType)
|
||||
checkDataType("timestamp", TimestampType)
|
||||
checkDataType("string", StringType)
|
||||
|
@ -87,7 +87,7 @@ class DataTypeParserSuite extends SparkFunSuite {
|
|||
StructType(
|
||||
StructField("struct",
|
||||
StructType(
|
||||
StructField("deciMal", DecimalType.Unlimited, true) ::
|
||||
StructField("deciMal", DecimalType.USER_DEFAULT, true) ::
|
||||
StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) ::
|
||||
StructField("MAP", MapType(TimestampType, StringType), true) ::
|
||||
StructField("arrAy", ArrayType(DoubleType, true), true) :: Nil)
|
||||
|
|
|
@ -185,7 +185,7 @@ class DataTypeSuite extends SparkFunSuite {
|
|||
checkDataTypeJsonRepr(FloatType)
|
||||
checkDataTypeJsonRepr(DoubleType)
|
||||
checkDataTypeJsonRepr(DecimalType(10, 5))
|
||||
checkDataTypeJsonRepr(DecimalType.Unlimited)
|
||||
checkDataTypeJsonRepr(DecimalType.SYSTEM_DEFAULT)
|
||||
checkDataTypeJsonRepr(DateType)
|
||||
checkDataTypeJsonRepr(TimestampType)
|
||||
checkDataTypeJsonRepr(StringType)
|
||||
|
@ -219,7 +219,7 @@ class DataTypeSuite extends SparkFunSuite {
|
|||
checkDefaultSize(FloatType, 4)
|
||||
checkDefaultSize(DoubleType, 8)
|
||||
checkDefaultSize(DecimalType(10, 5), 4096)
|
||||
checkDefaultSize(DecimalType.Unlimited, 4096)
|
||||
checkDefaultSize(DecimalType.SYSTEM_DEFAULT, 4096)
|
||||
checkDefaultSize(DateType, 4)
|
||||
checkDefaultSize(TimestampType, 8)
|
||||
checkDefaultSize(StringType, 4096)
|
||||
|
|
|
@ -34,7 +34,7 @@ object DataTypeTestUtils {
|
|||
* decimal types.
|
||||
*/
|
||||
val fractionalTypes: Set[FractionalType] = Set(
|
||||
DecimalType(precisionInfo = None),
|
||||
DecimalType.SYSTEM_DEFAULT,
|
||||
DecimalType(2, 1),
|
||||
DoubleType,
|
||||
FloatType
|
||||
|
|
|
@ -996,7 +996,7 @@ class ColumnName(name: String) extends Column(name) {
|
|||
* Creates a new [[StructField]] of type decimal.
|
||||
* @since 1.3.0
|
||||
*/
|
||||
def decimal: StructField = StructField(name, DecimalType.Unlimited)
|
||||
def decimal: StructField = StructField(name, DecimalType.USER_DEFAULT)
|
||||
|
||||
/**
|
||||
* Creates a new [[StructField]] of type decimal.
|
||||
|
|
|
@ -375,7 +375,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 8) {
|
|||
|
||||
private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
|
||||
extends NativeColumnType(
|
||||
DecimalType(Some(PrecisionInfo(precision, scale))),
|
||||
DecimalType(precision, scale),
|
||||
10,
|
||||
FIXED_DECIMAL.defaultSize) {
|
||||
|
||||
|
|
|
@ -21,9 +21,9 @@ import org.apache.spark.TaskContext
|
|||
import org.apache.spark.annotation.DeveloperApi
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.trees._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.catalyst.trees._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
case class AggregateEvaluation(
|
||||
|
@ -92,8 +92,8 @@ case class GeneratedAggregate(
|
|||
case s @ Sum(expr) =>
|
||||
val calcType =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
DecimalType.Unlimited
|
||||
case DecimalType.Fixed(p, s) =>
|
||||
DecimalType.bounded(p + 10, s)
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
|
@ -121,8 +121,8 @@ case class GeneratedAggregate(
|
|||
case cs @ CombineSum(expr) =>
|
||||
val calcType =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
DecimalType.Unlimited
|
||||
case DecimalType.Fixed(p, s) =>
|
||||
DecimalType.bounded(p + 10, s)
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import scala.util.Try
|
|||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.hadoop.util.Shell
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -236,7 +237,7 @@ private[sql] object PartitioningUtils {
|
|||
|
||||
/**
|
||||
* Converts a string to a [[Literal]] with automatic type inference. Currently only supports
|
||||
* [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.Unlimited]], and
|
||||
* [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.SYSTEM_DEFAULT]], and
|
||||
* [[StringType]].
|
||||
*/
|
||||
private[sql] def inferPartitionColumnValue(
|
||||
|
@ -249,7 +250,7 @@ private[sql] object PartitioningUtils {
|
|||
.orElse(Try(Literal.create(JLong.parseLong(raw), LongType)))
|
||||
// Then falls back to fractional types
|
||||
.orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType)))
|
||||
.orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited)))
|
||||
.orElse(Try(Literal(new JBigDecimal(raw))))
|
||||
// Then falls back to string
|
||||
.getOrElse {
|
||||
if (raw == defaultPartitionName) {
|
||||
|
@ -268,7 +269,7 @@ private[sql] object PartitioningUtils {
|
|||
}
|
||||
|
||||
private val upCastingOrder: Seq[DataType] =
|
||||
Seq(NullType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited, StringType)
|
||||
Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType)
|
||||
|
||||
/**
|
||||
* Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower"
|
||||
|
|
|
@ -66,8 +66,8 @@ private[sql] object JDBCRDD extends Logging {
|
|||
case java.sql.Types.DATALINK => null
|
||||
case java.sql.Types.DATE => DateType
|
||||
case java.sql.Types.DECIMAL
|
||||
if precision != 0 || scale != 0 => DecimalType(precision, scale)
|
||||
case java.sql.Types.DECIMAL => DecimalType.Unlimited
|
||||
if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
|
||||
case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT
|
||||
case java.sql.Types.DISTINCT => null
|
||||
case java.sql.Types.DOUBLE => DoubleType
|
||||
case java.sql.Types.FLOAT => FloatType
|
||||
|
@ -80,8 +80,8 @@ private[sql] object JDBCRDD extends Logging {
|
|||
case java.sql.Types.NCLOB => StringType
|
||||
case java.sql.Types.NULL => null
|
||||
case java.sql.Types.NUMERIC
|
||||
if precision != 0 || scale != 0 => DecimalType(precision, scale)
|
||||
case java.sql.Types.NUMERIC => DecimalType.Unlimited
|
||||
if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
|
||||
case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT
|
||||
case java.sql.Types.NVARCHAR => StringType
|
||||
case java.sql.Types.OTHER => null
|
||||
case java.sql.Types.REAL => DoubleType
|
||||
|
@ -314,7 +314,7 @@ private[sql] class JDBCRDD(
|
|||
abstract class JDBCConversion
|
||||
case object BooleanConversion extends JDBCConversion
|
||||
case object DateConversion extends JDBCConversion
|
||||
case class DecimalConversion(precisionInfo: Option[(Int, Int)]) extends JDBCConversion
|
||||
case class DecimalConversion(precision: Int, scale: Int) extends JDBCConversion
|
||||
case object DoubleConversion extends JDBCConversion
|
||||
case object FloatConversion extends JDBCConversion
|
||||
case object IntegerConversion extends JDBCConversion
|
||||
|
@ -331,8 +331,7 @@ private[sql] class JDBCRDD(
|
|||
schema.fields.map(sf => sf.dataType match {
|
||||
case BooleanType => BooleanConversion
|
||||
case DateType => DateConversion
|
||||
case DecimalType.Unlimited => DecimalConversion(None)
|
||||
case DecimalType.Fixed(d) => DecimalConversion(Some(d))
|
||||
case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
|
||||
case DoubleType => DoubleConversion
|
||||
case FloatType => FloatConversion
|
||||
case IntegerType => IntegerConversion
|
||||
|
@ -399,20 +398,13 @@ private[sql] class JDBCRDD(
|
|||
// DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
|
||||
// retrieve it, you will get wrong result 199.99.
|
||||
// So it is needed to set precision and scale for Decimal based on JDBC metadata.
|
||||
case DecimalConversion(Some((p, s))) =>
|
||||
case DecimalConversion(p, s) =>
|
||||
val decimalVal = rs.getBigDecimal(pos)
|
||||
if (decimalVal == null) {
|
||||
mutableRow.update(i, null)
|
||||
} else {
|
||||
mutableRow.update(i, Decimal(decimalVal, p, s))
|
||||
}
|
||||
case DecimalConversion(None) =>
|
||||
val decimalVal = rs.getBigDecimal(pos)
|
||||
if (decimalVal == null) {
|
||||
mutableRow.update(i, null)
|
||||
} else {
|
||||
mutableRow.update(i, Decimal(decimalVal))
|
||||
}
|
||||
case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
|
||||
case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
|
||||
case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos))
|
||||
|
|
|
@ -89,8 +89,7 @@ package object jdbc {
|
|||
case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i))
|
||||
case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
|
||||
case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
|
||||
case DecimalType.Unlimited => stmt.setBigDecimal(i + 1,
|
||||
row.getAs[java.math.BigDecimal](i))
|
||||
case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
|
||||
case _ => throw new IllegalArgumentException(
|
||||
s"Can't translate non-null value for field $i")
|
||||
}
|
||||
|
@ -145,7 +144,7 @@ package object jdbc {
|
|||
case BinaryType => "BLOB"
|
||||
case TimestampType => "TIMESTAMP"
|
||||
case DateType => "DATE"
|
||||
case DecimalType.Unlimited => "DECIMAL(40,20)"
|
||||
case t: DecimalType => s"DECIMAL(${t.precision}},${t.scale}})"
|
||||
case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
|
||||
})
|
||||
val nullable = if (field.nullable) "" else "NOT NULL"
|
||||
|
@ -177,7 +176,7 @@ package object jdbc {
|
|||
case BinaryType => java.sql.Types.BLOB
|
||||
case TimestampType => java.sql.Types.TIMESTAMP
|
||||
case DateType => java.sql.Types.DATE
|
||||
case DecimalType.Unlimited => java.sql.Types.DECIMAL
|
||||
case t: DecimalType => java.sql.Types.DECIMAL
|
||||
case _ => throw new IllegalArgumentException(
|
||||
s"Can't translate null value for field $field")
|
||||
})
|
||||
|
|
|
@ -113,7 +113,7 @@ private[sql] object InferSchema {
|
|||
case INT | LONG => LongType
|
||||
// Since we do not have a data type backed by BigInteger,
|
||||
// when we see a Java BigInteger, we use DecimalType.
|
||||
case BIG_INTEGER | BIG_DECIMAL => DecimalType.Unlimited
|
||||
case BIG_INTEGER | BIG_DECIMAL => DecimalType.SYSTEM_DEFAULT
|
||||
case FLOAT | DOUBLE => DoubleType
|
||||
}
|
||||
|
||||
|
@ -168,8 +168,13 @@ private[sql] object InferSchema {
|
|||
HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse {
|
||||
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
|
||||
(t1, t2) match {
|
||||
case (other: DataType, NullType) => other
|
||||
case (NullType, other: DataType) => other
|
||||
// Double support larger range than fixed decimal, DecimalType.Maximum should be enough
|
||||
// in most case, also have better precision.
|
||||
case (DoubleType, t: DecimalType) =>
|
||||
if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType
|
||||
case (t: DecimalType, DoubleType) =>
|
||||
if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType
|
||||
|
||||
case (StructType(fields1), StructType(fields2)) =>
|
||||
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
|
||||
case (name, fieldTypes) =>
|
||||
|
|
|
@ -439,10 +439,6 @@ private[parquet] class CatalystSchemaConverter(
|
|||
.length(minBytesForPrecision(precision))
|
||||
.named(field.name)
|
||||
|
||||
case dec @ DecimalType.Unlimited if followParquetFormatSpec =>
|
||||
throw new AnalysisException(
|
||||
s"Data type $dec is not supported. Decimal precision and scale must be specified.")
|
||||
|
||||
// ===================================================
|
||||
// ArrayType and MapType (for Spark versions <= 1.4.x)
|
||||
// ===================================================
|
||||
|
|
|
@ -261,10 +261,10 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
|
|||
case BinaryType => writer.addBinary(
|
||||
Binary.fromByteArray(value.asInstanceOf[Array[Byte]]))
|
||||
case d: DecimalType =>
|
||||
if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
|
||||
if (d.precision > 18) {
|
||||
sys.error(s"Unsupported datatype $d, cannot write to consumer")
|
||||
}
|
||||
writeDecimal(value.asInstanceOf[Decimal], d.precisionInfo.get.precision)
|
||||
writeDecimal(value.asInstanceOf[Decimal], d.precision)
|
||||
case _ => sys.error(s"Do not know how to writer $schema to consumer")
|
||||
}
|
||||
}
|
||||
|
@ -415,10 +415,10 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
|
|||
case BinaryType => writer.addBinary(
|
||||
Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]]))
|
||||
case d: DecimalType =>
|
||||
if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
|
||||
if (d.precision > 18) {
|
||||
sys.error(s"Unsupported datatype $d, cannot write to consumer")
|
||||
}
|
||||
writeDecimal(record(index).asInstanceOf[Decimal], d.precisionInfo.get.precision)
|
||||
writeDecimal(record(index).asInstanceOf[Decimal], d.precision)
|
||||
case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,6 @@ import java.util.ArrayList;
|
|||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.spark.sql.test.TestSQLContext$;
|
||||
import org.junit.After;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
|
@ -31,8 +30,14 @@ import org.junit.Test;
|
|||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.api.java.function.Function;
|
||||
import org.apache.spark.sql.*;
|
||||
import org.apache.spark.sql.types.*;
|
||||
import org.apache.spark.sql.DataFrame;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.RowFactory;
|
||||
import org.apache.spark.sql.SQLContext;
|
||||
import org.apache.spark.sql.test.TestSQLContext$;
|
||||
import org.apache.spark.sql.types.DataTypes;
|
||||
import org.apache.spark.sql.types.StructField;
|
||||
import org.apache.spark.sql.types.StructType;
|
||||
|
||||
// The test suite itself is Serializable so that anonymous Function implementations can be
|
||||
// serialized, as an alternative to converting these anonymous classes to static inner classes;
|
||||
|
@ -159,7 +164,8 @@ public class JavaApplySchemaSuite implements Serializable {
|
|||
"\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " +
|
||||
"\"boolean\":false, \"null\":null}"));
|
||||
List<StructField> fields = new ArrayList<StructField>(7);
|
||||
fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(), true));
|
||||
fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(38, 18),
|
||||
true));
|
||||
fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true));
|
||||
fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true));
|
||||
fields.add(DataTypes.createStructField("integer", DataTypes.LongType, true));
|
||||
|
|
|
@ -148,7 +148,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
|
|||
val dataTypes =
|
||||
Seq(StringType, BinaryType, NullType, BooleanType,
|
||||
ByteType, ShortType, IntegerType, LongType,
|
||||
FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5),
|
||||
FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5),
|
||||
DateType, TimestampType,
|
||||
ArrayType(IntegerType), MapType(StringType, LongType), struct)
|
||||
val fields = dataTypes.zipWithIndex.map { case (dataType, index) =>
|
||||
|
|
|
@ -109,7 +109,7 @@ class PlannerSuite extends SparkFunSuite {
|
|||
FloatType ::
|
||||
DoubleType ::
|
||||
DecimalType(10, 5) ::
|
||||
DecimalType.Unlimited ::
|
||||
DecimalType.SYSTEM_DEFAULT ::
|
||||
DateType ::
|
||||
TimestampType ::
|
||||
StringType ::
|
||||
|
|
|
@ -54,7 +54,7 @@ class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite {
|
|||
checkSupported(StringType, isSupported = true)
|
||||
checkSupported(BinaryType, isSupported = true)
|
||||
checkSupported(DecimalType(10, 5), isSupported = true)
|
||||
checkSupported(DecimalType.Unlimited, isSupported = true)
|
||||
checkSupported(DecimalType.SYSTEM_DEFAULT, isSupported = true)
|
||||
|
||||
// If NullType is the only data type in the schema, we do not support it.
|
||||
checkSupported(NullType, isSupported = false)
|
||||
|
@ -86,7 +86,7 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
|
|||
val supportedTypes =
|
||||
Seq(StringType, BinaryType, NullType, BooleanType,
|
||||
ByteType, ShortType, IntegerType, LongType,
|
||||
FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5),
|
||||
FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5),
|
||||
DateType, TimestampType)
|
||||
|
||||
val fields = supportedTypes.zipWithIndex.map { case (dataType, index) =>
|
||||
|
|
|
@ -134,7 +134,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
|
|||
""".stripMargin.replaceAll("\n", " "))
|
||||
|
||||
|
||||
conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(40, 20))"
|
||||
conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(38, 18))"
|
||||
).executeUpdate()
|
||||
conn.prepareStatement("insert into test.flttypes values ("
|
||||
+ "1.0000000000000002220446049250313080847263336181640625, "
|
||||
|
@ -152,7 +152,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
|
|||
s"""
|
||||
|create table test.nulltypes (a INT, b BOOLEAN, c TINYINT, d BINARY(20), e VARCHAR(20),
|
||||
|f VARCHAR_IGNORECASE(20), g CHAR(20), h BLOB, i CLOB, j TIME, k DATE, l TIMESTAMP,
|
||||
|m DOUBLE, n REAL, o DECIMAL(40, 20))
|
||||
|m DOUBLE, n REAL, o DECIMAL(38, 18))
|
||||
""".stripMargin.replaceAll("\n", " ")).executeUpdate()
|
||||
conn.prepareStatement("insert into test.nulltypes values ("
|
||||
+ "null, null, null, null, null, null, null, null, null, "
|
||||
|
@ -357,14 +357,14 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
|
|||
|
||||
test("H2 floating-point types") {
|
||||
val rows = sql("SELECT * FROM flttypes").collect()
|
||||
assert(rows(0).getDouble(0) === 1.00000000000000022) // Yes, I meant ==.
|
||||
assert(rows(0).getDouble(1) === 1.00000011920928955) // Yes, I meant ==.
|
||||
assert(rows(0).getAs[BigDecimal](2)
|
||||
.equals(new BigDecimal("123456789012345.54321543215432100000")))
|
||||
assert(rows(0).schema.fields(2).dataType === DecimalType(40, 20))
|
||||
val compareDecimal = sql("SELECT C FROM flttypes where C > C - 1").collect()
|
||||
assert(compareDecimal(0).getAs[BigDecimal](0)
|
||||
.equals(new BigDecimal("123456789012345.54321543215432100000")))
|
||||
assert(rows(0).getDouble(0) === 1.00000000000000022)
|
||||
assert(rows(0).getDouble(1) === 1.00000011920928955)
|
||||
assert(rows(0).getAs[BigDecimal](2) ===
|
||||
new BigDecimal("123456789012345.543215432154321000"))
|
||||
assert(rows(0).schema.fields(2).dataType === DecimalType(38, 18))
|
||||
val result = sql("SELECT C FROM flttypes where C > C - 1").collect()
|
||||
assert(result(0).getAs[BigDecimal](0) ===
|
||||
new BigDecimal("123456789012345.543215432154321000"))
|
||||
}
|
||||
|
||||
test("SQL query as table name") {
|
||||
|
|
|
@ -63,18 +63,18 @@ class JsonSuite extends QueryTest with TestJsonData {
|
|||
checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType))
|
||||
checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType))
|
||||
checkTypePromotion(
|
||||
Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.Unlimited))
|
||||
Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.SYSTEM_DEFAULT))
|
||||
|
||||
val longNumber: Long = 9223372036854775807L
|
||||
checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType))
|
||||
checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType))
|
||||
checkTypePromotion(
|
||||
Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.Unlimited))
|
||||
Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.SYSTEM_DEFAULT))
|
||||
|
||||
val doubleNumber: Double = 1.7976931348623157E308d
|
||||
checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType))
|
||||
checkTypePromotion(
|
||||
Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited))
|
||||
Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.SYSTEM_DEFAULT))
|
||||
|
||||
checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber)),
|
||||
enforceCorrectType(intNumber, TimestampType))
|
||||
|
@ -115,7 +115,7 @@ class JsonSuite extends QueryTest with TestJsonData {
|
|||
checkDataType(NullType, IntegerType, IntegerType)
|
||||
checkDataType(NullType, LongType, LongType)
|
||||
checkDataType(NullType, DoubleType, DoubleType)
|
||||
checkDataType(NullType, DecimalType.Unlimited, DecimalType.Unlimited)
|
||||
checkDataType(NullType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT)
|
||||
checkDataType(NullType, StringType, StringType)
|
||||
checkDataType(NullType, ArrayType(IntegerType), ArrayType(IntegerType))
|
||||
checkDataType(NullType, StructType(Nil), StructType(Nil))
|
||||
|
@ -126,7 +126,7 @@ class JsonSuite extends QueryTest with TestJsonData {
|
|||
checkDataType(BooleanType, IntegerType, StringType)
|
||||
checkDataType(BooleanType, LongType, StringType)
|
||||
checkDataType(BooleanType, DoubleType, StringType)
|
||||
checkDataType(BooleanType, DecimalType.Unlimited, StringType)
|
||||
checkDataType(BooleanType, DecimalType.SYSTEM_DEFAULT, StringType)
|
||||
checkDataType(BooleanType, StringType, StringType)
|
||||
checkDataType(BooleanType, ArrayType(IntegerType), StringType)
|
||||
checkDataType(BooleanType, StructType(Nil), StringType)
|
||||
|
@ -135,7 +135,7 @@ class JsonSuite extends QueryTest with TestJsonData {
|
|||
checkDataType(IntegerType, IntegerType, IntegerType)
|
||||
checkDataType(IntegerType, LongType, LongType)
|
||||
checkDataType(IntegerType, DoubleType, DoubleType)
|
||||
checkDataType(IntegerType, DecimalType.Unlimited, DecimalType.Unlimited)
|
||||
checkDataType(IntegerType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT)
|
||||
checkDataType(IntegerType, StringType, StringType)
|
||||
checkDataType(IntegerType, ArrayType(IntegerType), StringType)
|
||||
checkDataType(IntegerType, StructType(Nil), StringType)
|
||||
|
@ -143,23 +143,24 @@ class JsonSuite extends QueryTest with TestJsonData {
|
|||
// LongType
|
||||
checkDataType(LongType, LongType, LongType)
|
||||
checkDataType(LongType, DoubleType, DoubleType)
|
||||
checkDataType(LongType, DecimalType.Unlimited, DecimalType.Unlimited)
|
||||
checkDataType(LongType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT)
|
||||
checkDataType(LongType, StringType, StringType)
|
||||
checkDataType(LongType, ArrayType(IntegerType), StringType)
|
||||
checkDataType(LongType, StructType(Nil), StringType)
|
||||
|
||||
// DoubleType
|
||||
checkDataType(DoubleType, DoubleType, DoubleType)
|
||||
checkDataType(DoubleType, DecimalType.Unlimited, DecimalType.Unlimited)
|
||||
checkDataType(DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT)
|
||||
checkDataType(DoubleType, StringType, StringType)
|
||||
checkDataType(DoubleType, ArrayType(IntegerType), StringType)
|
||||
checkDataType(DoubleType, StructType(Nil), StringType)
|
||||
|
||||
// DoubleType
|
||||
checkDataType(DecimalType.Unlimited, DecimalType.Unlimited, DecimalType.Unlimited)
|
||||
checkDataType(DecimalType.Unlimited, StringType, StringType)
|
||||
checkDataType(DecimalType.Unlimited, ArrayType(IntegerType), StringType)
|
||||
checkDataType(DecimalType.Unlimited, StructType(Nil), StringType)
|
||||
// DecimalType
|
||||
checkDataType(DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT,
|
||||
DecimalType.SYSTEM_DEFAULT)
|
||||
checkDataType(DecimalType.SYSTEM_DEFAULT, StringType, StringType)
|
||||
checkDataType(DecimalType.SYSTEM_DEFAULT, ArrayType(IntegerType), StringType)
|
||||
checkDataType(DecimalType.SYSTEM_DEFAULT, StructType(Nil), StringType)
|
||||
|
||||
// StringType
|
||||
checkDataType(StringType, StringType, StringType)
|
||||
|
@ -213,7 +214,7 @@ class JsonSuite extends QueryTest with TestJsonData {
|
|||
checkDataType(
|
||||
StructType(
|
||||
StructField("f1", IntegerType, true) :: Nil),
|
||||
DecimalType.Unlimited,
|
||||
DecimalType.SYSTEM_DEFAULT,
|
||||
StringType)
|
||||
}
|
||||
|
||||
|
@ -240,7 +241,7 @@ class JsonSuite extends QueryTest with TestJsonData {
|
|||
val jsonDF = ctx.read.json(primitiveFieldAndType)
|
||||
|
||||
val expectedSchema = StructType(
|
||||
StructField("bigInteger", DecimalType.Unlimited, true) ::
|
||||
StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) ::
|
||||
StructField("boolean", BooleanType, true) ::
|
||||
StructField("double", DoubleType, true) ::
|
||||
StructField("integer", LongType, true) ::
|
||||
|
@ -270,7 +271,7 @@ class JsonSuite extends QueryTest with TestJsonData {
|
|||
val expectedSchema = StructType(
|
||||
StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) ::
|
||||
StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, true), true), true) ::
|
||||
StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, true), true) ::
|
||||
StructField("arrayOfBigInteger", ArrayType(DecimalType.SYSTEM_DEFAULT, true), true) ::
|
||||
StructField("arrayOfBoolean", ArrayType(BooleanType, true), true) ::
|
||||
StructField("arrayOfDouble", ArrayType(DoubleType, true), true) ::
|
||||
StructField("arrayOfInteger", ArrayType(LongType, true), true) ::
|
||||
|
@ -284,7 +285,7 @@ class JsonSuite extends QueryTest with TestJsonData {
|
|||
StructField("field3", StringType, true) :: Nil), true), true) ::
|
||||
StructField("struct", StructType(
|
||||
StructField("field1", BooleanType, true) ::
|
||||
StructField("field2", DecimalType.Unlimited, true) :: Nil), true) ::
|
||||
StructField("field2", DecimalType.SYSTEM_DEFAULT, true) :: Nil), true) ::
|
||||
StructField("structWithArrayFields", StructType(
|
||||
StructField("field1", ArrayType(LongType, true), true) ::
|
||||
StructField("field2", ArrayType(StringType, true), true) :: Nil), true) :: Nil)
|
||||
|
@ -385,7 +386,7 @@ class JsonSuite extends QueryTest with TestJsonData {
|
|||
val expectedSchema = StructType(
|
||||
StructField("num_bool", StringType, true) ::
|
||||
StructField("num_num_1", LongType, true) ::
|
||||
StructField("num_num_2", DecimalType.Unlimited, true) ::
|
||||
StructField("num_num_2", DecimalType.SYSTEM_DEFAULT, true) ::
|
||||
StructField("num_num_3", DoubleType, true) ::
|
||||
StructField("num_str", StringType, true) ::
|
||||
StructField("str_bool", StringType, true) :: Nil)
|
||||
|
@ -421,11 +422,11 @@ class JsonSuite extends QueryTest with TestJsonData {
|
|||
Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil
|
||||
)
|
||||
|
||||
// Widening to DecimalType
|
||||
// Widening to DoubleType
|
||||
checkAnswer(
|
||||
sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"),
|
||||
Row(new java.math.BigDecimal("21474836472.1")) ::
|
||||
Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil
|
||||
sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"),
|
||||
Row(21474836472.2) ::
|
||||
Row(92233720368547758071.3) :: Nil
|
||||
)
|
||||
|
||||
// Widening to DoubleType
|
||||
|
@ -442,8 +443,8 @@ class JsonSuite extends QueryTest with TestJsonData {
|
|||
|
||||
// Number and String conflict: resolve the type as number in this query.
|
||||
checkAnswer(
|
||||
sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"),
|
||||
Row(new java.math.BigDecimal("92233720368547758061.2").doubleValue)
|
||||
sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"),
|
||||
Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue)
|
||||
)
|
||||
|
||||
// String and Boolean conflict: resolve the type as string.
|
||||
|
@ -489,9 +490,9 @@ class JsonSuite extends QueryTest with TestJsonData {
|
|||
// in the Project.
|
||||
checkAnswer(
|
||||
jsonDF.
|
||||
where('num_str > BigDecimal("92233720368547758060")).
|
||||
where('num_str >= BigDecimal("92233720368547758060")).
|
||||
select(('num_str + 1.2).as("num")),
|
||||
Row(new java.math.BigDecimal("92233720368547758061.2"))
|
||||
Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue())
|
||||
)
|
||||
|
||||
// The following test will fail. The type of num_str is StringType.
|
||||
|
@ -610,7 +611,7 @@ class JsonSuite extends QueryTest with TestJsonData {
|
|||
val jsonDF = ctx.read.json(path)
|
||||
|
||||
val expectedSchema = StructType(
|
||||
StructField("bigInteger", DecimalType.Unlimited, true) ::
|
||||
StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) ::
|
||||
StructField("boolean", BooleanType, true) ::
|
||||
StructField("double", DoubleType, true) ::
|
||||
StructField("integer", LongType, true) ::
|
||||
|
@ -668,7 +669,7 @@ class JsonSuite extends QueryTest with TestJsonData {
|
|||
primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
|
||||
|
||||
val schema = StructType(
|
||||
StructField("bigInteger", DecimalType.Unlimited, true) ::
|
||||
StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) ::
|
||||
StructField("boolean", BooleanType, true) ::
|
||||
StructField("double", DoubleType, true) ::
|
||||
StructField("integer", IntegerType, true) ::
|
||||
|
|
|
@ -122,14 +122,6 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
|
|||
sqlContext.read.parquet(dir.getCanonicalPath).collect()
|
||||
}
|
||||
}
|
||||
|
||||
// Unlimited-length decimals are not yet supported
|
||||
intercept[Throwable] {
|
||||
withTempPath { dir =>
|
||||
makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath)
|
||||
sqlContext.read.parquet(dir.getCanonicalPath).collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("date type") {
|
||||
|
|
|
@ -509,7 +509,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
|
|||
FloatType,
|
||||
DoubleType,
|
||||
DecimalType(10, 5),
|
||||
DecimalType.Unlimited,
|
||||
DecimalType.SYSTEM_DEFAULT,
|
||||
DateType,
|
||||
TimestampType,
|
||||
StringType)
|
||||
|
|
|
@ -44,7 +44,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo
|
|||
StructField("doubleType", DoubleType, nullable = false),
|
||||
StructField("bigintType", LongType, nullable = false),
|
||||
StructField("tinyintType", ByteType, nullable = false),
|
||||
StructField("decimalType", DecimalType.Unlimited, nullable = false),
|
||||
StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false),
|
||||
StructField("fixedDecimalType", DecimalType(5, 1), nullable = false),
|
||||
StructField("binaryType", BinaryType, nullable = false),
|
||||
StructField("booleanType", BooleanType, nullable = false),
|
||||
|
|
|
@ -202,7 +202,7 @@ class TableScanSuite extends DataSourceTest {
|
|||
StructField("longField_:,<>=+/~^", LongType, true) ::
|
||||
StructField("floatField", FloatType, true) ::
|
||||
StructField("doubleField", DoubleType, true) ::
|
||||
StructField("decimalField1", DecimalType.Unlimited, true) ::
|
||||
StructField("decimalField1", DecimalType.USER_DEFAULT, true) ::
|
||||
StructField("decimalField2", DecimalType(9, 2), true) ::
|
||||
StructField("dateField", DateType, true) ::
|
||||
StructField("timestampField", TimestampType, true) ::
|
||||
|
|
|
@ -179,7 +179,7 @@ private[hive] trait HiveInspectors {
|
|||
// writable
|
||||
case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType
|
||||
case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType
|
||||
case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.Unlimited
|
||||
case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.SYSTEM_DEFAULT
|
||||
case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType
|
||||
case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType
|
||||
case c: Class[_] if c == classOf[hiveIo.DateWritable] => DateType
|
||||
|
@ -195,8 +195,8 @@ private[hive] trait HiveInspectors {
|
|||
case c: Class[_] if c == classOf[java.lang.String] => StringType
|
||||
case c: Class[_] if c == classOf[java.sql.Date] => DateType
|
||||
case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType
|
||||
case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.Unlimited
|
||||
case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.Unlimited
|
||||
case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.SYSTEM_DEFAULT
|
||||
case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.SYSTEM_DEFAULT
|
||||
case c: Class[_] if c == classOf[Array[Byte]] => BinaryType
|
||||
case c: Class[_] if c == classOf[java.lang.Short] => ShortType
|
||||
case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
|
||||
|
@ -813,9 +813,6 @@ private[hive] trait HiveInspectors {
|
|||
|
||||
private def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match {
|
||||
case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale)
|
||||
case _ => new DecimalTypeInfo(
|
||||
HiveShim.UNLIMITED_DECIMAL_PRECISION,
|
||||
HiveShim.UNLIMITED_DECIMAL_SCALE)
|
||||
}
|
||||
|
||||
def toTypeInfo: TypeInfo = dt match {
|
||||
|
|
|
@ -377,7 +377,7 @@ private[hive] object HiveQl extends Logging {
|
|||
DecimalType(precision.getText.toInt, scale.getText.toInt)
|
||||
case Token("TOK_DECIMAL", precision :: Nil) =>
|
||||
DecimalType(precision.getText.toInt, 0)
|
||||
case Token("TOK_DECIMAL", Nil) => DecimalType.Unlimited
|
||||
case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT
|
||||
case Token("TOK_BIGINT", Nil) => LongType
|
||||
case Token("TOK_INT", Nil) => IntegerType
|
||||
case Token("TOK_TINYINT", Nil) => ByteType
|
||||
|
@ -1369,7 +1369,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
|
|||
case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) =>
|
||||
Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, 0))
|
||||
case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) =>
|
||||
Cast(nodeToExpr(arg), DecimalType.Unlimited)
|
||||
Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT)
|
||||
case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) =>
|
||||
Cast(nodeToExpr(arg), TimestampType)
|
||||
case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) =>
|
||||
|
|
Loading…
Reference in a new issue