From 9ea55fe771cd26b701ff72d9aa539570ca04cbfc Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Fri, 14 May 2021 12:40:36 +0800 Subject: [PATCH] [SPARK-35207][SQL] Normalize hash function behavior with negative zero (floating point types) ### What changes were proposed in this pull request? Generally, we would expect that x = y => hash( x ) = hash( y ). However +-0 hash to different values for floating point types. ``` scala> spark.sql("select hash(cast('0.0' as double)), hash(cast('-0.0' as double))").show +-------------------------+--------------------------+ |hash(CAST(0.0 AS DOUBLE))|hash(CAST(-0.0 AS DOUBLE))| +-------------------------+--------------------------+ | -1670924195| -853646085| +-------------------------+--------------------------+ scala> spark.sql("select cast('0.0' as double) == cast('-0.0' as double)").show +--------------------------------------------+ |(CAST(0.0 AS DOUBLE) = CAST(-0.0 AS DOUBLE))| +--------------------------------------------+ | true| +--------------------------------------------+ ``` Here is an extract from IEEE 754: > The two zeros are distinguishable arithmetically only by either division-byzero ( producing appropriately signed infinities ) or else by the CopySign function recommended by IEEE 754 /854. Infinities, SNaNs, NaNs and Subnormal numbers necessitate four more special cases From this, I deduce that the hash function must produce the same result for 0 and -0. ### Why are the changes needed? It is a correctness issue ### Does this PR introduce _any_ user-facing change? This changes only affect to the hash function applied to -0 value in float and double types ### How was this patch tested? Unit testing and manual testing Closes #32496 from planga82/feature/spark35207_hashnegativezero. Authored-by: Pablo Langa Signed-off-by: Wenchen Fan --- docs/sql-migration-guide.md | 2 ++ .../spark/sql/catalyst/expressions/hash.scala | 24 +++++++++++++++---- .../expressions/HashExpressionsSuite.scala | 11 +++++++++ 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 2ce42bc5a5..ff2ad0464e 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -87,6 +87,8 @@ license: | - In Spark 3.2, Spark supports `DayTimeIntervalType` and `YearMonthIntervalType` as inputs and outputs of `TRANSFORM` clause in Hive `SERDE` mode, the behavior is different between Hive `SERDE` mode and `ROW FORMAT DELIMITED` mode when these two types are used as inputs. In Hive `SERDE` mode, `DayTimeIntervalType` column is converted to `HiveIntervalDayTime`, its string format is `[-]?d h:m:s.n`, but in `ROW FORMAT DELIMITED` mode the format is `INTERVAL '[-]?d h:m:s.n' DAY TO TIME`. In Hive `SERDE` mode, `YearMonthIntervalType` column is converted to `HiveIntervalYearMonth`, its string format is `[-]?y-m`, but in `ROW FORMAT DELIMITED` mode the format is `INTERVAL '[-]?y-m' YEAR TO MONTH`. + - In Spark 3.2, `hash(0) == hash(-0)` for floating point types. Previously, different values were generated. + ## Upgrading from Spark SQL 3.0 to 3.1 - In Spark 3.1, statistical aggregation function includes `std`, `stddev`, `stddev_samp`, `variance`, `var_samp`, `skewness`, `kurtosis`, `covar_samp`, `corr` will return `NULL` instead of `Double.NaN` when `DivideByZero` occurs during expression evaluation, for example, when `stddev_samp` applied on a single element set. In Spark version 3.0 and earlier, it will return `Double.NaN` in such case. To restore the behavior before Spark 3.1, you can set `spark.sql.legacy.statisticalAggregate` to `true`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index f3a8274318..65e7714a3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -369,11 +369,25 @@ abstract class HashExpression[E] extends Expression { protected def genHashBoolean(input: String, result: String): String = genHashInt(s"$input ? 1 : 0", result) - protected def genHashFloat(input: String, result: String): String = - genHashInt(s"Float.floatToIntBits($input)", result) + protected def genHashFloat(input: String, result: String): String = { + s""" + |if($input == -0.0f) { + | ${genHashInt("0", result)} + |} else { + | ${genHashInt(s"Float.floatToIntBits($input)", result)} + |} + """.stripMargin + } - protected def genHashDouble(input: String, result: String): String = - genHashLong(s"Double.doubleToLongBits($input)", result) + protected def genHashDouble(input: String, result: String): String = { + s""" + |if($input == -0.0d) { + | ${genHashLong("0L", result)} + |} else { + | ${genHashLong(s"Double.doubleToLongBits($input)", result)} + |} + """.stripMargin + } protected def genHashDecimal( ctx: CodegenContext, @@ -523,7 +537,9 @@ abstract class InterpretedHashFunction { case s: Short => hashInt(s, seed) case i: Int => hashInt(i, seed) case l: Long => hashLong(l, seed) + case f: Float if (f == -0.0f) => hashInt(0, seed) case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) + case d: Double if (d == -0.0d) => hashLong(0L, seed) case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) case d: Decimal => val precision = dataType.asInstanceOf[DecimalType].precision diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 858d8f78be..bd981d1633 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -708,6 +708,17 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(HiveHash(Seq(yearMonth)), 1234) } + test("SPARK-35207: Compute hash consistent between -0.0 and 0.0") { + def checkResult(exprs1: Expression, exprs2: Expression): Unit = { + checkEvaluation(Murmur3Hash(Seq(exprs1), 42), Murmur3Hash(Seq(exprs2), 42).eval()) + checkEvaluation(XxHash64(Seq(exprs1), 42), XxHash64(Seq(exprs2), 42).eval()) + checkEvaluation(HiveHash(Seq(exprs1)), HiveHash(Seq(exprs2)).eval()) + } + + checkResult(Literal.create(-0D, DoubleType), Literal.create(0D, DoubleType)) + checkResult(Literal.create(-0F, FloatType), Literal.create(0F, FloatType)) + } + private def testHash(inputSchema: StructType): Unit = { val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get val toRow = RowEncoder(inputSchema).createSerializer()