[SPARK-25734][SQL] Literal should have a value corresponding to dataType
## What changes were proposed in this pull request? `Literal.value` should have a value a value corresponding to `dataType`. This pr added code to verify it and fixed the existing tests to do so. ## How was this patch tested? Modified the existing tests. Closes #22724 from maropu/SPARK-25734. Authored-by: Takeshi Yamamuro <yamamuro@apache.org> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
e9af9460bc
commit
a9f685bb70
|
@ -52,7 +52,7 @@ private[kafka010] object KafkaWriter extends Logging {
|
|||
s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " +
|
||||
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.")
|
||||
} else {
|
||||
Literal(topic.get, StringType)
|
||||
Literal.create(topic.get, StringType)
|
||||
}
|
||||
).dataType match {
|
||||
case StringType => // good
|
||||
|
|
|
@ -40,9 +40,10 @@ import org.json4s.JsonAST._
|
|||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types._
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
object Literal {
|
||||
val TrueLiteral: Literal = Literal(true, BooleanType)
|
||||
|
@ -196,6 +197,47 @@ object Literal {
|
|||
case other =>
|
||||
throw new RuntimeException(s"no default for type $dataType")
|
||||
}
|
||||
|
||||
private[expressions] def validateLiteralValue(value: Any, dataType: DataType): Unit = {
|
||||
def doValidate(v: Any, dataType: DataType): Boolean = dataType match {
|
||||
case _ if v == null => true
|
||||
case BooleanType => v.isInstanceOf[Boolean]
|
||||
case ByteType => v.isInstanceOf[Byte]
|
||||
case ShortType => v.isInstanceOf[Short]
|
||||
case IntegerType | DateType => v.isInstanceOf[Int]
|
||||
case LongType | TimestampType => v.isInstanceOf[Long]
|
||||
case FloatType => v.isInstanceOf[Float]
|
||||
case DoubleType => v.isInstanceOf[Double]
|
||||
case _: DecimalType => v.isInstanceOf[Decimal]
|
||||
case CalendarIntervalType => v.isInstanceOf[CalendarInterval]
|
||||
case BinaryType => v.isInstanceOf[Array[Byte]]
|
||||
case StringType => v.isInstanceOf[UTF8String]
|
||||
case st: StructType =>
|
||||
v.isInstanceOf[InternalRow] && {
|
||||
val row = v.asInstanceOf[InternalRow]
|
||||
st.fields.map(_.dataType).zipWithIndex.forall {
|
||||
case (dt, i) => doValidate(row.get(i, dt), dt)
|
||||
}
|
||||
}
|
||||
case at: ArrayType =>
|
||||
v.isInstanceOf[ArrayData] && {
|
||||
val ar = v.asInstanceOf[ArrayData]
|
||||
ar.numElements() == 0 || doValidate(ar.get(0, at.elementType), at.elementType)
|
||||
}
|
||||
case mt: MapType =>
|
||||
v.isInstanceOf[MapData] && {
|
||||
val map = v.asInstanceOf[MapData]
|
||||
doValidate(map.keyArray(), ArrayType(mt.keyType)) &&
|
||||
doValidate(map.valueArray(), ArrayType(mt.valueType))
|
||||
}
|
||||
case ObjectType(cls) => cls.isInstance(v)
|
||||
case udt: UserDefinedType[_] => doValidate(v, udt.sqlType)
|
||||
case _ => false
|
||||
}
|
||||
require(doValidate(value, dataType),
|
||||
s"Literal must have a corresponding value to ${dataType.catalogString}, " +
|
||||
s"but class ${Utils.getSimpleName(value.getClass)} found.")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -240,6 +282,8 @@ object DecimalLiteral {
|
|||
*/
|
||||
case class Literal (value: Any, dataType: DataType) extends LeafExpression {
|
||||
|
||||
Literal.validateLiteralValue(value, dataType)
|
||||
|
||||
override def foldable: Boolean = true
|
||||
override def nullable: Boolean = value == null
|
||||
|
||||
|
|
|
@ -742,7 +742,7 @@ class TypeCoercionSuite extends AnalysisTest {
|
|||
val nullLit = Literal.create(null, NullType)
|
||||
val floatNullLit = Literal.create(null, FloatType)
|
||||
val floatLit = Literal.create(1.0f, FloatType)
|
||||
val timestampLit = Literal.create("2017-04-12", TimestampType)
|
||||
val timestampLit = Literal.create(Timestamp.valueOf("2017-04-12 00:00:00"), TimestampType)
|
||||
val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000"))
|
||||
val tsArrayLit = Literal(Array(new Timestamp(System.currentTimeMillis())))
|
||||
val strArrayLit = Literal(Array("c"))
|
||||
|
@ -793,11 +793,11 @@ class TypeCoercionSuite extends AnalysisTest {
|
|||
ruleTest(TypeCoercion.FunctionArgumentConversion,
|
||||
CreateArray(Literal(1.0)
|
||||
:: Literal(1)
|
||||
:: Literal.create(1.0, FloatType)
|
||||
:: Literal.create(1.0f, FloatType)
|
||||
:: Nil),
|
||||
CreateArray(Literal(1.0)
|
||||
:: Cast(Literal(1), DoubleType)
|
||||
:: Cast(Literal.create(1.0, FloatType), DoubleType)
|
||||
:: Cast(Literal.create(1.0f, FloatType), DoubleType)
|
||||
:: Nil))
|
||||
|
||||
ruleTest(TypeCoercion.FunctionArgumentConversion,
|
||||
|
@ -834,23 +834,23 @@ class TypeCoercionSuite extends AnalysisTest {
|
|||
ruleTest(TypeCoercion.FunctionArgumentConversion,
|
||||
CreateMap(Literal(1)
|
||||
:: Literal("a")
|
||||
:: Literal.create(2.0, FloatType)
|
||||
:: Literal.create(2.0f, FloatType)
|
||||
:: Literal("b")
|
||||
:: Nil),
|
||||
CreateMap(Cast(Literal(1), FloatType)
|
||||
:: Literal("a")
|
||||
:: Literal.create(2.0, FloatType)
|
||||
:: Literal.create(2.0f, FloatType)
|
||||
:: Literal("b")
|
||||
:: Nil))
|
||||
ruleTest(TypeCoercion.FunctionArgumentConversion,
|
||||
CreateMap(Literal.create(null, DecimalType(5, 3))
|
||||
:: Literal("a")
|
||||
:: Literal.create(2.0, FloatType)
|
||||
:: Literal.create(2.0f, FloatType)
|
||||
:: Literal("b")
|
||||
:: Nil),
|
||||
CreateMap(Literal.create(null, DecimalType(5, 3)).cast(DoubleType)
|
||||
:: Literal("a")
|
||||
:: Literal.create(2.0, FloatType).cast(DoubleType)
|
||||
:: Literal.create(2.0f, FloatType).cast(DoubleType)
|
||||
:: Literal("b")
|
||||
:: Nil))
|
||||
// type coercion for map values
|
||||
|
@ -895,11 +895,11 @@ class TypeCoercionSuite extends AnalysisTest {
|
|||
ruleTest(TypeCoercion.FunctionArgumentConversion,
|
||||
operator(Literal(1.0)
|
||||
:: Literal(1)
|
||||
:: Literal.create(1.0, FloatType)
|
||||
:: Literal.create(1.0f, FloatType)
|
||||
:: Nil),
|
||||
operator(Literal(1.0)
|
||||
:: Cast(Literal(1), DoubleType)
|
||||
:: Cast(Literal.create(1.0, FloatType), DoubleType)
|
||||
:: Cast(Literal.create(1.0f, FloatType), DoubleType)
|
||||
:: Nil))
|
||||
ruleTest(TypeCoercion.FunctionArgumentConversion,
|
||||
operator(Literal(1L)
|
||||
|
@ -966,7 +966,7 @@ class TypeCoercionSuite extends AnalysisTest {
|
|||
val falseLit = Literal.create(false, BooleanType)
|
||||
val stringLit = Literal.create("c", StringType)
|
||||
val floatLit = Literal.create(1.0f, FloatType)
|
||||
val timestampLit = Literal.create("2017-04-12", TimestampType)
|
||||
val timestampLit = Literal.create(Timestamp.valueOf("2017-04-12 00:00:00"), TimestampType)
|
||||
val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000"))
|
||||
|
||||
ruleTest(rule,
|
||||
|
@ -1016,14 +1016,16 @@ class TypeCoercionSuite extends AnalysisTest {
|
|||
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
|
||||
)
|
||||
ruleTest(TypeCoercion.CaseWhenCoercion,
|
||||
CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))),
|
||||
CaseWhen(Seq((Literal(true), Literal(1.2))),
|
||||
Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))
|
||||
Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2))),
|
||||
CaseWhen(Seq((Literal(true), Literal(1.2))),
|
||||
Cast(Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2)), DoubleType))
|
||||
)
|
||||
ruleTest(TypeCoercion.CaseWhenCoercion,
|
||||
CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))),
|
||||
CaseWhen(Seq((Literal(true), Literal(100L))),
|
||||
Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2))),
|
||||
CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))),
|
||||
Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))
|
||||
Cast(Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2)), DecimalType(22, 2)))
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -623,7 +623,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
|
|||
|
||||
test("SPARK-21513: to_json support map[string, struct] to json") {
|
||||
val schema = MapType(StringType, StructType(StructField("a", IntegerType) :: Nil))
|
||||
val input = Literal.create(ArrayBasedMapData(Map("test" -> InternalRow(1))), schema)
|
||||
val input = Literal(
|
||||
ArrayBasedMapData(Map(UTF8String.fromString("test") -> InternalRow(1))), schema)
|
||||
checkEvaluation(
|
||||
StructsToJson(Map.empty, input),
|
||||
"""{"test":{"a":1}}"""
|
||||
|
@ -633,7 +634,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
|
|||
test("SPARK-21513: to_json support map[struct, struct] to json") {
|
||||
val schema = MapType(StructType(StructField("a", IntegerType) :: Nil),
|
||||
StructType(StructField("b", IntegerType) :: Nil))
|
||||
val input = Literal.create(ArrayBasedMapData(Map(InternalRow(1) -> InternalRow(2))), schema)
|
||||
val input = Literal(ArrayBasedMapData(Map(InternalRow(1) -> InternalRow(2))), schema)
|
||||
checkEvaluation(
|
||||
StructsToJson(Map.empty, input),
|
||||
"""{"[1]":{"b":2}}"""
|
||||
|
@ -642,7 +643,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
|
|||
|
||||
test("SPARK-21513: to_json support map[string, integer] to json") {
|
||||
val schema = MapType(StringType, IntegerType)
|
||||
val input = Literal.create(ArrayBasedMapData(Map("a" -> 1)), schema)
|
||||
val input = Literal(ArrayBasedMapData(Map(UTF8String.fromString("a") -> 1)), schema)
|
||||
checkEvaluation(
|
||||
StructsToJson(Map.empty, input),
|
||||
"""{"a":1}"""
|
||||
|
@ -651,17 +652,18 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
|
|||
|
||||
test("to_json - array with maps") {
|
||||
val inputSchema = ArrayType(MapType(StringType, IntegerType))
|
||||
val input = new GenericArrayData(ArrayBasedMapData(
|
||||
Map("a" -> 1)) :: ArrayBasedMapData(Map("b" -> 2)) :: Nil)
|
||||
val input = new GenericArrayData(
|
||||
ArrayBasedMapData(Map(UTF8String.fromString("a") -> 1)) ::
|
||||
ArrayBasedMapData(Map(UTF8String.fromString("b") -> 2)) :: Nil)
|
||||
val output = """[{"a":1},{"b":2}]"""
|
||||
checkEvaluation(
|
||||
StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId),
|
||||
StructsToJson(Map.empty, Literal(input, inputSchema), gmtId),
|
||||
output)
|
||||
}
|
||||
|
||||
test("to_json - array with single map") {
|
||||
val inputSchema = ArrayType(MapType(StringType, IntegerType))
|
||||
val input = new GenericArrayData(ArrayBasedMapData(Map("a" -> 1)) :: Nil)
|
||||
val input = new GenericArrayData(ArrayBasedMapData(Map(UTF8String.fromString("a") -> 1)) :: Nil)
|
||||
val output = """[{"a":1}]"""
|
||||
checkEvaluation(
|
||||
StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId),
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import java.sql.Timestamp
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
|
||||
|
@ -107,8 +109,8 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
val nullLit = Literal.create(null, NullType)
|
||||
val floatNullLit = Literal.create(null, FloatType)
|
||||
val floatLit = Literal.create(1.01f, FloatType)
|
||||
val timestampLit = Literal.create("2017-04-12", TimestampType)
|
||||
val decimalLit = Literal.create(10.2, DecimalType(20, 2))
|
||||
val timestampLit = Literal.create(Timestamp.valueOf("2017-04-12 00:00:00"), TimestampType)
|
||||
val decimalLit = Literal.create(BigDecimal.valueOf(10.2), DecimalType(20, 2))
|
||||
|
||||
assert(analyze(new Nvl(decimalLit, stringLit)).dataType == StringType)
|
||||
assert(analyze(new Nvl(doubleLit, decimalLit)).dataType == DoubleType)
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import java.sql.{Date, Timestamp}
|
||||
import java.sql.Timestamp
|
||||
import java.util.TimeZone
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
|
@ -32,9 +32,9 @@ class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
val b2 = Literal.create(true, BooleanType)
|
||||
val i1 = Literal.create(20132983, IntegerType)
|
||||
val i2 = Literal.create(-20132983, IntegerType)
|
||||
val l1 = Literal.create(20132983, LongType)
|
||||
val l2 = Literal.create(-20132983, LongType)
|
||||
val millis = 1524954911000L;
|
||||
val l1 = Literal.create(20132983L, LongType)
|
||||
val l2 = Literal.create(-20132983L, LongType)
|
||||
val millis = 1524954911000L
|
||||
// Explicitly choose a time zone, since Date objects can create different values depending on
|
||||
// local time zone of the machine on which the test is running
|
||||
val oldDefaultTZ = TimeZone.getDefault
|
||||
|
@ -57,7 +57,7 @@ class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
val dec1 = Literal(Decimal(20132983L, 10, 2))
|
||||
val dec2 = Literal(Decimal(20132983L, 19, 2))
|
||||
val dec3 = Literal(Decimal(20132983L, 21, 2))
|
||||
val list1 = Literal(List(1, 2), ArrayType(IntegerType))
|
||||
val list1 = Literal.create(Seq(1, 2), ArrayType(IntegerType))
|
||||
val nullVal = Literal.create(null, IntegerType)
|
||||
|
||||
checkEvaluation(SortPrefix(SortOrder(b1, Ascending)), 0L)
|
||||
|
|
|
@ -105,9 +105,9 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
|
|||
}
|
||||
|
||||
test("parse sql expression for duration in microseconds - long") {
|
||||
val dur = TimeWindow.invokePrivate(parseExpression(Literal.create(2 << 52, LongType)))
|
||||
val dur = TimeWindow.invokePrivate(parseExpression(Literal.create(2L << 52, LongType)))
|
||||
assert(dur.isInstanceOf[Long])
|
||||
assert(dur === (2 << 52))
|
||||
assert(dur === (2L << 52))
|
||||
}
|
||||
|
||||
test("parse sql expression for duration in microseconds - invalid interval") {
|
||||
|
|
|
@ -232,11 +232,14 @@ class PercentileSuite extends SparkFunSuite {
|
|||
BooleanType, StringType, DateType, TimestampType, CalendarIntervalType, NullType)
|
||||
|
||||
invalidDataTypes.foreach { dataType =>
|
||||
val percentage = Literal(0.5, dataType)
|
||||
val percentage = Literal.default(dataType)
|
||||
val percentile4 = new Percentile(child, percentage)
|
||||
assertEqual(percentile4.checkInputDataTypes(),
|
||||
TypeCheckFailure(s"argument 2 requires double type, however, " +
|
||||
s"'0.5' is of ${dataType.simpleString} type."))
|
||||
val checkResult = percentile4.checkInputDataTypes()
|
||||
assert(checkResult.isFailure)
|
||||
Seq("argument 2 requires double type, however, ",
|
||||
s"is of ${dataType.simpleString} type.").foreach { errMsg =>
|
||||
assert(checkResult.asInstanceOf[TypeCheckFailure].message.contains(errMsg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -210,13 +210,13 @@ case class AnalyzeColumnCommand(
|
|||
def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr =>
|
||||
expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() }
|
||||
})
|
||||
val one = Literal(1, LongType)
|
||||
val one = Literal(1L, LongType)
|
||||
|
||||
// the approximate ndv (num distinct value) should never be larger than the number of rows
|
||||
val numNonNulls = if (col.nullable) Count(col) else Count(one)
|
||||
val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls))
|
||||
val numNulls = Subtract(Count(one), numNonNulls)
|
||||
val defaultSize = Literal(col.dataType.defaultSize, LongType)
|
||||
val defaultSize = Literal(col.dataType.defaultSize.toLong, LongType)
|
||||
val nullArray = Literal(null, ArrayType(LongType))
|
||||
|
||||
def fixedLenTypeStruct: CreateNamedStruct = {
|
||||
|
|
|
@ -228,7 +228,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
|
|||
test("join key rewritten") {
|
||||
val l = Literal(1L)
|
||||
val i = Literal(2)
|
||||
val s = Literal.create(3, ShortType)
|
||||
val s = Literal.create(3.toShort, ShortType)
|
||||
val ss = Literal("hello")
|
||||
|
||||
assert(HashJoin.rewriteKeyExpr(l :: Nil) === l :: Nil)
|
||||
|
|
Loading…
Reference in a new issue