[SPARK-27924][SQL][FOLLOW-UP] Improve ANSI SQL Boolean-Predicate
### What changes were proposed in this pull request? This PR follows https://github.com/apache/spark/pull/25074 and improves the implement. ### Why are the changes needed? Improve code. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Exists UT Closes #27699 from beliefer/improve-boolean-test. Authored-by: beliefer <beliefer@163.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
2b744fe885
commit
1515d45b8d
|
@ -927,66 +927,6 @@ case class GreaterThanOrEqual(left: Expression, right: Expression)
|
|||
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2)
|
||||
}
|
||||
|
||||
trait BooleanTest extends UnaryExpression with Predicate with ExpectsInputTypes {
|
||||
|
||||
def boolValueForComparison: Boolean
|
||||
def boolValueWhenNull: Boolean
|
||||
|
||||
override def nullable: Boolean = false
|
||||
override def inputTypes: Seq[DataType] = Seq(BooleanType)
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val value = child.eval(input)
|
||||
Option(value) match {
|
||||
case None => boolValueWhenNull
|
||||
case other => if (boolValueWhenNull) {
|
||||
value == !boolValueForComparison
|
||||
} else {
|
||||
value == boolValueForComparison
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val eval = child.genCode(ctx)
|
||||
ev.copy(code = code"""
|
||||
${eval.code}
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (${eval.isNull}) {
|
||||
${ev.value} = $boolValueWhenNull;
|
||||
} else if ($boolValueWhenNull) {
|
||||
${ev.value} = ${eval.value} == !$boolValueForComparison;
|
||||
} else {
|
||||
${ev.value} = ${eval.value} == $boolValueForComparison;
|
||||
}
|
||||
""", isNull = FalseLiteral)
|
||||
}
|
||||
}
|
||||
|
||||
case class IsTrue(child: Expression) extends BooleanTest {
|
||||
override def boolValueForComparison: Boolean = true
|
||||
override def boolValueWhenNull: Boolean = false
|
||||
override def sql: String = s"(${child.sql} IS TRUE)"
|
||||
}
|
||||
|
||||
case class IsNotTrue(child: Expression) extends BooleanTest {
|
||||
override def boolValueForComparison: Boolean = true
|
||||
override def boolValueWhenNull: Boolean = true
|
||||
override def sql: String = s"(${child.sql} IS NOT TRUE)"
|
||||
}
|
||||
|
||||
case class IsFalse(child: Expression) extends BooleanTest {
|
||||
override def boolValueForComparison: Boolean = false
|
||||
override def boolValueWhenNull: Boolean = false
|
||||
override def sql: String = s"(${child.sql} IS FALSE)"
|
||||
}
|
||||
|
||||
case class IsNotFalse(child: Expression) extends BooleanTest {
|
||||
override def boolValueForComparison: Boolean = false
|
||||
override def boolValueWhenNull: Boolean = true
|
||||
override def sql: String = s"(${child.sql} IS NOT FALSE)"
|
||||
}
|
||||
|
||||
/**
|
||||
* IS UNKNOWN and IS NOT UNKNOWN are the same as IS NULL and IS NOT NULL, respectively,
|
||||
* except that the input expression must be of a boolean type.
|
||||
|
|
|
@ -1414,12 +1414,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
|
|||
case SqlBaseParser.NULL =>
|
||||
IsNull(e)
|
||||
case SqlBaseParser.TRUE => ctx.NOT match {
|
||||
case null => IsTrue(e)
|
||||
case _ => IsNotTrue(e)
|
||||
case null => EqualNullSafe(e, Literal(true))
|
||||
case _ => Not(EqualNullSafe(e, Literal(true)))
|
||||
}
|
||||
case SqlBaseParser.FALSE => ctx.NOT match {
|
||||
case null => IsFalse(e)
|
||||
case _ => IsNotFalse(e)
|
||||
case null => EqualNullSafe(e, Literal(false))
|
||||
case _ => Not(EqualNullSafe(e, Literal(false)))
|
||||
}
|
||||
case SqlBaseParser.UNKNOWN => ctx.NOT match {
|
||||
case null => IsUnknown(e)
|
||||
|
|
|
@ -522,37 +522,9 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
assert(expression == expected)
|
||||
}
|
||||
|
||||
val row0 = create_row(null)
|
||||
val row1 = create_row(false)
|
||||
val row2 = create_row(true)
|
||||
|
||||
test("istrue and isnottrue") {
|
||||
checkEvaluation(IsTrue(Literal.create(null, BooleanType)), false, row0)
|
||||
checkEvaluation(IsNotTrue(Literal.create(null, BooleanType)), true, row0)
|
||||
checkEvaluation(IsTrue(Literal.create(false, BooleanType)), false, row1)
|
||||
checkEvaluation(IsNotTrue(Literal.create(false, BooleanType)), true, row1)
|
||||
checkEvaluation(IsTrue(Literal.create(true, BooleanType)), true, row2)
|
||||
checkEvaluation(IsNotTrue(Literal.create(true, BooleanType)), false, row2)
|
||||
IsTrue(Literal.create(null, IntegerType)).checkInputDataTypes() match {
|
||||
case TypeCheckResult.TypeCheckFailure(msg) =>
|
||||
assert(msg.contains("argument 1 requires boolean type"))
|
||||
}
|
||||
}
|
||||
|
||||
test("isfalse and isnotfalse") {
|
||||
checkEvaluation(IsFalse(Literal.create(null, BooleanType)), false, row0)
|
||||
checkEvaluation(IsNotFalse(Literal.create(null, BooleanType)), true, row0)
|
||||
checkEvaluation(IsFalse(Literal.create(false, BooleanType)), true, row1)
|
||||
checkEvaluation(IsNotFalse(Literal.create(false, BooleanType)), false, row1)
|
||||
checkEvaluation(IsFalse(Literal.create(true, BooleanType)), false, row2)
|
||||
checkEvaluation(IsNotFalse(Literal.create(true, BooleanType)), true, row2)
|
||||
IsFalse(Literal.create(null, IntegerType)).checkInputDataTypes() match {
|
||||
case TypeCheckResult.TypeCheckFailure(msg) =>
|
||||
assert(msg.contains("argument 1 requires boolean type"))
|
||||
}
|
||||
}
|
||||
|
||||
test("isunknown and isnotunknown") {
|
||||
val row0 = create_row(null)
|
||||
|
||||
checkEvaluation(IsUnknown(Literal.create(null, BooleanType)), true, row0)
|
||||
checkEvaluation(IsNotUnknown(Literal.create(null, BooleanType)), false, row0)
|
||||
IsUnknown(Literal.create(null, IntegerType)).checkInputDataTypes() match {
|
||||
|
|
Loading…
Reference in a new issue