[SPARK-36741][SQL] ArrayDistinct handle duplicated Double.NaN and Float.Nan
### What changes were proposed in this pull request?
For query
```
select array_distinct(array(cast('nan' as double), cast('nan' as double)))
```
This returns [NaN, NaN], but it should return [NaN].
This issue is caused by `OpenHashSet` can't handle `Double.NaN` and `Float.NaN` too.
In this pr fix this based on https://github.com/apache/spark/pull/33955
### Why are the changes needed?
Fix bug
### Does this PR introduce _any_ user-facing change?
ArrayDistinct won't show duplicated `NaN` value
### How was this patch tested?
Added UT
Closes #33993 from AngersZhuuuu/SPARK-36741.
Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit e356f6aa11
)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
16215755b7
commit
a78c4c44ed
|
@ -3412,32 +3412,59 @@ case class ArrayDistinct(child: Expression)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def nullSafeEval(array: Any): Any = {
|
override def nullSafeEval(array: Any): Any = {
|
||||||
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
|
val data = array.asInstanceOf[ArrayData]
|
||||||
doEvaluation(data)
|
doEvaluation(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
@transient private lazy val doEvaluation = if (TypeUtils.typeWithProperEquals(elementType)) {
|
@transient private lazy val doEvaluation = if (TypeUtils.typeWithProperEquals(elementType)) {
|
||||||
(data: Array[AnyRef]) => new GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
|
(array: ArrayData) =>
|
||||||
|
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
|
||||||
|
val hs = new SQLOpenHashSet[Any]()
|
||||||
|
val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
|
||||||
|
(value: Any) =>
|
||||||
|
if (!hs.contains(value)) {
|
||||||
|
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
|
||||||
|
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
|
||||||
|
}
|
||||||
|
arrayBuffer += value
|
||||||
|
hs.add(value)
|
||||||
|
},
|
||||||
|
(valueNaN: Any) => arrayBuffer += valueNaN)
|
||||||
|
var i = 0
|
||||||
|
while (i < array.numElements()) {
|
||||||
|
if (array.isNullAt(i)) {
|
||||||
|
if (!hs.containsNull) {
|
||||||
|
hs.addNull
|
||||||
|
arrayBuffer += null
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
val elem = array.get(i, elementType)
|
||||||
|
withNaNCheckFunc(elem)
|
||||||
|
}
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
new GenericArrayData(arrayBuffer.toSeq)
|
||||||
} else {
|
} else {
|
||||||
(data: Array[AnyRef]) => {
|
(data: ArrayData) => {
|
||||||
|
val array = data.toArray[AnyRef](elementType)
|
||||||
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[AnyRef]
|
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[AnyRef]
|
||||||
var alreadyStoredNull = false
|
var alreadyStoredNull = false
|
||||||
for (i <- 0 until data.length) {
|
for (i <- 0 until array.length) {
|
||||||
if (data(i) != null) {
|
if (array(i) != null) {
|
||||||
var found = false
|
var found = false
|
||||||
var j = 0
|
var j = 0
|
||||||
while (!found && j < arrayBuffer.size) {
|
while (!found && j < arrayBuffer.size) {
|
||||||
val va = arrayBuffer(j)
|
val va = arrayBuffer(j)
|
||||||
found = (va != null) && ordering.equiv(va, data(i))
|
found = (va != null) && ordering.equiv(va, array(i))
|
||||||
j += 1
|
j += 1
|
||||||
}
|
}
|
||||||
if (!found) {
|
if (!found) {
|
||||||
arrayBuffer += data(i)
|
arrayBuffer += array(i)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// De-duplicate the null values.
|
// De-duplicate the null values.
|
||||||
if (!alreadyStoredNull) {
|
if (!alreadyStoredNull) {
|
||||||
arrayBuffer += data(i)
|
arrayBuffer += array(i)
|
||||||
alreadyStoredNull = true
|
alreadyStoredNull = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3456,10 +3483,9 @@ case class ArrayDistinct(child: Expression)
|
||||||
val ptName = CodeGenerator.primitiveTypeName(jt)
|
val ptName = CodeGenerator.primitiveTypeName(jt)
|
||||||
|
|
||||||
nullSafeCodeGen(ctx, ev, (array) => {
|
nullSafeCodeGen(ctx, ev, (array) => {
|
||||||
val foundNullElement = ctx.freshName("foundNullElement")
|
|
||||||
val nullElementIndex = ctx.freshName("nullElementIndex")
|
val nullElementIndex = ctx.freshName("nullElementIndex")
|
||||||
val builder = ctx.freshName("builder")
|
val builder = ctx.freshName("builder")
|
||||||
val openHashSet = classOf[OpenHashSet[_]].getName
|
val openHashSet = classOf[SQLOpenHashSet[_]].getName
|
||||||
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
|
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
|
||||||
val hashSet = ctx.freshName("hashSet")
|
val hashSet = ctx.freshName("hashSet")
|
||||||
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
|
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
|
||||||
|
@ -3468,7 +3494,6 @@ case class ArrayDistinct(child: Expression)
|
||||||
// Only need to track null element index when array's element is nullable.
|
// Only need to track null element index when array's element is nullable.
|
||||||
val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
|
val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
|
||||||
s"""
|
s"""
|
||||||
|boolean $foundNullElement = false;
|
|
||||||
|int $nullElementIndex = -1;
|
|int $nullElementIndex = -1;
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
} else {
|
} else {
|
||||||
|
@ -3479,9 +3504,9 @@ case class ArrayDistinct(child: Expression)
|
||||||
if (dataType.asInstanceOf[ArrayType].containsNull) {
|
if (dataType.asInstanceOf[ArrayType].containsNull) {
|
||||||
s"""
|
s"""
|
||||||
|if ($array.isNullAt($i)) {
|
|if ($array.isNullAt($i)) {
|
||||||
| if (!$foundNullElement) {
|
| if (!$hashSet.containsNull()) {
|
||||||
| $nullElementIndex = $size;
|
| $nullElementIndex = $size;
|
||||||
| $foundNullElement = true;
|
| $hashSet.addNull();
|
||||||
| $size++;
|
| $size++;
|
||||||
| $builder.$$plus$$eq($nullValueHolder);
|
| $builder.$$plus$$eq($nullValueHolder);
|
||||||
| }
|
| }
|
||||||
|
@ -3493,9 +3518,8 @@ case class ArrayDistinct(child: Expression)
|
||||||
body
|
body
|
||||||
}
|
}
|
||||||
|
|
||||||
val processArray = withArrayNullAssignment(
|
val body =
|
||||||
s"""
|
s"""
|
||||||
|$jt $value = ${genGetValue(array, i)};
|
|
||||||
|if (!$hashSet.contains($hsValueCast$value)) {
|
|if (!$hashSet.contains($hsValueCast$value)) {
|
||||||
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
|
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
|
||||||
| break;
|
| break;
|
||||||
|
@ -3503,7 +3527,16 @@ case class ArrayDistinct(child: Expression)
|
||||||
| $hashSet.add$hsPostFix($hsValueCast$value);
|
| $hashSet.add$hsPostFix($hsValueCast$value);
|
||||||
| $builder.$$plus$$eq($value);
|
| $builder.$$plus$$eq($value);
|
||||||
|}
|
|}
|
||||||
""".stripMargin)
|
""".stripMargin
|
||||||
|
|
||||||
|
val processArray = withArrayNullAssignment(
|
||||||
|
s"$jt $value = ${genGetValue(array, i)};" +
|
||||||
|
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, body,
|
||||||
|
(valueNaN: String) =>
|
||||||
|
s"""
|
||||||
|
|$size++;
|
||||||
|
|$builder.$$plus$$eq($valueNaN);
|
||||||
|
|""".stripMargin))
|
||||||
|
|
||||||
s"""
|
s"""
|
||||||
|$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag);
|
|$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag);
|
||||||
|
@ -3579,8 +3612,16 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
||||||
(array1, array2) =>
|
(array1, array2) =>
|
||||||
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
|
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
|
||||||
val hs = new SQLOpenHashSet[Any]()
|
val hs = new SQLOpenHashSet[Any]()
|
||||||
val isNaN = SQLOpenHashSet.isNaN(elementType)
|
val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
|
||||||
val valueNaN = SQLOpenHashSet.valueNaN(elementType)
|
(value: Any) =>
|
||||||
|
if (!hs.contains(value)) {
|
||||||
|
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
|
||||||
|
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
|
||||||
|
}
|
||||||
|
arrayBuffer += value
|
||||||
|
hs.add(value)
|
||||||
|
},
|
||||||
|
(valueNaN: Any) => arrayBuffer += valueNaN)
|
||||||
Seq(array1, array2).foreach { array =>
|
Seq(array1, array2).foreach { array =>
|
||||||
var i = 0
|
var i = 0
|
||||||
while (i < array.numElements()) {
|
while (i < array.numElements()) {
|
||||||
|
@ -3591,20 +3632,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
val elem = array.get(i, elementType)
|
val elem = array.get(i, elementType)
|
||||||
if (isNaN(elem)) {
|
withNaNCheckFunc(elem)
|
||||||
if (!hs.containsNaN) {
|
|
||||||
arrayBuffer += valueNaN
|
|
||||||
hs.addNaN
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (!hs.contains(elem)) {
|
|
||||||
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
|
|
||||||
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
|
|
||||||
}
|
|
||||||
arrayBuffer += elem
|
|
||||||
hs.add(elem)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
i += 1
|
i += 1
|
||||||
}
|
}
|
||||||
|
@ -3689,28 +3717,6 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
||||||
body
|
body
|
||||||
}
|
}
|
||||||
|
|
||||||
def withNaNCheck(body: String): String = {
|
|
||||||
(elementType match {
|
|
||||||
case DoubleType =>
|
|
||||||
Some((s"java.lang.Double.isNaN((double)$value)", "java.lang.Double.NaN"))
|
|
||||||
case FloatType =>
|
|
||||||
Some((s"java.lang.Float.isNaN((float)$value)", "java.lang.Float.NaN"))
|
|
||||||
case _ => None
|
|
||||||
}).map { case (isNaN, valueNaN) =>
|
|
||||||
s"""
|
|
||||||
|if ($isNaN) {
|
|
||||||
| if (!$hashSet.containsNaN()) {
|
|
||||||
| $size++;
|
|
||||||
| $hashSet.addNaN();
|
|
||||||
| $builder.$$plus$$eq($valueNaN);
|
|
||||||
| }
|
|
||||||
|} else {
|
|
||||||
| $body
|
|
||||||
|}
|
|
||||||
""".stripMargin
|
|
||||||
}
|
|
||||||
}.getOrElse(body)
|
|
||||||
|
|
||||||
val body =
|
val body =
|
||||||
s"""
|
s"""
|
||||||
|if (!$hashSet.contains($hsValueCast$value)) {
|
|if (!$hashSet.contains($hsValueCast$value)) {
|
||||||
|
@ -3721,8 +3727,14 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
|
||||||
| $builder.$$plus$$eq($value);
|
| $builder.$$plus$$eq($value);
|
||||||
|}
|
|}
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
val processArray =
|
val processArray = withArrayNullAssignment(
|
||||||
withArrayNullAssignment(s"$jt $value = ${genGetValue(array, i)};" + withNaNCheck(body))
|
s"$jt $value = ${genGetValue(array, i)};" +
|
||||||
|
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, body,
|
||||||
|
(valueNaN: String) =>
|
||||||
|
s"""
|
||||||
|
|$size++;
|
||||||
|
|$builder.$$plus$$eq($valueNaN);
|
||||||
|
""".stripMargin))
|
||||||
|
|
||||||
// Only need to track null element index when result array's element is nullable.
|
// Only need to track null element index when result array's element is nullable.
|
||||||
val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
|
val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
|
||||||
|
|
|
@ -60,21 +60,55 @@ class SQLOpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
|
||||||
}
|
}
|
||||||
|
|
||||||
object SQLOpenHashSet {
|
object SQLOpenHashSet {
|
||||||
def isNaN(dataType: DataType): Any => Boolean = {
|
def withNaNCheckFunc(
|
||||||
dataType match {
|
dataType: DataType,
|
||||||
|
hashSet: SQLOpenHashSet[Any],
|
||||||
|
handleNotNaN: Any => Unit,
|
||||||
|
handleNaN: Any => Unit): Any => Unit = {
|
||||||
|
val (isNaN, valueNaN) = dataType match {
|
||||||
case DoubleType =>
|
case DoubleType =>
|
||||||
(value: Any) => java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double])
|
((value: Any) => java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double]),
|
||||||
|
java.lang.Double.NaN)
|
||||||
case FloatType =>
|
case FloatType =>
|
||||||
(value: Any) => java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float])
|
((value: Any) => java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float]),
|
||||||
case _ => (_: Any) => false
|
java.lang.Float.NaN)
|
||||||
|
case _ => ((_: Any) => false, null)
|
||||||
}
|
}
|
||||||
|
(value: Any) =>
|
||||||
|
if (isNaN(value)) {
|
||||||
|
if (!hashSet.containsNaN) {
|
||||||
|
hashSet.addNaN
|
||||||
|
handleNaN(valueNaN)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
handleNotNaN(value)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def valueNaN(dataType: DataType): Any = {
|
def withNaNCheckCode(
|
||||||
dataType match {
|
dataType: DataType,
|
||||||
case DoubleType => java.lang.Double.NaN
|
valueName: String,
|
||||||
case FloatType => java.lang.Float.NaN
|
hashSet: String,
|
||||||
case _ => null
|
handleNotNaN: String,
|
||||||
|
handleNaN: String => String): String = {
|
||||||
|
val ret = dataType match {
|
||||||
|
case DoubleType =>
|
||||||
|
Some((s"java.lang.Double.isNaN((double)$valueName)", "java.lang.Double.NaN"))
|
||||||
|
case FloatType =>
|
||||||
|
Some((s"java.lang.Float.isNaN((float)$valueName)", "java.lang.Float.NaN"))
|
||||||
|
case _ => None
|
||||||
}
|
}
|
||||||
|
ret.map { case (isNaN, valueNaN) =>
|
||||||
|
s"""
|
||||||
|
|if ($isNaN) {
|
||||||
|
| if (!$hashSet.containsNaN()) {
|
||||||
|
| $hashSet.addNaN();
|
||||||
|
| ${handleNaN(valueNaN)}
|
||||||
|
| }
|
||||||
|
|} else {
|
||||||
|
| $handleNotNaN
|
||||||
|
|}
|
||||||
|
""".stripMargin
|
||||||
|
}.getOrElse(handleNotNaN)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2310,6 +2310,15 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
||||||
Seq(Float.NaN, null, 1f))
|
Seq(Float.NaN, null, 1f))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-36741: ArrayDistinct should handle duplicated Double.NaN and Float.Nan") {
|
||||||
|
checkEvaluation(ArrayDistinct(
|
||||||
|
Literal.create(Seq(Double.NaN, Double.NaN, null, null, 1d, 1d), ArrayType(DoubleType))),
|
||||||
|
Seq(Double.NaN, null, 1d))
|
||||||
|
checkEvaluation(ArrayDistinct(
|
||||||
|
Literal.create(Seq(Float.NaN, Float.NaN, null, null, 1f, 1f), ArrayType(FloatType))),
|
||||||
|
Seq(Float.NaN, null, 1f))
|
||||||
|
}
|
||||||
|
|
||||||
test("SPARK-36755: ArraysOverlap hould handle duplicated Double.NaN and Float.Nan") {
|
test("SPARK-36755: ArraysOverlap hould handle duplicated Double.NaN and Float.Nan") {
|
||||||
checkEvaluation(ArraysOverlap(
|
checkEvaluation(ArraysOverlap(
|
||||||
Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN))), true)
|
Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN))), true)
|
||||||
|
|
Loading…
Reference in a new issue