[SPARK-35207][SQL] Normalize hash function behavior with negative zero (floating point types)

### What changes were proposed in this pull request?

Generally, we would expect that x = y => hash( x ) = hash( y ). However +-0 hash to different values for floating point types.
```
scala> spark.sql("select hash(cast('0.0' as double)), hash(cast('-0.0' as double))").show
+-------------------------+--------------------------+
|hash(CAST(0.0 AS DOUBLE))|hash(CAST(-0.0 AS DOUBLE))|
+-------------------------+--------------------------+
|              -1670924195|                -853646085|
+-------------------------+--------------------------+
scala> spark.sql("select cast('0.0' as double) == cast('-0.0' as double)").show
+--------------------------------------------+
|(CAST(0.0 AS DOUBLE) = CAST(-0.0 AS DOUBLE))|
+--------------------------------------------+
|                                        true|
+--------------------------------------------+
```
Here is an extract from IEEE 754:

> The two zeros are distinguishable arithmetically only by either division-byzero ( producing appropriately signed infinities ) or else by the CopySign function recommended by IEEE 754 /854. Infinities, SNaNs, NaNs and Subnormal numbers necessitate four more special cases

From this, I deduce that the hash function must produce the same result for 0 and -0.

### Why are the changes needed?

It is a correctness issue

### Does this PR introduce _any_ user-facing change?

This changes only affect to the hash function applied to -0 value in float and double types

### How was this patch tested?

Unit testing and manual testing

Closes #32496 from planga82/feature/spark35207_hashnegativezero.

Authored-by: Pablo Langa <soypab@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Pablo Langa 2021-05-14 12:40:36 +08:00 committed by Wenchen Fan
parent f7af9ab8dc
commit 9ea55fe771
3 changed files with 33 additions and 4 deletions

View file

@ -87,6 +87,8 @@ license: |
- In Spark 3.2, Spark supports `DayTimeIntervalType` and `YearMonthIntervalType` as inputs and outputs of `TRANSFORM` clause in Hive `SERDE` mode, the behavior is different between Hive `SERDE` mode and `ROW FORMAT DELIMITED` mode when these two types are used as inputs. In Hive `SERDE` mode, `DayTimeIntervalType` column is converted to `HiveIntervalDayTime`, its string format is `[-]?d h:m:s.n`, but in `ROW FORMAT DELIMITED` mode the format is `INTERVAL '[-]?d h:m:s.n' DAY TO TIME`. In Hive `SERDE` mode, `YearMonthIntervalType` column is converted to `HiveIntervalYearMonth`, its string format is `[-]?y-m`, but in `ROW FORMAT DELIMITED` mode the format is `INTERVAL '[-]?y-m' YEAR TO MONTH`.
- In Spark 3.2, `hash(0) == hash(-0)` for floating point types. Previously, different values were generated.
## Upgrading from Spark SQL 3.0 to 3.1
- In Spark 3.1, statistical aggregation function includes `std`, `stddev`, `stddev_samp`, `variance`, `var_samp`, `skewness`, `kurtosis`, `covar_samp`, `corr` will return `NULL` instead of `Double.NaN` when `DivideByZero` occurs during expression evaluation, for example, when `stddev_samp` applied on a single element set. In Spark version 3.0 and earlier, it will return `Double.NaN` in such case. To restore the behavior before Spark 3.1, you can set `spark.sql.legacy.statisticalAggregate` to `true`.

View file

@ -369,11 +369,25 @@ abstract class HashExpression[E] extends Expression {
protected def genHashBoolean(input: String, result: String): String =
genHashInt(s"$input ? 1 : 0", result)
protected def genHashFloat(input: String, result: String): String =
genHashInt(s"Float.floatToIntBits($input)", result)
protected def genHashFloat(input: String, result: String): String = {
s"""
|if($input == -0.0f) {
| ${genHashInt("0", result)}
|} else {
| ${genHashInt(s"Float.floatToIntBits($input)", result)}
|}
""".stripMargin
}
protected def genHashDouble(input: String, result: String): String =
genHashLong(s"Double.doubleToLongBits($input)", result)
protected def genHashDouble(input: String, result: String): String = {
s"""
|if($input == -0.0d) {
| ${genHashLong("0L", result)}
|} else {
| ${genHashLong(s"Double.doubleToLongBits($input)", result)}
|}
""".stripMargin
}
protected def genHashDecimal(
ctx: CodegenContext,
@ -523,7 +537,9 @@ abstract class InterpretedHashFunction {
case s: Short => hashInt(s, seed)
case i: Int => hashInt(i, seed)
case l: Long => hashLong(l, seed)
case f: Float if (f == -0.0f) => hashInt(0, seed)
case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed)
case d: Double if (d == -0.0d) => hashLong(0L, seed)
case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed)
case d: Decimal =>
val precision = dataType.asInstanceOf[DecimalType].precision

View file

@ -708,6 +708,17 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(HiveHash(Seq(yearMonth)), 1234)
}
test("SPARK-35207: Compute hash consistent between -0.0 and 0.0") {
def checkResult(exprs1: Expression, exprs2: Expression): Unit = {
checkEvaluation(Murmur3Hash(Seq(exprs1), 42), Murmur3Hash(Seq(exprs2), 42).eval())
checkEvaluation(XxHash64(Seq(exprs1), 42), XxHash64(Seq(exprs2), 42).eval())
checkEvaluation(HiveHash(Seq(exprs1)), HiveHash(Seq(exprs2)).eval())
}
checkResult(Literal.create(-0D, DoubleType), Literal.create(0D, DoubleType))
checkResult(Literal.create(-0F, FloatType), Literal.create(0F, FloatType))
}
private def testHash(inputSchema: StructType): Unit = {
val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get
val toRow = RowEncoder(inputSchema).createSerializer()