[SPARK-36792][SQL] InSet should handle NaN
### What changes were proposed in this pull request?
InSet should handle NaN
```
InSet(Literal(Double.NaN), Set(Double.NaN, 1d)) should return true, but return false.
```
### Why are the changes needed?
InSet should handle NaN
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added UT
Closes #34033 from AngersZhuuuu/SPARK-36792.
Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 64f4bf47af
)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
d0c97d6ed9
commit
b7174188e5
|
@ -554,6 +554,16 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
|
||||||
}
|
}
|
||||||
|
|
||||||
@transient private[this] lazy val hasNull: Boolean = hset.contains(null)
|
@transient private[this] lazy val hasNull: Boolean = hset.contains(null)
|
||||||
|
@transient private[this] lazy val isNaN: Any => Boolean = child.dataType match {
|
||||||
|
case DoubleType => (value: Any) => java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double])
|
||||||
|
case FloatType => (value: Any) => java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float])
|
||||||
|
case _ => (_: Any) => false
|
||||||
|
}
|
||||||
|
@transient private[this] lazy val hasNaN = child.dataType match {
|
||||||
|
case DoubleType | FloatType => set.exists(isNaN)
|
||||||
|
case _ => false
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
override def nullable: Boolean = child.nullable || hasNull
|
override def nullable: Boolean = child.nullable || hasNull
|
||||||
|
|
||||||
|
@ -562,6 +572,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
|
||||||
protected override def nullSafeEval(value: Any): Any = {
|
protected override def nullSafeEval(value: Any): Any = {
|
||||||
if (set.contains(value)) {
|
if (set.contains(value)) {
|
||||||
true
|
true
|
||||||
|
} else if (isNaN(value)) {
|
||||||
|
hasNaN
|
||||||
} else if (hasNull) {
|
} else if (hasNull) {
|
||||||
null
|
null
|
||||||
} else {
|
} else {
|
||||||
|
@ -593,15 +605,33 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
|
||||||
private def genCodeWithSet(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
private def genCodeWithSet(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||||
nullSafeCodeGen(ctx, ev, c => {
|
nullSafeCodeGen(ctx, ev, c => {
|
||||||
val setTerm = ctx.addReferenceObj("set", set)
|
val setTerm = ctx.addReferenceObj("set", set)
|
||||||
|
|
||||||
val setIsNull = if (hasNull) {
|
val setIsNull = if (hasNull) {
|
||||||
s"${ev.isNull} = !${ev.value};"
|
s"${ev.isNull} = !${ev.value};"
|
||||||
} else {
|
} else {
|
||||||
""
|
""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val ret = child.dataType match {
|
||||||
|
case DoubleType => Some((v: Any) => s"java.lang.Double.isNaN($v)")
|
||||||
|
case FloatType => Some((v: Any) => s"java.lang.Float.isNaN($v)")
|
||||||
|
case _ => None
|
||||||
|
}
|
||||||
|
|
||||||
|
ret.map { isNaN =>
|
||||||
|
s"""
|
||||||
|
|if ($setTerm.contains($c)) {
|
||||||
|
| ${ev.value} = true;
|
||||||
|
|} else if (${isNaN(c)}) {
|
||||||
|
| ${ev.value} = $hasNaN;
|
||||||
|
|}
|
||||||
|
|$setIsNull
|
||||||
|
|""".stripMargin
|
||||||
|
}.getOrElse(
|
||||||
s"""
|
s"""
|
||||||
|${ev.value} = $setTerm.contains($c);
|
|${ev.value} = $setTerm.contains($c);
|
||||||
|$setIsNull
|
|$setIsNull
|
||||||
""".stripMargin
|
""".stripMargin)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -644,4 +644,18 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||||
checkExpr(GreaterThan, Double.NaN, Double.NaN, false)
|
checkExpr(GreaterThan, Double.NaN, Double.NaN, false)
|
||||||
checkExpr(GreaterThan, 0.0, -0.0, false)
|
checkExpr(GreaterThan, 0.0, -0.0, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-36792: InSet should handle Double.NaN and Float.NaN") {
|
||||||
|
checkInAndInSet(In(Literal(Double.NaN), Seq(Literal(Double.NaN), Literal(2d))), true)
|
||||||
|
checkInAndInSet(In(Literal.create(null, DoubleType),
|
||||||
|
Seq(Literal(Double.NaN), Literal(2d), Literal.create(null, DoubleType))), null)
|
||||||
|
checkInAndInSet(In(Literal.create(null, DoubleType),
|
||||||
|
Seq(Literal(Double.NaN), Literal(2d))), null)
|
||||||
|
checkInAndInSet(In(Literal(3d),
|
||||||
|
Seq(Literal(Double.NaN), Literal(2d))), false)
|
||||||
|
checkInAndInSet(In(Literal(3d),
|
||||||
|
Seq(Literal(Double.NaN), Literal(2d), Literal.create(null, DoubleType))), null)
|
||||||
|
checkInAndInSet(In(Literal(Double.NaN),
|
||||||
|
Seq(Literal(Double.NaN), Literal(2d), Literal.create(null, DoubleType))), true)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue