[SPARK-36754][SQL] ArrayIntersect handle duplicated Double.NaN and Float.NaN

### What changes were proposed in this pull request?
For query
```
select array_intersect(array(cast('nan' as double), 1d), array(cast('nan' as double)))
```
This returns [NaN], but it should return [].
This issue is caused by `OpenHashSet` can't handle `Double.NaN` and `Float.NaN` too.
In this pr fix this based on https://github.com/apache/spark/pull/33955

### Why are the changes needed?
Fix bug

### Does this PR introduce _any_ user-facing change?
ArrayIntersect won't show equal `NaN` value

### How was this patch tested?
Added UT

Closes #33995 from AngersZhuuuu/SPARK-36754.

Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Angerszhuuuu 2021-09-20 16:48:59 +08:00 committed by Wenchen Fan
parent a396dd6216
commit 2fc7f2f702
2 changed files with 58 additions and 25 deletions

View file

@ -3847,33 +3847,42 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
if (TypeUtils.typeWithProperEquals(elementType)) {
(array1, array2) =>
if (array1.numElements() != 0 && array2.numElements() != 0) {
val hs = new OpenHashSet[Any]
val hsResult = new OpenHashSet[Any]
var foundNullElement = false
val hs = new SQLOpenHashSet[Any]
val hsResult = new SQLOpenHashSet[Any]
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
val withArray2NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
(value: Any) => hs.add(value),
(valueNaN: Any) => {} )
val withArray1NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hsResult,
(value: Any) =>
if (hs.contains(value) && !hsResult.contains(value)) {
arrayBuffer += value
hsResult.add(value)
},
(valueNaN: Any) =>
if (hs.containsNaN()) {
arrayBuffer += valueNaN
})
var i = 0
while (i < array2.numElements()) {
if (array2.isNullAt(i)) {
foundNullElement = true
hs.addNull()
} else {
val elem = array2.get(i, elementType)
hs.add(elem)
withArray2NaNCheckFunc(elem)
}
i += 1
}
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
i = 0
while (i < array1.numElements()) {
if (array1.isNullAt(i)) {
if (foundNullElement) {
if (hs.containsNull() && !hsResult.containsNull()) {
arrayBuffer += null
foundNullElement = false
hsResult.addNull()
}
} else {
val elem = array1.get(i, elementType)
if (hs.contains(elem) && !hsResult.contains(elem)) {
arrayBuffer += elem
hsResult.add(elem)
}
withArray1NaNCheckFunc(elem)
}
i += 1
}
@ -3948,10 +3957,9 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
val ptName = CodeGenerator.primitiveTypeName(jt)
nullSafeCodeGen(ctx, ev, (array1, array2) => {
val foundNullElement = ctx.freshName("foundNullElement")
val nullElementIndex = ctx.freshName("nullElementIndex")
val builder = ctx.freshName("builder")
val openHashSet = classOf[OpenHashSet[_]].getName
val openHashSet = classOf[SQLOpenHashSet[_]].getName
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
val hashSet = ctx.freshName("hashSet")
val hashSetResult = ctx.freshName("hashSetResult")
@ -3963,7 +3971,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
if (left.dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|if ($array2.isNullAt($i)) {
| $foundNullElement = true;
| $hashSet.addNull();
|} else {
| $body
|}
@ -3981,19 +3989,18 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
}
val writeArray2ToHashSet = withArray2NullCheck(
s"""
|$jt $value = ${genGetValue(array2, i)};
|$hashSet.add$hsPostFix($hsValueCast$value);
""".stripMargin)
s"$jt $value = ${genGetValue(array2, i)};" +
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet,
s"$hashSet.add$hsPostFix($hsValueCast$value);", (valueNaN: String) => ""))
def withArray1NullAssignment(body: String) =
if (left.dataType.asInstanceOf[ArrayType].containsNull) {
if (right.dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|if ($array1.isNullAt($i)) {
| if ($foundNullElement) {
| if ($hashSet.containsNull() && !$hashSetResult.containsNull()) {
| $nullElementIndex = $size;
| $foundNullElement = false;
| $hashSetResult.addNull();
| $size++;
| $builder.$$plus$$eq($nullValueHolder);
| }
@ -4012,9 +4019,8 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
body
}
val processArray1 = withArray1NullAssignment(
val body =
s"""
|$jt $value = ${genGetValue(array1, i)};
|if ($hashSet.contains($hsValueCast$value) &&
| !$hashSetResult.contains($hsValueCast$value)) {
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
@ -4023,12 +4029,22 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
| $hashSetResult.add$hsPostFix($hsValueCast$value);
| $builder.$$plus$$eq($value);
|}
""".stripMargin)
""".stripMargin
val processArray1 = withArray1NullAssignment(
s"$jt $value = ${genGetValue(array1, i)};" +
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSetResult, body,
(valueNaN: Any) =>
s"""
|if ($hashSet.containsNaN()) {
| ++$size;
| $builder.$$plus$$eq($valueNaN);
|}
""".stripMargin))
// Only need to track null element index when result array's element is nullable.
val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|boolean $foundNullElement = false;
|int $nullElementIndex = -1;
""".stripMargin
} else {

View file

@ -2327,6 +2327,23 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
Seq(Float.NaN, null, 1f))
}
test("SPARK-36754: ArrayIntersect should handle duplicated Double.NaN and Float.Nan") {
checkEvaluation(ArrayIntersect(
Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN, 1d, 2d))),
Seq(Double.NaN, 1d))
checkEvaluation(ArrayIntersect(
Literal.create(Seq(null, Double.NaN, null, 1d), ArrayType(DoubleType)),
Literal.create(Seq(null, Double.NaN, null), ArrayType(DoubleType))),
Seq(null, Double.NaN))
checkEvaluation(ArrayIntersect(
Literal.apply(Array(Float.NaN, 1f)), Literal.apply(Array(Float.NaN, 1f, 2f))),
Seq(Float.NaN, 1f))
checkEvaluation(ArrayIntersect(
Literal.create(Seq(null, Float.NaN, null, 1f), ArrayType(FloatType)),
Literal.create(Seq(null, Float.NaN, null), ArrayType(FloatType))),
Seq(null, Float.NaN))
}
test("SPARK-36741: ArrayDistinct should handle duplicated Double.NaN and Float.Nan") {
checkEvaluation(ArrayDistinct(
Literal.create(Seq(Double.NaN, Double.NaN, null, null, 1d, 1d), ArrayType(DoubleType))),