[SPARK-27134][SQL] array_distinct function does not work correctly with columns containing array of array
## What changes were proposed in this pull request? Correct the logic to compute the distinct. Below is a small repro snippet. ``` scala> val df = Seq(Seq(Seq(1, 2), Seq(1, 2), Seq(1, 2), Seq(3, 4), Seq(4, 5))).toDF("array_col") df: org.apache.spark.sql.DataFrame = [array_col: array<array<int>>] scala> val distinctDF = df.select(array_distinct(col("array_col"))) distinctDF: org.apache.spark.sql.DataFrame = [array_distinct(array_col): array<array<int>>] scala> df.show(false) +----------------------------------------+ |array_col | +----------------------------------------+ |[[1, 2], [1, 2], [1, 2], [3, 4], [4, 5]]| +----------------------------------------+ ``` Error ``` scala> distinctDF.show(false) +-------------------------+ |array_distinct(array_col)| +-------------------------+ |[[1, 2], [1, 2], [1, 2]] | +-------------------------+ ``` Expected result ``` scala> distinctDF.show(false) +-------------------------+ |array_distinct(array_col)| +-------------------------+ |[[1, 2], [3, 4], [4, 5]] | +-------------------------+ ``` ## How was this patch tested? Added an additional test. Closes #24073 from dilipbiswal/SPARK-27134. Authored-by: Dilip Biswal <dbiswal@us.ibm.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
parent
7a136f8670
commit
aea9a574c4
|
@ -3112,29 +3112,29 @@ case class ArrayDistinct(child: Expression)
|
|||
(data: Array[AnyRef]) => new GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
|
||||
} else {
|
||||
(data: Array[AnyRef]) => {
|
||||
var foundNullElement = false
|
||||
var pos = 0
|
||||
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[AnyRef]
|
||||
var alreadyStoredNull = false
|
||||
for (i <- 0 until data.length) {
|
||||
if (data(i) == null) {
|
||||
if (!foundNullElement) {
|
||||
foundNullElement = true
|
||||
pos = pos + 1
|
||||
if (data(i) != null) {
|
||||
var found = false
|
||||
var j = 0
|
||||
while (!found && j < arrayBuffer.size) {
|
||||
val va = arrayBuffer(j)
|
||||
found = (va != null) && ordering.equiv(va, data(i))
|
||||
j += 1
|
||||
}
|
||||
if (!found) {
|
||||
arrayBuffer += data(i)
|
||||
}
|
||||
} else {
|
||||
var j = 0
|
||||
var done = false
|
||||
while (j <= i && !done) {
|
||||
if (data(j) != null && ordering.equiv(data(j), data(i))) {
|
||||
done = true
|
||||
}
|
||||
j = j + 1
|
||||
}
|
||||
if (i == j - 1) {
|
||||
pos = pos + 1
|
||||
// De-duplicate the null values.
|
||||
if (!alreadyStoredNull) {
|
||||
arrayBuffer += data(i)
|
||||
alreadyStoredNull = true
|
||||
}
|
||||
}
|
||||
}
|
||||
new GenericArrayData(data.slice(0, pos))
|
||||
new GenericArrayData(arrayBuffer)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1364,6 +1364,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
ArrayType(DoubleType))
|
||||
val a7 = Literal.create(Seq(1.123f, 0.1234f, 1.121f, 1.123f, 1.1230f, 1.121f, 0.1234f),
|
||||
ArrayType(FloatType))
|
||||
val a8 =
|
||||
Literal.create(Seq(2, 1, 2, 3, 4, 4, 5).map(_.toString.getBytes), ArrayType(BinaryType))
|
||||
|
||||
checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5))
|
||||
checkEvaluation(new ArrayDistinct(a1), Seq.empty[Integer])
|
||||
|
@ -1373,6 +1375,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
checkEvaluation(new ArrayDistinct(a5), Seq(true, false))
|
||||
checkEvaluation(new ArrayDistinct(a6), Seq(1.123, 0.1234, 1.121))
|
||||
checkEvaluation(new ArrayDistinct(a7), Seq(1.123f, 0.1234f, 1.121f))
|
||||
checkEvaluation(new ArrayDistinct(a8), Seq(2, 1, 3, 4, 5).map(_.toString.getBytes))
|
||||
|
||||
// complex data types
|
||||
val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2),
|
||||
|
@ -1393,9 +1396,17 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
ArrayType(ArrayType(IntegerType)))
|
||||
val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1), null, null, Seq[Int](2, 1), null),
|
||||
ArrayType(ArrayType(IntegerType)))
|
||||
val c3 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](1, 2), Seq[Int](1, 2),
|
||||
Seq[Int](3, 4), Seq[Int](4, 5)), ArrayType(ArrayType(IntegerType)))
|
||||
val c4 = Literal.create(Seq[Seq[Int]](null, Seq[Int](1, 2), Seq[Int](1, 2),
|
||||
Seq[Int](3, 4), Seq[Int](4, 5), null), ArrayType(ArrayType(IntegerType)))
|
||||
checkEvaluation(ArrayDistinct(c0), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)))
|
||||
checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)))
|
||||
checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1)))
|
||||
checkEvaluation(ArrayDistinct(c3), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4),
|
||||
Seq[Int](4, 5)))
|
||||
checkEvaluation(ArrayDistinct(c4), Seq[Seq[Int]](null, Seq[Int](1, 2), Seq[Int](3, 4),
|
||||
Seq[Int](4, 5)))
|
||||
}
|
||||
|
||||
test("Array Union") {
|
||||
|
|
Loading…
Reference in a new issue