[SPARK-11490][SQL] variance should alias var_samp instead of var_pop.
stddev is an alias for stddev_samp. variance should be consistent with stddev. Also took the chance to remove internal Stddev and Variance, and only kept StddevSamp/StddevPop and VarianceSamp/VariancePop. Author: Reynold Xin <rxin@databricks.com> Closes #9449 from rxin/SPARK-11490.
This commit is contained in:
parent
e0fc9c7e59
commit
3bd6f5d2ae
|
@ -187,11 +187,11 @@ object FunctionRegistry {
|
|||
expression[Max]("max"),
|
||||
expression[Average]("mean"),
|
||||
expression[Min]("min"),
|
||||
expression[Stddev]("stddev"),
|
||||
expression[StddevSamp]("stddev"),
|
||||
expression[StddevPop]("stddev_pop"),
|
||||
expression[StddevSamp]("stddev_samp"),
|
||||
expression[Sum]("sum"),
|
||||
expression[Variance]("variance"),
|
||||
expression[VarianceSamp]("variance"),
|
||||
expression[VariancePop]("var_pop"),
|
||||
expression[VarianceSamp]("var_samp"),
|
||||
expression[Skewness]("skewness"),
|
||||
|
|
|
@ -297,10 +297,8 @@ object HiveTypeCoercion {
|
|||
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
|
||||
case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType))
|
||||
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
|
||||
case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType))
|
||||
case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
|
||||
case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
|
||||
case Variance(e @ StringType()) => Variance(Cast(e, DoubleType))
|
||||
case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType))
|
||||
case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType))
|
||||
case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType))
|
||||
|
|
|
@ -159,14 +159,6 @@ package object dsl {
|
|||
def lower(e: Expression): Expression = Lower(e)
|
||||
def sqrt(e: Expression): Expression = Sqrt(e)
|
||||
def abs(e: Expression): Expression = Abs(e)
|
||||
def stddev(e: Expression): Expression = Stddev(e)
|
||||
def stddev_pop(e: Expression): Expression = StddevPop(e)
|
||||
def stddev_samp(e: Expression): Expression = StddevSamp(e)
|
||||
def variance(e: Expression): Expression = Variance(e)
|
||||
def var_pop(e: Expression): Expression = VariancePop(e)
|
||||
def var_samp(e: Expression): Expression = VarianceSamp(e)
|
||||
def skewness(e: Expression): Expression = Skewness(e)
|
||||
def kurtosis(e: Expression): Expression = Kurtosis(e)
|
||||
|
||||
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
|
||||
// TODO more implicit class for literal?
|
||||
|
|
|
@ -328,13 +328,6 @@ case class Min(child: Expression) extends DeclarativeAggregate {
|
|||
override val evaluateExpression = min
|
||||
}
|
||||
|
||||
// Compute the sample standard deviation of a column
|
||||
case class Stddev(child: Expression) extends StddevAgg(child) {
|
||||
|
||||
override def isSample: Boolean = true
|
||||
override def prettyName: String = "stddev"
|
||||
}
|
||||
|
||||
// Compute the population standard deviation of a column
|
||||
case class StddevPop(child: Expression) extends StddevAgg(child) {
|
||||
|
||||
|
@ -1274,28 +1267,6 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w
|
|||
}
|
||||
}
|
||||
|
||||
case class Variance(child: Expression,
|
||||
mutableAggBufferOffset: Int = 0,
|
||||
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {
|
||||
|
||||
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
|
||||
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
|
||||
|
||||
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
|
||||
copy(inputAggBufferOffset = newInputAggBufferOffset)
|
||||
|
||||
override def prettyName: String = "variance"
|
||||
|
||||
override protected val momentOrder = 2
|
||||
|
||||
override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
|
||||
require(moments.length == momentOrder + 1,
|
||||
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
|
||||
|
||||
if (n == 0.0) Double.NaN else moments(2) / n
|
||||
}
|
||||
}
|
||||
|
||||
case class VarianceSamp(child: Expression,
|
||||
mutableAggBufferOffset: Int = 0,
|
||||
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {
|
||||
|
|
|
@ -97,12 +97,6 @@ object Utils {
|
|||
mode = aggregate.Complete,
|
||||
isDistinct = false)
|
||||
|
||||
case expressions.Stddev(child) =>
|
||||
aggregate.AggregateExpression2(
|
||||
aggregateFunction = aggregate.Stddev(child),
|
||||
mode = aggregate.Complete,
|
||||
isDistinct = false)
|
||||
|
||||
case expressions.StddevPop(child) =>
|
||||
aggregate.AggregateExpression2(
|
||||
aggregateFunction = aggregate.StddevPop(child),
|
||||
|
@ -139,12 +133,6 @@ object Utils {
|
|||
mode = aggregate.Complete,
|
||||
isDistinct = false)
|
||||
|
||||
case expressions.Variance(child) =>
|
||||
aggregate.AggregateExpression2(
|
||||
aggregateFunction = aggregate.Variance(child),
|
||||
mode = aggregate.Complete,
|
||||
isDistinct = false)
|
||||
|
||||
case expressions.VariancePop(child) =>
|
||||
aggregate.AggregateExpression2(
|
||||
aggregateFunction = aggregate.VariancePop(child),
|
||||
|
|
|
@ -785,13 +785,6 @@ abstract class StddevAgg1(child: Expression) extends UnaryExpression with Partia
|
|||
|
||||
}
|
||||
|
||||
// Compute the sample standard deviation of a column
|
||||
case class Stddev(child: Expression) extends StddevAgg1(child) {
|
||||
|
||||
override def toString: String = s"STDDEV($child)"
|
||||
override def isSample: Boolean = true
|
||||
}
|
||||
|
||||
// Compute the population standard deviation of a column
|
||||
case class StddevPop(child: Expression) extends StddevAgg1(child) {
|
||||
|
||||
|
@ -807,20 +800,21 @@ case class StddevSamp(child: Expression) extends StddevAgg1(child) {
|
|||
}
|
||||
|
||||
case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1 {
|
||||
def this() = this(null)
|
||||
def this() = this(null)
|
||||
|
||||
override def children: Seq[Expression] = child :: Nil
|
||||
override def nullable: Boolean = false
|
||||
override def dataType: DataType = ArrayType(DoubleType)
|
||||
override def toString: String = s"computePartialStddev($child)"
|
||||
override def newInstance(): ComputePartialStdFunction =
|
||||
new ComputePartialStdFunction(child, this)
|
||||
override def children: Seq[Expression] = child :: Nil
|
||||
override def nullable: Boolean = false
|
||||
override def dataType: DataType = ArrayType(DoubleType)
|
||||
override def toString: String = s"computePartialStddev($child)"
|
||||
override def newInstance(): ComputePartialStdFunction =
|
||||
new ComputePartialStdFunction(child, this)
|
||||
}
|
||||
|
||||
case class ComputePartialStdFunction (
|
||||
expr: Expression,
|
||||
base: AggregateExpression1
|
||||
) extends AggregateFunction1 {
|
||||
) extends AggregateFunction1 {
|
||||
|
||||
def this() = this(null, null) // Required for serialization
|
||||
|
||||
private val computeType = DoubleType
|
||||
|
@ -1048,25 +1042,6 @@ case class Skewness(child: Expression) extends UnaryExpression with AggregateExp
|
|||
override def toString: String = s"SKEWNESS($child)"
|
||||
}
|
||||
|
||||
// placeholder
|
||||
case class Variance(child: Expression) extends UnaryExpression with AggregateExpression1 {
|
||||
|
||||
override def newInstance(): AggregateFunction1 = {
|
||||
throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
|
||||
"please set spark.sql.useAggregate2 = true")
|
||||
}
|
||||
|
||||
override def nullable: Boolean = false
|
||||
|
||||
override def dataType: DoubleType.type = DoubleType
|
||||
|
||||
override def foldable: Boolean = false
|
||||
|
||||
override def prettyName: String = "variance"
|
||||
|
||||
override def toString: String = s"VARIANCE($child)"
|
||||
}
|
||||
|
||||
// placeholder
|
||||
case class VariancePop(child: Expression) extends UnaryExpression with AggregateExpression1 {
|
||||
|
||||
|
|
|
@ -1383,7 +1383,7 @@ class DataFrame private[sql](
|
|||
val statistics = List[(String, Expression => Expression)](
|
||||
"count" -> Count,
|
||||
"mean" -> Average,
|
||||
"stddev" -> Stddev,
|
||||
"stddev" -> StddevSamp,
|
||||
"min" -> Min,
|
||||
"max" -> Max)
|
||||
|
||||
|
|
|
@ -96,10 +96,10 @@ class GroupedData protected[sql](
|
|||
case "avg" | "average" | "mean" => Average
|
||||
case "max" => Max
|
||||
case "min" => Min
|
||||
case "stddev" | "std" => Stddev
|
||||
case "stddev" | "std" => StddevSamp
|
||||
case "stddev_pop" => StddevPop
|
||||
case "stddev_samp" => StddevSamp
|
||||
case "variance" => Variance
|
||||
case "variance" => VarianceSamp
|
||||
case "var_pop" => VariancePop
|
||||
case "var_samp" => VarianceSamp
|
||||
case "sum" => Sum
|
||||
|
|
|
@ -329,13 +329,12 @@ object functions {
|
|||
def skewness(e: Column): Column = Skewness(e.expr)
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the unbiased sample standard deviation of
|
||||
* the expression in a group.
|
||||
* Aggregate function: alias for [[stddev_samp]].
|
||||
*
|
||||
* @group agg_funcs
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def stddev(e: Column): Column = Stddev(e.expr)
|
||||
def stddev(e: Column): Column = StddevSamp(e.expr)
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the unbiased sample standard deviation of
|
||||
|
@ -388,12 +387,12 @@ object functions {
|
|||
def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName))
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the population variance of the values in a group.
|
||||
* Aggregate function: alias for [[var_samp]].
|
||||
*
|
||||
* @group agg_funcs
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def variance(e: Column): Column = Variance(e.expr)
|
||||
def variance(e: Column): Column = VarianceSamp(e.expr)
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the unbiased variance of the values in a group.
|
||||
|
|
|
@ -226,23 +226,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
|
|||
val absTol = 1e-8
|
||||
|
||||
val sparkVariance = testData2.agg(variance('a))
|
||||
val expectedVariance = Row(4.0 / 6.0)
|
||||
checkAggregatesWithTol(sparkVariance, expectedVariance, absTol)
|
||||
checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol)
|
||||
val sparkVariancePop = testData2.agg(var_pop('a))
|
||||
checkAggregatesWithTol(sparkVariancePop, expectedVariance, absTol)
|
||||
checkAggregatesWithTol(sparkVariancePop, Row(4.0 / 6.0), absTol)
|
||||
|
||||
val sparkVarianceSamp = testData2.agg(var_samp('a))
|
||||
val expectedVarianceSamp = Row(4.0 / 5.0)
|
||||
checkAggregatesWithTol(sparkVarianceSamp, expectedVarianceSamp, absTol)
|
||||
checkAggregatesWithTol(sparkVarianceSamp, Row(4.0 / 5.0), absTol)
|
||||
|
||||
val sparkSkewness = testData2.agg(skewness('a))
|
||||
val expectedSkewness = Row(0.0)
|
||||
checkAggregatesWithTol(sparkSkewness, expectedSkewness, absTol)
|
||||
checkAggregatesWithTol(sparkSkewness, Row(0.0), absTol)
|
||||
|
||||
val sparkKurtosis = testData2.agg(kurtosis('a))
|
||||
val expectedKurtosis = Row(-1.5)
|
||||
checkAggregatesWithTol(sparkKurtosis, expectedKurtosis, absTol)
|
||||
|
||||
checkAggregatesWithTol(sparkKurtosis, Row(-1.5), absTol)
|
||||
}
|
||||
|
||||
test("zero moments") {
|
||||
|
@ -251,7 +246,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
|
|||
|
||||
checkAnswer(
|
||||
emptyTableData.agg(variance('a)),
|
||||
Row(0.0))
|
||||
Row(Double.NaN))
|
||||
|
||||
checkAnswer(
|
||||
emptyTableData.agg(var_samp('a)),
|
||||
|
|
|
@ -536,7 +536,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
|||
checkAnswer(
|
||||
sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," +
|
||||
"AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"),
|
||||
Row(0, -1.5, 1, 3, 2, 2.0 / 3.0, 1, 6, 3)
|
||||
Row(0, -1.5, 1, 3, 2, 1.0, 1, 6, 3)
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -757,7 +757,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
|||
test("variance") {
|
||||
val absTol = 1e-8
|
||||
val sparkAnswer = sql("SELECT VARIANCE(a) FROM testData2")
|
||||
val expectedAnswer = Row(4.0 / 6.0)
|
||||
val expectedAnswer = Row(0.8)
|
||||
checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
|
||||
}
|
||||
|
||||
|
@ -784,16 +784,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
|||
|
||||
test("stddev agg") {
|
||||
checkAnswer(
|
||||
sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"),
|
||||
sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"),
|
||||
(1 to 3).map(i => Row(i, math.sqrt(1.0 / 2.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 2.0))))
|
||||
}
|
||||
|
||||
test("variance agg") {
|
||||
val absTol = 1e-8
|
||||
val sparkAnswer = sql("SELECT a, variance(b), var_samp(b), var_pop(b)" +
|
||||
"FROM testData2 GROUP BY a")
|
||||
val expectedAnswer = (1 to 3).map(i => Row(i, 1.0 / 4.0, 1.0 / 2.0, 1.0 / 4.0))
|
||||
checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
|
||||
checkAggregatesWithTol(
|
||||
sql("SELECT a, variance(b), var_samp(b), var_pop(b) FROM testData2 GROUP BY a"),
|
||||
(1 to 3).map(i => Row(i, 1.0 / 2.0, 1.0 / 2.0, 1.0 / 4.0)),
|
||||
absTol)
|
||||
}
|
||||
|
||||
test("skewness and kurtosis agg") {
|
||||
|
|
Loading…
Reference in a new issue