From 331f2657d9451ac9de85f576953afde187ff9bab Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 13 Aug 2019 16:48:30 -0700 Subject: [PATCH] [SPARK-27768][SQL] Support Infinity/NaN-related float/double literals case-insensitively ## What changes were proposed in this pull request? Here is the problem description from the JIRA. ``` When the inputs contain the constant 'infinity', Spark SQL does not generate the expected results. SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE)) FROM (VALUES ('1'), (CAST('infinity' AS DOUBLE))) v(x); SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE)) FROM (VALUES ('infinity'), ('1')) v(x); SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE)) FROM (VALUES ('infinity'), ('infinity')) v(x); SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE)) FROM (VALUES ('-infinity'), ('infinity')) v(x); The root cause: Spark SQL does not recognize the special constants in a case insensitive way. In PostgreSQL, they are recognized in a case insensitive way. Link: https://www.postgresql.org/docs/9.3/datatype-numeric.html ``` In this PR, the casting code is enhanced to handle these `special` string literals in case insensitive manner. ## How was this patch tested? Added tests in CastSuite and modified existing test suites. Closes #25331 from dilipbiswal/double_infinity. Authored-by: Dilip Biswal Signed-off-by: Dongjoon Hyun --- docs/sql-migration-guide-upgrade.md | 89 +++++++++++++ .../spark/sql/catalyst/expressions/Cast.scala | 119 ++++++++++++------ .../sql/catalyst/expressions/CastSuite.scala | 26 ++++ .../inputs/pgSQL/aggregates_part1.sql | 10 +- .../sql-tests/inputs/pgSQL/float4.sql | 2 - .../sql-tests/inputs/pgSQL/float8.sql | 4 +- .../results/pgSQL/aggregates_part1.sql.out | 8 +- .../sql-tests/results/pgSQL/float4.sql.out | 10 +- .../sql-tests/results/pgSQL/float8.sql.out | 8 +- 9 files changed, 214 insertions(+), 62 deletions(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index b2bd8cefc3..a643a843a5 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -161,6 +161,95 @@ license: | - Since Spark 3.0, Dataset query fails if it contains ambiguous column reference that is caused by self join. A typical example: `val df1 = ...; val df2 = df1.filter(...);`, then `df1.join(df2, df1("a") > df2("a"))` returns an empty result which is quite confusing. This is because Spark cannot resolve Dataset column references that point to tables being self joined, and `df1("a")` is exactly the same as `df2("a")` in Spark. To restore the behavior before Spark 3.0, you can set `spark.sql.analyzer.failAmbiguousSelfJoin` to `false`. + - Since Spark 3.0, `Cast` function processes string literals such as 'Infinity', '+Infinity', '-Infinity', 'NaN', 'Inf', '+Inf', '-Inf' in case insensitive manner when casting the literals to `Double` or `Float` type to ensure greater compatibility with other database systems. This behaviour change is illustrated in the table below: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ Operation + + Result prior to Spark 3.0 + + Result starting Spark 3.0 +
+ CAST('infinity' AS DOUBLE)
+ CAST('+infinity' AS DOUBLE)
+ CAST('inf' AS DOUBLE)
+ CAST('+inf' AS DOUBLE)
+
+ NULL + + Double.PositiveInfinity +
+ CAST('-infinity' AS DOUBLE)
+ CAST('-inf' AS DOUBLE)
+
+ NULL + + Double.NegativeInfinity +
+ CAST('infinity' AS FLOAT)
+ CAST('+infinity' AS FLOAT)
+ CAST('inf' AS FLOAT)
+ CAST('+inf' AS FLOAT)
+
+ NULL + + Float.PositiveInfinity +
+ CAST('-infinity' AS FLOAT)
+ CAST('-inf' AS FLOAT)
+
+ NULL + + Float.NegativeInfinity +
+ CAST('nan' AS DOUBLE) + + NULL + + Double.NaN +
+ CAST('nan' AS FLOAT) + + NULL + + Float.NaN +
+ ## Upgrading from Spark SQL 2.4 to 2.4.1 - The value of `spark.executor.heartbeatInterval`, when specified without units like "30" rather than "30s", was diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a0cb5da078..32e2707948 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.math.{BigDecimal => JavaBigDecimal} import java.time.ZoneId +import java.util.Locale import java.util.concurrent.TimeUnit._ import org.apache.spark.SparkException @@ -193,6 +194,22 @@ object Cast { } def resolvableNullability(from: Boolean, to: Boolean): Boolean = !from || to + + /** + * We process literals such as 'Infinity', 'Inf', '-Infinity' and 'NaN' etc in case + * insensitive manner to be compatible with other database systems such as PostgreSQL and DB2. + */ + def processFloatingPointSpecialLiterals(v: String, isFloat: Boolean): Any = { + v.trim.toLowerCase(Locale.ROOT) match { + case "inf" | "+inf" | "infinity" | "+infinity" => + if (isFloat) Float.PositiveInfinity else Double.PositiveInfinity + case "-inf" | "-infinity" => + if (isFloat) Float.NegativeInfinity else Double.NegativeInfinity + case "nan" => + if (isFloat) Float.NaN else Double.NaN + case _ => null + } + } } /** @@ -563,8 +580,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // DoubleConverter private[this] def castToDouble(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toString.toDouble catch { - case _: NumberFormatException => null + buildCast[UTF8String](_, s => { + val doubleStr = s.toString + try doubleStr.toDouble catch { + case _: NumberFormatException => + Cast.processFloatingPointSpecialLiterals(doubleStr, false) + } }) case BooleanType => buildCast[Boolean](_, b => if (b) 1d else 0d) @@ -579,8 +600,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // FloatConverter private[this] def castToFloat(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toString.toFloat catch { - case _: NumberFormatException => null + buildCast[UTF8String](_, s => { + val floatStr = s.toString + try floatStr.toFloat catch { + case _: NumberFormatException => + Cast.processFloatingPointSpecialLiterals(floatStr, true) + } }) case BooleanType => buildCast[Boolean](_, b => if (b) 1f else 0f) @@ -718,9 +743,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case ByteType => castToByteCode(from, ctx) case ShortType => castToShortCode(from, ctx) case IntegerType => castToIntCode(from, ctx) - case FloatType => castToFloatCode(from) + case FloatType => castToFloatCode(from, ctx) case LongType => castToLongCode(from, ctx) - case DoubleType => castToDoubleCode(from) + case DoubleType => castToDoubleCode(from, ctx) case array: ArrayType => castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx) @@ -1260,48 +1285,66 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => code"$evPrim = (long) $c;" } - private[this] def castToFloatCode(from: DataType): CastFunction = from match { - case StringType => - (c, evPrim, evNull) => - code""" + private[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = { + from match { + case StringType => + val floatStr = ctx.freshVariable("floatStr", StringType) + (c, evPrim, evNull) => + code""" + final String $floatStr = $c.toString(); try { - $evPrim = Float.valueOf($c.toString()); + $evPrim = Float.valueOf($floatStr); } catch (java.lang.NumberFormatException e) { - $evNull = true; + final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true); + if (f == null) { + $evNull = true; + } else { + $evPrim = f.floatValue(); + } } """ - case BooleanType => - (c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;" - case DateType => - (c, evPrim, evNull) => code"$evNull = true;" - case TimestampType => - (c, evPrim, evNull) => code"$evPrim = (float) (${timestampToDoubleCode(c)});" - case DecimalType() => - (c, evPrim, evNull) => code"$evPrim = $c.toFloat();" - case x: NumericType => - (c, evPrim, evNull) => code"$evPrim = (float) $c;" + case BooleanType => + (c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;" + case DateType => + (c, evPrim, evNull) => code"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => code"$evPrim = (float) (${timestampToDoubleCode(c)});" + case DecimalType() => + (c, evPrim, evNull) => code"$evPrim = $c.toFloat();" + case x: NumericType => + (c, evPrim, evNull) => code"$evPrim = (float) $c;" + } } - private[this] def castToDoubleCode(from: DataType): CastFunction = from match { - case StringType => - (c, evPrim, evNull) => - code""" + private[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = { + from match { + case StringType => + val doubleStr = ctx.freshVariable("doubleStr", StringType) + (c, evPrim, evNull) => + code""" + final String $doubleStr = $c.toString(); try { - $evPrim = Double.valueOf($c.toString()); + $evPrim = Double.valueOf($doubleStr); } catch (java.lang.NumberFormatException e) { - $evNull = true; + final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false); + if (d == null) { + $evNull = true; + } else { + $evPrim = d.doubleValue(); + } } """ - case BooleanType => - (c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;" - case DateType => - (c, evPrim, evNull) => code"$evNull = true;" - case TimestampType => - (c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};" - case DecimalType() => - (c, evPrim, evNull) => code"$evPrim = $c.toDouble();" - case x: NumericType => - (c, evPrim, evNull) => code"$evPrim = (double) $c;" + case BooleanType => + (c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;" + case DateType => + (c, evPrim, evNull) => code"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => code"$evPrim = $c.toDouble();" + case x: NumericType => + (c, evPrim, evNull) => code"$evPrim = (double) $c;" + } } private[this] def castArrayCode( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 69adb8e922..1f9fa22d30 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -1049,4 +1049,30 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { Cast(Literal(134.12), DecimalType(3, 2)), "cannot be represented") } } + + test("Process Infinity, -Infinity, NaN in case insensitive manner") { + Seq("inf", "+inf", "infinity", "+infiNity", " infinity ").foreach { value => + checkEvaluation(cast(value, FloatType), Float.PositiveInfinity) + } + Seq("-infinity", "-infiniTy", " -infinity ", " -inf ").foreach { value => + checkEvaluation(cast(value, FloatType), Float.NegativeInfinity) + } + Seq("inf", "+inf", "infinity", "+infiNity", " infinity ").foreach { value => + checkEvaluation(cast(value, DoubleType), Double.PositiveInfinity) + } + Seq("-infinity", "-infiniTy", " -infinity ", " -inf ").foreach { value => + checkEvaluation(cast(value, DoubleType), Double.NegativeInfinity) + } + Seq("nan", "nAn", " nan ").foreach { value => + checkEvaluation(cast(value, FloatType), Float.NaN) + } + Seq("nan", "nAn", " nan ").foreach { value => + checkEvaluation(cast(value, DoubleType), Double.NaN) + } + + // Invalid literals when casted to double and float results in null. + Seq(DoubleType, FloatType).foreach { dataType => + checkEvaluation(cast("badvalue", dataType), null) + } + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/aggregates_part1.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/aggregates_part1.sql index 801a16cf41..5d54be9341 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/aggregates_part1.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/aggregates_part1.sql @@ -59,16 +59,14 @@ select avg(CAST(null AS DOUBLE)) from range(1,4); select sum(CAST('NaN' AS DOUBLE)) from range(1,4); select avg(CAST('NaN' AS DOUBLE)) from range(1,4); --- [SPARK-27768] verify correct results for infinite inputs SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE)) -FROM (VALUES (CAST('1' AS DOUBLE)), (CAST('Infinity' AS DOUBLE))) v(x); +FROM (VALUES (CAST('1' AS DOUBLE)), (CAST('infinity' AS DOUBLE))) v(x); SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE)) -FROM (VALUES ('Infinity'), ('1')) v(x); +FROM (VALUES ('infinity'), ('1')) v(x); SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE)) -FROM (VALUES ('Infinity'), ('Infinity')) v(x); +FROM (VALUES ('infinity'), ('infinity')) v(x); SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE)) -FROM (VALUES ('-Infinity'), ('Infinity')) v(x); - +FROM (VALUES ('-infinity'), ('infinity')) v(x); -- test accuracy with a large input offset SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float4.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float4.sql index 3dad5cd56b..058467695a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float4.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float4.sql @@ -38,7 +38,6 @@ INSERT INTO FLOAT4_TBL VALUES ('1.2345678901234e-20'); -- special inputs SELECT float('NaN'); --- [SPARK-28060] Float type can not accept some special inputs SELECT float('nan'); SELECT float(' NAN '); SELECT float('infinity'); @@ -49,7 +48,6 @@ SELECT float('N A N'); SELECT float('NaN x'); SELECT float(' INFINITY x'); --- [SPARK-28060] Float type can not accept some special inputs SELECT float('Infinity') + 100.0; SELECT float('Infinity') / float('Infinity'); SELECT float('nan') / float('nan'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float8.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float8.sql index 6f8e3b596e..957dabdeba 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float8.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/float8.sql @@ -37,7 +37,6 @@ SELECT double('-10e-400'); -- special inputs SELECT double('NaN'); --- [SPARK-28060] Double type can not accept some special inputs SELECT double('nan'); SELECT double(' NAN '); SELECT double('infinity'); @@ -49,7 +48,6 @@ SELECT double('NaN x'); SELECT double(' INFINITY x'); SELECT double('Infinity') + 100.0; --- [SPARK-27768] Infinity, -Infinity, NaN should be recognized in a case insensitive manner SELECT double('Infinity') / double('Infinity'); SELECT double('NaN') / double('NaN'); -- [SPARK-28315] Decimal can not accept NaN as input @@ -190,7 +188,7 @@ SELECT tanh(double('1')); SELECT asinh(double('1')); SELECT acosh(double('2')); SELECT atanh(double('0.5')); --- [SPARK-27768] Infinity, -Infinity, NaN should be recognized in a case insensitive manner + -- test Inf/NaN cases for hyperbolic functions SELECT sinh(double('Infinity')); SELECT sinh(double('-Infinity')); diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/aggregates_part1.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/aggregates_part1.sql.out index 51ca1d5586..29bafb42f5 100644 --- a/sql/core/src/test/resources/sql-tests/results/pgSQL/aggregates_part1.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/aggregates_part1.sql.out @@ -236,7 +236,7 @@ NaN -- !query 29 SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE)) -FROM (VALUES (CAST('1' AS DOUBLE)), (CAST('Infinity' AS DOUBLE))) v(x) +FROM (VALUES (CAST('1' AS DOUBLE)), (CAST('infinity' AS DOUBLE))) v(x) -- !query 29 schema struct -- !query 29 output @@ -245,7 +245,7 @@ Infinity NaN -- !query 30 SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE)) -FROM (VALUES ('Infinity'), ('1')) v(x) +FROM (VALUES ('infinity'), ('1')) v(x) -- !query 30 schema struct -- !query 30 output @@ -254,7 +254,7 @@ Infinity NaN -- !query 31 SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE)) -FROM (VALUES ('Infinity'), ('Infinity')) v(x) +FROM (VALUES ('infinity'), ('infinity')) v(x) -- !query 31 schema struct -- !query 31 output @@ -263,7 +263,7 @@ Infinity NaN -- !query 32 SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE)) -FROM (VALUES ('-Infinity'), ('Infinity')) v(x) +FROM (VALUES ('-infinity'), ('infinity')) v(x) -- !query 32 schema struct -- !query 32 output diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/float4.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/float4.sql.out index 86d88007d8..6e47cff91a 100644 --- a/sql/core/src/test/resources/sql-tests/results/pgSQL/float4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/float4.sql.out @@ -63,7 +63,7 @@ SELECT float('nan') -- !query 7 schema struct -- !query 7 output -NULL +NaN -- !query 8 @@ -71,7 +71,7 @@ SELECT float(' NAN ') -- !query 8 schema struct -- !query 8 output -NULL +NaN -- !query 9 @@ -79,7 +79,7 @@ SELECT float('infinity') -- !query 9 schema struct -- !query 9 output -NULL +Infinity -- !query 10 @@ -87,7 +87,7 @@ SELECT float(' -INFINiTY ') -- !query 10 schema struct -- !query 10 output -NULL +-Infinity -- !query 11 @@ -135,7 +135,7 @@ SELECT float('nan') / float('nan') -- !query 16 schema struct<(CAST(CAST(nan AS FLOAT) AS DOUBLE) / CAST(CAST(nan AS FLOAT) AS DOUBLE)):double> -- !query 16 output -NULL +NaN -- !query 17 diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/float8.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/float8.sql.out index eb9e8aa636..b4ea3c1ad1 100644 --- a/sql/core/src/test/resources/sql-tests/results/pgSQL/float8.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/float8.sql.out @@ -95,7 +95,7 @@ SELECT double('nan') -- !query 11 schema struct -- !query 11 output -NULL +NaN -- !query 12 @@ -103,7 +103,7 @@ SELECT double(' NAN ') -- !query 12 schema struct -- !query 12 output -NULL +NaN -- !query 13 @@ -111,7 +111,7 @@ SELECT double('infinity') -- !query 13 schema struct -- !query 13 output -NULL +Infinity -- !query 14 @@ -119,7 +119,7 @@ SELECT double(' -INFINiTY ') -- !query 14 schema struct -- !query 14 output -NULL +-Infinity -- !query 15