From a78c4c44ed5f7b8296205100d62c60a780a5d5c2 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Fri, 17 Sep 2021 20:48:17 +0800 Subject: [PATCH] [SPARK-36741][SQL] ArrayDistinct handle duplicated Double.NaN and Float.Nan ### What changes were proposed in this pull request? For query ``` select array_distinct(array(cast('nan' as double), cast('nan' as double))) ``` This returns [NaN, NaN], but it should return [NaN]. This issue is caused by `OpenHashSet` can't handle `Double.NaN` and `Float.NaN` too. In this pr fix this based on https://github.com/apache/spark/pull/33955 ### Why are the changes needed? Fix bug ### Does this PR introduce _any_ user-facing change? ArrayDistinct won't show duplicated `NaN` value ### How was this patch tested? Added UT Closes #33993 from AngersZhuuuu/SPARK-36741. Authored-by: Angerszhuuuu Signed-off-by: Wenchen Fan (cherry picked from commit e356f6aa1119f4ceeafc7bcdea5f7b8f1f010638) Signed-off-by: Wenchen Fan --- .../expressions/collectionOperations.scala | 124 ++++++++++-------- .../spark/sql/util/SQLOpenHashSet.scala | 54 ++++++-- .../CollectionExpressionsSuite.scala | 9 ++ 3 files changed, 121 insertions(+), 66 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ba000a383e..a50263c852 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3412,32 +3412,59 @@ case class ArrayDistinct(child: Expression) } override def nullSafeEval(array: Any): Any = { - val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) + val data = array.asInstanceOf[ArrayData] doEvaluation(data) } @transient private lazy val doEvaluation = if (TypeUtils.typeWithProperEquals(elementType)) { - (data: Array[AnyRef]) => new GenericArrayData(data.distinct.asInstanceOf[Array[Any]]) + (array: ArrayData) => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val hs = new SQLOpenHashSet[Any]() + val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, + (value: Any) => + if (!hs.contains(value)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size) + } + arrayBuffer += value + hs.add(value) + }, + (valueNaN: Any) => arrayBuffer += valueNaN) + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + if (!hs.containsNull) { + hs.addNull + arrayBuffer += null + } + } else { + val elem = array.get(i, elementType) + withNaNCheckFunc(elem) + } + i += 1 + } + new GenericArrayData(arrayBuffer.toSeq) } else { - (data: Array[AnyRef]) => { + (data: ArrayData) => { + val array = data.toArray[AnyRef](elementType) val arrayBuffer = new scala.collection.mutable.ArrayBuffer[AnyRef] var alreadyStoredNull = false - for (i <- 0 until data.length) { - if (data(i) != null) { + for (i <- 0 until array.length) { + if (array(i) != null) { var found = false var j = 0 while (!found && j < arrayBuffer.size) { val va = arrayBuffer(j) - found = (va != null) && ordering.equiv(va, data(i)) + found = (va != null) && ordering.equiv(va, array(i)) j += 1 } if (!found) { - arrayBuffer += data(i) + arrayBuffer += array(i) } } else { // De-duplicate the null values. if (!alreadyStoredNull) { - arrayBuffer += data(i) + arrayBuffer += array(i) alreadyStoredNull = true } } @@ -3456,10 +3483,9 @@ case class ArrayDistinct(child: Expression) val ptName = CodeGenerator.primitiveTypeName(jt) nullSafeCodeGen(ctx, ev, (array) => { - val foundNullElement = ctx.freshName("foundNullElement") val nullElementIndex = ctx.freshName("nullElementIndex") val builder = ctx.freshName("builder") - val openHashSet = classOf[OpenHashSet[_]].getName + val openHashSet = classOf[SQLOpenHashSet[_]].getName val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" val hashSet = ctx.freshName("hashSet") val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName @@ -3468,7 +3494,6 @@ case class ArrayDistinct(child: Expression) // Only need to track null element index when array's element is nullable. val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { s""" - |boolean $foundNullElement = false; |int $nullElementIndex = -1; """.stripMargin } else { @@ -3479,9 +3504,9 @@ case class ArrayDistinct(child: Expression) if (dataType.asInstanceOf[ArrayType].containsNull) { s""" |if ($array.isNullAt($i)) { - | if (!$foundNullElement) { + | if (!$hashSet.containsNull()) { | $nullElementIndex = $size; - | $foundNullElement = true; + | $hashSet.addNull(); | $size++; | $builder.$$plus$$eq($nullValueHolder); | } @@ -3493,9 +3518,8 @@ case class ArrayDistinct(child: Expression) body } - val processArray = withArrayNullAssignment( + val body = s""" - |$jt $value = ${genGetValue(array, i)}; |if (!$hashSet.contains($hsValueCast$value)) { | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | break; @@ -3503,7 +3527,16 @@ case class ArrayDistinct(child: Expression) | $hashSet.add$hsPostFix($hsValueCast$value); | $builder.$$plus$$eq($value); |} - """.stripMargin) + """.stripMargin + + val processArray = withArrayNullAssignment( + s"$jt $value = ${genGetValue(array, i)};" + + SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, body, + (valueNaN: String) => + s""" + |$size++; + |$builder.$$plus$$eq($valueNaN); + |""".stripMargin)) s""" |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); @@ -3579,8 +3612,16 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi (array1, array2) => val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] val hs = new SQLOpenHashSet[Any]() - val isNaN = SQLOpenHashSet.isNaN(elementType) - val valueNaN = SQLOpenHashSet.valueNaN(elementType) + val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, + (value: Any) => + if (!hs.contains(value)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size) + } + arrayBuffer += value + hs.add(value) + }, + (valueNaN: Any) => arrayBuffer += valueNaN) Seq(array1, array2).foreach { array => var i = 0 while (i < array.numElements()) { @@ -3591,20 +3632,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi } } else { val elem = array.get(i, elementType) - if (isNaN(elem)) { - if (!hs.containsNaN) { - arrayBuffer += valueNaN - hs.addNaN - } - } else { - if (!hs.contains(elem)) { - if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size) - } - arrayBuffer += elem - hs.add(elem) - } - } + withNaNCheckFunc(elem) } i += 1 } @@ -3689,28 +3717,6 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi body } - def withNaNCheck(body: String): String = { - (elementType match { - case DoubleType => - Some((s"java.lang.Double.isNaN((double)$value)", "java.lang.Double.NaN")) - case FloatType => - Some((s"java.lang.Float.isNaN((float)$value)", "java.lang.Float.NaN")) - case _ => None - }).map { case (isNaN, valueNaN) => - s""" - |if ($isNaN) { - | if (!$hashSet.containsNaN()) { - | $size++; - | $hashSet.addNaN(); - | $builder.$$plus$$eq($valueNaN); - | } - |} else { - | $body - |} - """.stripMargin - } - }.getOrElse(body) - val body = s""" |if (!$hashSet.contains($hsValueCast$value)) { @@ -3721,8 +3727,14 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi | $builder.$$plus$$eq($value); |} """.stripMargin - val processArray = - withArrayNullAssignment(s"$jt $value = ${genGetValue(array, i)};" + withNaNCheck(body)) + val processArray = withArrayNullAssignment( + s"$jt $value = ${genGetValue(array, i)};" + + SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, body, + (valueNaN: String) => + s""" + |$size++; + |$builder.$$plus$$eq($valueNaN); + """.stripMargin)) // Only need to track null element index when result array's element is nullable. val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala index 083cfddf07..e09cd95db5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala @@ -60,21 +60,55 @@ class SQLOpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( } object SQLOpenHashSet { - def isNaN(dataType: DataType): Any => Boolean = { - dataType match { + def withNaNCheckFunc( + dataType: DataType, + hashSet: SQLOpenHashSet[Any], + handleNotNaN: Any => Unit, + handleNaN: Any => Unit): Any => Unit = { + val (isNaN, valueNaN) = dataType match { case DoubleType => - (value: Any) => java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double]) + ((value: Any) => java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double]), + java.lang.Double.NaN) case FloatType => - (value: Any) => java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float]) - case _ => (_: Any) => false + ((value: Any) => java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float]), + java.lang.Float.NaN) + case _ => ((_: Any) => false, null) } + (value: Any) => + if (isNaN(value)) { + if (!hashSet.containsNaN) { + hashSet.addNaN + handleNaN(valueNaN) + } + } else { + handleNotNaN(value) + } } - def valueNaN(dataType: DataType): Any = { - dataType match { - case DoubleType => java.lang.Double.NaN - case FloatType => java.lang.Float.NaN - case _ => null + def withNaNCheckCode( + dataType: DataType, + valueName: String, + hashSet: String, + handleNotNaN: String, + handleNaN: String => String): String = { + val ret = dataType match { + case DoubleType => + Some((s"java.lang.Double.isNaN((double)$valueName)", "java.lang.Double.NaN")) + case FloatType => + Some((s"java.lang.Float.isNaN((float)$valueName)", "java.lang.Float.NaN")) + case _ => None } + ret.map { case (isNaN, valueNaN) => + s""" + |if ($isNaN) { + | if (!$hashSet.containsNaN()) { + | $hashSet.addNaN(); + | ${handleNaN(valueNaN)} + | } + |} else { + | $handleNotNaN + |} + """.stripMargin + }.getOrElse(handleNotNaN) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index caca24a212..62098bc840 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -2310,6 +2310,15 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq(Float.NaN, null, 1f)) } + test("SPARK-36741: ArrayDistinct should handle duplicated Double.NaN and Float.Nan") { + checkEvaluation(ArrayDistinct( + Literal.create(Seq(Double.NaN, Double.NaN, null, null, 1d, 1d), ArrayType(DoubleType))), + Seq(Double.NaN, null, 1d)) + checkEvaluation(ArrayDistinct( + Literal.create(Seq(Float.NaN, Float.NaN, null, null, 1f, 1f), ArrayType(FloatType))), + Seq(Float.NaN, null, 1f)) + } + test("SPARK-36755: ArraysOverlap hould handle duplicated Double.NaN and Float.Nan") { checkEvaluation(ArraysOverlap( Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN))), true)