[SPARK-23930][SQL] Add slice function
## What changes were proposed in this pull request? The PR add the `slice` function. The behavior of the function is based on Presto's one. The function slices an array according to the requested start index and length. ## How was this patch tested? added UTs Author: Marco Gaido <marcogaido91@gmail.com> Closes #21040 from mgaido91/SPARK-23930.
This commit is contained in:
parent
f06528015d
commit
e35ad3cadd
|
@ -1834,6 +1834,19 @@ def array_contains(col, value):
|
|||
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
|
||||
|
||||
|
||||
@since(2.4)
|
||||
def slice(x, start, length):
|
||||
"""
|
||||
Collection function: returns an array containing all the elements in `x` from index `start`
|
||||
(or starting from the end if `start` is negative) with the specified `length`.
|
||||
>>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x'])
|
||||
>>> df.select(slice(df.x, 2, 2).alias("sliced")).collect()
|
||||
[Row(sliced=[2, 3]), Row(sliced=[5])]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
return Column(sc._jvm.functions.slice(_to_java_column(x), start, length))
|
||||
|
||||
|
||||
@ignore_unicode_prefix
|
||||
@since(2.4)
|
||||
def array_join(col, delimiter, null_replacement=None):
|
||||
|
|
|
@ -410,6 +410,7 @@ object FunctionRegistry {
|
|||
expression[MapKeys]("map_keys"),
|
||||
expression[MapValues]("map_values"),
|
||||
expression[Size]("size"),
|
||||
expression[Slice]("slice"),
|
||||
expression[Size]("cardinality"),
|
||||
expression[SortArray]("sort_array"),
|
||||
expression[ArrayMin]("array_min"),
|
||||
|
|
|
@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
|
|||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.Platform
|
||||
import org.apache.spark.unsafe.array.ByteArrayMethods
|
||||
import org.apache.spark.unsafe.types._
|
||||
import org.apache.spark.util.{ParentClassLoader, Utils}
|
||||
|
||||
|
@ -730,6 +731,39 @@ class CodegenContext {
|
|||
""".stripMargin
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates code creating a [[UnsafeArrayData]].
|
||||
*
|
||||
* @param arrayName name of the array to create
|
||||
* @param numElements code representing the number of elements the array should contain
|
||||
* @param elementType data type of the elements in the array
|
||||
* @param additionalErrorMessage string to include in the error message
|
||||
*/
|
||||
def createUnsafeArray(
|
||||
arrayName: String,
|
||||
numElements: String,
|
||||
elementType: DataType,
|
||||
additionalErrorMessage: String): String = {
|
||||
val arraySize = freshName("size")
|
||||
val arrayBytes = freshName("arrayBytes")
|
||||
|
||||
s"""
|
||||
|long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
|
||||
| $numElements,
|
||||
| ${elementType.defaultSize});
|
||||
|if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
|
||||
| throw new RuntimeException("Unsuccessful try create array with " + $arraySize +
|
||||
| " bytes of data due to exceeding the limit " +
|
||||
| "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for UnsafeArrayData." +
|
||||
| "$additionalErrorMessage");
|
||||
|}
|
||||
|byte[] $arrayBytes = new byte[(int)$arraySize];
|
||||
|UnsafeArrayData $arrayName = new UnsafeArrayData();
|
||||
|Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements);
|
||||
|$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize);
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates code to do null safe execution, i.e. only execute the code when the input is not
|
||||
* null by adding null check if necessary.
|
||||
|
|
|
@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
|
|||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.Platform
|
||||
import org.apache.spark.unsafe.array.ByteArrayMethods
|
||||
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
|
||||
|
||||
|
@ -530,6 +529,129 @@ case class ArrayContains(left: Expression, right: Expression)
|
|||
override def prettyName: String = "array_contains"
|
||||
}
|
||||
|
||||
/**
|
||||
* Slices an array according to the requested start index and length
|
||||
*/
|
||||
// scalastyle:off line.size.limit
|
||||
@ExpressionDescription(
|
||||
usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.",
|
||||
examples = """
|
||||
Examples:
|
||||
> SELECT _FUNC_(array(1, 2, 3, 4), 2, 2);
|
||||
[2,3]
|
||||
> SELECT _FUNC_(array(1, 2, 3, 4), -2, 2);
|
||||
[3,4]
|
||||
""", since = "2.4.0")
|
||||
// scalastyle:on line.size.limit
|
||||
case class Slice(x: Expression, start: Expression, length: Expression)
|
||||
extends TernaryExpression with ImplicitCastInputTypes {
|
||||
|
||||
override def dataType: DataType = x.dataType
|
||||
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType)
|
||||
|
||||
override def children: Seq[Expression] = Seq(x, start, length)
|
||||
|
||||
lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType
|
||||
|
||||
override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = {
|
||||
val startInt = startVal.asInstanceOf[Int]
|
||||
val lengthInt = lengthVal.asInstanceOf[Int]
|
||||
val arr = xVal.asInstanceOf[ArrayData]
|
||||
val startIndex = if (startInt == 0) {
|
||||
throw new RuntimeException(
|
||||
s"Unexpected value for start in function $prettyName: SQL array indices start at 1.")
|
||||
} else if (startInt < 0) {
|
||||
startInt + arr.numElements()
|
||||
} else {
|
||||
startInt - 1
|
||||
}
|
||||
if (lengthInt < 0) {
|
||||
throw new RuntimeException(s"Unexpected value for length in function $prettyName: " +
|
||||
"length must be greater than or equal to 0.")
|
||||
}
|
||||
// startIndex can be negative if start is negative and its absolute value is greater than the
|
||||
// number of elements in the array
|
||||
if (startIndex < 0 || startIndex >= arr.numElements()) {
|
||||
return new GenericArrayData(Array.empty[AnyRef])
|
||||
}
|
||||
val data = arr.toSeq[AnyRef](elementType)
|
||||
new GenericArrayData(data.slice(startIndex, startIndex + lengthInt))
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
nullSafeCodeGen(ctx, ev, (x, start, length) => {
|
||||
val startIdx = ctx.freshName("startIdx")
|
||||
val resLength = ctx.freshName("resLength")
|
||||
val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false)
|
||||
s"""
|
||||
|${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue;
|
||||
|${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue;
|
||||
|if ($start == 0) {
|
||||
| throw new RuntimeException("Unexpected value for start in function $prettyName: "
|
||||
| + "SQL array indices start at 1.");
|
||||
|} else if ($start < 0) {
|
||||
| $startIdx = $start + $x.numElements();
|
||||
|} else {
|
||||
| // arrays in SQL are 1-based instead of 0-based
|
||||
| $startIdx = $start - 1;
|
||||
|}
|
||||
|if ($length < 0) {
|
||||
| throw new RuntimeException("Unexpected value for length in function $prettyName: "
|
||||
| + "length must be greater than or equal to 0.");
|
||||
|} else if ($length > $x.numElements() - $startIdx) {
|
||||
| $resLength = $x.numElements() - $startIdx;
|
||||
|} else {
|
||||
| $resLength = $length;
|
||||
|}
|
||||
|${genCodeForResult(ctx, ev, x, startIdx, resLength)}
|
||||
""".stripMargin
|
||||
})
|
||||
}
|
||||
|
||||
def genCodeForResult(
|
||||
ctx: CodegenContext,
|
||||
ev: ExprCode,
|
||||
inputArray: String,
|
||||
startIdx: String,
|
||||
resLength: String): String = {
|
||||
val values = ctx.freshName("values")
|
||||
val i = ctx.freshName("i")
|
||||
val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx")
|
||||
if (!CodeGenerator.isPrimitiveType(elementType)) {
|
||||
val arrayClass = classOf[GenericArrayData].getName
|
||||
s"""
|
||||
|Object[] $values;
|
||||
|if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
|
||||
| $values = new Object[0];
|
||||
|} else {
|
||||
| $values = new Object[$resLength];
|
||||
| for (int $i = 0; $i < $resLength; $i ++) {
|
||||
| $values[$i] = $getValue;
|
||||
| }
|
||||
|}
|
||||
|${ev.value} = new $arrayClass($values);
|
||||
""".stripMargin
|
||||
} else {
|
||||
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
|
||||
s"""
|
||||
|if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
|
||||
| $resLength = 0;
|
||||
|}
|
||||
|${ctx.createUnsafeArray(values, resLength, elementType, s" $prettyName failed.")}
|
||||
|for (int $i = 0; $i < $resLength; $i ++) {
|
||||
| if ($inputArray.isNullAt($i + $startIdx)) {
|
||||
| $values.setNullAt($i);
|
||||
| } else {
|
||||
| $values.set$primitiveValueTypeName($i, $getValue);
|
||||
| }
|
||||
|}
|
||||
|${ev.value} = $values;
|
||||
""".stripMargin
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a String containing all the elements of the input array separated by the delimiter.
|
||||
*/
|
||||
|
@ -1127,24 +1249,11 @@ case class Concat(children: Seq[Expression]) extends Expression {
|
|||
}
|
||||
|
||||
private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
|
||||
val arrayName = ctx.freshName("array")
|
||||
val arraySizeName = ctx.freshName("size")
|
||||
val counter = ctx.freshName("counter")
|
||||
val arrayData = ctx.freshName("arrayData")
|
||||
|
||||
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
|
||||
|
||||
val unsafeArraySizeInBytes = s"""
|
||||
|long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
|
||||
| $numElemName,
|
||||
| ${elementType.defaultSize});
|
||||
|if ($arraySizeName > $MAX_ARRAY_LENGTH) {
|
||||
| throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName +
|
||||
| " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" +
|
||||
| " for UnsafeArrayData.");
|
||||
|}
|
||||
""".stripMargin
|
||||
val baseOffset = Platform.BYTE_ARRAY_OFFSET
|
||||
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
|
||||
|
||||
s"""
|
||||
|
@ -1152,11 +1261,7 @@ case class Concat(children: Seq[Expression]) extends Expression {
|
|||
| public ArrayData concat($javaType[] args) {
|
||||
| ${nullArgumentProtection()}
|
||||
| $numElemCode
|
||||
| $unsafeArraySizeInBytes
|
||||
| byte[] $arrayName = new byte[(int)$arraySizeName];
|
||||
| UnsafeArrayData $arrayData = new UnsafeArrayData();
|
||||
| Platform.putLong($arrayName, $baseOffset, $numElemName);
|
||||
| $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
|
||||
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
|
||||
| int $counter = 0;
|
||||
| for (int y = 0; y < ${children.length}; y++) {
|
||||
| for (int z = 0; z < args[y].numElements(); z++) {
|
||||
|
@ -1308,34 +1413,16 @@ case class Flatten(child: Expression) extends UnaryExpression {
|
|||
ctx: CodegenContext,
|
||||
childVariableName: String,
|
||||
arrayDataName: String): String = {
|
||||
val arrayName = ctx.freshName("array")
|
||||
val arraySizeName = ctx.freshName("size")
|
||||
val counter = ctx.freshName("counter")
|
||||
val tempArrayDataName = ctx.freshName("tempArrayData")
|
||||
|
||||
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName)
|
||||
|
||||
val unsafeArraySizeInBytes = s"""
|
||||
|long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
|
||||
| $numElemName,
|
||||
| ${elementType.defaultSize});
|
||||
|if ($arraySizeName > $MAX_ARRAY_LENGTH) {
|
||||
| throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
|
||||
| $arraySizeName + " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH" +
|
||||
| " bytes for UnsafeArrayData.");
|
||||
|}
|
||||
""".stripMargin
|
||||
val baseOffset = Platform.BYTE_ARRAY_OFFSET
|
||||
|
||||
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
|
||||
|
||||
s"""
|
||||
|$numElemCode
|
||||
|$unsafeArraySizeInBytes
|
||||
|byte[] $arrayName = new byte[(int)$arraySizeName];
|
||||
|UnsafeArrayData $tempArrayDataName = new UnsafeArrayData();
|
||||
|Platform.putLong($arrayName, $baseOffset, $numElemName);
|
||||
|$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
|
||||
|${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" $prettyName failed.")}
|
||||
|int $counter = 0;
|
||||
|for (int k = 0; k < $childVariableName.numElements(); k++) {
|
||||
| ArrayData arr = $childVariableName.getArray(k);
|
||||
|
|
|
@ -136,6 +136,34 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
|
||||
}
|
||||
|
||||
test("Slice") {
|
||||
val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType))
|
||||
val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType))
|
||||
val a2 = Literal.create(Seq[String]("", null, "a", "b"), ArrayType(StringType))
|
||||
val a3 = Literal.create(Seq(1, 2, null, 4), ArrayType(IntegerType))
|
||||
|
||||
checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2))
|
||||
checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5))
|
||||
checkEvaluation(Slice(a0, Literal(4), Literal(10)), Seq(4, 5, 6))
|
||||
checkEvaluation(Slice(a0, Literal(-1), Literal(2)), Seq(6))
|
||||
checkExceptionInExpression[RuntimeException](Slice(a0, Literal(1), Literal(-1)),
|
||||
"Unexpected value for length")
|
||||
checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), Literal(1)),
|
||||
"Unexpected value for start")
|
||||
checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int])
|
||||
checkEvaluation(Slice(a1, Literal(-20), Literal(1)), Seq.empty[String])
|
||||
checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), null)
|
||||
checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), null)
|
||||
checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), Literal(1), Literal(2)),
|
||||
null)
|
||||
|
||||
checkEvaluation(Slice(a1, Literal(1), Literal(2)), Seq("a", "b"))
|
||||
checkEvaluation(Slice(a2, Literal(1), Literal(2)), Seq("", null))
|
||||
checkEvaluation(Slice(a0, Literal(10), Literal(1)), Seq.empty[Int])
|
||||
checkEvaluation(Slice(a1, Literal(10), Literal(1)), Seq.empty[String])
|
||||
checkEvaluation(Slice(a3, Literal(2), Literal(3)), Seq(2, null, 4))
|
||||
}
|
||||
|
||||
test("ArrayJoin") {
|
||||
def testArrays(
|
||||
arrays: Seq[Expression],
|
||||
|
|
|
@ -104,6 +104,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
|
|||
}
|
||||
}
|
||||
|
||||
protected def checkExceptionInExpression[T <: Throwable : ClassTag](
|
||||
expression: => Expression,
|
||||
expectedErrMsg: String): Unit = {
|
||||
checkExceptionInExpression[T](expression, InternalRow.empty, expectedErrMsg)
|
||||
}
|
||||
|
||||
protected def checkExceptionInExpression[T <: Throwable : ClassTag](
|
||||
expression: => Expression,
|
||||
inputRow: InternalRow,
|
||||
|
|
|
@ -223,7 +223,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
Literal.fromObject(new java.util.LinkedList[Int]),
|
||||
Map("nonexisting" -> Literal(1)))
|
||||
checkExceptionInExpression[Exception](initializeWithNonexistingMethod,
|
||||
InternalRow.fromSeq(Seq()),
|
||||
"""A method named "nonexisting" is not declared in any enclosing class """ +
|
||||
"nor any supertype")
|
||||
|
||||
|
|
|
@ -3039,6 +3039,16 @@ object functions {
|
|||
ArrayContains(column.expr, Literal(value))
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an array containing all the elements in `x` from index `start` (or starting from the
|
||||
* end if `start` is negative) with the specified `length`.
|
||||
* @group collection_funcs
|
||||
* @since 2.4.0
|
||||
*/
|
||||
def slice(x: Column, start: Int, length: Int): Column = withExpr {
|
||||
Slice(x.expr, Literal(start), Literal(length))
|
||||
}
|
||||
|
||||
/**
|
||||
* Concatenates the elements of `column` using the `delimiter`. Null values are replaced with
|
||||
* `nullReplacement`.
|
||||
|
|
|
@ -442,6 +442,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
|
|||
)
|
||||
}
|
||||
|
||||
test("slice function") {
|
||||
val df = Seq(
|
||||
Seq(1, 2, 3),
|
||||
Seq(4, 5)
|
||||
).toDF("x")
|
||||
|
||||
val answer = Seq(Row(Seq(2, 3)), Row(Seq(5)))
|
||||
|
||||
checkAnswer(df.select(slice(df("x"), 2, 2)), answer)
|
||||
checkAnswer(df.selectExpr("slice(x, 2, 2)"), answer)
|
||||
|
||||
val answerNegative = Seq(Row(Seq(3)), Row(Seq(5)))
|
||||
checkAnswer(df.select(slice(df("x"), -1, 1)), answerNegative)
|
||||
checkAnswer(df.selectExpr("slice(x, -1, 1)"), answerNegative)
|
||||
}
|
||||
|
||||
test("array_join function") {
|
||||
val df = Seq(
|
||||
(Seq[String]("a", "b"), ","),
|
||||
|
|
Loading…
Reference in a new issue