[SPARK-35766][SQL][TESTS] Break down CastSuite/AnsiCastSuite into multiple files

### What changes were proposed in this pull request?

Currently, the file CastSuite.scala becomes big: 2000 lines, 2 base classes, 4 test suites.
In my previous work of Timestamp without time zone, I planned to put new test cases in CastSuiteBase, but they were accidentally added in AnsiCastSuiteBase.

This PR is to break the file down into 3 files. It also moves the test cases about timestamp without time zone to the right base class.

### Why are the changes needed?

Make development and review easier.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Unit tests

Closes #32918 from gengliangwang/refactorCastSuite.

Authored-by: Gengliang Wang <gengliang@apache.org>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
This commit is contained in:
Gengliang Wang 2021-06-16 00:17:04 +08:00
parent b74260f67f
commit c382d4009b
3 changed files with 1412 additions and 1356 deletions

View file

@ -0,0 +1,481 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import java.time.DateTimeException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
abstract class AnsiCastSuiteBase extends CastSuiteBase {
private def testIntMaxAndMin(dt: DataType): Unit = {
assert(Seq(IntegerType, ShortType, ByteType).contains(dt))
Seq(Int.MaxValue + 1L, Int.MinValue - 1L).foreach { value =>
checkExceptionInExpression[ArithmeticException](cast(value, dt), "overflow")
checkExceptionInExpression[ArithmeticException](cast(Decimal(value.toString), dt), "overflow")
checkExceptionInExpression[ArithmeticException](
cast(Literal(value * 1.5f, FloatType), dt), "overflow")
checkExceptionInExpression[ArithmeticException](
cast(Literal(value * 1.0, DoubleType), dt), "overflow")
}
}
private def testLongMaxAndMin(dt: DataType): Unit = {
assert(Seq(LongType, IntegerType).contains(dt))
Seq(Decimal(Long.MaxValue) + Decimal(1), Decimal(Long.MinValue) - Decimal(1)).foreach { value =>
checkExceptionInExpression[ArithmeticException](
cast(value, dt), "overflow")
checkExceptionInExpression[ArithmeticException](
cast((value * Decimal(1.1)).toFloat, dt), "overflow")
checkExceptionInExpression[ArithmeticException](
cast((value * Decimal(1.1)).toDouble, dt), "overflow")
}
}
test("ANSI mode: Throw exception on casting out-of-range value to byte type") {
testIntMaxAndMin(ByteType)
Seq(Byte.MaxValue + 1, Byte.MinValue - 1).foreach { value =>
checkExceptionInExpression[ArithmeticException](cast(value, ByteType), "overflow")
checkExceptionInExpression[ArithmeticException](
cast(Literal(value.toFloat, FloatType), ByteType), "overflow")
checkExceptionInExpression[ArithmeticException](
cast(Literal(value.toDouble, DoubleType), ByteType), "overflow")
}
Seq(Byte.MaxValue, 0.toByte, Byte.MinValue).foreach { value =>
checkEvaluation(cast(value, ByteType), value)
checkEvaluation(cast(value.toString, ByteType), value)
checkEvaluation(cast(Decimal(value.toString), ByteType), value)
checkEvaluation(cast(Literal(value.toFloat, FloatType), ByteType), value)
checkEvaluation(cast(Literal(value.toDouble, DoubleType), ByteType), value)
}
}
test("ANSI mode: Throw exception on casting out-of-range value to short type") {
testIntMaxAndMin(ShortType)
Seq(Short.MaxValue + 1, Short.MinValue - 1).foreach { value =>
checkExceptionInExpression[ArithmeticException](cast(value, ShortType), "overflow")
checkExceptionInExpression[ArithmeticException](
cast(Literal(value.toFloat, FloatType), ShortType), "overflow")
checkExceptionInExpression[ArithmeticException](
cast(Literal(value.toDouble, DoubleType), ShortType), "overflow")
}
Seq(Short.MaxValue, 0.toShort, Short.MinValue).foreach { value =>
checkEvaluation(cast(value, ShortType), value)
checkEvaluation(cast(value.toString, ShortType), value)
checkEvaluation(cast(Decimal(value.toString), ShortType), value)
checkEvaluation(cast(Literal(value.toFloat, FloatType), ShortType), value)
checkEvaluation(cast(Literal(value.toDouble, DoubleType), ShortType), value)
}
}
test("ANSI mode: Throw exception on casting out-of-range value to int type") {
testIntMaxAndMin(IntegerType)
testLongMaxAndMin(IntegerType)
Seq(Int.MaxValue, 0, Int.MinValue).foreach { value =>
checkEvaluation(cast(value, IntegerType), value)
checkEvaluation(cast(value.toString, IntegerType), value)
checkEvaluation(cast(Decimal(value.toString), IntegerType), value)
checkEvaluation(cast(Literal(value * 1.0, DoubleType), IntegerType), value)
}
checkEvaluation(cast(Int.MaxValue + 0.9D, IntegerType), Int.MaxValue)
checkEvaluation(cast(Int.MinValue - 0.9D, IntegerType), Int.MinValue)
}
test("ANSI mode: Throw exception on casting out-of-range value to long type") {
testLongMaxAndMin(LongType)
Seq(Long.MaxValue, 0, Long.MinValue).foreach { value =>
checkEvaluation(cast(value, LongType), value)
checkEvaluation(cast(value.toString, LongType), value)
checkEvaluation(cast(Decimal(value.toString), LongType), value)
}
checkEvaluation(cast(Long.MaxValue + 0.9F, LongType), Long.MaxValue)
checkEvaluation(cast(Long.MinValue - 0.9F, LongType), Long.MinValue)
checkEvaluation(cast(Long.MaxValue + 0.9D, LongType), Long.MaxValue)
checkEvaluation(cast(Long.MinValue - 0.9D, LongType), Long.MinValue)
}
test("ANSI mode: Throw exception on casting out-of-range value to decimal type") {
checkExceptionInExpression[ArithmeticException](
cast(Literal("134.12"), DecimalType(3, 2)), "cannot be represented")
checkExceptionInExpression[ArithmeticException](
cast(Literal(BigDecimal(134.12)), DecimalType(3, 2)), "cannot be represented")
checkExceptionInExpression[ArithmeticException](
cast(Literal(134.12), DecimalType(3, 2)), "cannot be represented")
}
test("ANSI mode: disallow type conversions between Numeric types and Timestamp type") {
import DataTypeTestUtils.numericTypes
checkInvalidCastFromNumericType(TimestampType)
var errorMsg =
"you can use functions TIMESTAMP_SECONDS/TIMESTAMP_MILLIS/TIMESTAMP_MICROS instead"
verifyCastFailure(cast(Literal(0L), TimestampType), Some(errorMsg))
val timestampLiteral = Literal(1L, TimestampType)
errorMsg = "you can use functions UNIX_SECONDS/UNIX_MILLIS/UNIX_MICROS instead."
numericTypes.foreach { numericType =>
verifyCastFailure(cast(timestampLiteral, numericType), Some(errorMsg))
}
}
test("ANSI mode: disallow type conversions between Numeric types and Date type") {
import DataTypeTestUtils.numericTypes
checkInvalidCastFromNumericType(DateType)
var errorMsg = "you can use function DATE_FROM_UNIX_DATE instead"
verifyCastFailure(cast(Literal(0L), DateType), Some(errorMsg))
val dateLiteral = Literal(1, DateType)
errorMsg = "you can use function UNIX_DATE instead"
numericTypes.foreach { numericType =>
verifyCastFailure(cast(dateLiteral, numericType), Some(errorMsg))
}
}
test("ANSI mode: disallow type conversions between Numeric types and Binary type") {
import DataTypeTestUtils.numericTypes
checkInvalidCastFromNumericType(BinaryType)
val binaryLiteral = Literal(new Array[Byte](1.toByte), BinaryType)
numericTypes.foreach { numericType =>
assert(cast(binaryLiteral, numericType).checkInputDataTypes().isFailure)
}
}
test("ANSI mode: disallow type conversions between Datatime types and Boolean types") {
val timestampLiteral = Literal(1L, TimestampType)
assert(cast(timestampLiteral, BooleanType).checkInputDataTypes().isFailure)
val dateLiteral = Literal(1, DateType)
assert(cast(dateLiteral, BooleanType).checkInputDataTypes().isFailure)
val booleanLiteral = Literal(true, BooleanType)
assert(cast(booleanLiteral, TimestampType).checkInputDataTypes().isFailure)
assert(cast(booleanLiteral, DateType).checkInputDataTypes().isFailure)
}
test("cast from invalid string to numeric should throw NumberFormatException") {
// cast to IntegerType
Seq(IntegerType, ShortType, ByteType, LongType).foreach { dataType =>
checkExceptionInExpression[NumberFormatException](
cast("string", dataType), "invalid input syntax for type numeric: string")
checkExceptionInExpression[NumberFormatException](
cast("123-string", dataType), "invalid input syntax for type numeric: 123-string")
checkExceptionInExpression[NumberFormatException](
cast("2020-07-19", dataType), "invalid input syntax for type numeric: 2020-07-19")
checkExceptionInExpression[NumberFormatException](
cast("1.23", dataType), "invalid input syntax for type numeric: 1.23")
}
Seq(DoubleType, FloatType, DecimalType.USER_DEFAULT).foreach { dataType =>
checkExceptionInExpression[NumberFormatException](
cast("string", dataType), "invalid input syntax for type numeric: string")
checkExceptionInExpression[NumberFormatException](
cast("123.000.00", dataType), "invalid input syntax for type numeric: 123.000.00")
checkExceptionInExpression[NumberFormatException](
cast("abc.com", dataType), "invalid input syntax for type numeric: abc.com")
}
}
protected def checkCastToNumericError(l: Literal, to: DataType, tryCastResult: Any): Unit = {
checkExceptionInExpression[NumberFormatException](
cast(l, to), "invalid input syntax for type numeric: true")
}
test("cast from invalid string array to numeric array should throw NumberFormatException") {
val array = Literal.create(Seq("123", "true", "f", null),
ArrayType(StringType, containsNull = true))
checkCastToNumericError(array, ArrayType(ByteType, containsNull = true),
Seq(123.toByte, null, null, null))
checkCastToNumericError(array, ArrayType(ShortType, containsNull = true),
Seq(123.toShort, null, null, null))
checkCastToNumericError(array, ArrayType(IntegerType, containsNull = true),
Seq(123, null, null, null))
checkCastToNumericError(array, ArrayType(LongType, containsNull = true),
Seq(123L, null, null, null))
}
test("Fast fail for cast string type to decimal type in ansi mode") {
checkEvaluation(cast("12345678901234567890123456789012345678", DecimalType(38, 0)),
Decimal("12345678901234567890123456789012345678"))
checkExceptionInExpression[ArithmeticException](
cast("123456789012345678901234567890123456789", DecimalType(38, 0)),
"out of decimal type range")
checkExceptionInExpression[ArithmeticException](
cast("12345678901234567890123456789012345678", DecimalType(38, 1)),
"cannot be represented as Decimal(38, 1)")
checkEvaluation(cast("0.00000000000000000000000000000000000001", DecimalType(38, 0)),
Decimal("0"))
checkEvaluation(cast("0.00000000000000000000000000000000000000000001", DecimalType(38, 0)),
Decimal("0"))
checkEvaluation(cast("0.00000000000000000000000000000000000001", DecimalType(38, 18)),
Decimal("0E-18"))
checkEvaluation(cast("6E-120", DecimalType(38, 0)),
Decimal("0"))
checkEvaluation(cast("6E+37", DecimalType(38, 0)),
Decimal("60000000000000000000000000000000000000"))
checkExceptionInExpression[ArithmeticException](
cast("6E+38", DecimalType(38, 0)),
"out of decimal type range")
checkExceptionInExpression[ArithmeticException](
cast("6E+37", DecimalType(38, 1)),
"cannot be represented as Decimal(38, 1)")
checkExceptionInExpression[NumberFormatException](
cast("abcd", DecimalType(38, 1)),
"invalid input syntax for type numeric")
}
protected def checkCastToBooleanError(l: Literal, to: DataType, tryCastResult: Any): Unit = {
checkExceptionInExpression[UnsupportedOperationException](
cast(l, to), s"invalid input syntax for type boolean")
}
test("ANSI mode: cast string to boolean with parse error") {
checkCastToBooleanError(Literal("abc"), BooleanType, null)
checkCastToBooleanError(Literal(""), BooleanType, null)
}
test("cast from array II") {
val array = Literal.create(Seq("123", "true", "f", null),
ArrayType(StringType, containsNull = true))
val array_notNull = Literal.create(Seq("123", "true", "f"),
ArrayType(StringType, containsNull = false))
{
val to: DataType = ArrayType(BooleanType, containsNull = true)
val ret = cast(array, to)
assert(ret.resolved)
checkCastToBooleanError(array, to, Seq(null, true, false, null))
}
{
val to: DataType = ArrayType(BooleanType, containsNull = true)
val ret = cast(array_notNull, to)
assert(ret.resolved)
checkCastToBooleanError(array_notNull, to, Seq(null, true, false))
}
}
test("cast from map II") {
val map = Literal.create(
Map("a" -> "123", "b" -> "true", "c" -> "f", "d" -> null),
MapType(StringType, StringType, valueContainsNull = true))
val map_notNull = Literal.create(
Map("a" -> "123", "b" -> "true", "c" -> "f"),
MapType(StringType, StringType, valueContainsNull = false))
checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType))
{
val to: DataType = MapType(StringType, BooleanType, valueContainsNull = true)
val ret = cast(map, to)
assert(ret.resolved)
checkCastToBooleanError(map, to, Map("a" -> null, "b" -> true, "c" -> false, "d" -> null))
}
{
val to: DataType = MapType(StringType, BooleanType, valueContainsNull = true)
val ret = cast(map_notNull, to)
assert(ret.resolved)
checkCastToBooleanError(map_notNull, to, Map("a" -> null, "b" -> true, "c" -> false))
}
}
test("cast from struct II") {
checkNullCast(
StructType(Seq(
StructField("a", StringType),
StructField("b", IntegerType))),
StructType(Seq(
StructField("a", StringType),
StructField("b", StringType))))
val struct = Literal.create(
InternalRow(
UTF8String.fromString("123"),
UTF8String.fromString("true"),
UTF8String.fromString("f"),
null),
StructType(Seq(
StructField("a", StringType, nullable = true),
StructField("b", StringType, nullable = true),
StructField("c", StringType, nullable = true),
StructField("d", StringType, nullable = true))))
val struct_notNull = Literal.create(
InternalRow(
UTF8String.fromString("123"),
UTF8String.fromString("true"),
UTF8String.fromString("f")),
StructType(Seq(
StructField("a", StringType, nullable = false),
StructField("b", StringType, nullable = false),
StructField("c", StringType, nullable = false))))
{
val to: DataType = StructType(Seq(
StructField("a", BooleanType, nullable = true),
StructField("b", BooleanType, nullable = true),
StructField("c", BooleanType, nullable = true),
StructField("d", BooleanType, nullable = true)))
val ret = cast(struct, to)
assert(ret.resolved)
checkCastToBooleanError(struct, to, InternalRow(null, true, false, null))
}
{
val to: DataType = StructType(Seq(
StructField("a", BooleanType, nullable = true),
StructField("b", BooleanType, nullable = true),
StructField("c", BooleanType, nullable = true)))
val ret = cast(struct_notNull, to)
assert(ret.resolved)
checkCastToBooleanError(struct_notNull, to, InternalRow(null, true, false))
}
}
test("ANSI mode: cast string to timestamp with parse error") {
DateTimeTestUtils.outstandingZoneIds.foreach { zid =>
def checkCastWithParseError(str: String): Unit = {
checkExceptionInExpression[DateTimeException](
cast(Literal(str), TimestampType, Option(zid.getId)),
s"Cannot cast $str to TimestampType.")
}
checkCastWithParseError("123")
checkCastWithParseError("2015-03-18 123142")
checkCastWithParseError("2015-03-18T123123")
checkCastWithParseError("2015-03-18X")
checkCastWithParseError("2015/03/18")
checkCastWithParseError("2015.03.18")
checkCastWithParseError("20150318")
checkCastWithParseError("2015-031-8")
checkCastWithParseError("2015-03-18T12:03:17-0:70")
checkCastWithParseError("abdef")
}
}
test("ANSI mode: cast string to date with parse error") {
DateTimeTestUtils.outstandingZoneIds.foreach { zid =>
def checkCastWithParseError(str: String): Unit = {
checkExceptionInExpression[DateTimeException](
cast(Literal(str), DateType, Option(zid.getId)),
s"Cannot cast $str to DateType.")
}
checkCastWithParseError("12345")
checkCastWithParseError("12345-12-18")
checkCastWithParseError("2015-13-18")
checkCastWithParseError("2015-03-128")
checkCastWithParseError("2015/03/18")
checkCastWithParseError("2015.03.18")
checkCastWithParseError("20150318")
checkCastWithParseError("2015-031-8")
checkCastWithParseError("2015-03-18ABC")
checkCastWithParseError("abdef")
}
}
test("SPARK-26218: Fix the corner case of codegen when casting float to Integer") {
checkExceptionInExpression[ArithmeticException](
cast(cast(Literal("2147483648"), FloatType), IntegerType), "overflow")
}
}
/**
* Test suite for data type casting expression [[Cast]] with ANSI mode disabled.
*/
class CastSuiteWithAnsiModeOn extends AnsiCastSuiteBase {
override def beforeAll(): Unit = {
super.beforeAll()
SQLConf.get.setConf(SQLConf.ANSI_ENABLED, true)
}
override def afterAll(): Unit = {
super.afterAll()
SQLConf.get.unsetConf(SQLConf.ANSI_ENABLED)
}
override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = {
v match {
case lit: Expression => Cast(lit, targetType, timeZoneId)
case _ => Cast(Literal(v), targetType, timeZoneId)
}
}
override def setConfigurationHint: String =
s"set ${SQLConf.ANSI_ENABLED.key} as false"
}
/**
* Test suite for data type casting expression [[AnsiCast]] with ANSI mode enabled.
*/
class AnsiCastSuiteWithAnsiModeOn extends AnsiCastSuiteBase {
override def beforeAll(): Unit = {
super.beforeAll()
SQLConf.get.setConf(SQLConf.ANSI_ENABLED, true)
}
override def afterAll(): Unit = {
super.afterAll()
SQLConf.get.unsetConf(SQLConf.ANSI_ENABLED)
}
override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = {
v match {
case lit: Expression => AnsiCast(lit, targetType, timeZoneId)
case _ => AnsiCast(Literal(v), targetType, timeZoneId)
}
}
override def setConfigurationHint: String =
s"set ${SQLConf.STORE_ASSIGNMENT_POLICY.key} as" +
s" ${SQLConf.StoreAssignmentPolicy.LEGACY.toString}"
}
/**
* Test suite for data type casting expression [[AnsiCast]] with ANSI mode disabled.
*/
class AnsiCastSuiteWithAnsiModeOff extends AnsiCastSuiteBase {
override def beforeAll(): Unit = {
super.beforeAll()
SQLConf.get.setConf(SQLConf.ANSI_ENABLED, false)
}
override def afterAll(): Unit = {
super.afterAll()
SQLConf.get.unsetConf(SQLConf.ANSI_ENABLED)
}
override def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase = {
v match {
case lit: Expression => AnsiCast(lit, targetType, timeZoneId)
case _ => AnsiCast(Literal(v), targetType, timeZoneId)
}
}
override def setConfigurationHint: String =
s"set ${SQLConf.STORE_ASSIGNMENT_POLICY.key} as" +
s" ${SQLConf.StoreAssignmentPolicy.LEGACY.toString}"
}

View file

@ -0,0 +1,930 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
import java.time.{Duration, LocalDate, LocalDateTime, Period}
import java.time.temporal.ChronoUnit
import java.util.{Calendar, TimeZone}
import scala.collection.parallel.immutable.ParVector
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.catalyst.util.IntervalUtils.microsToDuration
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DataTypeTestUtils.dayTimeIntervalTypes
import org.apache.spark.unsafe.types.UTF8String
abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
protected def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): CastBase
// expected cannot be null
protected def checkCast(v: Any, expected: Any): Unit = {
checkEvaluation(cast(v, Literal(expected).dataType), expected)
}
protected def checkNullCast(from: DataType, to: DataType): Unit = {
checkEvaluation(cast(Literal.create(null, from), to, UTC_OPT), null)
}
protected def verifyCastFailure(c: CastBase, optionalExpectedMsg: Option[String] = None): Unit = {
val typeCheckResult = c.checkInputDataTypes()
assert(typeCheckResult.isFailure)
assert(typeCheckResult.isInstanceOf[TypeCheckFailure])
val message = typeCheckResult.asInstanceOf[TypeCheckFailure].message
if (optionalExpectedMsg.isDefined) {
assert(message.contains(optionalExpectedMsg.get))
} else if (setConfigurationHint.nonEmpty) {
assert(message.contains("with ANSI mode on"))
assert(message.contains(setConfigurationHint))
} else {
assert("cannot cast [a-zA-Z]+ to [a-zA-Z]+".r.findFirstIn(message).isDefined)
}
}
protected def isAlwaysNullable: Boolean = false
protected def setConfigurationHint: String = ""
test("null cast") {
import DataTypeTestUtils._
atomicTypes.zip(atomicTypes).foreach { case (from, to) =>
checkNullCast(from, to)
}
atomicTypes.foreach(dt => checkNullCast(NullType, dt))
atomicTypes.foreach(dt => checkNullCast(dt, StringType))
checkNullCast(StringType, BinaryType)
checkNullCast(StringType, BooleanType)
numericTypes.foreach(dt => checkNullCast(dt, BooleanType))
checkNullCast(StringType, TimestampType)
checkNullCast(DateType, TimestampType)
checkNullCast(StringType, DateType)
checkNullCast(TimestampType, DateType)
checkNullCast(StringType, CalendarIntervalType)
numericTypes.foreach(dt => checkNullCast(StringType, dt))
numericTypes.foreach(dt => checkNullCast(BooleanType, dt))
for (from <- numericTypes; to <- numericTypes) checkNullCast(from, to)
}
test("cast string to date") {
var c = Calendar.getInstance()
c.set(2015, 0, 1, 0, 0, 0)
c.set(Calendar.MILLISECOND, 0)
checkEvaluation(Cast(Literal("2015"), DateType), new Date(c.getTimeInMillis))
c = Calendar.getInstance()
c.set(2015, 2, 1, 0, 0, 0)
c.set(Calendar.MILLISECOND, 0)
checkEvaluation(Cast(Literal("2015-03"), DateType), new Date(c.getTimeInMillis))
c = Calendar.getInstance()
c.set(2015, 2, 18, 0, 0, 0)
c.set(Calendar.MILLISECOND, 0)
checkEvaluation(Cast(Literal("2015-03-18"), DateType), new Date(c.getTimeInMillis))
checkEvaluation(Cast(Literal("2015-03-18 "), DateType), new Date(c.getTimeInMillis))
checkEvaluation(Cast(Literal("2015-03-18 123142"), DateType), new Date(c.getTimeInMillis))
checkEvaluation(Cast(Literal("2015-03-18T123123"), DateType), new Date(c.getTimeInMillis))
checkEvaluation(Cast(Literal("2015-03-18T"), DateType), new Date(c.getTimeInMillis))
}
test("cast string to timestamp") {
new ParVector(ALL_TIMEZONES.toVector).foreach { zid =>
def checkCastStringToTimestamp(str: String, expected: Timestamp): Unit = {
checkEvaluation(cast(Literal(str), TimestampType, Option(zid.getId)), expected)
}
val tz = TimeZone.getTimeZone(zid)
var c = Calendar.getInstance(tz)
c.set(2015, 0, 1, 0, 0, 0)
c.set(Calendar.MILLISECOND, 0)
checkCastStringToTimestamp("2015", new Timestamp(c.getTimeInMillis))
c = Calendar.getInstance(tz)
c.set(2015, 2, 1, 0, 0, 0)
c.set(Calendar.MILLISECOND, 0)
checkCastStringToTimestamp("2015-03", new Timestamp(c.getTimeInMillis))
c = Calendar.getInstance(tz)
c.set(2015, 2, 18, 0, 0, 0)
c.set(Calendar.MILLISECOND, 0)
checkCastStringToTimestamp("2015-03-18", new Timestamp(c.getTimeInMillis))
checkCastStringToTimestamp("2015-03-18 ", new Timestamp(c.getTimeInMillis))
checkCastStringToTimestamp("2015-03-18T", new Timestamp(c.getTimeInMillis))
c = Calendar.getInstance(tz)
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 0)
checkCastStringToTimestamp("2015-03-18 12:03:17", new Timestamp(c.getTimeInMillis))
checkCastStringToTimestamp("2015-03-18T12:03:17", new Timestamp(c.getTimeInMillis))
// If the string value includes timezone string, it represents the timestamp string
// in the timezone regardless of the timeZoneId parameter.
c = Calendar.getInstance(TimeZone.getTimeZone(UTC))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 0)
checkCastStringToTimestamp("2015-03-18T12:03:17Z", new Timestamp(c.getTimeInMillis))
checkCastStringToTimestamp("2015-03-18 12:03:17Z", new Timestamp(c.getTimeInMillis))
c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00"))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 0)
checkCastStringToTimestamp("2015-03-18T12:03:17-1:0", new Timestamp(c.getTimeInMillis))
checkCastStringToTimestamp("2015-03-18T12:03:17-01:00", new Timestamp(c.getTimeInMillis))
c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30"))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 0)
checkCastStringToTimestamp("2015-03-18T12:03:17+07:30", new Timestamp(c.getTimeInMillis))
c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03"))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 0)
checkCastStringToTimestamp("2015-03-18T12:03:17+7:3", new Timestamp(c.getTimeInMillis))
// tests for the string including milliseconds.
c = Calendar.getInstance(tz)
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 123)
checkCastStringToTimestamp("2015-03-18 12:03:17.123", new Timestamp(c.getTimeInMillis))
checkCastStringToTimestamp("2015-03-18T12:03:17.123", new Timestamp(c.getTimeInMillis))
// If the string value includes timezone string, it represents the timestamp string
// in the timezone regardless of the timeZoneId parameter.
c = Calendar.getInstance(TimeZone.getTimeZone(UTC))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 456)
checkCastStringToTimestamp("2015-03-18T12:03:17.456Z", new Timestamp(c.getTimeInMillis))
checkCastStringToTimestamp("2015-03-18 12:03:17.456Z", new Timestamp(c.getTimeInMillis))
c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00"))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 123)
checkCastStringToTimestamp("2015-03-18T12:03:17.123-1:0", new Timestamp(c.getTimeInMillis))
checkCastStringToTimestamp("2015-03-18T12:03:17.123-01:00", new Timestamp(c.getTimeInMillis))
c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30"))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 123)
checkCastStringToTimestamp("2015-03-18T12:03:17.123+07:30", new Timestamp(c.getTimeInMillis))
c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03"))
c.set(2015, 2, 18, 12, 3, 17)
c.set(Calendar.MILLISECOND, 123)
checkCastStringToTimestamp("2015-03-18T12:03:17.123+7:3", new Timestamp(c.getTimeInMillis))
}
}
test("cast from boolean") {
checkEvaluation(cast(true, IntegerType), 1)
checkEvaluation(cast(false, IntegerType), 0)
checkEvaluation(cast(true, StringType), "true")
checkEvaluation(cast(false, StringType), "false")
checkEvaluation(cast(cast(1, BooleanType), IntegerType), 1)
checkEvaluation(cast(cast(0, BooleanType), IntegerType), 0)
}
test("cast from int") {
checkCast(0, false)
checkCast(1, true)
checkCast(-5, true)
checkCast(1, 1.toByte)
checkCast(1, 1.toShort)
checkCast(1, 1)
checkCast(1, 1.toLong)
checkCast(1, 1.0f)
checkCast(1, 1.0)
checkCast(123, "123")
checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123))
checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123))
checkEvaluation(cast(1, LongType), 1.toLong)
}
test("cast from long") {
checkCast(0L, false)
checkCast(1L, true)
checkCast(-5L, true)
checkCast(1L, 1.toByte)
checkCast(1L, 1.toShort)
checkCast(1L, 1)
checkCast(1L, 1.toLong)
checkCast(1L, 1.0f)
checkCast(1L, 1.0)
checkCast(123L, "123")
checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123))
checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123))
}
test("cast from float") {
checkCast(0.0f, false)
checkCast(0.5f, true)
checkCast(-5.0f, true)
checkCast(1.5f, 1.toByte)
checkCast(1.5f, 1.toShort)
checkCast(1.5f, 1)
checkCast(1.5f, 1.toLong)
checkCast(1.5f, 1.5)
checkCast(1.5f, "1.5")
}
test("cast from double") {
checkCast(0.0, false)
checkCast(0.5, true)
checkCast(-5.0, true)
checkCast(1.5, 1.toByte)
checkCast(1.5, 1.toShort)
checkCast(1.5, 1)
checkCast(1.5, 1.toLong)
checkCast(1.5, 1.5f)
checkCast(1.5, "1.5")
}
test("cast from string") {
assert(cast("abcdef", StringType).nullable === isAlwaysNullable)
assert(cast("abcdef", BinaryType).nullable === isAlwaysNullable)
assert(cast("abcdef", BooleanType).nullable)
assert(cast("abcdef", TimestampType).nullable)
assert(cast("abcdef", LongType).nullable)
assert(cast("abcdef", IntegerType).nullable)
assert(cast("abcdef", ShortType).nullable)
assert(cast("abcdef", ByteType).nullable)
assert(cast("abcdef", DecimalType.USER_DEFAULT).nullable)
assert(cast("abcdef", DecimalType(4, 2)).nullable)
assert(cast("abcdef", DoubleType).nullable)
assert(cast("abcdef", FloatType).nullable)
}
test("data type casting") {
val sd = "1970-01-01"
val d = Date.valueOf(sd)
val zts = sd + " 00:00:00"
val sts = sd + " 00:00:02"
val nts = sts + ".1"
val ts = withDefaultTimeZone(UTC)(Timestamp.valueOf(nts))
for (tz <- ALL_TIMEZONES) {
val timeZoneId = Option(tz.getId)
var c = Calendar.getInstance(TimeZoneUTC)
c.set(2015, 2, 8, 2, 30, 0)
checkEvaluation(
cast(cast(new Timestamp(c.getTimeInMillis), StringType, timeZoneId),
TimestampType, timeZoneId),
millisToMicros(c.getTimeInMillis))
c = Calendar.getInstance(TimeZoneUTC)
c.set(2015, 10, 1, 2, 30, 0)
checkEvaluation(
cast(cast(new Timestamp(c.getTimeInMillis), StringType, timeZoneId),
TimestampType, timeZoneId),
millisToMicros(c.getTimeInMillis))
}
checkEvaluation(cast("abdef", StringType), "abdef")
checkEvaluation(cast("12.65", DecimalType.SYSTEM_DEFAULT), Decimal(12.65))
checkEvaluation(cast(cast(sd, DateType), StringType), sd)
checkEvaluation(cast(cast(d, StringType), DateType), 0)
checkEvaluation(cast(cast(nts, TimestampType, UTC_OPT), StringType, UTC_OPT), nts)
checkEvaluation(
cast(cast(ts, StringType, UTC_OPT), TimestampType, UTC_OPT),
fromJavaTimestamp(ts))
// all convert to string type to check
checkEvaluation(
cast(cast(cast(nts, TimestampType, UTC_OPT), DateType, UTC_OPT), StringType),
sd)
checkEvaluation(
cast(cast(cast(ts, DateType, UTC_OPT), TimestampType, UTC_OPT), StringType, UTC_OPT),
zts)
checkEvaluation(cast(cast("abdef", BinaryType), StringType), "abdef")
checkEvaluation(cast(cast(cast(cast(
cast(cast("5", ByteType), ShortType), IntegerType), FloatType), DoubleType), LongType),
5.toLong)
checkEvaluation(cast("23", DoubleType), 23d)
checkEvaluation(cast("23", IntegerType), 23)
checkEvaluation(cast("23", FloatType), 23f)
checkEvaluation(cast("23", DecimalType.USER_DEFAULT), Decimal(23))
checkEvaluation(cast("23", ByteType), 23.toByte)
checkEvaluation(cast("23", ShortType), 23.toShort)
checkEvaluation(cast(123, IntegerType), 123)
checkEvaluation(cast(Literal.create(null, IntegerType), ShortType), null)
}
test("cast and add") {
checkEvaluation(Add(Literal(23d), cast(true, DoubleType)), 24d)
checkEvaluation(Add(Literal(23), cast(true, IntegerType)), 24)
checkEvaluation(Add(Literal(23f), cast(true, FloatType)), 24f)
checkEvaluation(Add(Literal(Decimal(23)), cast(true, DecimalType.USER_DEFAULT)), Decimal(24))
checkEvaluation(Add(Literal(23.toByte), cast(true, ByteType)), 24.toByte)
checkEvaluation(Add(Literal(23.toShort), cast(true, ShortType)), 24.toShort)
}
test("from decimal") {
checkCast(Decimal(0.0), false)
checkCast(Decimal(0.5), true)
checkCast(Decimal(-5.0), true)
checkCast(Decimal(1.5), 1.toByte)
checkCast(Decimal(1.5), 1.toShort)
checkCast(Decimal(1.5), 1)
checkCast(Decimal(1.5), 1.toLong)
checkCast(Decimal(1.5), 1.5f)
checkCast(Decimal(1.5), 1.5)
checkCast(Decimal(1.5), "1.5")
}
test("cast from array") {
val array = Literal.create(Seq("123", "true", "f", null),
ArrayType(StringType, containsNull = true))
val array_notNull = Literal.create(Seq("123", "true", "f"),
ArrayType(StringType, containsNull = false))
checkNullCast(ArrayType(StringType), ArrayType(IntegerType))
{
val array = Literal.create(Seq.empty, ArrayType(NullType, containsNull = false))
val ret = cast(array, ArrayType(IntegerType, containsNull = false))
assert(ret.resolved)
checkEvaluation(ret, Seq.empty)
}
{
val ret = cast(array, ArrayType(BooleanType, containsNull = false))
assert(ret.resolved === false)
}
{
val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false))
assert(ret.resolved === false)
}
{
val ret = cast(array, IntegerType)
assert(ret.resolved === false)
}
}
test("cast from map") {
val map = Literal.create(
Map("a" -> "123", "b" -> "true", "c" -> "f", "d" -> null),
MapType(StringType, StringType, valueContainsNull = true))
val map_notNull = Literal.create(
Map("a" -> "123", "b" -> "true", "c" -> "f"),
MapType(StringType, StringType, valueContainsNull = false))
checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType))
{
val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false))
assert(ret.resolved === false)
}
{
val ret = cast(map, MapType(IntegerType, StringType, valueContainsNull = true))
assert(ret.resolved === false)
}
{
val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false))
assert(ret.resolved === false)
}
{
val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true))
assert(ret.resolved === false)
}
{
val ret = cast(map, IntegerType)
assert(ret.resolved === false)
}
}
test("cast from struct") {
checkNullCast(
StructType(Seq(
StructField("a", StringType),
StructField("b", IntegerType))),
StructType(Seq(
StructField("a", StringType),
StructField("b", StringType))))
val struct = Literal.create(
InternalRow(
UTF8String.fromString("123"),
UTF8String.fromString("true"),
UTF8String.fromString("f"),
null),
StructType(Seq(
StructField("a", StringType, nullable = true),
StructField("b", StringType, nullable = true),
StructField("c", StringType, nullable = true),
StructField("d", StringType, nullable = true))))
val struct_notNull = Literal.create(
InternalRow(
UTF8String.fromString("123"),
UTF8String.fromString("true"),
UTF8String.fromString("f")),
StructType(Seq(
StructField("a", StringType, nullable = false),
StructField("b", StringType, nullable = false),
StructField("c", StringType, nullable = false))))
{
val ret = cast(struct, StructType(Seq(
StructField("a", BooleanType, nullable = true),
StructField("b", BooleanType, nullable = true),
StructField("c", BooleanType, nullable = false),
StructField("d", BooleanType, nullable = true))))
assert(ret.resolved === false)
}
{
val ret = cast(struct_notNull, StructType(Seq(
StructField("a", BooleanType, nullable = true),
StructField("b", BooleanType, nullable = true),
StructField("c", BooleanType, nullable = false))))
assert(ret.resolved === false)
}
{
val ret = cast(struct, StructType(Seq(
StructField("a", StringType, nullable = true),
StructField("b", StringType, nullable = true),
StructField("c", StringType, nullable = true))))
assert(ret.resolved === false)
}
{
val ret = cast(struct, IntegerType)
assert(ret.resolved === false)
}
}
test("cast struct with a timestamp field") {
val originalSchema = new StructType().add("tsField", TimestampType, nullable = false)
// nine out of ten times I'm casting a struct, it's to normalize its fields nullability
val targetSchema = new StructType().add("tsField", TimestampType, nullable = true)
val inp = Literal.create(InternalRow(0L), originalSchema)
val expected = InternalRow(0L)
checkEvaluation(cast(inp, targetSchema), expected)
}
test("complex casting") {
val complex = Literal.create(
Row(
Seq("123", "true", "f"),
Map("a" -> "123", "b" -> "true", "c" -> "f"),
Row(0)),
StructType(Seq(
StructField("a",
ArrayType(StringType, containsNull = false), nullable = true),
StructField("m",
MapType(StringType, StringType, valueContainsNull = false), nullable = true),
StructField("s",
StructType(Seq(
StructField("i", IntegerType, nullable = true)))))))
val ret = cast(complex, StructType(Seq(
StructField("a",
ArrayType(IntegerType, containsNull = true), nullable = true),
StructField("m",
MapType(StringType, BooleanType, valueContainsNull = false), nullable = true),
StructField("s",
StructType(Seq(
StructField("l", LongType, nullable = true)))))))
assert(ret.resolved === false)
}
test("cast between string and interval") {
import org.apache.spark.unsafe.types.CalendarInterval
checkEvaluation(Cast(Literal(""), CalendarIntervalType), null)
checkEvaluation(Cast(Literal("interval -3 month 1 day 7 hours"), CalendarIntervalType),
new CalendarInterval(-3, 1, 7 * MICROS_PER_HOUR))
checkEvaluation(Cast(Literal.create(
new CalendarInterval(15, 9, -3 * MICROS_PER_HOUR), CalendarIntervalType),
StringType),
"1 years 3 months 9 days -3 hours")
checkEvaluation(Cast(Literal("INTERVAL 1 Second 1 microsecond"), CalendarIntervalType),
new CalendarInterval(0, 0, 1000001))
checkEvaluation(Cast(Literal("1 MONTH 1 Microsecond"), CalendarIntervalType),
new CalendarInterval(1, 0, 1))
}
test("cast string to boolean") {
checkCast("t", true)
checkCast("true", true)
checkCast("tRUe", true)
checkCast("y", true)
checkCast("yes", true)
checkCast("1", true)
checkCast("f", false)
checkCast("false", false)
checkCast("FAlsE", false)
checkCast("n", false)
checkCast("no", false)
checkCast("0", false)
}
protected def checkInvalidCastFromNumericType(to: DataType): Unit = {
assert(cast(1.toByte, to).checkInputDataTypes().isFailure)
assert(cast(1.toShort, to).checkInputDataTypes().isFailure)
assert(cast(1, to).checkInputDataTypes().isFailure)
assert(cast(1L, to).checkInputDataTypes().isFailure)
assert(cast(1.0.toFloat, to).checkInputDataTypes().isFailure)
assert(cast(1.0, to).checkInputDataTypes().isFailure)
}
test("SPARK-16729 type checking for casting to date type") {
assert(cast("1234", DateType).checkInputDataTypes().isSuccess)
assert(cast(new Timestamp(1), DateType).checkInputDataTypes().isSuccess)
assert(cast(false, DateType).checkInputDataTypes().isFailure)
checkInvalidCastFromNumericType(DateType)
}
test("SPARK-20302 cast with same structure") {
val from = new StructType()
.add("a", IntegerType)
.add("b", new StructType().add("b1", LongType))
val to = new StructType()
.add("a1", IntegerType)
.add("b1", new StructType().add("b11", LongType))
val input = Row(10, Row(12L))
checkEvaluation(cast(Literal.create(input, from), to), input)
}
test("SPARK-22500: cast for struct should not generate codes beyond 64KB") {
val N = 25
val fromInner = new StructType(
(1 to N).map(i => StructField(s"s$i", DoubleType)).toArray)
val toInner = new StructType(
(1 to N).map(i => StructField(s"i$i", IntegerType)).toArray)
val inputInner = Row.fromSeq((1 to N).map(i => i + 0.5))
val outputInner = Row.fromSeq((1 to N))
val fromOuter = new StructType(
(1 to N).map(i => StructField(s"s$i", fromInner)).toArray)
val toOuter = new StructType(
(1 to N).map(i => StructField(s"s$i", toInner)).toArray)
val inputOuter = Row.fromSeq((1 to N).map(_ => inputInner))
val outputOuter = Row.fromSeq((1 to N).map(_ => outputInner))
checkEvaluation(cast(Literal.create(inputOuter, fromOuter), toOuter), outputOuter)
}
test("SPARK-22570: Cast should not create a lot of global variables") {
val ctx = new CodegenContext
cast("1", IntegerType).genCode(ctx)
cast("2", LongType).genCode(ctx)
assert(ctx.inlinedMutableStates.length == 0)
}
test("up-cast") {
def isCastSafe(from: NumericType, to: NumericType): Boolean = (from, to) match {
case (_, dt: DecimalType) => dt.isWiderThan(from)
case (dt: DecimalType, _) => dt.isTighterThan(to)
case _ => numericPrecedence.indexOf(from) <= numericPrecedence.indexOf(to)
}
def makeComplexTypes(dt: NumericType, nullable: Boolean): Seq[DataType] = {
Seq(
new StructType().add("a", dt, nullable).add("b", dt, nullable),
ArrayType(dt, nullable),
MapType(dt, dt, nullable),
ArrayType(new StructType().add("a", dt, nullable), nullable),
new StructType().add("a", ArrayType(dt, nullable), nullable)
)
}
import DataTypeTestUtils._
numericTypes.foreach { from =>
val (safeTargetTypes, unsafeTargetTypes) = numericTypes.partition(to => isCastSafe(from, to))
safeTargetTypes.foreach { to =>
assert(Cast.canUpCast(from, to), s"It should be possible to up-cast $from to $to")
// If the nullability is compatible, we can up-cast complex types too.
Seq(true -> true, false -> false, false -> true).foreach { case (fn, tn) =>
makeComplexTypes(from, fn).zip(makeComplexTypes(to, tn)).foreach {
case (complexFromType, complexToType) =>
assert(Cast.canUpCast(complexFromType, complexToType))
}
}
makeComplexTypes(from, true).zip(makeComplexTypes(to, false)).foreach {
case (complexFromType, complexToType) =>
assert(!Cast.canUpCast(complexFromType, complexToType))
}
}
unsafeTargetTypes.foreach { to =>
assert(!Cast.canUpCast(from, to), s"It shouldn't be possible to up-cast $from to $to")
makeComplexTypes(from, true).zip(makeComplexTypes(to, true)).foreach {
case (complexFromType, complexToType) =>
assert(!Cast.canUpCast(complexFromType, complexToType))
}
}
}
numericTypes.foreach { dt =>
makeComplexTypes(dt, true).foreach { complexType =>
assert(!Cast.canUpCast(complexType, StringType))
}
}
atomicTypes.foreach { atomicType =>
assert(Cast.canUpCast(NullType, atomicType))
}
}
test("SPARK-27671: cast from nested null type in struct") {
import DataTypeTestUtils._
atomicTypes.foreach { atomicType =>
val struct = Literal.create(
InternalRow(null),
StructType(Seq(StructField("a", NullType, nullable = true))))
val ret = cast(struct, StructType(Seq(
StructField("a", atomicType, nullable = true))))
assert(ret.resolved)
checkEvaluation(ret, InternalRow(null))
}
}
test("Process Infinity, -Infinity, NaN in case insensitive manner") {
Seq("inf", "+inf", "infinity", "+infiNity", " infinity ").foreach { value =>
checkEvaluation(cast(value, FloatType), Float.PositiveInfinity)
}
Seq("-infinity", "-infiniTy", " -infinity ", " -inf ").foreach { value =>
checkEvaluation(cast(value, FloatType), Float.NegativeInfinity)
}
Seq("inf", "+inf", "infinity", "+infiNity", " infinity ").foreach { value =>
checkEvaluation(cast(value, DoubleType), Double.PositiveInfinity)
}
Seq("-infinity", "-infiniTy", " -infinity ", " -inf ").foreach { value =>
checkEvaluation(cast(value, DoubleType), Double.NegativeInfinity)
}
Seq("nan", "nAn", " nan ").foreach { value =>
checkEvaluation(cast(value, FloatType), Float.NaN)
}
Seq("nan", "nAn", " nan ").foreach { value =>
checkEvaluation(cast(value, DoubleType), Double.NaN)
}
}
test("SPARK-22825 Cast array to string") {
val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType)
checkEvaluation(ret1, "[1, 2, 3, 4, 5]")
val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType)
checkEvaluation(ret2, "[ab, cde, f]")
Seq(false, true).foreach { omitNull =>
withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> omitNull.toString) {
val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType)
checkEvaluation(ret3, s"[ab,${if (omitNull) "" else " null"}, c]")
}
}
val ret4 =
cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType)
checkEvaluation(ret4, "[ab, cde, f]")
val ret5 = cast(
Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)),
StringType)
checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]")
val ret6 = cast(
Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00")
.map(Timestamp.valueOf)),
StringType)
checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]")
val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType)
checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]")
val ret8 = cast(
Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))),
StringType)
checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]")
}
test("SPARK-33291: Cast array with null elements to string") {
Seq(false, true).foreach { omitNull =>
withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> omitNull.toString) {
val ret1 = cast(Literal.create(Array(null, null)), StringType)
checkEvaluation(
ret1,
s"[${if (omitNull) "" else "null"},${if (omitNull) "" else " null"}]")
}
}
}
test("SPARK-22973 Cast map to string") {
Seq(
false -> ("{", "}"),
true -> ("[", "]")).foreach { case (legacyCast, (lb, rb)) =>
withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> legacyCast.toString) {
val ret1 = cast(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c")), StringType)
checkEvaluation(ret1, s"${lb}1 -> a, 2 -> b, 3 -> c$rb")
val ret2 = cast(
Literal.create(Map("1" -> "a".getBytes, "2" -> null, "3" -> "c".getBytes)),
StringType)
checkEvaluation(ret2, s"${lb}1 -> a, 2 ->${if (legacyCast) "" else " null"}, 3 -> c$rb")
val ret3 = cast(
Literal.create(Map(
1 -> Date.valueOf("2014-12-03"),
2 -> Date.valueOf("2014-12-04"),
3 -> Date.valueOf("2014-12-05"))),
StringType)
checkEvaluation(ret3, s"${lb}1 -> 2014-12-03, 2 -> 2014-12-04, 3 -> 2014-12-05$rb")
val ret4 = cast(
Literal.create(Map(
1 -> Timestamp.valueOf("2014-12-03 13:01:00"),
2 -> Timestamp.valueOf("2014-12-04 15:05:00"))),
StringType)
checkEvaluation(ret4, s"${lb}1 -> 2014-12-03 13:01:00, 2 -> 2014-12-04 15:05:00$rb")
val ret5 = cast(
Literal.create(Map(
1 -> Array(1, 2, 3),
2 -> Array(4, 5, 6))),
StringType)
checkEvaluation(ret5, s"${lb}1 -> [1, 2, 3], 2 -> [4, 5, 6]$rb")
}
}
}
test("SPARK-22981 Cast struct to string") {
Seq(
false -> ("{", "}"),
true -> ("[", "]")).foreach { case (legacyCast, (lb, rb)) =>
withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> legacyCast.toString) {
val ret1 = cast(Literal.create((1, "a", 0.1)), StringType)
checkEvaluation(ret1, s"${lb}1, a, 0.1$rb")
val ret2 = cast(Literal.create(Tuple3[Int, String, String](1, null, "a")), StringType)
checkEvaluation(ret2, s"${lb}1,${if (legacyCast) "" else " null"}, a$rb")
val ret3 = cast(Literal.create(
(Date.valueOf("2014-12-03"), Timestamp.valueOf("2014-12-03 15:05:00"))), StringType)
checkEvaluation(ret3, s"${lb}2014-12-03, 2014-12-03 15:05:00$rb")
val ret4 = cast(Literal.create(((1, "a"), 5, 0.1)), StringType)
checkEvaluation(ret4, s"$lb${lb}1, a$rb, 5, 0.1$rb")
val ret5 = cast(Literal.create((Seq(1, 2, 3), "a", 0.1)), StringType)
checkEvaluation(ret5, s"$lb[1, 2, 3], a, 0.1$rb")
val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType)
checkEvaluation(ret6, s"${lb}1, ${lb}1 -> a, 2 -> b, 3 -> c$rb$rb")
}
}
}
test("SPARK-33291: Cast struct with null elements to string") {
Seq(
false -> ("{", "}"),
true -> ("[", "]")).foreach { case (legacyCast, (lb, rb)) =>
withSQLConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING.key -> legacyCast.toString) {
val ret1 = cast(Literal.create(Tuple2[String, String](null, null)), StringType)
checkEvaluation(
ret1,
s"$lb${if (legacyCast) "" else "null"},${if (legacyCast) "" else " null"}$rb")
}
}
}
test("SPARK-34667: cast year-month interval to string") {
Seq(
Period.ofMonths(0) -> "0-0",
Period.ofMonths(1) -> "0-1",
Period.ofMonths(-1) -> "-0-1",
Period.ofYears(1) -> "1-0",
Period.ofYears(-1) -> "-1-0",
Period.ofYears(10).plusMonths(10) -> "10-10",
Period.ofYears(-123).minusMonths(6) -> "-123-6",
Period.ofMonths(Int.MaxValue) -> "178956970-7",
Period.ofMonths(Int.MinValue) -> "-178956970-8"
).foreach { case (period, intervalPayload) =>
checkEvaluation(
Cast(Literal(period), StringType),
s"INTERVAL '$intervalPayload' YEAR TO MONTH")
}
checkConsistencyBetweenInterpretedAndCodegen(
(child: Expression) => Cast(child, StringType), YearMonthIntervalType)
}
test("SPARK-34668: cast day-time interval to string") {
Seq(
Duration.ZERO -> "0 00:00:00",
Duration.of(1, ChronoUnit.MICROS) -> "0 00:00:00.000001",
Duration.ofMillis(-1) -> "-0 00:00:00.001",
Duration.ofMillis(1234) -> "0 00:00:01.234",
Duration.ofSeconds(-9).minus(999999, ChronoUnit.MICROS) -> "-0 00:00:09.999999",
Duration.ofMinutes(30).plusMillis(59010) -> "0 00:30:59.01",
Duration.ofHours(-23).minusSeconds(59) -> "-0 23:00:59",
Duration.ofDays(1).plus(12345678, ChronoUnit.MICROS) -> "1 00:00:12.345678",
Duration.ofDays(-1234).minusHours(23).minusMinutes(59).minusSeconds(59).minusMillis(999) ->
"-1234 23:59:59.999",
microsToDuration(Long.MaxValue) -> "106751991 04:00:54.775807",
microsToDuration(Long.MinValue + 1) -> "-106751991 04:00:54.775807",
microsToDuration(Long.MinValue) -> "-106751991 04:00:54.775808"
).foreach { case (period, intervalPayload) =>
checkEvaluation(
Cast(Literal(period), StringType),
s"INTERVAL '$intervalPayload' DAY TO SECOND")
}
dayTimeIntervalTypes.foreach { it =>
checkConsistencyBetweenInterpretedAndCodegen((child: Expression) =>
Cast(child, StringType), it)
}
}
private val specialTs = Seq(
"0001-01-01T00:00:00", // the fist timestamp of Common Era
"1582-10-15T23:59:59", // the cutover date from Julian to Gregorian calendar
"1970-01-01T00:00:00", // the epoch timestamp
"9999-12-31T23:59:59" // the last supported timestamp according to SQL standard
)
test("SPARK-35698: cast timestamp without time zone to string") {
specialTs.foreach { s =>
checkEvaluation(cast(LocalDateTime.parse(s), StringType), s.replace("T", " "))
}
}
test("SPARK-35711: cast timestamp without time zone to timestamp with local time zone") {
outstandingZoneIds.foreach { zoneId =>
withDefaultTimeZone(zoneId) {
specialTs.foreach { s =>
val input = LocalDateTime.parse(s)
val expectedTs = Timestamp.valueOf(s.replace("T", " "))
checkEvaluation(cast(input, TimestampType), expectedTs)
}
}
}
}
test("SPARK-35716: cast timestamp without time zone to date type") {
specialTs.foreach { s =>
val dt = LocalDateTime.parse(s)
checkEvaluation(cast(dt, DateType), LocalDate.parse(s.split("T")(0)))
}
}
test("SPARK-35718: cast date type to timestamp without timezone") {
specialTs.foreach { s =>
val inputDate = LocalDate.parse(s.split("T")(0))
// The hour/minute/second of the expect result should be 0
val expectedTs = LocalDateTime.parse(s.split("T")(0) + "T00:00:00")
checkEvaluation(cast(inputDate, TimestampWithoutTZType), expectedTs)
}
}
test("SPARK-35719: cast timestamp with local time zone to timestamp without timezone") {
outstandingZoneIds.foreach { zoneId =>
withDefaultTimeZone(zoneId) {
specialTs.foreach { s =>
val input = Timestamp.valueOf(s.replace("T", " "))
val expectedTs = LocalDateTime.parse(s)
checkEvaluation(cast(input, TimestampWithoutTZType), expectedTs)
}
}
}
}
test("disallow type conversions between Numeric types and Timestamp without time zone type") {
import DataTypeTestUtils.numericTypes
checkInvalidCastFromNumericType(TimestampWithoutTZType)
var errorMsg = "cannot cast bigint to timestamp without time zone"
verifyCastFailure(cast(Literal(0L), TimestampWithoutTZType), Some(errorMsg))
val timestampWithoutTZLiteral = Literal.create(LocalDateTime.now(), TimestampWithoutTZType)
errorMsg = "cannot cast timestamp without time zone to"
numericTypes.foreach { numericType =>
verifyCastFailure(cast(timestampWithoutTZLiteral, numericType), Some(errorMsg))
}
}
}