[SPARK-12011][SQL] Stddev/Variance etc should support columnName as arguments
Spark SQL aggregate function: ```Java stddev stddev_pop stddev_samp variance var_pop var_samp skewness kurtosis collect_list collect_set ``` should support ```columnName``` as arguments like other aggregate function(max/min/count/sum). Author: Yanbo Liang <ybliang8@gmail.com> Closes #9994 from yanboliang/SPARK-12011.
This commit is contained in:
parent
0c1e72e7f7
commit
6f6bb0e893
|
@ -214,6 +214,16 @@ object functions extends LegacyFunctions {
|
|||
*/
|
||||
def collect_list(e: Column): Column = callUDF("collect_list", e)
|
||||
|
||||
/**
|
||||
* Aggregate function: returns a list of objects with duplicates.
|
||||
*
|
||||
* For now this is an alias for the collect_list Hive UDAF.
|
||||
*
|
||||
* @group agg_funcs
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def collect_list(columnName: String): Column = collect_list(Column(columnName))
|
||||
|
||||
/**
|
||||
* Aggregate function: returns a set of objects with duplicate elements eliminated.
|
||||
*
|
||||
|
@ -224,6 +234,16 @@ object functions extends LegacyFunctions {
|
|||
*/
|
||||
def collect_set(e: Column): Column = callUDF("collect_set", e)
|
||||
|
||||
/**
|
||||
* Aggregate function: returns a set of objects with duplicate elements eliminated.
|
||||
*
|
||||
* For now this is an alias for the collect_set Hive UDAF.
|
||||
*
|
||||
* @group agg_funcs
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def collect_set(columnName: String): Column = collect_set(Column(columnName))
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the Pearson Correlation Coefficient for two columns.
|
||||
*
|
||||
|
@ -312,6 +332,14 @@ object functions extends LegacyFunctions {
|
|||
*/
|
||||
def kurtosis(e: Column): Column = withAggregateFunction { Kurtosis(e.expr) }
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the kurtosis of the values in a group.
|
||||
*
|
||||
* @group agg_funcs
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def kurtosis(columnName: String): Column = kurtosis(Column(columnName))
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the last value in a group.
|
||||
*
|
||||
|
@ -386,6 +414,14 @@ object functions extends LegacyFunctions {
|
|||
*/
|
||||
def skewness(e: Column): Column = withAggregateFunction { Skewness(e.expr) }
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the skewness of the values in a group.
|
||||
*
|
||||
* @group agg_funcs
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def skewness(columnName: String): Column = skewness(Column(columnName))
|
||||
|
||||
/**
|
||||
* Aggregate function: alias for [[stddev_samp]].
|
||||
*
|
||||
|
@ -394,6 +430,14 @@ object functions extends LegacyFunctions {
|
|||
*/
|
||||
def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) }
|
||||
|
||||
/**
|
||||
* Aggregate function: alias for [[stddev_samp]].
|
||||
*
|
||||
* @group agg_funcs
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def stddev(columnName: String): Column = stddev(Column(columnName))
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the sample standard deviation of
|
||||
* the expression in a group.
|
||||
|
@ -403,6 +447,15 @@ object functions extends LegacyFunctions {
|
|||
*/
|
||||
def stddev_samp(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) }
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the sample standard deviation of
|
||||
* the expression in a group.
|
||||
*
|
||||
* @group agg_funcs
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def stddev_samp(columnName: String): Column = stddev_samp(Column(columnName))
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the population standard deviation of
|
||||
* the expression in a group.
|
||||
|
@ -412,6 +465,15 @@ object functions extends LegacyFunctions {
|
|||
*/
|
||||
def stddev_pop(e: Column): Column = withAggregateFunction { StddevPop(e.expr) }
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the population standard deviation of
|
||||
* the expression in a group.
|
||||
*
|
||||
* @group agg_funcs
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def stddev_pop(columnName: String): Column = stddev_pop(Column(columnName))
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the sum of all values in the expression.
|
||||
*
|
||||
|
@ -452,6 +514,14 @@ object functions extends LegacyFunctions {
|
|||
*/
|
||||
def variance(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) }
|
||||
|
||||
/**
|
||||
* Aggregate function: alias for [[var_samp]].
|
||||
*
|
||||
* @group agg_funcs
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def variance(columnName: String): Column = variance(Column(columnName))
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the unbiased variance of the values in a group.
|
||||
*
|
||||
|
@ -460,6 +530,14 @@ object functions extends LegacyFunctions {
|
|||
*/
|
||||
def var_samp(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) }
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the unbiased variance of the values in a group.
|
||||
*
|
||||
* @group agg_funcs
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def var_samp(columnName: String): Column = var_samp(Column(columnName))
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the population variance of the values in a group.
|
||||
*
|
||||
|
@ -468,6 +546,14 @@ object functions extends LegacyFunctions {
|
|||
*/
|
||||
def var_pop(e: Column): Column = withAggregateFunction { VariancePop(e.expr) }
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the population variance of the values in a group.
|
||||
*
|
||||
* @group agg_funcs
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def var_pop(columnName: String): Column = var_pop(Column(columnName))
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Window functions
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -261,6 +261,9 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
|
|||
checkAnswer(
|
||||
testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)),
|
||||
Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev))
|
||||
checkAnswer(
|
||||
testData2.agg(stddev("a"), stddev_pop("a"), stddev_samp("a")),
|
||||
Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev))
|
||||
}
|
||||
|
||||
test("zero stddev") {
|
||||
|
|
Loading…
Reference in a new issue