From 2fc7f2f702c6c08d9c76332f45e2902728ba2ee3 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Mon, 20 Sep 2021 16:48:59 +0800 Subject: [PATCH] [SPARK-36754][SQL] ArrayIntersect handle duplicated Double.NaN and Float.NaN ### What changes were proposed in this pull request? For query ``` select array_intersect(array(cast('nan' as double), 1d), array(cast('nan' as double))) ``` This returns [NaN], but it should return []. This issue is caused by `OpenHashSet` can't handle `Double.NaN` and `Float.NaN` too. In this pr fix this based on https://github.com/apache/spark/pull/33955 ### Why are the changes needed? Fix bug ### Does this PR introduce _any_ user-facing change? ArrayIntersect won't show equal `NaN` value ### How was this patch tested? Added UT Closes #33995 from AngersZhuuuu/SPARK-36754. Authored-by: Angerszhuuuu Signed-off-by: Wenchen Fan --- .../expressions/collectionOperations.scala | 66 ++++++++++++------- .../CollectionExpressionsSuite.scala | 17 +++++ 2 files changed, 58 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 1182194e4c..b325e9aebb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3847,33 +3847,42 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina if (TypeUtils.typeWithProperEquals(elementType)) { (array1, array2) => if (array1.numElements() != 0 && array2.numElements() != 0) { - val hs = new OpenHashSet[Any] - val hsResult = new OpenHashSet[Any] - var foundNullElement = false + val hs = new SQLOpenHashSet[Any] + val hsResult = new SQLOpenHashSet[Any] + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val withArray2NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, + (value: Any) => hs.add(value), + (valueNaN: Any) => {} ) + val withArray1NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hsResult, + (value: Any) => + if (hs.contains(value) && !hsResult.contains(value)) { + arrayBuffer += value + hsResult.add(value) + }, + (valueNaN: Any) => + if (hs.containsNaN()) { + arrayBuffer += valueNaN + }) var i = 0 while (i < array2.numElements()) { if (array2.isNullAt(i)) { - foundNullElement = true + hs.addNull() } else { val elem = array2.get(i, elementType) - hs.add(elem) + withArray2NaNCheckFunc(elem) } i += 1 } - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] i = 0 while (i < array1.numElements()) { if (array1.isNullAt(i)) { - if (foundNullElement) { + if (hs.containsNull() && !hsResult.containsNull()) { arrayBuffer += null - foundNullElement = false + hsResult.addNull() } } else { val elem = array1.get(i, elementType) - if (hs.contains(elem) && !hsResult.contains(elem)) { - arrayBuffer += elem - hsResult.add(elem) - } + withArray1NaNCheckFunc(elem) } i += 1 } @@ -3948,10 +3957,9 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina val ptName = CodeGenerator.primitiveTypeName(jt) nullSafeCodeGen(ctx, ev, (array1, array2) => { - val foundNullElement = ctx.freshName("foundNullElement") val nullElementIndex = ctx.freshName("nullElementIndex") val builder = ctx.freshName("builder") - val openHashSet = classOf[OpenHashSet[_]].getName + val openHashSet = classOf[SQLOpenHashSet[_]].getName val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" val hashSet = ctx.freshName("hashSet") val hashSetResult = ctx.freshName("hashSetResult") @@ -3963,7 +3971,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina if (left.dataType.asInstanceOf[ArrayType].containsNull) { s""" |if ($array2.isNullAt($i)) { - | $foundNullElement = true; + | $hashSet.addNull(); |} else { | $body |} @@ -3981,19 +3989,18 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina } val writeArray2ToHashSet = withArray2NullCheck( - s""" - |$jt $value = ${genGetValue(array2, i)}; - |$hashSet.add$hsPostFix($hsValueCast$value); - """.stripMargin) + s"$jt $value = ${genGetValue(array2, i)};" + + SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, + s"$hashSet.add$hsPostFix($hsValueCast$value);", (valueNaN: String) => "")) def withArray1NullAssignment(body: String) = if (left.dataType.asInstanceOf[ArrayType].containsNull) { if (right.dataType.asInstanceOf[ArrayType].containsNull) { s""" |if ($array1.isNullAt($i)) { - | if ($foundNullElement) { + | if ($hashSet.containsNull() && !$hashSetResult.containsNull()) { | $nullElementIndex = $size; - | $foundNullElement = false; + | $hashSetResult.addNull(); | $size++; | $builder.$$plus$$eq($nullValueHolder); | } @@ -4012,9 +4019,8 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina body } - val processArray1 = withArray1NullAssignment( + val body = s""" - |$jt $value = ${genGetValue(array1, i)}; |if ($hashSet.contains($hsValueCast$value) && | !$hashSetResult.contains($hsValueCast$value)) { | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { @@ -4023,12 +4029,22 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina | $hashSetResult.add$hsPostFix($hsValueCast$value); | $builder.$$plus$$eq($value); |} - """.stripMargin) + """.stripMargin + + val processArray1 = withArray1NullAssignment( + s"$jt $value = ${genGetValue(array1, i)};" + + SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSetResult, body, + (valueNaN: Any) => + s""" + |if ($hashSet.containsNaN()) { + | ++$size; + | $builder.$$plus$$eq($valueNaN); + |} + """.stripMargin)) // Only need to track null element index when result array's element is nullable. val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { s""" - |boolean $foundNullElement = false; |int $nullElementIndex = -1; """.stripMargin } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index f9431f2c60..aa3c46c3d0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -2327,6 +2327,23 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq(Float.NaN, null, 1f)) } + test("SPARK-36754: ArrayIntersect should handle duplicated Double.NaN and Float.Nan") { + checkEvaluation(ArrayIntersect( + Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN, 1d, 2d))), + Seq(Double.NaN, 1d)) + checkEvaluation(ArrayIntersect( + Literal.create(Seq(null, Double.NaN, null, 1d), ArrayType(DoubleType)), + Literal.create(Seq(null, Double.NaN, null), ArrayType(DoubleType))), + Seq(null, Double.NaN)) + checkEvaluation(ArrayIntersect( + Literal.apply(Array(Float.NaN, 1f)), Literal.apply(Array(Float.NaN, 1f, 2f))), + Seq(Float.NaN, 1f)) + checkEvaluation(ArrayIntersect( + Literal.create(Seq(null, Float.NaN, null, 1f), ArrayType(FloatType)), + Literal.create(Seq(null, Float.NaN, null), ArrayType(FloatType))), + Seq(null, Float.NaN)) + } + test("SPARK-36741: ArrayDistinct should handle duplicated Double.NaN and Float.Nan") { checkEvaluation(ArrayDistinct( Literal.create(Seq(Double.NaN, Double.NaN, null, null, 1d, 1d), ArrayType(DoubleType))),