[SPARK-23917][SQL] Add array_max function
## What changes were proposed in this pull request? The PR adds the SQL function `array_max`. It takes an array as argument and returns the maximum value in it. ## How was this patch tested? added UTs Author: Marco Gaido <marcogaido91@gmail.com> Closes #21024 from mgaido91/SPARK-23917.
This commit is contained in:
parent
c0964935d6
commit
6931022031
|
@ -2080,6 +2080,21 @@ def size(col):
|
|||
return Column(sc._jvm.functions.size(_to_java_column(col)))
|
||||
|
||||
|
||||
@since(2.4)
|
||||
def array_max(col):
|
||||
"""
|
||||
Collection function: returns the maximum value of the array.
|
||||
|
||||
:param col: name of column or expression
|
||||
|
||||
>>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])
|
||||
>>> df.select(array_max(df.data).alias('max')).collect()
|
||||
[Row(max=3), Row(max=10)]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
return Column(sc._jvm.functions.array_max(_to_java_column(col)))
|
||||
|
||||
|
||||
@since(1.5)
|
||||
def sort_array(col, asc=True):
|
||||
"""
|
||||
|
|
|
@ -409,6 +409,7 @@ object FunctionRegistry {
|
|||
expression[MapValues]("map_values"),
|
||||
expression[Size]("size"),
|
||||
expression[SortArray]("sort_array"),
|
||||
expression[ArrayMax]("array_max"),
|
||||
CreateStruct.registryEntry,
|
||||
|
||||
// misc functions
|
||||
|
|
|
@ -674,11 +674,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
|
|||
val evals = evalChildren.map(eval =>
|
||||
s"""
|
||||
|${eval.code}
|
||||
|if (!${eval.isNull} && (${ev.isNull} ||
|
||||
| ${ctx.genGreater(dataType, eval.value, ev.value)})) {
|
||||
| ${ev.isNull} = false;
|
||||
| ${ev.value} = ${eval.value};
|
||||
|}
|
||||
|${ctx.reassignIfGreater(dataType, ev, eval)}
|
||||
""".stripMargin
|
||||
)
|
||||
|
||||
|
|
|
@ -699,6 +699,23 @@ class CodegenContext {
|
|||
case _ => s"(${genComp(dataType, c1, c2)}) > 0"
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates code for updating `partialResult` if `item` is greater than it.
|
||||
*
|
||||
* @param dataType data type of the expressions
|
||||
* @param partialResult `ExprCode` representing the partial result which has to be updated
|
||||
* @param item `ExprCode` representing the new expression to evaluate for the result
|
||||
*/
|
||||
def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = {
|
||||
s"""
|
||||
|if (!${item.isNull} && (${partialResult.isNull} ||
|
||||
| ${genGreater(dataType, item.value, partialResult.value)})) {
|
||||
| ${partialResult.isNull} = false;
|
||||
| ${partialResult.value} = ${item.value};
|
||||
|}
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates code to do null safe execution, i.e. only execute the code when the input is not
|
||||
* null by adding null check if necessary.
|
||||
|
|
|
@ -21,7 +21,7 @@ import java.util.Comparator
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
|
@ -287,3 +287,69 @@ case class ArrayContains(left: Expression, right: Expression)
|
|||
|
||||
override def prettyName: String = "array_contains"
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns the maximum value in the array.
|
||||
*/
|
||||
@ExpressionDescription(
|
||||
usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.",
|
||||
examples = """
|
||||
Examples:
|
||||
> SELECT _FUNC_(array(1, 20, null, 3));
|
||||
20
|
||||
""", since = "2.4.0")
|
||||
case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
|
||||
|
||||
override def nullable: Boolean = true
|
||||
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
|
||||
|
||||
private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
val typeCheckResult = super.checkInputDataTypes()
|
||||
if (typeCheckResult.isSuccess) {
|
||||
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
|
||||
} else {
|
||||
typeCheckResult
|
||||
}
|
||||
}
|
||||
|
||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val childGen = child.genCode(ctx)
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
val i = ctx.freshName("i")
|
||||
val item = ExprCode("",
|
||||
isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
|
||||
value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
|
||||
ev.copy(code =
|
||||
s"""
|
||||
|${childGen.code}
|
||||
|boolean ${ev.isNull} = true;
|
||||
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|if (!${childGen.isNull}) {
|
||||
| for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) {
|
||||
| ${ctx.reassignIfGreater(dataType, ev, item)}
|
||||
| }
|
||||
|}
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
override protected def nullSafeEval(input: Any): Any = {
|
||||
var max: Any = null
|
||||
input.asInstanceOf[ArrayData].foreach(dataType, (_, item) =>
|
||||
if (item != null && (max == null || ordering.gt(item, max))) {
|
||||
max = item
|
||||
}
|
||||
)
|
||||
max
|
||||
}
|
||||
|
||||
override def dataType: DataType = child.dataType match {
|
||||
case ArrayType(dt, _) => dt
|
||||
case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
|
||||
}
|
||||
|
||||
override def prettyName: String = "array_max"
|
||||
}
|
||||
|
|
|
@ -105,4 +105,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
checkEvaluation(ArrayContains(a3, Literal("")), null)
|
||||
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
|
||||
}
|
||||
|
||||
test("Array max") {
|
||||
checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10)
|
||||
checkEvaluation(
|
||||
ArrayMax(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "abc")
|
||||
checkEvaluation(ArrayMax(Literal.create(Seq(null), ArrayType(LongType))), null)
|
||||
checkEvaluation(ArrayMax(Literal.create(null, ArrayType(StringType))), null)
|
||||
checkEvaluation(
|
||||
ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3300,6 +3300,14 @@ object functions {
|
|||
*/
|
||||
def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) }
|
||||
|
||||
/**
|
||||
* Returns the maximum value in the array.
|
||||
*
|
||||
* @group collection_funcs
|
||||
* @since 2.4.0
|
||||
*/
|
||||
def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) }
|
||||
|
||||
/**
|
||||
* Returns an unordered array containing the keys of the map.
|
||||
* @group collection_funcs
|
||||
|
|
|
@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
|
|||
)
|
||||
}
|
||||
|
||||
test("array_max function") {
|
||||
val df = Seq(
|
||||
Seq[Option[Int]](Some(1), Some(3), Some(2)),
|
||||
Seq.empty[Option[Int]],
|
||||
Seq[Option[Int]](None),
|
||||
Seq[Option[Int]](None, Some(1), Some(-100))
|
||||
).toDF("a")
|
||||
|
||||
val answer = Seq(Row(3), Row(null), Row(null), Row(1))
|
||||
|
||||
checkAnswer(df.select(array_max(df("a"))), answer)
|
||||
checkAnswer(df.selectExpr("array_max(a)"), answer)
|
||||
}
|
||||
|
||||
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
|
||||
import DataFrameFunctionsSuite.CodegenFallbackExpr
|
||||
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
|
||||
|
|
Loading…
Reference in a new issue