[SPARK-36702][SQL] ArrayUnion handle duplicated Double.NaN and Float.Nan
### What changes were proposed in this pull request? For query ``` select array_union(array(cast('nan' as double), cast('nan' as double)), array()) ``` This returns [NaN, NaN], but it should return [NaN]. This issue is caused by `OpenHashSet` can't handle `Double.NaN` and `Float.NaN` too. In this pr we add a wrap for OpenHashSet that can handle `null`, `Double.NaN`, `Float.NaN` together ### Why are the changes needed? Fix bug ### Does this PR introduce _any_ user-facing change? ArrayUnion won't show duplicated `NaN` value ### How was this patch tested? Added UT Closes #33955 from AngersZhuuuu/SPARK-36702-WrapOpenHashSet. Lead-authored-by: Angerszhuuuu <angers.zhu@gmail.com> Co-authored-by: AngersZhuuuu <angers.zhu@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
2d7dc7c7ce
commit
f71f37755d
|
@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils._
|
||||||
import org.apache.spark.sql.errors.QueryExecutionErrors
|
import org.apache.spark.sql.errors.QueryExecutionErrors
|
||||||
import org.apache.spark.sql.internal.SQLConf
|
import org.apache.spark.sql.internal.SQLConf
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
import org.apache.spark.sql.util.SQLOpenHashSet
|
||||||
import org.apache.spark.unsafe.UTF8StringBuilder
|
import org.apache.spark.unsafe.UTF8StringBuilder
|
||||||
import org.apache.spark.unsafe.array.ByteArrayMethods
|
import org.apache.spark.unsafe.array.ByteArrayMethods
|
||||||
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
|
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
|
||||||
|
@ -3575,18 +3576,24 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
||||||
if (TypeUtils.typeWithProperEquals(elementType)) {
|
if (TypeUtils.typeWithProperEquals(elementType)) {
|
||||||
(array1, array2) =>
|
(array1, array2) =>
|
||||||
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
|
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
|
||||||
val hs = new OpenHashSet[Any]
|
val hs = new SQLOpenHashSet[Any]()
|
||||||
var foundNullElement = false
|
val isNaN = SQLOpenHashSet.isNaN(elementType)
|
||||||
Seq(array1, array2).foreach { array =>
|
Seq(array1, array2).foreach { array =>
|
||||||
var i = 0
|
var i = 0
|
||||||
while (i < array.numElements()) {
|
while (i < array.numElements()) {
|
||||||
if (array.isNullAt(i)) {
|
if (array.isNullAt(i)) {
|
||||||
if (!foundNullElement) {
|
if (!hs.containsNull) {
|
||||||
|
hs.addNull
|
||||||
arrayBuffer += null
|
arrayBuffer += null
|
||||||
foundNullElement = true
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
val elem = array.get(i, elementType)
|
val elem = array.get(i, elementType)
|
||||||
|
if (isNaN(elem)) {
|
||||||
|
if (!hs.containsNaN) {
|
||||||
|
arrayBuffer += elem
|
||||||
|
hs.addNaN
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if (!hs.contains(elem)) {
|
if (!hs.contains(elem)) {
|
||||||
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
|
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
|
||||||
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
|
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
|
||||||
|
@ -3595,6 +3602,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
||||||
hs.add(elem)
|
hs.add(elem)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
i += 1
|
i += 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3649,13 +3657,12 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
||||||
val ptName = CodeGenerator.primitiveTypeName(jt)
|
val ptName = CodeGenerator.primitiveTypeName(jt)
|
||||||
|
|
||||||
nullSafeCodeGen(ctx, ev, (array1, array2) => {
|
nullSafeCodeGen(ctx, ev, (array1, array2) => {
|
||||||
val foundNullElement = ctx.freshName("foundNullElement")
|
|
||||||
val nullElementIndex = ctx.freshName("nullElementIndex")
|
val nullElementIndex = ctx.freshName("nullElementIndex")
|
||||||
val builder = ctx.freshName("builder")
|
val builder = ctx.freshName("builder")
|
||||||
val array = ctx.freshName("array")
|
val array = ctx.freshName("array")
|
||||||
val arrays = ctx.freshName("arrays")
|
val arrays = ctx.freshName("arrays")
|
||||||
val arrayDataIdx = ctx.freshName("arrayDataIdx")
|
val arrayDataIdx = ctx.freshName("arrayDataIdx")
|
||||||
val openHashSet = classOf[OpenHashSet[_]].getName
|
val openHashSet = classOf[SQLOpenHashSet[_]].getName
|
||||||
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
|
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
|
||||||
val hashSet = ctx.freshName("hashSet")
|
val hashSet = ctx.freshName("hashSet")
|
||||||
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
|
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
|
||||||
|
@ -3665,9 +3672,9 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
||||||
if (dataType.asInstanceOf[ArrayType].containsNull) {
|
if (dataType.asInstanceOf[ArrayType].containsNull) {
|
||||||
s"""
|
s"""
|
||||||
|if ($array.isNullAt($i)) {
|
|if ($array.isNullAt($i)) {
|
||||||
| if (!$foundNullElement) {
|
| if (!$hashSet.containsNull()) {
|
||||||
| $nullElementIndex = $size;
|
| $nullElementIndex = $size;
|
||||||
| $foundNullElement = true;
|
| $hashSet.addNull();
|
||||||
| $size++;
|
| $size++;
|
||||||
| $builder.$$plus$$eq($nullValueHolder);
|
| $builder.$$plus$$eq($nullValueHolder);
|
||||||
| }
|
| }
|
||||||
|
@ -3679,9 +3686,28 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
||||||
body
|
body
|
||||||
}
|
}
|
||||||
|
|
||||||
val processArray = withArrayNullAssignment(
|
def withNaNCheck(body: String): String = {
|
||||||
|
(elementType match {
|
||||||
|
case DoubleType => Some(s"java.lang.Double.isNaN((double)$value)")
|
||||||
|
case FloatType => Some(s"java.lang.Float.isNaN((float)$value)")
|
||||||
|
case _ => None
|
||||||
|
}).map { isNaN =>
|
||||||
|
s"""
|
||||||
|
|if ($isNaN) {
|
||||||
|
| if (!$hashSet.containsNaN()) {
|
||||||
|
| $size++;
|
||||||
|
| $hashSet.addNaN();
|
||||||
|
| $builder.$$plus$$eq($value);
|
||||||
|
| }
|
||||||
|
|} else {
|
||||||
|
| $body
|
||||||
|
|}
|
||||||
|
""".stripMargin
|
||||||
|
}
|
||||||
|
}.getOrElse(body)
|
||||||
|
|
||||||
|
val body =
|
||||||
s"""
|
s"""
|
||||||
|$jt $value = ${genGetValue(array, i)};
|
|
||||||
|if (!$hashSet.contains($hsValueCast$value)) {
|
|if (!$hashSet.contains($hsValueCast$value)) {
|
||||||
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
|
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
|
||||||
| break;
|
| break;
|
||||||
|
@ -3689,12 +3715,13 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
||||||
| $hashSet.add$hsPostFix($hsValueCast$value);
|
| $hashSet.add$hsPostFix($hsValueCast$value);
|
||||||
| $builder.$$plus$$eq($value);
|
| $builder.$$plus$$eq($value);
|
||||||
|}
|
|}
|
||||||
""".stripMargin)
|
""".stripMargin
|
||||||
|
val processArray =
|
||||||
|
withArrayNullAssignment(s"$jt $value = ${genGetValue(array, i)};" + withNaNCheck(body))
|
||||||
|
|
||||||
// Only need to track null element index when result array's element is nullable.
|
// Only need to track null element index when result array's element is nullable.
|
||||||
val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
|
val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
|
||||||
s"""
|
s"""
|
||||||
|boolean $foundNullElement = false;
|
|
||||||
|int $nullElementIndex = -1;
|
|int $nullElementIndex = -1;
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.util
|
||||||
|
|
||||||
|
import scala.reflect._
|
||||||
|
|
||||||
|
import org.apache.spark.annotation.Private
|
||||||
|
import org.apache.spark.sql.types.{DataType, DoubleType, FloatType}
|
||||||
|
import org.apache.spark.util.collection.OpenHashSet
|
||||||
|
|
||||||
|
// A wrap of OpenHashSet that can handle null, Double.NaN and Float.NaN w.r.t. the SQL semantic.
|
||||||
|
@Private
|
||||||
|
class SQLOpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
|
||||||
|
initialCapacity: Int,
|
||||||
|
loadFactor: Double) {
|
||||||
|
|
||||||
|
def this(initialCapacity: Int) = this(initialCapacity, 0.7)
|
||||||
|
|
||||||
|
def this() = this(64)
|
||||||
|
|
||||||
|
private val hashSet = new OpenHashSet[T](initialCapacity, loadFactor)
|
||||||
|
|
||||||
|
private var containNull = false
|
||||||
|
private var containNaN = false
|
||||||
|
|
||||||
|
def addNull(): Unit = {
|
||||||
|
containNull = true
|
||||||
|
}
|
||||||
|
|
||||||
|
def addNaN(): Unit = {
|
||||||
|
containNaN = true
|
||||||
|
}
|
||||||
|
|
||||||
|
def add(k: T): Unit = {
|
||||||
|
hashSet.add(k)
|
||||||
|
}
|
||||||
|
|
||||||
|
def contains(k: T): Boolean = {
|
||||||
|
hashSet.contains(k)
|
||||||
|
}
|
||||||
|
|
||||||
|
def containsNull(): Boolean = containNull
|
||||||
|
|
||||||
|
def containsNaN(): Boolean = containNaN
|
||||||
|
}
|
||||||
|
|
||||||
|
object SQLOpenHashSet {
|
||||||
|
def isNaN(dataType: DataType): Any => Boolean = {
|
||||||
|
dataType match {
|
||||||
|
case DoubleType =>
|
||||||
|
(value: Any) => java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double])
|
||||||
|
case FloatType =>
|
||||||
|
(value: Any) => java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float])
|
||||||
|
case _ => (_: Any) => false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -2309,4 +2309,21 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-36702: ArrayUnion should handle duplicated Double.NaN and Float.Nan") {
|
||||||
|
checkEvaluation(ArrayUnion(
|
||||||
|
Literal.apply(Array(Double.NaN, Double.NaN)), Literal.apply(Array(1d))),
|
||||||
|
Seq(Double.NaN, 1d))
|
||||||
|
checkEvaluation(ArrayUnion(
|
||||||
|
Literal.create(Seq(Double.NaN, null), ArrayType(DoubleType)),
|
||||||
|
Literal.create(Seq(Double.NaN, null, 1d), ArrayType(DoubleType))),
|
||||||
|
Seq(Double.NaN, null, 1d))
|
||||||
|
checkEvaluation(ArrayUnion(
|
||||||
|
Literal.apply(Array(Float.NaN, Float.NaN)), Literal.apply(Array(1f))),
|
||||||
|
Seq(Float.NaN, 1f))
|
||||||
|
checkEvaluation(ArrayUnion(
|
||||||
|
Literal.create(Seq(Float.NaN, null), ArrayType(FloatType)),
|
||||||
|
Literal.create(Seq(Float.NaN, null, 1f), ArrayType(FloatType))),
|
||||||
|
Seq(Float.NaN, null, 1f))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue