[SPARK-25522][SQL] Improve type promotion for input arguments of elementAt function

## What changes were proposed in this pull request?
In ElementAt, when first argument is MapType, we should coerce the key type and the second argument based on findTightestCommonType. This is not happening currently. We may produce wrong output as we will incorrectly downcast the right hand side double expression to int.

```SQL
spark-sql> select element_at(map(1,"one", 2, "two"), 2.2);

two
```

Also, when the first argument is ArrayType, the second argument should be an integer type or a smaller integral type that can be safely casted to an integer type. Currently we may do an unsafe cast. In the following case, we should fail with an error as 2.2 is not a integer index. But instead we down cast it to int currently and return a result instead.

```SQL
spark-sql> select element_at(array(1,2), 1.24D);

1
```
This PR also supports implicit cast between two MapTypes. I have followed similar logic that exists today to do implicit casts between two array types.
## How was this patch tested?
Added new tests in DataFrameFunctionSuite, TypeCoercionSuite.

Closes #22544 from dilipbiswal/SPARK-25522.

Authored-by: Dilip Biswal <dbiswal@us.ibm.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Dilip Biswal 2018-09-27 15:04:59 +08:00 committed by Wenchen Fan
parent ff876137fa
commit d03e0af80d
5 changed files with 154 additions and 22 deletions

View file

@ -950,6 +950,25 @@ object TypeCoercion {
if !Cast.forceNullable(fromType, toType) =>
implicitCast(fromType, toType).map(ArrayType(_, false)).orNull
// Implicit cast between Map types.
// Follows the same semantics of implicit casting between two array types.
// Refer to documentation above. Make sure that both key and values
// can not be null after the implicit cast operation by calling forceNullable
// method.
case (MapType(fromKeyType, fromValueType, fn), MapType(toKeyType, toValueType, tn))
if !Cast.forceNullable(fromKeyType, toKeyType) && Cast.resolvableNullability(fn, tn) =>
if (Cast.forceNullable(fromValueType, toValueType) && !tn) {
null
} else {
val newKeyType = implicitCast(fromKeyType, toKeyType).orNull
val newValueType = implicitCast(fromValueType, toValueType).orNull
if (newKeyType != null && newValueType != null) {
MapType(newKeyType, newValueType, tn)
} else {
null
}
}
case _ => null
}
Option(ret)

View file

@ -183,7 +183,7 @@ object Cast {
case _ => false
}
private def resolvableNullability(from: Boolean, to: Boolean) = !from || to
def resolvableNullability(from: Boolean, to: Boolean): Boolean = !from || to
}
/**

View file

@ -2154,21 +2154,34 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
}
override def inputTypes: Seq[AbstractDataType] = {
Seq(TypeCollection(ArrayType, MapType),
left.dataType match {
case _: ArrayType => IntegerType
case _: MapType => mapKeyType
case _ => AnyDataType // no match for a wrong 'left' expression type
}
)
(left.dataType, right.dataType) match {
case (arr: ArrayType, e2: IntegralType) if (e2 != LongType) =>
Seq(arr, IntegerType)
case (MapType(keyType, valueType, hasNull), e2) =>
TypeCoercion.findTightestCommonType(keyType, e2) match {
case Some(dt) => Seq(MapType(dt, valueType, hasNull), dt)
case _ => Seq.empty
}
case (l, r) => Seq.empty
}
}
override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case f: TypeCheckResult.TypeCheckFailure => f
case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] =>
TypeUtils.checkForOrderingExpr(mapKeyType, s"function $prettyName")
case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess
(left.dataType, right.dataType) match {
case (_: ArrayType, e2) if e2 != IntegerType =>
TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
s"been ${ArrayType.simpleString} followed by a ${IntegerType.simpleString}, but it's " +
s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
case (MapType(e1, _, _), e2) if (!e2.sameType(e1)) =>
TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
s"been ${MapType.simpleString} followed by a value of same key type, but it's " +
s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
case (e1, _) if (!e1.isInstanceOf[MapType] && !e1.isInstanceOf[ArrayType]) =>
TypeCheckResult.TypeCheckFailure(s"The first argument to function $prettyName should " +
s"have been ${ArrayType.simpleString} or ${MapType.simpleString} type, but its " +
s"${left.dataType.catalogString} type.")
case _ => TypeCheckResult.TypeCheckSuccess
}
}

View file

@ -257,12 +257,43 @@ class TypeCoercionSuite extends AnalysisTest {
shouldNotCast(checkedType, IntegralType)
}
test("implicit type cast - MapType(StringType, StringType)") {
val checkedType = MapType(StringType, StringType)
checkTypeCasting(checkedType, castableTypes = Seq(checkedType))
shouldNotCast(checkedType, DecimalType)
shouldNotCast(checkedType, NumericType)
shouldNotCast(checkedType, IntegralType)
test("implicit type cast between two Map types") {
val sourceType = MapType(IntegerType, IntegerType, true)
val castableTypes = numericTypes ++ Seq(StringType).filter(!Cast.forceNullable(IntegerType, _))
val targetTypes = numericTypes.filter(!Cast.forceNullable(IntegerType, _)).map { t =>
MapType(t, sourceType.valueType, valueContainsNull = true)
}
val nonCastableTargetTypes = allTypes.filterNot(castableTypes.contains(_)).map {t =>
MapType(t, sourceType.valueType, valueContainsNull = true)
}
// Tests that its possible to setup implicit casts between two map types when
// source map's key type is integer and the target map's key type are either Byte, Short,
// Long, Double, Float, Decimal(38, 18) or String.
targetTypes.foreach { targetType =>
shouldCast(sourceType, targetType, targetType)
}
// Tests that its not possible to setup implicit casts between two map types when
// source map's key type is integer and the target map's key type are either Binary,
// Boolean, Date, Timestamp, Array, Struct, CaleandarIntervalType or NullType
nonCastableTargetTypes.foreach { targetType =>
shouldNotCast(sourceType, targetType)
}
// Tests that its not possible to cast from nullable map type to not nullable map type.
val targetNotNullableTypes = allTypes.filterNot(_ == IntegerType).map { t =>
MapType(t, sourceType.valueType, valueContainsNull = false)
}
val sourceMapExprWithValueNull =
CreateMap(Seq(Literal.default(sourceType.keyType),
Literal.create(null, sourceType.valueType)))
targetNotNullableTypes.foreach { targetType =>
val castDefault =
TypeCoercion.ImplicitTypeCasts.implicitCast(sourceMapExprWithValueNull, targetType)
assert(castDefault.isEmpty,
s"Should not be able to cast $sourceType to $targetType, but got $castDefault")
}
}
test("implicit type cast - StructType().add(\"a1\", StringType)") {

View file

@ -1211,11 +1211,80 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
Seq(Row("3"), Row(""), Row(null))
)
val e = intercept[AnalysisException] {
val e1 = intercept[AnalysisException] {
Seq(("a string element", 1)).toDF().selectExpr("element_at(_1, _2)")
}
assert(e.message.contains(
"argument 1 requires (array or map) type, however, '`_1`' is of string type"))
val errorMsg1 =
s"""
|The first argument to function element_at should have been array or map type, but
|its string type.
""".stripMargin.replace("\n", " ").trim()
assert(e1.message.contains(errorMsg1))
checkAnswer(
OneRowRelation().selectExpr("element_at(array(2, 1), 2S)"),
Seq(Row(1))
)
checkAnswer(
OneRowRelation().selectExpr("element_at(array('a', 'b'), 1Y)"),
Seq(Row("a"))
)
checkAnswer(
OneRowRelation().selectExpr("element_at(array(1, 2, 3), 3)"),
Seq(Row(3))
)
val e2 = intercept[AnalysisException] {
OneRowRelation().selectExpr("element_at(array('a', 'b'), 1L)")
}
val errorMsg2 =
s"""
|Input to function element_at should have been array followed by a int, but it's
|[array<string>, bigint].
""".stripMargin.replace("\n", " ").trim()
assert(e2.message.contains(errorMsg2))
checkAnswer(
OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 2Y)"),
Seq(Row("b"))
)
checkAnswer(
OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1S)"),
Seq(Row("a"))
)
checkAnswer(
OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 2)"),
Seq(Row("b"))
)
checkAnswer(
OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 2L)"),
Seq(Row("b"))
)
checkAnswer(
OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1.0D)"),
Seq(Row("a"))
)
checkAnswer(
OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1.23D)"),
Seq(Row(null))
)
val e3 = intercept[AnalysisException] {
OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), '1')")
}
val errorMsg3 =
s"""
|Input to function element_at should have been map followed by a value of same
|key type, but it's [map<int,string>, string].
""".stripMargin.replace("\n", " ").trim()
assert(e3.message.contains(errorMsg3))
}
test("array_union functions") {