[SPARK-23917][SQL] Add array_max function

## What changes were proposed in this pull request?

The PR adds the SQL function `array_max`. It takes an array as argument and returns the maximum value in it.

## How was this patch tested?

added UTs

Author: Marco Gaido <marcogaido91@gmail.com>

Closes #21024 from mgaido91/SPARK-23917.
This commit is contained in:
Marco Gaido 2018-04-15 21:45:55 -07:00 committed by gatorsmile
parent c0964935d6
commit 6931022031
8 changed files with 133 additions and 6 deletions

View file

@ -2080,6 +2080,21 @@ def size(col):
return Column(sc._jvm.functions.size(_to_java_column(col)))
@since(2.4)
def array_max(col):
"""
Collection function: returns the maximum value of the array.
:param col: name of column or expression
>>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])
>>> df.select(array_max(df.data).alias('max')).collect()
[Row(max=3), Row(max=10)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.array_max(_to_java_column(col)))
@since(1.5)
def sort_array(col, asc=True):
"""

View file

@ -409,6 +409,7 @@ object FunctionRegistry {
expression[MapValues]("map_values"),
expression[Size]("size"),
expression[SortArray]("sort_array"),
expression[ArrayMax]("array_max"),
CreateStruct.registryEntry,
// misc functions

View file

@ -674,11 +674,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
val evals = evalChildren.map(eval =>
s"""
|${eval.code}
|if (!${eval.isNull} && (${ev.isNull} ||
| ${ctx.genGreater(dataType, eval.value, ev.value)})) {
| ${ev.isNull} = false;
| ${ev.value} = ${eval.value};
|}
|${ctx.reassignIfGreater(dataType, ev, eval)}
""".stripMargin
)

View file

@ -699,6 +699,23 @@ class CodegenContext {
case _ => s"(${genComp(dataType, c1, c2)}) > 0"
}
/**
* Generates code for updating `partialResult` if `item` is greater than it.
*
* @param dataType data type of the expressions
* @param partialResult `ExprCode` representing the partial result which has to be updated
* @param item `ExprCode` representing the new expression to evaluate for the result
*/
def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = {
s"""
|if (!${item.isNull} && (${partialResult.isNull} ||
| ${genGreater(dataType, item.value, partialResult.value)})) {
| ${partialResult.isNull} = false;
| ${partialResult.value} = ${item.value};
|}
""".stripMargin
}
/**
* Generates code to do null safe execution, i.e. only execute the code when the input is not
* null by adding null check if necessary.

View file

@ -21,7 +21,7 @@ import java.util.Comparator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
/**
@ -287,3 +287,69 @@ case class ArrayContains(left: Expression, right: Expression)
override def prettyName: String = "array_contains"
}
/**
* Returns the maximum value in the array.
*/
@ExpressionDescription(
usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 20, null, 3));
20
""", since = "2.4.0")
case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
override def checkInputDataTypes(): TypeCheckResult = {
val typeCheckResult = super.checkInputDataTypes()
if (typeCheckResult.isSuccess) {
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
} else {
typeCheckResult
}
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
val javaType = CodeGenerator.javaType(dataType)
val i = ctx.freshName("i")
val item = ExprCode("",
isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
ev.copy(code =
s"""
|${childGen.code}
|boolean ${ev.isNull} = true;
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${childGen.isNull}) {
| for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) {
| ${ctx.reassignIfGreater(dataType, ev, item)}
| }
|}
""".stripMargin)
}
override protected def nullSafeEval(input: Any): Any = {
var max: Any = null
input.asInstanceOf[ArrayData].foreach(dataType, (_, item) =>
if (item != null && (max == null || ordering.gt(item, max))) {
max = item
}
)
max
}
override def dataType: DataType = child.dataType match {
case ArrayType(dt, _) => dt
case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
}
override def prettyName: String = "array_max"
}

View file

@ -105,4 +105,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayContains(a3, Literal("")), null)
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
}
test("Array max") {
checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10)
checkEvaluation(
ArrayMax(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "abc")
checkEvaluation(ArrayMax(Literal.create(Seq(null), ArrayType(LongType))), null)
checkEvaluation(ArrayMax(Literal.create(null, ArrayType(StringType))), null)
checkEvaluation(
ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123)
}
}

View file

@ -3300,6 +3300,14 @@ object functions {
*/
def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) }
/**
* Returns the maximum value in the array.
*
* @group collection_funcs
* @since 2.4.0
*/
def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) }
/**
* Returns an unordered array containing the keys of the map.
* @group collection_funcs

View file

@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}
test("array_max function") {
val df = Seq(
Seq[Option[Int]](Some(1), Some(3), Some(2)),
Seq.empty[Option[Int]],
Seq[Option[Int]](None),
Seq[Option[Int]](None, Some(1), Some(-100))
).toDF("a")
val answer = Seq(Row(3), Row(null), Row(null), Row(1))
checkAnswer(df.select(array_max(df("a"))), answer)
checkAnswer(df.selectExpr("array_max(a)"), answer)
}
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {