[SPARK-23914][SQL] Add array_union function

## What changes were proposed in this pull request?

The PR adds the SQL function `array_union`. The behavior of the function is based on Presto's one.

This function returns returns an array of the elements in the union of array1 and array2.

Note: The order of elements in the result is not defined.

## How was this patch tested?

Added UTs

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes #21061 from kiszk/SPARK-23914.
This commit is contained in:
Kazuaki Ishizaki 2018-07-12 17:42:29 +09:00 committed by Takuya UESHIN
parent 5ad4735bda
commit 301bff7063
7 changed files with 499 additions and 3 deletions

View file

@ -2033,6 +2033,25 @@ def array_distinct(col):
return Column(sc._jvm.functions.array_distinct(_to_java_column(col)))
@ignore_unicode_prefix
@since(2.4)
def array_union(col1, col2):
"""
Collection function: returns an array of the elements in the union of col1 and col2,
without duplicates.
:param col1: name of column containing array
:param col2: name of column containing array
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])
>>> df.select(array_union(df.c1, df.c2)).collect()
[Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f'])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2)))
@since(1.4)
def explode(col):
"""Returns a new row for each element in the given array or map.

View file

@ -450,7 +450,7 @@ public final class UnsafeArrayData extends ArrayData {
return values;
}
private static UnsafeArrayData fromPrimitiveArray(
public static UnsafeArrayData fromPrimitiveArray(
Object arr, int offset, int length, int elementSize) {
final long headerInBytes = calculateHeaderPortionInBytes(length);
final long valueRegionInBytes = (long)elementSize * length;
@ -463,14 +463,27 @@ public final class UnsafeArrayData extends ArrayData {
final long[] data = new long[(int)totalSizeInLongs];
Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length);
if (arr != null) {
Platform.copyMemory(arr, offset, data,
Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes);
}
UnsafeArrayData result = new UnsafeArrayData();
result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8);
return result;
}
public static UnsafeArrayData forPrimitiveArray(int offset, int length, int elementSize) {
return fromPrimitiveArray(null, offset, length, elementSize);
}
public static boolean shouldUseGenericArrayData(int elementSize, int length) {
final long headerInBytes = calculateHeaderPortionInBytes(length);
final long valueRegionInBytes = (long)elementSize * length;
final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8;
return totalSizeInLongs > Integer.MAX_VALUE / 8;
}
public static UnsafeArrayData fromPrimitiveArray(boolean[] arr) {
return fromPrimitiveArray(arr, Platform.BOOLEAN_ARRAY_OFFSET, arr.length, 1);
}

View file

@ -414,6 +414,7 @@ object FunctionRegistry {
expression[ArrayJoin]("array_join"),
expression[ArrayPosition]("array_position"),
expression[ArraySort]("array_sort"),
expression[ArrayUnion]("array_union"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
expression[ElementAt]("element_at"),

View file

@ -3486,3 +3486,322 @@ case class ArrayDistinct(child: Expression)
override def prettyName: String = "array_distinct"
}
/**
* Will become common base class for [[ArrayUnion]], ArrayIntersect, and ArrayExcept.
*/
abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast {
override def dataType: DataType = {
val dataTypes = children.map(_.dataType.asInstanceOf[ArrayType])
ArrayType(elementType, dataTypes.exists(_.containsNull))
}
override def checkInputDataTypes(): TypeCheckResult = {
val typeCheckResult = super.checkInputDataTypes()
if (typeCheckResult.isSuccess) {
TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType,
s"function $prettyName")
} else {
typeCheckResult
}
}
@transient protected lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(elementType)
@transient protected lazy val elementTypeSupportEquals = elementType match {
case BinaryType => false
case _: AtomicType => true
case _ => false
}
}
object ArraySetLike {
def throwUnionLengthOverflowException(length: Int): Unit = {
throw new RuntimeException(s"Unsuccessful try to union arrays with $length " +
s"elements due to exceeding the array size limit " +
s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
}
}
/**
* Returns an array of the elements in the union of x and y, without duplicates
*/
@ExpressionDescription(
usage = """
_FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2,
without duplicates.
""",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5));
array(1, 2, 3, 5)
""",
since = "2.4.0")
case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike {
var hsInt: OpenHashSet[Int] = _
var hsLong: OpenHashSet[Long] = _
def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
val elem = array.getInt(idx)
if (!hsInt.contains(elem)) {
if (resultArray != null) {
resultArray.setInt(pos, elem)
}
hsInt.add(elem)
true
} else {
false
}
}
def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
val elem = array.getLong(idx)
if (!hsLong.contains(elem)) {
if (resultArray != null) {
resultArray.setLong(pos, elem)
}
hsLong.add(elem)
true
} else {
false
}
}
def evalIntLongPrimitiveType(
array1: ArrayData,
array2: ArrayData,
resultArray: ArrayData,
isLongType: Boolean): Int = {
// store elements into resultArray
var nullElementSize = 0
var pos = 0
Seq(array1, array2).foreach { array =>
var i = 0
while (i < array.numElements()) {
val size = if (!isLongType) hsInt.size else hsLong.size
if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
ArraySetLike.throwUnionLengthOverflowException(size)
}
if (array.isNullAt(i)) {
if (nullElementSize == 0) {
if (resultArray != null) {
resultArray.setNullAt(pos)
}
pos += 1
nullElementSize = 1
}
} else {
val assigned = if (!isLongType) {
assignInt(array, i, resultArray, pos)
} else {
assignLong(array, i, resultArray, pos)
}
if (assigned) {
pos += 1
}
}
i += 1
}
}
pos
}
override def nullSafeEval(input1: Any, input2: Any): Any = {
val array1 = input1.asInstanceOf[ArrayData]
val array2 = input2.asInstanceOf[ArrayData]
if (elementTypeSupportEquals) {
elementType match {
case IntegerType =>
// avoid boxing of primitive int array elements
// calculate result array size
hsInt = new OpenHashSet[Int]
val elements = evalIntLongPrimitiveType(array1, array2, null, false)
hsInt = new OpenHashSet[Int]
val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
IntegerType.defaultSize, elements)) {
new GenericArrayData(new Array[Any](elements))
} else {
UnsafeArrayData.forPrimitiveArray(
Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize)
}
evalIntLongPrimitiveType(array1, array2, resultArray, false)
resultArray
case LongType =>
// avoid boxing of primitive long array elements
// calculate result array size
hsLong = new OpenHashSet[Long]
val elements = evalIntLongPrimitiveType(array1, array2, null, true)
hsLong = new OpenHashSet[Long]
val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
LongType.defaultSize, elements)) {
new GenericArrayData(new Array[Any](elements))
} else {
UnsafeArrayData.forPrimitiveArray(
Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize)
}
evalIntLongPrimitiveType(array1, array2, resultArray, true)
resultArray
case _ =>
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
val hs = new OpenHashSet[Any]
var foundNullElement = false
Seq(array1, array2).foreach { array =>
var i = 0
while (i < array.numElements()) {
if (array.isNullAt(i)) {
if (!foundNullElement) {
arrayBuffer += null
foundNullElement = true
}
} else {
val elem = array.get(i, elementType)
if (!hs.contains(elem)) {
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
}
arrayBuffer += elem
hs.add(elem)
}
}
i += 1
}
}
new GenericArrayData(arrayBuffer)
}
} else {
ArrayUnion.unionOrdering(array1, array2, elementType, ordering)
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val i = ctx.freshName("i")
val pos = ctx.freshName("pos")
val value = ctx.freshName("value")
val size = ctx.freshName("size")
val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) =
if (elementTypeSupportEquals) {
elementType match {
case ByteType | ShortType | IntegerType | LongType =>
val ptName = CodeGenerator.primitiveTypeName(elementType)
val unsafeArray = ctx.freshName("unsafeArray")
(if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp",
if (elementType == LongType) "Long" else "Int",
s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType),
if (elementType == LongType) "(long)" else "(int)",
s"""
|${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")}
|${ev.value} = $unsafeArray;
""".stripMargin)
case _ =>
val genericArrayData = classOf[GenericArrayData].getName
val et = ctx.addReferenceObj("elementType", elementType)
("", "Object",
s"get($i, $et)", s"update($pos, $value)", "Object", "",
s"${ev.value} = new $genericArrayData(new Object[$size]);")
}
} else {
("", "", "", "", "", "", "")
}
nullSafeCodeGen(ctx, ev, (array1, array2) => {
if (openHashElementType != "") {
// Here, we ensure elementTypeSupportEquals is true
val foundNullElement = ctx.freshName("foundNullElement")
val openHashSet = classOf[OpenHashSet[_]].getName
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()"
val hs = ctx.freshName("hs")
val arrayData = classOf[ArrayData].getName
val arrays = ctx.freshName("arrays")
val array = ctx.freshName("array")
val arrayDataIdx = ctx.freshName("arrayDataIdx")
s"""
|$openHashSet $hs = new $openHashSet$postFix($classTag);
|boolean $foundNullElement = false;
|$arrayData[] $arrays = new $arrayData[]{$array1, $array2};
|for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) {
| $arrayData $array = $arrays[$arrayDataIdx];
| for (int $i = 0; $i < $array.numElements(); $i++) {
| if ($array.isNullAt($i)) {
| $foundNullElement = true;
| } else {
| $hs.add$postFix($array.$getter);
| }
| }
|}
|int $size = $hs.size() + ($foundNullElement ? 1 : 0);
|$arrayBuilder
|$hs = new $openHashSet$postFix($classTag);
|$foundNullElement = false;
|int $pos = 0;
|for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) {
| $arrayData $array = $arrays[$arrayDataIdx];
| for (int $i = 0; $i < $array.numElements(); $i++) {
| if ($array.isNullAt($i)) {
| if (!$foundNullElement) {
| ${ev.value}.setNullAt($pos++);
| $foundNullElement = true;
| }
| } else {
| $javaTypeName $value = $array.$getter;
| if (!$hs.contains($castOp $value)) {
| $hs.add$postFix($value);
| ${ev.value}.$setter;
| $pos++;
| }
| }
| }
|}
""".stripMargin
} else {
val arrayUnion = classOf[ArrayUnion].getName
val et = ctx.addReferenceObj("elementTypeUnion", elementType)
val order = ctx.addReferenceObj("orderingUnion", ordering)
val method = "unionOrdering"
s"${ev.value} = $arrayUnion$$.MODULE$$.$method($array1, $array2, $et, $order);"
}
})
}
override def prettyName: String = "array_union"
}
object ArrayUnion {
def unionOrdering(
array1: ArrayData,
array2: ArrayData,
elementType: DataType,
ordering: Ordering[Any]): ArrayData = {
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
var alreadyIncludeNull = false
Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => {
var found = false
if (elem == null) {
if (alreadyIncludeNull) {
found = true
} else {
alreadyIncludeNull = true
}
} else {
// check elem is already stored in arrayBuffer or not?
var j = 0
while (!found && j < arrayBuffer.size) {
val va = arrayBuffer(j)
if (va != null && ordering.equiv(va, elem)) {
found = true
}
j = j + 1
}
}
if (!found) {
if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length)
}
arrayBuffer += elem
}
}))
new GenericArrayData(arrayBuffer)
}
}

View file

@ -1304,4 +1304,85 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)))
checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1)))
}
test("Array Union") {
val a00 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, containsNull = false))
val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType, containsNull = true))
val a03 = Literal.create(Seq(-5, 4, -3, 2, 4), ArrayType(IntegerType, containsNull = false))
val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false))
val a05 = Literal.create(Seq[Byte](1, 2, 3), ArrayType(ByteType, containsNull = false))
val a06 = Literal.create(Seq[Byte](4, 2), ArrayType(ByteType, containsNull = false))
val a07 = Literal.create(Seq[Short](1, 2, 3), ArrayType(ShortType, containsNull = false))
val a08 = Literal.create(Seq[Short](4, 2), ArrayType(ShortType, containsNull = false))
val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, containsNull = false))
val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, containsNull = false))
val a12 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType, containsNull = true))
val a13 = Literal.create(Seq(-5L, 4L, -3L, 2L, -1L), ArrayType(LongType, containsNull = false))
val a14 = Literal.create(Seq.empty[Long], ArrayType(LongType, containsNull = false))
val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, containsNull = false))
val a21 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, containsNull = false))
val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType, containsNull = true))
val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType))
val a31 = Literal.create(null, ArrayType(StringType))
checkEvaluation(ArrayUnion(a00, a01), Seq(1, 2, 3, 4))
checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3))
checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, 1, null, 5))
checkEvaluation(ArrayUnion(a02, a04), Seq(1, 2, null, 4, 5))
checkEvaluation(ArrayUnion(a05, a06), Seq[Byte](1, 2, 3, 4))
checkEvaluation(ArrayUnion(a07, a08), Seq[Short](1, 2, 3, 4))
checkEvaluation(ArrayUnion(a10, a11), Seq(1L, 2L, 3L, 4L))
checkEvaluation(ArrayUnion(a12, a13), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L))
checkEvaluation(ArrayUnion(a13, a12), Seq(-5L, 4L, -3L, 2L, -1L, 1L, null, 5L))
checkEvaluation(ArrayUnion(a12, a14), Seq(1L, 2L, null, 4L, 5L))
checkEvaluation(ArrayUnion(a20, a21), Seq("b", "a", "c", "d", "f"))
checkEvaluation(ArrayUnion(a20, a22), Seq("b", "a", "c", null, "g"))
checkEvaluation(ArrayUnion(a30, a30), Seq(null))
checkEvaluation(ArrayUnion(a20, a31), null)
checkEvaluation(ArrayUnion(a31, a20), null)
val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)),
ArrayType(BinaryType))
val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)),
ArrayType(BinaryType))
val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](4, 3)),
ArrayType(BinaryType))
val b3 = Literal.create(Seq[Array[Byte]](
Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](1, 2)), ArrayType(BinaryType))
val b4 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), null), ArrayType(BinaryType))
val b5 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), ArrayType(BinaryType))
val b6 = Literal.create(Seq.empty, ArrayType(BinaryType))
val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType))
checkEvaluation(ArrayUnion(b0, b1),
Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](2, 1), Array[Byte](4, 3)))
checkEvaluation(ArrayUnion(b0, b2),
Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](4, 3)))
checkEvaluation(ArrayUnion(b2, b4), Seq(Array[Byte](1, 2), Array[Byte](4, 3), null))
checkEvaluation(ArrayUnion(b3, b0),
Seq(Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](5, 6)))
checkEvaluation(ArrayUnion(b4, b0), Seq(Array[Byte](1, 2), null, Array[Byte](5, 6)))
checkEvaluation(ArrayUnion(b4, b5), Seq(Array[Byte](1, 2), null))
checkEvaluation(ArrayUnion(b6, b4), Seq(Array[Byte](1, 2), null))
checkEvaluation(ArrayUnion(b4, arrayWithBinaryNull), Seq(Array[Byte](1, 2), null))
val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
ArrayType(ArrayType(IntegerType)))
val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
ArrayType(ArrayType(IntegerType)))
checkEvaluation(ArrayUnion(aa0, aa1),
Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](5, 6), Seq[Int](2, 1)))
assert(ArrayUnion(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false)
assert(ArrayUnion(a00, a02).dataType.asInstanceOf[ArrayType].containsNull === true)
assert(ArrayUnion(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false)
assert(ArrayUnion(a20, a22).dataType.asInstanceOf[ArrayType].containsNull === true)
}
}

View file

@ -3204,6 +3204,7 @@ object functions {
/**
* Remove all elements that equal to element from the given array.
*
* @group collection_funcs
* @since 2.4.0
*/
@ -3218,6 +3219,16 @@ object functions {
*/
def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) }
/**
* Returns an array of the elements in the union of the given two arrays, without duplicates.
*
* @group collection_funcs
* @since 2.4.0
*/
def array_union(col1: Column, col2: Column): Column = withExpr {
ArrayUnion(col1.expr, col2.expr)
}
/**
* Creates a new row for each element in the given array or map column.
*

View file

@ -1198,6 +1198,58 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
"argument 1 requires (array or map) type, however, '`_1`' is of string type"))
}
test("array_union functions") {
val df1 = Seq((Array(1, 2, 3), Array(4, 2))).toDF("a", "b")
val ans1 = Row(Seq(1, 2, 3, 4))
checkAnswer(df1.select(array_union($"a", $"b")), ans1)
checkAnswer(df1.selectExpr("array_union(a, b)"), ans1)
val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array(-5, 4, -3, 2, -1))).toDF("a", "b")
val ans2 = Row(Seq(1, 2, null, 4, 5, -5, -3, -1))
checkAnswer(df2.select(array_union($"a", $"b")), ans2)
checkAnswer(df2.selectExpr("array_union(a, b)"), ans2)
val df3 = Seq((Array(1L, 2L, 3L), Array(4L, 2L))).toDF("a", "b")
val ans3 = Row(Seq(1L, 2L, 3L, 4L))
checkAnswer(df3.select(array_union($"a", $"b")), ans3)
checkAnswer(df3.selectExpr("array_union(a, b)"), ans3)
val df4 = Seq((Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array(-5L, 4L, -3L, 2L, -1L)))
.toDF("a", "b")
val ans4 = Row(Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L))
checkAnswer(df4.select(array_union($"a", $"b")), ans4)
checkAnswer(df4.selectExpr("array_union(a, b)"), ans4)
val df5 = Seq((Array("b", "a", "c"), Array("b", null, "a", "g"))).toDF("a", "b")
val ans5 = Row(Seq("b", "a", "c", null, "g"))
checkAnswer(df5.select(array_union($"a", $"b")), ans5)
checkAnswer(df5.selectExpr("array_union(a, b)"), ans5)
val df6 = Seq((null, Array("a"))).toDF("a", "b")
intercept[AnalysisException] {
df6.select(array_union($"a", $"b"))
}
intercept[AnalysisException] {
df6.selectExpr("array_union(a, b)")
}
val df7 = Seq((null, null)).toDF("a", "b")
intercept[AnalysisException] {
df7.select(array_union($"a", $"b"))
}
intercept[AnalysisException] {
df7.selectExpr("array_union(a, b)")
}
val df8 = Seq((Array(Array(1)), Array("a"))).toDF("a", "b")
intercept[AnalysisException] {
df8.select(array_union($"a", $"b"))
}
intercept[AnalysisException] {
df8.selectExpr("array_union(a, b)")
}
}
test("concat function - arrays") {
val nseqi : Seq[Int] = null
val nseqs : Seq[String] = null