diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala index c5fd2f9d5a..6355e0f179 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -218,7 +218,7 @@ class AttributeSuite extends SparkFunSuite { // Attribute.fromStructField should accept any NumericType, not just DoubleType val longFldWithMeta = new StructField("x", LongType, false, metadata) assert(Attribute.fromStructField(longFldWithMeta).isNumeric) - val decimalFldWithMeta = new StructField("x", DecimalType(None), false, metadata) + val decimalFldWithMeta = new StructField("x", DecimalType(38, 18), false, metadata) assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric) } } diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 10ad89ea14..b97d50c945 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -194,30 +194,33 @@ class TimestampType(AtomicType): class DecimalType(FractionalType): """Decimal (decimal.Decimal) data type. + + The DecimalType must have fixed precision (the maximum total number of digits) + and scale (the number of digits on the right of dot). For example, (5, 2) can + support the value from [-999.99 to 999.99]. + + The precision can be up to 38, the scale must less or equal to precision. + + When create a DecimalType, the default precision and scale is (10, 0). When infer + schema from decimal.Decimal objects, it will be DecimalType(38, 18). + + :param precision: the maximum total number of digits (default: 10) + :param scale: the number of digits on right side of dot. (default: 0) """ - def __init__(self, precision=None, scale=None): + def __init__(self, precision=10, scale=0): self.precision = precision self.scale = scale - self.hasPrecisionInfo = precision is not None + self.hasPrecisionInfo = True # this is public API def simpleString(self): - if self.hasPrecisionInfo: - return "decimal(%d,%d)" % (self.precision, self.scale) - else: - return "decimal(10,0)" + return "decimal(%d,%d)" % (self.precision, self.scale) def jsonValue(self): - if self.hasPrecisionInfo: - return "decimal(%d,%d)" % (self.precision, self.scale) - else: - return "decimal" + return "decimal(%d,%d)" % (self.precision, self.scale) def __repr__(self): - if self.hasPrecisionInfo: - return "DecimalType(%d,%d)" % (self.precision, self.scale) - else: - return "DecimalType()" + return "DecimalType(%d,%d)" % (self.precision, self.scale) class DoubleType(FractionalType): @@ -761,7 +764,10 @@ def _infer_type(obj): return obj.__UDT__ dataType = _type_mappings.get(type(obj)) - if dataType is not None: + if dataType is DecimalType: + # the precision and scale of `obj` may be different from row to row. + return DecimalType(38, 18) + elif dataType is not None: return dataType() if isinstance(obj, dict): diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java index d22ad6794d..5703de4239 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java @@ -111,12 +111,18 @@ public class DataTypes { return new ArrayType(elementType, containsNull); } + /** + * Creates a DecimalType by specifying the precision and scale. + */ public static DecimalType createDecimalType(int precision, int scale) { return DecimalType$.MODULE$.apply(precision, scale); } + /** + * Creates a DecimalType with default precision and scale, which are 10 and 0. + */ public static DecimalType createDecimalType() { - return DecimalType$.MODULE$.Unlimited(); + return DecimalType$.MODULE$.USER_DEFAULT(); } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 9a3f9694e4..88a457f87c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -75,7 +75,7 @@ private [sql] object JavaTypeInference { case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) - case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true) + case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true) case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 21b1de1ab9..2442341da1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -131,10 +131,10 @@ trait ScalaReflection { case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true) case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true) - case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[java.math.BigDecimal] => - Schema(DecimalType.Unlimited, nullable = true) - case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) + Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) + case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true) case t if t <:< localTypeOf[java.lang.Double] => Schema(DoubleType, nullable = true) @@ -167,8 +167,8 @@ trait ScalaReflection { case obj: Float => FloatType case obj: Double => DoubleType case obj: java.sql.Date => DateType - case obj: java.math.BigDecimal => DecimalType.Unlimited - case obj: Decimal => DecimalType.Unlimited + case obj: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT + case obj: Decimal => DecimalType.SYSTEM_DEFAULT case obj: java.sql.Timestamp => TimestampType case null => NullType // For other cases, there is no obvious mapping from the type of the given object to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 29cfc064da..c494e5d704 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -322,7 +322,10 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val numericLiteral: Parser[Literal] = ( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) } - | sign.? ~ unsignedFloat ^^ { case s ~ f => Literal((s.getOrElse("") + f).toDouble) } + | sign.? ~ unsignedFloat ^^ { + // TODO(davies): some precisions may loss, we should create decimal literal + case s ~ f => Literal(BigDecimal(s.getOrElse("") + f).doubleValue()) + } ) protected lazy val unsignedFloat: Parser[String] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e214545726..d56ceeadc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ @@ -58,8 +60,7 @@ object HiveTypeCoercion { IntegerType, LongType, FloatType, - DoubleType, - DecimalType.Unlimited) + DoubleType) /** * Find the tightest common type of two types that might be used in a binary expression. @@ -72,15 +73,16 @@ object HiveTypeCoercion { case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) - // Promote numeric types to the highest of the two and all numeric types to unlimited decimal + case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) => + Some(t2) + case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) => + Some(t1) + + // Promote numeric types to the highest of the two case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) Some(numericPrecedence(index)) - // Fixed-precision decimals can up-cast into unlimited - case (DecimalType.Unlimited, _: DecimalType) => Some(DecimalType.Unlimited) - case (_: DecimalType, DecimalType.Unlimited) => Some(DecimalType.Unlimited) - case _ => None } @@ -101,7 +103,7 @@ object HiveTypeCoercion { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { case None => None case Some(d) => - findTightestCommonTypeOfTwo(d, c).orElse(findTightestCommonTypeToString(d, c)) + findTightestCommonTypeToString(d, c) }) } @@ -158,6 +160,9 @@ object HiveTypeCoercion { * converted to DOUBLE. * - TINYINT, SMALLINT, and INT can all be converted to FLOAT. * - BOOLEAN types cannot be converted to any other type. + * - Any integral numeric type can be implicitly converted to decimal type. + * - two different decimal types will be converted into a wider decimal type for both of them. + * - decimal type will be converted into double if there float or double together with it. * * Additionally, all types when UNION-ed with strings will be promoted to strings. * Other string conversions are handled by PromoteStrings. @@ -166,55 +171,50 @@ object HiveTypeCoercion { * - IntegerType to FloatType * - LongType to FloatType * - LongType to DoubleType + * - DecimalType to Double + * + * This rule is only applied to Union/Except/Intersect */ object WidenTypes extends Rule[LogicalPlan] { - private[this] def widenOutputTypes(planName: String, left: LogicalPlan, right: LogicalPlan): - (LogicalPlan, LogicalPlan) = { - - // TODO: with fixed-precision decimals - val castedInput = left.output.zip(right.output).map { - // When a string is found on one side, make the other side a string too. - case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType => - (lhs, Alias(Cast(rhs, StringType), rhs.name)()) - case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType => - (Alias(Cast(lhs, StringType), lhs.name)(), rhs) + private[this] def widenOutputTypes( + planName: String, + left: LogicalPlan, + right: LogicalPlan): (LogicalPlan, LogicalPlan) = { + val castedTypes = left.output.zip(right.output).map { case (lhs, rhs) if lhs.dataType != rhs.dataType => - logDebug(s"Resolving mismatched $planName input ${lhs.dataType}, ${rhs.dataType}") - findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType => - val newLeft = - if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)() - val newRight = - if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)() - - (newLeft, newRight) - }.getOrElse { - // If there is no applicable conversion, leave expression unchanged. - (lhs, rhs) + (lhs.dataType, rhs.dataType) match { + case (t1: DecimalType, t2: DecimalType) => + Some(DecimalPrecision.widerDecimalType(t1, t2)) + case (t: IntegralType, d: DecimalType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (d: DecimalType, t: IntegralType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (t: FractionalType, d: DecimalType) => + Some(DoubleType) + case (d: DecimalType, t: FractionalType) => + Some(DoubleType) + case _ => + findTightestCommonTypeToString(lhs.dataType, rhs.dataType) } - - case other => other + case other => None } - val (castedLeft, castedRight) = castedInput.unzip - - val newLeft = - if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - logDebug(s"Widening numeric types in $planName $castedLeft ${left.output}") - Project(castedLeft, left) - } else { - left + def castOutput(plan: LogicalPlan): LogicalPlan = { + val casted = plan.output.zip(castedTypes).map { + case (hs, Some(dt)) if dt != hs.dataType => + Alias(Cast(hs, dt), hs.name)() + case (hs, _) => hs } + Project(casted, plan) + } - val newRight = - if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - logDebug(s"Widening numeric types in $planName $castedRight ${right.output}") - Project(castedRight, right) - } else { - right - } - (newLeft, newRight) + if (castedTypes.exists(_.isDefined)) { + (castOutput(left), castOutput(right)) + } else { + (left, right) + } } def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -334,144 +334,94 @@ object HiveTypeCoercion { * - SHORT gets turned into DECIMAL(5, 0) * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) - * - FLOAT and DOUBLE - * 1. Union, Intersect and Except operations: - * FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this is the - * same as Hive) - * 2. Other operation: - * FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, - * but note that unlimited decimals are considered bigger than doubles in WidenTypes) + * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE + * + * Note: Union/Except/Interact is handled by WidenTypes */ // scalastyle:on object DecimalPrecision extends Rule[LogicalPlan] { import scala.math.{max, min} - // Conversion rules for integer types into fixed-precision decimals - private val intTypeToFixed: Map[DataType, DecimalType] = Map( - ByteType -> DecimalType(3, 0), - ShortType -> DecimalType(5, 0), - IntegerType -> DecimalType(10, 0), - LongType -> DecimalType(20, 0) - ) - private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType - // Conversion rules for float and double into fixed-precision decimals - private val floatTypeToFixed: Map[DataType, DecimalType] = Map( - FloatType -> DecimalType(7, 7), - DoubleType -> DecimalType(15, 15) - ) + // Returns the wider decimal type that's wider than both of them + def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = { + widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale) + } + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = { + val scale = max(s1, s2) + val range = max(p1 - s1, p2 - s2) + DecimalType.bounded(range + scale, scale) + } - private def castDecimalPrecision( - left: LogicalPlan, - right: LogicalPlan): (LogicalPlan, LogicalPlan) = { - val castedInput = left.output.zip(right.output).map { - case (lhs, rhs) if lhs.dataType != rhs.dataType => - (lhs.dataType, rhs.dataType) match { - case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => - // Decimals with precision/scale p1/s2 and p2/s2 will be promoted to - // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)) - val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2)) - (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)()) - case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs) - case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)()) - case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) => - (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs) - case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) => - (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)()) - case _ => (lhs, rhs) - } - case other => other - } + /** + * An expression used to wrap the children when promote the precision of DecimalType to avoid + * promote multiple times. + */ + case class ChangePrecision(child: Expression) extends UnaryExpression { + override def dataType: DataType = child.dataType + override def eval(input: InternalRow): Any = child.eval(input) + override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" + override def prettyName: String = "change_precision" + } - val (castedLeft, castedRight) = castedInput.unzip - - val newLeft = - if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - Project(castedLeft, left) - } else { - left - } - - val newRight = - if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - Project(castedRight, right) - } else { - right - } - (newLeft, newRight) + def changePrecision(e: Expression, dataType: DataType): Expression = { + ChangePrecision(Cast(e, dataType)) } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // fix decimal precision for union, intersect and except - case u @ Union(left, right) if u.childrenResolved && !u.resolved => - val (newLeft, newRight) = castDecimalPrecision(left, right) - Union(newLeft, newRight) - case i @ Intersect(left, right) if i.childrenResolved && !i.resolved => - val (newLeft, newRight) = castDecimalPrecision(left, right) - Intersect(newLeft, newRight) - case e @ Except(left, right) if e.childrenResolved && !e.resolved => - val (newLeft, newRight) = castDecimalPrecision(left, right) - Except(newLeft, newRight) - // fix decimal precision for expressions case q => q.transformExpressions { // Skip nodes whose children have not been resolved yet case e if !e.childrenResolved => e + // Skip nodes who is already promoted + case e: BinaryArithmetic if e.left.isInstanceOf[ChangePrecision] => e + case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - ) + val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + Add(changePrecision(e1, dt), changePrecision(e2, dt)) case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - ) + val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + Subtract(changePrecision(e1, dt), changePrecision(e2, dt)) case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(p1 + p2 + 1, s1 + s2) - ) + val dt = DecimalType.bounded(p1 + p2 + 1, s1 + s2) + Multiply(changePrecision(e1, dt), changePrecision(e2, dt)) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) - ) + val dt = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) + Divide(changePrecision(e1, dt), changePrecision(e2, dt)) case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) - ) + val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + // resultType may have lower precision, so we cast them into wider type first. + val widerType = widerDecimalType(p1, s1, p2, s2) + Cast(Remainder(changePrecision(e1, widerType), changePrecision(e2, widerType)), + resultType) case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Pmod(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) - ) + val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + // resultType may have lower precision, so we cast them into wider type first. + val widerType = widerDecimalType(p1, s1, p2, s2) + Cast(Pmod(changePrecision(e1, widerType), changePrecision(e2, widerType)), resultType) - // When we compare 2 decimal types with different precisions, cast them to the smallest - // common precision. case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - val resultType = DecimalType(max(p1, p2), max(s1, s2)) + val resultType = widerDecimalType(p1, s1, p2, s2) b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles case b @ BinaryOperator(left, right) if left.dataType != right.dataType => (left.dataType, right.dataType) match { - case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right)) - case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - b.makeCopy(Array(left, Cast(right, intTypeToFixed(t)))) + case (t: IntegralType, DecimalType.Fixed(p, s)) => + b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right)) + case (DecimalType.Fixed(p, s), t: IntegralType) => + b.makeCopy(Array(left, Cast(right, DecimalType.forType(t)))) case (t, DecimalType.Fixed(p, s)) if isFloat(t) => b.makeCopy(Array(left, Cast(right, DoubleType))) case (DecimalType.Fixed(p, s), t) if isFloat(t) => @@ -485,7 +435,6 @@ object HiveTypeCoercion { // SUM and AVERAGE are handled by the implementations of those expressions } } - } /** @@ -563,7 +512,7 @@ object HiveTypeCoercion { case e if !e.childrenResolved => e case Cast(e @ StringType(), t: IntegralType) => - Cast(Cast(e, DecimalType.Unlimited), t) + Cast(Cast(e, DecimalType.forType(LongType)), t) } } @@ -756,8 +705,8 @@ object HiveTypeCoercion { // Implicit cast among numeric types. When we reach here, input type is not acceptable. // If input is a numeric type but not decimal, and we expect a decimal type, - // cast the input to unlimited precision decimal. - case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited) + // cast the input to decimal. + case (d: NumericType, DecimalType) => Cast(e, DecimalType.forType(d)) // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long case (_: NumericType, target: NumericType) => Cast(e, target) @@ -766,7 +715,7 @@ object HiveTypeCoercion { case (TimestampType, DateType) => Cast(e, DateType) // Implicit cast from/to string - case (StringType, DecimalType) => Cast(e, DecimalType.Unlimited) + case (StringType, DecimalType) => Cast(e, DecimalType.SYSTEM_DEFAULT) case (StringType, target: NumericType) => Cast(e, target) case (StringType, DateType) => Cast(e, DateType) case (StringType, TimestampType) => Cast(e, TimestampType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 5182175796..a7e3a49327 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -201,7 +201,7 @@ package object dsl { /** Creates a new AttributeReference of type decimal */ def decimal: AttributeReference = - AttributeReference(s, DecimalType.Unlimited, nullable = true)() + AttributeReference(s, DecimalType.SYSTEM_DEFAULT, nullable = true)() /** Creates a new AttributeReference of type decimal */ def decimal(precision: Int, scale: Int): AttributeReference = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index e66cd82848..c66854d52c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -300,12 +300,7 @@ case class Cast(child: Expression, dataType: DataType) * NOTE: this modifies `value` in-place, so don't call it on external data. */ private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = { - decimalType match { - case DecimalType.Unlimited => - value - case DecimalType.Fixed(precision, scale) => - if (value.changePrecision(precision, scale)) value else null - } + if (value.changePrecision(decimalType.precision, decimalType.scale)) value else null } private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index b924af4cc8..88fb516e64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -36,14 +36,13 @@ case class Average(child: Expression) extends AlgebraicAggregate { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) private val resultType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 4, scale + 4) - case DecimalType.Unlimited => DecimalType.Unlimited + case DecimalType.Fixed(p, s) => + DecimalType.bounded(p + 4, s + 4) case _ => DoubleType } private val sumDataType = child.dataType match { - case _ @ DecimalType() => DecimalType.Unlimited + case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) case _ => DoubleType } @@ -71,7 +70,14 @@ case class Average(child: Expression) extends AlgebraicAggregate { ) // If all input are nulls, currentCount will be 0 and we will get null after the division. - override val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType) + override val evaluateExpression = child.dataType match { + case DecimalType.Fixed(p, s) => + // increase the precision and scale to prevent precision loss + val dt = DecimalType.bounded(p + 14, s + 4) + Cast(Cast(currentSum, dt) / Cast(currentCount, dt), resultType) + case _ => + Cast(currentSum, resultType) / Cast(currentCount, resultType) + } } case class Count(child: Expression) extends AlgebraicAggregate { @@ -255,15 +261,11 @@ case class Sum(child: Expression) extends AlgebraicAggregate { private val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 4, scale + 4) - case DecimalType.Unlimited => DecimalType.Unlimited + DecimalType.bounded(precision + 10, scale) case _ => child.dataType } - private val sumDataType = child.dataType match { - case _ @ DecimalType() => DecimalType.Unlimited - case _ => child.dataType - } + private val sumDataType = resultType private val currentSum = AttributeReference("currentSum", sumDataType)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index d3295b8baf..73fde4e916 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -390,22 +390,21 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited + // Add 4 digits after decimal point, like Hive + DecimalType.bounded(precision + 4, scale + 4) case _ => DoubleType } override def asPartial: SplitEvaluation = { child.dataType match { - case DecimalType.Fixed(_, _) | DecimalType.Unlimited => - // Turn the child to unlimited decimals for calculation, before going back to fixed - val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() + case DecimalType.Fixed(precision, scale) => + val partialSum = Alias(Sum(child), "PartialSum")() val partialCount = Alias(Count(child), "PartialCount")() - val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited) - val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited) + // partialSum already increase the precision by 10 + val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType) + val castedCount = Sum(partialCount.toAttribute) SplitEvaluation( Cast(Divide(castedSum, castedCount), dataType), partialCount :: partialSum :: Nil) @@ -435,8 +434,8 @@ case class AverageFunction(expr: Expression, base: AggregateExpression1) private val calcType = expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision + 10, scale) case _ => expr.dataType } @@ -454,10 +453,9 @@ case class AverageFunction(expr: Expression, base: AggregateExpression1) null } else { expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(Divide( - Cast(sum, DecimalType.Unlimited), - Cast(Literal(count), DecimalType.Unlimited)), dataType).eval(null) + case DecimalType.Fixed(precision, scale) => + val dt = DecimalType.bounded(precision + 14, scale + 4) + Cast(Divide(Cast(sum, dt), Cast(Literal(count), dt)), dataType).eval(null) case _ => Divide( Cast(sum, dataType), @@ -481,9 +479,8 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited + // Add 10 digits left of decimal point, like Hive + DecimalType.bounded(precision + 10, scale) case _ => child.dataType } @@ -491,7 +488,7 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) => - val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() + val partialSum = Alias(Sum(child), "PartialSum")() SplitEvaluation( Cast(CombineSum(partialSum.toAttribute), dataType), partialSum :: Nil) @@ -515,8 +512,8 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg private val calcType = expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision + 10, scale) case _ => expr.dataType } @@ -572,8 +569,8 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression1) private val calcType = expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision + 10, scale) case _ => expr.dataType } @@ -608,9 +605,8 @@ case class SumDistinct(child: Expression) extends UnaryExpression with PartialAg override def nullable: Boolean = true override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited + // Add 10 digits left of decimal point, like Hive + DecimalType.bounded(precision + 10, scale) case _ => child.dataType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 05b5ad88fe..7c254a8750 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -88,6 +88,8 @@ abstract class BinaryArithmetic extends BinaryOperator { override def dataType: DataType = left.dataType + override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess + /** Name of the function for this expression on a [[Decimal]] type. */ def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") @@ -114,9 +116,6 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" - override lazy val resolved = - childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -146,9 +145,6 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti override def symbol: String = "-" - override lazy val resolved = - childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -179,9 +175,6 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti override def symbol: String = "*" override def decimalMethod: String = "$times" - override lazy val resolved = - childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) @@ -195,9 +188,6 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic override def decimalMethod: String = "$div" override def nullable: Boolean = true - override lazy val resolved = - childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot @@ -260,9 +250,6 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet override def decimalMethod: String = "remainder" override def nullable: Boolean = true - override lazy val resolved = - childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - private lazy val integral = dataType match { case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index f25ac32679..85060b7893 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -36,9 +36,9 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) case b: Boolean => Literal(b, BooleanType) - case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) - case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) - case d: Decimal => Literal(d, DecimalType.Unlimited) + case d: BigDecimal => Literal(Decimal(d), DecimalType(d.precision, d.scale)) + case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType(d.precision(), d.scale())) + case d: Decimal => Literal(d, DecimalType(d.precision, d.scale)) case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index d06a7a2add..c610f70d38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.plans -import org.apache.spark.sql.catalyst.expressions.{VirtualColumn, Attribute, AttributeSet, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn} import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, StructType} abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] { self: PlanType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index e98fd2583b..591fb26e67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -106,7 +106,7 @@ object DataType { private def nameToType(name: String): DataType = { val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r name match { - case "decimal" => DecimalType.Unlimited + case "decimal" => DecimalType.USER_DEFAULT case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) case other => nonDecimalNameToType(other) } @@ -177,7 +177,7 @@ object DataType { | "BinaryType" ^^^ BinaryType | "BooleanType" ^^^ BooleanType | "DateType" ^^^ DateType - | "DecimalType()" ^^^ DecimalType.Unlimited + | "DecimalType()" ^^^ DecimalType.USER_DEFAULT | fixedDecimalType | "TimestampType" ^^^ TimestampType ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala index 6b43224feb..6e081ea923 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -48,7 +48,7 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { "(?i)binary".r ^^^ BinaryType | "(?i)boolean".r ^^^ BooleanType | fixedDecimalType | - "(?i)decimal".r ^^^ DecimalType.Unlimited | + "(?i)decimal".r ^^^ DecimalType.USER_DEFAULT | "(?i)date".r ^^^ DateType | "(?i)timestamp".r ^^^ TimestampType | varchar diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 377c75f6e8..26b24616d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -26,25 +26,46 @@ import org.apache.spark.sql.catalyst.expressions.Expression /** Precision parameters for a Decimal */ +@deprecated("Use DecimalType(precision, scale) directly", "1.5") case class PrecisionInfo(precision: Int, scale: Int) { if (scale > precision) { throw new AnalysisException( s"Decimal scale ($scale) cannot be greater than precision ($precision).") } + if (precision > DecimalType.MAX_PRECISION) { + throw new AnalysisException( + s"DecimalType can only support precision up to 38" + ) + } } /** * :: DeveloperApi :: * The data type representing `java.math.BigDecimal` values. - * A Decimal that might have fixed precision and scale, or unlimited values for these. + * A Decimal that must have fixed precision (the maximum number of digits) and scale (the number + * of digits on right side of dot). + * + * The precision can be up to 38, scale can also be up to 38 (less or equal to precision). + * + * The default precision and scale is (10, 0). * * Please use [[DataTypes.createDecimalType()]] to create a specific instance. */ @DeveloperApi -case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { +case class DecimalType(precision: Int, scale: Int) extends FractionalType { - /** No-arg constructor for kryo. */ - protected def this() = this(null) + // default constructor for Java + def this(precision: Int) = this(precision, 0) + def this() = this(10) + + @deprecated("Use DecimalType(precision, scale) instead", "1.5") + def this(precisionInfo: Option[PrecisionInfo]) { + this(precisionInfo.getOrElse(PrecisionInfo(10, 0)).precision, + precisionInfo.getOrElse(PrecisionInfo(10, 0)).scale) + } + + @deprecated("Use DecimalType.precision and DecimalType.scale instead", "1.5") + val precisionInfo = Some(PrecisionInfo(precision, scale)) private[sql] type InternalType = Decimal @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } @@ -53,18 +74,16 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT private[sql] val ordering = Decimal.DecimalIsFractional private[sql] val asIntegral = Decimal.DecimalAsIfIntegral - def precision: Int = precisionInfo.map(_.precision).getOrElse(-1) + override def typeName: String = s"decimal($precision,$scale)" - def scale: Int = precisionInfo.map(_.scale).getOrElse(-1) + override def toString: String = s"DecimalType($precision,$scale)" - override def typeName: String = precisionInfo match { - case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" - case None => "decimal" - } - - override def toString: String = precisionInfo match { - case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)" - case None => "DecimalType()" + private[sql] def isWiderThan(other: DataType): Boolean = other match { + case dt: DecimalType => + (precision - scale) >= (dt.precision - dt.scale) && scale >= dt.scale + case dt: IntegralType => + isWiderThan(DecimalType.forType(dt)) + case _ => false } /** @@ -72,10 +91,7 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT */ override def defaultSize: Int = 4096 - override def simpleString: String = precisionInfo match { - case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" - case None => "decimal(10,0)" - } + override def simpleString: String = s"decimal($precision,$scale)" private[spark] override def asNullable: DecimalType = this } @@ -83,8 +99,47 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT /** Extra factory methods and pattern matchers for Decimals */ object DecimalType extends AbstractDataType { + import scala.math.min - override private[sql] def defaultConcreteType: DataType = Unlimited + val MAX_PRECISION = 38 + val MAX_SCALE = 38 + val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18) + val USER_DEFAULT: DecimalType = DecimalType(10, 0) + + @deprecated("Does not support unlimited precision, please specify the precision and scale", "1.5") + val Unlimited: DecimalType = SYSTEM_DEFAULT + + // The decimal types compatible with other numberic types + private[sql] val ByteDecimal = DecimalType(3, 0) + private[sql] val ShortDecimal = DecimalType(5, 0) + private[sql] val IntDecimal = DecimalType(10, 0) + private[sql] val LongDecimal = DecimalType(20, 0) + private[sql] val FloatDecimal = DecimalType(14, 7) + private[sql] val DoubleDecimal = DecimalType(30, 15) + + private[sql] def forType(dataType: DataType): DecimalType = dataType match { + case ByteType => ByteDecimal + case ShortType => ShortDecimal + case IntegerType => IntDecimal + case LongType => LongDecimal + case FloatType => FloatDecimal + case DoubleType => DoubleDecimal + } + + @deprecated("please specify precision and scale", "1.5") + def apply(): DecimalType = USER_DEFAULT + + @deprecated("Use DecimalType(precision, scale) instead", "1.5") + def apply(precisionInfo: Option[PrecisionInfo]) { + this(precisionInfo.getOrElse(PrecisionInfo(10, 0)).precision, + precisionInfo.getOrElse(PrecisionInfo(10, 0)).scale) + } + + private[sql] def bounded(precision: Int, scale: Int): DecimalType = { + DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) + } + + override private[sql] def defaultConcreteType: DataType = SYSTEM_DEFAULT override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[DecimalType] @@ -92,31 +147,18 @@ object DecimalType extends AbstractDataType { override private[sql] def simpleString: String = "decimal" - val Unlimited: DecimalType = DecimalType(None) - private[sql] object Fixed { - def unapply(t: DecimalType): Option[(Int, Int)] = - t.precisionInfo.map(p => (p.precision, p.scale)) + def unapply(t: DecimalType): Option[(Int, Int)] = Some((t.precision, t.scale)) } private[sql] object Expression { def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { - case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) + case t: DecimalType => Some((t.precision, t.scale)) case _ => None } } - def apply(): DecimalType = Unlimited - - def apply(precision: Int, scale: Int): DecimalType = - DecimalType(Some(PrecisionInfo(precision, scale))) - def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] - - def isFixed(dataType: DataType): Boolean = dataType match { - case DecimalType.Fixed(_, _) => true - case _ => false - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 13aad467fa..b9f2ad7ec0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -94,8 +94,8 @@ object RandomDataGenerator { case BooleanType => Some(() => rand.nextBoolean()) case DateType => Some(() => new java.sql.Date(rand.nextInt())) case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong())) - case DecimalType.Unlimited => Some( - () => BigDecimal.apply(rand.nextLong, rand.nextInt, MathContext.UNLIMITED)) + case DecimalType.Fixed(precision, scale) => Some( + () => BigDecimal.apply(rand.nextLong, rand.nextInt, new MathContext(precision))) case DoubleType => randomNumeric[Double]( rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue, Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala index dbba93dba6..677ba0a180 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala @@ -50,9 +50,7 @@ class RandomDataGeneratorSuite extends SparkFunSuite { for ( dataType <- DataTypeTestUtils.atomicTypes; nullable <- Seq(true, false) - if !dataType.isInstanceOf[DecimalType] || - dataType.asInstanceOf[DecimalType].precisionInfo.isEmpty - ) { + if !dataType.isInstanceOf[DecimalType]) { test(s"$dataType (nullable=$nullable)") { testRandomDataGeneration(dataType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index b4b00f5584..3b848cfdf7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -102,7 +102,7 @@ class ScalaReflectionSuite extends SparkFunSuite { StructField("byteField", ByteType, nullable = true), StructField("booleanField", BooleanType, nullable = true), StructField("stringField", StringType, nullable = true), - StructField("decimalField", DecimalType.Unlimited, nullable = true), + StructField("decimalField", DecimalType.SYSTEM_DEFAULT, nullable = true), StructField("dateField", DateType, nullable = true), StructField("timestampField", TimestampType, nullable = true), StructField("binaryField", BinaryType, nullable = true))), @@ -216,7 +216,7 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(DoubleType === typeOfObject(1.7976931348623157E308)) // DecimalType - assert(DecimalType.Unlimited === + assert(DecimalType.SYSTEM_DEFAULT === typeOfObject(new java.math.BigDecimal("1.7976931348623157E318"))) // DateType @@ -229,19 +229,19 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(NullType === typeOfObject(null)) def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.Unlimited - case value: java.math.BigDecimal => DecimalType.Unlimited + case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT + case value: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT case _ => StringType } - assert(DecimalType.Unlimited === typeOfObject1( + assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( new BigInteger("92233720368547758070"))) - assert(DecimalType.Unlimited === typeOfObject1( + assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( new java.math.BigDecimal("1.7976931348623157E318"))) assert(StringType === typeOfObject1(BigInt("92233720368547758070"))) def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.Unlimited + case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT } intercept[MatchError](typeOfObject2(BigInt("92233720368547758070"))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 58df1de983..7e67427237 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -55,7 +55,7 @@ object AnalysisSuite { AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType.Unlimited)(), + AttributeReference("d", DecimalType.SYSTEM_DEFAULT)(), AttributeReference("e", ShortType)()) val nestedRelation = LocalRelation( @@ -158,7 +158,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType.Unlimited)(), + AttributeReference("d", DecimalType(10, 2))(), AttributeReference("e", ShortType)()) val plan = caseInsensitiveAnalyzer.execute( @@ -173,7 +173,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) - assert(pl(3).dataType == DecimalType.Unlimited) + assert(pl(3).dataType == DoubleType) // StringType will be promoted into Double assert(pl(4).dataType == DoubleType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 7bac97b789..f9f15e7a66 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -34,7 +34,7 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { AttributeReference("i", IntegerType)(), AttributeReference("d1", DecimalType(2, 1))(), AttributeReference("d2", DecimalType(5, 2))(), - AttributeReference("u", DecimalType.Unlimited)(), + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), AttributeReference("f", FloatType)(), AttributeReference("b", DoubleType)() ) @@ -92,11 +92,11 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { } test("Comparison operations") { - checkComparison(EqualTo(i, d1), DecimalType(10, 1)) + checkComparison(EqualTo(i, d1), DecimalType(11, 1)) checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2)) - checkComparison(LessThan(i, d1), DecimalType(10, 1)) + checkComparison(LessThan(i, d1), DecimalType(11, 1)) checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2)) - checkComparison(GreaterThan(d2, u), DecimalType.Unlimited) + checkComparison(GreaterThan(d2, u), DecimalType.SYSTEM_DEFAULT) checkComparison(GreaterThanOrEqual(d1, f), DoubleType) checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) } @@ -106,12 +106,12 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { checkUnion(i, d2, DecimalType(12, 2)) checkUnion(d1, d2, DecimalType(5, 2)) checkUnion(d2, d1, DecimalType(5, 2)) - checkUnion(d1, f, DecimalType(8, 7)) - checkUnion(f, d2, DecimalType(10, 7)) - checkUnion(d1, b, DecimalType(16, 15)) - checkUnion(b, d2, DecimalType(18, 15)) - checkUnion(d1, u, DecimalType.Unlimited) - checkUnion(u, d2, DecimalType.Unlimited) + checkUnion(d1, f, DoubleType) + checkUnion(f, d2, DoubleType) + checkUnion(d1, b, DoubleType) + checkUnion(b, d2, DoubleType) + checkUnion(d1, u, DecimalType.SYSTEM_DEFAULT) + checkUnion(u, d2, DecimalType.SYSTEM_DEFAULT) } test("bringing in primitive types") { @@ -125,13 +125,33 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { checkType(Add(d1, Cast(i, DoubleType)), DoubleType) } - test("unlimited decimals make everything else cast up") { - for (expr <- Seq(d1, d2, i, f, u)) { - checkType(Add(expr, u), DecimalType.Unlimited) - checkType(Subtract(expr, u), DecimalType.Unlimited) - checkType(Multiply(expr, u), DecimalType.Unlimited) - checkType(Divide(expr, u), DecimalType.Unlimited) - checkType(Remainder(expr, u), DecimalType.Unlimited) + test("maximum decimals") { + for (expr <- Seq(d1, d2, i, u)) { + checkType(Add(expr, u), DecimalType.SYSTEM_DEFAULT) + checkType(Subtract(expr, u), DecimalType.SYSTEM_DEFAULT) + } + + checkType(Multiply(d1, u), DecimalType(38, 19)) + checkType(Multiply(d2, u), DecimalType(38, 20)) + checkType(Multiply(i, u), DecimalType(38, 18)) + checkType(Multiply(u, u), DecimalType(38, 36)) + + checkType(Divide(u, d1), DecimalType(38, 21)) + checkType(Divide(u, d2), DecimalType(38, 24)) + checkType(Divide(u, i), DecimalType(38, 29)) + checkType(Divide(u, u), DecimalType(38, 38)) + + checkType(Remainder(d1, u), DecimalType(19, 18)) + checkType(Remainder(d2, u), DecimalType(21, 18)) + checkType(Remainder(i, u), DecimalType(28, 18)) + checkType(Remainder(u, u), DecimalType.SYSTEM_DEFAULT) + + for (expr <- Seq(f, b)) { + checkType(Add(expr, u), DoubleType) + checkType(Subtract(expr, u), DoubleType) + checkType(Multiply(expr, u), DoubleType) + checkType(Divide(expr, u), DoubleType) + checkType(Remainder(expr, u), DoubleType) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 835220c563..d0fb95b580 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -35,14 +35,14 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(NullType, NullType, NullType) shouldCast(NullType, IntegerType, IntegerType) - shouldCast(NullType, DecimalType, DecimalType.Unlimited) + shouldCast(NullType, DecimalType, DecimalType.SYSTEM_DEFAULT) shouldCast(ByteType, IntegerType, IntegerType) shouldCast(IntegerType, IntegerType, IntegerType) shouldCast(IntegerType, LongType, LongType) - shouldCast(IntegerType, DecimalType, DecimalType.Unlimited) + shouldCast(IntegerType, DecimalType, DecimalType(10, 0)) shouldCast(LongType, IntegerType, IntegerType) - shouldCast(LongType, DecimalType, DecimalType.Unlimited) + shouldCast(LongType, DecimalType, DecimalType(20, 0)) shouldCast(DateType, TimestampType, TimestampType) shouldCast(TimestampType, DateType, DateType) @@ -71,8 +71,8 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType) shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType) - shouldCast( - DecimalType.Unlimited, TypeCollection(IntegerType, DecimalType), DecimalType.Unlimited) + shouldCast(DecimalType.SYSTEM_DEFAULT, + TypeCollection(IntegerType, DecimalType), DecimalType.SYSTEM_DEFAULT) shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2)) shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2)) shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2)) @@ -82,7 +82,7 @@ class HiveTypeCoercionSuite extends PlanTest { // NumericType should not be changed when function accepts any of them. Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, - DecimalType.Unlimited, DecimalType(10, 2)).foreach { tpe => + DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2)).foreach { tpe => shouldCast(tpe, NumericType, tpe) } @@ -107,8 +107,8 @@ class HiveTypeCoercionSuite extends PlanTest { shouldNotCast(IntegerType, TimestampType) shouldNotCast(LongType, DateType) shouldNotCast(LongType, TimestampType) - shouldNotCast(DecimalType.Unlimited, DateType) - shouldNotCast(DecimalType.Unlimited, TimestampType) + shouldNotCast(DecimalType.SYSTEM_DEFAULT, DateType) + shouldNotCast(DecimalType.SYSTEM_DEFAULT, TimestampType) shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType)) @@ -160,14 +160,6 @@ class HiveTypeCoercionSuite extends PlanTest { widenTest(LongType, FloatType, Some(FloatType)) widenTest(LongType, DoubleType, Some(DoubleType)) - // Casting up to unlimited-precision decimal - widenTest(IntegerType, DecimalType.Unlimited, Some(DecimalType.Unlimited)) - widenTest(DoubleType, DecimalType.Unlimited, Some(DecimalType.Unlimited)) - widenTest(DecimalType(3, 2), DecimalType.Unlimited, Some(DecimalType.Unlimited)) - widenTest(DecimalType.Unlimited, IntegerType, Some(DecimalType.Unlimited)) - widenTest(DecimalType.Unlimited, DoubleType, Some(DecimalType.Unlimited)) - widenTest(DecimalType.Unlimited, DecimalType(3, 2), Some(DecimalType.Unlimited)) - // No up-casting for fixed-precision decimal (this is handled by arithmetic rules) widenTest(DecimalType(2, 1), DecimalType(3, 2), None) widenTest(DecimalType(2, 1), DoubleType, None) @@ -242,9 +234,9 @@ class HiveTypeCoercionSuite extends PlanTest { :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) :: Nil), - Coalesce(Cast(Literal(1L), DecimalType()) - :: Cast(Literal(1), DecimalType()) - :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType()) + Coalesce(Cast(Literal(1L), DecimalType(22, 0)) + :: Cast(Literal(1), DecimalType(22, 0)) + :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) :: Nil)) } @@ -323,7 +315,7 @@ class HiveTypeCoercionSuite extends PlanTest { val left = LocalRelation( AttributeReference("i", IntegerType)(), - AttributeReference("u", DecimalType.Unlimited)(), + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), AttributeReference("b", ByteType)(), AttributeReference("d", DoubleType)()) val right = LocalRelation( @@ -333,7 +325,7 @@ class HiveTypeCoercionSuite extends PlanTest { AttributeReference("l", LongType)()) val wt = HiveTypeCoercion.WidenTypes - val expectedTypes = Seq(StringType, DecimalType.Unlimited, FloatType, DoubleType) + val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) val r1 = wt(Union(left, right)).asInstanceOf[Union] val r2 = wt(Except(left, right)).asInstanceOf[Except] @@ -353,13 +345,13 @@ class HiveTypeCoercionSuite extends PlanTest { } } - val dp = HiveTypeCoercion.DecimalPrecision + val dp = HiveTypeCoercion.WidenTypes val left1 = LocalRelation( AttributeReference("l", DecimalType(10, 8))()) val right1 = LocalRelation( AttributeReference("r", DecimalType(5, 5))()) - val expectedType1 = Seq(DecimalType(math.max(8, 5) + math.max(10 - 8, 5 - 5), math.max(8, 5))) + val expectedType1 = Seq(DecimalType(10, 8)) val r1 = dp(Union(left1, right1)).asInstanceOf[Union] val r2 = dp(Except(left1, right1)).asInstanceOf[Except] @@ -372,12 +364,11 @@ class HiveTypeCoercionSuite extends PlanTest { checkOutput(r3.left, expectedType1) checkOutput(r3.right, expectedType1) - val plan1 = LocalRelation( - AttributeReference("l", DecimalType(10, 10))()) + val plan1 = LocalRelation(AttributeReference("l", DecimalType(10, 5))()) val rightTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType) - val expectedTypes = Seq(DecimalType(3, 0), DecimalType(5, 0), DecimalType(10, 0), - DecimalType(20, 0), DecimalType(7, 7), DecimalType(15, 15)) + val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5), + DecimalType(25, 5), DoubleType, DoubleType) rightTypes.zip(expectedTypes).map { case (rType, expectedType) => val plan2 = LocalRelation( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index ccf448eee0..facf65c155 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -185,7 +185,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkCast(1, 1.0) checkCast(123, "123") - checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123)) + checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 1)), null) checkEvaluation(cast(123, DecimalType(2, 0)), null) @@ -203,7 +203,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkCast(1L, 1.0) checkCast(123L, "123") - checkEvaluation(cast(123L, DecimalType.Unlimited), Decimal(123)) + checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123)) checkEvaluation(cast(123L, DecimalType(3, 1)), Decimal(123.0)) @@ -225,7 +225,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(cast(1000, TimestampType), LongType), 1.toLong) checkEvaluation(cast(cast(-1200, TimestampType), LongType), -2.toLong) - checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123)) + checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 1)), null) checkEvaluation(cast(123, DecimalType(2, 0)), null) @@ -267,7 +267,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(cast("abcdef", IntegerType).nullable === true) assert(cast("abcdef", ShortType).nullable === true) assert(cast("abcdef", ByteType).nullable === true) - assert(cast("abcdef", DecimalType.Unlimited).nullable === true) + assert(cast("abcdef", DecimalType.USER_DEFAULT).nullable === true) assert(cast("abcdef", DecimalType(4, 2)).nullable === true) assert(cast("abcdef", DoubleType).nullable === true) assert(cast("abcdef", FloatType).nullable === true) @@ -291,9 +291,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { c.getTimeInMillis * 1000) checkEvaluation(cast("abdef", StringType), "abdef") - checkEvaluation(cast("abdef", DecimalType.Unlimited), null) + checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null) checkEvaluation(cast("abdef", TimestampType), null) - checkEvaluation(cast("12.65", DecimalType.Unlimited), Decimal(12.65)) + checkEvaluation(cast("12.65", DecimalType.SYSTEM_DEFAULT), Decimal(12.65)) checkEvaluation(cast(cast(sd, DateType), StringType), sd) checkEvaluation(cast(cast(d, StringType), DateType), 0) @@ -311,20 +311,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { 5.toLong) checkEvaluation( cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), - DecimalType.Unlimited), LongType), StringType), ShortType), + DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), 0.toShort) checkEvaluation( cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType), - DecimalType.Unlimited), LongType), StringType), ShortType), + DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), null) - checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.Unlimited), + checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT), ByteType), TimestampType), LongType), StringType), ShortType), 0.toShort) checkEvaluation(cast("23", DoubleType), 23d) checkEvaluation(cast("23", IntegerType), 23) checkEvaluation(cast("23", FloatType), 23f) - checkEvaluation(cast("23", DecimalType.Unlimited), Decimal(23)) + checkEvaluation(cast("23", DecimalType.USER_DEFAULT), Decimal(23)) checkEvaluation(cast("23", ByteType), 23.toByte) checkEvaluation(cast("23", ShortType), 23.toShort) checkEvaluation(cast("2012-12-11", DoubleType), null) @@ -338,7 +338,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Add(Literal(23d), cast(true, DoubleType)), 24d) checkEvaluation(Add(Literal(23), cast(true, IntegerType)), 24) checkEvaluation(Add(Literal(23f), cast(true, FloatType)), 24f) - checkEvaluation(Add(Literal(Decimal(23)), cast(true, DecimalType.Unlimited)), Decimal(24)) + checkEvaluation(Add(Literal(Decimal(23)), cast(true, DecimalType.USER_DEFAULT)), Decimal(24)) checkEvaluation(Add(Literal(23.toByte), cast(true, ByteType)), 24.toByte) checkEvaluation(Add(Literal(23.toShort), cast(true, ShortType)), 24.toShort) } @@ -362,10 +362,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { // - Values that would overflow the target precision should turn into null // - Because of this, casts to fixed-precision decimals should be nullable - assert(cast(123, DecimalType.Unlimited).nullable === false) - assert(cast(10.03f, DecimalType.Unlimited).nullable === true) - assert(cast(10.03, DecimalType.Unlimited).nullable === true) - assert(cast(Decimal(10.03), DecimalType.Unlimited).nullable === false) + assert(cast(123, DecimalType.USER_DEFAULT).nullable === true) + assert(cast(10.03f, DecimalType.SYSTEM_DEFAULT).nullable === true) + assert(cast(10.03, DecimalType.SYSTEM_DEFAULT).nullable === true) + assert(cast(Decimal(10.03), DecimalType.SYSTEM_DEFAULT).nullable === true) assert(cast(123, DecimalType(2, 1)).nullable === true) assert(cast(10.03f, DecimalType(2, 1)).nullable === true) @@ -373,7 +373,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(cast(Decimal(10.03), DecimalType(2, 1)).nullable === true) - checkEvaluation(cast(10.03, DecimalType.Unlimited), Decimal(10.03)) + checkEvaluation(cast(10.03, DecimalType.SYSTEM_DEFAULT), Decimal(10.03)) checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03)) checkEvaluation(cast(10.03, DecimalType(3, 1)), Decimal(10.0)) checkEvaluation(cast(10.03, DecimalType(2, 0)), Decimal(10)) @@ -383,7 +383,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(Decimal(10.03), DecimalType(3, 1)), Decimal(10.0)) checkEvaluation(cast(Decimal(10.03), DecimalType(3, 2)), null) - checkEvaluation(cast(10.05, DecimalType.Unlimited), Decimal(10.05)) + checkEvaluation(cast(10.05, DecimalType.SYSTEM_DEFAULT), Decimal(10.05)) checkEvaluation(cast(10.05, DecimalType(4, 2)), Decimal(10.05)) checkEvaluation(cast(10.05, DecimalType(3, 1)), Decimal(10.1)) checkEvaluation(cast(10.05, DecimalType(2, 0)), Decimal(10)) @@ -409,10 +409,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(Decimal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) checkEvaluation(cast(Decimal(-9.95), DecimalType(1, 0)), null) - checkEvaluation(cast(Double.NaN, DecimalType.Unlimited), null) - checkEvaluation(cast(1.0 / 0.0, DecimalType.Unlimited), null) - checkEvaluation(cast(Float.NaN, DecimalType.Unlimited), null) - checkEvaluation(cast(1.0f / 0.0f, DecimalType.Unlimited), null) + checkEvaluation(cast(Double.NaN, DecimalType.SYSTEM_DEFAULT), null) + checkEvaluation(cast(1.0 / 0.0, DecimalType.SYSTEM_DEFAULT), null) + checkEvaluation(cast(Float.NaN, DecimalType.SYSTEM_DEFAULT), null) + checkEvaluation(cast(1.0f / 0.0f, DecimalType.SYSTEM_DEFAULT), null) checkEvaluation(cast(Double.NaN, DecimalType(2, 1)), null) checkEvaluation(cast(1.0 / 0.0, DecimalType(2, 1)), null) @@ -427,7 +427,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(d, LongType), null) checkEvaluation(cast(d, FloatType), null) checkEvaluation(cast(d, DoubleType), null) - checkEvaluation(cast(d, DecimalType.Unlimited), null) + checkEvaluation(cast(d, DecimalType.SYSTEM_DEFAULT), null) checkEvaluation(cast(d, DecimalType(10, 2)), null) checkEvaluation(cast(d, StringType), "1970-01-01") checkEvaluation(cast(cast(d, TimestampType), StringType), "1970-01-01 00:00:00") @@ -454,7 +454,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { cast(cast(millis.toDouble / 1000, TimestampType), DoubleType), millis.toDouble / 1000) checkEvaluation( - cast(cast(Decimal(1), TimestampType), DecimalType.Unlimited), + cast(cast(Decimal(1), TimestampType), DecimalType.SYSTEM_DEFAULT), Decimal(1)) // A test for higher precision than millis diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index afa143bd5f..b31d6661c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -60,7 +60,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper testIf(_.toFloat, FloatType) testIf(_.toDouble, DoubleType) - testIf(Decimal(_), DecimalType.Unlimited) + testIf(Decimal(_), DecimalType.USER_DEFAULT) testIf(identity, DateType) testIf(_.toLong, TimestampType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index d924ff7a10..f6404d2161 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -33,7 +33,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, LongType), null) checkEvaluation(Literal.create(null, StringType), null) checkEvaluation(Literal.create(null, BinaryType), null) - checkEvaluation(Literal.create(null, DecimalType()), null) + checkEvaluation(Literal.create(null, DecimalType.USER_DEFAULT), null) checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null) checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null) checkEvaluation(Literal.create(null, StructType(Seq.empty)), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index 0728f6695c..9efe44c832 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -30,7 +30,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testFunc(1L, LongType) testFunc(1.0F, FloatType) testFunc(1.0, DoubleType) - testFunc(Decimal(1.5), DecimalType.Unlimited) + testFunc(Decimal(1.5), DecimalType(2, 1)) testFunc(new java.sql.Date(10), DateType) testFunc(new java.sql.Timestamp(10), TimestampType) testFunc("abcd", StringType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 7566cb59e3..48b7dc5745 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -121,6 +121,8 @@ class UnsafeFixedWidthAggregationMapSuite }.toSet seenKeys.size should be (groupKeys.size) seenKeys should be (groupKeys) + + map.free() } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 8819234e78..a5d9806c20 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -145,7 +145,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { DoubleType, StringType, BinaryType - // DecimalType.Unlimited, + // DecimalType.Default, // ArrayType(IntegerType) ) val converter = new UnsafeRowConverter(fieldTypes) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala index c6171b7b69..1ba290753c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala @@ -44,7 +44,7 @@ class DataTypeParserSuite extends SparkFunSuite { checkDataType("float", FloatType) checkDataType("dOUBle", DoubleType) checkDataType("decimal(10, 5)", DecimalType(10, 5)) - checkDataType("decimal", DecimalType.Unlimited) + checkDataType("decimal", DecimalType.USER_DEFAULT) checkDataType("DATE", DateType) checkDataType("timestamp", TimestampType) checkDataType("string", StringType) @@ -87,7 +87,7 @@ class DataTypeParserSuite extends SparkFunSuite { StructType( StructField("struct", StructType( - StructField("deciMal", DecimalType.Unlimited, true) :: + StructField("deciMal", DecimalType.USER_DEFAULT, true) :: StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) :: StructField("MAP", MapType(TimestampType, StringType), true) :: StructField("arrAy", ArrayType(DoubleType, true), true) :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 14e7b4a956..88b221cd81 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -185,7 +185,7 @@ class DataTypeSuite extends SparkFunSuite { checkDataTypeJsonRepr(FloatType) checkDataTypeJsonRepr(DoubleType) checkDataTypeJsonRepr(DecimalType(10, 5)) - checkDataTypeJsonRepr(DecimalType.Unlimited) + checkDataTypeJsonRepr(DecimalType.SYSTEM_DEFAULT) checkDataTypeJsonRepr(DateType) checkDataTypeJsonRepr(TimestampType) checkDataTypeJsonRepr(StringType) @@ -219,7 +219,7 @@ class DataTypeSuite extends SparkFunSuite { checkDefaultSize(FloatType, 4) checkDefaultSize(DoubleType, 8) checkDefaultSize(DecimalType(10, 5), 4096) - checkDefaultSize(DecimalType.Unlimited, 4096) + checkDefaultSize(DecimalType.SYSTEM_DEFAULT, 4096) checkDefaultSize(DateType, 4) checkDefaultSize(TimestampType, 8) checkDefaultSize(StringType, 4096) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala index 32632b5d6e..0ee9ddac81 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala @@ -34,7 +34,7 @@ object DataTypeTestUtils { * decimal types. */ val fractionalTypes: Set[FractionalType] = Set( - DecimalType(precisionInfo = None), + DecimalType.SYSTEM_DEFAULT, DecimalType(2, 1), DoubleType, FloatType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 6e2a6525bf..b25dcbca82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -996,7 +996,7 @@ class ColumnName(name: String) extends Column(name) { * Creates a new [[StructField]] of type decimal. * @since 1.3.0 */ - def decimal: StructField = StructField(name, DecimalType.Unlimited) + def decimal: StructField = StructField(name, DecimalType.USER_DEFAULT) /** * Creates a new [[StructField]] of type decimal. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index fc72360c88..9d8415f063 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -375,7 +375,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 8) { private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int) extends NativeColumnType( - DecimalType(Some(PrecisionInfo(precision, scale))), + DecimalType(precision, scale), 10, FIXED_DECIMAL.defaultSize) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 16176abe3a..5ed158b3d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -21,9 +21,9 @@ import org.apache.spark.TaskContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.trees._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.trees._ import org.apache.spark.sql.types._ case class AggregateEvaluation( @@ -92,8 +92,8 @@ case class GeneratedAggregate( case s @ Sum(expr) => val calcType = expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited + case DecimalType.Fixed(p, s) => + DecimalType.bounded(p + 10, s) case _ => expr.dataType } @@ -121,8 +121,8 @@ case class GeneratedAggregate( case cs @ CombineSum(expr) => val calcType = expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited + case DecimalType.Fixed(p, s) => + DecimalType.bounded(p + 10, s) case _ => expr.dataType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 6b4a359db2..9d0fa894b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -25,6 +25,7 @@ import scala.util.Try import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ @@ -236,7 +237,7 @@ private[sql] object PartitioningUtils { /** * Converts a string to a [[Literal]] with automatic type inference. Currently only supports - * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.Unlimited]], and + * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.SYSTEM_DEFAULT]], and * [[StringType]]. */ private[sql] def inferPartitionColumnValue( @@ -249,7 +250,7 @@ private[sql] object PartitioningUtils { .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) // Then falls back to fractional types .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) - .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited))) + .orElse(Try(Literal(new JBigDecimal(raw)))) // Then falls back to string .getOrElse { if (raw == defaultPartitionName) { @@ -268,7 +269,7 @@ private[sql] object PartitioningUtils { } private val upCastingOrder: Seq[DataType] = - Seq(NullType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited, StringType) + Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType) /** * Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 7a27fba178..3cf70db6b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -66,8 +66,8 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.DATALINK => null case java.sql.Types.DATE => DateType case java.sql.Types.DECIMAL - if precision != 0 || scale != 0 => DecimalType(precision, scale) - case java.sql.Types.DECIMAL => DecimalType.Unlimited + if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale) + case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT case java.sql.Types.DISTINCT => null case java.sql.Types.DOUBLE => DoubleType case java.sql.Types.FLOAT => FloatType @@ -80,8 +80,8 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.NCLOB => StringType case java.sql.Types.NULL => null case java.sql.Types.NUMERIC - if precision != 0 || scale != 0 => DecimalType(precision, scale) - case java.sql.Types.NUMERIC => DecimalType.Unlimited + if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale) + case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT case java.sql.Types.NVARCHAR => StringType case java.sql.Types.OTHER => null case java.sql.Types.REAL => DoubleType @@ -314,7 +314,7 @@ private[sql] class JDBCRDD( abstract class JDBCConversion case object BooleanConversion extends JDBCConversion case object DateConversion extends JDBCConversion - case class DecimalConversion(precisionInfo: Option[(Int, Int)]) extends JDBCConversion + case class DecimalConversion(precision: Int, scale: Int) extends JDBCConversion case object DoubleConversion extends JDBCConversion case object FloatConversion extends JDBCConversion case object IntegerConversion extends JDBCConversion @@ -331,8 +331,7 @@ private[sql] class JDBCRDD( schema.fields.map(sf => sf.dataType match { case BooleanType => BooleanConversion case DateType => DateConversion - case DecimalType.Unlimited => DecimalConversion(None) - case DecimalType.Fixed(d) => DecimalConversion(Some(d)) + case DecimalType.Fixed(p, s) => DecimalConversion(p, s) case DoubleType => DoubleConversion case FloatType => FloatConversion case IntegerType => IntegerConversion @@ -399,20 +398,13 @@ private[sql] class JDBCRDD( // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then // retrieve it, you will get wrong result 199.99. // So it is needed to set precision and scale for Decimal based on JDBC metadata. - case DecimalConversion(Some((p, s))) => + case DecimalConversion(p, s) => val decimalVal = rs.getBigDecimal(pos) if (decimalVal == null) { mutableRow.update(i, null) } else { mutableRow.update(i, Decimal(decimalVal, p, s)) } - case DecimalConversion(None) => - val decimalVal = rs.getBigDecimal(pos) - if (decimalVal == null) { - mutableRow.update(i, null) - } else { - mutableRow.update(i, Decimal(decimalVal)) - } case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index f7ea852fe7..035e051008 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -89,8 +89,7 @@ package object jdbc { case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) - case DecimalType.Unlimited => stmt.setBigDecimal(i + 1, - row.getAs[java.math.BigDecimal](i)) + case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) case _ => throw new IllegalArgumentException( s"Can't translate non-null value for field $i") } @@ -145,7 +144,7 @@ package object jdbc { case BinaryType => "BLOB" case TimestampType => "TIMESTAMP" case DateType => "DATE" - case DecimalType.Unlimited => "DECIMAL(40,20)" + case t: DecimalType => s"DECIMAL(${t.precision}},${t.scale}})" case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") }) val nullable = if (field.nullable) "" else "NOT NULL" @@ -177,7 +176,7 @@ package object jdbc { case BinaryType => java.sql.Types.BLOB case TimestampType => java.sql.Types.TIMESTAMP case DateType => java.sql.Types.DATE - case DecimalType.Unlimited => java.sql.Types.DECIMAL + case t: DecimalType => java.sql.Types.DECIMAL case _ => throw new IllegalArgumentException( s"Can't translate null value for field $field") }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index afe2c6c11a..0eb3b04007 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -113,7 +113,7 @@ private[sql] object InferSchema { case INT | LONG => LongType // Since we do not have a data type backed by BigInteger, // when we see a Java BigInteger, we use DecimalType. - case BIG_INTEGER | BIG_DECIMAL => DecimalType.Unlimited + case BIG_INTEGER | BIG_DECIMAL => DecimalType.SYSTEM_DEFAULT case FLOAT | DOUBLE => DoubleType } @@ -168,8 +168,13 @@ private[sql] object InferSchema { HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { - case (other: DataType, NullType) => other - case (NullType, other: DataType) => other + // Double support larger range than fixed decimal, DecimalType.Maximum should be enough + // in most case, also have better precision. + case (DoubleType, t: DecimalType) => + if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType + case (t: DecimalType, DoubleType) => + if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType + case (StructType(fields1), StructType(fields2)) => val newFields = (fields1 ++ fields2).groupBy(field => field.name).map { case (name, fieldTypes) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index 1ea6926af6..1d3a0d15d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -439,10 +439,6 @@ private[parquet] class CatalystSchemaConverter( .length(minBytesForPrecision(precision)) .named(field.name) - case dec @ DecimalType.Unlimited if followParquetFormatSpec => - throw new AnalysisException( - s"Data type $dec is not supported. Decimal precision and scale must be specified.") - // =================================================== // ArrayType and MapType (for Spark versions <= 1.4.x) // =================================================== diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index e8851ddb68..d1040bf556 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -261,10 +261,10 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo case BinaryType => writer.addBinary( Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) case d: DecimalType => - if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { + if (d.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") } - writeDecimal(value.asInstanceOf[Decimal], d.precisionInfo.get.precision) + writeDecimal(value.asInstanceOf[Decimal], d.precision) case _ => sys.error(s"Do not know how to writer $schema to consumer") } } @@ -415,10 +415,10 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case BinaryType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) case d: DecimalType => - if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { + if (d.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") } - writeDecimal(record(index).asInstanceOf[Decimal], d.precisionInfo.get.precision) + writeDecimal(record(index).asInstanceOf[Decimal], d.precision) case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index fcb8f5499c..cb84e78d62 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -22,7 +22,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import org.apache.spark.sql.test.TestSQLContext$; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -31,8 +30,14 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.*; -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.test.TestSQLContext$; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -159,7 +164,8 @@ public class JavaApplySchemaSuite implements Serializable { "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + "\"boolean\":false, \"null\":null}")); List fields = new ArrayList(7); - fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(), true)); + fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(38, 18), + true)); fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("integer", DataTypes.LongType, true)); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 01bc23277f..037e2048a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -148,7 +148,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType, ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5), + FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), MapType(StringType, LongType), struct) val fields = dataTypes.zipWithIndex.map { case (dataType, index) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3d71deb13e..845ce669f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -109,7 +109,7 @@ class PlannerSuite extends SparkFunSuite { FloatType :: DoubleType :: DecimalType(10, 5) :: - DecimalType.Unlimited :: + DecimalType.SYSTEM_DEFAULT :: DateType :: TimestampType :: StringType :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 4a53fadd7e..54f82f89ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -54,7 +54,7 @@ class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { checkSupported(StringType, isSupported = true) checkSupported(BinaryType, isSupported = true) checkSupported(DecimalType(10, 5), isSupported = true) - checkSupported(DecimalType.Unlimited, isSupported = true) + checkSupported(DecimalType.SYSTEM_DEFAULT, isSupported = true) // If NullType is the only data type in the schema, we do not support it. checkSupported(NullType, isSupported = false) @@ -86,7 +86,7 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll val supportedTypes = Seq(StringType, BinaryType, NullType, BooleanType, ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5), + FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5), DateType, TimestampType) val fields = supportedTypes.zipWithIndex.map { case (dataType, index) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 0f82f13088..42f2449afb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -134,7 +134,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(40, 20))" + conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(38, 18))" ).executeUpdate() conn.prepareStatement("insert into test.flttypes values (" + "1.0000000000000002220446049250313080847263336181640625, " @@ -152,7 +152,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { s""" |create table test.nulltypes (a INT, b BOOLEAN, c TINYINT, d BINARY(20), e VARCHAR(20), |f VARCHAR_IGNORECASE(20), g CHAR(20), h BLOB, i CLOB, j TIME, k DATE, l TIMESTAMP, - |m DOUBLE, n REAL, o DECIMAL(40, 20)) + |m DOUBLE, n REAL, o DECIMAL(38, 18)) """.stripMargin.replaceAll("\n", " ")).executeUpdate() conn.prepareStatement("insert into test.nulltypes values (" + "null, null, null, null, null, null, null, null, null, " @@ -357,14 +357,14 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { test("H2 floating-point types") { val rows = sql("SELECT * FROM flttypes").collect() - assert(rows(0).getDouble(0) === 1.00000000000000022) // Yes, I meant ==. - assert(rows(0).getDouble(1) === 1.00000011920928955) // Yes, I meant ==. - assert(rows(0).getAs[BigDecimal](2) - .equals(new BigDecimal("123456789012345.54321543215432100000"))) - assert(rows(0).schema.fields(2).dataType === DecimalType(40, 20)) - val compareDecimal = sql("SELECT C FROM flttypes where C > C - 1").collect() - assert(compareDecimal(0).getAs[BigDecimal](0) - .equals(new BigDecimal("123456789012345.54321543215432100000"))) + assert(rows(0).getDouble(0) === 1.00000000000000022) + assert(rows(0).getDouble(1) === 1.00000011920928955) + assert(rows(0).getAs[BigDecimal](2) === + new BigDecimal("123456789012345.543215432154321000")) + assert(rows(0).schema.fields(2).dataType === DecimalType(38, 18)) + val result = sql("SELECT C FROM flttypes where C > C - 1").collect() + assert(result(0).getAs[BigDecimal](0) === + new BigDecimal("123456789012345.543215432154321000")) } test("SQL query as table name") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 1d04513a44..3ac312d6f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -63,18 +63,18 @@ class JsonSuite extends QueryTest with TestJsonData { checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType)) checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType)) checkTypePromotion( - Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.Unlimited)) + Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.SYSTEM_DEFAULT)) val longNumber: Long = 9223372036854775807L checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType)) checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType)) checkTypePromotion( - Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.Unlimited)) + Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.SYSTEM_DEFAULT)) val doubleNumber: Double = 1.7976931348623157E308d checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) checkTypePromotion( - Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited)) + Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.SYSTEM_DEFAULT)) checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber)), enforceCorrectType(intNumber, TimestampType)) @@ -115,7 +115,7 @@ class JsonSuite extends QueryTest with TestJsonData { checkDataType(NullType, IntegerType, IntegerType) checkDataType(NullType, LongType, LongType) checkDataType(NullType, DoubleType, DoubleType) - checkDataType(NullType, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(NullType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) checkDataType(NullType, StringType, StringType) checkDataType(NullType, ArrayType(IntegerType), ArrayType(IntegerType)) checkDataType(NullType, StructType(Nil), StructType(Nil)) @@ -126,7 +126,7 @@ class JsonSuite extends QueryTest with TestJsonData { checkDataType(BooleanType, IntegerType, StringType) checkDataType(BooleanType, LongType, StringType) checkDataType(BooleanType, DoubleType, StringType) - checkDataType(BooleanType, DecimalType.Unlimited, StringType) + checkDataType(BooleanType, DecimalType.SYSTEM_DEFAULT, StringType) checkDataType(BooleanType, StringType, StringType) checkDataType(BooleanType, ArrayType(IntegerType), StringType) checkDataType(BooleanType, StructType(Nil), StringType) @@ -135,7 +135,7 @@ class JsonSuite extends QueryTest with TestJsonData { checkDataType(IntegerType, IntegerType, IntegerType) checkDataType(IntegerType, LongType, LongType) checkDataType(IntegerType, DoubleType, DoubleType) - checkDataType(IntegerType, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(IntegerType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) checkDataType(IntegerType, StringType, StringType) checkDataType(IntegerType, ArrayType(IntegerType), StringType) checkDataType(IntegerType, StructType(Nil), StringType) @@ -143,23 +143,24 @@ class JsonSuite extends QueryTest with TestJsonData { // LongType checkDataType(LongType, LongType, LongType) checkDataType(LongType, DoubleType, DoubleType) - checkDataType(LongType, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(LongType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) checkDataType(LongType, StringType, StringType) checkDataType(LongType, ArrayType(IntegerType), StringType) checkDataType(LongType, StructType(Nil), StringType) // DoubleType checkDataType(DoubleType, DoubleType, DoubleType) - checkDataType(DoubleType, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) checkDataType(DoubleType, StringType, StringType) checkDataType(DoubleType, ArrayType(IntegerType), StringType) checkDataType(DoubleType, StructType(Nil), StringType) - // DoubleType - checkDataType(DecimalType.Unlimited, DecimalType.Unlimited, DecimalType.Unlimited) - checkDataType(DecimalType.Unlimited, StringType, StringType) - checkDataType(DecimalType.Unlimited, ArrayType(IntegerType), StringType) - checkDataType(DecimalType.Unlimited, StructType(Nil), StringType) + // DecimalType + checkDataType(DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT, + DecimalType.SYSTEM_DEFAULT) + checkDataType(DecimalType.SYSTEM_DEFAULT, StringType, StringType) + checkDataType(DecimalType.SYSTEM_DEFAULT, ArrayType(IntegerType), StringType) + checkDataType(DecimalType.SYSTEM_DEFAULT, StructType(Nil), StringType) // StringType checkDataType(StringType, StringType, StringType) @@ -213,7 +214,7 @@ class JsonSuite extends QueryTest with TestJsonData { checkDataType( StructType( StructField("f1", IntegerType, true) :: Nil), - DecimalType.Unlimited, + DecimalType.SYSTEM_DEFAULT, StringType) } @@ -240,7 +241,7 @@ class JsonSuite extends QueryTest with TestJsonData { val jsonDF = ctx.read.json(primitiveFieldAndType) val expectedSchema = StructType( - StructField("bigInteger", DecimalType.Unlimited, true) :: + StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", LongType, true) :: @@ -270,7 +271,7 @@ class JsonSuite extends QueryTest with TestJsonData { val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, true), true), true) :: - StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, true), true) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType.SYSTEM_DEFAULT, true), true) :: StructField("arrayOfBoolean", ArrayType(BooleanType, true), true) :: StructField("arrayOfDouble", ArrayType(DoubleType, true), true) :: StructField("arrayOfInteger", ArrayType(LongType, true), true) :: @@ -284,7 +285,7 @@ class JsonSuite extends QueryTest with TestJsonData { StructField("field3", StringType, true) :: Nil), true), true) :: StructField("struct", StructType( StructField("field1", BooleanType, true) :: - StructField("field2", DecimalType.Unlimited, true) :: Nil), true) :: + StructField("field2", DecimalType.SYSTEM_DEFAULT, true) :: Nil), true) :: StructField("structWithArrayFields", StructType( StructField("field1", ArrayType(LongType, true), true) :: StructField("field2", ArrayType(StringType, true), true) :: Nil), true) :: Nil) @@ -385,7 +386,7 @@ class JsonSuite extends QueryTest with TestJsonData { val expectedSchema = StructType( StructField("num_bool", StringType, true) :: StructField("num_num_1", LongType, true) :: - StructField("num_num_2", DecimalType.Unlimited, true) :: + StructField("num_num_2", DecimalType.SYSTEM_DEFAULT, true) :: StructField("num_num_3", DoubleType, true) :: StructField("num_str", StringType, true) :: StructField("str_bool", StringType, true) :: Nil) @@ -421,11 +422,11 @@ class JsonSuite extends QueryTest with TestJsonData { Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil ) - // Widening to DecimalType + // Widening to DoubleType checkAnswer( - sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"), - Row(new java.math.BigDecimal("21474836472.1")) :: - Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil + sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"), + Row(21474836472.2) :: + Row(92233720368547758071.3) :: Nil ) // Widening to DoubleType @@ -442,8 +443,8 @@ class JsonSuite extends QueryTest with TestJsonData { // Number and String conflict: resolve the type as number in this query. checkAnswer( - sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"), - Row(new java.math.BigDecimal("92233720368547758061.2").doubleValue) + sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), + Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue) ) // String and Boolean conflict: resolve the type as string. @@ -489,9 +490,9 @@ class JsonSuite extends QueryTest with TestJsonData { // in the Project. checkAnswer( jsonDF. - where('num_str > BigDecimal("92233720368547758060")). + where('num_str >= BigDecimal("92233720368547758060")). select(('num_str + 1.2).as("num")), - Row(new java.math.BigDecimal("92233720368547758061.2")) + Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue()) ) // The following test will fail. The type of num_str is StringType. @@ -610,7 +611,7 @@ class JsonSuite extends QueryTest with TestJsonData { val jsonDF = ctx.read.json(path) val expectedSchema = StructType( - StructField("bigInteger", DecimalType.Unlimited, true) :: + StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", LongType, true) :: @@ -668,7 +669,7 @@ class JsonSuite extends QueryTest with TestJsonData { primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) val schema = StructType( - StructField("bigInteger", DecimalType.Unlimited, true) :: + StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", IntegerType, true) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 7b16eba00d..3a5b860484 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -122,14 +122,6 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { sqlContext.read.parquet(dir.getCanonicalPath).collect() } } - - // Unlimited-length decimals are not yet supported - intercept[Throwable] { - withTempPath { dir => - makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).collect() - } - } } test("date type") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 4f98776b91..7f16b1125c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -509,7 +509,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { FloatType, DoubleType, DecimalType(10, 5), - DecimalType.Unlimited, + DecimalType.SYSTEM_DEFAULT, DateType, TimestampType, StringType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 54e1efb6e3..da53ec16b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -44,7 +44,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo StructField("doubleType", DoubleType, nullable = false), StructField("bigintType", LongType, nullable = false), StructField("tinyintType", ByteType, nullable = false), - StructField("decimalType", DecimalType.Unlimited, nullable = false), + StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false), StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), StructField("binaryType", BinaryType, nullable = false), StructField("booleanType", BooleanType, nullable = false), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 2c916f3322..143aadc08b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -202,7 +202,7 @@ class TableScanSuite extends DataSourceTest { StructField("longField_:,<>=+/~^", LongType, true) :: StructField("floatField", FloatType, true) :: StructField("doubleField", DoubleType, true) :: - StructField("decimalField1", DecimalType.Unlimited, true) :: + StructField("decimalField1", DecimalType.USER_DEFAULT, true) :: StructField("decimalField2", DecimalType(9, 2), true) :: StructField("dateField", DateType, true) :: StructField("timestampField", TimestampType, true) :: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index a8f2ee37cb..592cfa0ee8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -179,7 +179,7 @@ private[hive] trait HiveInspectors { // writable case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType - case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.Unlimited + case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.SYSTEM_DEFAULT case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType case c: Class[_] if c == classOf[hiveIo.DateWritable] => DateType @@ -195,8 +195,8 @@ private[hive] trait HiveInspectors { case c: Class[_] if c == classOf[java.lang.String] => StringType case c: Class[_] if c == classOf[java.sql.Date] => DateType case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType - case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.Unlimited - case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.Unlimited + case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.SYSTEM_DEFAULT + case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.SYSTEM_DEFAULT case c: Class[_] if c == classOf[Array[Byte]] => BinaryType case c: Class[_] if c == classOf[java.lang.Short] => ShortType case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType @@ -813,9 +813,6 @@ private[hive] trait HiveInspectors { private def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) - case _ => new DecimalTypeInfo( - HiveShim.UNLIMITED_DECIMAL_PRECISION, - HiveShim.UNLIMITED_DECIMAL_SCALE) } def toTypeInfo: TypeInfo = dt match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 8518e333e8..620b8a44d8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -377,7 +377,7 @@ private[hive] object HiveQl extends Logging { DecimalType(precision.getText.toInt, scale.getText.toInt) case Token("TOK_DECIMAL", precision :: Nil) => DecimalType(precision.getText.toInt, 0) - case Token("TOK_DECIMAL", Nil) => DecimalType.Unlimited + case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT case Token("TOK_BIGINT", Nil) => LongType case Token("TOK_INT", Nil) => IntegerType case Token("TOK_TINYINT", Nil) => ByteType @@ -1369,7 +1369,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, 0)) case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType.Unlimited) + Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT) case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), TimestampType) case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) =>