[SPARK-33460][SQL] Accessing map values should fail if key is not found
### What changes were proposed in this pull request? Instead of returning NULL, throws runtime NoSuchElementException towards invalid key accessing in map-like functions, such as element_at, GetMapValue, when ANSI mode is on. ### Why are the changes needed? For ANSI mode. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Added UT and Existing UT. Closes #30386 from leanken/leanken-SPARK-33460. Authored-by: xuewei.linxuewei <xuewei.linxuewei@alibaba-inc.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
6883f29465
commit
b5eca18af0
|
@ -112,12 +112,14 @@ SELECT * FROM t;
|
|||
The behavior of some SQL functions can be different under ANSI mode (`spark.sql.ansi.enabled=true`).
|
||||
- `size`: This function returns null for null input.
|
||||
- `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices.
|
||||
- `element_at`: This function throws `NoSuchElementException` if key does not exist in map.
|
||||
- `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices.
|
||||
|
||||
### SQL Operators
|
||||
|
||||
The behavior of some SQL operators can be different under ANSI mode (`spark.sql.ansi.enabled=true`).
|
||||
- `array_col[index]`: This operator throws `ArrayIndexOutOfBoundsException` if using invalid indices.
|
||||
- `map_col[key]`: This operator throws `NoSuchElementException` if key does not exist in map.
|
||||
|
||||
### SQL Keywords
|
||||
|
||||
|
|
|
@ -55,8 +55,8 @@ case class ProjectionOverSchema(schema: StructType) {
|
|||
getProjection(child).map { projection => MapKeys(projection) }
|
||||
case MapValues(child) =>
|
||||
getProjection(child).map { projection => MapValues(projection) }
|
||||
case GetMapValue(child, key) =>
|
||||
getProjection(child).map { projection => GetMapValue(projection, key) }
|
||||
case GetMapValue(child, key, failOnError) =>
|
||||
getProjection(child).map { projection => GetMapValue(projection, key, failOnError) }
|
||||
case GetStructFieldObject(child, field: StructField) =>
|
||||
getProjection(child).map(p => (p, p.dataType)).map {
|
||||
case (projection, projSchema: StructType) =>
|
||||
|
|
|
@ -91,7 +91,7 @@ object SelectedField {
|
|||
}
|
||||
val newField = StructField(field.name, newFieldDataType, field.nullable)
|
||||
selectField(child, Option(ArrayType(struct(newField), containsNull)))
|
||||
case GetMapValue(child, _) =>
|
||||
case GetMapValue(child, _, _) =>
|
||||
// GetMapValue does not select a field from a struct (i.e. prune the struct) so it can't be
|
||||
// the top-level extractor. However it can be part of an extractor chain.
|
||||
val MapType(keyType, _, valueContainsNull) = child.dataType
|
||||
|
|
|
@ -1911,7 +1911,9 @@ case class ArrayPosition(left: Expression, right: Expression)
|
|||
If `spark.sql.ansi.enabled` is set to true, it throws ArrayIndexOutOfBoundsException
|
||||
for invalid indices.
|
||||
|
||||
_FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map
|
||||
_FUNC_(map, key) - Returns value for given key. The function returns NULL
|
||||
if the key is not contained in the map and `spark.sql.ansi.enabled` is set to false.
|
||||
If `spark.sql.ansi.enabled` is set to true, it throws NoSuchElementException instead.
|
||||
""",
|
||||
examples = """
|
||||
Examples:
|
||||
|
@ -1931,6 +1933,9 @@ case class ElementAt(
|
|||
|
||||
@transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType
|
||||
|
||||
@transient private lazy val mapValueContainsNull =
|
||||
left.dataType.asInstanceOf[MapType].valueContainsNull
|
||||
|
||||
@transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull
|
||||
|
||||
@transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(mapKeyType)
|
||||
|
@ -1989,7 +1994,7 @@ case class ElementAt(
|
|||
override def nullable: Boolean = left.dataType match {
|
||||
case _: ArrayType =>
|
||||
computeNullabilityFromArray(left, right, failOnError, nullability)
|
||||
case _: MapType => true
|
||||
case _: MapType => if (failOnError) mapValueContainsNull else true
|
||||
}
|
||||
|
||||
override def nullSafeEval(value: Any, ordinal: Any): Any = doElementAt(value, ordinal)
|
||||
|
@ -2022,7 +2027,7 @@ case class ElementAt(
|
|||
}
|
||||
}
|
||||
case _: MapType =>
|
||||
(value, ordinal) => getValueEval(value, ordinal, mapKeyType, ordering)
|
||||
(value, ordinal) => getValueEval(value, ordinal, mapKeyType, ordering, failOnError)
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
|
@ -2069,7 +2074,7 @@ case class ElementAt(
|
|||
""".stripMargin
|
||||
})
|
||||
case _: MapType =>
|
||||
doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType])
|
||||
doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType], failOnError)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -336,7 +336,12 @@ trait GetArrayItemUtil {
|
|||
trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
|
||||
|
||||
// todo: current search is O(n), improve it.
|
||||
def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = {
|
||||
def getValueEval(
|
||||
value: Any,
|
||||
ordinal: Any,
|
||||
keyType: DataType,
|
||||
ordering: Ordering[Any],
|
||||
failOnError: Boolean): Any = {
|
||||
val map = value.asInstanceOf[MapData]
|
||||
val length = map.numElements()
|
||||
val keys = map.keyArray()
|
||||
|
@ -352,14 +357,24 @@ trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
|
|||
}
|
||||
}
|
||||
|
||||
if (!found || values.isNullAt(i)) {
|
||||
if (!found) {
|
||||
if (failOnError) {
|
||||
throw new NoSuchElementException(s"Key $ordinal does not exist.")
|
||||
} else {
|
||||
null
|
||||
}
|
||||
} else if (values.isNullAt(i)) {
|
||||
null
|
||||
} else {
|
||||
values.get(i, dataType)
|
||||
}
|
||||
}
|
||||
|
||||
def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): ExprCode = {
|
||||
def doGetValueGenCode(
|
||||
ctx: CodegenContext,
|
||||
ev: ExprCode,
|
||||
mapType: MapType,
|
||||
failOnError: Boolean): ExprCode = {
|
||||
val index = ctx.freshName("index")
|
||||
val length = ctx.freshName("length")
|
||||
val keys = ctx.freshName("keys")
|
||||
|
@ -368,12 +383,22 @@ trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
|
|||
val values = ctx.freshName("values")
|
||||
val keyType = mapType.keyType
|
||||
val nullCheck = if (mapType.valueContainsNull) {
|
||||
s" || $values.isNullAt($index)"
|
||||
s"""else if ($values.isNullAt($index)) {
|
||||
${ev.isNull} = true;
|
||||
}
|
||||
"""
|
||||
} else {
|
||||
""
|
||||
}
|
||||
|
||||
val keyJavaType = CodeGenerator.javaType(keyType)
|
||||
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
|
||||
val keyNotFoundBranch = if (failOnError) {
|
||||
s"""throw new NoSuchElementException("Key " + $eval2 + " does not exist.");"""
|
||||
} else {
|
||||
s"${ev.isNull} = true;"
|
||||
}
|
||||
|
||||
s"""
|
||||
final int $length = $eval1.numElements();
|
||||
final ArrayData $keys = $eval1.keyArray();
|
||||
|
@ -390,9 +415,9 @@ trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
|
|||
}
|
||||
}
|
||||
|
||||
if (!$found$nullCheck) {
|
||||
${ev.isNull} = true;
|
||||
} else {
|
||||
if (!$found) {
|
||||
$keyNotFoundBranch
|
||||
} $nullCheck else {
|
||||
${ev.value} = ${CodeGenerator.getValue(values, dataType, index)};
|
||||
}
|
||||
"""
|
||||
|
@ -405,9 +430,14 @@ trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
|
|||
*
|
||||
* We need to do type checking here as `key` expression maybe unresolved.
|
||||
*/
|
||||
case class GetMapValue(child: Expression, key: Expression)
|
||||
case class GetMapValue(
|
||||
child: Expression,
|
||||
key: Expression,
|
||||
failOnError: Boolean = SQLConf.get.ansiEnabled)
|
||||
extends GetMapValueUtil with ExtractValue with NullIntolerant {
|
||||
|
||||
def this(child: Expression, key: Expression) = this(child, key, SQLConf.get.ansiEnabled)
|
||||
|
||||
@transient private lazy val ordering: Ordering[Any] =
|
||||
TypeUtils.getInterpretedOrdering(keyType)
|
||||
|
||||
|
@ -442,10 +472,10 @@ case class GetMapValue(child: Expression, key: Expression)
|
|||
|
||||
// todo: current search is O(n), improve it.
|
||||
override def nullSafeEval(value: Any, ordinal: Any): Any = {
|
||||
getValueEval(value, ordinal, keyType, ordering)
|
||||
getValueEval(value, ordinal, keyType, ordering, failOnError)
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType])
|
||||
doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType], failOnError)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -71,7 +71,7 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
|
|||
// out of bounds, mimic the runtime behavior and return null
|
||||
Literal(null, ga.dataType)
|
||||
}
|
||||
case GetMapValue(CreateMap(elems, _), key) => CaseKeyWhen(key, elems)
|
||||
case GetMapValue(CreateMap(elems, _), key, _) => CaseKeyWhen(key, elems)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1915,4 +1915,19 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-33460: element_at NoSuchElementException") {
|
||||
Seq(true, false).foreach { ansiEnabled =>
|
||||
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
|
||||
val map = Literal.create(Map(1 -> "a", 2 -> "b"), MapType(IntegerType, StringType))
|
||||
val expr: Expression = ElementAt(map, Literal(5))
|
||||
if (ansiEnabled) {
|
||||
val errMsg = "Key 5 does not exist."
|
||||
checkExceptionInExpression[Exception](expr, errMsg)
|
||||
} else {
|
||||
checkEvaluation(expr, null)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -85,6 +85,23 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
}
|
||||
}
|
||||
|
||||
test("SPARK-33460: GetMapValue NoSuchElementException") {
|
||||
Seq(true, false).foreach { ansiEnabled =>
|
||||
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
|
||||
val map = Literal.create(Map(1 -> "a", 2 -> "b"), MapType(IntegerType, StringType))
|
||||
|
||||
if (ansiEnabled) {
|
||||
checkExceptionInExpression[Exception](
|
||||
GetMapValue(map, Literal(5)),
|
||||
"Key 5 does not exist."
|
||||
)
|
||||
} else {
|
||||
checkEvaluation(GetMapValue(map, Literal(5)), null)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-26637 handles GetArrayItem nullability correctly when input array size is constant") {
|
||||
// CreateArray case
|
||||
val a = AttributeReference("a", IntegerType, nullable = false)()
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
--IMPORT map.sql
|
5
sql/core/src/test/resources/sql-tests/inputs/map.sql
Normal file
5
sql/core/src/test/resources/sql-tests/inputs/map.sql
Normal file
|
@ -0,0 +1,5 @@
|
|||
-- test cases for map functions
|
||||
|
||||
-- key does not exist
|
||||
select element_at(map(1, 'a', 2, 'b'), 5);
|
||||
select map(1, 'a', 2, 'b')[5];
|
|
@ -0,0 +1,20 @@
|
|||
-- Automatically generated by SQLQueryTestSuite
|
||||
-- Number of queries: 2
|
||||
|
||||
|
||||
-- !query
|
||||
select element_at(map(1, 'a', 2, 'b'), 5)
|
||||
-- !query schema
|
||||
struct<>
|
||||
-- !query output
|
||||
java.util.NoSuchElementException
|
||||
Key 5 does not exist.
|
||||
|
||||
|
||||
-- !query
|
||||
select map(1, 'a', 2, 'b')[5]
|
||||
-- !query schema
|
||||
struct<>
|
||||
-- !query output
|
||||
java.util.NoSuchElementException
|
||||
Key 5 does not exist.
|
18
sql/core/src/test/resources/sql-tests/results/map.sql.out
Normal file
18
sql/core/src/test/resources/sql-tests/results/map.sql.out
Normal file
|
@ -0,0 +1,18 @@
|
|||
-- Automatically generated by SQLQueryTestSuite
|
||||
-- Number of queries: 2
|
||||
|
||||
|
||||
-- !query
|
||||
select element_at(map(1, 'a', 2, 'b'), 5)
|
||||
-- !query schema
|
||||
struct<element_at(map(1, a, 2, b), 5):string>
|
||||
-- !query output
|
||||
NULL
|
||||
|
||||
|
||||
-- !query
|
||||
select map(1, 'a', 2, 'b')[5]
|
||||
-- !query schema
|
||||
struct<map(1, a, 2, b)[5]:string>
|
||||
-- !query output
|
||||
NULL
|
Loading…
Reference in a new issue