[SPARK-23919][SQL] Add array_position function
## What changes were proposed in this pull request? The PR adds the SQL function `array_position`. The behavior of the function is based on Presto's one. The function returns the position of the first occurrence of the element in array x (or 0 if not found) using 1-based index as BigInt. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com> Closes #21037 from kiszk/SPARK-23919.
This commit is contained in:
parent
8bb0df2c65
commit
d5bec48b9c
|
@ -1845,6 +1845,23 @@ def array_contains(col, value):
|
|||
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
|
||||
|
||||
|
||||
@since(2.4)
|
||||
def array_position(col, value):
|
||||
"""
|
||||
Collection function: Locates the position of the first occurrence of the given value
|
||||
in the given array. Returns null if either of the arguments are null.
|
||||
|
||||
.. note:: The position is not zero based, but 1 based index. Returns 0 if the given
|
||||
value could not be found in the array.
|
||||
|
||||
>>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data'])
|
||||
>>> df.select(array_position(df.data, "a")).collect()
|
||||
[Row(array_position(data, a)=3), Row(array_position(data, a)=0)]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
return Column(sc._jvm.functions.array_position(_to_java_column(col), value))
|
||||
|
||||
|
||||
@since(1.4)
|
||||
def explode(col):
|
||||
"""Returns a new row for each element in the given array or map.
|
||||
|
|
|
@ -402,6 +402,7 @@ object FunctionRegistry {
|
|||
// collection functions
|
||||
expression[CreateArray]("array"),
|
||||
expression[ArrayContains]("array_contains"),
|
||||
expression[ArrayPosition]("array_position"),
|
||||
expression[CreateMap]("map"),
|
||||
expression[CreateNamedStruct]("named_struct"),
|
||||
expression[MapKeys]("map_keys"),
|
||||
|
|
|
@ -505,3 +505,59 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
|
|||
|
||||
override def prettyName: String = "array_max"
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns the position of the first occurrence of element in the given array as long.
|
||||
* Returns 0 if the given value could not be found in the array. Returns null if either of
|
||||
* the arguments are null
|
||||
*
|
||||
* NOTE: that this is not zero based, but 1-based index. The first element in the array has
|
||||
* index 1.
|
||||
*/
|
||||
@ExpressionDescription(
|
||||
usage = """
|
||||
_FUNC_(array, element) - Returns the (1-based) index of the first element of the array as long.
|
||||
""",
|
||||
examples = """
|
||||
Examples:
|
||||
> SELECT _FUNC_(array(3, 2, 1), 1);
|
||||
3
|
||||
""",
|
||||
since = "2.4.0")
|
||||
case class ArrayPosition(left: Expression, right: Expression)
|
||||
extends BinaryExpression with ImplicitCastInputTypes {
|
||||
|
||||
override def dataType: DataType = LongType
|
||||
override def inputTypes: Seq[AbstractDataType] =
|
||||
Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType)
|
||||
|
||||
override def nullSafeEval(arr: Any, value: Any): Any = {
|
||||
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
|
||||
if (v == value) {
|
||||
return (i + 1).toLong
|
||||
}
|
||||
)
|
||||
0L
|
||||
}
|
||||
|
||||
override def prettyName: String = "array_position"
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
nullSafeCodeGen(ctx, ev, (arr, value) => {
|
||||
val pos = ctx.freshName("arrayPosition")
|
||||
val i = ctx.freshName("i")
|
||||
val getValue = CodeGenerator.getValue(arr, right.dataType, i)
|
||||
s"""
|
||||
|int $pos = 0;
|
||||
|for (int $i = 0; $i < $arr.numElements(); $i ++) {
|
||||
| if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) {
|
||||
| $pos = $i + 1;
|
||||
| break;
|
||||
| }
|
||||
|}
|
||||
|${ev.value} = (long) $pos;
|
||||
""".stripMargin
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -169,4 +169,26 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
checkEvaluation(Reverse(as7), null)
|
||||
checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b")))
|
||||
}
|
||||
|
||||
test("Array Position") {
|
||||
val a0 = Literal.create(Seq(1, null, 2, 3), ArrayType(IntegerType))
|
||||
val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
|
||||
val a2 = Literal.create(Seq(null), ArrayType(LongType))
|
||||
val a3 = Literal.create(null, ArrayType(StringType))
|
||||
|
||||
checkEvaluation(ArrayPosition(a0, Literal(3)), 4L)
|
||||
checkEvaluation(ArrayPosition(a0, Literal(1)), 1L)
|
||||
checkEvaluation(ArrayPosition(a0, Literal(0)), 0L)
|
||||
checkEvaluation(ArrayPosition(a0, Literal.create(null, IntegerType)), null)
|
||||
|
||||
checkEvaluation(ArrayPosition(a1, Literal("")), 2L)
|
||||
checkEvaluation(ArrayPosition(a1, Literal("a")), 0L)
|
||||
checkEvaluation(ArrayPosition(a1, Literal.create(null, StringType)), null)
|
||||
|
||||
checkEvaluation(ArrayPosition(a2, Literal(1L)), 0L)
|
||||
checkEvaluation(ArrayPosition(a2, Literal.create(null, LongType)), null)
|
||||
|
||||
checkEvaluation(ArrayPosition(a3, Literal("")), null)
|
||||
checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3038,6 +3038,20 @@ object functions {
|
|||
ArrayContains(column.expr, Literal(value))
|
||||
}
|
||||
|
||||
/**
|
||||
* Locates the position of the first occurrence of the value in the given array as long.
|
||||
* Returns null if either of the arguments are null.
|
||||
*
|
||||
* @note The position is not zero based, but 1 based index. Returns 0 if value
|
||||
* could not be found in array.
|
||||
*
|
||||
* @group collection_funcs
|
||||
* @since 2.4.0
|
||||
*/
|
||||
def array_position(column: Column, value: Any): Column = withExpr {
|
||||
ArrayPosition(column.expr, Literal(value))
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new row for each element in the given array or map column.
|
||||
*
|
||||
|
|
|
@ -535,6 +535,40 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
|
|||
}
|
||||
}
|
||||
|
||||
test("array position function") {
|
||||
val df = Seq(
|
||||
(Seq[Int](1, 2), "x"),
|
||||
(Seq[Int](), "x")
|
||||
).toDF("a", "b")
|
||||
|
||||
checkAnswer(
|
||||
df.select(array_position(df("a"), 1)),
|
||||
Seq(Row(1L), Row(0L))
|
||||
)
|
||||
checkAnswer(
|
||||
df.selectExpr("array_position(a, 1)"),
|
||||
Seq(Row(1L), Row(0L))
|
||||
)
|
||||
|
||||
checkAnswer(
|
||||
df.select(array_position(df("a"), null)),
|
||||
Seq(Row(null), Row(null))
|
||||
)
|
||||
checkAnswer(
|
||||
df.selectExpr("array_position(a, null)"),
|
||||
Seq(Row(null), Row(null))
|
||||
)
|
||||
|
||||
checkAnswer(
|
||||
df.selectExpr("array_position(array(array(1), null)[0], 1)"),
|
||||
Seq(Row(1L), Row(1L))
|
||||
)
|
||||
checkAnswer(
|
||||
df.selectExpr("array_position(array(1, null), array(1, null)[0])"),
|
||||
Seq(Row(1L), Row(1L))
|
||||
)
|
||||
}
|
||||
|
||||
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
|
||||
import DataFrameFunctionsSuite.CodegenFallbackExpr
|
||||
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
|
||||
|
|
Loading…
Reference in a new issue