[SPARK-33391][SQL] element_at with CreateArray not respect one based index
### What changes were proposed in this pull request? element_at with CreateArray not respect one based index. repo step: ``` var df = spark.sql("select element_at(array(3, 2, 1), 0)") df.printSchema() df = spark.sql("select element_at(array(3, 2, 1), 1)") df.printSchema() df = spark.sql("select element_at(array(3, 2, 1), 2)") df.printSchema() df = spark.sql("select element_at(array(3, 2, 1), 3)") df.printSchema() root – element_at(array(3, 2, 1), 0): integer (nullable = false) root – element_at(array(3, 2, 1), 1): integer (nullable = false) root – element_at(array(3, 2, 1), 2): integer (nullable = false) root – element_at(array(3, 2, 1), 3): integer (nullable = true) correct answer should be 0 true which is outOfBounds return default true. 1 false 2 false 3 false ``` For expression eval, it respect the oneBasedIndex, but within checking the nullable, it calculates with zeroBasedIndex using `computeNullabilityFromArray`. ### Why are the changes needed? Correctness issue. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Added UT and existing UT. Closes #30296 from leanken/leanken-SPARK-33391. Authored-by: xuewei.linxuewei <xuewei.linxuewei@alibaba-inc.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
ad02ceda29
commit
e3a768dd79
|
@ -1965,6 +1965,36 @@ case class ElementAt(left: Expression, right: Expression)
|
|||
}
|
||||
}
|
||||
|
||||
private def nullability(elements: Seq[Expression], ordinal: Int): Boolean = {
|
||||
if (ordinal == 0) {
|
||||
false
|
||||
} else if (elements.length < math.abs(ordinal)) {
|
||||
true
|
||||
} else {
|
||||
if (ordinal < 0) {
|
||||
elements(elements.length + ordinal).nullable
|
||||
} else {
|
||||
elements(ordinal - 1).nullable
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = {
|
||||
if (ordinal.foldable && !ordinal.nullable) {
|
||||
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
|
||||
child match {
|
||||
case CreateArray(ar, _) =>
|
||||
nullability(ar, intOrdinal)
|
||||
case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) =>
|
||||
nullability(elements, intOrdinal) || field.nullable
|
||||
case _ =>
|
||||
true
|
||||
}
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
override def nullable: Boolean = left.dataType match {
|
||||
case _: ArrayType => computeNullabilityFromArray(left, right)
|
||||
case _: MapType => true
|
||||
|
|
|
@ -1122,11 +1122,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
val a = AttributeReference("a", IntegerType, nullable = false)()
|
||||
val b = AttributeReference("b", IntegerType, nullable = true)()
|
||||
val array = CreateArray(a :: b :: Nil)
|
||||
assert(!ElementAt(array, Literal(0)).nullable)
|
||||
assert(ElementAt(array, Literal(1)).nullable)
|
||||
assert(!ElementAt(array, Subtract(Literal(2), Literal(2))).nullable)
|
||||
assert(!ElementAt(array, Literal(1)).nullable)
|
||||
assert(!ElementAt(array, Literal(-2)).nullable)
|
||||
assert(ElementAt(array, Literal(2)).nullable)
|
||||
assert(ElementAt(array, Literal(-1)).nullable)
|
||||
assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable)
|
||||
assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable)
|
||||
|
||||
// CreateArray case invalid indices
|
||||
assert(!ElementAt(array, Literal(0)).nullable)
|
||||
assert(ElementAt(array, Literal(4)).nullable)
|
||||
assert(ElementAt(array, Literal(-4)).nullable)
|
||||
|
||||
// GetArrayStructFields case
|
||||
val f1 = StructField("a", IntegerType, nullable = false)
|
||||
val f2 = StructField("b", IntegerType, nullable = true)
|
||||
|
@ -1135,19 +1142,34 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
val inputArray1 = CreateArray(c :: Nil)
|
||||
val inputArray1ContainsNull = c.nullable
|
||||
val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull)
|
||||
assert(!ElementAt(stArray1, Literal(0)).nullable)
|
||||
assert(!ElementAt(stArray1, Literal(1)).nullable)
|
||||
assert(!ElementAt(stArray1, Literal(-1)).nullable)
|
||||
val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull)
|
||||
assert(ElementAt(stArray2, Literal(0)).nullable)
|
||||
assert(ElementAt(stArray2, Literal(1)).nullable)
|
||||
assert(ElementAt(stArray2, Literal(-1)).nullable)
|
||||
|
||||
val d = AttributeReference("d", structType, nullable = true)()
|
||||
val inputArray2 = CreateArray(c :: d :: Nil)
|
||||
val inputArray2ContainsNull = c.nullable || d.nullable
|
||||
val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull)
|
||||
assert(!ElementAt(stArray3, Literal(0)).nullable)
|
||||
assert(ElementAt(stArray3, Literal(1)).nullable)
|
||||
assert(!ElementAt(stArray3, Literal(1)).nullable)
|
||||
assert(!ElementAt(stArray3, Literal(-2)).nullable)
|
||||
assert(ElementAt(stArray3, Literal(2)).nullable)
|
||||
assert(ElementAt(stArray3, Literal(-1)).nullable)
|
||||
val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull)
|
||||
assert(ElementAt(stArray4, Literal(0)).nullable)
|
||||
assert(ElementAt(stArray4, Literal(1)).nullable)
|
||||
assert(ElementAt(stArray4, Literal(-2)).nullable)
|
||||
assert(ElementAt(stArray4, Literal(2)).nullable)
|
||||
assert(ElementAt(stArray4, Literal(-1)).nullable)
|
||||
|
||||
// GetArrayStructFields case invalid indices
|
||||
assert(!ElementAt(stArray3, Literal(0)).nullable)
|
||||
assert(ElementAt(stArray3, Literal(4)).nullable)
|
||||
assert(ElementAt(stArray3, Literal(-4)).nullable)
|
||||
|
||||
assert(ElementAt(stArray4, Literal(0)).nullable)
|
||||
assert(ElementAt(stArray4, Literal(4)).nullable)
|
||||
assert(ElementAt(stArray4, Literal(-4)).nullable)
|
||||
}
|
||||
|
||||
test("Concat") {
|
||||
|
|
Loading…
Reference in a new issue