[SPARK-36702][SQL] ArrayUnion handle duplicated Double.NaN and Float.Nan

### What changes were proposed in this pull request?
For query
```
select array_union(array(cast('nan' as double), cast('nan' as double)), array())
```
This returns [NaN, NaN], but it should return [NaN].
This issue is caused by `OpenHashSet` can't handle `Double.NaN` and `Float.NaN` too.
In this pr we add a wrap for OpenHashSet that can handle `null`, `Double.NaN`, `Float.NaN` together

### Why are the changes needed?
Fix bug

### Does this PR introduce _any_ user-facing change?
ArrayUnion won't show duplicated `NaN` value

### How was this patch tested?
Added UT

Closes #33955 from AngersZhuuuu/SPARK-36702-WrapOpenHashSet.

Lead-authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Co-authored-by: AngersZhuuuu <angers.zhu@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Angerszhuuuu 2021-09-14 18:25:47 +08:00 committed by Wenchen Fan
parent 2d7dc7c7ce
commit f71f37755d
3 changed files with 133 additions and 17 deletions

View file

@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SQLOpenHashSet
import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
@ -3575,18 +3576,24 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
if (TypeUtils.typeWithProperEquals(elementType)) { if (TypeUtils.typeWithProperEquals(elementType)) {
(array1, array2) => (array1, array2) =>
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
val hs = new OpenHashSet[Any] val hs = new SQLOpenHashSet[Any]()
var foundNullElement = false val isNaN = SQLOpenHashSet.isNaN(elementType)
Seq(array1, array2).foreach { array => Seq(array1, array2).foreach { array =>
var i = 0 var i = 0
while (i < array.numElements()) { while (i < array.numElements()) {
if (array.isNullAt(i)) { if (array.isNullAt(i)) {
if (!foundNullElement) { if (!hs.containsNull) {
hs.addNull
arrayBuffer += null arrayBuffer += null
foundNullElement = true
} }
} else { } else {
val elem = array.get(i, elementType) val elem = array.get(i, elementType)
if (isNaN(elem)) {
if (!hs.containsNaN) {
arrayBuffer += elem
hs.addNaN
}
} else {
if (!hs.contains(elem)) { if (!hs.contains(elem)) {
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size) ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
@ -3595,6 +3602,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
hs.add(elem) hs.add(elem)
} }
} }
}
i += 1 i += 1
} }
} }
@ -3649,13 +3657,12 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
val ptName = CodeGenerator.primitiveTypeName(jt) val ptName = CodeGenerator.primitiveTypeName(jt)
nullSafeCodeGen(ctx, ev, (array1, array2) => { nullSafeCodeGen(ctx, ev, (array1, array2) => {
val foundNullElement = ctx.freshName("foundNullElement")
val nullElementIndex = ctx.freshName("nullElementIndex") val nullElementIndex = ctx.freshName("nullElementIndex")
val builder = ctx.freshName("builder") val builder = ctx.freshName("builder")
val array = ctx.freshName("array") val array = ctx.freshName("array")
val arrays = ctx.freshName("arrays") val arrays = ctx.freshName("arrays")
val arrayDataIdx = ctx.freshName("arrayDataIdx") val arrayDataIdx = ctx.freshName("arrayDataIdx")
val openHashSet = classOf[OpenHashSet[_]].getName val openHashSet = classOf[SQLOpenHashSet[_]].getName
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
val hashSet = ctx.freshName("hashSet") val hashSet = ctx.freshName("hashSet")
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
@ -3665,9 +3672,9 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
if (dataType.asInstanceOf[ArrayType].containsNull) { if (dataType.asInstanceOf[ArrayType].containsNull) {
s""" s"""
|if ($array.isNullAt($i)) { |if ($array.isNullAt($i)) {
| if (!$foundNullElement) { | if (!$hashSet.containsNull()) {
| $nullElementIndex = $size; | $nullElementIndex = $size;
| $foundNullElement = true; | $hashSet.addNull();
| $size++; | $size++;
| $builder.$$plus$$eq($nullValueHolder); | $builder.$$plus$$eq($nullValueHolder);
| } | }
@ -3679,9 +3686,28 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
body body
} }
val processArray = withArrayNullAssignment( def withNaNCheck(body: String): String = {
(elementType match {
case DoubleType => Some(s"java.lang.Double.isNaN((double)$value)")
case FloatType => Some(s"java.lang.Float.isNaN((float)$value)")
case _ => None
}).map { isNaN =>
s"""
|if ($isNaN) {
| if (!$hashSet.containsNaN()) {
| $size++;
| $hashSet.addNaN();
| $builder.$$plus$$eq($value);
| }
|} else {
| $body
|}
""".stripMargin
}
}.getOrElse(body)
val body =
s""" s"""
|$jt $value = ${genGetValue(array, i)};
|if (!$hashSet.contains($hsValueCast$value)) { |if (!$hashSet.contains($hsValueCast$value)) {
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| break; | break;
@ -3689,12 +3715,13 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
| $hashSet.add$hsPostFix($hsValueCast$value); | $hashSet.add$hsPostFix($hsValueCast$value);
| $builder.$$plus$$eq($value); | $builder.$$plus$$eq($value);
|} |}
""".stripMargin) """.stripMargin
val processArray =
withArrayNullAssignment(s"$jt $value = ${genGetValue(array, i)};" + withNaNCheck(body))
// Only need to track null element index when result array's element is nullable. // Only need to track null element index when result array's element is nullable.
val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
s""" s"""
|boolean $foundNullElement = false;
|int $nullElementIndex = -1; |int $nullElementIndex = -1;
""".stripMargin """.stripMargin
} else { } else {

View file

@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.util
import scala.reflect._
import org.apache.spark.annotation.Private
import org.apache.spark.sql.types.{DataType, DoubleType, FloatType}
import org.apache.spark.util.collection.OpenHashSet
// A wrap of OpenHashSet that can handle null, Double.NaN and Float.NaN w.r.t. the SQL semantic.
@Private
class SQLOpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
initialCapacity: Int,
loadFactor: Double) {
def this(initialCapacity: Int) = this(initialCapacity, 0.7)
def this() = this(64)
private val hashSet = new OpenHashSet[T](initialCapacity, loadFactor)
private var containNull = false
private var containNaN = false
def addNull(): Unit = {
containNull = true
}
def addNaN(): Unit = {
containNaN = true
}
def add(k: T): Unit = {
hashSet.add(k)
}
def contains(k: T): Boolean = {
hashSet.contains(k)
}
def containsNull(): Boolean = containNull
def containsNaN(): Boolean = containNaN
}
object SQLOpenHashSet {
def isNaN(dataType: DataType): Any => Boolean = {
dataType match {
case DoubleType =>
(value: Any) => java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double])
case FloatType =>
(value: Any) => java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float])
case _ => (_: Any) => false
}
}
}

View file

@ -2309,4 +2309,21 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
} }
} }
} }
test("SPARK-36702: ArrayUnion should handle duplicated Double.NaN and Float.Nan") {
checkEvaluation(ArrayUnion(
Literal.apply(Array(Double.NaN, Double.NaN)), Literal.apply(Array(1d))),
Seq(Double.NaN, 1d))
checkEvaluation(ArrayUnion(
Literal.create(Seq(Double.NaN, null), ArrayType(DoubleType)),
Literal.create(Seq(Double.NaN, null, 1d), ArrayType(DoubleType))),
Seq(Double.NaN, null, 1d))
checkEvaluation(ArrayUnion(
Literal.apply(Array(Float.NaN, Float.NaN)), Literal.apply(Array(1f))),
Seq(Float.NaN, 1f))
checkEvaluation(ArrayUnion(
Literal.create(Seq(Float.NaN, null), ArrayType(FloatType)),
Literal.create(Seq(Float.NaN, null, 1f), ArrayType(FloatType))),
Seq(Float.NaN, null, 1f))
}
} }