[SPARK-33386][SQL] Accessing array elements in ElementAt/Elt/GetArrayItem should failed if index is out of bound

### What changes were proposed in this pull request?

Instead of returning NULL, throws runtime ArrayIndexOutOfBoundsException when ansiMode is enable for `element_at`,`elt`, `GetArrayItem` functions.

### Why are the changes needed?

For ansiMode.

### Does this PR introduce any user-facing change?

When `spark.sql.ansi.enabled` = true, Spark will throw `ArrayIndexOutOfBoundsException` if out-of-range index when accessing array elements

### How was this patch tested?

Added UT and existing UT.

Closes #30297 from leanken/leanken-SPARK-33386.

Authored-by: xuewei.linxuewei <xuewei.linxuewei@alibaba-inc.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
xuewei.linxuewei 2020-11-12 08:50:32 +00:00 committed by Wenchen Fan
parent 22baf05a9e
commit 6d31daeb6a
16 changed files with 579 additions and 99 deletions

View file

@ -110,7 +110,14 @@ SELECT * FROM t;
### SQL Functions
The behavior of some SQL functions can be different under ANSI mode (`spark.sql.ansi.enabled=true`).
- `size`: This function returns null for null input under ANSI mode.
- `size`: This function returns null for null input.
- `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices.
- `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices.
### SQL Operators
The behavior of some SQL operators can be different under ANSI mode (`spark.sql.ansi.enabled=true`).
- `array_col[index]`: This operator throws `ArrayIndexOutOfBoundsException` if using invalid indices.
### SQL Keywords

View file

@ -840,8 +840,8 @@ object TypeCoercion {
plan resolveOperators { case p =>
p transformExpressionsUp {
// Skip nodes if unresolved or not enough children
case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c
case c @ Elt(children) =>
case c @ Elt(children, _) if !c.childrenResolved || children.size < 2 => c
case c @ Elt(children, _) =>
val index = children.head
val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index)
val newInputs = if (conf.eltOutputAsString ||

View file

@ -34,8 +34,10 @@ case class ProjectionOverSchema(schema: StructType) {
expr match {
case a: AttributeReference if fieldNames.contains(a.name) =>
Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier))
case GetArrayItem(child, arrayItemOrdinal) =>
getProjection(child).map { projection => GetArrayItem(projection, arrayItemOrdinal) }
case GetArrayItem(child, arrayItemOrdinal, failOnError) =>
getProjection(child).map {
projection => GetArrayItem(projection, arrayItemOrdinal, failOnError)
}
case a: GetArrayStructFields =>
getProjection(a.child).map(p => (p, p.dataType)).map {
case (projection, ArrayType(projSchema @ StructType(_), _)) =>

View file

@ -119,7 +119,7 @@ object SelectedField {
throw new AnalysisException(s"DataType '$x' is not supported by MapKeys.")
}
selectField(child, opt)
case GetArrayItem(child, _) =>
case GetArrayItem(child, _, _) =>
// GetArrayItem does not select a field from a struct (i.e. prune the struct) so it can't be
// the top-level extractor. However it can be part of an extractor chain.
val ArrayType(_, containsNull) = child.dataType

View file

@ -1906,8 +1906,10 @@ case class ArrayPosition(left: Expression, right: Expression)
@ExpressionDescription(
usage = """
_FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0,
accesses elements from the last to the first. Returns NULL if the index exceeds the length
of the array.
accesses elements from the last to the first. The function returns NULL
if the index exceeds the length of the array and `spark.sql.ansi.enabled` is set to false.
If `spark.sql.ansi.enabled` is set to true, it throws ArrayIndexOutOfBoundsException
for invalid indices.
_FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map
""",
@ -1919,9 +1921,14 @@ case class ArrayPosition(left: Expression, right: Expression)
b
""",
since = "2.4.0")
case class ElementAt(left: Expression, right: Expression)
case class ElementAt(
left: Expression,
right: Expression,
failOnError: Boolean = SQLConf.get.ansiEnabled)
extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant {
def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled)
@transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType
@transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull
@ -1969,7 +1976,7 @@ case class ElementAt(left: Expression, right: Expression)
if (ordinal == 0) {
false
} else if (elements.length < math.abs(ordinal)) {
true
!failOnError
} else {
if (ordinal < 0) {
elements(elements.length + ordinal).nullable
@ -1979,24 +1986,9 @@ case class ElementAt(left: Expression, right: Expression)
}
}
override def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = {
if (ordinal.foldable && !ordinal.nullable) {
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
child match {
case CreateArray(ar, _) =>
nullability(ar, intOrdinal)
case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) =>
nullability(elements, intOrdinal) || field.nullable
case _ =>
true
}
} else {
true
}
}
override def nullable: Boolean = left.dataType match {
case _: ArrayType => computeNullabilityFromArray(left, right)
case _: ArrayType =>
computeNullabilityFromArray(left, right, failOnError, nullability)
case _: MapType => true
}
@ -2008,7 +2000,12 @@ case class ElementAt(left: Expression, right: Expression)
val array = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Int]
if (array.numElements() < math.abs(index)) {
null
if (failOnError) {
throw new ArrayIndexOutOfBoundsException(
s"Invalid index: $index, numElements: ${array.numElements()}")
} else {
null
}
} else {
val idx = if (index == 0) {
throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
@ -2042,10 +2039,20 @@ case class ElementAt(left: Expression, right: Expression)
} else {
""
}
val indexOutOfBoundBranch = if (failOnError) {
s"""throw new ArrayIndexOutOfBoundsException(
| "Invalid index: " + $index + ", numElements: " + $eval1.numElements()
|);
""".stripMargin
} else {
s"${ev.isNull} = true;"
}
s"""
|int $index = (int) $eval2;
|if ($eval1.numElements() < Math.abs($index)) {
| ${ev.isNull} = true;
| $indexOutOfBoundBranch
|} else {
| if ($index == 0) {
| throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1");

View file

@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -222,10 +223,15 @@ case class GetArrayStructFields(
*
* We need to do type checking here as `ordinal` expression maybe unresolved.
*/
case class GetArrayItem(child: Expression, ordinal: Expression)
case class GetArrayItem(
child: Expression,
ordinal: Expression,
failOnError: Boolean = SQLConf.get.ansiEnabled)
extends BinaryExpression with GetArrayItemUtil with ExpectsInputTypes with ExtractValue
with NullIntolerant {
def this(child: Expression, ordinal: Expression) = this(child, ordinal, SQLConf.get.ansiEnabled)
// We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)
@ -234,13 +240,29 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
override def left: Expression = child
override def right: Expression = ordinal
override def nullable: Boolean = computeNullabilityFromArray(left, right)
override def nullable: Boolean =
computeNullabilityFromArray(left, right, failOnError, nullability)
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
private def nullability(elements: Seq[Expression], ordinal: Int): Boolean = {
if (ordinal >= 0 && ordinal < elements.length) {
elements(ordinal).nullable
} else {
!failOnError
}
}
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
val baseValue = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Number].intValue()
if (index >= baseValue.numElements() || index < 0 || baseValue.isNullAt(index)) {
if (index >= baseValue.numElements() || index < 0) {
if (failOnError) {
throw new ArrayIndexOutOfBoundsException(
s"Invalid index: $index, numElements: ${baseValue.numElements()}")
} else {
null
}
} else if (baseValue.isNullAt(index)) {
null
} else {
baseValue.get(index, dataType)
@ -251,15 +273,28 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val index = ctx.freshName("index")
val nullCheck = if (child.dataType.asInstanceOf[ArrayType].containsNull) {
s" || $eval1.isNullAt($index)"
s"""else if ($eval1.isNullAt($index)) {
${ev.isNull} = true;
}
"""
} else {
""
}
val indexOutOfBoundBranch = if (failOnError) {
s"""throw new ArrayIndexOutOfBoundsException(
| "Invalid index: " + $index + ", numElements: " + $eval1.numElements()
|);
""".stripMargin
} else {
s"${ev.isNull} = true;"
}
s"""
final int $index = (int) $eval2;
if ($index >= $eval1.numElements() || $index < 0$nullCheck) {
${ev.isNull} = true;
} else {
if ($index >= $eval1.numElements() || $index < 0) {
$indexOutOfBoundBranch
} $nullCheck else {
${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
}
"""
@ -273,20 +308,24 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
trait GetArrayItemUtil {
/** `Null` is returned for invalid ordinals. */
protected def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = {
protected def computeNullabilityFromArray(
child: Expression,
ordinal: Expression,
failOnError: Boolean,
nullability: (Seq[Expression], Int) => Boolean): Boolean = {
val arrayContainsNull = child.dataType.asInstanceOf[ArrayType].containsNull
if (ordinal.foldable && !ordinal.nullable) {
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
child match {
case CreateArray(ar, _) if intOrdinal < ar.length =>
ar(intOrdinal).nullable
case GetArrayStructFields(CreateArray(elements, _), field, _, _, _)
if intOrdinal < elements.length =>
elements(intOrdinal).nullable || field.nullable
case CreateArray(ar, _) =>
nullability(ar, intOrdinal)
case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) =>
nullability(elements, intOrdinal) || field.nullable
case _ =>
true
}
} else {
true
if (failOnError) arrayContainsNull else true
}
}
}

View file

@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@ -231,7 +232,12 @@ case class ConcatWs(children: Seq[Expression])
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.",
usage = """
_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.
The function returns NULL if the index exceeds the length of the array
and `spark.sql.ansi.enabled` is set to false. If `spark.sql.ansi.enabled` is set to true,
it throws ArrayIndexOutOfBoundsException for invalid indices.
""",
examples = """
Examples:
> SELECT _FUNC_(1, 'scala', 'java');
@ -239,7 +245,11 @@ case class ConcatWs(children: Seq[Expression])
""",
since = "2.0.0")
// scalastyle:on line.size.limit
case class Elt(children: Seq[Expression]) extends Expression {
case class Elt(
children: Seq[Expression],
failOnError: Boolean = SQLConf.get.ansiEnabled) extends Expression {
def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled)
private lazy val indexExpr = children.head
private lazy val inputExprs = children.tail.toArray
@ -275,7 +285,12 @@ case class Elt(children: Seq[Expression]) extends Expression {
} else {
val index = indexObj.asInstanceOf[Int]
if (index <= 0 || index > inputExprs.length) {
null
if (failOnError) {
throw new ArrayIndexOutOfBoundsException(
s"Invalid index: $index, numElements: ${inputExprs.length}")
} else {
null
}
} else {
inputExprs(index - 1).eval(input)
}
@ -323,6 +338,17 @@ case class Elt(children: Seq[Expression]) extends Expression {
""".stripMargin
}.mkString)
val indexOutOfBoundBranch = if (failOnError) {
s"""
|if (!$indexMatched) {
| throw new ArrayIndexOutOfBoundsException(
| "Invalid index: " + ${index.value} + ", numElements: " + ${inputExprs.length});
|}
""".stripMargin
} else {
""
}
ev.copy(
code"""
|${index.code}
@ -332,6 +358,7 @@ case class Elt(children: Seq[Expression]) extends Expression {
|do {
| $codes
|} while (false);
|$indexOutOfBoundBranch
|final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal;
|final boolean ${ev.isNull} = ${ev.value} == null;
""".stripMargin)

View file

@ -61,7 +61,7 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name))), useStringTypeWhenEmpty)
// Remove redundant map lookup.
case ga @ GetArrayItem(CreateArray(elems, _), IntegerLiteral(idx)) =>
case ga @ GetArrayItem(CreateArray(elems, _), IntegerLiteral(idx), _) =>
// Instead of creating the array and then selecting one row, remove array creation
// altogether.
if (idx >= 0 && idx < elems.size) {

View file

@ -2144,9 +2144,10 @@ object SQLConf {
val ANSI_ENABLED = buildConf("spark.sql.ansi.enabled")
.doc("When true, Spark tries to conform to the ANSI SQL specification: 1. Spark will " +
"throw a runtime exception if an overflow occurs in any operation on integral/decimal " +
"field. 2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " +
"the SQL parser.")
"throw an exception at runtime if the inputs to a SQL operator/function are invalid, " +
"e.g. overflow in arithmetic operations, out-of-range index when accessing array elements. " +
"2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " +
"the SQL parser. 3. Spark will return NULL for null input for function `size`.")
.version("3.0.0")
.booleanConf
.createWithDefault(false)

View file

@ -1118,58 +1118,62 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
}
test("correctly handles ElementAt nullability for arrays") {
// CreateArray case
val a = AttributeReference("a", IntegerType, nullable = false)()
val b = AttributeReference("b", IntegerType, nullable = true)()
val array = CreateArray(a :: b :: Nil)
assert(!ElementAt(array, Literal(1)).nullable)
assert(!ElementAt(array, Literal(-2)).nullable)
assert(ElementAt(array, Literal(2)).nullable)
assert(ElementAt(array, Literal(-1)).nullable)
assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable)
assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable)
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
// CreateArray case
val a = AttributeReference("a", IntegerType, nullable = false)()
val b = AttributeReference("b", IntegerType, nullable = true)()
val array = CreateArray(a :: b :: Nil)
assert(!ElementAt(array, Literal(1)).nullable)
assert(!ElementAt(array, Literal(-2)).nullable)
assert(ElementAt(array, Literal(2)).nullable)
assert(ElementAt(array, Literal(-1)).nullable)
assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable)
assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable)
// CreateArray case invalid indices
assert(!ElementAt(array, Literal(0)).nullable)
assert(ElementAt(array, Literal(4)).nullable)
assert(ElementAt(array, Literal(-4)).nullable)
// CreateArray case invalid indices
assert(!ElementAt(array, Literal(0)).nullable)
assert(ElementAt(array, Literal(4)).nullable == !ansiEnabled)
assert(ElementAt(array, Literal(-4)).nullable == !ansiEnabled)
// GetArrayStructFields case
val f1 = StructField("a", IntegerType, nullable = false)
val f2 = StructField("b", IntegerType, nullable = true)
val structType = StructType(f1 :: f2 :: Nil)
val c = AttributeReference("c", structType, nullable = false)()
val inputArray1 = CreateArray(c :: Nil)
val inputArray1ContainsNull = c.nullable
val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull)
assert(!ElementAt(stArray1, Literal(1)).nullable)
assert(!ElementAt(stArray1, Literal(-1)).nullable)
val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull)
assert(ElementAt(stArray2, Literal(1)).nullable)
assert(ElementAt(stArray2, Literal(-1)).nullable)
// GetArrayStructFields case
val f1 = StructField("a", IntegerType, nullable = false)
val f2 = StructField("b", IntegerType, nullable = true)
val structType = StructType(f1 :: f2 :: Nil)
val c = AttributeReference("c", structType, nullable = false)()
val inputArray1 = CreateArray(c :: Nil)
val inputArray1ContainsNull = c.nullable
val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull)
assert(!ElementAt(stArray1, Literal(1)).nullable)
assert(!ElementAt(stArray1, Literal(-1)).nullable)
val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull)
assert(ElementAt(stArray2, Literal(1)).nullable)
assert(ElementAt(stArray2, Literal(-1)).nullable)
val d = AttributeReference("d", structType, nullable = true)()
val inputArray2 = CreateArray(c :: d :: Nil)
val inputArray2ContainsNull = c.nullable || d.nullable
val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull)
assert(!ElementAt(stArray3, Literal(1)).nullable)
assert(!ElementAt(stArray3, Literal(-2)).nullable)
assert(ElementAt(stArray3, Literal(2)).nullable)
assert(ElementAt(stArray3, Literal(-1)).nullable)
val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull)
assert(ElementAt(stArray4, Literal(1)).nullable)
assert(ElementAt(stArray4, Literal(-2)).nullable)
assert(ElementAt(stArray4, Literal(2)).nullable)
assert(ElementAt(stArray4, Literal(-1)).nullable)
val d = AttributeReference("d", structType, nullable = true)()
val inputArray2 = CreateArray(c :: d :: Nil)
val inputArray2ContainsNull = c.nullable || d.nullable
val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull)
assert(!ElementAt(stArray3, Literal(1)).nullable)
assert(!ElementAt(stArray3, Literal(-2)).nullable)
assert(ElementAt(stArray3, Literal(2)).nullable)
assert(ElementAt(stArray3, Literal(-1)).nullable)
val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull)
assert(ElementAt(stArray4, Literal(1)).nullable)
assert(ElementAt(stArray4, Literal(-2)).nullable)
assert(ElementAt(stArray4, Literal(2)).nullable)
assert(ElementAt(stArray4, Literal(-1)).nullable)
// GetArrayStructFields case invalid indices
assert(!ElementAt(stArray3, Literal(0)).nullable)
assert(ElementAt(stArray3, Literal(4)).nullable)
assert(ElementAt(stArray3, Literal(-4)).nullable)
// GetArrayStructFields case invalid indices
assert(!ElementAt(stArray3, Literal(0)).nullable)
assert(ElementAt(stArray3, Literal(4)).nullable == !ansiEnabled)
assert(ElementAt(stArray3, Literal(-4)).nullable == !ansiEnabled)
assert(ElementAt(stArray4, Literal(0)).nullable)
assert(ElementAt(stArray4, Literal(4)).nullable)
assert(ElementAt(stArray4, Literal(-4)).nullable)
assert(ElementAt(stArray4, Literal(0)).nullable)
assert(ElementAt(stArray4, Literal(4)).nullable)
assert(ElementAt(stArray4, Literal(-4)).nullable)
}
}
}
test("Concat") {
@ -1883,4 +1887,32 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
Literal(stringToInterval("interval 1 year"))),
Seq(Date.valueOf("2018-01-01")))
}
test("SPARK-33386: element_at ArrayIndexOutOfBoundsException") {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
val array = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
var expr: Expression = ElementAt(array, Literal(5))
if (ansiEnabled) {
val errMsg = "Invalid index: 5, numElements: 3"
checkExceptionInExpression[Exception](expr, errMsg)
} else {
checkEvaluation(expr, null)
}
expr = ElementAt(array, Literal(-5))
if (ansiEnabled) {
val errMsg = "Invalid index: -5, numElements: 3"
checkExceptionInExpression[Exception](expr, errMsg)
} else {
checkEvaluation(expr, null)
}
// SQL array indices start at 1 exception throws for both mode.
expr = ElementAt(array, Literal(0))
val errMsg = "SQL array indices start at 1"
checkExceptionInExpression[Exception](expr, errMsg)
}
}
}
}

View file

@ -62,6 +62,29 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1))
}
test("SPARK-33386: GetArrayItem ArrayIndexOutOfBoundsException") {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
val array = Literal.create(Seq("a", "b"), ArrayType(StringType))
if (ansiEnabled) {
checkExceptionInExpression[Exception](
GetArrayItem(array, Literal(5)),
"Invalid index: 5, numElements: 2"
)
checkExceptionInExpression[Exception](
GetArrayItem(array, Literal(-1)),
"Invalid index: -1, numElements: 2"
)
} else {
checkEvaluation(GetArrayItem(array, Literal(5)), null)
checkEvaluation(GetArrayItem(array, Literal(-1)), null)
}
}
}
}
test("SPARK-26637 handles GetArrayItem nullability correctly when input array size is constant") {
// CreateArray case
val a = AttributeReference("a", IntegerType, nullable = false)()

View file

@ -18,9 +18,9 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@ -968,4 +968,34 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
GenerateUnsafeProjection.generate(
Sentences(Literal("\"quote"), Literal("\"quote"), Literal("\"quote")) :: Nil)
}
test("SPARK-33386: elt ArrayIndexOutOfBoundsException") {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
var expr: Expression = Elt(Seq(Literal(4), Literal("123"), Literal("456")))
if (ansiEnabled) {
val errMsg = "Invalid index: 4, numElements: 2"
checkExceptionInExpression[Exception](expr, errMsg)
} else {
checkEvaluation(expr, null)
}
expr = Elt(Seq(Literal(0), Literal("123"), Literal("456")))
if (ansiEnabled) {
val errMsg = "Invalid index: 0, numElements: 2"
checkExceptionInExpression[Exception](expr, errMsg)
} else {
checkEvaluation(expr, null)
}
expr = Elt(Seq(Literal(-1), Literal("123"), Literal("456")))
if (ansiEnabled) {
val errMsg = "Invalid index: -1, numElements: 2"
checkExceptionInExpression[Exception](expr, errMsg)
} else {
checkEvaluation(expr, null)
}
}
}
}
}

View file

@ -0,0 +1 @@
--IMPORT array.sql

View file

@ -90,3 +90,15 @@ select
size(date_array),
size(timestamp_array)
from primitive_arrays;
-- index out of range for array elements
select element_at(array(1, 2, 3), 5);
select element_at(array(1, 2, 3), -5);
select element_at(array(1, 2, 3), 0);
select elt(4, '123', '456');
select elt(0, '123', '456');
select elt(-1, '123', '456');
select array(1, 2, 3)[5];
select array(1, 2, 3)[-1];

View file

@ -0,0 +1,234 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 20
-- !query
create temporary view data as select * from values
("one", array(11, 12, 13), array(array(111, 112, 113), array(121, 122, 123))),
("two", array(21, 22, 23), array(array(211, 212, 213), array(221, 222, 223)))
as data(a, b, c)
-- !query schema
struct<>
-- !query output
-- !query
select * from data
-- !query schema
struct<a:string,b:array<int>,c:array<array<int>>>
-- !query output
one [11,12,13] [[111,112,113],[121,122,123]]
two [21,22,23] [[211,212,213],[221,222,223]]
-- !query
select a, b[0], b[0] + b[1] from data
-- !query schema
struct<a:string,b[0]:int,(b[0] + b[1]):int>
-- !query output
one 11 23
two 21 43
-- !query
select a, c[0][0] + c[0][0 + 1] from data
-- !query schema
struct<a:string,(c[0][0] + c[0][(0 + 1)]):int>
-- !query output
one 223
two 423
-- !query
create temporary view primitive_arrays as select * from values (
array(true),
array(2Y, 1Y),
array(2S, 1S),
array(2, 1),
array(2L, 1L),
array(9223372036854775809, 9223372036854775808),
array(2.0D, 1.0D),
array(float(2.0), float(1.0)),
array(date '2016-03-14', date '2016-03-13'),
array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000')
) as primitive_arrays(
boolean_array,
tinyint_array,
smallint_array,
int_array,
bigint_array,
decimal_array,
double_array,
float_array,
date_array,
timestamp_array
)
-- !query schema
struct<>
-- !query output
-- !query
select * from primitive_arrays
-- !query schema
struct<boolean_array:array<boolean>,tinyint_array:array<tinyint>,smallint_array:array<smallint>,int_array:array<int>,bigint_array:array<bigint>,decimal_array:array<decimal(19,0)>,double_array:array<double>,float_array:array<float>,date_array:array<date>,timestamp_array:array<timestamp>>
-- !query output
[true] [2,1] [2,1] [2,1] [2,1] [9223372036854775809,9223372036854775808] [2.0,1.0] [2.0,1.0] [2016-03-14,2016-03-13] [2016-11-15 20:54:00,2016-11-12 20:54:00]
-- !query
select
array_contains(boolean_array, true), array_contains(boolean_array, false),
array_contains(tinyint_array, 2Y), array_contains(tinyint_array, 0Y),
array_contains(smallint_array, 2S), array_contains(smallint_array, 0S),
array_contains(int_array, 2), array_contains(int_array, 0),
array_contains(bigint_array, 2L), array_contains(bigint_array, 0L),
array_contains(decimal_array, 9223372036854775809), array_contains(decimal_array, 1),
array_contains(double_array, 2.0D), array_contains(double_array, 0.0D),
array_contains(float_array, float(2.0)), array_contains(float_array, float(0.0)),
array_contains(date_array, date '2016-03-14'), array_contains(date_array, date '2016-01-01'),
array_contains(timestamp_array, timestamp '2016-11-15 20:54:00.000'), array_contains(timestamp_array, timestamp '2016-01-01 20:54:00.000')
from primitive_arrays
-- !query schema
struct<array_contains(boolean_array, true):boolean,array_contains(boolean_array, false):boolean,array_contains(tinyint_array, 2):boolean,array_contains(tinyint_array, 0):boolean,array_contains(smallint_array, 2):boolean,array_contains(smallint_array, 0):boolean,array_contains(int_array, 2):boolean,array_contains(int_array, 0):boolean,array_contains(bigint_array, 2):boolean,array_contains(bigint_array, 0):boolean,array_contains(decimal_array, 9223372036854775809):boolean,array_contains(decimal_array, CAST(1 AS DECIMAL(19,0))):boolean,array_contains(double_array, 2.0):boolean,array_contains(double_array, 0.0):boolean,array_contains(float_array, CAST(2.0 AS FLOAT)):boolean,array_contains(float_array, CAST(0.0 AS FLOAT)):boolean,array_contains(date_array, DATE '2016-03-14'):boolean,array_contains(date_array, DATE '2016-01-01'):boolean,array_contains(timestamp_array, TIMESTAMP '2016-11-15 20:54:00'):boolean,array_contains(timestamp_array, TIMESTAMP '2016-01-01 20:54:00'):boolean>
-- !query output
true false true false true false true false true false true false true false true false true false true false
-- !query
select array_contains(b, 11), array_contains(c, array(111, 112, 113)) from data
-- !query schema
struct<array_contains(b, 11):boolean,array_contains(c, array(111, 112, 113)):boolean>
-- !query output
false false
true true
-- !query
select
sort_array(boolean_array),
sort_array(tinyint_array),
sort_array(smallint_array),
sort_array(int_array),
sort_array(bigint_array),
sort_array(decimal_array),
sort_array(double_array),
sort_array(float_array),
sort_array(date_array),
sort_array(timestamp_array)
from primitive_arrays
-- !query schema
struct<sort_array(boolean_array, true):array<boolean>,sort_array(tinyint_array, true):array<tinyint>,sort_array(smallint_array, true):array<smallint>,sort_array(int_array, true):array<int>,sort_array(bigint_array, true):array<bigint>,sort_array(decimal_array, true):array<decimal(19,0)>,sort_array(double_array, true):array<double>,sort_array(float_array, true):array<float>,sort_array(date_array, true):array<date>,sort_array(timestamp_array, true):array<timestamp>>
-- !query output
[true] [1,2] [1,2] [1,2] [1,2] [9223372036854775808,9223372036854775809] [1.0,2.0] [1.0,2.0] [2016-03-13,2016-03-14] [2016-11-12 20:54:00,2016-11-15 20:54:00]
-- !query
select sort_array(array('b', 'd'), '1')
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
cannot resolve 'sort_array(array('b', 'd'), '1')' due to data type mismatch: Sort order in second argument requires a boolean literal.; line 1 pos 7
-- !query
select sort_array(array('b', 'd'), cast(NULL as boolean))
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
cannot resolve 'sort_array(array('b', 'd'), CAST(NULL AS BOOLEAN))' due to data type mismatch: Sort order in second argument requires a boolean literal.; line 1 pos 7
-- !query
select
size(boolean_array),
size(tinyint_array),
size(smallint_array),
size(int_array),
size(bigint_array),
size(decimal_array),
size(double_array),
size(float_array),
size(date_array),
size(timestamp_array)
from primitive_arrays
-- !query schema
struct<size(boolean_array):int,size(tinyint_array):int,size(smallint_array):int,size(int_array):int,size(bigint_array):int,size(decimal_array):int,size(double_array):int,size(float_array):int,size(date_array):int,size(timestamp_array):int>
-- !query output
1 2 2 2 2 2 2 2 2 2
-- !query
select element_at(array(1, 2, 3), 5)
-- !query schema
struct<>
-- !query output
java.lang.ArrayIndexOutOfBoundsException
Invalid index: 5, numElements: 3
-- !query
select element_at(array(1, 2, 3), -5)
-- !query schema
struct<>
-- !query output
java.lang.ArrayIndexOutOfBoundsException
Invalid index: -5, numElements: 3
-- !query
select element_at(array(1, 2, 3), 0)
-- !query schema
struct<>
-- !query output
java.lang.ArrayIndexOutOfBoundsException
SQL array indices start at 1
-- !query
select elt(4, '123', '456')
-- !query schema
struct<>
-- !query output
java.lang.ArrayIndexOutOfBoundsException
Invalid index: 4, numElements: 2
-- !query
select elt(0, '123', '456')
-- !query schema
struct<>
-- !query output
java.lang.ArrayIndexOutOfBoundsException
Invalid index: 0, numElements: 2
-- !query
select elt(-1, '123', '456')
-- !query schema
struct<>
-- !query output
java.lang.ArrayIndexOutOfBoundsException
Invalid index: -1, numElements: 2
-- !query
select array(1, 2, 3)[5]
-- !query schema
struct<>
-- !query output
java.lang.ArrayIndexOutOfBoundsException
Invalid index: 5, numElements: 3
-- !query
select array(1, 2, 3)[-1]
-- !query schema
struct<>
-- !query output
java.lang.ArrayIndexOutOfBoundsException
Invalid index: -1, numElements: 3

View file

@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 12
-- Number of queries: 20
-- !query
@ -160,3 +160,68 @@ from primitive_arrays
struct<size(boolean_array):int,size(tinyint_array):int,size(smallint_array):int,size(int_array):int,size(bigint_array):int,size(decimal_array):int,size(double_array):int,size(float_array):int,size(date_array):int,size(timestamp_array):int>
-- !query output
1 2 2 2 2 2 2 2 2 2
-- !query
select element_at(array(1, 2, 3), 5)
-- !query schema
struct<element_at(array(1, 2, 3), 5):int>
-- !query output
NULL
-- !query
select element_at(array(1, 2, 3), -5)
-- !query schema
struct<element_at(array(1, 2, 3), -5):int>
-- !query output
NULL
-- !query
select element_at(array(1, 2, 3), 0)
-- !query schema
struct<>
-- !query output
java.lang.ArrayIndexOutOfBoundsException
SQL array indices start at 1
-- !query
select elt(4, '123', '456')
-- !query schema
struct<elt(4, 123, 456):string>
-- !query output
NULL
-- !query
select elt(0, '123', '456')
-- !query schema
struct<elt(0, 123, 456):string>
-- !query output
NULL
-- !query
select elt(-1, '123', '456')
-- !query schema
struct<elt(-1, 123, 456):string>
-- !query output
NULL
-- !query
select array(1, 2, 3)[5]
-- !query schema
struct<array(1, 2, 3)[5]:int>
-- !query output
NULL
-- !query
select array(1, 2, 3)[-1]
-- !query schema
struct<array(1, 2, 3)[-1]:int>
-- !query output
NULL