[SPARK-23931][SQL] Adds arrays_zip function to sparksql
Signed-off-by: DylanGuedes <djmgguedesgmail.com> ## What changes were proposed in this pull request? Addition of arrays_zip function to spark sql functions. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) Unit tests that checks if the results are correct. Author: DylanGuedes <djmgguedes@gmail.com> Closes #21045 from DylanGuedes/SPARK-23931.
This commit is contained in:
parent
3af1d3e6d9
commit
f0ef1b311d
|
@ -2394,6 +2394,23 @@ def array_repeat(col, count):
|
|||
return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count))
|
||||
|
||||
|
||||
@since(2.4)
|
||||
def arrays_zip(*cols):
|
||||
"""
|
||||
Collection function: Returns a merged array of structs in which the N-th struct contains all
|
||||
N-th values of input arrays.
|
||||
|
||||
:param cols: columns of arrays to be merged.
|
||||
|
||||
>>> from pyspark.sql.functions import arrays_zip
|
||||
>>> df = spark.createDataFrame([(([1, 2, 3], [2, 3, 4]))], ['vals1', 'vals2'])
|
||||
>>> df.select(arrays_zip(df.vals1, df.vals2).alias('zipped')).collect()
|
||||
[Row(zipped=[Row(vals1=1, vals2=2), Row(vals1=2, vals2=3), Row(vals1=3, vals2=4)])]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column)))
|
||||
|
||||
|
||||
# ---------------------------- User Defined Function ----------------------------------
|
||||
|
||||
class PandasUDFType(object):
|
||||
|
|
|
@ -423,6 +423,7 @@ object FunctionRegistry {
|
|||
expression[Size]("size"),
|
||||
expression[Slice]("slice"),
|
||||
expression[Size]("cardinality"),
|
||||
expression[ArraysZip]("arrays_zip"),
|
||||
expression[SortArray]("sort_array"),
|
||||
expression[ArrayMin]("array_min"),
|
||||
expression[ArrayMax]("array_max"),
|
||||
|
|
|
@ -128,6 +128,172 @@ case class MapKeys(child: Expression)
|
|||
override def prettyName: String = "map_keys"
|
||||
}
|
||||
|
||||
@ExpressionDescription(
|
||||
usage = """
|
||||
_FUNC_(a1, a2, ...) - Returns a merged array of structs in which the N-th struct contains all
|
||||
N-th values of input arrays.
|
||||
""",
|
||||
examples = """
|
||||
Examples:
|
||||
> SELECT _FUNC_(array(1, 2, 3), array(2, 3, 4));
|
||||
[[1, 2], [2, 3], [3, 4]]
|
||||
> SELECT _FUNC_(array(1, 2), array(2, 3), array(3, 4));
|
||||
[[1, 2, 3], [2, 3, 4]]
|
||||
""",
|
||||
since = "2.4.0")
|
||||
case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsInputTypes {
|
||||
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType)
|
||||
|
||||
override def dataType: DataType = ArrayType(mountSchema)
|
||||
|
||||
override def nullable: Boolean = children.exists(_.nullable)
|
||||
|
||||
private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType])
|
||||
|
||||
private lazy val arrayElementTypes = arrayTypes.map(_.elementType)
|
||||
|
||||
@transient private lazy val mountSchema: StructType = {
|
||||
val fields = children.zip(arrayElementTypes).zipWithIndex.map {
|
||||
case ((expr: NamedExpression, elementType), _) =>
|
||||
StructField(expr.name, elementType, nullable = true)
|
||||
case ((_, elementType), idx) =>
|
||||
StructField(idx.toString, elementType, nullable = true)
|
||||
}
|
||||
StructType(fields)
|
||||
}
|
||||
|
||||
@transient lazy val numberOfArrays: Int = children.length
|
||||
|
||||
@transient lazy val genericArrayData = classOf[GenericArrayData].getName
|
||||
|
||||
def emptyInputGenCode(ev: ExprCode): ExprCode = {
|
||||
ev.copy(code"""
|
||||
|${CodeGenerator.javaType(dataType)} ${ev.value} = new $genericArrayData(new Object[0]);
|
||||
|boolean ${ev.isNull} = false;
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
def nonEmptyInputGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val genericInternalRow = classOf[GenericInternalRow].getName
|
||||
val arrVals = ctx.freshName("arrVals")
|
||||
val biggestCardinality = ctx.freshName("biggestCardinality")
|
||||
|
||||
val currentRow = ctx.freshName("currentRow")
|
||||
val j = ctx.freshName("j")
|
||||
val i = ctx.freshName("i")
|
||||
val args = ctx.freshName("args")
|
||||
|
||||
val evals = children.map(_.genCode(ctx))
|
||||
val getValuesAndCardinalities = evals.zipWithIndex.map { case (eval, index) =>
|
||||
s"""
|
||||
|if ($biggestCardinality != -1) {
|
||||
| ${eval.code}
|
||||
| if (!${eval.isNull}) {
|
||||
| $arrVals[$index] = ${eval.value};
|
||||
| $biggestCardinality = Math.max($biggestCardinality, ${eval.value}.numElements());
|
||||
| } else {
|
||||
| $biggestCardinality = -1;
|
||||
| }
|
||||
|}
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
val splittedGetValuesAndCardinalities = ctx.splitExpressions(
|
||||
expressions = getValuesAndCardinalities,
|
||||
funcName = "getValuesAndCardinalities",
|
||||
returnType = "int",
|
||||
makeSplitFunction = body =>
|
||||
s"""
|
||||
|$body
|
||||
|return $biggestCardinality;
|
||||
""".stripMargin,
|
||||
foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"),
|
||||
arguments =
|
||||
("ArrayData[]", arrVals) ::
|
||||
("int", biggestCardinality) :: Nil)
|
||||
|
||||
val getValueForType = arrayElementTypes.zipWithIndex.map { case (eleType, idx) =>
|
||||
val g = CodeGenerator.getValue(s"$arrVals[$idx]", eleType, i)
|
||||
s"""
|
||||
|if ($i < $arrVals[$idx].numElements() && !$arrVals[$idx].isNullAt($i)) {
|
||||
| $currentRow[$idx] = $g;
|
||||
|} else {
|
||||
| $currentRow[$idx] = null;
|
||||
|}
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
val getValueForTypeSplitted = ctx.splitExpressions(
|
||||
expressions = getValueForType,
|
||||
funcName = "extractValue",
|
||||
arguments =
|
||||
("int", i) ::
|
||||
("Object[]", currentRow) ::
|
||||
("ArrayData[]", arrVals) :: Nil)
|
||||
|
||||
val initVariables = s"""
|
||||
|ArrayData[] $arrVals = new ArrayData[$numberOfArrays];
|
||||
|int $biggestCardinality = 0;
|
||||
|${CodeGenerator.javaType(dataType)} ${ev.value} = null;
|
||||
""".stripMargin
|
||||
|
||||
ev.copy(code"""
|
||||
|$initVariables
|
||||
|$splittedGetValuesAndCardinalities
|
||||
|boolean ${ev.isNull} = $biggestCardinality == -1;
|
||||
|if (!${ev.isNull}) {
|
||||
| Object[] $args = new Object[$biggestCardinality];
|
||||
| for (int $i = 0; $i < $biggestCardinality; $i ++) {
|
||||
| Object[] $currentRow = new Object[$numberOfArrays];
|
||||
| $getValueForTypeSplitted
|
||||
| $args[$i] = new $genericInternalRow($currentRow);
|
||||
| }
|
||||
| ${ev.value} = new $genericArrayData($args);
|
||||
|}
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
if (numberOfArrays == 0) {
|
||||
emptyInputGenCode(ev)
|
||||
} else {
|
||||
nonEmptyInputGenCode(ctx, ev)
|
||||
}
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val inputArrays = children.map(_.eval(input).asInstanceOf[ArrayData])
|
||||
if (inputArrays.contains(null)) {
|
||||
null
|
||||
} else {
|
||||
val biggestCardinality = if (inputArrays.isEmpty) {
|
||||
0
|
||||
} else {
|
||||
inputArrays.map(_.numElements()).max
|
||||
}
|
||||
|
||||
val result = new Array[InternalRow](biggestCardinality)
|
||||
val zippedArrs: Seq[(ArrayData, Int)] = inputArrays.zipWithIndex
|
||||
|
||||
for (i <- 0 until biggestCardinality) {
|
||||
val currentLayer: Seq[Object] = zippedArrs.map { case (arr, index) =>
|
||||
if (i < arr.numElements() && !arr.isNullAt(i)) {
|
||||
arr.get(i, arrayElementTypes(index))
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
result(i) = InternalRow.apply(currentLayer: _*)
|
||||
}
|
||||
new GenericArrayData(result)
|
||||
}
|
||||
}
|
||||
|
||||
override def prettyName: String = "arrays_zip"
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an unordered array containing the values of the map.
|
||||
*/
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -315,6 +316,91 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
Some(Literal.create(null, StringType))), null)
|
||||
}
|
||||
|
||||
test("ArraysZip") {
|
||||
val literals = Seq(
|
||||
Literal.create(Seq(9001, 9002, 9003, null), ArrayType(IntegerType)),
|
||||
Literal.create(Seq(null, 1L, null, 4L, 11L), ArrayType(LongType)),
|
||||
Literal.create(Seq(-1, -3, 900, null), ArrayType(IntegerType)),
|
||||
Literal.create(Seq("a", null, "c"), ArrayType(StringType)),
|
||||
Literal.create(Seq(null, false, true), ArrayType(BooleanType)),
|
||||
Literal.create(Seq(1.1, null, 1.3, null), ArrayType(DoubleType)),
|
||||
Literal.create(Seq(), ArrayType(NullType)),
|
||||
Literal.create(Seq(null), ArrayType(NullType)),
|
||||
Literal.create(Seq(192.toByte), ArrayType(ByteType)),
|
||||
Literal.create(
|
||||
Seq(Seq(1, 2, 3), null, Seq(4, 5), Seq(1, null, 3)), ArrayType(ArrayType(IntegerType))),
|
||||
Literal.create(Seq(Array[Byte](1.toByte, 5.toByte)), ArrayType(BinaryType))
|
||||
)
|
||||
|
||||
checkEvaluation(ArraysZip(Seq(literals(0), literals(1))),
|
||||
List(Row(9001, null), Row(9002, 1L), Row(9003, null), Row(null, 4L), Row(null, 11L)))
|
||||
|
||||
checkEvaluation(ArraysZip(Seq(literals(0), literals(2))),
|
||||
List(Row(9001, -1), Row(9002, -3), Row(9003, 900), Row(null, null)))
|
||||
|
||||
checkEvaluation(ArraysZip(Seq(literals(0), literals(3))),
|
||||
List(Row(9001, "a"), Row(9002, null), Row(9003, "c"), Row(null, null)))
|
||||
|
||||
checkEvaluation(ArraysZip(Seq(literals(0), literals(4))),
|
||||
List(Row(9001, null), Row(9002, false), Row(9003, true), Row(null, null)))
|
||||
|
||||
checkEvaluation(ArraysZip(Seq(literals(0), literals(5))),
|
||||
List(Row(9001, 1.1), Row(9002, null), Row(9003, 1.3), Row(null, null)))
|
||||
|
||||
checkEvaluation(ArraysZip(Seq(literals(0), literals(6))),
|
||||
List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null)))
|
||||
|
||||
checkEvaluation(ArraysZip(Seq(literals(0), literals(7))),
|
||||
List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null)))
|
||||
|
||||
checkEvaluation(ArraysZip(Seq(literals(0), literals(1), literals(2), literals(3))),
|
||||
List(
|
||||
Row(9001, null, -1, "a"),
|
||||
Row(9002, 1L, -3, null),
|
||||
Row(9003, null, 900, "c"),
|
||||
Row(null, 4L, null, null),
|
||||
Row(null, 11L, null, null)))
|
||||
|
||||
checkEvaluation(ArraysZip(Seq(literals(4), literals(5), literals(6), literals(7), literals(8))),
|
||||
List(
|
||||
Row(null, 1.1, null, null, 192.toByte),
|
||||
Row(false, null, null, null, null),
|
||||
Row(true, 1.3, null, null, null),
|
||||
Row(null, null, null, null, null)))
|
||||
|
||||
checkEvaluation(ArraysZip(Seq(literals(9), literals(0))),
|
||||
List(
|
||||
Row(List(1, 2, 3), 9001),
|
||||
Row(null, 9002),
|
||||
Row(List(4, 5), 9003),
|
||||
Row(List(1, null, 3), null)))
|
||||
|
||||
checkEvaluation(ArraysZip(Seq(literals(7), literals(10))),
|
||||
List(Row(null, Array[Byte](1.toByte, 5.toByte))))
|
||||
|
||||
val longLiteral =
|
||||
Literal.create((0 to 1000).toSeq, ArrayType(IntegerType))
|
||||
|
||||
checkEvaluation(ArraysZip(Seq(literals(0), longLiteral)),
|
||||
List(Row(9001, 0), Row(9002, 1), Row(9003, 2)) ++
|
||||
(3 to 1000).map { Row(null, _) }.toList)
|
||||
|
||||
val manyLiterals = (0 to 1000).map { _ =>
|
||||
Literal.create(Seq(1), ArrayType(IntegerType))
|
||||
}.toSeq
|
||||
|
||||
val numbers = List(
|
||||
Row(Seq(9001) ++ (0 to 1000).map { _ => 1 }.toSeq: _*),
|
||||
Row(Seq(9002) ++ (0 to 1000).map { _ => null }.toSeq: _*),
|
||||
Row(Seq(9003) ++ (0 to 1000).map { _ => null }.toSeq: _*),
|
||||
Row(Seq(null) ++ (0 to 1000).map { _ => null }.toSeq: _*))
|
||||
checkEvaluation(ArraysZip(Seq(literals(0)) ++ manyLiterals),
|
||||
List(numbers(0), numbers(1), numbers(2), numbers(3)))
|
||||
|
||||
checkEvaluation(ArraysZip(Seq(literals(0), Literal.create(null, ArrayType(IntegerType)))), null)
|
||||
checkEvaluation(ArraysZip(Seq()), List())
|
||||
}
|
||||
|
||||
test("Array Min") {
|
||||
checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11)
|
||||
checkEvaluation(
|
||||
|
|
|
@ -3508,6 +3508,14 @@ object functions {
|
|||
*/
|
||||
def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) }
|
||||
|
||||
/**
|
||||
* Returns a merged array of structs in which the N-th struct contains all N-th values of input
|
||||
* arrays.
|
||||
* @group collection_funcs
|
||||
* @since 2.4.0
|
||||
*/
|
||||
def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) }
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Mask functions
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -479,6 +479,53 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
|
|||
)
|
||||
}
|
||||
|
||||
test("dataframe arrays_zip function") {
|
||||
val df1 = Seq((Seq(9001, 9002, 9003), Seq(4, 5, 6))).toDF("val1", "val2")
|
||||
val df2 = Seq((Seq("a", "b"), Seq(true, false), Seq(10, 11))).toDF("val1", "val2", "val3")
|
||||
val df3 = Seq((Seq("a", "b"), Seq(4, 5, 6))).toDF("val1", "val2")
|
||||
val df4 = Seq((Seq("a", "b", null), Seq(4L))).toDF("val1", "val2")
|
||||
val df5 = Seq((Seq(-1), Seq(null), Seq(), Seq(null, null))).toDF("val1", "val2", "val3", "val4")
|
||||
val df6 = Seq((Seq(192.toByte, 256.toByte), Seq(1.1), Seq(), Seq(null, null)))
|
||||
.toDF("v1", "v2", "v3", "v4")
|
||||
val df7 = Seq((Seq(Seq(1, 2, 3), Seq(4, 5)), Seq(1.1, 2.2))).toDF("v1", "v2")
|
||||
val df8 = Seq((Seq(Array[Byte](1.toByte, 5.toByte)), Seq(null))).toDF("v1", "v2")
|
||||
|
||||
val expectedValue1 = Row(Seq(Row(9001, 4), Row(9002, 5), Row(9003, 6)))
|
||||
checkAnswer(df1.select(arrays_zip($"val1", $"val2")), expectedValue1)
|
||||
checkAnswer(df1.selectExpr("arrays_zip(val1, val2)"), expectedValue1)
|
||||
|
||||
val expectedValue2 = Row(Seq(Row("a", true, 10), Row("b", false, 11)))
|
||||
checkAnswer(df2.select(arrays_zip($"val1", $"val2", $"val3")), expectedValue2)
|
||||
checkAnswer(df2.selectExpr("arrays_zip(val1, val2, val3)"), expectedValue2)
|
||||
|
||||
val expectedValue3 = Row(Seq(Row("a", 4), Row("b", 5), Row(null, 6)))
|
||||
checkAnswer(df3.select(arrays_zip($"val1", $"val2")), expectedValue3)
|
||||
checkAnswer(df3.selectExpr("arrays_zip(val1, val2)"), expectedValue3)
|
||||
|
||||
val expectedValue4 = Row(Seq(Row("a", 4L), Row("b", null), Row(null, null)))
|
||||
checkAnswer(df4.select(arrays_zip($"val1", $"val2")), expectedValue4)
|
||||
checkAnswer(df4.selectExpr("arrays_zip(val1, val2)"), expectedValue4)
|
||||
|
||||
val expectedValue5 = Row(Seq(Row(-1, null, null, null), Row(null, null, null, null)))
|
||||
checkAnswer(df5.select(arrays_zip($"val1", $"val2", $"val3", $"val4")), expectedValue5)
|
||||
checkAnswer(df5.selectExpr("arrays_zip(val1, val2, val3, val4)"), expectedValue5)
|
||||
|
||||
val expectedValue6 = Row(Seq(
|
||||
Row(192.toByte, 1.1, null, null), Row(256.toByte, null, null, null)))
|
||||
checkAnswer(df6.select(arrays_zip($"v1", $"v2", $"v3", $"v4")), expectedValue6)
|
||||
checkAnswer(df6.selectExpr("arrays_zip(v1, v2, v3, v4)"), expectedValue6)
|
||||
|
||||
val expectedValue7 = Row(Seq(
|
||||
Row(Seq(1, 2, 3), 1.1), Row(Seq(4, 5), 2.2)))
|
||||
checkAnswer(df7.select(arrays_zip($"v1", $"v2")), expectedValue7)
|
||||
checkAnswer(df7.selectExpr("arrays_zip(v1, v2)"), expectedValue7)
|
||||
|
||||
val expectedValue8 = Row(Seq(
|
||||
Row(Array[Byte](1.toByte, 5.toByte), null)))
|
||||
checkAnswer(df8.select(arrays_zip($"v1", $"v2")), expectedValue8)
|
||||
checkAnswer(df8.selectExpr("arrays_zip(v1, v2)"), expectedValue8)
|
||||
}
|
||||
|
||||
test("map size function") {
|
||||
val df = Seq(
|
||||
(Map[Int, Int](1 -> 1, 2 -> 2), "x"),
|
||||
|
|
Loading…
Reference in a new issue