[SPARK-23930][SQL] Add slice function

## What changes were proposed in this pull request?

The PR add the `slice` function. The behavior of the function is based on Presto's one.

The function slices an array according to the requested start index and length.

## How was this patch tested?

added UTs

Author: Marco Gaido <marcogaido91@gmail.com>

Closes #21040 from mgaido91/SPARK-23930.
This commit is contained in:
Marco Gaido 2018-05-07 16:57:37 +09:00 committed by Takuya UESHIN
parent f06528015d
commit e35ad3cadd
9 changed files with 233 additions and 39 deletions

View file

@ -1834,6 +1834,19 @@ def array_contains(col, value):
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
@since(2.4)
def slice(x, start, length):
"""
Collection function: returns an array containing all the elements in `x` from index `start`
(or starting from the end if `start` is negative) with the specified `length`.
>>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x'])
>>> df.select(slice(df.x, 2, 2).alias("sliced")).collect()
[Row(sliced=[2, 3]), Row(sliced=[5])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.slice(_to_java_column(x), start, length))
@ignore_unicode_prefix
@since(2.4)
def array_join(col, delimiter, null_replacement=None):

View file

@ -410,6 +410,7 @@ object FunctionRegistry {
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
expression[Size]("size"),
expression[Slice]("slice"),
expression[Size]("cardinality"),
expression[SortArray]("sort_array"),
expression[ArrayMin]("array_min"),

View file

@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types._
import org.apache.spark.util.{ParentClassLoader, Utils}
@ -730,6 +731,39 @@ class CodegenContext {
""".stripMargin
}
/**
* Generates code creating a [[UnsafeArrayData]].
*
* @param arrayName name of the array to create
* @param numElements code representing the number of elements the array should contain
* @param elementType data type of the elements in the array
* @param additionalErrorMessage string to include in the error message
*/
def createUnsafeArray(
arrayName: String,
numElements: String,
elementType: DataType,
additionalErrorMessage: String): String = {
val arraySize = freshName("size")
val arrayBytes = freshName("arrayBytes")
s"""
|long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
| $numElements,
| ${elementType.defaultSize});
|if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw new RuntimeException("Unsuccessful try create array with " + $arraySize +
| " bytes of data due to exceeding the limit " +
| "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for UnsafeArrayData." +
| "$additionalErrorMessage");
|}
|byte[] $arrayBytes = new byte[(int)$arraySize];
|UnsafeArrayData $arrayName = new UnsafeArrayData();
|Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements);
|$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize);
""".stripMargin
}
/**
* Generates code to do null safe execution, i.e. only execute the code when the input is not
* null by adding null check if necessary.

View file

@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@ -530,6 +529,129 @@ case class ArrayContains(left: Expression, right: Expression)
override def prettyName: String = "array_contains"
}
/**
* Slices an array according to the requested start index and length
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3, 4), 2, 2);
[2,3]
> SELECT _FUNC_(array(1, 2, 3, 4), -2, 2);
[3,4]
""", since = "2.4.0")
// scalastyle:on line.size.limit
case class Slice(x: Expression, start: Expression, length: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = x.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType)
override def children: Seq[Expression] = Seq(x, start, length)
lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType
override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = {
val startInt = startVal.asInstanceOf[Int]
val lengthInt = lengthVal.asInstanceOf[Int]
val arr = xVal.asInstanceOf[ArrayData]
val startIndex = if (startInt == 0) {
throw new RuntimeException(
s"Unexpected value for start in function $prettyName: SQL array indices start at 1.")
} else if (startInt < 0) {
startInt + arr.numElements()
} else {
startInt - 1
}
if (lengthInt < 0) {
throw new RuntimeException(s"Unexpected value for length in function $prettyName: " +
"length must be greater than or equal to 0.")
}
// startIndex can be negative if start is negative and its absolute value is greater than the
// number of elements in the array
if (startIndex < 0 || startIndex >= arr.numElements()) {
return new GenericArrayData(Array.empty[AnyRef])
}
val data = arr.toSeq[AnyRef](elementType)
new GenericArrayData(data.slice(startIndex, startIndex + lengthInt))
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (x, start, length) => {
val startIdx = ctx.freshName("startIdx")
val resLength = ctx.freshName("resLength")
val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false)
s"""
|${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue;
|${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue;
|if ($start == 0) {
| throw new RuntimeException("Unexpected value for start in function $prettyName: "
| + "SQL array indices start at 1.");
|} else if ($start < 0) {
| $startIdx = $start + $x.numElements();
|} else {
| // arrays in SQL are 1-based instead of 0-based
| $startIdx = $start - 1;
|}
|if ($length < 0) {
| throw new RuntimeException("Unexpected value for length in function $prettyName: "
| + "length must be greater than or equal to 0.");
|} else if ($length > $x.numElements() - $startIdx) {
| $resLength = $x.numElements() - $startIdx;
|} else {
| $resLength = $length;
|}
|${genCodeForResult(ctx, ev, x, startIdx, resLength)}
""".stripMargin
})
}
def genCodeForResult(
ctx: CodegenContext,
ev: ExprCode,
inputArray: String,
startIdx: String,
resLength: String): String = {
val values = ctx.freshName("values")
val i = ctx.freshName("i")
val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx")
if (!CodeGenerator.isPrimitiveType(elementType)) {
val arrayClass = classOf[GenericArrayData].getName
s"""
|Object[] $values;
|if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
| $values = new Object[0];
|} else {
| $values = new Object[$resLength];
| for (int $i = 0; $i < $resLength; $i ++) {
| $values[$i] = $getValue;
| }
|}
|${ev.value} = new $arrayClass($values);
""".stripMargin
} else {
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
s"""
|if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
| $resLength = 0;
|}
|${ctx.createUnsafeArray(values, resLength, elementType, s" $prettyName failed.")}
|for (int $i = 0; $i < $resLength; $i ++) {
| if ($inputArray.isNullAt($i + $startIdx)) {
| $values.setNullAt($i);
| } else {
| $values.set$primitiveValueTypeName($i, $getValue);
| }
|}
|${ev.value} = $values;
""".stripMargin
}
}
}
/**
* Creates a String containing all the elements of the input array separated by the delimiter.
*/
@ -1127,24 +1249,11 @@ case class Concat(children: Seq[Expression]) extends Expression {
}
private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
val arrayName = ctx.freshName("array")
val arraySizeName = ctx.freshName("size")
val counter = ctx.freshName("counter")
val arrayData = ctx.freshName("arrayData")
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
val unsafeArraySizeInBytes = s"""
|long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
| $numElemName,
| ${elementType.defaultSize});
|if ($arraySizeName > $MAX_ARRAY_LENGTH) {
| throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName +
| " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" +
| " for UnsafeArrayData.");
|}
""".stripMargin
val baseOffset = Platform.BYTE_ARRAY_OFFSET
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
s"""
@ -1152,11 +1261,7 @@ case class Concat(children: Seq[Expression]) extends Expression {
| public ArrayData concat($javaType[] args) {
| ${nullArgumentProtection()}
| $numElemCode
| $unsafeArraySizeInBytes
| byte[] $arrayName = new byte[(int)$arraySizeName];
| UnsafeArrayData $arrayData = new UnsafeArrayData();
| Platform.putLong($arrayName, $baseOffset, $numElemName);
| $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) {
@ -1308,34 +1413,16 @@ case class Flatten(child: Expression) extends UnaryExpression {
ctx: CodegenContext,
childVariableName: String,
arrayDataName: String): String = {
val arrayName = ctx.freshName("array")
val arraySizeName = ctx.freshName("size")
val counter = ctx.freshName("counter")
val tempArrayDataName = ctx.freshName("tempArrayData")
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName)
val unsafeArraySizeInBytes = s"""
|long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
| $numElemName,
| ${elementType.defaultSize});
|if ($arraySizeName > $MAX_ARRAY_LENGTH) {
| throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
| $arraySizeName + " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH" +
| " bytes for UnsafeArrayData.");
|}
""".stripMargin
val baseOffset = Platform.BYTE_ARRAY_OFFSET
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
s"""
|$numElemCode
|$unsafeArraySizeInBytes
|byte[] $arrayName = new byte[(int)$arraySizeName];
|UnsafeArrayData $tempArrayDataName = new UnsafeArrayData();
|Platform.putLong($arrayName, $baseOffset, $numElemName);
|$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
|${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" $prettyName failed.")}
|int $counter = 0;
|for (int k = 0; k < $childVariableName.numElements(); k++) {
| ArrayData arr = $childVariableName.getArray(k);

View file

@ -136,6 +136,34 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
}
test("Slice") {
val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType))
val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType))
val a2 = Literal.create(Seq[String]("", null, "a", "b"), ArrayType(StringType))
val a3 = Literal.create(Seq(1, 2, null, 4), ArrayType(IntegerType))
checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2))
checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5))
checkEvaluation(Slice(a0, Literal(4), Literal(10)), Seq(4, 5, 6))
checkEvaluation(Slice(a0, Literal(-1), Literal(2)), Seq(6))
checkExceptionInExpression[RuntimeException](Slice(a0, Literal(1), Literal(-1)),
"Unexpected value for length")
checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), Literal(1)),
"Unexpected value for start")
checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int])
checkEvaluation(Slice(a1, Literal(-20), Literal(1)), Seq.empty[String])
checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), null)
checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), null)
checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), Literal(1), Literal(2)),
null)
checkEvaluation(Slice(a1, Literal(1), Literal(2)), Seq("a", "b"))
checkEvaluation(Slice(a2, Literal(1), Literal(2)), Seq("", null))
checkEvaluation(Slice(a0, Literal(10), Literal(1)), Seq.empty[Int])
checkEvaluation(Slice(a1, Literal(10), Literal(1)), Seq.empty[String])
checkEvaluation(Slice(a3, Literal(2), Literal(3)), Seq(2, null, 4))
}
test("ArrayJoin") {
def testArrays(
arrays: Seq[Expression],

View file

@ -104,6 +104,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
}
}
protected def checkExceptionInExpression[T <: Throwable : ClassTag](
expression: => Expression,
expectedErrMsg: String): Unit = {
checkExceptionInExpression[T](expression, InternalRow.empty, expectedErrMsg)
}
protected def checkExceptionInExpression[T <: Throwable : ClassTag](
expression: => Expression,
inputRow: InternalRow,

View file

@ -223,7 +223,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Literal.fromObject(new java.util.LinkedList[Int]),
Map("nonexisting" -> Literal(1)))
checkExceptionInExpression[Exception](initializeWithNonexistingMethod,
InternalRow.fromSeq(Seq()),
"""A method named "nonexisting" is not declared in any enclosing class """ +
"nor any supertype")

View file

@ -3039,6 +3039,16 @@ object functions {
ArrayContains(column.expr, Literal(value))
}
/**
* Returns an array containing all the elements in `x` from index `start` (or starting from the
* end if `start` is negative) with the specified `length`.
* @group collection_funcs
* @since 2.4.0
*/
def slice(x: Column, start: Int, length: Int): Column = withExpr {
Slice(x.expr, Literal(start), Literal(length))
}
/**
* Concatenates the elements of `column` using the `delimiter`. Null values are replaced with
* `nullReplacement`.

View file

@ -442,6 +442,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}
test("slice function") {
val df = Seq(
Seq(1, 2, 3),
Seq(4, 5)
).toDF("x")
val answer = Seq(Row(Seq(2, 3)), Row(Seq(5)))
checkAnswer(df.select(slice(df("x"), 2, 2)), answer)
checkAnswer(df.selectExpr("slice(x, 2, 2)"), answer)
val answerNegative = Seq(Row(Seq(3)), Row(Seq(5)))
checkAnswer(df.select(slice(df("x"), -1, 1)), answerNegative)
checkAnswer(df.selectExpr("slice(x, -1, 1)"), answerNegative)
}
test("array_join function") {
val df = Seq(
(Seq[String]("a", "b"), ","),