[SPARK-9169][SQL] Improve unit test coverage for null expressions.
Author: Reynold Xin <rxin@databricks.com> Closes #7490 from rxin/unit-test-null-funcs and squashes the following commits: 7b276f0 [Reynold Xin] Move isNaN. 8307287 [Reynold Xin] [SPARK-9169][SQL] Improve unit test coverage for null expressions.
This commit is contained in:
parent
b9ef7ac98c
commit
fba3f5ba85
|
@ -21,8 +21,19 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
|
||||
import org.apache.spark.sql.catalyst.util.TypeUtils
|
||||
import org.apache.spark.sql.types.DataType
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
||||
/**
|
||||
* An expression that is evaluated to the first non-null input.
|
||||
*
|
||||
* {{{
|
||||
* coalesce(1, 2) => 1
|
||||
* coalesce(null, 1, 2) => 1
|
||||
* coalesce(null, null, 2) => 2
|
||||
* coalesce(null, null, null) => null
|
||||
* }}}
|
||||
*/
|
||||
case class Coalesce(children: Seq[Expression]) extends Expression {
|
||||
|
||||
/** Coalesce is nullable if all of its children are nullable, or if it has no children. */
|
||||
|
@ -70,6 +81,62 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Evaluates to `true` if it's NaN or null
|
||||
*/
|
||||
case class IsNaN(child: Expression) extends UnaryExpression
|
||||
with Predicate with ImplicitCastInputTypes {
|
||||
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType))
|
||||
|
||||
override def nullable: Boolean = false
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val value = child.eval(input)
|
||||
if (value == null) {
|
||||
true
|
||||
} else {
|
||||
child.dataType match {
|
||||
case DoubleType => value.asInstanceOf[Double].isNaN
|
||||
case FloatType => value.asInstanceOf[Float].isNaN
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||
val eval = child.gen(ctx)
|
||||
child.dataType match {
|
||||
case FloatType =>
|
||||
s"""
|
||||
${eval.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
|
||||
if (${eval.isNull}) {
|
||||
${ev.primitive} = true;
|
||||
} else {
|
||||
${ev.primitive} = Float.isNaN(${eval.primitive});
|
||||
}
|
||||
"""
|
||||
case DoubleType =>
|
||||
s"""
|
||||
${eval.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
|
||||
if (${eval.isNull}) {
|
||||
${ev.primitive} = true;
|
||||
} else {
|
||||
${ev.primitive} = Double.isNaN(${eval.primitive});
|
||||
}
|
||||
"""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* An expression that is evaluated to true if the input is null.
|
||||
*/
|
||||
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
|
||||
override def nullable: Boolean = false
|
||||
|
||||
|
@ -83,13 +150,14 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
|
|||
ev.primitive = eval.isNull
|
||||
eval.code
|
||||
}
|
||||
|
||||
override def toString: String = s"IS NULL $child"
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* An expression that is evaluated to true if the input is not null.
|
||||
*/
|
||||
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
|
||||
override def nullable: Boolean = false
|
||||
override def toString: String = s"IS NOT NULL $child"
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
child.eval(input) != null
|
||||
|
@ -103,12 +171,13 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A predicate that is evaluated to be true if there are at least `n` non-null values.
|
||||
* A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values.
|
||||
*/
|
||||
case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate {
|
||||
override def nullable: Boolean = false
|
||||
override def foldable: Boolean = false
|
||||
override def foldable: Boolean = children.forall(_.foldable)
|
||||
override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})"
|
||||
|
||||
private[this] val childrenArray = children.toArray
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.util.TypeUtils
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
|
@ -120,56 +119,6 @@ case class InSet(child: Expression, hset: Set[Any])
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluates to `true` if it's NaN or null
|
||||
*/
|
||||
case class IsNaN(child: Expression) extends UnaryExpression
|
||||
with Predicate with ImplicitCastInputTypes {
|
||||
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType))
|
||||
|
||||
override def nullable: Boolean = false
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val value = child.eval(input)
|
||||
if (value == null) {
|
||||
true
|
||||
} else {
|
||||
child.dataType match {
|
||||
case DoubleType => value.asInstanceOf[Double].isNaN
|
||||
case FloatType => value.asInstanceOf[Float].isNaN
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||
val eval = child.gen(ctx)
|
||||
child.dataType match {
|
||||
case FloatType =>
|
||||
s"""
|
||||
${eval.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
|
||||
if (${eval.isNull}) {
|
||||
${ev.primitive} = true;
|
||||
} else {
|
||||
${ev.primitive} = Float.isNaN(${eval.primitive});
|
||||
}
|
||||
"""
|
||||
case DoubleType =>
|
||||
s"""
|
||||
${eval.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
|
||||
if (${eval.isNull}) {
|
||||
${ev.primitive} = true;
|
||||
} else {
|
||||
${ev.primitive} = Double.isNaN(${eval.primitive});
|
||||
}
|
||||
"""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {
|
||||
|
||||
|
|
|
@ -18,48 +18,52 @@
|
|||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.types.{BooleanType, StringType, ShortType}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||
|
||||
test("null checking") {
|
||||
val row = create_row("^Ba*n", null, true, null)
|
||||
val c1 = 'a.string.at(0)
|
||||
val c2 = 'a.string.at(1)
|
||||
val c3 = 'a.boolean.at(2)
|
||||
val c4 = 'a.boolean.at(3)
|
||||
def testAllTypes(testFunc: (Any, DataType) => Unit): Unit = {
|
||||
testFunc(false, BooleanType)
|
||||
testFunc(1.toByte, ByteType)
|
||||
testFunc(1.toShort, ShortType)
|
||||
testFunc(1, IntegerType)
|
||||
testFunc(1L, LongType)
|
||||
testFunc(1.0F, FloatType)
|
||||
testFunc(1.0, DoubleType)
|
||||
testFunc(Decimal(1.5), DecimalType.Unlimited)
|
||||
testFunc(new java.sql.Date(10), DateType)
|
||||
testFunc(new java.sql.Timestamp(10), TimestampType)
|
||||
testFunc("abcd", StringType)
|
||||
}
|
||||
|
||||
checkEvaluation(c1.isNull, false, row)
|
||||
checkEvaluation(c1.isNotNull, true, row)
|
||||
test("isnull and isnotnull") {
|
||||
testAllTypes { (value: Any, tpe: DataType) =>
|
||||
checkEvaluation(IsNull(Literal.create(value, tpe)), false)
|
||||
checkEvaluation(IsNotNull(Literal.create(value, tpe)), true)
|
||||
checkEvaluation(IsNull(Literal.create(null, tpe)), true)
|
||||
checkEvaluation(IsNotNull(Literal.create(null, tpe)), false)
|
||||
}
|
||||
}
|
||||
|
||||
checkEvaluation(c2.isNull, true, row)
|
||||
checkEvaluation(c2.isNotNull, false, row)
|
||||
test("IsNaN") {
|
||||
checkEvaluation(IsNaN(Literal(Double.NaN)), true)
|
||||
checkEvaluation(IsNaN(Literal(Float.NaN)), true)
|
||||
checkEvaluation(IsNaN(Literal(math.log(-3))), true)
|
||||
checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true)
|
||||
checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false)
|
||||
checkEvaluation(IsNaN(Literal(Float.MaxValue)), false)
|
||||
checkEvaluation(IsNaN(Literal(5.5f)), false)
|
||||
}
|
||||
|
||||
checkEvaluation(Literal.create(1, ShortType).isNull, false)
|
||||
checkEvaluation(Literal.create(1, ShortType).isNotNull, true)
|
||||
|
||||
checkEvaluation(Literal.create(null, ShortType).isNull, true)
|
||||
checkEvaluation(Literal.create(null, ShortType).isNotNull, false)
|
||||
|
||||
checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row)
|
||||
checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row)
|
||||
checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row)
|
||||
|
||||
checkEvaluation(
|
||||
If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row)
|
||||
checkEvaluation(If(c3, c1, c2), "^Ba*n", row)
|
||||
checkEvaluation(If(c4, c2, c1), "^Ba*n", row)
|
||||
checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row)
|
||||
checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", row)
|
||||
checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", row)
|
||||
checkEvaluation(If(Literal.create(false, BooleanType),
|
||||
Literal.create("a", StringType), Literal.create("b", StringType)), "b", row)
|
||||
|
||||
checkEvaluation(c1 in (c1, c2), true, row)
|
||||
checkEvaluation(
|
||||
Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType)), true, row)
|
||||
checkEvaluation(
|
||||
Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType), c2), true, row)
|
||||
test("coalesce") {
|
||||
testAllTypes { (value: Any, tpe: DataType) =>
|
||||
val lit = Literal.create(value, tpe)
|
||||
val nullLit = Literal.create(null, tpe)
|
||||
checkEvaluation(Coalesce(Seq(nullLit)), null)
|
||||
checkEvaluation(Coalesce(Seq(lit)), value)
|
||||
checkEvaluation(Coalesce(Seq(nullLit, lit)), value)
|
||||
checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value)
|
||||
checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -114,16 +114,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(
|
||||
And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))),
|
||||
true)
|
||||
}
|
||||
|
||||
test("IsNaN") {
|
||||
checkEvaluation(IsNaN(Literal(Double.NaN)), true)
|
||||
checkEvaluation(IsNaN(Literal(Float.NaN)), true)
|
||||
checkEvaluation(IsNaN(Literal(math.log(-3))), true)
|
||||
checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true)
|
||||
checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false)
|
||||
checkEvaluation(IsNaN(Literal(Float.MaxValue)), false)
|
||||
checkEvaluation(IsNaN(Literal(5.5f)), false)
|
||||
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true)
|
||||
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
|
||||
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)
|
||||
}
|
||||
|
||||
test("INSET") {
|
||||
|
|
Loading…
Reference in a new issue