[SPARK-36702][SQL] ArrayUnion handle duplicated Double.NaN and Float.Nan
### What changes were proposed in this pull request?
For query
```
select array_union(array(cast('nan' as double), cast('nan' as double)), array())
```
This returns [NaN, NaN], but it should return [NaN].
This issue is caused by `OpenHashSet` can't handle `Double.NaN` and `Float.NaN` too.
In this pr we add a wrap for OpenHashSet that can handle `null`, `Double.NaN`, `Float.NaN` together
### Why are the changes needed?
Fix bug
### Does this PR introduce _any_ user-facing change?
ArrayUnion won't show duplicated `NaN` value
### How was this patch tested?
Added UT
Closes #33955 from AngersZhuuuu/SPARK-36702-WrapOpenHashSet.
Lead-authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Co-authored-by: AngersZhuuuu <angers.zhu@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit f71f37755d
)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
4a486f40cf
commit
a472612eb8
|
@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils._
|
|||
import org.apache.spark.sql.errors.QueryExecutionErrors
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.util.SQLOpenHashSet
|
||||
import org.apache.spark.unsafe.UTF8StringBuilder
|
||||
import org.apache.spark.unsafe.array.ByteArrayMethods
|
||||
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
|
||||
|
@ -3575,18 +3576,24 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
|||
if (TypeUtils.typeWithProperEquals(elementType)) {
|
||||
(array1, array2) =>
|
||||
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
|
||||
val hs = new OpenHashSet[Any]
|
||||
var foundNullElement = false
|
||||
val hs = new SQLOpenHashSet[Any]()
|
||||
val isNaN = SQLOpenHashSet.isNaN(elementType)
|
||||
Seq(array1, array2).foreach { array =>
|
||||
var i = 0
|
||||
while (i < array.numElements()) {
|
||||
if (array.isNullAt(i)) {
|
||||
if (!foundNullElement) {
|
||||
if (!hs.containsNull) {
|
||||
hs.addNull
|
||||
arrayBuffer += null
|
||||
foundNullElement = true
|
||||
}
|
||||
} else {
|
||||
val elem = array.get(i, elementType)
|
||||
if (isNaN(elem)) {
|
||||
if (!hs.containsNaN) {
|
||||
arrayBuffer += elem
|
||||
hs.addNaN
|
||||
}
|
||||
} else {
|
||||
if (!hs.contains(elem)) {
|
||||
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
|
||||
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
|
||||
|
@ -3595,6 +3602,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
|||
hs.add(elem)
|
||||
}
|
||||
}
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
}
|
||||
|
@ -3649,13 +3657,12 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
|||
val ptName = CodeGenerator.primitiveTypeName(jt)
|
||||
|
||||
nullSafeCodeGen(ctx, ev, (array1, array2) => {
|
||||
val foundNullElement = ctx.freshName("foundNullElement")
|
||||
val nullElementIndex = ctx.freshName("nullElementIndex")
|
||||
val builder = ctx.freshName("builder")
|
||||
val array = ctx.freshName("array")
|
||||
val arrays = ctx.freshName("arrays")
|
||||
val arrayDataIdx = ctx.freshName("arrayDataIdx")
|
||||
val openHashSet = classOf[OpenHashSet[_]].getName
|
||||
val openHashSet = classOf[SQLOpenHashSet[_]].getName
|
||||
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
|
||||
val hashSet = ctx.freshName("hashSet")
|
||||
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
|
||||
|
@ -3665,9 +3672,9 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
|||
if (dataType.asInstanceOf[ArrayType].containsNull) {
|
||||
s"""
|
||||
|if ($array.isNullAt($i)) {
|
||||
| if (!$foundNullElement) {
|
||||
| if (!$hashSet.containsNull()) {
|
||||
| $nullElementIndex = $size;
|
||||
| $foundNullElement = true;
|
||||
| $hashSet.addNull();
|
||||
| $size++;
|
||||
| $builder.$$plus$$eq($nullValueHolder);
|
||||
| }
|
||||
|
@ -3679,9 +3686,28 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
|||
body
|
||||
}
|
||||
|
||||
val processArray = withArrayNullAssignment(
|
||||
def withNaNCheck(body: String): String = {
|
||||
(elementType match {
|
||||
case DoubleType => Some(s"java.lang.Double.isNaN((double)$value)")
|
||||
case FloatType => Some(s"java.lang.Float.isNaN((float)$value)")
|
||||
case _ => None
|
||||
}).map { isNaN =>
|
||||
s"""
|
||||
|if ($isNaN) {
|
||||
| if (!$hashSet.containsNaN()) {
|
||||
| $size++;
|
||||
| $hashSet.addNaN();
|
||||
| $builder.$$plus$$eq($value);
|
||||
| }
|
||||
|} else {
|
||||
| $body
|
||||
|}
|
||||
""".stripMargin
|
||||
}
|
||||
}.getOrElse(body)
|
||||
|
||||
val body =
|
||||
s"""
|
||||
|$jt $value = ${genGetValue(array, i)};
|
||||
|if (!$hashSet.contains($hsValueCast$value)) {
|
||||
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
|
||||
| break;
|
||||
|
@ -3689,12 +3715,13 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
|||
| $hashSet.add$hsPostFix($hsValueCast$value);
|
||||
| $builder.$$plus$$eq($value);
|
||||
|}
|
||||
""".stripMargin)
|
||||
""".stripMargin
|
||||
val processArray =
|
||||
withArrayNullAssignment(s"$jt $value = ${genGetValue(array, i)};" + withNaNCheck(body))
|
||||
|
||||
// Only need to track null element index when result array's element is nullable.
|
||||
val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
|
||||
s"""
|
||||
|boolean $foundNullElement = false;
|
||||
|int $nullElementIndex = -1;
|
||||
""".stripMargin
|
||||
} else {
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.util
|
||||
|
||||
import scala.reflect._
|
||||
|
||||
import org.apache.spark.annotation.Private
|
||||
import org.apache.spark.sql.types.{DataType, DoubleType, FloatType}
|
||||
import org.apache.spark.util.collection.OpenHashSet
|
||||
|
||||
// A wrap of OpenHashSet that can handle null, Double.NaN and Float.NaN w.r.t. the SQL semantic.
|
||||
@Private
|
||||
class SQLOpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
|
||||
initialCapacity: Int,
|
||||
loadFactor: Double) {
|
||||
|
||||
def this(initialCapacity: Int) = this(initialCapacity, 0.7)
|
||||
|
||||
def this() = this(64)
|
||||
|
||||
private val hashSet = new OpenHashSet[T](initialCapacity, loadFactor)
|
||||
|
||||
private var containNull = false
|
||||
private var containNaN = false
|
||||
|
||||
def addNull(): Unit = {
|
||||
containNull = true
|
||||
}
|
||||
|
||||
def addNaN(): Unit = {
|
||||
containNaN = true
|
||||
}
|
||||
|
||||
def add(k: T): Unit = {
|
||||
hashSet.add(k)
|
||||
}
|
||||
|
||||
def contains(k: T): Boolean = {
|
||||
hashSet.contains(k)
|
||||
}
|
||||
|
||||
def containsNull(): Boolean = containNull
|
||||
|
||||
def containsNaN(): Boolean = containNaN
|
||||
}
|
||||
|
||||
object SQLOpenHashSet {
|
||||
def isNaN(dataType: DataType): Any => Boolean = {
|
||||
dataType match {
|
||||
case DoubleType =>
|
||||
(value: Any) => java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double])
|
||||
case FloatType =>
|
||||
(value: Any) => java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float])
|
||||
case _ => (_: Any) => false
|
||||
}
|
||||
}
|
||||
}
|
|
@ -2292,4 +2292,21 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-36702: ArrayUnion should handle duplicated Double.NaN and Float.Nan") {
|
||||
checkEvaluation(ArrayUnion(
|
||||
Literal.apply(Array(Double.NaN, Double.NaN)), Literal.apply(Array(1d))),
|
||||
Seq(Double.NaN, 1d))
|
||||
checkEvaluation(ArrayUnion(
|
||||
Literal.create(Seq(Double.NaN, null), ArrayType(DoubleType)),
|
||||
Literal.create(Seq(Double.NaN, null, 1d), ArrayType(DoubleType))),
|
||||
Seq(Double.NaN, null, 1d))
|
||||
checkEvaluation(ArrayUnion(
|
||||
Literal.apply(Array(Float.NaN, Float.NaN)), Literal.apply(Array(1f))),
|
||||
Seq(Float.NaN, 1f))
|
||||
checkEvaluation(ArrayUnion(
|
||||
Literal.create(Seq(Float.NaN, null), ArrayType(FloatType)),
|
||||
Literal.create(Seq(Float.NaN, null, 1f), ArrayType(FloatType))),
|
||||
Seq(Float.NaN, null, 1f))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue