[SPARK-26218][SQL] Overflow on arithmetic operations returns incorrect result

## What changes were proposed in this pull request?

When an overflow occurs performing an arithmetic operation, we are returning an incorrect value. Instead, we should throw an exception, as stated in the SQL standard.

## How was this patch tested?

added UT + existing UTs (improved)

Closes #21599 from mgaido91/SPARK-24598.

Authored-by: Marco Gaido <marcogaido91@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Marco Gaido 2019-08-01 14:51:38 +08:00 committed by Wenchen Fan
parent b3ffd8be14
commit ee41001949
14 changed files with 554 additions and 172 deletions

View file

@ -35,6 +35,7 @@ import org.apache.spark.unsafe.types.CalendarInterval
""")
case class UnaryMinus(child: Expression) extends UnaryExpression
with ExpectsInputTypes with NullIntolerant {
private val checkOverflow = SQLConf.get.arithmeticOperationsFailOnOverflow
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
@ -42,10 +43,28 @@ case class UnaryMinus(child: Expression) extends UnaryExpression
override def toString: String = s"-$child"
private lazy val numeric = TypeUtils.getNumeric(dataType)
private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
case ByteType | ShortType if checkOverflow =>
nullSafeCodeGen(ctx, ev, eval => {
val javaBoxedType = CodeGenerator.boxedType(dataType)
val javaType = CodeGenerator.javaType(dataType)
val originValue = ctx.freshName("origin")
s"""
|$javaType $originValue = ($javaType)($eval);
|if ($originValue == $javaBoxedType.MIN_VALUE) {
| throw new ArithmeticException("- " + $originValue + " caused overflow.");
|}
|${ev.value} = ($javaType)(-($originValue));
""".stripMargin
})
case IntegerType | LongType if checkOverflow =>
nullSafeCodeGen(ctx, ev, eval => {
val mathClass = classOf[Math].getName
s"${ev.value} = $mathClass.negateExact($eval);"
})
case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
val originValue = ctx.freshName("origin")
// codegen would fail to compile if we just write (-($c))
@ -117,6 +136,8 @@ case class Abs(child: Expression)
abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
protected val checkOverflow = SQLConf.get.arithmeticOperationsFailOnOverflow
override def dataType: DataType = left.dataType
override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess
@ -129,17 +150,57 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
def calendarIntervalMethod: String =
sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode")
/** Name of the function for the exact version of this expression in [[Math]]. */
def exactMathMethod: String =
sys.error("BinaryArithmetics must override either exactMathMethod or genCode")
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case _: DecimalType =>
// Overflow is handled in the CheckOverflow operator
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
case CalendarIntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)")
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
(eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val tmpResult = ctx.freshName("tmpResult")
val overflowCheck = if (checkOverflow) {
val javaType = CodeGenerator.boxedType(dataType)
s"""
|if ($tmpResult < $javaType.MIN_VALUE || $tmpResult > $javaType.MAX_VALUE) {
| throw new ArithmeticException($eval1 + " $symbol " + $eval2 + " caused overflow.");
|}
""".stripMargin
} else {
""
}
s"""
|${CodeGenerator.JAVA_INT} $tmpResult = $eval1 $symbol $eval2;
|$overflowCheck
|${ev.value} = (${CodeGenerator.javaType(dataType)})($tmpResult);
""".stripMargin
})
case IntegerType | LongType =>
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val operation = if (checkOverflow) {
val mathClass = classOf[Math].getName
s"$mathClass.$exactMathMethod($eval1, $eval2)"
} else {
s"$eval1 $symbol $eval2"
}
s"""
|${ev.value} = $operation;
""".stripMargin
})
case DoubleType | FloatType =>
// When Double/Float overflows, there can be 2 cases:
// - precision loss: according to SQL standard, the number is truncated;
// - returns (+/-)Infinite: same behavior also other DBs have (eg. Postgres)
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
|${ev.value} = $eval1 $symbol $eval2;
""".stripMargin
})
}
}
@ -164,7 +225,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
override def calendarIntervalMethod: String = "add"
private lazy val numeric = TypeUtils.getNumeric(dataType)
private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
if (dataType.isInstanceOf[CalendarIntervalType]) {
@ -173,6 +234,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
numeric.plus(input1, input2)
}
}
override def exactMathMethod: String = "addExact"
}
@ExpressionDescription(
@ -192,7 +255,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
override def calendarIntervalMethod: String = "subtract"
private lazy val numeric = TypeUtils.getNumeric(dataType)
private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
if (dataType.isInstanceOf[CalendarIntervalType]) {
@ -201,6 +264,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
numeric.minus(input1, input2)
}
}
override def exactMathMethod: String = "subtractExact"
}
@ExpressionDescription(
@ -217,9 +282,11 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
override def symbol: String = "*"
override def decimalMethod: String = "$times"
private lazy val numeric = TypeUtils.getNumeric(dataType)
private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
override def exactMathMethod: String = "multiplyExact"
}
// Common base trait for Divide and Remainder, since these two classes are almost identical

View file

@ -60,8 +60,13 @@ object TypeUtils {
}
}
def getNumeric(t: DataType): Numeric[Any] =
t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]
def getNumeric(t: DataType, exactNumericRequired: Boolean = false): Numeric[Any] = {
if (exactNumericRequired) {
t.asInstanceOf[NumericType].exactNumeric.asInstanceOf[Numeric[Any]]
} else {
t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]
}
}
def getInterpretedOrdering(t: DataType): Ordering[Any] = {
t match {

View file

@ -1780,6 +1780,15 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
val ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW =
buildConf("spark.sql.arithmeticOperations.failOnOverFlow")
.doc("If it is set to true, all arithmetic operations on non-decimal fields throw an " +
"exception if an overflow occurs. If it is false (default), in case of overflow a wrong " +
"result is returned.")
.internal()
.booleanConf
.createWithDefault(false)
val LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE =
buildConf("spark.sql.legacy.parser.havingWithoutGroupByAsWhere")
.internal()
@ -2287,6 +2296,8 @@ class SQLConf extends Serializable with Logging {
def decimalOperationsNullOnOverflow: Boolean = getConf(DECIMAL_OPERATIONS_NULL_ON_OVERFLOW)
def arithmeticOperationsFailOnOverflow: Boolean = getConf(ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW)
def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION)
def continuousStreamingEpochBacklogQueueSize: Int =

View file

@ -142,6 +142,8 @@ abstract class NumericType extends AtomicType {
// desugared by the compiler into an argument to the objects constructor. This means there is no
// longer a no argument constructor and thus the JVM cannot serialize the object anymore.
private[sql] val numeric: Numeric[InternalType]
private[sql] def exactNumeric: Numeric[InternalType] = numeric
}

View file

@ -37,6 +37,7 @@ class ByteType private() extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Byte]]
private[sql] val integral = implicitly[Integral[Byte]]
private[sql] val ordering = implicitly[Ordering[InternalType]]
override private[sql] val exactNumeric = ByteExactNumeric
/**
* The default size of a value of the ByteType is 1 byte.

View file

@ -37,6 +37,7 @@ class IntegerType private() extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Int]]
private[sql] val integral = implicitly[Integral[Int]]
private[sql] val ordering = implicitly[Ordering[InternalType]]
override private[sql] val exactNumeric = IntegerExactNumeric
/**
* The default size of a value of the IntegerType is 4 bytes.

View file

@ -37,6 +37,7 @@ class LongType private() extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Long]]
private[sql] val integral = implicitly[Integral[Long]]
private[sql] val ordering = implicitly[Ordering[InternalType]]
override private[sql] val exactNumeric = LongExactNumeric
/**
* The default size of a value of the LongType is 8 bytes.

View file

@ -37,6 +37,7 @@ class ShortType private() extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Short]]
private[sql] val integral = implicitly[Integral[Short]]
private[sql] val ordering = implicitly[Ordering[InternalType]]
override private[sql] val exactNumeric = ShortExactNumeric
/**
* The default size of a value of the ShortType is 2 bytes.

View file

@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.types
import scala.math.Numeric.{ByteIsIntegral, IntIsIntegral, LongIsIntegral, ShortIsIntegral}
import scala.math.Ordering
object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOrdering {
private def checkOverflow(res: Int, x: Byte, y: Byte, op: String): Unit = {
if (res > Byte.MaxValue || res < Byte.MinValue) {
throw new ArithmeticException(s"$x $op $y caused overflow.")
}
}
override def plus(x: Byte, y: Byte): Byte = {
val tmp = x + y
checkOverflow(tmp, x, y, "+")
tmp.toByte
}
override def minus(x: Byte, y: Byte): Byte = {
val tmp = x - y
checkOverflow(tmp, x, y, "-")
tmp.toByte
}
override def times(x: Byte, y: Byte): Byte = {
val tmp = x * y
checkOverflow(tmp, x, y, "*")
tmp.toByte
}
override def negate(x: Byte): Byte = {
if (x == Byte.MinValue) { // if and only if x is Byte.MinValue, overflow can happen
throw new ArithmeticException(s"- $x caused overflow.")
}
(-x).toByte
}
}
object ShortExactNumeric extends ShortIsIntegral with Ordering.ShortOrdering {
private def checkOverflow(res: Int, x: Short, y: Short, op: String): Unit = {
if (res > Short.MaxValue || res < Short.MinValue) {
throw new ArithmeticException(s"$x $op $y caused overflow.")
}
}
override def plus(x: Short, y: Short): Short = {
val tmp = x + y
checkOverflow(tmp, x, y, "+")
tmp.toShort
}
override def minus(x: Short, y: Short): Short = {
val tmp = x - y
checkOverflow(tmp, x, y, "-")
tmp.toShort
}
override def times(x: Short, y: Short): Short = {
val tmp = x * y
checkOverflow(tmp, x, y, "*")
tmp.toShort
}
override def negate(x: Short): Short = {
if (x == Short.MinValue) { // if and only if x is Byte.MinValue, overflow can happen
throw new ArithmeticException(s"- $x caused overflow.")
}
(-x).toShort
}
}
object IntegerExactNumeric extends IntIsIntegral with Ordering.IntOrdering {
override def plus(x: Int, y: Int): Int = Math.addExact(x, y)
override def minus(x: Int, y: Int): Int = Math.subtractExact(x, y)
override def times(x: Int, y: Int): Int = Math.multiplyExact(x, y)
override def negate(x: Int): Int = Math.negateExact(x)
}
object LongExactNumeric extends LongIsIntegral with Ordering.LongOrdering {
override def plus(x: Long, y: Long): Long = Math.addExact(x, y)
override def minus(x: Long, y: Long): Long = Math.subtractExact(x, y)
override def times(x: Long, y: Long): Long = Math.multiplyExact(x, y)
override def negate(x: Long): Long = Math.negateExact(x)
}

View file

@ -59,8 +59,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Add(positiveIntLit, negativeIntLit), -1)
checkEvaluation(Add(positiveLongLit, negativeLongLit), -1L)
DataTypeTestUtils.numericAndInterval.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegen(Add, tpe, tpe)
Seq("true", "false").foreach { checkOverflow =>
withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> checkOverflow) {
DataTypeTestUtils.numericAndInterval.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegenAllowingException(Add, tpe, tpe)
}
}
}
}
@ -75,6 +79,22 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(UnaryMinus(Literal(Int.MinValue)), Int.MinValue)
checkEvaluation(UnaryMinus(Literal(Short.MinValue)), Short.MinValue)
checkEvaluation(UnaryMinus(Literal(Byte.MinValue)), Byte.MinValue)
withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true") {
checkExceptionInExpression[ArithmeticException](
UnaryMinus(Literal(Long.MinValue)), "overflow")
checkExceptionInExpression[ArithmeticException](
UnaryMinus(Literal(Int.MinValue)), "overflow")
checkExceptionInExpression[ArithmeticException](
UnaryMinus(Literal(Short.MinValue)), "overflow")
checkExceptionInExpression[ArithmeticException](
UnaryMinus(Literal(Byte.MinValue)), "overflow")
checkEvaluation(UnaryMinus(positiveShortLit), (- positiveShort).toShort)
checkEvaluation(UnaryMinus(negativeShortLit), (- negativeShort).toShort)
checkEvaluation(UnaryMinus(positiveIntLit), - positiveInt)
checkEvaluation(UnaryMinus(negativeIntLit), - negativeInt)
checkEvaluation(UnaryMinus(positiveLongLit), - positiveLong)
checkEvaluation(UnaryMinus(negativeLongLit), - negativeLong)
}
checkEvaluation(UnaryMinus(positiveShortLit), (- positiveShort).toShort)
checkEvaluation(UnaryMinus(negativeShortLit), (- negativeShort).toShort)
checkEvaluation(UnaryMinus(positiveIntLit), - positiveInt)
@ -100,8 +120,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Subtract(positiveIntLit, negativeIntLit), positiveInt - negativeInt)
checkEvaluation(Subtract(positiveLongLit, negativeLongLit), positiveLong - negativeLong)
DataTypeTestUtils.numericAndInterval.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegen(Subtract, tpe, tpe)
Seq("true", "false").foreach { checkOverflow =>
withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> checkOverflow) {
DataTypeTestUtils.numericAndInterval.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegenAllowingException(Subtract, tpe, tpe)
}
}
}
}
@ -118,8 +142,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Multiply(positiveIntLit, negativeIntLit), positiveInt * negativeInt)
checkEvaluation(Multiply(positiveLongLit, negativeLongLit), positiveLong * negativeLong)
DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegen(Multiply, tpe, tpe)
Seq("true", "false").foreach { checkOverflow =>
withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> checkOverflow) {
DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegenAllowingException(Multiply, tpe, tpe)
}
}
}
}
@ -376,4 +404,100 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2)
assert(ctx2.inlinedMutableStates.size == 1)
}
test("SPARK-24598: overflow on long returns wrong result") {
val maxLongLiteral = Literal(Long.MaxValue)
val minLongLiteral = Literal(Long.MinValue)
val e1 = Add(maxLongLiteral, Literal(1L))
val e2 = Subtract(maxLongLiteral, Literal(-1L))
val e3 = Multiply(maxLongLiteral, Literal(2L))
val e4 = Add(minLongLiteral, minLongLiteral)
val e5 = Subtract(minLongLiteral, maxLongLiteral)
val e6 = Multiply(minLongLiteral, minLongLiteral)
withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true") {
Seq(e1, e2, e3, e4, e5, e6).foreach { e =>
checkExceptionInExpression[ArithmeticException](e, "overflow")
}
}
withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "false") {
checkEvaluation(e1, Long.MinValue)
checkEvaluation(e2, Long.MinValue)
checkEvaluation(e3, -2L)
checkEvaluation(e4, 0L)
checkEvaluation(e5, 1L)
checkEvaluation(e6, 0L)
}
}
test("SPARK-24598: overflow on integer returns wrong result") {
val maxIntLiteral = Literal(Int.MaxValue)
val minIntLiteral = Literal(Int.MinValue)
val e1 = Add(maxIntLiteral, Literal(1))
val e2 = Subtract(maxIntLiteral, Literal(-1))
val e3 = Multiply(maxIntLiteral, Literal(2))
val e4 = Add(minIntLiteral, minIntLiteral)
val e5 = Subtract(minIntLiteral, maxIntLiteral)
val e6 = Multiply(minIntLiteral, minIntLiteral)
withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true") {
Seq(e1, e2, e3, e4, e5, e6).foreach { e =>
checkExceptionInExpression[ArithmeticException](e, "overflow")
}
}
withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "false") {
checkEvaluation(e1, Int.MinValue)
checkEvaluation(e2, Int.MinValue)
checkEvaluation(e3, -2)
checkEvaluation(e4, 0)
checkEvaluation(e5, 1)
checkEvaluation(e6, 0)
}
}
test("SPARK-24598: overflow on short returns wrong result") {
val maxShortLiteral = Literal(Short.MaxValue)
val minShortLiteral = Literal(Short.MinValue)
val e1 = Add(maxShortLiteral, Literal(1.toShort))
val e2 = Subtract(maxShortLiteral, Literal((-1).toShort))
val e3 = Multiply(maxShortLiteral, Literal(2.toShort))
val e4 = Add(minShortLiteral, minShortLiteral)
val e5 = Subtract(minShortLiteral, maxShortLiteral)
val e6 = Multiply(minShortLiteral, minShortLiteral)
withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true") {
Seq(e1, e2, e3, e4, e5, e6).foreach { e =>
checkExceptionInExpression[ArithmeticException](e, "overflow")
}
}
withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "false") {
checkEvaluation(e1, Short.MinValue)
checkEvaluation(e2, Short.MinValue)
checkEvaluation(e3, (-2).toShort)
checkEvaluation(e4, 0.toShort)
checkEvaluation(e5, 1.toShort)
checkEvaluation(e6, 0.toShort)
}
}
test("SPARK-24598: overflow on byte returns wrong result") {
val maxByteLiteral = Literal(Byte.MaxValue)
val minByteLiteral = Literal(Byte.MinValue)
val e1 = Add(maxByteLiteral, Literal(1.toByte))
val e2 = Subtract(maxByteLiteral, Literal((-1).toByte))
val e3 = Multiply(maxByteLiteral, Literal(2.toByte))
val e4 = Add(minByteLiteral, minByteLiteral)
val e5 = Subtract(minByteLiteral, maxByteLiteral)
val e6 = Multiply(minByteLiteral, minByteLiteral)
withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true") {
Seq(e1, e2, e3, e4, e5, e6).foreach { e =>
checkExceptionInExpression[ArithmeticException](e, "overflow")
}
}
withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "false") {
checkEvaluation(e1, Byte.MinValue)
checkEvaluation(e2, Byte.MinValue)
checkEvaluation(e3, (-2).toByte)
checkEvaluation(e4, 0.toByte)
checkEvaluation(e5, 1.toByte)
checkEvaluation(e6, 0.toByte)
}
}
}

View file

@ -359,6 +359,26 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa
}
}
/**
* Test evaluation results between Interpreted mode and Codegen mode, making sure we have
* consistent result regardless of the evaluation method we use. If an exception is thrown,
* it checks that both modes throw the same exception.
*
* This method test against binary expressions by feeding them arbitrary literals of `dataType1`
* and `dataType2`.
*/
def checkConsistencyBetweenInterpretedAndCodegenAllowingException(
c: (Expression, Expression) => Expression,
dataType1: DataType,
dataType2: DataType): Unit = {
forAll (
LiteralGenerator.randomGen(dataType1),
LiteralGenerator.randomGen(dataType2)
) { (l1: Literal, l2: Literal) =>
cmpInterpretWithCodegen(EmptyRow, c(l1, l2), true)
}
}
/**
* Test evaluation results between Interpreted mode and Codegen mode, making sure we have
* consistent result regardless of the evaluation method we use.
@ -398,21 +418,52 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa
}
}
def cmpInterpretWithCodegen(inputRow: InternalRow, expr: Expression): Unit = {
val interpret = try {
evaluateWithoutCodegen(expr, inputRow)
def cmpInterpretWithCodegen(
inputRow: InternalRow,
expr: Expression,
exceptionAllowed: Boolean = false): Unit = {
val (interpret, interpretExc) = try {
(Some(evaluateWithoutCodegen(expr, inputRow)), None)
} catch {
case e: Exception => fail(s"Exception evaluating $expr", e)
case e: Exception => if (exceptionAllowed) {
(None, Some(e))
} else {
fail(s"Exception evaluating $expr", e)
}
}
val plan = generateProject(
GenerateMutableProjection.generate(Alias(expr, s"Optimized($expr)")() :: Nil),
expr)
val codegen = plan(inputRow).get(0, expr.dataType)
if (!compareResults(interpret, codegen)) {
fail(s"Incorrect evaluation: $expr, interpret: $interpret, codegen: $codegen")
val (codegen, codegenExc) = try {
(Some(plan(inputRow).get(0, expr.dataType)), None)
} catch {
case e: Exception => if (exceptionAllowed) {
(None, Some(e))
} else {
fail(s"Exception evaluating $expr", e)
}
}
if (interpret.isDefined && codegen.isDefined && !compareResults(interpret.get, codegen.get)) {
fail(s"Incorrect evaluation: $expr, interpret: ${interpret.get}, codegen: ${codegen.get}")
} else if (interpretExc.isDefined && codegenExc.isEmpty) {
fail(s"Incorrect evaluation: $expr, interpet threw exception ${interpretExc.get}")
} else if (interpretExc.isEmpty && codegenExc.isDefined) {
fail(s"Incorrect evaluation: $expr, codegen threw exception ${codegenExc.get}")
} else if (interpretExc.isDefined && codegenExc.isDefined
&& !compareExceptions(interpretExc.get, codegenExc.get)) {
fail(s"Different exception evaluating: $expr, " +
s"interpret: ${interpretExc.get}, codegen: ${codegenExc.get}")
}
}
/**
* Checks the equality between two exceptions. Returns true iff the two exceptions are instances
* of the same class and they have the same message.
*/
private[this] def compareExceptions(e1: Exception, e2: Exception): Boolean = {
e1.getClass == e2.getClass && e1.getMessage == e2.getMessage
}
/**

View file

@ -33,6 +33,10 @@ INSERT INTO INT4_TBL VALUES ('-2147483647');
-- INSERT INTO INT4_TBL(f1) VALUES ('123 5');
-- INSERT INTO INT4_TBL(f1) VALUES ('');
-- We cannot test this when failOnOverFlow=true here
-- because exception happens in the executors and the
-- output stacktrace cannot have an exact match
set spark.sql.arithmeticOperations.failOnOverFlow=false;
SELECT '' AS five, * FROM INT4_TBL;

View file

@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 53
-- Number of queries: 54
-- !query 0
@ -51,10 +51,18 @@ struct<>
-- !query 6
SELECT '' AS five, * FROM INT4_TBL
set spark.sql.arithmeticOperations.failOnOverFlow=false
-- !query 6 schema
struct<five:string,f1:int>
struct<key:string,value:string>
-- !query 6 output
spark.sql.arithmeticOperations.failOnOverFlow false
-- !query 7
SELECT '' AS five, * FROM INT4_TBL
-- !query 7 schema
struct<five:string,f1:int>
-- !query 7 output
-123456
-2147483647
0
@ -62,19 +70,8 @@ struct<five:string,f1:int>
2147483647
-- !query 7
SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> smallint('0')
-- !query 7 schema
struct<four:string,f1:int>
-- !query 7 output
-123456
-2147483647
123456
2147483647
-- !query 8
SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> int('0')
SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> smallint('0')
-- !query 8 schema
struct<four:string,f1:int>
-- !query 8 output
@ -85,15 +82,18 @@ struct<four:string,f1:int>
-- !query 9
SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = smallint('0')
SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> int('0')
-- !query 9 schema
struct<one:string,f1:int>
struct<four:string,f1:int>
-- !query 9 output
0
-123456
-2147483647
123456
2147483647
-- !query 10
SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = int('0')
SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = smallint('0')
-- !query 10 schema
struct<one:string,f1:int>
-- !query 10 output
@ -101,16 +101,15 @@ struct<one:string,f1:int>
-- !query 11
SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < smallint('0')
SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = int('0')
-- !query 11 schema
struct<two:string,f1:int>
struct<one:string,f1:int>
-- !query 11 output
-123456
-2147483647
0
-- !query 12
SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < int('0')
SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < smallint('0')
-- !query 12 schema
struct<two:string,f1:int>
-- !query 12 output
@ -119,17 +118,16 @@ struct<two:string,f1:int>
-- !query 13
SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= smallint('0')
SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < int('0')
-- !query 13 schema
struct<three:string,f1:int>
struct<two:string,f1:int>
-- !query 13 output
-123456
-2147483647
0
-- !query 14
SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= int('0')
SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= smallint('0')
-- !query 14 schema
struct<three:string,f1:int>
-- !query 14 output
@ -139,16 +137,17 @@ struct<three:string,f1:int>
-- !query 15
SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > smallint('0')
SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= int('0')
-- !query 15 schema
struct<two:string,f1:int>
struct<three:string,f1:int>
-- !query 15 output
123456
2147483647
-123456
-2147483647
0
-- !query 16
SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > int('0')
SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > smallint('0')
-- !query 16 schema
struct<two:string,f1:int>
-- !query 16 output
@ -157,17 +156,16 @@ struct<two:string,f1:int>
-- !query 17
SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= smallint('0')
SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > int('0')
-- !query 17 schema
struct<three:string,f1:int>
struct<two:string,f1:int>
-- !query 17 output
0
123456
123456
2147483647
-- !query 18
SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= int('0')
SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= smallint('0')
-- !query 18 schema
struct<three:string,f1:int>
-- !query 18 output
@ -177,51 +175,61 @@ struct<three:string,f1:int>
-- !query 19
SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % smallint('2')) = smallint('1')
SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= int('0')
-- !query 19 schema
struct<one:string,f1:int>
struct<three:string,f1:int>
-- !query 19 output
2147483647
0
123456
2147483647
-- !query 20
SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = smallint('0')
SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % smallint('2')) = smallint('1')
-- !query 20 schema
struct<three:string,f1:int>
struct<one:string,f1:int>
-- !query 20 output
2147483647
-- !query 21
SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = smallint('0')
-- !query 21 schema
struct<three:string,f1:int>
-- !query 21 output
-123456
0
123456
-- !query 21
SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i
-- !query 21 schema
struct<five:string,f1:int,x:int>
-- !query 21 output
-123456 -246912
-2147483647 2
0 0
123456 246912
2147483647 -2
-- !query 22
SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i
WHERE abs(f1) < 1073741824
-- !query 22 schema
struct<five:string,f1:int,x:int>
-- !query 22 output
-123456 -246912
-2147483647 2
0 0
123456 246912
2147483647 -2
-- !query 23
SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i
WHERE abs(f1) < 1073741824
-- !query 23 schema
struct<five:string,f1:int,x:int>
-- !query 23 output
-123456 -246912
0 0
123456 246912
-- !query 23
-- !query 24
SELECT '' AS five, i.f1, i.f1 * int('2') AS x FROM INT4_TBL i
-- !query 23 schema
-- !query 24 schema
struct<five:string,f1:int,x:int>
-- !query 23 output
-- !query 24 output
-123456 -246912
-2147483647 2
0 0
@ -229,32 +237,19 @@ struct<five:string,f1:int,x:int>
2147483647 -2
-- !query 24
-- !query 25
SELECT '' AS five, i.f1, i.f1 * int('2') AS x FROM INT4_TBL i
WHERE abs(f1) < 1073741824
-- !query 24 schema
-- !query 25 schema
struct<five:string,f1:int,x:int>
-- !query 24 output
-- !query 25 output
-123456 -246912
0 0
123456 246912
-- !query 25
SELECT '' AS five, i.f1, i.f1 + smallint('2') AS x FROM INT4_TBL i
-- !query 25 schema
struct<five:string,f1:int,x:int>
-- !query 25 output
-123456 -123454
-2147483647 -2147483645
0 2
123456 123458
2147483647 -2147483647
-- !query 26
SELECT '' AS five, i.f1, i.f1 + smallint('2') AS x FROM INT4_TBL i
WHERE f1 < 2147483646
-- !query 26 schema
struct<five:string,f1:int,x:int>
-- !query 26 output
@ -262,10 +257,12 @@ struct<five:string,f1:int,x:int>
-2147483647 -2147483645
0 2
123456 123458
2147483647 -2147483647
-- !query 27
SELECT '' AS five, i.f1, i.f1 + int('2') AS x FROM INT4_TBL i
SELECT '' AS five, i.f1, i.f1 + smallint('2') AS x FROM INT4_TBL i
WHERE f1 < 2147483646
-- !query 27 schema
struct<five:string,f1:int,x:int>
-- !query 27 output
@ -273,12 +270,10 @@ struct<five:string,f1:int,x:int>
-2147483647 -2147483645
0 2
123456 123458
2147483647 -2147483647
-- !query 28
SELECT '' AS five, i.f1, i.f1 + int('2') AS x FROM INT4_TBL i
WHERE f1 < 2147483646
-- !query 28 schema
struct<five:string,f1:int,x:int>
-- !query 28 output
@ -286,39 +281,40 @@ struct<five:string,f1:int,x:int>
-2147483647 -2147483645
0 2
123456 123458
2147483647 -2147483647
-- !query 29
SELECT '' AS five, i.f1, i.f1 - smallint('2') AS x FROM INT4_TBL i
SELECT '' AS five, i.f1, i.f1 + int('2') AS x FROM INT4_TBL i
WHERE f1 < 2147483646
-- !query 29 schema
struct<five:string,f1:int,x:int>
-- !query 29 output
-123456 -123458
-2147483647 2147483647
0 -2
123456 123454
2147483647 2147483645
-123456 -123454
-2147483647 -2147483645
0 2
123456 123458
-- !query 30
SELECT '' AS five, i.f1, i.f1 - smallint('2') AS x FROM INT4_TBL i
WHERE f1 > -2147483647
-- !query 30 schema
struct<five:string,f1:int,x:int>
-- !query 30 output
-123456 -123458
-2147483647 2147483647
0 -2
123456 123454
2147483647 2147483645
-- !query 31
SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i
SELECT '' AS five, i.f1, i.f1 - smallint('2') AS x FROM INT4_TBL i
WHERE f1 > -2147483647
-- !query 31 schema
struct<five:string,f1:int,x:int>
-- !query 31 output
-123456 -123458
-2147483647 2147483647
0 -2
123456 123454
2147483647 2147483645
@ -326,30 +322,30 @@ struct<five:string,f1:int,x:int>
-- !query 32
SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i
WHERE f1 > -2147483647
-- !query 32 schema
struct<five:string,f1:int,x:int>
-- !query 32 output
-123456 -123458
-2147483647 2147483647
0 -2
123456 123454
2147483647 2147483645
-- !query 33
SELECT '' AS five, i.f1, i.f1 / smallint('2') AS x FROM INT4_TBL i
SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i
WHERE f1 > -2147483647
-- !query 33 schema
struct<five:string,f1:int,x:int>
-- !query 33 output
-123456 -61728
-2147483647 -1073741823
0 0
123456 61728
2147483647 1073741823
-123456 -123458
0 -2
123456 123454
2147483647 2147483645
-- !query 34
SELECT '' AS five, i.f1, i.f1 / int('2') AS x FROM INT4_TBL i
SELECT '' AS five, i.f1, i.f1 / smallint('2') AS x FROM INT4_TBL i
-- !query 34 schema
struct<five:string,f1:int,x:int>
-- !query 34 output
@ -361,47 +357,51 @@ struct<five:string,f1:int,x:int>
-- !query 35
SELECT -2+3 AS one
SELECT '' AS five, i.f1, i.f1 / int('2') AS x FROM INT4_TBL i
-- !query 35 schema
struct<one:int>
struct<five:string,f1:int,x:int>
-- !query 35 output
1
-123456 -61728
-2147483647 -1073741823
0 0
123456 61728
2147483647 1073741823
-- !query 36
SELECT 4-2 AS two
SELECT -2+3 AS one
-- !query 36 schema
struct<two:int>
struct<one:int>
-- !query 36 output
2
1
-- !query 37
SELECT 2- -1 AS three
SELECT 4-2 AS two
-- !query 37 schema
struct<three:int>
struct<two:int>
-- !query 37 output
3
2
-- !query 38
SELECT 2 - -2 AS four
SELECT 2- -1 AS three
-- !query 38 schema
struct<four:int>
struct<three:int>
-- !query 38 output
4
3
-- !query 39
SELECT smallint('2') * smallint('2') = smallint('16') / smallint('4') AS true
SELECT 2 - -2 AS four
-- !query 39 schema
struct<true:boolean>
struct<four:int>
-- !query 39 output
true
4
-- !query 40
SELECT int('2') * smallint('2') = smallint('16') / int('4') AS true
SELECT smallint('2') * smallint('2') = smallint('16') / smallint('4') AS true
-- !query 40 schema
struct<true:boolean>
-- !query 40 output
@ -409,7 +409,7 @@ true
-- !query 41
SELECT smallint('2') * int('2') = int('16') / smallint('4') AS true
SELECT int('2') * smallint('2') = smallint('16') / int('4') AS true
-- !query 41 schema
struct<true:boolean>
-- !query 41 output
@ -417,70 +417,78 @@ true
-- !query 42
SELECT int('1000') < int('999') AS `false`
SELECT smallint('2') * int('2') = int('16') / smallint('4') AS true
-- !query 42 schema
struct<false:boolean>
struct<true:boolean>
-- !query 42 output
false
true
-- !query 43
SELECT 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 AS ten
SELECT int('1000') < int('999') AS `false`
-- !query 43 schema
struct<ten:int>
struct<false:boolean>
-- !query 43 output
10
false
-- !query 44
SELECT 2 + 2 / 2 AS three
SELECT 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 AS ten
-- !query 44 schema
struct<three:int>
struct<ten:int>
-- !query 44 output
3
10
-- !query 45
SELECT (2 + 2) / 2 AS two
SELECT 2 + 2 / 2 AS three
-- !query 45 schema
struct<two:int>
struct<three:int>
-- !query 45 output
2
3
-- !query 46
SELECT string(shiftleft(int(-1), 31))
SELECT (2 + 2) / 2 AS two
-- !query 46 schema
struct<CAST(shiftleft(CAST(-1 AS INT), 31) AS STRING):string>
struct<two:int>
-- !query 46 output
-2147483648
2
-- !query 47
SELECT string(int(shiftleft(int(-1), 31))+1)
SELECT string(shiftleft(int(-1), 31))
-- !query 47 schema
struct<CAST((CAST(shiftleft(CAST(-1 AS INT), 31) AS INT) + 1) AS STRING):string>
struct<CAST(shiftleft(CAST(-1 AS INT), 31) AS STRING):string>
-- !query 47 output
-2147483647
-2147483648
-- !query 48
SELECT int(-2147483648) % int(-1)
SELECT string(int(shiftleft(int(-1), 31))+1)
-- !query 48 schema
struct<(CAST(-2147483648 AS INT) % CAST(-1 AS INT)):int>
struct<CAST((CAST(shiftleft(CAST(-1 AS INT), 31) AS INT) + 1) AS STRING):string>
-- !query 48 output
0
-2147483647
-- !query 49
SELECT int(-2147483648) % smallint(-1)
SELECT int(-2147483648) % int(-1)
-- !query 49 schema
struct<(CAST(-2147483648 AS INT) % CAST(CAST(-1 AS SMALLINT) AS INT)):int>
struct<(CAST(-2147483648 AS INT) % CAST(-1 AS INT)):int>
-- !query 49 output
0
-- !query 50
SELECT int(-2147483648) % smallint(-1)
-- !query 50 schema
struct<(CAST(-2147483648 AS INT) % CAST(CAST(-1 AS SMALLINT) AS INT)):int>
-- !query 50 output
0
-- !query 51
SELECT x, int(x) AS int4_value
FROM (VALUES double(-2.5),
double(-1.5),
@ -489,9 +497,9 @@ FROM (VALUES double(-2.5),
double(0.5),
double(1.5),
double(2.5)) t(x)
-- !query 50 schema
-- !query 51 schema
struct<x:double,int4_value:int>
-- !query 50 output
-- !query 51 output
-0.5 0
-1.5 -1
-2.5 -2
@ -501,7 +509,7 @@ struct<x:double,int4_value:int>
2.5 2
-- !query 51
-- !query 52
SELECT x, int(x) AS int4_value
FROM (VALUES cast(-2.5 as decimal(38, 18)),
cast(-1.5 as decimal(38, 18)),
@ -510,9 +518,9 @@ FROM (VALUES cast(-2.5 as decimal(38, 18)),
cast(0.5 as decimal(38, 18)),
cast(1.5 as decimal(38, 18)),
cast(2.5 as decimal(38, 18))) t(x)
-- !query 51 schema
-- !query 52 schema
struct<x:decimal(38,18),int4_value:int>
-- !query 51 output
-- !query 52 output
-0.5 0
-1.5 -1
-2.5 -2
@ -522,9 +530,9 @@ struct<x:decimal(38,18),int4_value:int>
2.5 2
-- !query 52
-- !query 53
DROP TABLE INT4_TBL
-- !query 52 schema
-- !query 53 schema
struct<>
-- !query 52 output
-- !query 53 output

View file

@ -113,11 +113,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall
val random = new Random(seed)
def randomBound(): Long = {
val n = if (random.nextBoolean()) {
random.nextLong() % (Long.MaxValue / (100 * MAX_NUM_STEPS))
} else {
random.nextLong() / 2
}
val n = random.nextLong() % (Long.MaxValue / (100 * MAX_NUM_STEPS))
if (random.nextBoolean()) n else -n
}