[SPARK-23931][SQL] Adds arrays_zip function to sparksql

Signed-off-by: DylanGuedes <djmgguedesgmail.com>

## What changes were proposed in this pull request?

Addition of arrays_zip function to spark sql functions.

## How was this patch tested?

(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
Unit tests that checks if the results are correct.

Author: DylanGuedes <djmgguedes@gmail.com>

Closes #21045 from DylanGuedes/SPARK-23931.
This commit is contained in:
DylanGuedes 2018-06-12 11:57:25 -07:00 committed by Takuya UESHIN
parent 3af1d3e6d9
commit f0ef1b311d
6 changed files with 325 additions and 0 deletions

View file

@ -2394,6 +2394,23 @@ def array_repeat(col, count):
return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count))
@since(2.4)
def arrays_zip(*cols):
"""
Collection function: Returns a merged array of structs in which the N-th struct contains all
N-th values of input arrays.
:param cols: columns of arrays to be merged.
>>> from pyspark.sql.functions import arrays_zip
>>> df = spark.createDataFrame([(([1, 2, 3], [2, 3, 4]))], ['vals1', 'vals2'])
>>> df.select(arrays_zip(df.vals1, df.vals2).alias('zipped')).collect()
[Row(zipped=[Row(vals1=1, vals2=2), Row(vals1=2, vals2=3), Row(vals1=3, vals2=4)])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column)))
# ---------------------------- User Defined Function ----------------------------------
class PandasUDFType(object):

View file

@ -423,6 +423,7 @@ object FunctionRegistry {
expression[Size]("size"),
expression[Slice]("slice"),
expression[Size]("cardinality"),
expression[ArraysZip]("arrays_zip"),
expression[SortArray]("sort_array"),
expression[ArrayMin]("array_min"),
expression[ArrayMax]("array_max"),

View file

@ -128,6 +128,172 @@ case class MapKeys(child: Expression)
override def prettyName: String = "map_keys"
}
@ExpressionDescription(
usage = """
_FUNC_(a1, a2, ...) - Returns a merged array of structs in which the N-th struct contains all
N-th values of input arrays.
""",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), array(2, 3, 4));
[[1, 2], [2, 3], [3, 4]]
> SELECT _FUNC_(array(1, 2), array(2, 3), array(3, 4));
[[1, 2, 3], [2, 3, 4]]
""",
since = "2.4.0")
case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType)
override def dataType: DataType = ArrayType(mountSchema)
override def nullable: Boolean = children.exists(_.nullable)
private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType])
private lazy val arrayElementTypes = arrayTypes.map(_.elementType)
@transient private lazy val mountSchema: StructType = {
val fields = children.zip(arrayElementTypes).zipWithIndex.map {
case ((expr: NamedExpression, elementType), _) =>
StructField(expr.name, elementType, nullable = true)
case ((_, elementType), idx) =>
StructField(idx.toString, elementType, nullable = true)
}
StructType(fields)
}
@transient lazy val numberOfArrays: Int = children.length
@transient lazy val genericArrayData = classOf[GenericArrayData].getName
def emptyInputGenCode(ev: ExprCode): ExprCode = {
ev.copy(code"""
|${CodeGenerator.javaType(dataType)} ${ev.value} = new $genericArrayData(new Object[0]);
|boolean ${ev.isNull} = false;
""".stripMargin)
}
def nonEmptyInputGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val genericInternalRow = classOf[GenericInternalRow].getName
val arrVals = ctx.freshName("arrVals")
val biggestCardinality = ctx.freshName("biggestCardinality")
val currentRow = ctx.freshName("currentRow")
val j = ctx.freshName("j")
val i = ctx.freshName("i")
val args = ctx.freshName("args")
val evals = children.map(_.genCode(ctx))
val getValuesAndCardinalities = evals.zipWithIndex.map { case (eval, index) =>
s"""
|if ($biggestCardinality != -1) {
| ${eval.code}
| if (!${eval.isNull}) {
| $arrVals[$index] = ${eval.value};
| $biggestCardinality = Math.max($biggestCardinality, ${eval.value}.numElements());
| } else {
| $biggestCardinality = -1;
| }
|}
""".stripMargin
}
val splittedGetValuesAndCardinalities = ctx.splitExpressions(
expressions = getValuesAndCardinalities,
funcName = "getValuesAndCardinalities",
returnType = "int",
makeSplitFunction = body =>
s"""
|$body
|return $biggestCardinality;
""".stripMargin,
foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"),
arguments =
("ArrayData[]", arrVals) ::
("int", biggestCardinality) :: Nil)
val getValueForType = arrayElementTypes.zipWithIndex.map { case (eleType, idx) =>
val g = CodeGenerator.getValue(s"$arrVals[$idx]", eleType, i)
s"""
|if ($i < $arrVals[$idx].numElements() && !$arrVals[$idx].isNullAt($i)) {
| $currentRow[$idx] = $g;
|} else {
| $currentRow[$idx] = null;
|}
""".stripMargin
}
val getValueForTypeSplitted = ctx.splitExpressions(
expressions = getValueForType,
funcName = "extractValue",
arguments =
("int", i) ::
("Object[]", currentRow) ::
("ArrayData[]", arrVals) :: Nil)
val initVariables = s"""
|ArrayData[] $arrVals = new ArrayData[$numberOfArrays];
|int $biggestCardinality = 0;
|${CodeGenerator.javaType(dataType)} ${ev.value} = null;
""".stripMargin
ev.copy(code"""
|$initVariables
|$splittedGetValuesAndCardinalities
|boolean ${ev.isNull} = $biggestCardinality == -1;
|if (!${ev.isNull}) {
| Object[] $args = new Object[$biggestCardinality];
| for (int $i = 0; $i < $biggestCardinality; $i ++) {
| Object[] $currentRow = new Object[$numberOfArrays];
| $getValueForTypeSplitted
| $args[$i] = new $genericInternalRow($currentRow);
| }
| ${ev.value} = new $genericArrayData($args);
|}
""".stripMargin)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
if (numberOfArrays == 0) {
emptyInputGenCode(ev)
} else {
nonEmptyInputGenCode(ctx, ev)
}
}
override def eval(input: InternalRow): Any = {
val inputArrays = children.map(_.eval(input).asInstanceOf[ArrayData])
if (inputArrays.contains(null)) {
null
} else {
val biggestCardinality = if (inputArrays.isEmpty) {
0
} else {
inputArrays.map(_.numElements()).max
}
val result = new Array[InternalRow](biggestCardinality)
val zippedArrs: Seq[(ArrayData, Int)] = inputArrays.zipWithIndex
for (i <- 0 until biggestCardinality) {
val currentLayer: Seq[Object] = zippedArrs.map { case (arr, index) =>
if (i < arr.numElements() && !arr.isNullAt(i)) {
arr.get(i, arrayElementTypes(index))
} else {
null
}
}
result(i) = InternalRow.apply(currentLayer: _*)
}
new GenericArrayData(result)
}
}
override def prettyName: String = "arrays_zip"
}
/**
* Returns an unordered array containing the values of the map.
*/

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
@ -315,6 +316,91 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
Some(Literal.create(null, StringType))), null)
}
test("ArraysZip") {
val literals = Seq(
Literal.create(Seq(9001, 9002, 9003, null), ArrayType(IntegerType)),
Literal.create(Seq(null, 1L, null, 4L, 11L), ArrayType(LongType)),
Literal.create(Seq(-1, -3, 900, null), ArrayType(IntegerType)),
Literal.create(Seq("a", null, "c"), ArrayType(StringType)),
Literal.create(Seq(null, false, true), ArrayType(BooleanType)),
Literal.create(Seq(1.1, null, 1.3, null), ArrayType(DoubleType)),
Literal.create(Seq(), ArrayType(NullType)),
Literal.create(Seq(null), ArrayType(NullType)),
Literal.create(Seq(192.toByte), ArrayType(ByteType)),
Literal.create(
Seq(Seq(1, 2, 3), null, Seq(4, 5), Seq(1, null, 3)), ArrayType(ArrayType(IntegerType))),
Literal.create(Seq(Array[Byte](1.toByte, 5.toByte)), ArrayType(BinaryType))
)
checkEvaluation(ArraysZip(Seq(literals(0), literals(1))),
List(Row(9001, null), Row(9002, 1L), Row(9003, null), Row(null, 4L), Row(null, 11L)))
checkEvaluation(ArraysZip(Seq(literals(0), literals(2))),
List(Row(9001, -1), Row(9002, -3), Row(9003, 900), Row(null, null)))
checkEvaluation(ArraysZip(Seq(literals(0), literals(3))),
List(Row(9001, "a"), Row(9002, null), Row(9003, "c"), Row(null, null)))
checkEvaluation(ArraysZip(Seq(literals(0), literals(4))),
List(Row(9001, null), Row(9002, false), Row(9003, true), Row(null, null)))
checkEvaluation(ArraysZip(Seq(literals(0), literals(5))),
List(Row(9001, 1.1), Row(9002, null), Row(9003, 1.3), Row(null, null)))
checkEvaluation(ArraysZip(Seq(literals(0), literals(6))),
List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null)))
checkEvaluation(ArraysZip(Seq(literals(0), literals(7))),
List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null)))
checkEvaluation(ArraysZip(Seq(literals(0), literals(1), literals(2), literals(3))),
List(
Row(9001, null, -1, "a"),
Row(9002, 1L, -3, null),
Row(9003, null, 900, "c"),
Row(null, 4L, null, null),
Row(null, 11L, null, null)))
checkEvaluation(ArraysZip(Seq(literals(4), literals(5), literals(6), literals(7), literals(8))),
List(
Row(null, 1.1, null, null, 192.toByte),
Row(false, null, null, null, null),
Row(true, 1.3, null, null, null),
Row(null, null, null, null, null)))
checkEvaluation(ArraysZip(Seq(literals(9), literals(0))),
List(
Row(List(1, 2, 3), 9001),
Row(null, 9002),
Row(List(4, 5), 9003),
Row(List(1, null, 3), null)))
checkEvaluation(ArraysZip(Seq(literals(7), literals(10))),
List(Row(null, Array[Byte](1.toByte, 5.toByte))))
val longLiteral =
Literal.create((0 to 1000).toSeq, ArrayType(IntegerType))
checkEvaluation(ArraysZip(Seq(literals(0), longLiteral)),
List(Row(9001, 0), Row(9002, 1), Row(9003, 2)) ++
(3 to 1000).map { Row(null, _) }.toList)
val manyLiterals = (0 to 1000).map { _ =>
Literal.create(Seq(1), ArrayType(IntegerType))
}.toSeq
val numbers = List(
Row(Seq(9001) ++ (0 to 1000).map { _ => 1 }.toSeq: _*),
Row(Seq(9002) ++ (0 to 1000).map { _ => null }.toSeq: _*),
Row(Seq(9003) ++ (0 to 1000).map { _ => null }.toSeq: _*),
Row(Seq(null) ++ (0 to 1000).map { _ => null }.toSeq: _*))
checkEvaluation(ArraysZip(Seq(literals(0)) ++ manyLiterals),
List(numbers(0), numbers(1), numbers(2), numbers(3)))
checkEvaluation(ArraysZip(Seq(literals(0), Literal.create(null, ArrayType(IntegerType)))), null)
checkEvaluation(ArraysZip(Seq()), List())
}
test("Array Min") {
checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11)
checkEvaluation(

View file

@ -3508,6 +3508,14 @@ object functions {
*/
def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) }
/**
* Returns a merged array of structs in which the N-th struct contains all N-th values of input
* arrays.
* @group collection_funcs
* @since 2.4.0
*/
def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) }
//////////////////////////////////////////////////////////////////////////////////////////////
// Mask functions
//////////////////////////////////////////////////////////////////////////////////////////////

View file

@ -479,6 +479,53 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}
test("dataframe arrays_zip function") {
val df1 = Seq((Seq(9001, 9002, 9003), Seq(4, 5, 6))).toDF("val1", "val2")
val df2 = Seq((Seq("a", "b"), Seq(true, false), Seq(10, 11))).toDF("val1", "val2", "val3")
val df3 = Seq((Seq("a", "b"), Seq(4, 5, 6))).toDF("val1", "val2")
val df4 = Seq((Seq("a", "b", null), Seq(4L))).toDF("val1", "val2")
val df5 = Seq((Seq(-1), Seq(null), Seq(), Seq(null, null))).toDF("val1", "val2", "val3", "val4")
val df6 = Seq((Seq(192.toByte, 256.toByte), Seq(1.1), Seq(), Seq(null, null)))
.toDF("v1", "v2", "v3", "v4")
val df7 = Seq((Seq(Seq(1, 2, 3), Seq(4, 5)), Seq(1.1, 2.2))).toDF("v1", "v2")
val df8 = Seq((Seq(Array[Byte](1.toByte, 5.toByte)), Seq(null))).toDF("v1", "v2")
val expectedValue1 = Row(Seq(Row(9001, 4), Row(9002, 5), Row(9003, 6)))
checkAnswer(df1.select(arrays_zip($"val1", $"val2")), expectedValue1)
checkAnswer(df1.selectExpr("arrays_zip(val1, val2)"), expectedValue1)
val expectedValue2 = Row(Seq(Row("a", true, 10), Row("b", false, 11)))
checkAnswer(df2.select(arrays_zip($"val1", $"val2", $"val3")), expectedValue2)
checkAnswer(df2.selectExpr("arrays_zip(val1, val2, val3)"), expectedValue2)
val expectedValue3 = Row(Seq(Row("a", 4), Row("b", 5), Row(null, 6)))
checkAnswer(df3.select(arrays_zip($"val1", $"val2")), expectedValue3)
checkAnswer(df3.selectExpr("arrays_zip(val1, val2)"), expectedValue3)
val expectedValue4 = Row(Seq(Row("a", 4L), Row("b", null), Row(null, null)))
checkAnswer(df4.select(arrays_zip($"val1", $"val2")), expectedValue4)
checkAnswer(df4.selectExpr("arrays_zip(val1, val2)"), expectedValue4)
val expectedValue5 = Row(Seq(Row(-1, null, null, null), Row(null, null, null, null)))
checkAnswer(df5.select(arrays_zip($"val1", $"val2", $"val3", $"val4")), expectedValue5)
checkAnswer(df5.selectExpr("arrays_zip(val1, val2, val3, val4)"), expectedValue5)
val expectedValue6 = Row(Seq(
Row(192.toByte, 1.1, null, null), Row(256.toByte, null, null, null)))
checkAnswer(df6.select(arrays_zip($"v1", $"v2", $"v3", $"v4")), expectedValue6)
checkAnswer(df6.selectExpr("arrays_zip(v1, v2, v3, v4)"), expectedValue6)
val expectedValue7 = Row(Seq(
Row(Seq(1, 2, 3), 1.1), Row(Seq(4, 5), 2.2)))
checkAnswer(df7.select(arrays_zip($"v1", $"v2")), expectedValue7)
checkAnswer(df7.selectExpr("arrays_zip(v1, v2)"), expectedValue7)
val expectedValue8 = Row(Seq(
Row(Array[Byte](1.toByte, 5.toByte), null)))
checkAnswer(df8.select(arrays_zip($"v1", $"v2")), expectedValue8)
checkAnswer(df8.selectExpr("arrays_zip(v1, v2)"), expectedValue8)
}
test("map size function") {
val df = Seq(
(Map[Int, Int](1 -> 1, 2 -> 2), "x"),