[SPARK-10442] [SQL] fix string to boolean cast

When we cast string to boolean in hive, it returns `true` if the length of string is > 0, and spark SQL follows this behavior.

However, this behavior is very different from other SQL systems:

1. [presto](https://github.com/facebook/presto/blob/master/presto-main/src/main/java/com/facebook/presto/type/VarcharOperators.java#L89-L118) will return `true` for 't' 'true' '1', `false` for 'f' 'false' '0', throw exception for others.
2. [redshift](http://docs.aws.amazon.com/redshift/latest/dg/r_Boolean_type.html) will return `true` for 't' 'true' 'y' 'yes' '1', `false` for 'f' 'false' 'n' 'no' '0', null for others.
3. [postgresql](http://www.postgresql.org/docs/devel/static/datatype-boolean.html) will return `true` for 't' 'true' 'y' 'yes' 'on' '1', `false` for 'f' 'false' 'n' 'no' 'off' '0', throw exception for others.
4. [vertica](https://my.vertica.com/docs/5.0/HTML/Master/2983.htm) will return `true` for 't' 'true' 'y' 'yes' '1', `false` for 'f' 'false' 'n' 'no' '0', null for others.
5. [impala](http://www.cloudera.com/content/cloudera/en/documentation/cloudera-impala/latest/topics/impala_boolean.html) throw exception when try to cast string to boolean.
6. mysql, oracle, sqlserver don't have boolean type

Whether we should change the cast behavior according to other SQL system or not is not decided yet, this PR is a test to see if we changed, how many compatibility tests will fail.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #8698 from cloud-fan/string2boolean.
This commit is contained in:
Wenchen Fan 2015-09-11 14:15:16 -07:00 committed by Yin Huai
parent c373866774
commit d5d647380f
4 changed files with 82 additions and 24 deletions

View file

@ -22,7 +22,7 @@ import java.math.{BigDecimal => JavaBigDecimal}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{StringUtils, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@ -140,7 +140,15 @@ case class Cast(child: Expression, dataType: DataType)
// UDFToBoolean
private[this] def castToBoolean(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, _.numBytes() != 0)
buildCast[UTF8String](_, s => {
if (StringUtils.isTrueString(s)) {
true
} else if (StringUtils.isFalseString(s)) {
false
} else {
null
}
})
case TimestampType =>
buildCast[Long](_, t => t != 0)
case DateType =>
@ -646,7 +654,17 @@ case class Cast(child: Expression, dataType: DataType)
private[this] def castToBooleanCode(from: DataType): CastFunction = from match {
case StringType =>
(c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;"
val stringUtils = StringUtils.getClass.getName.stripSuffix("$")
(c, evPrim, evNull) =>
s"""
if ($stringUtils.isTrueString($c)) {
$evPrim = true;
} else if ($stringUtils.isFalseString($c)) {
$evPrim = false;
} else {
$evNull = true;
}
"""
case TimestampType =>
(c, evPrim, evNull) => s"$evPrim = $c != 0;"
case DateType =>

View file

@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util
import java.util.regex.Pattern
import org.apache.spark.unsafe.types.UTF8String
object StringUtils {
// replace the _ with .{1} exactly match 1 time of any character
@ -44,4 +46,10 @@ object StringUtils {
v
}
}
private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString)
private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString)
def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)
}

View file

@ -503,9 +503,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("cast from array") {
val array = Literal.create(Seq("123", "abc", "", null),
val array = Literal.create(Seq("123", "true", "f", null),
ArrayType(StringType, containsNull = true))
val array_notNull = Literal.create(Seq("123", "abc", ""),
val array_notNull = Literal.create(Seq("123", "true", "f"),
ArrayType(StringType, containsNull = false))
checkNullCast(ArrayType(StringType), ArrayType(IntegerType))
@ -522,7 +522,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
{
val ret = cast(array, ArrayType(BooleanType, containsNull = true))
assert(ret.resolved === true)
checkEvaluation(ret, Seq(true, true, false, null))
checkEvaluation(ret, Seq(null, true, false, null))
}
{
val ret = cast(array, ArrayType(BooleanType, containsNull = false))
@ -541,12 +541,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
{
val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true))
assert(ret.resolved === true)
checkEvaluation(ret, Seq(true, true, false))
checkEvaluation(ret, Seq(null, true, false))
}
{
val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false))
assert(ret.resolved === true)
checkEvaluation(ret, Seq(true, true, false))
checkEvaluation(ret, Seq(null, true, false))
}
{
@ -557,10 +557,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
test("cast from map") {
val map = Literal.create(
Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null),
Map("a" -> "123", "b" -> "true", "c" -> "f", "d" -> null),
MapType(StringType, StringType, valueContainsNull = true))
val map_notNull = Literal.create(
Map("a" -> "123", "b" -> "abc", "c" -> ""),
Map("a" -> "123", "b" -> "true", "c" -> "f"),
MapType(StringType, StringType, valueContainsNull = false))
checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType))
@ -577,7 +577,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
{
val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true))
assert(ret.resolved === true)
checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null))
checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false, "d" -> null))
}
{
val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false))
@ -600,12 +600,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
{
val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true))
assert(ret.resolved === true)
checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false))
checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false))
}
{
val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false))
assert(ret.resolved === true)
checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false))
checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false))
}
{
val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true))
@ -630,8 +630,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
val struct = Literal.create(
InternalRow(
UTF8String.fromString("123"),
UTF8String.fromString("abc"),
UTF8String.fromString(""),
UTF8String.fromString("true"),
UTF8String.fromString("f"),
null),
StructType(Seq(
StructField("a", StringType, nullable = true),
@ -641,8 +641,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
val struct_notNull = Literal.create(
InternalRow(
UTF8String.fromString("123"),
UTF8String.fromString("abc"),
UTF8String.fromString("")),
UTF8String.fromString("true"),
UTF8String.fromString("f")),
StructType(Seq(
StructField("a", StringType, nullable = false),
StructField("b", StringType, nullable = false),
@ -672,7 +672,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StructField("c", BooleanType, nullable = true),
StructField("d", BooleanType, nullable = true))))
assert(ret.resolved === true)
checkEvaluation(ret, InternalRow(true, true, false, null))
checkEvaluation(ret, InternalRow(null, true, false, null))
}
{
val ret = cast(struct, StructType(Seq(
@ -704,7 +704,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StructField("b", BooleanType, nullable = true),
StructField("c", BooleanType, nullable = true))))
assert(ret.resolved === true)
checkEvaluation(ret, InternalRow(true, true, false))
checkEvaluation(ret, InternalRow(null, true, false))
}
{
val ret = cast(struct_notNull, StructType(Seq(
@ -712,7 +712,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StructField("b", BooleanType, nullable = true),
StructField("c", BooleanType, nullable = false))))
assert(ret.resolved === true)
checkEvaluation(ret, InternalRow(true, true, false))
checkEvaluation(ret, InternalRow(null, true, false))
}
{
@ -731,8 +731,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
test("complex casting") {
val complex = Literal.create(
Row(
Seq("123", "abc", ""),
Map("a" ->"123", "b" -> "abc", "c" -> ""),
Seq("123", "true", "f"),
Map("a" ->"123", "b" -> "true", "c" -> "f"),
Row(0)),
StructType(Seq(
StructField("a",
@ -755,11 +755,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(ret.resolved === true)
checkEvaluation(ret, Row(
Seq(123, null, null),
Map("a" -> true, "b" -> true, "c" -> false),
Map("a" -> null, "b" -> true, "c" -> false),
Row(0L)))
}
test("case between string and interval") {
test("cast between string and interval") {
import org.apache.spark.unsafe.types.CalendarInterval
checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType),
@ -769,4 +769,23 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StringType),
"interval 1 years 3 months -3 days")
}
test("cast string to boolean") {
checkCast("t", true)
checkCast("true", true)
checkCast("tRUe", true)
checkCast("y", true)
checkCast("yes", true)
checkCast("1", true)
checkCast("f", false)
checkCast("false", false)
checkCast("FAlsE", false)
checkCast("n", false)
checkCast("no", false)
checkCast("0", false)
checkEvaluation(cast("abc", BooleanType), null)
checkEvaluation(cast("", BooleanType), null)
}
}

View file

@ -375,6 +375,19 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
}
}
test("saveAsTable()/load() - partitioned table - boolean type") {
sqlContext.range(2)
.select('id, ('id % 2 === 0).as("b"))
.write.partitionBy("b").saveAsTable("t")
withTable("t") {
checkAnswer(
sqlContext.table("t").sort('id),
Row(0, true) :: Row(1, false) :: Nil
)
}
}
test("saveAsTable()/load() - partitioned table - Overwrite") {
partitionedTestDF.write
.format(dataSourceName)