[SPARK-14580][SPARK-14655][SQL] Hive IfCoercion should preserve predicate.

## What changes were proposed in this pull request?

Currently, `HiveTypeCoercion.IfCoercion` removes all predicates whose return-type are null. However, some UDFs need evaluations because they are designed to throw exceptions. This PR fixes that to preserve the predicates. Also, `assert_true` is implemented as Spark SQL function.

**Before**
```
scala> sql("select if(assert_true(false),2,3)").head
res2: org.apache.spark.sql.Row = [3]
```

**After**
```
scala> sql("select if(assert_true(false),2,3)").head
... ASSERT_TRUE ...
```

**Hive**
```
hive> select if(assert_true(false),2,3);
OK
Failed with exception java.io.IOException:org.apache.hadoop.hive.ql.metadata.HiveException: ASSERT_TRUE(): assertion failed.
```

## How was this patch tested?

Pass the Jenkins tests (including a new testcase in `HivePlanTest`)

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #12340 from dongjoon-hyun/SPARK-14580.
This commit is contained in:
Dongjoon Hyun 2016-04-18 12:26:56 -07:00 committed by Reynold Xin
parent b64482f49f
commit d280d1da1a
5 changed files with 70 additions and 7 deletions

View file

@ -329,6 +329,7 @@ object FunctionRegistry {
expression[SortArray]("sort_array"),
// misc functions
expression[AssertTrue]("assert_true"),
expression[Crc32]("crc32"),
expression[Md5]("md5"),
expression[Murmur3Hash]("hash"),

View file

@ -584,10 +584,10 @@ object HiveTypeCoercion {
val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
If(pred, newLeft, newRight)
}.getOrElse(i) // If there is no applicable conversion, leave expression unchanged.
// Convert If(null literal, _, _) into boolean type.
// In the optimizer, we should short-circuit this directly into false value.
case If(pred, left, right) if pred.dataType == NullType =>
case If(Literal(null, NullType), left, right) =>
If(Literal.create(null, BooleanType), left, right)
case If(pred, left, right) if pred.dataType == NullType =>
If(Cast(pred, BooleanType), left, right)
}
}

View file

@ -486,6 +486,44 @@ case class PrintToStderr(child: Expression) extends UnaryExpression {
}
}
/**
* A function throws an exception if 'condition' is not true.
*/
@ExpressionDescription(
usage = "_FUNC_(condition) - Throw an exception if 'condition' is not true.")
case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def nullable: Boolean = true
override def inputTypes: Seq[DataType] = Seq(BooleanType)
override def dataType: DataType = NullType
override def prettyName: String = "assert_true"
override def eval(input: InternalRow) : Any = {
val v = child.eval(input)
if (v == null || java.lang.Boolean.FALSE.equals(v)) {
throw new RuntimeException(s"'${child.simpleString}' is not true!")
} else {
null
}
}
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val eval = child.gen(ctx)
ev.isNull = "true"
ev.value = "null"
s"""${eval.code}
|if (${eval.isNull} || !${eval.value}) {
| throw new RuntimeException("'${child.simpleString}' is not true.");
|}
""".stripMargin
}
override def sql: String = s"assert_true(${child.sql})"
}
/**
* A xxHash64 64-bit hash expression.
*/

View file

@ -348,15 +348,22 @@ class HiveTypeCoercionSuite extends PlanTest {
test("type coercion for If") {
val rule = HiveTypeCoercion.IfCoercion
ruleTest(rule,
If(Literal(true), Literal(1), Literal(1L)),
If(Literal(true), Cast(Literal(1), LongType), Literal(1L))
)
If(Literal(true), Cast(Literal(1), LongType), Literal(1L)))
ruleTest(rule,
If(Literal.create(null, NullType), Literal(1), Literal(1)),
If(Literal.create(null, BooleanType), Literal(1), Literal(1))
)
If(Literal.create(null, BooleanType), Literal(1), Literal(1)))
ruleTest(rule,
If(AssertTrue(Literal.create(true, BooleanType)), Literal(1), Literal(2)),
If(Cast(AssertTrue(Literal.create(true, BooleanType)), BooleanType), Literal(1), Literal(2)))
ruleTest(rule,
If(AssertTrue(Literal.create(false, BooleanType)), Literal(1), Literal(2)),
If(Cast(AssertTrue(Literal.create(false, BooleanType)), BooleanType), Literal(1), Literal(2)))
}
test("type coercion for CaseKeyWhen") {

View file

@ -69,6 +69,23 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
}
test("assert_true") {
intercept[RuntimeException] {
checkEvaluation(AssertTrue(Literal(false, BooleanType)), null)
}
intercept[RuntimeException] {
checkEvaluation(AssertTrue(Cast(Literal(0), BooleanType)), null)
}
intercept[RuntimeException] {
checkEvaluation(AssertTrue(Literal.create(null, NullType)), null)
}
intercept[RuntimeException] {
checkEvaluation(AssertTrue(Literal.create(null, BooleanType)), null)
}
checkEvaluation(AssertTrue(Literal(true, BooleanType)), null)
checkEvaluation(AssertTrue(Cast(Literal(1), BooleanType)), null)
}
private val structOfString = new StructType().add("str", StringType)
private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false)
private val arrayOfString = ArrayType(StringType)