[SPARK-36755][SQL] ArraysOverlap should handle duplicated Double.NaN and Float.NaN

### What changes were proposed in this pull request?
For query
```
select arrays_overlap(array(cast('nan' as double), 1d), array(cast('nan' as double)))
```
This returns [false], but it should return [true].
This issue is caused by `scala.mutable.HashSet` can't handle `Double.NaN` and `Float.NaN`.

### Why are the changes needed?
Fix bug

### Does this PR introduce _any_ user-facing change?
arrays_overlap won't handle equal `NaN` value

### How was this patch tested?
Added UT

Closes #34006 from AngersZhuuuu/SPARK-36755.

Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Angerszhuuuu 2021-09-15 22:31:46 +08:00 committed by Wenchen Fan
parent 638085953f
commit b665782f0d
2 changed files with 15 additions and 2 deletions

View file

@ -1297,12 +1297,12 @@ case class ArraysOverlap(left: Expression, right: Expression)
(arr2, arr1)
}
if (smaller.numElements() > 0) {
val smallestSet = new mutable.HashSet[Any]
val smallestSet = new java.util.HashSet[Any]()
smaller.foreach(elementType, (_, v) =>
if (v == null) {
hasNull = true
} else {
smallestSet += v
smallestSet.add(v)
})
bigger.foreach(elementType, (_, v1) =>
if (v1 == null) {

View file

@ -2326,4 +2326,17 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
Literal.create(Seq(Float.NaN, null, 1f), ArrayType(FloatType))),
Seq(Float.NaN, null, 1f))
}
test("SPARK-36755: ArraysOverlap hould handle duplicated Double.NaN and Float.Nan") {
checkEvaluation(ArraysOverlap(
Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN))), true)
checkEvaluation(ArraysOverlap(
Literal.create(Seq(Double.NaN, null), ArrayType(DoubleType)),
Literal.create(Seq(Double.NaN, null, 1d), ArrayType(DoubleType))), true)
checkEvaluation(ArraysOverlap(
Literal.apply(Array(Float.NaN)), Literal.apply(Array(Float.NaN, 1f))), true)
checkEvaluation(ArraysOverlap(
Literal.create(Seq(Float.NaN, null), ArrayType(FloatType)),
Literal.create(Seq(Float.NaN, null, 1f), ArrayType(FloatType))), true)
}
}