diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a50263c852..1182194e4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -914,9 +914,9 @@ object ArraySortLike { @ExpressionDescription( usage = """ _FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order - according to the natural ordering of the array elements. Null elements will be placed - at the beginning of the returned array in ascending order or at the end of the returned - array in descending order. + according to the natural ordering of the array elements. NaN is greater than any non-NaN + elements for double/float type. Null elements will be placed at the beginning of the returned + array in ascending order or at the end of the returned array in descending order. """, examples = """ Examples: @@ -1767,7 +1767,9 @@ case class ArrayJoin( * Returns the minimum value in the array. */ @ExpressionDescription( - usage = "_FUNC_(array) - Returns the minimum value in the array. NULL elements are skipped.", + usage = """ + _FUNC_(array) - Returns the minimum value in the array. NaN is greater than + any non-NaN elements for double/float type. NULL elements are skipped.""", examples = """ Examples: > SELECT _FUNC_(array(1, 20, null, 3)); @@ -1838,7 +1840,9 @@ case class ArrayMin(child: Expression) * Returns the maximum value in the array. */ @ExpressionDescription( - usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.", + usage = """ + _FUNC_(array) - Returns the maximum value in the array. NaN is greater than + any non-NaN elements for double/float type. NULL elements are skipped.""", examples = """ Examples: > SELECT _FUNC_(array(1, 20, null, 3)); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index bf9cc6cd65..bbcd3b4957 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -351,10 +351,12 @@ case class ArrayTransform( // scalastyle:off line.size.limit @ExpressionDescription( usage = """_FUNC_(expr, func) - Sorts the input array. If func is omitted, sort - in ascending order. The elements of the input array must be orderable. Null elements - will be placed at the end of the returned array. Since 3.0.0 this function also sorts - and returns the array based on the given comparator function. The comparator will - take two arguments representing two elements of the array. + in ascending order. The elements of the input array must be orderable. + NaN is greater than any non-NaN elements for double/float type. + Null elements will be placed at the end of the returned array. + Since 3.0.0 this function also sorts and returns the array based on the + given comparator function. The comparator will take two arguments representing + two elements of the array. It returns -1, 0, or 1 as the first element is less than, equal to, or greater than the second element. If the comparator function returns other values (including null), the function will fail and raise an error. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 62098bc840..46b7a8fc07 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -2331,4 +2331,21 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Literal.create(Seq(Float.NaN, null), ArrayType(FloatType)), Literal.create(Seq(Float.NaN, null, 1f), ArrayType(FloatType))), true) } + + test("SPARK-36740: ArrayMin/ArrayMax/SortArray should handle NaN greater then non-NaN value") { + // ArrayMin + checkEvaluation(ArrayMin( + Literal.create(Seq(Double.NaN, 1d, 2d), ArrayType(DoubleType))), 1d) + checkEvaluation(ArrayMin( + Literal.create(Seq(Double.NaN, 1d, 2d, null), ArrayType(DoubleType))), 1d) + // ArrayMax + checkEvaluation(ArrayMax( + Literal.create(Seq(Double.NaN, 1d, 2d), ArrayType(DoubleType))), Double.NaN) + checkEvaluation(ArrayMax( + Literal.create(Seq(Double.NaN, 1d, 2d, null), ArrayType(DoubleType))), Double.NaN) + // SortArray + checkEvaluation(new SortArray( + Literal.create(Seq(Double.NaN, 1d, 2d, null), ArrayType(DoubleType))), + Seq(null, 1d, 2d, Double.NaN)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index bb565ceb38..c0db6d8dc2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -832,4 +832,10 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(mapFilter2_1.semanticEquals(mapFilter2_2)) assert(!mapFilter2_1.semanticEquals(mapFilter2_3)) } + + test("SPARK-36740: ArraySort should handle NaN greater then non-NaN value") { + checkEvaluation(arraySort( + Literal.create(Seq(Double.NaN, 1d, 2d, null), ArrayType(DoubleType))), + Seq(1d, 2d, Double.NaN, null)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index da82ac5211..a4c77b20c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3815,6 +3815,7 @@ object functions { /** * Sorts the input array in ascending order. The elements of the input array must be orderable. + * NaN is greater than any non-NaN elements for double/float type. * Null elements will be placed at the end of the returned array. * * @group collection_funcs @@ -4524,8 +4525,9 @@ object functions { /** * Sorts the input array for the given column in ascending or descending order, - * according to the natural ordering of the array elements. - * Null elements will be placed at the beginning of the returned array in ascending order or + * according to the natural ordering of the array elements. NaN is greater than any non-NaN + * elements for double/float type. Null elements will be placed at the beginning of the returned + * array in ascending order or * at the end of the returned array in descending order. * * @group collection_funcs @@ -4534,7 +4536,8 @@ object functions { def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } /** - * Returns the minimum value in the array. + * Returns the minimum value in the array. NaN is greater than any non-NaN elements for + * double/float type. NULL elements are skipped. * * @group collection_funcs * @since 2.4.0 @@ -4542,7 +4545,8 @@ object functions { def array_min(e: Column): Column = withExpr { ArrayMin(e.expr) } /** - * Returns the maximum value in the array. + * Returns the maximum value in the array. NaN is greater than any non-NaN elements for + * double/float type. NULL elements are skipped. * * @group collection_funcs * @since 2.4.0