[SPARK-27134][SQL] array_distinct function does not work correctly with columns containing array of array

## What changes were proposed in this pull request?
Correct the logic to compute the distinct.

Below is a small repro snippet.

```
scala> val df = Seq(Seq(Seq(1, 2), Seq(1, 2), Seq(1, 2), Seq(3, 4), Seq(4, 5))).toDF("array_col")
df: org.apache.spark.sql.DataFrame = [array_col: array<array<int>>]

scala> val distinctDF = df.select(array_distinct(col("array_col")))
distinctDF: org.apache.spark.sql.DataFrame = [array_distinct(array_col): array<array<int>>]

scala> df.show(false)
+----------------------------------------+
|array_col                               |
+----------------------------------------+
|[[1, 2], [1, 2], [1, 2], [3, 4], [4, 5]]|
+----------------------------------------+
```
Error
```
scala> distinctDF.show(false)
+-------------------------+
|array_distinct(array_col)|
+-------------------------+
|[[1, 2], [1, 2], [1, 2]] |
+-------------------------+
```
Expected result
```
scala> distinctDF.show(false)
+-------------------------+
|array_distinct(array_col)|
+-------------------------+
|[[1, 2], [3, 4], [4, 5]] |
+-------------------------+
```
## How was this patch tested?
Added an additional test.

Closes #24073 from dilipbiswal/SPARK-27134.

Authored-by: Dilip Biswal <dbiswal@us.ibm.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
Dilip Biswal 2019-03-16 14:30:42 -05:00 committed by Sean Owen
parent 7a136f8670
commit aea9a574c4
2 changed files with 28 additions and 17 deletions

View file

@ -3112,29 +3112,29 @@ case class ArrayDistinct(child: Expression)
(data: Array[AnyRef]) => new GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
} else {
(data: Array[AnyRef]) => {
var foundNullElement = false
var pos = 0
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[AnyRef]
var alreadyStoredNull = false
for (i <- 0 until data.length) {
if (data(i) == null) {
if (!foundNullElement) {
foundNullElement = true
pos = pos + 1
if (data(i) != null) {
var found = false
var j = 0
while (!found && j < arrayBuffer.size) {
val va = arrayBuffer(j)
found = (va != null) && ordering.equiv(va, data(i))
j += 1
}
if (!found) {
arrayBuffer += data(i)
}
} else {
var j = 0
var done = false
while (j <= i && !done) {
if (data(j) != null && ordering.equiv(data(j), data(i))) {
done = true
}
j = j + 1
}
if (i == j - 1) {
pos = pos + 1
// De-duplicate the null values.
if (!alreadyStoredNull) {
arrayBuffer += data(i)
alreadyStoredNull = true
}
}
}
new GenericArrayData(data.slice(0, pos))
new GenericArrayData(arrayBuffer)
}
}

View file

@ -1364,6 +1364,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
ArrayType(DoubleType))
val a7 = Literal.create(Seq(1.123f, 0.1234f, 1.121f, 1.123f, 1.1230f, 1.121f, 0.1234f),
ArrayType(FloatType))
val a8 =
Literal.create(Seq(2, 1, 2, 3, 4, 4, 5).map(_.toString.getBytes), ArrayType(BinaryType))
checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5))
checkEvaluation(new ArrayDistinct(a1), Seq.empty[Integer])
@ -1373,6 +1375,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new ArrayDistinct(a5), Seq(true, false))
checkEvaluation(new ArrayDistinct(a6), Seq(1.123, 0.1234, 1.121))
checkEvaluation(new ArrayDistinct(a7), Seq(1.123f, 0.1234f, 1.121f))
checkEvaluation(new ArrayDistinct(a8), Seq(2, 1, 3, 4, 5).map(_.toString.getBytes))
// complex data types
val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2),
@ -1393,9 +1396,17 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
ArrayType(ArrayType(IntegerType)))
val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1), null, null, Seq[Int](2, 1), null),
ArrayType(ArrayType(IntegerType)))
val c3 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](1, 2), Seq[Int](1, 2),
Seq[Int](3, 4), Seq[Int](4, 5)), ArrayType(ArrayType(IntegerType)))
val c4 = Literal.create(Seq[Seq[Int]](null, Seq[Int](1, 2), Seq[Int](1, 2),
Seq[Int](3, 4), Seq[Int](4, 5), null), ArrayType(ArrayType(IntegerType)))
checkEvaluation(ArrayDistinct(c0), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)))
checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)))
checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1)))
checkEvaluation(ArrayDistinct(c3), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4),
Seq[Int](4, 5)))
checkEvaluation(ArrayDistinct(c4), Seq[Seq[Int]](null, Seq[Int](1, 2), Seq[Int](3, 4),
Seq[Int](4, 5)))
}
test("Array Union") {