[SPARK-23898][SQL] Simplify add & subtract code generation
## What changes were proposed in this pull request? Code generation for the `Add` and `Subtract` expressions was not done using the `BinaryArithmetic.doCodeGen` method because these expressions also support `CalendarInterval`. This leads to a bit of duplication. This PR gets rid of that duplication by adding `calendarIntervalMethod` to `BinaryArithmetic` and doing the code generation for `CalendarInterval` in `BinaryArithmetic` instead. ## How was this patch tested? Existing tests. Author: Herman van Hovell <hvanhovell@databricks.com> Closes #21005 from hvanhovell/SPARK-23898.
This commit is contained in:
parent
f94f3624ea
commit
6498884154
|
@ -43,7 +43,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression
|
|||
private lazy val numeric = TypeUtils.getNumeric(dataType)
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
|
||||
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
|
||||
case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
|
||||
case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
|
||||
val originValue = ctx.freshName("origin")
|
||||
// codegen would fail to compile if we just write (-($c))
|
||||
|
@ -52,7 +52,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression
|
|||
${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval);
|
||||
${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue));
|
||||
"""})
|
||||
case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
|
||||
case _: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
|
||||
}
|
||||
|
||||
protected override def nullSafeEval(input: Any): Any = {
|
||||
|
@ -104,7 +104,7 @@ case class Abs(child: Expression)
|
|||
private lazy val numeric = TypeUtils.getNumeric(dataType)
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
|
||||
case dt: DecimalType =>
|
||||
case _: DecimalType =>
|
||||
defineCodeGen(ctx, ev, c => s"$c.abs()")
|
||||
case dt: NumericType =>
|
||||
defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))")
|
||||
|
@ -117,15 +117,21 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
|
|||
|
||||
override def dataType: DataType = left.dataType
|
||||
|
||||
override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess
|
||||
override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess
|
||||
|
||||
/** Name of the function for this expression on a [[Decimal]] type. */
|
||||
def decimalMethod: String =
|
||||
sys.error("BinaryArithmetics must override either decimalMethod or genCode")
|
||||
|
||||
/** Name of the function for this expression on a [[CalendarInterval]] type. */
|
||||
def calendarIntervalMethod: String =
|
||||
sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode")
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
|
||||
case dt: DecimalType =>
|
||||
case _: DecimalType =>
|
||||
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
|
||||
case CalendarIntervalType =>
|
||||
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)")
|
||||
// byte and short are casted into int when add, minus, times or divide
|
||||
case ByteType | ShortType =>
|
||||
defineCodeGen(ctx, ev,
|
||||
|
@ -152,6 +158,10 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
|
|||
|
||||
override def symbol: String = "+"
|
||||
|
||||
override def decimalMethod: String = "$plus"
|
||||
|
||||
override def calendarIntervalMethod: String = "add"
|
||||
|
||||
private lazy val numeric = TypeUtils.getNumeric(dataType)
|
||||
|
||||
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
|
||||
|
@ -161,18 +171,6 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
|
|||
numeric.plus(input1, input2)
|
||||
}
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
|
||||
case dt: DecimalType =>
|
||||
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)")
|
||||
case ByteType | ShortType =>
|
||||
defineCodeGen(ctx, ev,
|
||||
(eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
|
||||
case CalendarIntervalType =>
|
||||
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)")
|
||||
case _ =>
|
||||
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
|
||||
}
|
||||
}
|
||||
|
||||
@ExpressionDescription(
|
||||
|
@ -188,6 +186,10 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
|
|||
|
||||
override def symbol: String = "-"
|
||||
|
||||
override def decimalMethod: String = "$minus"
|
||||
|
||||
override def calendarIntervalMethod: String = "subtract"
|
||||
|
||||
private lazy val numeric = TypeUtils.getNumeric(dataType)
|
||||
|
||||
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
|
||||
|
@ -197,18 +199,6 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
|
|||
numeric.minus(input1, input2)
|
||||
}
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
|
||||
case dt: DecimalType =>
|
||||
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)")
|
||||
case ByteType | ShortType =>
|
||||
defineCodeGen(ctx, ev,
|
||||
(eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
|
||||
case CalendarIntervalType =>
|
||||
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)")
|
||||
case _ =>
|
||||
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
|
||||
}
|
||||
}
|
||||
|
||||
@ExpressionDescription(
|
||||
|
@ -416,7 +406,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
|
|||
|
||||
override def symbol: String = "pmod"
|
||||
|
||||
protected def checkTypesInternal(t: DataType) =
|
||||
protected def checkTypesInternal(t: DataType): TypeCheckResult =
|
||||
TypeUtils.checkForNumericExpr(t, "pmod")
|
||||
|
||||
override def inputType: AbstractDataType = NumericType
|
||||
|
|
Loading…
Reference in a new issue