[SPARK-36754][SQL] ArrayIntersect handle duplicated Double.NaN and Float.NaN
### What changes were proposed in this pull request?
For query
```
select array_intersect(array(cast('nan' as double), 1d), array(cast('nan' as double)))
```
This returns [NaN], but it should return [].
This issue is caused by `OpenHashSet` can't handle `Double.NaN` and `Float.NaN` too.
In this pr fix this based on https://github.com/apache/spark/pull/33955
### Why are the changes needed?
Fix bug
### Does this PR introduce _any_ user-facing change?
ArrayIntersect won't show equal `NaN` value
### How was this patch tested?
Added UT
Closes #33995 from AngersZhuuuu/SPARK-36754.
Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 2fc7f2f702
)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
b0249851f6
commit
337a1979d2
|
@ -3847,33 +3847,42 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
|
|||
if (TypeUtils.typeWithProperEquals(elementType)) {
|
||||
(array1, array2) =>
|
||||
if (array1.numElements() != 0 && array2.numElements() != 0) {
|
||||
val hs = new OpenHashSet[Any]
|
||||
val hsResult = new OpenHashSet[Any]
|
||||
var foundNullElement = false
|
||||
val hs = new SQLOpenHashSet[Any]
|
||||
val hsResult = new SQLOpenHashSet[Any]
|
||||
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
|
||||
val withArray2NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
|
||||
(value: Any) => hs.add(value),
|
||||
(valueNaN: Any) => {} )
|
||||
val withArray1NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hsResult,
|
||||
(value: Any) =>
|
||||
if (hs.contains(value) && !hsResult.contains(value)) {
|
||||
arrayBuffer += value
|
||||
hsResult.add(value)
|
||||
},
|
||||
(valueNaN: Any) =>
|
||||
if (hs.containsNaN()) {
|
||||
arrayBuffer += valueNaN
|
||||
})
|
||||
var i = 0
|
||||
while (i < array2.numElements()) {
|
||||
if (array2.isNullAt(i)) {
|
||||
foundNullElement = true
|
||||
hs.addNull()
|
||||
} else {
|
||||
val elem = array2.get(i, elementType)
|
||||
hs.add(elem)
|
||||
withArray2NaNCheckFunc(elem)
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
|
||||
i = 0
|
||||
while (i < array1.numElements()) {
|
||||
if (array1.isNullAt(i)) {
|
||||
if (foundNullElement) {
|
||||
if (hs.containsNull() && !hsResult.containsNull()) {
|
||||
arrayBuffer += null
|
||||
foundNullElement = false
|
||||
hsResult.addNull()
|
||||
}
|
||||
} else {
|
||||
val elem = array1.get(i, elementType)
|
||||
if (hs.contains(elem) && !hsResult.contains(elem)) {
|
||||
arrayBuffer += elem
|
||||
hsResult.add(elem)
|
||||
}
|
||||
withArray1NaNCheckFunc(elem)
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
|
@ -3948,10 +3957,9 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
|
|||
val ptName = CodeGenerator.primitiveTypeName(jt)
|
||||
|
||||
nullSafeCodeGen(ctx, ev, (array1, array2) => {
|
||||
val foundNullElement = ctx.freshName("foundNullElement")
|
||||
val nullElementIndex = ctx.freshName("nullElementIndex")
|
||||
val builder = ctx.freshName("builder")
|
||||
val openHashSet = classOf[OpenHashSet[_]].getName
|
||||
val openHashSet = classOf[SQLOpenHashSet[_]].getName
|
||||
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
|
||||
val hashSet = ctx.freshName("hashSet")
|
||||
val hashSetResult = ctx.freshName("hashSetResult")
|
||||
|
@ -3963,7 +3971,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
|
|||
if (left.dataType.asInstanceOf[ArrayType].containsNull) {
|
||||
s"""
|
||||
|if ($array2.isNullAt($i)) {
|
||||
| $foundNullElement = true;
|
||||
| $hashSet.addNull();
|
||||
|} else {
|
||||
| $body
|
||||
|}
|
||||
|
@ -3981,19 +3989,18 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
|
|||
}
|
||||
|
||||
val writeArray2ToHashSet = withArray2NullCheck(
|
||||
s"""
|
||||
|$jt $value = ${genGetValue(array2, i)};
|
||||
|$hashSet.add$hsPostFix($hsValueCast$value);
|
||||
""".stripMargin)
|
||||
s"$jt $value = ${genGetValue(array2, i)};" +
|
||||
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet,
|
||||
s"$hashSet.add$hsPostFix($hsValueCast$value);", (valueNaN: String) => ""))
|
||||
|
||||
def withArray1NullAssignment(body: String) =
|
||||
if (left.dataType.asInstanceOf[ArrayType].containsNull) {
|
||||
if (right.dataType.asInstanceOf[ArrayType].containsNull) {
|
||||
s"""
|
||||
|if ($array1.isNullAt($i)) {
|
||||
| if ($foundNullElement) {
|
||||
| if ($hashSet.containsNull() && !$hashSetResult.containsNull()) {
|
||||
| $nullElementIndex = $size;
|
||||
| $foundNullElement = false;
|
||||
| $hashSetResult.addNull();
|
||||
| $size++;
|
||||
| $builder.$$plus$$eq($nullValueHolder);
|
||||
| }
|
||||
|
@ -4012,9 +4019,8 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
|
|||
body
|
||||
}
|
||||
|
||||
val processArray1 = withArray1NullAssignment(
|
||||
val body =
|
||||
s"""
|
||||
|$jt $value = ${genGetValue(array1, i)};
|
||||
|if ($hashSet.contains($hsValueCast$value) &&
|
||||
| !$hashSetResult.contains($hsValueCast$value)) {
|
||||
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
|
||||
|
@ -4023,12 +4029,22 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
|
|||
| $hashSetResult.add$hsPostFix($hsValueCast$value);
|
||||
| $builder.$$plus$$eq($value);
|
||||
|}
|
||||
""".stripMargin)
|
||||
""".stripMargin
|
||||
|
||||
val processArray1 = withArray1NullAssignment(
|
||||
s"$jt $value = ${genGetValue(array1, i)};" +
|
||||
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSetResult, body,
|
||||
(valueNaN: Any) =>
|
||||
s"""
|
||||
|if ($hashSet.containsNaN()) {
|
||||
| ++$size;
|
||||
| $builder.$$plus$$eq($valueNaN);
|
||||
|}
|
||||
""".stripMargin))
|
||||
|
||||
// Only need to track null element index when result array's element is nullable.
|
||||
val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
|
||||
s"""
|
||||
|boolean $foundNullElement = false;
|
||||
|int $nullElementIndex = -1;
|
||||
""".stripMargin
|
||||
} else {
|
||||
|
|
|
@ -2310,6 +2310,23 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
Seq(Float.NaN, null, 1f))
|
||||
}
|
||||
|
||||
test("SPARK-36754: ArrayIntersect should handle duplicated Double.NaN and Float.Nan") {
|
||||
checkEvaluation(ArrayIntersect(
|
||||
Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN, 1d, 2d))),
|
||||
Seq(Double.NaN, 1d))
|
||||
checkEvaluation(ArrayIntersect(
|
||||
Literal.create(Seq(null, Double.NaN, null, 1d), ArrayType(DoubleType)),
|
||||
Literal.create(Seq(null, Double.NaN, null), ArrayType(DoubleType))),
|
||||
Seq(null, Double.NaN))
|
||||
checkEvaluation(ArrayIntersect(
|
||||
Literal.apply(Array(Float.NaN, 1f)), Literal.apply(Array(Float.NaN, 1f, 2f))),
|
||||
Seq(Float.NaN, 1f))
|
||||
checkEvaluation(ArrayIntersect(
|
||||
Literal.create(Seq(null, Float.NaN, null, 1f), ArrayType(FloatType)),
|
||||
Literal.create(Seq(null, Float.NaN, null), ArrayType(FloatType))),
|
||||
Seq(null, Float.NaN))
|
||||
}
|
||||
|
||||
test("SPARK-36741: ArrayDistinct should handle duplicated Double.NaN and Float.Nan") {
|
||||
checkEvaluation(ArrayDistinct(
|
||||
Literal.create(Seq(Double.NaN, Double.NaN, null, null, 1d, 1d), ArrayType(DoubleType))),
|
||||
|
|
Loading…
Reference in a new issue