diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8719b2e065..cb081b80ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1965,6 +1965,36 @@ case class ElementAt(left: Expression, right: Expression) } } + private def nullability(elements: Seq[Expression], ordinal: Int): Boolean = { + if (ordinal == 0) { + false + } else if (elements.length < math.abs(ordinal)) { + true + } else { + if (ordinal < 0) { + elements(elements.length + ordinal).nullable + } else { + elements(ordinal - 1).nullable + } + } + } + + override def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = { + if (ordinal.foldable && !ordinal.nullable) { + val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue() + child match { + case CreateArray(ar, _) => + nullability(ar, intOrdinal) + case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) => + nullability(elements, intOrdinal) || field.nullable + case _ => + true + } + } else { + true + } + } + override def nullable: Boolean = left.dataType match { case _: ArrayType => computeNullabilityFromArray(left, right) case _: MapType => true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 856c1fad9b..d59d13d49c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1122,11 +1122,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a = AttributeReference("a", IntegerType, nullable = false)() val b = AttributeReference("b", IntegerType, nullable = true)() val array = CreateArray(a :: b :: Nil) - assert(!ElementAt(array, Literal(0)).nullable) - assert(ElementAt(array, Literal(1)).nullable) - assert(!ElementAt(array, Subtract(Literal(2), Literal(2))).nullable) + assert(!ElementAt(array, Literal(1)).nullable) + assert(!ElementAt(array, Literal(-2)).nullable) + assert(ElementAt(array, Literal(2)).nullable) + assert(ElementAt(array, Literal(-1)).nullable) + assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable) assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable) + // CreateArray case invalid indices + assert(!ElementAt(array, Literal(0)).nullable) + assert(ElementAt(array, Literal(4)).nullable) + assert(ElementAt(array, Literal(-4)).nullable) + // GetArrayStructFields case val f1 = StructField("a", IntegerType, nullable = false) val f2 = StructField("b", IntegerType, nullable = true) @@ -1135,19 +1142,34 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val inputArray1 = CreateArray(c :: Nil) val inputArray1ContainsNull = c.nullable val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull) - assert(!ElementAt(stArray1, Literal(0)).nullable) + assert(!ElementAt(stArray1, Literal(1)).nullable) + assert(!ElementAt(stArray1, Literal(-1)).nullable) val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull) - assert(ElementAt(stArray2, Literal(0)).nullable) + assert(ElementAt(stArray2, Literal(1)).nullable) + assert(ElementAt(stArray2, Literal(-1)).nullable) val d = AttributeReference("d", structType, nullable = true)() val inputArray2 = CreateArray(c :: d :: Nil) val inputArray2ContainsNull = c.nullable || d.nullable val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull) - assert(!ElementAt(stArray3, Literal(0)).nullable) - assert(ElementAt(stArray3, Literal(1)).nullable) + assert(!ElementAt(stArray3, Literal(1)).nullable) + assert(!ElementAt(stArray3, Literal(-2)).nullable) + assert(ElementAt(stArray3, Literal(2)).nullable) + assert(ElementAt(stArray3, Literal(-1)).nullable) val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull) - assert(ElementAt(stArray4, Literal(0)).nullable) assert(ElementAt(stArray4, Literal(1)).nullable) + assert(ElementAt(stArray4, Literal(-2)).nullable) + assert(ElementAt(stArray4, Literal(2)).nullable) + assert(ElementAt(stArray4, Literal(-1)).nullable) + + // GetArrayStructFields case invalid indices + assert(!ElementAt(stArray3, Literal(0)).nullable) + assert(ElementAt(stArray3, Literal(4)).nullable) + assert(ElementAt(stArray3, Literal(-4)).nullable) + + assert(ElementAt(stArray4, Literal(0)).nullable) + assert(ElementAt(stArray4, Literal(4)).nullable) + assert(ElementAt(stArray4, Literal(-4)).nullable) } test("Concat") {