[SPARK-25734][SQL] Literal should have a value corresponding to dataType

## What changes were proposed in this pull request?
`Literal.value` should have a value a value corresponding to `dataType`. This pr added code to verify it and fixed the existing tests to do so.

## How was this patch tested?
Modified the existing tests.

Closes #22724 from maropu/SPARK-25734.

Authored-by: Takeshi Yamamuro <yamamuro@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Takeshi Yamamuro 2018-10-17 11:02:39 +08:00 committed by Wenchen Fan
parent e9af9460bc
commit a9f685bb70
10 changed files with 92 additions and 39 deletions

View file

@ -52,7 +52,7 @@ private[kafka010] object KafkaWriter extends Logging {
s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " +
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.")
} else {
Literal(topic.get, StringType)
Literal.create(topic.get, StringType)
}
).dataType match {
case StringType => // good

View file

@ -40,9 +40,10 @@ import org.json4s.JsonAST._
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types._
import org.apache.spark.util.Utils
object Literal {
val TrueLiteral: Literal = Literal(true, BooleanType)
@ -196,6 +197,47 @@ object Literal {
case other =>
throw new RuntimeException(s"no default for type $dataType")
}
private[expressions] def validateLiteralValue(value: Any, dataType: DataType): Unit = {
def doValidate(v: Any, dataType: DataType): Boolean = dataType match {
case _ if v == null => true
case BooleanType => v.isInstanceOf[Boolean]
case ByteType => v.isInstanceOf[Byte]
case ShortType => v.isInstanceOf[Short]
case IntegerType | DateType => v.isInstanceOf[Int]
case LongType | TimestampType => v.isInstanceOf[Long]
case FloatType => v.isInstanceOf[Float]
case DoubleType => v.isInstanceOf[Double]
case _: DecimalType => v.isInstanceOf[Decimal]
case CalendarIntervalType => v.isInstanceOf[CalendarInterval]
case BinaryType => v.isInstanceOf[Array[Byte]]
case StringType => v.isInstanceOf[UTF8String]
case st: StructType =>
v.isInstanceOf[InternalRow] && {
val row = v.asInstanceOf[InternalRow]
st.fields.map(_.dataType).zipWithIndex.forall {
case (dt, i) => doValidate(row.get(i, dt), dt)
}
}
case at: ArrayType =>
v.isInstanceOf[ArrayData] && {
val ar = v.asInstanceOf[ArrayData]
ar.numElements() == 0 || doValidate(ar.get(0, at.elementType), at.elementType)
}
case mt: MapType =>
v.isInstanceOf[MapData] && {
val map = v.asInstanceOf[MapData]
doValidate(map.keyArray(), ArrayType(mt.keyType)) &&
doValidate(map.valueArray(), ArrayType(mt.valueType))
}
case ObjectType(cls) => cls.isInstance(v)
case udt: UserDefinedType[_] => doValidate(v, udt.sqlType)
case _ => false
}
require(doValidate(value, dataType),
s"Literal must have a corresponding value to ${dataType.catalogString}, " +
s"but class ${Utils.getSimpleName(value.getClass)} found.")
}
}
/**
@ -240,6 +282,8 @@ object DecimalLiteral {
*/
case class Literal (value: Any, dataType: DataType) extends LeafExpression {
Literal.validateLiteralValue(value, dataType)
override def foldable: Boolean = true
override def nullable: Boolean = value == null

View file

@ -742,7 +742,7 @@ class TypeCoercionSuite extends AnalysisTest {
val nullLit = Literal.create(null, NullType)
val floatNullLit = Literal.create(null, FloatType)
val floatLit = Literal.create(1.0f, FloatType)
val timestampLit = Literal.create("2017-04-12", TimestampType)
val timestampLit = Literal.create(Timestamp.valueOf("2017-04-12 00:00:00"), TimestampType)
val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000"))
val tsArrayLit = Literal(Array(new Timestamp(System.currentTimeMillis())))
val strArrayLit = Literal(Array("c"))
@ -793,11 +793,11 @@ class TypeCoercionSuite extends AnalysisTest {
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateArray(Literal(1.0)
:: Literal(1)
:: Literal.create(1.0, FloatType)
:: Literal.create(1.0f, FloatType)
:: Nil),
CreateArray(Literal(1.0)
:: Cast(Literal(1), DoubleType)
:: Cast(Literal.create(1.0, FloatType), DoubleType)
:: Cast(Literal.create(1.0f, FloatType), DoubleType)
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion,
@ -834,23 +834,23 @@ class TypeCoercionSuite extends AnalysisTest {
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateMap(Literal(1)
:: Literal("a")
:: Literal.create(2.0, FloatType)
:: Literal.create(2.0f, FloatType)
:: Literal("b")
:: Nil),
CreateMap(Cast(Literal(1), FloatType)
:: Literal("a")
:: Literal.create(2.0, FloatType)
:: Literal.create(2.0f, FloatType)
:: Literal("b")
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateMap(Literal.create(null, DecimalType(5, 3))
:: Literal("a")
:: Literal.create(2.0, FloatType)
:: Literal.create(2.0f, FloatType)
:: Literal("b")
:: Nil),
CreateMap(Literal.create(null, DecimalType(5, 3)).cast(DoubleType)
:: Literal("a")
:: Literal.create(2.0, FloatType).cast(DoubleType)
:: Literal.create(2.0f, FloatType).cast(DoubleType)
:: Literal("b")
:: Nil))
// type coercion for map values
@ -895,11 +895,11 @@ class TypeCoercionSuite extends AnalysisTest {
ruleTest(TypeCoercion.FunctionArgumentConversion,
operator(Literal(1.0)
:: Literal(1)
:: Literal.create(1.0, FloatType)
:: Literal.create(1.0f, FloatType)
:: Nil),
operator(Literal(1.0)
:: Cast(Literal(1), DoubleType)
:: Cast(Literal.create(1.0, FloatType), DoubleType)
:: Cast(Literal.create(1.0f, FloatType), DoubleType)
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion,
operator(Literal(1L)
@ -966,7 +966,7 @@ class TypeCoercionSuite extends AnalysisTest {
val falseLit = Literal.create(false, BooleanType)
val stringLit = Literal.create("c", StringType)
val floatLit = Literal.create(1.0f, FloatType)
val timestampLit = Literal.create("2017-04-12", TimestampType)
val timestampLit = Literal.create(Timestamp.valueOf("2017-04-12 00:00:00"), TimestampType)
val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000"))
ruleTest(rule,
@ -1016,14 +1016,16 @@ class TypeCoercionSuite extends AnalysisTest {
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
)
ruleTest(TypeCoercion.CaseWhenCoercion,
CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))),
CaseWhen(Seq((Literal(true), Literal(1.2))),
Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))
Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2))),
CaseWhen(Seq((Literal(true), Literal(1.2))),
Cast(Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2)), DoubleType))
)
ruleTest(TypeCoercion.CaseWhenCoercion,
CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))),
CaseWhen(Seq((Literal(true), Literal(100L))),
Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2))),
CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))),
Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))
Cast(Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2)), DecimalType(22, 2)))
)
}

View file

@ -623,7 +623,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
test("SPARK-21513: to_json support map[string, struct] to json") {
val schema = MapType(StringType, StructType(StructField("a", IntegerType) :: Nil))
val input = Literal.create(ArrayBasedMapData(Map("test" -> InternalRow(1))), schema)
val input = Literal(
ArrayBasedMapData(Map(UTF8String.fromString("test") -> InternalRow(1))), schema)
checkEvaluation(
StructsToJson(Map.empty, input),
"""{"test":{"a":1}}"""
@ -633,7 +634,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
test("SPARK-21513: to_json support map[struct, struct] to json") {
val schema = MapType(StructType(StructField("a", IntegerType) :: Nil),
StructType(StructField("b", IntegerType) :: Nil))
val input = Literal.create(ArrayBasedMapData(Map(InternalRow(1) -> InternalRow(2))), schema)
val input = Literal(ArrayBasedMapData(Map(InternalRow(1) -> InternalRow(2))), schema)
checkEvaluation(
StructsToJson(Map.empty, input),
"""{"[1]":{"b":2}}"""
@ -642,7 +643,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
test("SPARK-21513: to_json support map[string, integer] to json") {
val schema = MapType(StringType, IntegerType)
val input = Literal.create(ArrayBasedMapData(Map("a" -> 1)), schema)
val input = Literal(ArrayBasedMapData(Map(UTF8String.fromString("a") -> 1)), schema)
checkEvaluation(
StructsToJson(Map.empty, input),
"""{"a":1}"""
@ -651,17 +652,18 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
test("to_json - array with maps") {
val inputSchema = ArrayType(MapType(StringType, IntegerType))
val input = new GenericArrayData(ArrayBasedMapData(
Map("a" -> 1)) :: ArrayBasedMapData(Map("b" -> 2)) :: Nil)
val input = new GenericArrayData(
ArrayBasedMapData(Map(UTF8String.fromString("a") -> 1)) ::
ArrayBasedMapData(Map(UTF8String.fromString("b") -> 2)) :: Nil)
val output = """[{"a":1},{"b":2}]"""
checkEvaluation(
StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId),
StructsToJson(Map.empty, Literal(input, inputSchema), gmtId),
output)
}
test("to_json - array with single map") {
val inputSchema = ArrayType(MapType(StringType, IntegerType))
val input = new GenericArrayData(ArrayBasedMapData(Map("a" -> 1)) :: Nil)
val input = new GenericArrayData(ArrayBasedMapData(Map(UTF8String.fromString("a") -> 1)) :: Nil)
val output = """[{"a":1}]"""
checkEvaluation(
StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId),

View file

@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
import java.sql.Timestamp
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
@ -107,8 +109,8 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val nullLit = Literal.create(null, NullType)
val floatNullLit = Literal.create(null, FloatType)
val floatLit = Literal.create(1.01f, FloatType)
val timestampLit = Literal.create("2017-04-12", TimestampType)
val decimalLit = Literal.create(10.2, DecimalType(20, 2))
val timestampLit = Literal.create(Timestamp.valueOf("2017-04-12 00:00:00"), TimestampType)
val decimalLit = Literal.create(BigDecimal.valueOf(10.2), DecimalType(20, 2))
assert(analyze(new Nvl(decimalLit, stringLit)).dataType == StringType)
assert(analyze(new Nvl(doubleLit, decimalLit)).dataType == DoubleType)

View file

@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
import java.sql.Timestamp
import java.util.TimeZone
import org.apache.spark.SparkFunSuite
@ -32,9 +32,9 @@ class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val b2 = Literal.create(true, BooleanType)
val i1 = Literal.create(20132983, IntegerType)
val i2 = Literal.create(-20132983, IntegerType)
val l1 = Literal.create(20132983, LongType)
val l2 = Literal.create(-20132983, LongType)
val millis = 1524954911000L;
val l1 = Literal.create(20132983L, LongType)
val l2 = Literal.create(-20132983L, LongType)
val millis = 1524954911000L
// Explicitly choose a time zone, since Date objects can create different values depending on
// local time zone of the machine on which the test is running
val oldDefaultTZ = TimeZone.getDefault
@ -57,7 +57,7 @@ class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val dec1 = Literal(Decimal(20132983L, 10, 2))
val dec2 = Literal(Decimal(20132983L, 19, 2))
val dec3 = Literal(Decimal(20132983L, 21, 2))
val list1 = Literal(List(1, 2), ArrayType(IntegerType))
val list1 = Literal.create(Seq(1, 2), ArrayType(IntegerType))
val nullVal = Literal.create(null, IntegerType)
checkEvaluation(SortPrefix(SortOrder(b1, Ascending)), 0L)

View file

@ -105,9 +105,9 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
}
test("parse sql expression for duration in microseconds - long") {
val dur = TimeWindow.invokePrivate(parseExpression(Literal.create(2 << 52, LongType)))
val dur = TimeWindow.invokePrivate(parseExpression(Literal.create(2L << 52, LongType)))
assert(dur.isInstanceOf[Long])
assert(dur === (2 << 52))
assert(dur === (2L << 52))
}
test("parse sql expression for duration in microseconds - invalid interval") {

View file

@ -232,11 +232,14 @@ class PercentileSuite extends SparkFunSuite {
BooleanType, StringType, DateType, TimestampType, CalendarIntervalType, NullType)
invalidDataTypes.foreach { dataType =>
val percentage = Literal(0.5, dataType)
val percentage = Literal.default(dataType)
val percentile4 = new Percentile(child, percentage)
assertEqual(percentile4.checkInputDataTypes(),
TypeCheckFailure(s"argument 2 requires double type, however, " +
s"'0.5' is of ${dataType.simpleString} type."))
val checkResult = percentile4.checkInputDataTypes()
assert(checkResult.isFailure)
Seq("argument 2 requires double type, however, ",
s"is of ${dataType.simpleString} type.").foreach { errMsg =>
assert(checkResult.asInstanceOf[TypeCheckFailure].message.contains(errMsg))
}
}
}

View file

@ -210,13 +210,13 @@ case class AnalyzeColumnCommand(
def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr =>
expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() }
})
val one = Literal(1, LongType)
val one = Literal(1L, LongType)
// the approximate ndv (num distinct value) should never be larger than the number of rows
val numNonNulls = if (col.nullable) Count(col) else Count(one)
val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls))
val numNulls = Subtract(Count(one), numNonNulls)
val defaultSize = Literal(col.dataType.defaultSize, LongType)
val defaultSize = Literal(col.dataType.defaultSize.toLong, LongType)
val nullArray = Literal(null, ArrayType(LongType))
def fixedLenTypeStruct: CreateNamedStruct = {

View file

@ -228,7 +228,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
test("join key rewritten") {
val l = Literal(1L)
val i = Literal(2)
val s = Literal.create(3, ShortType)
val s = Literal.create(3.toShort, ShortType)
val ss = Literal("hello")
assert(HashJoin.rewriteKeyExpr(l :: Nil) === l :: Nil)