[SPARK-23926][SQL] Extending reverse function to support ArrayType arguments
## What changes were proposed in this pull request? This PR extends `reverse` functions to be able to operate over array columns and covers: - Introduction of `Reverse` expression that represents logic for reversing arrays and also strings - Removal of `StringReverse` expression - A wrapper for PySpark ## How was this patch tested? New tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite ## Codegen examples ### Primitive type ``` val df = Seq( Seq(1, 3, 4, 2), null ).toDF("i") df.filter($"i".isNotNull || $"i".isNull).select(reverse($"i")).debugCodegen ``` Result: ``` /* 032 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 033 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 034 */ null : (inputadapter_row.getArray(0)); /* 035 */ /* 036 */ boolean filter_value = true; /* 037 */ /* 038 */ if (!(!inputadapter_isNull)) { /* 039 */ filter_value = inputadapter_isNull; /* 040 */ } /* 041 */ if (!filter_value) continue; /* 042 */ /* 043 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 044 */ /* 045 */ boolean project_isNull = inputadapter_isNull; /* 046 */ ArrayData project_value = null; /* 047 */ /* 048 */ if (!inputadapter_isNull) { /* 049 */ final int project_length = inputadapter_value.numElements(); /* 050 */ project_value = inputadapter_value.copy(); /* 051 */ for(int k = 0; k < project_length / 2; k++) { /* 052 */ int l = project_length - k - 1; /* 053 */ boolean isNullAtK = project_value.isNullAt(k); /* 054 */ boolean isNullAtL = project_value.isNullAt(l); /* 055 */ if(!isNullAtK) { /* 056 */ int el = project_value.getInt(k); /* 057 */ if(!isNullAtL) { /* 058 */ project_value.setInt(k, project_value.getInt(l)); /* 059 */ } else { /* 060 */ project_value.setNullAt(k); /* 061 */ } /* 062 */ project_value.setInt(l, el); /* 063 */ } else if (!isNullAtL) { /* 064 */ project_value.setInt(k, project_value.getInt(l)); /* 065 */ project_value.setNullAt(l); /* 066 */ } /* 067 */ } /* 068 */ /* 069 */ } ``` ### Non-primitive type ``` val df = Seq( Seq("a", "c", "d", "b"), null ).toDF("s") df.filter($"s".isNotNull || $"s".isNull).select(reverse($"s")).debugCodegen ``` Result: ``` /* 032 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 033 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 034 */ null : (inputadapter_row.getArray(0)); /* 035 */ /* 036 */ boolean filter_value = true; /* 037 */ /* 038 */ if (!(!inputadapter_isNull)) { /* 039 */ filter_value = inputadapter_isNull; /* 040 */ } /* 041 */ if (!filter_value) continue; /* 042 */ /* 043 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 044 */ /* 045 */ boolean project_isNull = inputadapter_isNull; /* 046 */ ArrayData project_value = null; /* 047 */ /* 048 */ if (!inputadapter_isNull) { /* 049 */ final int project_length = inputadapter_value.numElements(); /* 050 */ project_value = new org.apache.spark.sql.catalyst.util.GenericArrayData(new Object[project_length]); /* 051 */ for(int k = 0; k < project_length; k++) { /* 052 */ int l = project_length - k - 1; /* 053 */ project_value.update(k, inputadapter_value.getUTF8String(l)); /* 054 */ } /* 055 */ /* 056 */ } ``` Author: mn-mikke <mrkAha12346github> Closes #21034 from mn-mikke/feature/array-api-reverse-to-master.
This commit is contained in:
parent
cce469435d
commit
f81fa478ff
|
@ -1414,7 +1414,6 @@ _string_functions = {
|
|||
'uppercase. Words are delimited by whitespace.',
|
||||
'lower': 'Converts a string column to lower case.',
|
||||
'upper': 'Converts a string column to upper case.',
|
||||
'reverse': 'Reverses the string column and returns it as a new string column.',
|
||||
'ltrim': 'Trim the spaces from left end for the specified string value.',
|
||||
'rtrim': 'Trim the spaces from right end for the specified string value.',
|
||||
'trim': 'Trim the spaces from both ends for the specified string column.',
|
||||
|
@ -2128,6 +2127,25 @@ def sort_array(col, asc=True):
|
|||
return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc))
|
||||
|
||||
|
||||
@since(1.5)
|
||||
@ignore_unicode_prefix
|
||||
def reverse(col):
|
||||
"""
|
||||
Collection function: returns a reversed string or an array with reverse order of elements.
|
||||
|
||||
:param col: name of column or expression
|
||||
|
||||
>>> df = spark.createDataFrame([('Spark SQL',)], ['data'])
|
||||
>>> df.select(reverse(df.data).alias('s')).collect()
|
||||
[Row(s=u'LQS krapS')]
|
||||
>>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data'])
|
||||
>>> df.select(reverse(df.data).alias('r')).collect()
|
||||
[Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
return Column(sc._jvm.functions.reverse(_to_java_column(col)))
|
||||
|
||||
|
||||
@since(2.3)
|
||||
def map_keys(col):
|
||||
"""
|
||||
|
|
|
@ -336,7 +336,6 @@ object FunctionRegistry {
|
|||
expression[RegExpReplace]("regexp_replace"),
|
||||
expression[StringRepeat]("repeat"),
|
||||
expression[StringReplace]("replace"),
|
||||
expression[StringReverse]("reverse"),
|
||||
expression[RLike]("rlike"),
|
||||
expression[StringRPad]("rpad"),
|
||||
expression[StringTrimRight]("rtrim"),
|
||||
|
@ -411,6 +410,7 @@ object FunctionRegistry {
|
|||
expression[SortArray]("sort_array"),
|
||||
expression[ArrayMin]("array_min"),
|
||||
expression[ArrayMax]("array_max"),
|
||||
expression[Reverse]("reverse"),
|
||||
CreateStruct.registryEntry,
|
||||
|
||||
// misc functions
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
|||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
/**
|
||||
* Given an array or map, returns its size. Returns -1 if null.
|
||||
|
@ -212,6 +213,93 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
|
|||
override def prettyName: String = "sort_array"
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reversed string or an array with reverse order of elements.
|
||||
*/
|
||||
@ExpressionDescription(
|
||||
usage = "_FUNC_(array) - Returns a reversed string or an array with reverse order of elements.",
|
||||
examples = """
|
||||
Examples:
|
||||
> SELECT _FUNC_('Spark SQL');
|
||||
LQS krapS
|
||||
> SELECT _FUNC_(array(2, 1, 4, 3));
|
||||
[3, 4, 1, 2]
|
||||
""",
|
||||
since = "1.5.0",
|
||||
note = "Reverse logic for arrays is available since 2.4.0."
|
||||
)
|
||||
case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
|
||||
|
||||
// Input types are utilized by type coercion in ImplicitTypeCasts.
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType))
|
||||
|
||||
override def dataType: DataType = child.dataType
|
||||
|
||||
lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
|
||||
|
||||
override def nullSafeEval(input: Any): Any = input match {
|
||||
case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse)
|
||||
case s: UTF8String => s.reverse()
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
nullSafeCodeGen(ctx, ev, c => dataType match {
|
||||
case _: StringType => stringCodeGen(ev, c)
|
||||
case _: ArrayType => arrayCodeGen(ctx, ev, c)
|
||||
})
|
||||
}
|
||||
|
||||
private def stringCodeGen(ev: ExprCode, childName: String): String = {
|
||||
s"${ev.value} = ($childName).reverse();"
|
||||
}
|
||||
|
||||
private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = {
|
||||
val length = ctx.freshName("length")
|
||||
val javaElementType = CodeGenerator.javaType(elementType)
|
||||
val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType)
|
||||
|
||||
val initialization = if (isPrimitiveType) {
|
||||
s"$childName.copy()"
|
||||
} else {
|
||||
s"new ${classOf[GenericArrayData].getName()}(new Object[$length])"
|
||||
}
|
||||
|
||||
val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length
|
||||
|
||||
val swapAssigments = if (isPrimitiveType) {
|
||||
val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType)
|
||||
val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index)
|
||||
s"""|boolean isNullAtK = ${ev.value}.isNullAt(k);
|
||||
|boolean isNullAtL = ${ev.value}.isNullAt(l);
|
||||
|if(!isNullAtK) {
|
||||
| $javaElementType el = ${getCall("k")};
|
||||
| if(!isNullAtL) {
|
||||
| ${ev.value}.$setFunc(k, ${getCall("l")});
|
||||
| } else {
|
||||
| ${ev.value}.setNullAt(k);
|
||||
| }
|
||||
| ${ev.value}.$setFunc(l, el);
|
||||
|} else if (!isNullAtL) {
|
||||
| ${ev.value}.$setFunc(k, ${getCall("l")});
|
||||
| ${ev.value}.setNullAt(l);
|
||||
|}""".stripMargin
|
||||
} else {
|
||||
s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});"
|
||||
}
|
||||
|
||||
s"""
|
||||
|final int $length = $childName.numElements();
|
||||
|${ev.value} = $initialization;
|
||||
|for(int k = 0; k < $numberOfIterations; k++) {
|
||||
| int l = $length - k - 1;
|
||||
| $swapAssigments
|
||||
|}
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
override def prettyName: String = "reverse"
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the array (left) has the element (right)
|
||||
*/
|
||||
|
|
|
@ -1504,26 +1504,6 @@ case class StringRepeat(str: Expression, times: Expression)
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the reversed given string.
|
||||
*/
|
||||
@ExpressionDescription(
|
||||
usage = "_FUNC_(str) - Returns the reversed given string.",
|
||||
examples = """
|
||||
Examples:
|
||||
> SELECT _FUNC_('Spark SQL');
|
||||
LQS krapS
|
||||
""")
|
||||
case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression {
|
||||
override def convert(v: UTF8String): UTF8String = v.reverse()
|
||||
|
||||
override def prettyName: String = "reverse"
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
defineCodeGen(ctx, ev, c => s"($c).reverse()")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a string consisting of n spaces.
|
||||
*/
|
||||
|
|
|
@ -125,4 +125,48 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
checkEvaluation(
|
||||
ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123)
|
||||
}
|
||||
|
||||
test("Reverse") {
|
||||
// Primitive-type elements
|
||||
val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType))
|
||||
val ai1 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
|
||||
val ai2 = Literal.create(Seq(null, 1, null, 3), ArrayType(IntegerType))
|
||||
val ai3 = Literal.create(Seq(2, null, 4, null), ArrayType(IntegerType))
|
||||
val ai4 = Literal.create(Seq(null, null, null), ArrayType(IntegerType))
|
||||
val ai5 = Literal.create(Seq(1), ArrayType(IntegerType))
|
||||
val ai6 = Literal.create(Seq.empty, ArrayType(IntegerType))
|
||||
val ai7 = Literal.create(null, ArrayType(IntegerType))
|
||||
|
||||
checkEvaluation(Reverse(ai0), Seq(3, 4, 1, 2))
|
||||
checkEvaluation(Reverse(ai1), Seq(3, 1, 2))
|
||||
checkEvaluation(Reverse(ai2), Seq(3, null, 1, null))
|
||||
checkEvaluation(Reverse(ai3), Seq(null, 4, null, 2))
|
||||
checkEvaluation(Reverse(ai4), Seq(null, null, null))
|
||||
checkEvaluation(Reverse(ai5), Seq(1))
|
||||
checkEvaluation(Reverse(ai6), Seq.empty)
|
||||
checkEvaluation(Reverse(ai7), null)
|
||||
|
||||
// Non-primitive-type elements
|
||||
val as0 = Literal.create(Seq("b", "a", "d", "c"), ArrayType(StringType))
|
||||
val as1 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType))
|
||||
val as2 = Literal.create(Seq(null, "a", null, "c"), ArrayType(StringType))
|
||||
val as3 = Literal.create(Seq("b", null, "d", null), ArrayType(StringType))
|
||||
val as4 = Literal.create(Seq(null, null, null), ArrayType(StringType))
|
||||
val as5 = Literal.create(Seq("a"), ArrayType(StringType))
|
||||
val as6 = Literal.create(Seq.empty, ArrayType(StringType))
|
||||
val as7 = Literal.create(null, ArrayType(StringType))
|
||||
val aa = Literal.create(
|
||||
Seq(Seq("a", "b"), Seq("c", "d"), Seq("e")),
|
||||
ArrayType(ArrayType(StringType)))
|
||||
|
||||
checkEvaluation(Reverse(as0), Seq("c", "d", "a", "b"))
|
||||
checkEvaluation(Reverse(as1), Seq("c", "a", "b"))
|
||||
checkEvaluation(Reverse(as2), Seq("c", null, "a", null))
|
||||
checkEvaluation(Reverse(as3), Seq(null, "d", null, "b"))
|
||||
checkEvaluation(Reverse(as4), Seq(null, null, null))
|
||||
checkEvaluation(Reverse(as5), Seq("a"))
|
||||
checkEvaluation(Reverse(as6), Seq.empty)
|
||||
checkEvaluation(Reverse(as7), null)
|
||||
checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b")))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -629,9 +629,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
test("REVERSE") {
|
||||
val s = 'a.string.at(0)
|
||||
val row1 = create_row("abccc")
|
||||
checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1)
|
||||
checkEvaluation(StringReverse(s), "cccba", row1)
|
||||
checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1)
|
||||
checkEvaluation(Reverse(Literal("abccc")), "cccba", row1)
|
||||
checkEvaluation(Reverse(s), "cccba", row1)
|
||||
checkEvaluation(Reverse(Literal.create(null, StringType)), null, row1)
|
||||
}
|
||||
|
||||
test("SPACE") {
|
||||
|
|
|
@ -2464,14 +2464,6 @@ object functions {
|
|||
StringRepeat(str.expr, lit(n).expr)
|
||||
}
|
||||
|
||||
/**
|
||||
* Reverses the string column and returns it as a new string column.
|
||||
*
|
||||
* @group string_funcs
|
||||
* @since 1.5.0
|
||||
*/
|
||||
def reverse(str: Column): Column = withExpr { StringReverse(str.expr) }
|
||||
|
||||
/**
|
||||
* Trim the spaces from right end for the specified string value.
|
||||
*
|
||||
|
@ -3316,6 +3308,13 @@ object functions {
|
|||
*/
|
||||
def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) }
|
||||
|
||||
/**
|
||||
* Returns a reversed string or an array with reverse order of elements.
|
||||
* @group collection_funcs
|
||||
* @since 1.5.0
|
||||
*/
|
||||
def reverse(e: Column): Column = withExpr { Reverse(e.expr) }
|
||||
|
||||
/**
|
||||
* Returns an unordered array containing the keys of the map.
|
||||
* @group collection_funcs
|
||||
|
|
|
@ -441,6 +441,100 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
|
|||
checkAnswer(df.selectExpr("array_max(a)"), answer)
|
||||
}
|
||||
|
||||
test("reverse function") {
|
||||
val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on
|
||||
|
||||
// String test cases
|
||||
val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i")
|
||||
|
||||
checkAnswer(
|
||||
oneRowDF.select(reverse('s)),
|
||||
Seq(Row("krapS"))
|
||||
)
|
||||
checkAnswer(
|
||||
oneRowDF.selectExpr("reverse(s)"),
|
||||
Seq(Row("krapS"))
|
||||
)
|
||||
checkAnswer(
|
||||
oneRowDF.select(reverse('i)),
|
||||
Seq(Row("5123"))
|
||||
)
|
||||
checkAnswer(
|
||||
oneRowDF.selectExpr("reverse(i)"),
|
||||
Seq(Row("5123"))
|
||||
)
|
||||
checkAnswer(
|
||||
oneRowDF.selectExpr("reverse(null)"),
|
||||
Seq(Row(null))
|
||||
)
|
||||
|
||||
// Array test cases (primitive-type elements)
|
||||
val idf = Seq(
|
||||
Seq(1, 9, 8, 7),
|
||||
Seq(5, 8, 9, 7, 2),
|
||||
Seq.empty,
|
||||
null
|
||||
).toDF("i")
|
||||
|
||||
checkAnswer(
|
||||
idf.select(reverse('i)),
|
||||
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
|
||||
)
|
||||
checkAnswer(
|
||||
idf.filter(dummyFilter('i)).select(reverse('i)),
|
||||
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
|
||||
)
|
||||
checkAnswer(
|
||||
idf.selectExpr("reverse(i)"),
|
||||
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
|
||||
)
|
||||
checkAnswer(
|
||||
oneRowDF.selectExpr("reverse(array(1, null, 2, null))"),
|
||||
Seq(Row(Seq(null, 2, null, 1)))
|
||||
)
|
||||
checkAnswer(
|
||||
oneRowDF.filter(dummyFilter('i)).selectExpr("reverse(array(1, null, 2, null))"),
|
||||
Seq(Row(Seq(null, 2, null, 1)))
|
||||
)
|
||||
|
||||
// Array test cases (non-primitive-type elements)
|
||||
val sdf = Seq(
|
||||
Seq("c", "a", "b"),
|
||||
Seq("b", null, "c", null),
|
||||
Seq.empty,
|
||||
null
|
||||
).toDF("s")
|
||||
|
||||
checkAnswer(
|
||||
sdf.select(reverse('s)),
|
||||
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
|
||||
)
|
||||
checkAnswer(
|
||||
sdf.filter(dummyFilter('s)).select(reverse('s)),
|
||||
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
|
||||
)
|
||||
checkAnswer(
|
||||
sdf.selectExpr("reverse(s)"),
|
||||
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
|
||||
)
|
||||
checkAnswer(
|
||||
oneRowDF.selectExpr("reverse(array(array(1, 2), array(3, 4)))"),
|
||||
Seq(Row(Seq(Seq(3, 4), Seq(1, 2))))
|
||||
)
|
||||
checkAnswer(
|
||||
oneRowDF.filter(dummyFilter('s)).selectExpr("reverse(array(array(1, 2), array(3, 4)))"),
|
||||
Seq(Row(Seq(Seq(3, 4), Seq(1, 2))))
|
||||
)
|
||||
|
||||
// Error test cases
|
||||
intercept[AnalysisException] {
|
||||
oneRowDF.selectExpr("reverse(struct(1, 'a'))")
|
||||
}
|
||||
intercept[AnalysisException] {
|
||||
oneRowDF.selectExpr("reverse(map(1, 'a'))")
|
||||
}
|
||||
}
|
||||
|
||||
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
|
||||
import DataFrameFunctionsSuite.CodegenFallbackExpr
|
||||
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
|
||||
|
|
Loading…
Reference in a new issue