[SPARK-33386][SQL] Accessing array elements in ElementAt/Elt/GetArrayItem should failed if index is out of bound
### What changes were proposed in this pull request? Instead of returning NULL, throws runtime ArrayIndexOutOfBoundsException when ansiMode is enable for `element_at`,`elt`, `GetArrayItem` functions. ### Why are the changes needed? For ansiMode. ### Does this PR introduce any user-facing change? When `spark.sql.ansi.enabled` = true, Spark will throw `ArrayIndexOutOfBoundsException` if out-of-range index when accessing array elements ### How was this patch tested? Added UT and existing UT. Closes #30297 from leanken/leanken-SPARK-33386. Authored-by: xuewei.linxuewei <xuewei.linxuewei@alibaba-inc.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
22baf05a9e
commit
6d31daeb6a
|
@ -110,7 +110,14 @@ SELECT * FROM t;
|
|||
### SQL Functions
|
||||
|
||||
The behavior of some SQL functions can be different under ANSI mode (`spark.sql.ansi.enabled=true`).
|
||||
- `size`: This function returns null for null input under ANSI mode.
|
||||
- `size`: This function returns null for null input.
|
||||
- `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices.
|
||||
- `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices.
|
||||
|
||||
### SQL Operators
|
||||
|
||||
The behavior of some SQL operators can be different under ANSI mode (`spark.sql.ansi.enabled=true`).
|
||||
- `array_col[index]`: This operator throws `ArrayIndexOutOfBoundsException` if using invalid indices.
|
||||
|
||||
### SQL Keywords
|
||||
|
||||
|
|
|
@ -840,8 +840,8 @@ object TypeCoercion {
|
|||
plan resolveOperators { case p =>
|
||||
p transformExpressionsUp {
|
||||
// Skip nodes if unresolved or not enough children
|
||||
case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c
|
||||
case c @ Elt(children) =>
|
||||
case c @ Elt(children, _) if !c.childrenResolved || children.size < 2 => c
|
||||
case c @ Elt(children, _) =>
|
||||
val index = children.head
|
||||
val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index)
|
||||
val newInputs = if (conf.eltOutputAsString ||
|
||||
|
|
|
@ -34,8 +34,10 @@ case class ProjectionOverSchema(schema: StructType) {
|
|||
expr match {
|
||||
case a: AttributeReference if fieldNames.contains(a.name) =>
|
||||
Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier))
|
||||
case GetArrayItem(child, arrayItemOrdinal) =>
|
||||
getProjection(child).map { projection => GetArrayItem(projection, arrayItemOrdinal) }
|
||||
case GetArrayItem(child, arrayItemOrdinal, failOnError) =>
|
||||
getProjection(child).map {
|
||||
projection => GetArrayItem(projection, arrayItemOrdinal, failOnError)
|
||||
}
|
||||
case a: GetArrayStructFields =>
|
||||
getProjection(a.child).map(p => (p, p.dataType)).map {
|
||||
case (projection, ArrayType(projSchema @ StructType(_), _)) =>
|
||||
|
|
|
@ -119,7 +119,7 @@ object SelectedField {
|
|||
throw new AnalysisException(s"DataType '$x' is not supported by MapKeys.")
|
||||
}
|
||||
selectField(child, opt)
|
||||
case GetArrayItem(child, _) =>
|
||||
case GetArrayItem(child, _, _) =>
|
||||
// GetArrayItem does not select a field from a struct (i.e. prune the struct) so it can't be
|
||||
// the top-level extractor. However it can be part of an extractor chain.
|
||||
val ArrayType(_, containsNull) = child.dataType
|
||||
|
|
|
@ -1906,8 +1906,10 @@ case class ArrayPosition(left: Expression, right: Expression)
|
|||
@ExpressionDescription(
|
||||
usage = """
|
||||
_FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0,
|
||||
accesses elements from the last to the first. Returns NULL if the index exceeds the length
|
||||
of the array.
|
||||
accesses elements from the last to the first. The function returns NULL
|
||||
if the index exceeds the length of the array and `spark.sql.ansi.enabled` is set to false.
|
||||
If `spark.sql.ansi.enabled` is set to true, it throws ArrayIndexOutOfBoundsException
|
||||
for invalid indices.
|
||||
|
||||
_FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map
|
||||
""",
|
||||
|
@ -1919,9 +1921,14 @@ case class ArrayPosition(left: Expression, right: Expression)
|
|||
b
|
||||
""",
|
||||
since = "2.4.0")
|
||||
case class ElementAt(left: Expression, right: Expression)
|
||||
case class ElementAt(
|
||||
left: Expression,
|
||||
right: Expression,
|
||||
failOnError: Boolean = SQLConf.get.ansiEnabled)
|
||||
extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant {
|
||||
|
||||
def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled)
|
||||
|
||||
@transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType
|
||||
|
||||
@transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull
|
||||
|
@ -1969,7 +1976,7 @@ case class ElementAt(left: Expression, right: Expression)
|
|||
if (ordinal == 0) {
|
||||
false
|
||||
} else if (elements.length < math.abs(ordinal)) {
|
||||
true
|
||||
!failOnError
|
||||
} else {
|
||||
if (ordinal < 0) {
|
||||
elements(elements.length + ordinal).nullable
|
||||
|
@ -1979,24 +1986,9 @@ case class ElementAt(left: Expression, right: Expression)
|
|||
}
|
||||
}
|
||||
|
||||
override def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = {
|
||||
if (ordinal.foldable && !ordinal.nullable) {
|
||||
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
|
||||
child match {
|
||||
case CreateArray(ar, _) =>
|
||||
nullability(ar, intOrdinal)
|
||||
case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) =>
|
||||
nullability(elements, intOrdinal) || field.nullable
|
||||
case _ =>
|
||||
true
|
||||
}
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
override def nullable: Boolean = left.dataType match {
|
||||
case _: ArrayType => computeNullabilityFromArray(left, right)
|
||||
case _: ArrayType =>
|
||||
computeNullabilityFromArray(left, right, failOnError, nullability)
|
||||
case _: MapType => true
|
||||
}
|
||||
|
||||
|
@ -2008,7 +2000,12 @@ case class ElementAt(left: Expression, right: Expression)
|
|||
val array = value.asInstanceOf[ArrayData]
|
||||
val index = ordinal.asInstanceOf[Int]
|
||||
if (array.numElements() < math.abs(index)) {
|
||||
null
|
||||
if (failOnError) {
|
||||
throw new ArrayIndexOutOfBoundsException(
|
||||
s"Invalid index: $index, numElements: ${array.numElements()}")
|
||||
} else {
|
||||
null
|
||||
}
|
||||
} else {
|
||||
val idx = if (index == 0) {
|
||||
throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
|
||||
|
@ -2042,10 +2039,20 @@ case class ElementAt(left: Expression, right: Expression)
|
|||
} else {
|
||||
""
|
||||
}
|
||||
|
||||
val indexOutOfBoundBranch = if (failOnError) {
|
||||
s"""throw new ArrayIndexOutOfBoundsException(
|
||||
| "Invalid index: " + $index + ", numElements: " + $eval1.numElements()
|
||||
|);
|
||||
""".stripMargin
|
||||
} else {
|
||||
s"${ev.isNull} = true;"
|
||||
}
|
||||
|
||||
s"""
|
||||
|int $index = (int) $eval2;
|
||||
|if ($eval1.numElements() < Math.abs($index)) {
|
||||
| ${ev.isNull} = true;
|
||||
| $indexOutOfBoundBranch
|
||||
|} else {
|
||||
| if ($index == 0) {
|
||||
| throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1");
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -222,10 +223,15 @@ case class GetArrayStructFields(
|
|||
*
|
||||
* We need to do type checking here as `ordinal` expression maybe unresolved.
|
||||
*/
|
||||
case class GetArrayItem(child: Expression, ordinal: Expression)
|
||||
case class GetArrayItem(
|
||||
child: Expression,
|
||||
ordinal: Expression,
|
||||
failOnError: Boolean = SQLConf.get.ansiEnabled)
|
||||
extends BinaryExpression with GetArrayItemUtil with ExpectsInputTypes with ExtractValue
|
||||
with NullIntolerant {
|
||||
|
||||
def this(child: Expression, ordinal: Expression) = this(child, ordinal, SQLConf.get.ansiEnabled)
|
||||
|
||||
// We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)
|
||||
|
||||
|
@ -234,13 +240,29 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
|
|||
|
||||
override def left: Expression = child
|
||||
override def right: Expression = ordinal
|
||||
override def nullable: Boolean = computeNullabilityFromArray(left, right)
|
||||
override def nullable: Boolean =
|
||||
computeNullabilityFromArray(left, right, failOnError, nullability)
|
||||
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
|
||||
|
||||
private def nullability(elements: Seq[Expression], ordinal: Int): Boolean = {
|
||||
if (ordinal >= 0 && ordinal < elements.length) {
|
||||
elements(ordinal).nullable
|
||||
} else {
|
||||
!failOnError
|
||||
}
|
||||
}
|
||||
|
||||
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
|
||||
val baseValue = value.asInstanceOf[ArrayData]
|
||||
val index = ordinal.asInstanceOf[Number].intValue()
|
||||
if (index >= baseValue.numElements() || index < 0 || baseValue.isNullAt(index)) {
|
||||
if (index >= baseValue.numElements() || index < 0) {
|
||||
if (failOnError) {
|
||||
throw new ArrayIndexOutOfBoundsException(
|
||||
s"Invalid index: $index, numElements: ${baseValue.numElements()}")
|
||||
} else {
|
||||
null
|
||||
}
|
||||
} else if (baseValue.isNullAt(index)) {
|
||||
null
|
||||
} else {
|
||||
baseValue.get(index, dataType)
|
||||
|
@ -251,15 +273,28 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
|
|||
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
|
||||
val index = ctx.freshName("index")
|
||||
val nullCheck = if (child.dataType.asInstanceOf[ArrayType].containsNull) {
|
||||
s" || $eval1.isNullAt($index)"
|
||||
s"""else if ($eval1.isNullAt($index)) {
|
||||
${ev.isNull} = true;
|
||||
}
|
||||
"""
|
||||
} else {
|
||||
""
|
||||
}
|
||||
|
||||
val indexOutOfBoundBranch = if (failOnError) {
|
||||
s"""throw new ArrayIndexOutOfBoundsException(
|
||||
| "Invalid index: " + $index + ", numElements: " + $eval1.numElements()
|
||||
|);
|
||||
""".stripMargin
|
||||
} else {
|
||||
s"${ev.isNull} = true;"
|
||||
}
|
||||
|
||||
s"""
|
||||
final int $index = (int) $eval2;
|
||||
if ($index >= $eval1.numElements() || $index < 0$nullCheck) {
|
||||
${ev.isNull} = true;
|
||||
} else {
|
||||
if ($index >= $eval1.numElements() || $index < 0) {
|
||||
$indexOutOfBoundBranch
|
||||
} $nullCheck else {
|
||||
${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
|
||||
}
|
||||
"""
|
||||
|
@ -273,20 +308,24 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
|
|||
trait GetArrayItemUtil {
|
||||
|
||||
/** `Null` is returned for invalid ordinals. */
|
||||
protected def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = {
|
||||
protected def computeNullabilityFromArray(
|
||||
child: Expression,
|
||||
ordinal: Expression,
|
||||
failOnError: Boolean,
|
||||
nullability: (Seq[Expression], Int) => Boolean): Boolean = {
|
||||
val arrayContainsNull = child.dataType.asInstanceOf[ArrayType].containsNull
|
||||
if (ordinal.foldable && !ordinal.nullable) {
|
||||
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
|
||||
child match {
|
||||
case CreateArray(ar, _) if intOrdinal < ar.length =>
|
||||
ar(intOrdinal).nullable
|
||||
case GetArrayStructFields(CreateArray(elements, _), field, _, _, _)
|
||||
if intOrdinal < elements.length =>
|
||||
elements(intOrdinal).nullable || field.nullable
|
||||
case CreateArray(ar, _) =>
|
||||
nullability(ar, intOrdinal)
|
||||
case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) =>
|
||||
nullability(elements, intOrdinal) || field.nullable
|
||||
case _ =>
|
||||
true
|
||||
}
|
||||
} else {
|
||||
true
|
||||
if (failOnError) arrayContainsNull else true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
|
|||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.UTF8StringBuilder
|
||||
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
|
||||
|
@ -231,7 +232,12 @@ case class ConcatWs(children: Seq[Expression])
|
|||
*/
|
||||
// scalastyle:off line.size.limit
|
||||
@ExpressionDescription(
|
||||
usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.",
|
||||
usage = """
|
||||
_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.
|
||||
The function returns NULL if the index exceeds the length of the array
|
||||
and `spark.sql.ansi.enabled` is set to false. If `spark.sql.ansi.enabled` is set to true,
|
||||
it throws ArrayIndexOutOfBoundsException for invalid indices.
|
||||
""",
|
||||
examples = """
|
||||
Examples:
|
||||
> SELECT _FUNC_(1, 'scala', 'java');
|
||||
|
@ -239,7 +245,11 @@ case class ConcatWs(children: Seq[Expression])
|
|||
""",
|
||||
since = "2.0.0")
|
||||
// scalastyle:on line.size.limit
|
||||
case class Elt(children: Seq[Expression]) extends Expression {
|
||||
case class Elt(
|
||||
children: Seq[Expression],
|
||||
failOnError: Boolean = SQLConf.get.ansiEnabled) extends Expression {
|
||||
|
||||
def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled)
|
||||
|
||||
private lazy val indexExpr = children.head
|
||||
private lazy val inputExprs = children.tail.toArray
|
||||
|
@ -275,7 +285,12 @@ case class Elt(children: Seq[Expression]) extends Expression {
|
|||
} else {
|
||||
val index = indexObj.asInstanceOf[Int]
|
||||
if (index <= 0 || index > inputExprs.length) {
|
||||
null
|
||||
if (failOnError) {
|
||||
throw new ArrayIndexOutOfBoundsException(
|
||||
s"Invalid index: $index, numElements: ${inputExprs.length}")
|
||||
} else {
|
||||
null
|
||||
}
|
||||
} else {
|
||||
inputExprs(index - 1).eval(input)
|
||||
}
|
||||
|
@ -323,6 +338,17 @@ case class Elt(children: Seq[Expression]) extends Expression {
|
|||
""".stripMargin
|
||||
}.mkString)
|
||||
|
||||
val indexOutOfBoundBranch = if (failOnError) {
|
||||
s"""
|
||||
|if (!$indexMatched) {
|
||||
| throw new ArrayIndexOutOfBoundsException(
|
||||
| "Invalid index: " + ${index.value} + ", numElements: " + ${inputExprs.length});
|
||||
|}
|
||||
""".stripMargin
|
||||
} else {
|
||||
""
|
||||
}
|
||||
|
||||
ev.copy(
|
||||
code"""
|
||||
|${index.code}
|
||||
|
@ -332,6 +358,7 @@ case class Elt(children: Seq[Expression]) extends Expression {
|
|||
|do {
|
||||
| $codes
|
||||
|} while (false);
|
||||
|$indexOutOfBoundBranch
|
||||
|final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal;
|
||||
|final boolean ${ev.isNull} = ${ev.value} == null;
|
||||
""".stripMargin)
|
||||
|
|
|
@ -61,7 +61,7 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
|
|||
CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name))), useStringTypeWhenEmpty)
|
||||
|
||||
// Remove redundant map lookup.
|
||||
case ga @ GetArrayItem(CreateArray(elems, _), IntegerLiteral(idx)) =>
|
||||
case ga @ GetArrayItem(CreateArray(elems, _), IntegerLiteral(idx), _) =>
|
||||
// Instead of creating the array and then selecting one row, remove array creation
|
||||
// altogether.
|
||||
if (idx >= 0 && idx < elems.size) {
|
||||
|
|
|
@ -2144,9 +2144,10 @@ object SQLConf {
|
|||
|
||||
val ANSI_ENABLED = buildConf("spark.sql.ansi.enabled")
|
||||
.doc("When true, Spark tries to conform to the ANSI SQL specification: 1. Spark will " +
|
||||
"throw a runtime exception if an overflow occurs in any operation on integral/decimal " +
|
||||
"field. 2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " +
|
||||
"the SQL parser.")
|
||||
"throw an exception at runtime if the inputs to a SQL operator/function are invalid, " +
|
||||
"e.g. overflow in arithmetic operations, out-of-range index when accessing array elements. " +
|
||||
"2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " +
|
||||
"the SQL parser. 3. Spark will return NULL for null input for function `size`.")
|
||||
.version("3.0.0")
|
||||
.booleanConf
|
||||
.createWithDefault(false)
|
||||
|
|
|
@ -1118,58 +1118,62 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
}
|
||||
|
||||
test("correctly handles ElementAt nullability for arrays") {
|
||||
// CreateArray case
|
||||
val a = AttributeReference("a", IntegerType, nullable = false)()
|
||||
val b = AttributeReference("b", IntegerType, nullable = true)()
|
||||
val array = CreateArray(a :: b :: Nil)
|
||||
assert(!ElementAt(array, Literal(1)).nullable)
|
||||
assert(!ElementAt(array, Literal(-2)).nullable)
|
||||
assert(ElementAt(array, Literal(2)).nullable)
|
||||
assert(ElementAt(array, Literal(-1)).nullable)
|
||||
assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable)
|
||||
assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable)
|
||||
Seq(true, false).foreach { ansiEnabled =>
|
||||
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
|
||||
// CreateArray case
|
||||
val a = AttributeReference("a", IntegerType, nullable = false)()
|
||||
val b = AttributeReference("b", IntegerType, nullable = true)()
|
||||
val array = CreateArray(a :: b :: Nil)
|
||||
assert(!ElementAt(array, Literal(1)).nullable)
|
||||
assert(!ElementAt(array, Literal(-2)).nullable)
|
||||
assert(ElementAt(array, Literal(2)).nullable)
|
||||
assert(ElementAt(array, Literal(-1)).nullable)
|
||||
assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable)
|
||||
assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable)
|
||||
|
||||
// CreateArray case invalid indices
|
||||
assert(!ElementAt(array, Literal(0)).nullable)
|
||||
assert(ElementAt(array, Literal(4)).nullable)
|
||||
assert(ElementAt(array, Literal(-4)).nullable)
|
||||
// CreateArray case invalid indices
|
||||
assert(!ElementAt(array, Literal(0)).nullable)
|
||||
assert(ElementAt(array, Literal(4)).nullable == !ansiEnabled)
|
||||
assert(ElementAt(array, Literal(-4)).nullable == !ansiEnabled)
|
||||
|
||||
// GetArrayStructFields case
|
||||
val f1 = StructField("a", IntegerType, nullable = false)
|
||||
val f2 = StructField("b", IntegerType, nullable = true)
|
||||
val structType = StructType(f1 :: f2 :: Nil)
|
||||
val c = AttributeReference("c", structType, nullable = false)()
|
||||
val inputArray1 = CreateArray(c :: Nil)
|
||||
val inputArray1ContainsNull = c.nullable
|
||||
val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull)
|
||||
assert(!ElementAt(stArray1, Literal(1)).nullable)
|
||||
assert(!ElementAt(stArray1, Literal(-1)).nullable)
|
||||
val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull)
|
||||
assert(ElementAt(stArray2, Literal(1)).nullable)
|
||||
assert(ElementAt(stArray2, Literal(-1)).nullable)
|
||||
// GetArrayStructFields case
|
||||
val f1 = StructField("a", IntegerType, nullable = false)
|
||||
val f2 = StructField("b", IntegerType, nullable = true)
|
||||
val structType = StructType(f1 :: f2 :: Nil)
|
||||
val c = AttributeReference("c", structType, nullable = false)()
|
||||
val inputArray1 = CreateArray(c :: Nil)
|
||||
val inputArray1ContainsNull = c.nullable
|
||||
val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull)
|
||||
assert(!ElementAt(stArray1, Literal(1)).nullable)
|
||||
assert(!ElementAt(stArray1, Literal(-1)).nullable)
|
||||
val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull)
|
||||
assert(ElementAt(stArray2, Literal(1)).nullable)
|
||||
assert(ElementAt(stArray2, Literal(-1)).nullable)
|
||||
|
||||
val d = AttributeReference("d", structType, nullable = true)()
|
||||
val inputArray2 = CreateArray(c :: d :: Nil)
|
||||
val inputArray2ContainsNull = c.nullable || d.nullable
|
||||
val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull)
|
||||
assert(!ElementAt(stArray3, Literal(1)).nullable)
|
||||
assert(!ElementAt(stArray3, Literal(-2)).nullable)
|
||||
assert(ElementAt(stArray3, Literal(2)).nullable)
|
||||
assert(ElementAt(stArray3, Literal(-1)).nullable)
|
||||
val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull)
|
||||
assert(ElementAt(stArray4, Literal(1)).nullable)
|
||||
assert(ElementAt(stArray4, Literal(-2)).nullable)
|
||||
assert(ElementAt(stArray4, Literal(2)).nullable)
|
||||
assert(ElementAt(stArray4, Literal(-1)).nullable)
|
||||
val d = AttributeReference("d", structType, nullable = true)()
|
||||
val inputArray2 = CreateArray(c :: d :: Nil)
|
||||
val inputArray2ContainsNull = c.nullable || d.nullable
|
||||
val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull)
|
||||
assert(!ElementAt(stArray3, Literal(1)).nullable)
|
||||
assert(!ElementAt(stArray3, Literal(-2)).nullable)
|
||||
assert(ElementAt(stArray3, Literal(2)).nullable)
|
||||
assert(ElementAt(stArray3, Literal(-1)).nullable)
|
||||
val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull)
|
||||
assert(ElementAt(stArray4, Literal(1)).nullable)
|
||||
assert(ElementAt(stArray4, Literal(-2)).nullable)
|
||||
assert(ElementAt(stArray4, Literal(2)).nullable)
|
||||
assert(ElementAt(stArray4, Literal(-1)).nullable)
|
||||
|
||||
// GetArrayStructFields case invalid indices
|
||||
assert(!ElementAt(stArray3, Literal(0)).nullable)
|
||||
assert(ElementAt(stArray3, Literal(4)).nullable)
|
||||
assert(ElementAt(stArray3, Literal(-4)).nullable)
|
||||
// GetArrayStructFields case invalid indices
|
||||
assert(!ElementAt(stArray3, Literal(0)).nullable)
|
||||
assert(ElementAt(stArray3, Literal(4)).nullable == !ansiEnabled)
|
||||
assert(ElementAt(stArray3, Literal(-4)).nullable == !ansiEnabled)
|
||||
|
||||
assert(ElementAt(stArray4, Literal(0)).nullable)
|
||||
assert(ElementAt(stArray4, Literal(4)).nullable)
|
||||
assert(ElementAt(stArray4, Literal(-4)).nullable)
|
||||
assert(ElementAt(stArray4, Literal(0)).nullable)
|
||||
assert(ElementAt(stArray4, Literal(4)).nullable)
|
||||
assert(ElementAt(stArray4, Literal(-4)).nullable)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("Concat") {
|
||||
|
@ -1883,4 +1887,32 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
Literal(stringToInterval("interval 1 year"))),
|
||||
Seq(Date.valueOf("2018-01-01")))
|
||||
}
|
||||
|
||||
test("SPARK-33386: element_at ArrayIndexOutOfBoundsException") {
|
||||
Seq(true, false).foreach { ansiEnabled =>
|
||||
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
|
||||
val array = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
|
||||
var expr: Expression = ElementAt(array, Literal(5))
|
||||
if (ansiEnabled) {
|
||||
val errMsg = "Invalid index: 5, numElements: 3"
|
||||
checkExceptionInExpression[Exception](expr, errMsg)
|
||||
} else {
|
||||
checkEvaluation(expr, null)
|
||||
}
|
||||
|
||||
expr = ElementAt(array, Literal(-5))
|
||||
if (ansiEnabled) {
|
||||
val errMsg = "Invalid index: -5, numElements: 3"
|
||||
checkExceptionInExpression[Exception](expr, errMsg)
|
||||
} else {
|
||||
checkEvaluation(expr, null)
|
||||
}
|
||||
|
||||
// SQL array indices start at 1 exception throws for both mode.
|
||||
expr = ElementAt(array, Literal(0))
|
||||
val errMsg = "SQL array indices start at 1"
|
||||
checkExceptionInExpression[Exception](expr, errMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -62,6 +62,29 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1))
|
||||
}
|
||||
|
||||
test("SPARK-33386: GetArrayItem ArrayIndexOutOfBoundsException") {
|
||||
Seq(true, false).foreach { ansiEnabled =>
|
||||
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
|
||||
val array = Literal.create(Seq("a", "b"), ArrayType(StringType))
|
||||
|
||||
if (ansiEnabled) {
|
||||
checkExceptionInExpression[Exception](
|
||||
GetArrayItem(array, Literal(5)),
|
||||
"Invalid index: 5, numElements: 2"
|
||||
)
|
||||
|
||||
checkExceptionInExpression[Exception](
|
||||
GetArrayItem(array, Literal(-1)),
|
||||
"Invalid index: -1, numElements: 2"
|
||||
)
|
||||
} else {
|
||||
checkEvaluation(GetArrayItem(array, Literal(5)), null)
|
||||
checkEvaluation(GetArrayItem(array, Literal(-1)), null)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-26637 handles GetArrayItem nullability correctly when input array size is constant") {
|
||||
// CreateArray case
|
||||
val a = AttributeReference("a", IntegerType, nullable = false)()
|
||||
|
|
|
@ -18,9 +18,9 @@
|
|||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||
|
@ -968,4 +968,34 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
GenerateUnsafeProjection.generate(
|
||||
Sentences(Literal("\"quote"), Literal("\"quote"), Literal("\"quote")) :: Nil)
|
||||
}
|
||||
|
||||
test("SPARK-33386: elt ArrayIndexOutOfBoundsException") {
|
||||
Seq(true, false).foreach { ansiEnabled =>
|
||||
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
|
||||
var expr: Expression = Elt(Seq(Literal(4), Literal("123"), Literal("456")))
|
||||
if (ansiEnabled) {
|
||||
val errMsg = "Invalid index: 4, numElements: 2"
|
||||
checkExceptionInExpression[Exception](expr, errMsg)
|
||||
} else {
|
||||
checkEvaluation(expr, null)
|
||||
}
|
||||
|
||||
expr = Elt(Seq(Literal(0), Literal("123"), Literal("456")))
|
||||
if (ansiEnabled) {
|
||||
val errMsg = "Invalid index: 0, numElements: 2"
|
||||
checkExceptionInExpression[Exception](expr, errMsg)
|
||||
} else {
|
||||
checkEvaluation(expr, null)
|
||||
}
|
||||
|
||||
expr = Elt(Seq(Literal(-1), Literal("123"), Literal("456")))
|
||||
if (ansiEnabled) {
|
||||
val errMsg = "Invalid index: -1, numElements: 2"
|
||||
checkExceptionInExpression[Exception](expr, errMsg)
|
||||
} else {
|
||||
checkEvaluation(expr, null)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
--IMPORT array.sql
|
|
@ -90,3 +90,15 @@ select
|
|||
size(date_array),
|
||||
size(timestamp_array)
|
||||
from primitive_arrays;
|
||||
|
||||
-- index out of range for array elements
|
||||
select element_at(array(1, 2, 3), 5);
|
||||
select element_at(array(1, 2, 3), -5);
|
||||
select element_at(array(1, 2, 3), 0);
|
||||
|
||||
select elt(4, '123', '456');
|
||||
select elt(0, '123', '456');
|
||||
select elt(-1, '123', '456');
|
||||
|
||||
select array(1, 2, 3)[5];
|
||||
select array(1, 2, 3)[-1];
|
||||
|
|
234
sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
Normal file
234
sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
Normal file
|
@ -0,0 +1,234 @@
|
|||
-- Automatically generated by SQLQueryTestSuite
|
||||
-- Number of queries: 20
|
||||
|
||||
|
||||
-- !query
|
||||
create temporary view data as select * from values
|
||||
("one", array(11, 12, 13), array(array(111, 112, 113), array(121, 122, 123))),
|
||||
("two", array(21, 22, 23), array(array(211, 212, 213), array(221, 222, 223)))
|
||||
as data(a, b, c)
|
||||
-- !query schema
|
||||
struct<>
|
||||
-- !query output
|
||||
|
||||
|
||||
|
||||
-- !query
|
||||
select * from data
|
||||
-- !query schema
|
||||
struct<a:string,b:array<int>,c:array<array<int>>>
|
||||
-- !query output
|
||||
one [11,12,13] [[111,112,113],[121,122,123]]
|
||||
two [21,22,23] [[211,212,213],[221,222,223]]
|
||||
|
||||
|
||||
-- !query
|
||||
select a, b[0], b[0] + b[1] from data
|
||||
-- !query schema
|
||||
struct<a:string,b[0]:int,(b[0] + b[1]):int>
|
||||
-- !query output
|
||||
one 11 23
|
||||
two 21 43
|
||||
|
||||
|
||||
-- !query
|
||||
select a, c[0][0] + c[0][0 + 1] from data
|
||||
-- !query schema
|
||||
struct<a:string,(c[0][0] + c[0][(0 + 1)]):int>
|
||||
-- !query output
|
||||
one 223
|
||||
two 423
|
||||
|
||||
|
||||
-- !query
|
||||
create temporary view primitive_arrays as select * from values (
|
||||
array(true),
|
||||
array(2Y, 1Y),
|
||||
array(2S, 1S),
|
||||
array(2, 1),
|
||||
array(2L, 1L),
|
||||
array(9223372036854775809, 9223372036854775808),
|
||||
array(2.0D, 1.0D),
|
||||
array(float(2.0), float(1.0)),
|
||||
array(date '2016-03-14', date '2016-03-13'),
|
||||
array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000')
|
||||
) as primitive_arrays(
|
||||
boolean_array,
|
||||
tinyint_array,
|
||||
smallint_array,
|
||||
int_array,
|
||||
bigint_array,
|
||||
decimal_array,
|
||||
double_array,
|
||||
float_array,
|
||||
date_array,
|
||||
timestamp_array
|
||||
)
|
||||
-- !query schema
|
||||
struct<>
|
||||
-- !query output
|
||||
|
||||
|
||||
|
||||
-- !query
|
||||
select * from primitive_arrays
|
||||
-- !query schema
|
||||
struct<boolean_array:array<boolean>,tinyint_array:array<tinyint>,smallint_array:array<smallint>,int_array:array<int>,bigint_array:array<bigint>,decimal_array:array<decimal(19,0)>,double_array:array<double>,float_array:array<float>,date_array:array<date>,timestamp_array:array<timestamp>>
|
||||
-- !query output
|
||||
[true] [2,1] [2,1] [2,1] [2,1] [9223372036854775809,9223372036854775808] [2.0,1.0] [2.0,1.0] [2016-03-14,2016-03-13] [2016-11-15 20:54:00,2016-11-12 20:54:00]
|
||||
|
||||
|
||||
-- !query
|
||||
select
|
||||
array_contains(boolean_array, true), array_contains(boolean_array, false),
|
||||
array_contains(tinyint_array, 2Y), array_contains(tinyint_array, 0Y),
|
||||
array_contains(smallint_array, 2S), array_contains(smallint_array, 0S),
|
||||
array_contains(int_array, 2), array_contains(int_array, 0),
|
||||
array_contains(bigint_array, 2L), array_contains(bigint_array, 0L),
|
||||
array_contains(decimal_array, 9223372036854775809), array_contains(decimal_array, 1),
|
||||
array_contains(double_array, 2.0D), array_contains(double_array, 0.0D),
|
||||
array_contains(float_array, float(2.0)), array_contains(float_array, float(0.0)),
|
||||
array_contains(date_array, date '2016-03-14'), array_contains(date_array, date '2016-01-01'),
|
||||
array_contains(timestamp_array, timestamp '2016-11-15 20:54:00.000'), array_contains(timestamp_array, timestamp '2016-01-01 20:54:00.000')
|
||||
from primitive_arrays
|
||||
-- !query schema
|
||||
struct<array_contains(boolean_array, true):boolean,array_contains(boolean_array, false):boolean,array_contains(tinyint_array, 2):boolean,array_contains(tinyint_array, 0):boolean,array_contains(smallint_array, 2):boolean,array_contains(smallint_array, 0):boolean,array_contains(int_array, 2):boolean,array_contains(int_array, 0):boolean,array_contains(bigint_array, 2):boolean,array_contains(bigint_array, 0):boolean,array_contains(decimal_array, 9223372036854775809):boolean,array_contains(decimal_array, CAST(1 AS DECIMAL(19,0))):boolean,array_contains(double_array, 2.0):boolean,array_contains(double_array, 0.0):boolean,array_contains(float_array, CAST(2.0 AS FLOAT)):boolean,array_contains(float_array, CAST(0.0 AS FLOAT)):boolean,array_contains(date_array, DATE '2016-03-14'):boolean,array_contains(date_array, DATE '2016-01-01'):boolean,array_contains(timestamp_array, TIMESTAMP '2016-11-15 20:54:00'):boolean,array_contains(timestamp_array, TIMESTAMP '2016-01-01 20:54:00'):boolean>
|
||||
-- !query output
|
||||
true false true false true false true false true false true false true false true false true false true false
|
||||
|
||||
|
||||
-- !query
|
||||
select array_contains(b, 11), array_contains(c, array(111, 112, 113)) from data
|
||||
-- !query schema
|
||||
struct<array_contains(b, 11):boolean,array_contains(c, array(111, 112, 113)):boolean>
|
||||
-- !query output
|
||||
false false
|
||||
true true
|
||||
|
||||
|
||||
-- !query
|
||||
select
|
||||
sort_array(boolean_array),
|
||||
sort_array(tinyint_array),
|
||||
sort_array(smallint_array),
|
||||
sort_array(int_array),
|
||||
sort_array(bigint_array),
|
||||
sort_array(decimal_array),
|
||||
sort_array(double_array),
|
||||
sort_array(float_array),
|
||||
sort_array(date_array),
|
||||
sort_array(timestamp_array)
|
||||
from primitive_arrays
|
||||
-- !query schema
|
||||
struct<sort_array(boolean_array, true):array<boolean>,sort_array(tinyint_array, true):array<tinyint>,sort_array(smallint_array, true):array<smallint>,sort_array(int_array, true):array<int>,sort_array(bigint_array, true):array<bigint>,sort_array(decimal_array, true):array<decimal(19,0)>,sort_array(double_array, true):array<double>,sort_array(float_array, true):array<float>,sort_array(date_array, true):array<date>,sort_array(timestamp_array, true):array<timestamp>>
|
||||
-- !query output
|
||||
[true] [1,2] [1,2] [1,2] [1,2] [9223372036854775808,9223372036854775809] [1.0,2.0] [1.0,2.0] [2016-03-13,2016-03-14] [2016-11-12 20:54:00,2016-11-15 20:54:00]
|
||||
|
||||
|
||||
-- !query
|
||||
select sort_array(array('b', 'd'), '1')
|
||||
-- !query schema
|
||||
struct<>
|
||||
-- !query output
|
||||
org.apache.spark.sql.AnalysisException
|
||||
cannot resolve 'sort_array(array('b', 'd'), '1')' due to data type mismatch: Sort order in second argument requires a boolean literal.; line 1 pos 7
|
||||
|
||||
|
||||
-- !query
|
||||
select sort_array(array('b', 'd'), cast(NULL as boolean))
|
||||
-- !query schema
|
||||
struct<>
|
||||
-- !query output
|
||||
org.apache.spark.sql.AnalysisException
|
||||
cannot resolve 'sort_array(array('b', 'd'), CAST(NULL AS BOOLEAN))' due to data type mismatch: Sort order in second argument requires a boolean literal.; line 1 pos 7
|
||||
|
||||
|
||||
-- !query
|
||||
select
|
||||
size(boolean_array),
|
||||
size(tinyint_array),
|
||||
size(smallint_array),
|
||||
size(int_array),
|
||||
size(bigint_array),
|
||||
size(decimal_array),
|
||||
size(double_array),
|
||||
size(float_array),
|
||||
size(date_array),
|
||||
size(timestamp_array)
|
||||
from primitive_arrays
|
||||
-- !query schema
|
||||
struct<size(boolean_array):int,size(tinyint_array):int,size(smallint_array):int,size(int_array):int,size(bigint_array):int,size(decimal_array):int,size(double_array):int,size(float_array):int,size(date_array):int,size(timestamp_array):int>
|
||||
-- !query output
|
||||
1 2 2 2 2 2 2 2 2 2
|
||||
|
||||
|
||||
-- !query
|
||||
select element_at(array(1, 2, 3), 5)
|
||||
-- !query schema
|
||||
struct<>
|
||||
-- !query output
|
||||
java.lang.ArrayIndexOutOfBoundsException
|
||||
Invalid index: 5, numElements: 3
|
||||
|
||||
|
||||
-- !query
|
||||
select element_at(array(1, 2, 3), -5)
|
||||
-- !query schema
|
||||
struct<>
|
||||
-- !query output
|
||||
java.lang.ArrayIndexOutOfBoundsException
|
||||
Invalid index: -5, numElements: 3
|
||||
|
||||
|
||||
-- !query
|
||||
select element_at(array(1, 2, 3), 0)
|
||||
-- !query schema
|
||||
struct<>
|
||||
-- !query output
|
||||
java.lang.ArrayIndexOutOfBoundsException
|
||||
SQL array indices start at 1
|
||||
|
||||
|
||||
-- !query
|
||||
select elt(4, '123', '456')
|
||||
-- !query schema
|
||||
struct<>
|
||||
-- !query output
|
||||
java.lang.ArrayIndexOutOfBoundsException
|
||||
Invalid index: 4, numElements: 2
|
||||
|
||||
|
||||
-- !query
|
||||
select elt(0, '123', '456')
|
||||
-- !query schema
|
||||
struct<>
|
||||
-- !query output
|
||||
java.lang.ArrayIndexOutOfBoundsException
|
||||
Invalid index: 0, numElements: 2
|
||||
|
||||
|
||||
-- !query
|
||||
select elt(-1, '123', '456')
|
||||
-- !query schema
|
||||
struct<>
|
||||
-- !query output
|
||||
java.lang.ArrayIndexOutOfBoundsException
|
||||
Invalid index: -1, numElements: 2
|
||||
|
||||
|
||||
-- !query
|
||||
select array(1, 2, 3)[5]
|
||||
-- !query schema
|
||||
struct<>
|
||||
-- !query output
|
||||
java.lang.ArrayIndexOutOfBoundsException
|
||||
Invalid index: 5, numElements: 3
|
||||
|
||||
|
||||
-- !query
|
||||
select array(1, 2, 3)[-1]
|
||||
-- !query schema
|
||||
struct<>
|
||||
-- !query output
|
||||
java.lang.ArrayIndexOutOfBoundsException
|
||||
Invalid index: -1, numElements: 3
|
|
@ -1,5 +1,5 @@
|
|||
-- Automatically generated by SQLQueryTestSuite
|
||||
-- Number of queries: 12
|
||||
-- Number of queries: 20
|
||||
|
||||
|
||||
-- !query
|
||||
|
@ -160,3 +160,68 @@ from primitive_arrays
|
|||
struct<size(boolean_array):int,size(tinyint_array):int,size(smallint_array):int,size(int_array):int,size(bigint_array):int,size(decimal_array):int,size(double_array):int,size(float_array):int,size(date_array):int,size(timestamp_array):int>
|
||||
-- !query output
|
||||
1 2 2 2 2 2 2 2 2 2
|
||||
|
||||
|
||||
-- !query
|
||||
select element_at(array(1, 2, 3), 5)
|
||||
-- !query schema
|
||||
struct<element_at(array(1, 2, 3), 5):int>
|
||||
-- !query output
|
||||
NULL
|
||||
|
||||
|
||||
-- !query
|
||||
select element_at(array(1, 2, 3), -5)
|
||||
-- !query schema
|
||||
struct<element_at(array(1, 2, 3), -5):int>
|
||||
-- !query output
|
||||
NULL
|
||||
|
||||
|
||||
-- !query
|
||||
select element_at(array(1, 2, 3), 0)
|
||||
-- !query schema
|
||||
struct<>
|
||||
-- !query output
|
||||
java.lang.ArrayIndexOutOfBoundsException
|
||||
SQL array indices start at 1
|
||||
|
||||
|
||||
-- !query
|
||||
select elt(4, '123', '456')
|
||||
-- !query schema
|
||||
struct<elt(4, 123, 456):string>
|
||||
-- !query output
|
||||
NULL
|
||||
|
||||
|
||||
-- !query
|
||||
select elt(0, '123', '456')
|
||||
-- !query schema
|
||||
struct<elt(0, 123, 456):string>
|
||||
-- !query output
|
||||
NULL
|
||||
|
||||
|
||||
-- !query
|
||||
select elt(-1, '123', '456')
|
||||
-- !query schema
|
||||
struct<elt(-1, 123, 456):string>
|
||||
-- !query output
|
||||
NULL
|
||||
|
||||
|
||||
-- !query
|
||||
select array(1, 2, 3)[5]
|
||||
-- !query schema
|
||||
struct<array(1, 2, 3)[5]:int>
|
||||
-- !query output
|
||||
NULL
|
||||
|
||||
|
||||
-- !query
|
||||
select array(1, 2, 3)[-1]
|
||||
-- !query schema
|
||||
struct<array(1, 2, 3)[-1]:int>
|
||||
-- !query output
|
||||
NULL
|
||||
|
|
Loading…
Reference in a new issue