diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 0efd1224f1..2bcbb92f1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -128,30 +128,36 @@ object Literal { val dataType = DataType.parseDataType(json \ "dataType") json \ "value" match { case JNull => Literal.create(null, dataType) - case JString(str) => - val value = dataType match { - case BooleanType => str.toBoolean - case ByteType => str.toByte - case ShortType => str.toShort - case IntegerType => str.toInt - case LongType => str.toLong - case FloatType => str.toFloat - case DoubleType => str.toDouble - case StringType => UTF8String.fromString(str) - case DateType => java.sql.Date.valueOf(str) - case TimestampType => java.sql.Timestamp.valueOf(str) - case CalendarIntervalType => CalendarInterval.fromString(str) - case t: DecimalType => - val d = Decimal(str) - assert(d.changePrecision(t.precision, t.scale)) - d - case _ => null - } - Literal.create(value, dataType) + case JString(str) => fromString(str, dataType) case other => sys.error(s"$other is not a valid Literal json value") } } + /** + * Constructs a Literal from a String + */ + def fromString(str: String, dataType: DataType): Literal = { + val value = dataType match { + case BooleanType => str.toBoolean + case ByteType => str.toByte + case ShortType => str.toShort + case IntegerType => str.toInt + case LongType => str.toLong + case FloatType => str.toFloat + case DoubleType => str.toDouble + case StringType => UTF8String.fromString(str) + case DateType => java.sql.Date.valueOf(str) + case TimestampType => java.sql.Timestamp.valueOf(str) + case CalendarIntervalType => CalendarInterval.fromString(str) + case t: DecimalType => + val d = Decimal(str) + assert(d.changePrecision(t.precision, t.scale)) + d + case _ => null + } + Literal.create(value, dataType) + } + def create(v: Any, dataType: DataType): Literal = { Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 86f80fe66d..3ea6bfac9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -226,4 +226,25 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal('\u0000'), "\u0000") checkEvaluation(Literal.create('\n'), "\n") } + + test("fromString converts String/DataType input correctly") { + checkEvaluation(Literal.fromString(false.toString, BooleanType), false) + checkEvaluation(Literal.fromString(null, NullType), null) + checkEvaluation(Literal.fromString(Int.MaxValue.toByte.toString, ByteType), Int.MaxValue.toByte) + checkEvaluation(Literal.fromString(Short.MaxValue.toShort.toString, ShortType), Short.MaxValue + .toShort) + checkEvaluation(Literal.fromString(Int.MaxValue.toString, IntegerType), Int.MaxValue) + checkEvaluation(Literal.fromString(Long.MaxValue.toString, LongType), Long.MaxValue) + checkEvaluation(Literal.fromString(Float.MaxValue.toString, FloatType), Float.MaxValue) + checkEvaluation(Literal.fromString(Double.MaxValue.toString, DoubleType), Double.MaxValue) + checkEvaluation(Literal.fromString("1.23456", DecimalType(10, 5)), Decimal(1.23456)) + checkEvaluation(Literal.fromString("Databricks", StringType), "Databricks") + val dateString = "1970-01-01" + checkEvaluation(Literal.fromString(dateString, DateType), java.sql.Date.valueOf(dateString)) + val timestampString = "0000-01-01 00:00:00" + checkEvaluation(Literal.fromString(timestampString, TimestampType), + java.sql.Timestamp.valueOf(timestampString)) + val calInterval = new CalendarInterval(1, 1) + checkEvaluation(Literal.fromString(calInterval.toString, CalendarIntervalType), calInterval) + } }