[SPARK-27924][SQL][FOLLOW-UP] Improve ANSI SQL Boolean-Predicate

### What changes were proposed in this pull request?
This PR follows https://github.com/apache/spark/pull/25074 and improves the implement.

### Why are the changes needed?
Improve code.

### Does this PR introduce any user-facing change?
No

### How was this patch tested?
Exists UT

Closes #27699 from beliefer/improve-boolean-test.

Authored-by: beliefer <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
beliefer 2020-02-27 13:42:02 +08:00 committed by Wenchen Fan
parent 2b744fe885
commit 1515d45b8d
3 changed files with 6 additions and 94 deletions

View file

@ -927,66 +927,6 @@ case class GreaterThanOrEqual(left: Expression, right: Expression)
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2)
}
trait BooleanTest extends UnaryExpression with Predicate with ExpectsInputTypes {
def boolValueForComparison: Boolean
def boolValueWhenNull: Boolean
override def nullable: Boolean = false
override def inputTypes: Seq[DataType] = Seq(BooleanType)
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
Option(value) match {
case None => boolValueWhenNull
case other => if (boolValueWhenNull) {
value == !boolValueForComparison
} else {
value == boolValueForComparison
}
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
ev.copy(code = code"""
${eval.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (${eval.isNull}) {
${ev.value} = $boolValueWhenNull;
} else if ($boolValueWhenNull) {
${ev.value} = ${eval.value} == !$boolValueForComparison;
} else {
${ev.value} = ${eval.value} == $boolValueForComparison;
}
""", isNull = FalseLiteral)
}
}
case class IsTrue(child: Expression) extends BooleanTest {
override def boolValueForComparison: Boolean = true
override def boolValueWhenNull: Boolean = false
override def sql: String = s"(${child.sql} IS TRUE)"
}
case class IsNotTrue(child: Expression) extends BooleanTest {
override def boolValueForComparison: Boolean = true
override def boolValueWhenNull: Boolean = true
override def sql: String = s"(${child.sql} IS NOT TRUE)"
}
case class IsFalse(child: Expression) extends BooleanTest {
override def boolValueForComparison: Boolean = false
override def boolValueWhenNull: Boolean = false
override def sql: String = s"(${child.sql} IS FALSE)"
}
case class IsNotFalse(child: Expression) extends BooleanTest {
override def boolValueForComparison: Boolean = false
override def boolValueWhenNull: Boolean = true
override def sql: String = s"(${child.sql} IS NOT FALSE)"
}
/**
* IS UNKNOWN and IS NOT UNKNOWN are the same as IS NULL and IS NOT NULL, respectively,
* except that the input expression must be of a boolean type.

View file

@ -1414,12 +1414,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
case SqlBaseParser.NULL =>
IsNull(e)
case SqlBaseParser.TRUE => ctx.NOT match {
case null => IsTrue(e)
case _ => IsNotTrue(e)
case null => EqualNullSafe(e, Literal(true))
case _ => Not(EqualNullSafe(e, Literal(true)))
}
case SqlBaseParser.FALSE => ctx.NOT match {
case null => IsFalse(e)
case _ => IsNotFalse(e)
case null => EqualNullSafe(e, Literal(false))
case _ => Not(EqualNullSafe(e, Literal(false)))
}
case SqlBaseParser.UNKNOWN => ctx.NOT match {
case null => IsUnknown(e)

View file

@ -522,37 +522,9 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(expression == expected)
}
val row0 = create_row(null)
val row1 = create_row(false)
val row2 = create_row(true)
test("istrue and isnottrue") {
checkEvaluation(IsTrue(Literal.create(null, BooleanType)), false, row0)
checkEvaluation(IsNotTrue(Literal.create(null, BooleanType)), true, row0)
checkEvaluation(IsTrue(Literal.create(false, BooleanType)), false, row1)
checkEvaluation(IsNotTrue(Literal.create(false, BooleanType)), true, row1)
checkEvaluation(IsTrue(Literal.create(true, BooleanType)), true, row2)
checkEvaluation(IsNotTrue(Literal.create(true, BooleanType)), false, row2)
IsTrue(Literal.create(null, IntegerType)).checkInputDataTypes() match {
case TypeCheckResult.TypeCheckFailure(msg) =>
assert(msg.contains("argument 1 requires boolean type"))
}
}
test("isfalse and isnotfalse") {
checkEvaluation(IsFalse(Literal.create(null, BooleanType)), false, row0)
checkEvaluation(IsNotFalse(Literal.create(null, BooleanType)), true, row0)
checkEvaluation(IsFalse(Literal.create(false, BooleanType)), true, row1)
checkEvaluation(IsNotFalse(Literal.create(false, BooleanType)), false, row1)
checkEvaluation(IsFalse(Literal.create(true, BooleanType)), false, row2)
checkEvaluation(IsNotFalse(Literal.create(true, BooleanType)), true, row2)
IsFalse(Literal.create(null, IntegerType)).checkInputDataTypes() match {
case TypeCheckResult.TypeCheckFailure(msg) =>
assert(msg.contains("argument 1 requires boolean type"))
}
}
test("isunknown and isnotunknown") {
val row0 = create_row(null)
checkEvaluation(IsUnknown(Literal.create(null, BooleanType)), true, row0)
checkEvaluation(IsNotUnknown(Literal.create(null, BooleanType)), false, row0)
IsUnknown(Literal.create(null, IntegerType)).checkInputDataTypes() match {