[SPARK-11076] [SQL] Add decimal support for floor and ceil
Actually all of the `UnaryMathExpression` doens't support the Decimal, will create follow ups for supporing it. This is the first PR which will be good to review the approach I am taking. Author: Cheng Hao <hao.cheng@intel.com> Closes #9086 from chenghao-intel/ceiling.
This commit is contained in:
parent
4ace4f8a9c
commit
9808052b5a
|
@ -55,7 +55,7 @@ abstract class LeafMathExpression(c: Double, name: String)
|
|||
abstract class UnaryMathExpression(val f: Double => Double, name: String)
|
||||
extends UnaryExpression with Serializable with ImplicitCastInputTypes {
|
||||
|
||||
override def inputTypes: Seq[DataType] = Seq(DoubleType)
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType)
|
||||
override def dataType: DataType = DoubleType
|
||||
override def nullable: Boolean = true
|
||||
override def toString: String = s"$name($child)"
|
||||
|
@ -153,13 +153,28 @@ case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN"
|
|||
case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT")
|
||||
|
||||
case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") {
|
||||
override def dataType: DataType = LongType
|
||||
protected override def nullSafeEval(input: Any): Any = {
|
||||
f(input.asInstanceOf[Double]).toLong
|
||||
override def dataType: DataType = child.dataType match {
|
||||
case dt @ DecimalType.Fixed(_, 0) => dt
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
DecimalType.bounded(precision - scale + 1, 0)
|
||||
case _ => LongType
|
||||
}
|
||||
|
||||
override def inputTypes: Seq[AbstractDataType] =
|
||||
Seq(TypeCollection(DoubleType, DecimalType))
|
||||
|
||||
protected override def nullSafeEval(input: Any): Any = child.dataType match {
|
||||
case DoubleType => f(input.asInstanceOf[Double]).toLong
|
||||
case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||
defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
|
||||
child.dataType match {
|
||||
case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
defineCodeGen(ctx, ev, c => s"$c.ceil()")
|
||||
case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -205,13 +220,28 @@ case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP")
|
|||
case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1")
|
||||
|
||||
case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") {
|
||||
override def dataType: DataType = LongType
|
||||
protected override def nullSafeEval(input: Any): Any = {
|
||||
f(input.asInstanceOf[Double]).toLong
|
||||
override def dataType: DataType = child.dataType match {
|
||||
case dt @ DecimalType.Fixed(_, 0) => dt
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
DecimalType.bounded(precision - scale + 1, 0)
|
||||
case _ => LongType
|
||||
}
|
||||
|
||||
override def inputTypes: Seq[AbstractDataType] =
|
||||
Seq(TypeCollection(DoubleType, DecimalType))
|
||||
|
||||
protected override def nullSafeEval(input: Any): Any = child.dataType match {
|
||||
case DoubleType => f(input.asInstanceOf[Double]).toLong
|
||||
case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||
defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
|
||||
child.dataType match {
|
||||
case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
defineCodeGen(ctx, ev, c => s"$c.floor()")
|
||||
case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -107,7 +107,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
|
|||
* Set this Decimal to the given BigDecimal value, with a given precision and scale.
|
||||
*/
|
||||
def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
|
||||
this.decimalVal = decimal.setScale(scale, ROUNDING_MODE)
|
||||
this.decimalVal = decimal.setScale(scale, ROUND_HALF_UP)
|
||||
require(
|
||||
decimalVal.precision <= precision,
|
||||
s"Decimal precision ${decimalVal.precision} exceeds max precision $precision")
|
||||
|
@ -198,6 +198,16 @@ final class Decimal extends Ordered[Decimal] with Serializable {
|
|||
* @return true if successful, false if overflow would occur
|
||||
*/
|
||||
def changePrecision(precision: Int, scale: Int): Boolean = {
|
||||
changePrecision(precision, scale, ROUND_HALF_UP)
|
||||
}
|
||||
|
||||
/**
|
||||
* Update precision and scale while keeping our value the same, and return true if successful.
|
||||
*
|
||||
* @return true if successful, false if overflow would occur
|
||||
*/
|
||||
private[sql] def changePrecision(precision: Int, scale: Int,
|
||||
roundMode: BigDecimal.RoundingMode.Value): Boolean = {
|
||||
// fast path for UnsafeProjection
|
||||
if (precision == this.precision && scale == this.scale) {
|
||||
return true
|
||||
|
@ -231,7 +241,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
|
|||
if (decimalVal.ne(null)) {
|
||||
// We get here if either we started with a BigDecimal, or we switched to one because we would
|
||||
// have overflowed our Long; in either case we must rescale decimalVal to the new scale.
|
||||
val newVal = decimalVal.setScale(scale, ROUNDING_MODE)
|
||||
val newVal = decimalVal.setScale(scale, roundMode)
|
||||
if (newVal.precision > precision) {
|
||||
return false
|
||||
}
|
||||
|
@ -309,10 +319,26 @@ final class Decimal extends Ordered[Decimal] with Serializable {
|
|||
}
|
||||
|
||||
def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this
|
||||
|
||||
def floor: Decimal = if (scale == 0) this else {
|
||||
val value = this.clone()
|
||||
value.changePrecision(
|
||||
DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_FLOOR)
|
||||
value
|
||||
}
|
||||
|
||||
def ceil: Decimal = if (scale == 0) this else {
|
||||
val value = this.clone()
|
||||
value.changePrecision(
|
||||
DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_CEILING)
|
||||
value
|
||||
}
|
||||
}
|
||||
|
||||
object Decimal {
|
||||
private val ROUNDING_MODE = BigDecimal.RoundingMode.HALF_UP
|
||||
val ROUND_HALF_UP = BigDecimal.RoundingMode.HALF_UP
|
||||
val ROUND_CEILING = BigDecimal.RoundingMode.CEILING
|
||||
val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR
|
||||
|
||||
/** Maximum number of decimal digits a Long can represent */
|
||||
val MAX_LONG_DIGITS = 18
|
||||
|
|
|
@ -78,7 +78,18 @@ object LiteralGenerator {
|
|||
Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity)
|
||||
} yield Literal.create(f, DoubleType)
|
||||
|
||||
// TODO: decimal type
|
||||
// TODO cache the generated data
|
||||
def decimalLiteralGen(precision: Int, scale: Int): Gen[Literal] = {
|
||||
assert(scale >= 0)
|
||||
assert(precision >= scale)
|
||||
Arbitrary.arbBigInt.arbitrary.map { s =>
|
||||
val a = (s % BigInt(10).pow(precision - scale)).toString()
|
||||
val b = (s % BigInt(10).pow(scale)).abs.toString()
|
||||
Literal.create(
|
||||
Decimal(BigDecimal(s"$a.$b"), precision, scale),
|
||||
DecimalType(precision, scale))
|
||||
}
|
||||
}
|
||||
|
||||
lazy val stringLiteralGen: Gen[Literal] =
|
||||
for { s <- Arbitrary.arbString.arbitrary } yield Literal.create(s, StringType)
|
||||
|
@ -122,6 +133,7 @@ object LiteralGenerator {
|
|||
case StringType => stringLiteralGen
|
||||
case BinaryType => binaryLiteralGen
|
||||
case CalendarIntervalType => calendarIntervalLiterGen
|
||||
case DecimalType.Fixed(precision, scale) => decimalLiteralGen(precision, scale)
|
||||
case dt => throw new IllegalArgumentException(s"not supported type $dt")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -246,11 +246,21 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
test("ceil") {
|
||||
testUnary(Ceil, (d: Double) => math.ceil(d).toLong)
|
||||
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType)
|
||||
|
||||
testUnary(Ceil, (d: Decimal) => d.ceil, (-20 to 20).map(x => Decimal(x * 0.1)))
|
||||
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3))
|
||||
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0))
|
||||
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0))
|
||||
}
|
||||
|
||||
test("floor") {
|
||||
testUnary(Floor, (d: Double) => math.floor(d).toLong)
|
||||
checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType)
|
||||
|
||||
testUnary(Floor, (d: Decimal) => d.floor, (-20 to 20).map(x => Decimal(x * 0.1)))
|
||||
checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3))
|
||||
checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0))
|
||||
checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0))
|
||||
}
|
||||
|
||||
test("factorial") {
|
||||
|
|
Loading…
Reference in a new issue