[SPARK-35207][SQL] Normalize hash function behavior with negative zero (floating point types)
### What changes were proposed in this pull request? Generally, we would expect that x = y => hash( x ) = hash( y ). However +-0 hash to different values for floating point types. ``` scala> spark.sql("select hash(cast('0.0' as double)), hash(cast('-0.0' as double))").show +-------------------------+--------------------------+ |hash(CAST(0.0 AS DOUBLE))|hash(CAST(-0.0 AS DOUBLE))| +-------------------------+--------------------------+ | -1670924195| -853646085| +-------------------------+--------------------------+ scala> spark.sql("select cast('0.0' as double) == cast('-0.0' as double)").show +--------------------------------------------+ |(CAST(0.0 AS DOUBLE) = CAST(-0.0 AS DOUBLE))| +--------------------------------------------+ | true| +--------------------------------------------+ ``` Here is an extract from IEEE 754: > The two zeros are distinguishable arithmetically only by either division-byzero ( producing appropriately signed infinities ) or else by the CopySign function recommended by IEEE 754 /854. Infinities, SNaNs, NaNs and Subnormal numbers necessitate four more special cases From this, I deduce that the hash function must produce the same result for 0 and -0. ### Why are the changes needed? It is a correctness issue ### Does this PR introduce _any_ user-facing change? This changes only affect to the hash function applied to -0 value in float and double types ### How was this patch tested? Unit testing and manual testing Closes #32496 from planga82/feature/spark35207_hashnegativezero. Authored-by: Pablo Langa <soypab@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
f7af9ab8dc
commit
9ea55fe771
|
@ -87,6 +87,8 @@ license: |
|
|||
|
||||
- In Spark 3.2, Spark supports `DayTimeIntervalType` and `YearMonthIntervalType` as inputs and outputs of `TRANSFORM` clause in Hive `SERDE` mode, the behavior is different between Hive `SERDE` mode and `ROW FORMAT DELIMITED` mode when these two types are used as inputs. In Hive `SERDE` mode, `DayTimeIntervalType` column is converted to `HiveIntervalDayTime`, its string format is `[-]?d h:m:s.n`, but in `ROW FORMAT DELIMITED` mode the format is `INTERVAL '[-]?d h:m:s.n' DAY TO TIME`. In Hive `SERDE` mode, `YearMonthIntervalType` column is converted to `HiveIntervalYearMonth`, its string format is `[-]?y-m`, but in `ROW FORMAT DELIMITED` mode the format is `INTERVAL '[-]?y-m' YEAR TO MONTH`.
|
||||
|
||||
- In Spark 3.2, `hash(0) == hash(-0)` for floating point types. Previously, different values were generated.
|
||||
|
||||
## Upgrading from Spark SQL 3.0 to 3.1
|
||||
|
||||
- In Spark 3.1, statistical aggregation function includes `std`, `stddev`, `stddev_samp`, `variance`, `var_samp`, `skewness`, `kurtosis`, `covar_samp`, `corr` will return `NULL` instead of `Double.NaN` when `DivideByZero` occurs during expression evaluation, for example, when `stddev_samp` applied on a single element set. In Spark version 3.0 and earlier, it will return `Double.NaN` in such case. To restore the behavior before Spark 3.1, you can set `spark.sql.legacy.statisticalAggregate` to `true`.
|
||||
|
|
|
@ -369,11 +369,25 @@ abstract class HashExpression[E] extends Expression {
|
|||
protected def genHashBoolean(input: String, result: String): String =
|
||||
genHashInt(s"$input ? 1 : 0", result)
|
||||
|
||||
protected def genHashFloat(input: String, result: String): String =
|
||||
genHashInt(s"Float.floatToIntBits($input)", result)
|
||||
protected def genHashFloat(input: String, result: String): String = {
|
||||
s"""
|
||||
|if($input == -0.0f) {
|
||||
| ${genHashInt("0", result)}
|
||||
|} else {
|
||||
| ${genHashInt(s"Float.floatToIntBits($input)", result)}
|
||||
|}
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
protected def genHashDouble(input: String, result: String): String =
|
||||
genHashLong(s"Double.doubleToLongBits($input)", result)
|
||||
protected def genHashDouble(input: String, result: String): String = {
|
||||
s"""
|
||||
|if($input == -0.0d) {
|
||||
| ${genHashLong("0L", result)}
|
||||
|} else {
|
||||
| ${genHashLong(s"Double.doubleToLongBits($input)", result)}
|
||||
|}
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
protected def genHashDecimal(
|
||||
ctx: CodegenContext,
|
||||
|
@ -523,7 +537,9 @@ abstract class InterpretedHashFunction {
|
|||
case s: Short => hashInt(s, seed)
|
||||
case i: Int => hashInt(i, seed)
|
||||
case l: Long => hashLong(l, seed)
|
||||
case f: Float if (f == -0.0f) => hashInt(0, seed)
|
||||
case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed)
|
||||
case d: Double if (d == -0.0d) => hashLong(0L, seed)
|
||||
case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed)
|
||||
case d: Decimal =>
|
||||
val precision = dataType.asInstanceOf[DecimalType].precision
|
||||
|
|
|
@ -708,6 +708,17 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(HiveHash(Seq(yearMonth)), 1234)
|
||||
}
|
||||
|
||||
test("SPARK-35207: Compute hash consistent between -0.0 and 0.0") {
|
||||
def checkResult(exprs1: Expression, exprs2: Expression): Unit = {
|
||||
checkEvaluation(Murmur3Hash(Seq(exprs1), 42), Murmur3Hash(Seq(exprs2), 42).eval())
|
||||
checkEvaluation(XxHash64(Seq(exprs1), 42), XxHash64(Seq(exprs2), 42).eval())
|
||||
checkEvaluation(HiveHash(Seq(exprs1)), HiveHash(Seq(exprs2)).eval())
|
||||
}
|
||||
|
||||
checkResult(Literal.create(-0D, DoubleType), Literal.create(0D, DoubleType))
|
||||
checkResult(Literal.create(-0F, FloatType), Literal.create(0F, FloatType))
|
||||
}
|
||||
|
||||
private def testHash(inputSchema: StructType): Unit = {
|
||||
val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get
|
||||
val toRow = RowEncoder(inputSchema).createSerializer()
|
||||
|
|
Loading…
Reference in a new issue