[SPARK-36741][SQL] ArrayDistinct handle duplicated Double.NaN and Float.Nan

### What changes were proposed in this pull request?
For query
```
select array_distinct(array(cast('nan' as double), cast('nan' as double)))
```
This returns [NaN, NaN], but it should return [NaN].
This issue is caused by `OpenHashSet` can't handle `Double.NaN` and `Float.NaN` too.
In this pr fix this based on https://github.com/apache/spark/pull/33955

### Why are the changes needed?
Fix bug

### Does this PR introduce _any_ user-facing change?
ArrayDistinct won't show duplicated `NaN` value

### How was this patch tested?
Added UT

Closes #33993 from AngersZhuuuu/SPARK-36741.

Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit e356f6aa11)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Angerszhuuuu 2021-09-17 20:48:17 +08:00 committed by Wenchen Fan
parent 16215755b7
commit a78c4c44ed
3 changed files with 121 additions and 66 deletions

View file

@ -3412,32 +3412,59 @@ case class ArrayDistinct(child: Expression)
}
override def nullSafeEval(array: Any): Any = {
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
val data = array.asInstanceOf[ArrayData]
doEvaluation(data)
}
@transient private lazy val doEvaluation = if (TypeUtils.typeWithProperEquals(elementType)) {
(data: Array[AnyRef]) => new GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
(array: ArrayData) =>
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
val hs = new SQLOpenHashSet[Any]()
val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
(value: Any) =>
if (!hs.contains(value)) {
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
}
arrayBuffer += value
hs.add(value)
},
(valueNaN: Any) => arrayBuffer += valueNaN)
var i = 0
while (i < array.numElements()) {
if (array.isNullAt(i)) {
if (!hs.containsNull) {
hs.addNull
arrayBuffer += null
}
} else {
(data: Array[AnyRef]) => {
val elem = array.get(i, elementType)
withNaNCheckFunc(elem)
}
i += 1
}
new GenericArrayData(arrayBuffer.toSeq)
} else {
(data: ArrayData) => {
val array = data.toArray[AnyRef](elementType)
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[AnyRef]
var alreadyStoredNull = false
for (i <- 0 until data.length) {
if (data(i) != null) {
for (i <- 0 until array.length) {
if (array(i) != null) {
var found = false
var j = 0
while (!found && j < arrayBuffer.size) {
val va = arrayBuffer(j)
found = (va != null) && ordering.equiv(va, data(i))
found = (va != null) && ordering.equiv(va, array(i))
j += 1
}
if (!found) {
arrayBuffer += data(i)
arrayBuffer += array(i)
}
} else {
// De-duplicate the null values.
if (!alreadyStoredNull) {
arrayBuffer += data(i)
arrayBuffer += array(i)
alreadyStoredNull = true
}
}
@ -3456,10 +3483,9 @@ case class ArrayDistinct(child: Expression)
val ptName = CodeGenerator.primitiveTypeName(jt)
nullSafeCodeGen(ctx, ev, (array) => {
val foundNullElement = ctx.freshName("foundNullElement")
val nullElementIndex = ctx.freshName("nullElementIndex")
val builder = ctx.freshName("builder")
val openHashSet = classOf[OpenHashSet[_]].getName
val openHashSet = classOf[SQLOpenHashSet[_]].getName
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
val hashSet = ctx.freshName("hashSet")
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
@ -3468,7 +3494,6 @@ case class ArrayDistinct(child: Expression)
// Only need to track null element index when array's element is nullable.
val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|boolean $foundNullElement = false;
|int $nullElementIndex = -1;
""".stripMargin
} else {
@ -3479,9 +3504,9 @@ case class ArrayDistinct(child: Expression)
if (dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|if ($array.isNullAt($i)) {
| if (!$foundNullElement) {
| if (!$hashSet.containsNull()) {
| $nullElementIndex = $size;
| $foundNullElement = true;
| $hashSet.addNull();
| $size++;
| $builder.$$plus$$eq($nullValueHolder);
| }
@ -3493,9 +3518,8 @@ case class ArrayDistinct(child: Expression)
body
}
val processArray = withArrayNullAssignment(
val body =
s"""
|$jt $value = ${genGetValue(array, i)};
|if (!$hashSet.contains($hsValueCast$value)) {
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| break;
@ -3503,7 +3527,16 @@ case class ArrayDistinct(child: Expression)
| $hashSet.add$hsPostFix($hsValueCast$value);
| $builder.$$plus$$eq($value);
|}
""".stripMargin)
""".stripMargin
val processArray = withArrayNullAssignment(
s"$jt $value = ${genGetValue(array, i)};" +
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, body,
(valueNaN: String) =>
s"""
|$size++;
|$builder.$$plus$$eq($valueNaN);
|""".stripMargin))
s"""
|$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag);
@ -3579,8 +3612,16 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
(array1, array2) =>
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
val hs = new SQLOpenHashSet[Any]()
val isNaN = SQLOpenHashSet.isNaN(elementType)
val valueNaN = SQLOpenHashSet.valueNaN(elementType)
val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
(value: Any) =>
if (!hs.contains(value)) {
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
}
arrayBuffer += value
hs.add(value)
},
(valueNaN: Any) => arrayBuffer += valueNaN)
Seq(array1, array2).foreach { array =>
var i = 0
while (i < array.numElements()) {
@ -3591,20 +3632,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
}
} else {
val elem = array.get(i, elementType)
if (isNaN(elem)) {
if (!hs.containsNaN) {
arrayBuffer += valueNaN
hs.addNaN
}
} else {
if (!hs.contains(elem)) {
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
}
arrayBuffer += elem
hs.add(elem)
}
}
withNaNCheckFunc(elem)
}
i += 1
}
@ -3689,28 +3717,6 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
body
}
def withNaNCheck(body: String): String = {
(elementType match {
case DoubleType =>
Some((s"java.lang.Double.isNaN((double)$value)", "java.lang.Double.NaN"))
case FloatType =>
Some((s"java.lang.Float.isNaN((float)$value)", "java.lang.Float.NaN"))
case _ => None
}).map { case (isNaN, valueNaN) =>
s"""
|if ($isNaN) {
| if (!$hashSet.containsNaN()) {
| $size++;
| $hashSet.addNaN();
| $builder.$$plus$$eq($valueNaN);
| }
|} else {
| $body
|}
""".stripMargin
}
}.getOrElse(body)
val body =
s"""
|if (!$hashSet.contains($hsValueCast$value)) {
@ -3721,8 +3727,14 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
| $builder.$$plus$$eq($value);
|}
""".stripMargin
val processArray =
withArrayNullAssignment(s"$jt $value = ${genGetValue(array, i)};" + withNaNCheck(body))
val processArray = withArrayNullAssignment(
s"$jt $value = ${genGetValue(array, i)};" +
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, body,
(valueNaN: String) =>
s"""
|$size++;
|$builder.$$plus$$eq($valueNaN);
""".stripMargin))
// Only need to track null element index when result array's element is nullable.
val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {

View file

@ -60,21 +60,55 @@ class SQLOpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
}
object SQLOpenHashSet {
def isNaN(dataType: DataType): Any => Boolean = {
dataType match {
def withNaNCheckFunc(
dataType: DataType,
hashSet: SQLOpenHashSet[Any],
handleNotNaN: Any => Unit,
handleNaN: Any => Unit): Any => Unit = {
val (isNaN, valueNaN) = dataType match {
case DoubleType =>
(value: Any) => java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double])
((value: Any) => java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double]),
java.lang.Double.NaN)
case FloatType =>
(value: Any) => java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float])
case _ => (_: Any) => false
((value: Any) => java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float]),
java.lang.Float.NaN)
case _ => ((_: Any) => false, null)
}
(value: Any) =>
if (isNaN(value)) {
if (!hashSet.containsNaN) {
hashSet.addNaN
handleNaN(valueNaN)
}
} else {
handleNotNaN(value)
}
}
def valueNaN(dataType: DataType): Any = {
dataType match {
case DoubleType => java.lang.Double.NaN
case FloatType => java.lang.Float.NaN
case _ => null
def withNaNCheckCode(
dataType: DataType,
valueName: String,
hashSet: String,
handleNotNaN: String,
handleNaN: String => String): String = {
val ret = dataType match {
case DoubleType =>
Some((s"java.lang.Double.isNaN((double)$valueName)", "java.lang.Double.NaN"))
case FloatType =>
Some((s"java.lang.Float.isNaN((float)$valueName)", "java.lang.Float.NaN"))
case _ => None
}
ret.map { case (isNaN, valueNaN) =>
s"""
|if ($isNaN) {
| if (!$hashSet.containsNaN()) {
| $hashSet.addNaN();
| ${handleNaN(valueNaN)}
| }
|} else {
| $handleNotNaN
|}
""".stripMargin
}.getOrElse(handleNotNaN)
}
}

View file

@ -2310,6 +2310,15 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
Seq(Float.NaN, null, 1f))
}
test("SPARK-36741: ArrayDistinct should handle duplicated Double.NaN and Float.Nan") {
checkEvaluation(ArrayDistinct(
Literal.create(Seq(Double.NaN, Double.NaN, null, null, 1d, 1d), ArrayType(DoubleType))),
Seq(Double.NaN, null, 1d))
checkEvaluation(ArrayDistinct(
Literal.create(Seq(Float.NaN, Float.NaN, null, null, 1f, 1f), ArrayType(FloatType))),
Seq(Float.NaN, null, 1f))
}
test("SPARK-36755: ArraysOverlap hould handle duplicated Double.NaN and Float.Nan") {
checkEvaluation(ArraysOverlap(
Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN))), true)