[SPARK-36741][SQL] ArrayDistinct handle duplicated Double.NaN and Float.Nan
### What changes were proposed in this pull request?
For query
```
select array_distinct(array(cast('nan' as double), cast('nan' as double)))
```
This returns [NaN, NaN], but it should return [NaN].
This issue is caused by `OpenHashSet` can't handle `Double.NaN` and `Float.NaN` too.
In this pr fix this based on https://github.com/apache/spark/pull/33955
### Why are the changes needed?
Fix bug
### Does this PR introduce _any_ user-facing change?
ArrayDistinct won't show duplicated `NaN` value
### How was this patch tested?
Added UT
Closes #33993 from AngersZhuuuu/SPARK-36741.
Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit e356f6aa11
)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
16215755b7
commit
a78c4c44ed
|
@ -3412,32 +3412,59 @@ case class ArrayDistinct(child: Expression)
|
|||
}
|
||||
|
||||
override def nullSafeEval(array: Any): Any = {
|
||||
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
|
||||
val data = array.asInstanceOf[ArrayData]
|
||||
doEvaluation(data)
|
||||
}
|
||||
|
||||
@transient private lazy val doEvaluation = if (TypeUtils.typeWithProperEquals(elementType)) {
|
||||
(data: Array[AnyRef]) => new GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
|
||||
(array: ArrayData) =>
|
||||
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
|
||||
val hs = new SQLOpenHashSet[Any]()
|
||||
val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
|
||||
(value: Any) =>
|
||||
if (!hs.contains(value)) {
|
||||
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
|
||||
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
|
||||
}
|
||||
arrayBuffer += value
|
||||
hs.add(value)
|
||||
},
|
||||
(valueNaN: Any) => arrayBuffer += valueNaN)
|
||||
var i = 0
|
||||
while (i < array.numElements()) {
|
||||
if (array.isNullAt(i)) {
|
||||
if (!hs.containsNull) {
|
||||
hs.addNull
|
||||
arrayBuffer += null
|
||||
}
|
||||
} else {
|
||||
val elem = array.get(i, elementType)
|
||||
withNaNCheckFunc(elem)
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
new GenericArrayData(arrayBuffer.toSeq)
|
||||
} else {
|
||||
(data: Array[AnyRef]) => {
|
||||
(data: ArrayData) => {
|
||||
val array = data.toArray[AnyRef](elementType)
|
||||
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[AnyRef]
|
||||
var alreadyStoredNull = false
|
||||
for (i <- 0 until data.length) {
|
||||
if (data(i) != null) {
|
||||
for (i <- 0 until array.length) {
|
||||
if (array(i) != null) {
|
||||
var found = false
|
||||
var j = 0
|
||||
while (!found && j < arrayBuffer.size) {
|
||||
val va = arrayBuffer(j)
|
||||
found = (va != null) && ordering.equiv(va, data(i))
|
||||
found = (va != null) && ordering.equiv(va, array(i))
|
||||
j += 1
|
||||
}
|
||||
if (!found) {
|
||||
arrayBuffer += data(i)
|
||||
arrayBuffer += array(i)
|
||||
}
|
||||
} else {
|
||||
// De-duplicate the null values.
|
||||
if (!alreadyStoredNull) {
|
||||
arrayBuffer += data(i)
|
||||
arrayBuffer += array(i)
|
||||
alreadyStoredNull = true
|
||||
}
|
||||
}
|
||||
|
@ -3456,10 +3483,9 @@ case class ArrayDistinct(child: Expression)
|
|||
val ptName = CodeGenerator.primitiveTypeName(jt)
|
||||
|
||||
nullSafeCodeGen(ctx, ev, (array) => {
|
||||
val foundNullElement = ctx.freshName("foundNullElement")
|
||||
val nullElementIndex = ctx.freshName("nullElementIndex")
|
||||
val builder = ctx.freshName("builder")
|
||||
val openHashSet = classOf[OpenHashSet[_]].getName
|
||||
val openHashSet = classOf[SQLOpenHashSet[_]].getName
|
||||
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
|
||||
val hashSet = ctx.freshName("hashSet")
|
||||
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
|
||||
|
@ -3468,7 +3494,6 @@ case class ArrayDistinct(child: Expression)
|
|||
// Only need to track null element index when array's element is nullable.
|
||||
val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
|
||||
s"""
|
||||
|boolean $foundNullElement = false;
|
||||
|int $nullElementIndex = -1;
|
||||
""".stripMargin
|
||||
} else {
|
||||
|
@ -3479,9 +3504,9 @@ case class ArrayDistinct(child: Expression)
|
|||
if (dataType.asInstanceOf[ArrayType].containsNull) {
|
||||
s"""
|
||||
|if ($array.isNullAt($i)) {
|
||||
| if (!$foundNullElement) {
|
||||
| if (!$hashSet.containsNull()) {
|
||||
| $nullElementIndex = $size;
|
||||
| $foundNullElement = true;
|
||||
| $hashSet.addNull();
|
||||
| $size++;
|
||||
| $builder.$$plus$$eq($nullValueHolder);
|
||||
| }
|
||||
|
@ -3493,9 +3518,8 @@ case class ArrayDistinct(child: Expression)
|
|||
body
|
||||
}
|
||||
|
||||
val processArray = withArrayNullAssignment(
|
||||
val body =
|
||||
s"""
|
||||
|$jt $value = ${genGetValue(array, i)};
|
||||
|if (!$hashSet.contains($hsValueCast$value)) {
|
||||
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
|
||||
| break;
|
||||
|
@ -3503,7 +3527,16 @@ case class ArrayDistinct(child: Expression)
|
|||
| $hashSet.add$hsPostFix($hsValueCast$value);
|
||||
| $builder.$$plus$$eq($value);
|
||||
|}
|
||||
""".stripMargin)
|
||||
""".stripMargin
|
||||
|
||||
val processArray = withArrayNullAssignment(
|
||||
s"$jt $value = ${genGetValue(array, i)};" +
|
||||
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, body,
|
||||
(valueNaN: String) =>
|
||||
s"""
|
||||
|$size++;
|
||||
|$builder.$$plus$$eq($valueNaN);
|
||||
|""".stripMargin))
|
||||
|
||||
s"""
|
||||
|$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag);
|
||||
|
@ -3579,8 +3612,16 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
|||
(array1, array2) =>
|
||||
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
|
||||
val hs = new SQLOpenHashSet[Any]()
|
||||
val isNaN = SQLOpenHashSet.isNaN(elementType)
|
||||
val valueNaN = SQLOpenHashSet.valueNaN(elementType)
|
||||
val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
|
||||
(value: Any) =>
|
||||
if (!hs.contains(value)) {
|
||||
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
|
||||
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
|
||||
}
|
||||
arrayBuffer += value
|
||||
hs.add(value)
|
||||
},
|
||||
(valueNaN: Any) => arrayBuffer += valueNaN)
|
||||
Seq(array1, array2).foreach { array =>
|
||||
var i = 0
|
||||
while (i < array.numElements()) {
|
||||
|
@ -3591,20 +3632,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
|||
}
|
||||
} else {
|
||||
val elem = array.get(i, elementType)
|
||||
if (isNaN(elem)) {
|
||||
if (!hs.containsNaN) {
|
||||
arrayBuffer += valueNaN
|
||||
hs.addNaN
|
||||
}
|
||||
} else {
|
||||
if (!hs.contains(elem)) {
|
||||
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
|
||||
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
|
||||
}
|
||||
arrayBuffer += elem
|
||||
hs.add(elem)
|
||||
}
|
||||
}
|
||||
withNaNCheckFunc(elem)
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
|
@ -3689,28 +3717,6 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
|||
body
|
||||
}
|
||||
|
||||
def withNaNCheck(body: String): String = {
|
||||
(elementType match {
|
||||
case DoubleType =>
|
||||
Some((s"java.lang.Double.isNaN((double)$value)", "java.lang.Double.NaN"))
|
||||
case FloatType =>
|
||||
Some((s"java.lang.Float.isNaN((float)$value)", "java.lang.Float.NaN"))
|
||||
case _ => None
|
||||
}).map { case (isNaN, valueNaN) =>
|
||||
s"""
|
||||
|if ($isNaN) {
|
||||
| if (!$hashSet.containsNaN()) {
|
||||
| $size++;
|
||||
| $hashSet.addNaN();
|
||||
| $builder.$$plus$$eq($valueNaN);
|
||||
| }
|
||||
|} else {
|
||||
| $body
|
||||
|}
|
||||
""".stripMargin
|
||||
}
|
||||
}.getOrElse(body)
|
||||
|
||||
val body =
|
||||
s"""
|
||||
|if (!$hashSet.contains($hsValueCast$value)) {
|
||||
|
@ -3721,8 +3727,14 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
|||
| $builder.$$plus$$eq($value);
|
||||
|}
|
||||
""".stripMargin
|
||||
val processArray =
|
||||
withArrayNullAssignment(s"$jt $value = ${genGetValue(array, i)};" + withNaNCheck(body))
|
||||
val processArray = withArrayNullAssignment(
|
||||
s"$jt $value = ${genGetValue(array, i)};" +
|
||||
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, body,
|
||||
(valueNaN: String) =>
|
||||
s"""
|
||||
|$size++;
|
||||
|$builder.$$plus$$eq($valueNaN);
|
||||
""".stripMargin))
|
||||
|
||||
// Only need to track null element index when result array's element is nullable.
|
||||
val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
|
||||
|
|
|
@ -60,21 +60,55 @@ class SQLOpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
|
|||
}
|
||||
|
||||
object SQLOpenHashSet {
|
||||
def isNaN(dataType: DataType): Any => Boolean = {
|
||||
dataType match {
|
||||
def withNaNCheckFunc(
|
||||
dataType: DataType,
|
||||
hashSet: SQLOpenHashSet[Any],
|
||||
handleNotNaN: Any => Unit,
|
||||
handleNaN: Any => Unit): Any => Unit = {
|
||||
val (isNaN, valueNaN) = dataType match {
|
||||
case DoubleType =>
|
||||
(value: Any) => java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double])
|
||||
((value: Any) => java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double]),
|
||||
java.lang.Double.NaN)
|
||||
case FloatType =>
|
||||
(value: Any) => java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float])
|
||||
case _ => (_: Any) => false
|
||||
((value: Any) => java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float]),
|
||||
java.lang.Float.NaN)
|
||||
case _ => ((_: Any) => false, null)
|
||||
}
|
||||
(value: Any) =>
|
||||
if (isNaN(value)) {
|
||||
if (!hashSet.containsNaN) {
|
||||
hashSet.addNaN
|
||||
handleNaN(valueNaN)
|
||||
}
|
||||
} else {
|
||||
handleNotNaN(value)
|
||||
}
|
||||
}
|
||||
|
||||
def valueNaN(dataType: DataType): Any = {
|
||||
dataType match {
|
||||
case DoubleType => java.lang.Double.NaN
|
||||
case FloatType => java.lang.Float.NaN
|
||||
case _ => null
|
||||
def withNaNCheckCode(
|
||||
dataType: DataType,
|
||||
valueName: String,
|
||||
hashSet: String,
|
||||
handleNotNaN: String,
|
||||
handleNaN: String => String): String = {
|
||||
val ret = dataType match {
|
||||
case DoubleType =>
|
||||
Some((s"java.lang.Double.isNaN((double)$valueName)", "java.lang.Double.NaN"))
|
||||
case FloatType =>
|
||||
Some((s"java.lang.Float.isNaN((float)$valueName)", "java.lang.Float.NaN"))
|
||||
case _ => None
|
||||
}
|
||||
ret.map { case (isNaN, valueNaN) =>
|
||||
s"""
|
||||
|if ($isNaN) {
|
||||
| if (!$hashSet.containsNaN()) {
|
||||
| $hashSet.addNaN();
|
||||
| ${handleNaN(valueNaN)}
|
||||
| }
|
||||
|} else {
|
||||
| $handleNotNaN
|
||||
|}
|
||||
""".stripMargin
|
||||
}.getOrElse(handleNotNaN)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2310,6 +2310,15 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
Seq(Float.NaN, null, 1f))
|
||||
}
|
||||
|
||||
test("SPARK-36741: ArrayDistinct should handle duplicated Double.NaN and Float.Nan") {
|
||||
checkEvaluation(ArrayDistinct(
|
||||
Literal.create(Seq(Double.NaN, Double.NaN, null, null, 1d, 1d), ArrayType(DoubleType))),
|
||||
Seq(Double.NaN, null, 1d))
|
||||
checkEvaluation(ArrayDistinct(
|
||||
Literal.create(Seq(Float.NaN, Float.NaN, null, null, 1f, 1f), ArrayType(FloatType))),
|
||||
Seq(Float.NaN, null, 1f))
|
||||
}
|
||||
|
||||
test("SPARK-36755: ArraysOverlap hould handle duplicated Double.NaN and Float.Nan") {
|
||||
checkEvaluation(ArraysOverlap(
|
||||
Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN))), true)
|
||||
|
|
Loading…
Reference in a new issue