[SPARK-23914][SQL] Add array_union function
## What changes were proposed in this pull request? The PR adds the SQL function `array_union`. The behavior of the function is based on Presto's one. This function returns returns an array of the elements in the union of array1 and array2. Note: The order of elements in the result is not defined. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com> Closes #21061 from kiszk/SPARK-23914.
This commit is contained in:
parent
5ad4735bda
commit
301bff7063
|
@ -2033,6 +2033,25 @@ def array_distinct(col):
|
|||
return Column(sc._jvm.functions.array_distinct(_to_java_column(col)))
|
||||
|
||||
|
||||
@ignore_unicode_prefix
|
||||
@since(2.4)
|
||||
def array_union(col1, col2):
|
||||
"""
|
||||
Collection function: returns an array of the elements in the union of col1 and col2,
|
||||
without duplicates.
|
||||
|
||||
:param col1: name of column containing array
|
||||
:param col2: name of column containing array
|
||||
|
||||
>>> from pyspark.sql import Row
|
||||
>>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])
|
||||
>>> df.select(array_union(df.c1, df.c2)).collect()
|
||||
[Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f'])]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2)))
|
||||
|
||||
|
||||
@since(1.4)
|
||||
def explode(col):
|
||||
"""Returns a new row for each element in the given array or map.
|
||||
|
|
|
@ -450,7 +450,7 @@ public final class UnsafeArrayData extends ArrayData {
|
|||
return values;
|
||||
}
|
||||
|
||||
private static UnsafeArrayData fromPrimitiveArray(
|
||||
public static UnsafeArrayData fromPrimitiveArray(
|
||||
Object arr, int offset, int length, int elementSize) {
|
||||
final long headerInBytes = calculateHeaderPortionInBytes(length);
|
||||
final long valueRegionInBytes = (long)elementSize * length;
|
||||
|
@ -463,14 +463,27 @@ public final class UnsafeArrayData extends ArrayData {
|
|||
final long[] data = new long[(int)totalSizeInLongs];
|
||||
|
||||
Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length);
|
||||
Platform.copyMemory(arr, offset, data,
|
||||
Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes);
|
||||
if (arr != null) {
|
||||
Platform.copyMemory(arr, offset, data,
|
||||
Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes);
|
||||
}
|
||||
|
||||
UnsafeArrayData result = new UnsafeArrayData();
|
||||
result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8);
|
||||
return result;
|
||||
}
|
||||
|
||||
public static UnsafeArrayData forPrimitiveArray(int offset, int length, int elementSize) {
|
||||
return fromPrimitiveArray(null, offset, length, elementSize);
|
||||
}
|
||||
|
||||
public static boolean shouldUseGenericArrayData(int elementSize, int length) {
|
||||
final long headerInBytes = calculateHeaderPortionInBytes(length);
|
||||
final long valueRegionInBytes = (long)elementSize * length;
|
||||
final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8;
|
||||
return totalSizeInLongs > Integer.MAX_VALUE / 8;
|
||||
}
|
||||
|
||||
public static UnsafeArrayData fromPrimitiveArray(boolean[] arr) {
|
||||
return fromPrimitiveArray(arr, Platform.BOOLEAN_ARRAY_OFFSET, arr.length, 1);
|
||||
}
|
||||
|
|
|
@ -414,6 +414,7 @@ object FunctionRegistry {
|
|||
expression[ArrayJoin]("array_join"),
|
||||
expression[ArrayPosition]("array_position"),
|
||||
expression[ArraySort]("array_sort"),
|
||||
expression[ArrayUnion]("array_union"),
|
||||
expression[CreateMap]("map"),
|
||||
expression[CreateNamedStruct]("named_struct"),
|
||||
expression[ElementAt]("element_at"),
|
||||
|
|
|
@ -3486,3 +3486,322 @@ case class ArrayDistinct(child: Expression)
|
|||
|
||||
override def prettyName: String = "array_distinct"
|
||||
}
|
||||
|
||||
/**
|
||||
* Will become common base class for [[ArrayUnion]], ArrayIntersect, and ArrayExcept.
|
||||
*/
|
||||
abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast {
|
||||
override def dataType: DataType = {
|
||||
val dataTypes = children.map(_.dataType.asInstanceOf[ArrayType])
|
||||
ArrayType(elementType, dataTypes.exists(_.containsNull))
|
||||
}
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
val typeCheckResult = super.checkInputDataTypes()
|
||||
if (typeCheckResult.isSuccess) {
|
||||
TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType,
|
||||
s"function $prettyName")
|
||||
} else {
|
||||
typeCheckResult
|
||||
}
|
||||
}
|
||||
|
||||
@transient protected lazy val ordering: Ordering[Any] =
|
||||
TypeUtils.getInterpretedOrdering(elementType)
|
||||
|
||||
@transient protected lazy val elementTypeSupportEquals = elementType match {
|
||||
case BinaryType => false
|
||||
case _: AtomicType => true
|
||||
case _ => false
|
||||
}
|
||||
}
|
||||
|
||||
object ArraySetLike {
|
||||
def throwUnionLengthOverflowException(length: Int): Unit = {
|
||||
throw new RuntimeException(s"Unsuccessful try to union arrays with $length " +
|
||||
s"elements due to exceeding the array size limit " +
|
||||
s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns an array of the elements in the union of x and y, without duplicates
|
||||
*/
|
||||
@ExpressionDescription(
|
||||
usage = """
|
||||
_FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2,
|
||||
without duplicates.
|
||||
""",
|
||||
examples = """
|
||||
Examples:
|
||||
> SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5));
|
||||
array(1, 2, 3, 5)
|
||||
""",
|
||||
since = "2.4.0")
|
||||
case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike {
|
||||
var hsInt: OpenHashSet[Int] = _
|
||||
var hsLong: OpenHashSet[Long] = _
|
||||
|
||||
def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
|
||||
val elem = array.getInt(idx)
|
||||
if (!hsInt.contains(elem)) {
|
||||
if (resultArray != null) {
|
||||
resultArray.setInt(pos, elem)
|
||||
}
|
||||
hsInt.add(elem)
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
|
||||
val elem = array.getLong(idx)
|
||||
if (!hsLong.contains(elem)) {
|
||||
if (resultArray != null) {
|
||||
resultArray.setLong(pos, elem)
|
||||
}
|
||||
hsLong.add(elem)
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
def evalIntLongPrimitiveType(
|
||||
array1: ArrayData,
|
||||
array2: ArrayData,
|
||||
resultArray: ArrayData,
|
||||
isLongType: Boolean): Int = {
|
||||
// store elements into resultArray
|
||||
var nullElementSize = 0
|
||||
var pos = 0
|
||||
Seq(array1, array2).foreach { array =>
|
||||
var i = 0
|
||||
while (i < array.numElements()) {
|
||||
val size = if (!isLongType) hsInt.size else hsLong.size
|
||||
if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
|
||||
ArraySetLike.throwUnionLengthOverflowException(size)
|
||||
}
|
||||
if (array.isNullAt(i)) {
|
||||
if (nullElementSize == 0) {
|
||||
if (resultArray != null) {
|
||||
resultArray.setNullAt(pos)
|
||||
}
|
||||
pos += 1
|
||||
nullElementSize = 1
|
||||
}
|
||||
} else {
|
||||
val assigned = if (!isLongType) {
|
||||
assignInt(array, i, resultArray, pos)
|
||||
} else {
|
||||
assignLong(array, i, resultArray, pos)
|
||||
}
|
||||
if (assigned) {
|
||||
pos += 1
|
||||
}
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
}
|
||||
pos
|
||||
}
|
||||
|
||||
override def nullSafeEval(input1: Any, input2: Any): Any = {
|
||||
val array1 = input1.asInstanceOf[ArrayData]
|
||||
val array2 = input2.asInstanceOf[ArrayData]
|
||||
|
||||
if (elementTypeSupportEquals) {
|
||||
elementType match {
|
||||
case IntegerType =>
|
||||
// avoid boxing of primitive int array elements
|
||||
// calculate result array size
|
||||
hsInt = new OpenHashSet[Int]
|
||||
val elements = evalIntLongPrimitiveType(array1, array2, null, false)
|
||||
hsInt = new OpenHashSet[Int]
|
||||
val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
|
||||
IntegerType.defaultSize, elements)) {
|
||||
new GenericArrayData(new Array[Any](elements))
|
||||
} else {
|
||||
UnsafeArrayData.forPrimitiveArray(
|
||||
Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize)
|
||||
}
|
||||
evalIntLongPrimitiveType(array1, array2, resultArray, false)
|
||||
resultArray
|
||||
case LongType =>
|
||||
// avoid boxing of primitive long array elements
|
||||
// calculate result array size
|
||||
hsLong = new OpenHashSet[Long]
|
||||
val elements = evalIntLongPrimitiveType(array1, array2, null, true)
|
||||
hsLong = new OpenHashSet[Long]
|
||||
val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
|
||||
LongType.defaultSize, elements)) {
|
||||
new GenericArrayData(new Array[Any](elements))
|
||||
} else {
|
||||
UnsafeArrayData.forPrimitiveArray(
|
||||
Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize)
|
||||
}
|
||||
evalIntLongPrimitiveType(array1, array2, resultArray, true)
|
||||
resultArray
|
||||
case _ =>
|
||||
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
|
||||
val hs = new OpenHashSet[Any]
|
||||
var foundNullElement = false
|
||||
Seq(array1, array2).foreach { array =>
|
||||
var i = 0
|
||||
while (i < array.numElements()) {
|
||||
if (array.isNullAt(i)) {
|
||||
if (!foundNullElement) {
|
||||
arrayBuffer += null
|
||||
foundNullElement = true
|
||||
}
|
||||
} else {
|
||||
val elem = array.get(i, elementType)
|
||||
if (!hs.contains(elem)) {
|
||||
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
|
||||
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
|
||||
}
|
||||
arrayBuffer += elem
|
||||
hs.add(elem)
|
||||
}
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
}
|
||||
new GenericArrayData(arrayBuffer)
|
||||
}
|
||||
} else {
|
||||
ArrayUnion.unionOrdering(array1, array2, elementType, ordering)
|
||||
}
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val i = ctx.freshName("i")
|
||||
val pos = ctx.freshName("pos")
|
||||
val value = ctx.freshName("value")
|
||||
val size = ctx.freshName("size")
|
||||
val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) =
|
||||
if (elementTypeSupportEquals) {
|
||||
elementType match {
|
||||
case ByteType | ShortType | IntegerType | LongType =>
|
||||
val ptName = CodeGenerator.primitiveTypeName(elementType)
|
||||
val unsafeArray = ctx.freshName("unsafeArray")
|
||||
(if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp",
|
||||
if (elementType == LongType) "Long" else "Int",
|
||||
s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType),
|
||||
if (elementType == LongType) "(long)" else "(int)",
|
||||
s"""
|
||||
|${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")}
|
||||
|${ev.value} = $unsafeArray;
|
||||
""".stripMargin)
|
||||
case _ =>
|
||||
val genericArrayData = classOf[GenericArrayData].getName
|
||||
val et = ctx.addReferenceObj("elementType", elementType)
|
||||
("", "Object",
|
||||
s"get($i, $et)", s"update($pos, $value)", "Object", "",
|
||||
s"${ev.value} = new $genericArrayData(new Object[$size]);")
|
||||
}
|
||||
} else {
|
||||
("", "", "", "", "", "", "")
|
||||
}
|
||||
|
||||
nullSafeCodeGen(ctx, ev, (array1, array2) => {
|
||||
if (openHashElementType != "") {
|
||||
// Here, we ensure elementTypeSupportEquals is true
|
||||
val foundNullElement = ctx.freshName("foundNullElement")
|
||||
val openHashSet = classOf[OpenHashSet[_]].getName
|
||||
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()"
|
||||
val hs = ctx.freshName("hs")
|
||||
val arrayData = classOf[ArrayData].getName
|
||||
val arrays = ctx.freshName("arrays")
|
||||
val array = ctx.freshName("array")
|
||||
val arrayDataIdx = ctx.freshName("arrayDataIdx")
|
||||
s"""
|
||||
|$openHashSet $hs = new $openHashSet$postFix($classTag);
|
||||
|boolean $foundNullElement = false;
|
||||
|$arrayData[] $arrays = new $arrayData[]{$array1, $array2};
|
||||
|for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) {
|
||||
| $arrayData $array = $arrays[$arrayDataIdx];
|
||||
| for (int $i = 0; $i < $array.numElements(); $i++) {
|
||||
| if ($array.isNullAt($i)) {
|
||||
| $foundNullElement = true;
|
||||
| } else {
|
||||
| $hs.add$postFix($array.$getter);
|
||||
| }
|
||||
| }
|
||||
|}
|
||||
|int $size = $hs.size() + ($foundNullElement ? 1 : 0);
|
||||
|$arrayBuilder
|
||||
|$hs = new $openHashSet$postFix($classTag);
|
||||
|$foundNullElement = false;
|
||||
|int $pos = 0;
|
||||
|for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) {
|
||||
| $arrayData $array = $arrays[$arrayDataIdx];
|
||||
| for (int $i = 0; $i < $array.numElements(); $i++) {
|
||||
| if ($array.isNullAt($i)) {
|
||||
| if (!$foundNullElement) {
|
||||
| ${ev.value}.setNullAt($pos++);
|
||||
| $foundNullElement = true;
|
||||
| }
|
||||
| } else {
|
||||
| $javaTypeName $value = $array.$getter;
|
||||
| if (!$hs.contains($castOp $value)) {
|
||||
| $hs.add$postFix($value);
|
||||
| ${ev.value}.$setter;
|
||||
| $pos++;
|
||||
| }
|
||||
| }
|
||||
| }
|
||||
|}
|
||||
""".stripMargin
|
||||
} else {
|
||||
val arrayUnion = classOf[ArrayUnion].getName
|
||||
val et = ctx.addReferenceObj("elementTypeUnion", elementType)
|
||||
val order = ctx.addReferenceObj("orderingUnion", ordering)
|
||||
val method = "unionOrdering"
|
||||
s"${ev.value} = $arrayUnion$$.MODULE$$.$method($array1, $array2, $et, $order);"
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
override def prettyName: String = "array_union"
|
||||
}
|
||||
|
||||
object ArrayUnion {
|
||||
def unionOrdering(
|
||||
array1: ArrayData,
|
||||
array2: ArrayData,
|
||||
elementType: DataType,
|
||||
ordering: Ordering[Any]): ArrayData = {
|
||||
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
|
||||
var alreadyIncludeNull = false
|
||||
Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => {
|
||||
var found = false
|
||||
if (elem == null) {
|
||||
if (alreadyIncludeNull) {
|
||||
found = true
|
||||
} else {
|
||||
alreadyIncludeNull = true
|
||||
}
|
||||
} else {
|
||||
// check elem is already stored in arrayBuffer or not?
|
||||
var j = 0
|
||||
while (!found && j < arrayBuffer.size) {
|
||||
val va = arrayBuffer(j)
|
||||
if (va != null && ordering.equiv(va, elem)) {
|
||||
found = true
|
||||
}
|
||||
j = j + 1
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
|
||||
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length)
|
||||
}
|
||||
arrayBuffer += elem
|
||||
}
|
||||
}))
|
||||
new GenericArrayData(arrayBuffer)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1304,4 +1304,85 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)))
|
||||
checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1)))
|
||||
}
|
||||
|
||||
test("Array Union") {
|
||||
val a00 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
|
||||
val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, containsNull = false))
|
||||
val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType, containsNull = true))
|
||||
val a03 = Literal.create(Seq(-5, 4, -3, 2, 4), ArrayType(IntegerType, containsNull = false))
|
||||
val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false))
|
||||
val a05 = Literal.create(Seq[Byte](1, 2, 3), ArrayType(ByteType, containsNull = false))
|
||||
val a06 = Literal.create(Seq[Byte](4, 2), ArrayType(ByteType, containsNull = false))
|
||||
val a07 = Literal.create(Seq[Short](1, 2, 3), ArrayType(ShortType, containsNull = false))
|
||||
val a08 = Literal.create(Seq[Short](4, 2), ArrayType(ShortType, containsNull = false))
|
||||
|
||||
val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, containsNull = false))
|
||||
val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, containsNull = false))
|
||||
val a12 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType, containsNull = true))
|
||||
val a13 = Literal.create(Seq(-5L, 4L, -3L, 2L, -1L), ArrayType(LongType, containsNull = false))
|
||||
val a14 = Literal.create(Seq.empty[Long], ArrayType(LongType, containsNull = false))
|
||||
|
||||
val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, containsNull = false))
|
||||
val a21 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, containsNull = false))
|
||||
val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType, containsNull = true))
|
||||
|
||||
val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType))
|
||||
val a31 = Literal.create(null, ArrayType(StringType))
|
||||
|
||||
checkEvaluation(ArrayUnion(a00, a01), Seq(1, 2, 3, 4))
|
||||
checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3))
|
||||
checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, 1, null, 5))
|
||||
checkEvaluation(ArrayUnion(a02, a04), Seq(1, 2, null, 4, 5))
|
||||
checkEvaluation(ArrayUnion(a05, a06), Seq[Byte](1, 2, 3, 4))
|
||||
checkEvaluation(ArrayUnion(a07, a08), Seq[Short](1, 2, 3, 4))
|
||||
|
||||
checkEvaluation(ArrayUnion(a10, a11), Seq(1L, 2L, 3L, 4L))
|
||||
checkEvaluation(ArrayUnion(a12, a13), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L))
|
||||
checkEvaluation(ArrayUnion(a13, a12), Seq(-5L, 4L, -3L, 2L, -1L, 1L, null, 5L))
|
||||
checkEvaluation(ArrayUnion(a12, a14), Seq(1L, 2L, null, 4L, 5L))
|
||||
|
||||
checkEvaluation(ArrayUnion(a20, a21), Seq("b", "a", "c", "d", "f"))
|
||||
checkEvaluation(ArrayUnion(a20, a22), Seq("b", "a", "c", null, "g"))
|
||||
|
||||
checkEvaluation(ArrayUnion(a30, a30), Seq(null))
|
||||
checkEvaluation(ArrayUnion(a20, a31), null)
|
||||
checkEvaluation(ArrayUnion(a31, a20), null)
|
||||
|
||||
val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)),
|
||||
ArrayType(BinaryType))
|
||||
val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)),
|
||||
ArrayType(BinaryType))
|
||||
val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](4, 3)),
|
||||
ArrayType(BinaryType))
|
||||
val b3 = Literal.create(Seq[Array[Byte]](
|
||||
Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](1, 2)), ArrayType(BinaryType))
|
||||
val b4 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), null), ArrayType(BinaryType))
|
||||
val b5 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), ArrayType(BinaryType))
|
||||
val b6 = Literal.create(Seq.empty, ArrayType(BinaryType))
|
||||
val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType))
|
||||
|
||||
checkEvaluation(ArrayUnion(b0, b1),
|
||||
Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](2, 1), Array[Byte](4, 3)))
|
||||
checkEvaluation(ArrayUnion(b0, b2),
|
||||
Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](4, 3)))
|
||||
checkEvaluation(ArrayUnion(b2, b4), Seq(Array[Byte](1, 2), Array[Byte](4, 3), null))
|
||||
checkEvaluation(ArrayUnion(b3, b0),
|
||||
Seq(Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](5, 6)))
|
||||
checkEvaluation(ArrayUnion(b4, b0), Seq(Array[Byte](1, 2), null, Array[Byte](5, 6)))
|
||||
checkEvaluation(ArrayUnion(b4, b5), Seq(Array[Byte](1, 2), null))
|
||||
checkEvaluation(ArrayUnion(b6, b4), Seq(Array[Byte](1, 2), null))
|
||||
checkEvaluation(ArrayUnion(b4, arrayWithBinaryNull), Seq(Array[Byte](1, 2), null))
|
||||
|
||||
val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
|
||||
ArrayType(ArrayType(IntegerType)))
|
||||
val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
|
||||
ArrayType(ArrayType(IntegerType)))
|
||||
checkEvaluation(ArrayUnion(aa0, aa1),
|
||||
Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](5, 6), Seq[Int](2, 1)))
|
||||
|
||||
assert(ArrayUnion(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false)
|
||||
assert(ArrayUnion(a00, a02).dataType.asInstanceOf[ArrayType].containsNull === true)
|
||||
assert(ArrayUnion(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false)
|
||||
assert(ArrayUnion(a20, a22).dataType.asInstanceOf[ArrayType].containsNull === true)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3204,6 +3204,7 @@ object functions {
|
|||
|
||||
/**
|
||||
* Remove all elements that equal to element from the given array.
|
||||
*
|
||||
* @group collection_funcs
|
||||
* @since 2.4.0
|
||||
*/
|
||||
|
@ -3218,6 +3219,16 @@ object functions {
|
|||
*/
|
||||
def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) }
|
||||
|
||||
/**
|
||||
* Returns an array of the elements in the union of the given two arrays, without duplicates.
|
||||
*
|
||||
* @group collection_funcs
|
||||
* @since 2.4.0
|
||||
*/
|
||||
def array_union(col1: Column, col2: Column): Column = withExpr {
|
||||
ArrayUnion(col1.expr, col2.expr)
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new row for each element in the given array or map column.
|
||||
*
|
||||
|
|
|
@ -1198,6 +1198,58 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
|
|||
"argument 1 requires (array or map) type, however, '`_1`' is of string type"))
|
||||
}
|
||||
|
||||
test("array_union functions") {
|
||||
val df1 = Seq((Array(1, 2, 3), Array(4, 2))).toDF("a", "b")
|
||||
val ans1 = Row(Seq(1, 2, 3, 4))
|
||||
checkAnswer(df1.select(array_union($"a", $"b")), ans1)
|
||||
checkAnswer(df1.selectExpr("array_union(a, b)"), ans1)
|
||||
|
||||
val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array(-5, 4, -3, 2, -1))).toDF("a", "b")
|
||||
val ans2 = Row(Seq(1, 2, null, 4, 5, -5, -3, -1))
|
||||
checkAnswer(df2.select(array_union($"a", $"b")), ans2)
|
||||
checkAnswer(df2.selectExpr("array_union(a, b)"), ans2)
|
||||
|
||||
val df3 = Seq((Array(1L, 2L, 3L), Array(4L, 2L))).toDF("a", "b")
|
||||
val ans3 = Row(Seq(1L, 2L, 3L, 4L))
|
||||
checkAnswer(df3.select(array_union($"a", $"b")), ans3)
|
||||
checkAnswer(df3.selectExpr("array_union(a, b)"), ans3)
|
||||
|
||||
val df4 = Seq((Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array(-5L, 4L, -3L, 2L, -1L)))
|
||||
.toDF("a", "b")
|
||||
val ans4 = Row(Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L))
|
||||
checkAnswer(df4.select(array_union($"a", $"b")), ans4)
|
||||
checkAnswer(df4.selectExpr("array_union(a, b)"), ans4)
|
||||
|
||||
val df5 = Seq((Array("b", "a", "c"), Array("b", null, "a", "g"))).toDF("a", "b")
|
||||
val ans5 = Row(Seq("b", "a", "c", null, "g"))
|
||||
checkAnswer(df5.select(array_union($"a", $"b")), ans5)
|
||||
checkAnswer(df5.selectExpr("array_union(a, b)"), ans5)
|
||||
|
||||
val df6 = Seq((null, Array("a"))).toDF("a", "b")
|
||||
intercept[AnalysisException] {
|
||||
df6.select(array_union($"a", $"b"))
|
||||
}
|
||||
intercept[AnalysisException] {
|
||||
df6.selectExpr("array_union(a, b)")
|
||||
}
|
||||
|
||||
val df7 = Seq((null, null)).toDF("a", "b")
|
||||
intercept[AnalysisException] {
|
||||
df7.select(array_union($"a", $"b"))
|
||||
}
|
||||
intercept[AnalysisException] {
|
||||
df7.selectExpr("array_union(a, b)")
|
||||
}
|
||||
|
||||
val df8 = Seq((Array(Array(1)), Array("a"))).toDF("a", "b")
|
||||
intercept[AnalysisException] {
|
||||
df8.select(array_union($"a", $"b"))
|
||||
}
|
||||
intercept[AnalysisException] {
|
||||
df8.selectExpr("array_union(a, b)")
|
||||
}
|
||||
}
|
||||
|
||||
test("concat function - arrays") {
|
||||
val nseqi : Seq[Int] = null
|
||||
val nseqs : Seq[String] = null
|
||||
|
|
Loading…
Reference in a new issue