[SPARK-33391][SQL] element_at with CreateArray not respect one based index

### What changes were proposed in this pull request?

element_at with CreateArray not respect one based index.

repo step:

```
var df = spark.sql("select element_at(array(3, 2, 1), 0)")
df.printSchema()

df = spark.sql("select element_at(array(3, 2, 1), 1)")
df.printSchema()

df = spark.sql("select element_at(array(3, 2, 1), 2)")
df.printSchema()

df = spark.sql("select element_at(array(3, 2, 1), 3)")
df.printSchema()

root
– element_at(array(3, 2, 1), 0): integer (nullable = false)

root
– element_at(array(3, 2, 1), 1): integer (nullable = false)

root
– element_at(array(3, 2, 1), 2): integer (nullable = false)

root
– element_at(array(3, 2, 1), 3): integer (nullable = true)

correct answer should be
0 true which is outOfBounds return default true.
1 false
2 false
3 false

```

For expression eval, it respect the oneBasedIndex, but within checking the nullable, it calculates with zeroBasedIndex using `computeNullabilityFromArray`.

### Why are the changes needed?

Correctness issue.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

Added UT and existing UT.

Closes #30296 from leanken/leanken-SPARK-33391.

Authored-by: xuewei.linxuewei <xuewei.linxuewei@alibaba-inc.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
xuewei.linxuewei 2020-11-10 07:23:47 +00:00 committed by Wenchen Fan
parent ad02ceda29
commit e3a768dd79
2 changed files with 60 additions and 8 deletions

View file

@ -1965,6 +1965,36 @@ case class ElementAt(left: Expression, right: Expression)
}
}
private def nullability(elements: Seq[Expression], ordinal: Int): Boolean = {
if (ordinal == 0) {
false
} else if (elements.length < math.abs(ordinal)) {
true
} else {
if (ordinal < 0) {
elements(elements.length + ordinal).nullable
} else {
elements(ordinal - 1).nullable
}
}
}
override def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = {
if (ordinal.foldable && !ordinal.nullable) {
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
child match {
case CreateArray(ar, _) =>
nullability(ar, intOrdinal)
case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) =>
nullability(elements, intOrdinal) || field.nullable
case _ =>
true
}
} else {
true
}
}
override def nullable: Boolean = left.dataType match {
case _: ArrayType => computeNullabilityFromArray(left, right)
case _: MapType => true

View file

@ -1122,11 +1122,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val a = AttributeReference("a", IntegerType, nullable = false)()
val b = AttributeReference("b", IntegerType, nullable = true)()
val array = CreateArray(a :: b :: Nil)
assert(!ElementAt(array, Literal(0)).nullable)
assert(ElementAt(array, Literal(1)).nullable)
assert(!ElementAt(array, Subtract(Literal(2), Literal(2))).nullable)
assert(!ElementAt(array, Literal(1)).nullable)
assert(!ElementAt(array, Literal(-2)).nullable)
assert(ElementAt(array, Literal(2)).nullable)
assert(ElementAt(array, Literal(-1)).nullable)
assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable)
assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable)
// CreateArray case invalid indices
assert(!ElementAt(array, Literal(0)).nullable)
assert(ElementAt(array, Literal(4)).nullable)
assert(ElementAt(array, Literal(-4)).nullable)
// GetArrayStructFields case
val f1 = StructField("a", IntegerType, nullable = false)
val f2 = StructField("b", IntegerType, nullable = true)
@ -1135,19 +1142,34 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val inputArray1 = CreateArray(c :: Nil)
val inputArray1ContainsNull = c.nullable
val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull)
assert(!ElementAt(stArray1, Literal(0)).nullable)
assert(!ElementAt(stArray1, Literal(1)).nullable)
assert(!ElementAt(stArray1, Literal(-1)).nullable)
val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull)
assert(ElementAt(stArray2, Literal(0)).nullable)
assert(ElementAt(stArray2, Literal(1)).nullable)
assert(ElementAt(stArray2, Literal(-1)).nullable)
val d = AttributeReference("d", structType, nullable = true)()
val inputArray2 = CreateArray(c :: d :: Nil)
val inputArray2ContainsNull = c.nullable || d.nullable
val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull)
assert(!ElementAt(stArray3, Literal(0)).nullable)
assert(ElementAt(stArray3, Literal(1)).nullable)
assert(!ElementAt(stArray3, Literal(1)).nullable)
assert(!ElementAt(stArray3, Literal(-2)).nullable)
assert(ElementAt(stArray3, Literal(2)).nullable)
assert(ElementAt(stArray3, Literal(-1)).nullable)
val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull)
assert(ElementAt(stArray4, Literal(0)).nullable)
assert(ElementAt(stArray4, Literal(1)).nullable)
assert(ElementAt(stArray4, Literal(-2)).nullable)
assert(ElementAt(stArray4, Literal(2)).nullable)
assert(ElementAt(stArray4, Literal(-1)).nullable)
// GetArrayStructFields case invalid indices
assert(!ElementAt(stArray3, Literal(0)).nullable)
assert(ElementAt(stArray3, Literal(4)).nullable)
assert(ElementAt(stArray3, Literal(-4)).nullable)
assert(ElementAt(stArray4, Literal(0)).nullable)
assert(ElementAt(stArray4, Literal(4)).nullable)
assert(ElementAt(stArray4, Literal(-4)).nullable)
}
test("Concat") {