[SPARK-14580][SPARK-14655][SQL] Hive IfCoercion should preserve predicate.
## What changes were proposed in this pull request? Currently, `HiveTypeCoercion.IfCoercion` removes all predicates whose return-type are null. However, some UDFs need evaluations because they are designed to throw exceptions. This PR fixes that to preserve the predicates. Also, `assert_true` is implemented as Spark SQL function. **Before** ``` scala> sql("select if(assert_true(false),2,3)").head res2: org.apache.spark.sql.Row = [3] ``` **After** ``` scala> sql("select if(assert_true(false),2,3)").head ... ASSERT_TRUE ... ``` **Hive** ``` hive> select if(assert_true(false),2,3); OK Failed with exception java.io.IOException:org.apache.hadoop.hive.ql.metadata.HiveException: ASSERT_TRUE(): assertion failed. ``` ## How was this patch tested? Pass the Jenkins tests (including a new testcase in `HivePlanTest`) Author: Dongjoon Hyun <dongjoon@apache.org> Closes #12340 from dongjoon-hyun/SPARK-14580.
This commit is contained in:
parent
b64482f49f
commit
d280d1da1a
|
@ -329,6 +329,7 @@ object FunctionRegistry {
|
|||
expression[SortArray]("sort_array"),
|
||||
|
||||
// misc functions
|
||||
expression[AssertTrue]("assert_true"),
|
||||
expression[Crc32]("crc32"),
|
||||
expression[Md5]("md5"),
|
||||
expression[Murmur3Hash]("hash"),
|
||||
|
|
|
@ -584,10 +584,10 @@ object HiveTypeCoercion {
|
|||
val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
|
||||
If(pred, newLeft, newRight)
|
||||
}.getOrElse(i) // If there is no applicable conversion, leave expression unchanged.
|
||||
// Convert If(null literal, _, _) into boolean type.
|
||||
// In the optimizer, we should short-circuit this directly into false value.
|
||||
case If(pred, left, right) if pred.dataType == NullType =>
|
||||
case If(Literal(null, NullType), left, right) =>
|
||||
If(Literal.create(null, BooleanType), left, right)
|
||||
case If(pred, left, right) if pred.dataType == NullType =>
|
||||
If(Cast(pred, BooleanType), left, right)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -486,6 +486,44 @@ case class PrintToStderr(child: Expression) extends UnaryExpression {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A function throws an exception if 'condition' is not true.
|
||||
*/
|
||||
@ExpressionDescription(
|
||||
usage = "_FUNC_(condition) - Throw an exception if 'condition' is not true.")
|
||||
case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
|
||||
|
||||
override def nullable: Boolean = true
|
||||
|
||||
override def inputTypes: Seq[DataType] = Seq(BooleanType)
|
||||
|
||||
override def dataType: DataType = NullType
|
||||
|
||||
override def prettyName: String = "assert_true"
|
||||
|
||||
override def eval(input: InternalRow) : Any = {
|
||||
val v = child.eval(input)
|
||||
if (v == null || java.lang.Boolean.FALSE.equals(v)) {
|
||||
throw new RuntimeException(s"'${child.simpleString}' is not true!")
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
|
||||
val eval = child.gen(ctx)
|
||||
ev.isNull = "true"
|
||||
ev.value = "null"
|
||||
s"""${eval.code}
|
||||
|if (${eval.isNull} || !${eval.value}) {
|
||||
| throw new RuntimeException("'${child.simpleString}' is not true.");
|
||||
|}
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
override def sql: String = s"assert_true(${child.sql})"
|
||||
}
|
||||
|
||||
/**
|
||||
* A xxHash64 64-bit hash expression.
|
||||
*/
|
||||
|
|
|
@ -348,15 +348,22 @@ class HiveTypeCoercionSuite extends PlanTest {
|
|||
|
||||
test("type coercion for If") {
|
||||
val rule = HiveTypeCoercion.IfCoercion
|
||||
|
||||
ruleTest(rule,
|
||||
If(Literal(true), Literal(1), Literal(1L)),
|
||||
If(Literal(true), Cast(Literal(1), LongType), Literal(1L))
|
||||
)
|
||||
If(Literal(true), Cast(Literal(1), LongType), Literal(1L)))
|
||||
|
||||
ruleTest(rule,
|
||||
If(Literal.create(null, NullType), Literal(1), Literal(1)),
|
||||
If(Literal.create(null, BooleanType), Literal(1), Literal(1))
|
||||
)
|
||||
If(Literal.create(null, BooleanType), Literal(1), Literal(1)))
|
||||
|
||||
ruleTest(rule,
|
||||
If(AssertTrue(Literal.create(true, BooleanType)), Literal(1), Literal(2)),
|
||||
If(Cast(AssertTrue(Literal.create(true, BooleanType)), BooleanType), Literal(1), Literal(2)))
|
||||
|
||||
ruleTest(rule,
|
||||
If(AssertTrue(Literal.create(false, BooleanType)), Literal(1), Literal(2)),
|
||||
If(Cast(AssertTrue(Literal.create(false, BooleanType)), BooleanType), Literal(1), Literal(2)))
|
||||
}
|
||||
|
||||
test("type coercion for CaseKeyWhen") {
|
||||
|
|
|
@ -69,6 +69,23 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
|
||||
}
|
||||
|
||||
test("assert_true") {
|
||||
intercept[RuntimeException] {
|
||||
checkEvaluation(AssertTrue(Literal(false, BooleanType)), null)
|
||||
}
|
||||
intercept[RuntimeException] {
|
||||
checkEvaluation(AssertTrue(Cast(Literal(0), BooleanType)), null)
|
||||
}
|
||||
intercept[RuntimeException] {
|
||||
checkEvaluation(AssertTrue(Literal.create(null, NullType)), null)
|
||||
}
|
||||
intercept[RuntimeException] {
|
||||
checkEvaluation(AssertTrue(Literal.create(null, BooleanType)), null)
|
||||
}
|
||||
checkEvaluation(AssertTrue(Literal(true, BooleanType)), null)
|
||||
checkEvaluation(AssertTrue(Cast(Literal(1), BooleanType)), null)
|
||||
}
|
||||
|
||||
private val structOfString = new StructType().add("str", StringType)
|
||||
private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false)
|
||||
private val arrayOfString = ArrayType(StringType)
|
||||
|
|
Loading…
Reference in a new issue