[SPARK-23898][SQL] Simplify add & subtract code generation

## What changes were proposed in this pull request?
Code generation for the `Add` and `Subtract` expressions was not done using the `BinaryArithmetic.doCodeGen` method because these expressions also support `CalendarInterval`. This leads to a bit of duplication.

This PR gets rid of that duplication by adding `calendarIntervalMethod` to `BinaryArithmetic` and doing the code generation for `CalendarInterval` in `BinaryArithmetic` instead.

## How was this patch tested?
Existing tests.

Author: Herman van Hovell <hvanhovell@databricks.com>

Closes #21005 from hvanhovell/SPARK-23898.
This commit is contained in:
Herman van Hovell 2018-04-09 21:49:49 -07:00 committed by gatorsmile
parent f94f3624ea
commit 6498884154

View file

@ -43,7 +43,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression
private lazy val numeric = TypeUtils.getNumeric(dataType)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
val originValue = ctx.freshName("origin")
// codegen would fail to compile if we just write (-($c))
@ -52,7 +52,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression
${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval);
${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue));
"""})
case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
case _: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
}
protected override def nullSafeEval(input: Any): Any = {
@ -104,7 +104,7 @@ case class Abs(child: Expression)
private lazy val numeric = TypeUtils.getNumeric(dataType)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case dt: DecimalType =>
case _: DecimalType =>
defineCodeGen(ctx, ev, c => s"$c.abs()")
case dt: NumericType =>
defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))")
@ -117,15 +117,21 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
override def dataType: DataType = left.dataType
override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess
override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess
/** Name of the function for this expression on a [[Decimal]] type. */
def decimalMethod: String =
sys.error("BinaryArithmetics must override either decimalMethod or genCode")
/** Name of the function for this expression on a [[CalendarInterval]] type. */
def calendarIntervalMethod: String =
sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode")
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case dt: DecimalType =>
case _: DecimalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
case CalendarIntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)")
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
@ -152,6 +158,10 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "+"
override def decimalMethod: String = "$plus"
override def calendarIntervalMethod: String = "add"
private lazy val numeric = TypeUtils.getNumeric(dataType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
@ -161,18 +171,6 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
numeric.plus(input1, input2)
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case dt: DecimalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)")
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
(eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
case CalendarIntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
}
}
@ExpressionDescription(
@ -188,6 +186,10 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
override def symbol: String = "-"
override def decimalMethod: String = "$minus"
override def calendarIntervalMethod: String = "subtract"
private lazy val numeric = TypeUtils.getNumeric(dataType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
@ -197,18 +199,6 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
numeric.minus(input1, input2)
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case dt: DecimalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)")
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
(eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
case CalendarIntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
}
}
@ExpressionDescription(
@ -416,7 +406,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "pmod"
protected def checkTypesInternal(t: DataType) =
protected def checkTypesInternal(t: DataType): TypeCheckResult =
TypeUtils.checkForNumericExpr(t, "pmod")
override def inputType: AbstractDataType = NumericType