[SPARK-27768][SQL] Support Infinity/NaN-related float/double literals case-insensitively

## What changes were proposed in this pull request?
Here is the problem description from the JIRA.
```
When the inputs contain the constant 'infinity', Spark SQL does not generate the expected results.

SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('1'), (CAST('infinity' AS DOUBLE))) v(x);
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('infinity'), ('1')) v(x);
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('infinity'), ('infinity')) v(x);
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('-infinity'), ('infinity')) v(x);
 The root cause: Spark SQL does not recognize the special constants in a case insensitive way. In PostgreSQL, they are recognized in a case insensitive way.

Link: https://www.postgresql.org/docs/9.3/datatype-numeric.html
```

In this PR, the casting code is enhanced to handle these `special` string literals in case insensitive manner.

## How was this patch tested?
Added tests in CastSuite and modified existing test suites.

Closes #25331 from dilipbiswal/double_infinity.

Authored-by: Dilip Biswal <dbiswal@us.ibm.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
Dilip Biswal 2019-08-13 16:48:30 -07:00 committed by Dongjoon Hyun
parent f1d6b19de5
commit 331f2657d9
9 changed files with 214 additions and 62 deletions

View file

@ -161,6 +161,95 @@ license: |
- Since Spark 3.0, Dataset query fails if it contains ambiguous column reference that is caused by self join. A typical example: `val df1 = ...; val df2 = df1.filter(...);`, then `df1.join(df2, df1("a") > df2("a"))` returns an empty result which is quite confusing. This is because Spark cannot resolve Dataset column references that point to tables being self joined, and `df1("a")` is exactly the same as `df2("a")` in Spark. To restore the behavior before Spark 3.0, you can set `spark.sql.analyzer.failAmbiguousSelfJoin` to `false`.
- Since Spark 3.0, `Cast` function processes string literals such as 'Infinity', '+Infinity', '-Infinity', 'NaN', 'Inf', '+Inf', '-Inf' in case insensitive manner when casting the literals to `Double` or `Float` type to ensure greater compatibility with other database systems. This behaviour change is illustrated in the table below:
<table class="table">
<tr>
<th>
<b>Operation</b>
</th>
<th>
<b>Result prior to Spark 3.0</b>
</th>
<th>
<b>Result starting Spark 3.0</b>
</th>
</tr>
<tr>
<td>
CAST('infinity' AS DOUBLE)<br>
CAST('+infinity' AS DOUBLE)<br>
CAST('inf' AS DOUBLE)<br>
CAST('+inf' AS DOUBLE)<br>
</td>
<td>
NULL
</td>
<td>
Double.PositiveInfinity
</td>
</tr>
<tr>
<td>
CAST('-infinity' AS DOUBLE)<br>
CAST('-inf' AS DOUBLE)<br>
</td>
<td>
NULL
</td>
<td>
Double.NegativeInfinity
</td>
</tr>
<tr>
<td>
CAST('infinity' AS FLOAT)<br>
CAST('+infinity' AS FLOAT)<br>
CAST('inf' AS FLOAT)<br>
CAST('+inf' AS FLOAT)<br>
</td>
<td>
NULL
</td>
<td>
Float.PositiveInfinity
</td>
</tr>
<tr>
<td>
CAST('-infinity' AS FLOAT)<br>
CAST('-inf' AS FLOAT)<br>
</td>
<td>
NULL
</td>
<td>
Float.NegativeInfinity
</td>
</tr>
<tr>
<td>
CAST('nan' AS DOUBLE)
</td>
<td>
NULL
</td>
<td>
Double.NaN
</td>
</tr>
<tr>
<td>
CAST('nan' AS FLOAT)
</td>
<td>
NULL
</td>
<td>
Float.NaN
</td>
</tr>
</table>
## Upgrading from Spark SQL 2.4 to 2.4.1
- The value of `spark.executor.heartbeatInterval`, when specified without units like "30" rather than "30s", was

View file

@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.math.{BigDecimal => JavaBigDecimal}
import java.time.ZoneId
import java.util.Locale
import java.util.concurrent.TimeUnit._
import org.apache.spark.SparkException
@ -193,6 +194,22 @@ object Cast {
}
def resolvableNullability(from: Boolean, to: Boolean): Boolean = !from || to
/**
* We process literals such as 'Infinity', 'Inf', '-Infinity' and 'NaN' etc in case
* insensitive manner to be compatible with other database systems such as PostgreSQL and DB2.
*/
def processFloatingPointSpecialLiterals(v: String, isFloat: Boolean): Any = {
v.trim.toLowerCase(Locale.ROOT) match {
case "inf" | "+inf" | "infinity" | "+infinity" =>
if (isFloat) Float.PositiveInfinity else Double.PositiveInfinity
case "-inf" | "-infinity" =>
if (isFloat) Float.NegativeInfinity else Double.NegativeInfinity
case "nan" =>
if (isFloat) Float.NaN else Double.NaN
case _ => null
}
}
}
/**
@ -563,8 +580,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// DoubleConverter
private[this] def castToDouble(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => try s.toString.toDouble catch {
case _: NumberFormatException => null
buildCast[UTF8String](_, s => {
val doubleStr = s.toString
try doubleStr.toDouble catch {
case _: NumberFormatException =>
Cast.processFloatingPointSpecialLiterals(doubleStr, false)
}
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1d else 0d)
@ -579,8 +600,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// FloatConverter
private[this] def castToFloat(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => try s.toString.toFloat catch {
case _: NumberFormatException => null
buildCast[UTF8String](_, s => {
val floatStr = s.toString
try floatStr.toFloat catch {
case _: NumberFormatException =>
Cast.processFloatingPointSpecialLiterals(floatStr, true)
}
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1f else 0f)
@ -718,9 +743,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case ByteType => castToByteCode(from, ctx)
case ShortType => castToShortCode(from, ctx)
case IntegerType => castToIntCode(from, ctx)
case FloatType => castToFloatCode(from)
case FloatType => castToFloatCode(from, ctx)
case LongType => castToLongCode(from, ctx)
case DoubleType => castToDoubleCode(from)
case DoubleType => castToDoubleCode(from, ctx)
case array: ArrayType =>
castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx)
@ -1260,48 +1285,66 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => code"$evPrim = (long) $c;"
}
private[this] def castToFloatCode(from: DataType): CastFunction = from match {
case StringType =>
(c, evPrim, evNull) =>
code"""
private[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = {
from match {
case StringType =>
val floatStr = ctx.freshVariable("floatStr", StringType)
(c, evPrim, evNull) =>
code"""
final String $floatStr = $c.toString();
try {
$evPrim = Float.valueOf($c.toString());
$evPrim = Float.valueOf($floatStr);
} catch (java.lang.NumberFormatException e) {
$evNull = true;
final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true);
if (f == null) {
$evNull = true;
} else {
$evPrim = f.floatValue();
}
}
"""
case BooleanType =>
(c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;"
case DateType =>
(c, evPrim, evNull) => code"$evNull = true;"
case TimestampType =>
(c, evPrim, evNull) => code"$evPrim = (float) (${timestampToDoubleCode(c)});"
case DecimalType() =>
(c, evPrim, evNull) => code"$evPrim = $c.toFloat();"
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (float) $c;"
case BooleanType =>
(c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;"
case DateType =>
(c, evPrim, evNull) => code"$evNull = true;"
case TimestampType =>
(c, evPrim, evNull) => code"$evPrim = (float) (${timestampToDoubleCode(c)});"
case DecimalType() =>
(c, evPrim, evNull) => code"$evPrim = $c.toFloat();"
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (float) $c;"
}
}
private[this] def castToDoubleCode(from: DataType): CastFunction = from match {
case StringType =>
(c, evPrim, evNull) =>
code"""
private[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = {
from match {
case StringType =>
val doubleStr = ctx.freshVariable("doubleStr", StringType)
(c, evPrim, evNull) =>
code"""
final String $doubleStr = $c.toString();
try {
$evPrim = Double.valueOf($c.toString());
$evPrim = Double.valueOf($doubleStr);
} catch (java.lang.NumberFormatException e) {
$evNull = true;
final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false);
if (d == null) {
$evNull = true;
} else {
$evPrim = d.doubleValue();
}
}
"""
case BooleanType =>
(c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;"
case DateType =>
(c, evPrim, evNull) => code"$evNull = true;"
case TimestampType =>
(c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};"
case DecimalType() =>
(c, evPrim, evNull) => code"$evPrim = $c.toDouble();"
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (double) $c;"
case BooleanType =>
(c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;"
case DateType =>
(c, evPrim, evNull) => code"$evNull = true;"
case TimestampType =>
(c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};"
case DecimalType() =>
(c, evPrim, evNull) => code"$evPrim = $c.toDouble();"
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (double) $c;"
}
}
private[this] def castArrayCode(

View file

@ -1049,4 +1049,30 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
Cast(Literal(134.12), DecimalType(3, 2)), "cannot be represented")
}
}
test("Process Infinity, -Infinity, NaN in case insensitive manner") {
Seq("inf", "+inf", "infinity", "+infiNity", " infinity ").foreach { value =>
checkEvaluation(cast(value, FloatType), Float.PositiveInfinity)
}
Seq("-infinity", "-infiniTy", " -infinity ", " -inf ").foreach { value =>
checkEvaluation(cast(value, FloatType), Float.NegativeInfinity)
}
Seq("inf", "+inf", "infinity", "+infiNity", " infinity ").foreach { value =>
checkEvaluation(cast(value, DoubleType), Double.PositiveInfinity)
}
Seq("-infinity", "-infiniTy", " -infinity ", " -inf ").foreach { value =>
checkEvaluation(cast(value, DoubleType), Double.NegativeInfinity)
}
Seq("nan", "nAn", " nan ").foreach { value =>
checkEvaluation(cast(value, FloatType), Float.NaN)
}
Seq("nan", "nAn", " nan ").foreach { value =>
checkEvaluation(cast(value, DoubleType), Double.NaN)
}
// Invalid literals when casted to double and float results in null.
Seq(DoubleType, FloatType).foreach { dataType =>
checkEvaluation(cast("badvalue", dataType), null)
}
}
}

View file

@ -59,16 +59,14 @@ select avg(CAST(null AS DOUBLE)) from range(1,4);
select sum(CAST('NaN' AS DOUBLE)) from range(1,4);
select avg(CAST('NaN' AS DOUBLE)) from range(1,4);
-- [SPARK-27768] verify correct results for infinite inputs
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES (CAST('1' AS DOUBLE)), (CAST('Infinity' AS DOUBLE))) v(x);
FROM (VALUES (CAST('1' AS DOUBLE)), (CAST('infinity' AS DOUBLE))) v(x);
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('Infinity'), ('1')) v(x);
FROM (VALUES ('infinity'), ('1')) v(x);
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('Infinity'), ('Infinity')) v(x);
FROM (VALUES ('infinity'), ('infinity')) v(x);
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('-Infinity'), ('Infinity')) v(x);
FROM (VALUES ('-infinity'), ('infinity')) v(x);
-- test accuracy with a large input offset
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))

View file

@ -38,7 +38,6 @@ INSERT INTO FLOAT4_TBL VALUES ('1.2345678901234e-20');
-- special inputs
SELECT float('NaN');
-- [SPARK-28060] Float type can not accept some special inputs
SELECT float('nan');
SELECT float(' NAN ');
SELECT float('infinity');
@ -49,7 +48,6 @@ SELECT float('N A N');
SELECT float('NaN x');
SELECT float(' INFINITY x');
-- [SPARK-28060] Float type can not accept some special inputs
SELECT float('Infinity') + 100.0;
SELECT float('Infinity') / float('Infinity');
SELECT float('nan') / float('nan');

View file

@ -37,7 +37,6 @@ SELECT double('-10e-400');
-- special inputs
SELECT double('NaN');
-- [SPARK-28060] Double type can not accept some special inputs
SELECT double('nan');
SELECT double(' NAN ');
SELECT double('infinity');
@ -49,7 +48,6 @@ SELECT double('NaN x');
SELECT double(' INFINITY x');
SELECT double('Infinity') + 100.0;
-- [SPARK-27768] Infinity, -Infinity, NaN should be recognized in a case insensitive manner
SELECT double('Infinity') / double('Infinity');
SELECT double('NaN') / double('NaN');
-- [SPARK-28315] Decimal can not accept NaN as input
@ -190,7 +188,7 @@ SELECT tanh(double('1'));
SELECT asinh(double('1'));
SELECT acosh(double('2'));
SELECT atanh(double('0.5'));
-- [SPARK-27768] Infinity, -Infinity, NaN should be recognized in a case insensitive manner
-- test Inf/NaN cases for hyperbolic functions
SELECT sinh(double('Infinity'));
SELECT sinh(double('-Infinity'));

View file

@ -236,7 +236,7 @@ NaN
-- !query 29
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES (CAST('1' AS DOUBLE)), (CAST('Infinity' AS DOUBLE))) v(x)
FROM (VALUES (CAST('1' AS DOUBLE)), (CAST('infinity' AS DOUBLE))) v(x)
-- !query 29 schema
struct<avg(CAST(x AS DOUBLE)):double,var_pop(CAST(x AS DOUBLE)):double>
-- !query 29 output
@ -245,7 +245,7 @@ Infinity NaN
-- !query 30
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('Infinity'), ('1')) v(x)
FROM (VALUES ('infinity'), ('1')) v(x)
-- !query 30 schema
struct<avg(CAST(x AS DOUBLE)):double,var_pop(CAST(x AS DOUBLE)):double>
-- !query 30 output
@ -254,7 +254,7 @@ Infinity NaN
-- !query 31
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('Infinity'), ('Infinity')) v(x)
FROM (VALUES ('infinity'), ('infinity')) v(x)
-- !query 31 schema
struct<avg(CAST(x AS DOUBLE)):double,var_pop(CAST(x AS DOUBLE)):double>
-- !query 31 output
@ -263,7 +263,7 @@ Infinity NaN
-- !query 32
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('-Infinity'), ('Infinity')) v(x)
FROM (VALUES ('-infinity'), ('infinity')) v(x)
-- !query 32 schema
struct<avg(CAST(x AS DOUBLE)):double,var_pop(CAST(x AS DOUBLE)):double>
-- !query 32 output

View file

@ -63,7 +63,7 @@ SELECT float('nan')
-- !query 7 schema
struct<CAST(nan AS FLOAT):float>
-- !query 7 output
NULL
NaN
-- !query 8
@ -71,7 +71,7 @@ SELECT float(' NAN ')
-- !query 8 schema
struct<CAST( NAN AS FLOAT):float>
-- !query 8 output
NULL
NaN
-- !query 9
@ -79,7 +79,7 @@ SELECT float('infinity')
-- !query 9 schema
struct<CAST(infinity AS FLOAT):float>
-- !query 9 output
NULL
Infinity
-- !query 10
@ -87,7 +87,7 @@ SELECT float(' -INFINiTY ')
-- !query 10 schema
struct<CAST( -INFINiTY AS FLOAT):float>
-- !query 10 output
NULL
-Infinity
-- !query 11
@ -135,7 +135,7 @@ SELECT float('nan') / float('nan')
-- !query 16 schema
struct<(CAST(CAST(nan AS FLOAT) AS DOUBLE) / CAST(CAST(nan AS FLOAT) AS DOUBLE)):double>
-- !query 16 output
NULL
NaN
-- !query 17

View file

@ -95,7 +95,7 @@ SELECT double('nan')
-- !query 11 schema
struct<CAST(nan AS DOUBLE):double>
-- !query 11 output
NULL
NaN
-- !query 12
@ -103,7 +103,7 @@ SELECT double(' NAN ')
-- !query 12 schema
struct<CAST( NAN AS DOUBLE):double>
-- !query 12 output
NULL
NaN
-- !query 13
@ -111,7 +111,7 @@ SELECT double('infinity')
-- !query 13 schema
struct<CAST(infinity AS DOUBLE):double>
-- !query 13 output
NULL
Infinity
-- !query 14
@ -119,7 +119,7 @@ SELECT double(' -INFINiTY ')
-- !query 14 schema
struct<CAST( -INFINiTY AS DOUBLE):double>
-- !query 14 output
NULL
-Infinity
-- !query 15