[SPARK-36767][SQL] ArrayMin/ArrayMax/SortArray/ArraySort add comment and Unit test

### What changes were proposed in this pull request?
Add comment about how ArrayMin/ArrayMax/SortArray/ArraySort handle NaN and add Unit test for this

### Why are the changes needed?
Add Unit test

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Added UT

Closes #34008 from AngersZhuuuu/SPARK-36740.

Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 69e006dd53)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Angerszhuuuu 2021-09-17 21:42:08 +08:00 committed by Wenchen Fan
parent 895218996a
commit 61d7f1da1b
5 changed files with 46 additions and 13 deletions

View file

@ -914,9 +914,9 @@ object ArraySortLike {
@ExpressionDescription( @ExpressionDescription(
usage = """ usage = """
_FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order _FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order
according to the natural ordering of the array elements. Null elements will be placed according to the natural ordering of the array elements. NaN is greater than any non-NaN
at the beginning of the returned array in ascending order or at the end of the returned elements for double/float type. Null elements will be placed at the beginning of the returned
array in descending order. array in ascending order or at the end of the returned array in descending order.
""", """,
examples = """ examples = """
Examples: Examples:
@ -1767,7 +1767,9 @@ case class ArrayJoin(
* Returns the minimum value in the array. * Returns the minimum value in the array.
*/ */
@ExpressionDescription( @ExpressionDescription(
usage = "_FUNC_(array) - Returns the minimum value in the array. NULL elements are skipped.", usage = """
_FUNC_(array) - Returns the minimum value in the array. NaN is greater than
any non-NaN elements for double/float type. NULL elements are skipped.""",
examples = """ examples = """
Examples: Examples:
> SELECT _FUNC_(array(1, 20, null, 3)); > SELECT _FUNC_(array(1, 20, null, 3));
@ -1838,7 +1840,9 @@ case class ArrayMin(child: Expression)
* Returns the maximum value in the array. * Returns the maximum value in the array.
*/ */
@ExpressionDescription( @ExpressionDescription(
usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.", usage = """
_FUNC_(array) - Returns the maximum value in the array. NaN is greater than
any non-NaN elements for double/float type. NULL elements are skipped.""",
examples = """ examples = """
Examples: Examples:
> SELECT _FUNC_(array(1, 20, null, 3)); > SELECT _FUNC_(array(1, 20, null, 3));

View file

@ -351,10 +351,12 @@ case class ArrayTransform(
// scalastyle:off line.size.limit // scalastyle:off line.size.limit
@ExpressionDescription( @ExpressionDescription(
usage = """_FUNC_(expr, func) - Sorts the input array. If func is omitted, sort usage = """_FUNC_(expr, func) - Sorts the input array. If func is omitted, sort
in ascending order. The elements of the input array must be orderable. Null elements in ascending order. The elements of the input array must be orderable.
will be placed at the end of the returned array. Since 3.0.0 this function also sorts NaN is greater than any non-NaN elements for double/float type.
and returns the array based on the given comparator function. The comparator will Null elements will be placed at the end of the returned array.
take two arguments representing two elements of the array. Since 3.0.0 this function also sorts and returns the array based on the
given comparator function. The comparator will take two arguments representing
two elements of the array.
It returns -1, 0, or 1 as the first element is less than, equal to, or greater It returns -1, 0, or 1 as the first element is less than, equal to, or greater
than the second element. If the comparator function returns other than the second element. If the comparator function returns other
values (including null), the function will fail and raise an error. values (including null), the function will fail and raise an error.

View file

@ -2331,4 +2331,21 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
Literal.create(Seq(Float.NaN, null), ArrayType(FloatType)), Literal.create(Seq(Float.NaN, null), ArrayType(FloatType)),
Literal.create(Seq(Float.NaN, null, 1f), ArrayType(FloatType))), true) Literal.create(Seq(Float.NaN, null, 1f), ArrayType(FloatType))), true)
} }
test("SPARK-36740: ArrayMin/ArrayMax/SortArray should handle NaN greater then non-NaN value") {
// ArrayMin
checkEvaluation(ArrayMin(
Literal.create(Seq(Double.NaN, 1d, 2d), ArrayType(DoubleType))), 1d)
checkEvaluation(ArrayMin(
Literal.create(Seq(Double.NaN, 1d, 2d, null), ArrayType(DoubleType))), 1d)
// ArrayMax
checkEvaluation(ArrayMax(
Literal.create(Seq(Double.NaN, 1d, 2d), ArrayType(DoubleType))), Double.NaN)
checkEvaluation(ArrayMax(
Literal.create(Seq(Double.NaN, 1d, 2d, null), ArrayType(DoubleType))), Double.NaN)
// SortArray
checkEvaluation(new SortArray(
Literal.create(Seq(Double.NaN, 1d, 2d, null), ArrayType(DoubleType))),
Seq(null, 1d, 2d, Double.NaN))
}
} }

View file

@ -832,4 +832,10 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
assert(mapFilter2_1.semanticEquals(mapFilter2_2)) assert(mapFilter2_1.semanticEquals(mapFilter2_2))
assert(!mapFilter2_1.semanticEquals(mapFilter2_3)) assert(!mapFilter2_1.semanticEquals(mapFilter2_3))
} }
test("SPARK-36740: ArraySort should handle NaN greater then non-NaN value") {
checkEvaluation(arraySort(
Literal.create(Seq(Double.NaN, 1d, 2d, null), ArrayType(DoubleType))),
Seq(1d, 2d, Double.NaN, null))
}
} }

View file

@ -3815,6 +3815,7 @@ object functions {
/** /**
* Sorts the input array in ascending order. The elements of the input array must be orderable. * Sorts the input array in ascending order. The elements of the input array must be orderable.
* NaN is greater than any non-NaN elements for double/float type.
* Null elements will be placed at the end of the returned array. * Null elements will be placed at the end of the returned array.
* *
* @group collection_funcs * @group collection_funcs
@ -4524,8 +4525,9 @@ object functions {
/** /**
* Sorts the input array for the given column in ascending or descending order, * Sorts the input array for the given column in ascending or descending order,
* according to the natural ordering of the array elements. * according to the natural ordering of the array elements. NaN is greater than any non-NaN
* Null elements will be placed at the beginning of the returned array in ascending order or * elements for double/float type. Null elements will be placed at the beginning of the returned
* array in ascending order or
* at the end of the returned array in descending order. * at the end of the returned array in descending order.
* *
* @group collection_funcs * @group collection_funcs
@ -4534,7 +4536,8 @@ object functions {
def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) }
/** /**
* Returns the minimum value in the array. * Returns the minimum value in the array. NaN is greater than any non-NaN elements for
* double/float type. NULL elements are skipped.
* *
* @group collection_funcs * @group collection_funcs
* @since 2.4.0 * @since 2.4.0
@ -4542,7 +4545,8 @@ object functions {
def array_min(e: Column): Column = withExpr { ArrayMin(e.expr) } def array_min(e: Column): Column = withExpr { ArrayMin(e.expr) }
/** /**
* Returns the maximum value in the array. * Returns the maximum value in the array. NaN is greater than any non-NaN elements for
* double/float type. NULL elements are skipped.
* *
* @group collection_funcs * @group collection_funcs
* @since 2.4.0 * @since 2.4.0