[SPARK-11490][SQL] variance should alias var_samp instead of var_pop.

stddev is an alias for stddev_samp. variance should be consistent with stddev.

Also took the chance to remove internal Stddev and Variance, and only kept StddevSamp/StddevPop and VarianceSamp/VariancePop.

Author: Reynold Xin <rxin@databricks.com>

Closes #9449 from rxin/SPARK-11490.
This commit is contained in:
Reynold Xin 2015-11-04 09:34:52 -08:00 committed by Yin Huai
parent e0fc9c7e59
commit 3bd6f5d2ae
11 changed files with 31 additions and 113 deletions

View file

@ -187,11 +187,11 @@ object FunctionRegistry {
expression[Max]("max"),
expression[Average]("mean"),
expression[Min]("min"),
expression[Stddev]("stddev"),
expression[StddevSamp]("stddev"),
expression[StddevPop]("stddev_pop"),
expression[StddevSamp]("stddev_samp"),
expression[Sum]("sum"),
expression[Variance]("variance"),
expression[VarianceSamp]("variance"),
expression[VariancePop]("var_pop"),
expression[VarianceSamp]("var_samp"),
expression[Skewness]("skewness"),

View file

@ -297,10 +297,8 @@ object HiveTypeCoercion {
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType))
case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
case Variance(e @ StringType()) => Variance(Cast(e, DoubleType))
case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType))
case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType))
case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType))

View file

@ -159,14 +159,6 @@ package object dsl {
def lower(e: Expression): Expression = Lower(e)
def sqrt(e: Expression): Expression = Sqrt(e)
def abs(e: Expression): Expression = Abs(e)
def stddev(e: Expression): Expression = Stddev(e)
def stddev_pop(e: Expression): Expression = StddevPop(e)
def stddev_samp(e: Expression): Expression = StddevSamp(e)
def variance(e: Expression): Expression = Variance(e)
def var_pop(e: Expression): Expression = VariancePop(e)
def var_samp(e: Expression): Expression = VarianceSamp(e)
def skewness(e: Expression): Expression = Skewness(e)
def kurtosis(e: Expression): Expression = Kurtosis(e)
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
// TODO more implicit class for literal?

View file

@ -328,13 +328,6 @@ case class Min(child: Expression) extends DeclarativeAggregate {
override val evaluateExpression = min
}
// Compute the sample standard deviation of a column
case class Stddev(child: Expression) extends StddevAgg(child) {
override def isSample: Boolean = true
override def prettyName: String = "stddev"
}
// Compute the population standard deviation of a column
case class StddevPop(child: Expression) extends StddevAgg(child) {
@ -1274,28 +1267,6 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w
}
}
case class Variance(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def prettyName: String = "variance"
override protected val momentOrder = 2
override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
if (n == 0.0) Double.NaN else moments(2) / n
}
}
case class VarianceSamp(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {

View file

@ -97,12 +97,6 @@ object Utils {
mode = aggregate.Complete,
isDistinct = false)
case expressions.Stddev(child) =>
aggregate.AggregateExpression2(
aggregateFunction = aggregate.Stddev(child),
mode = aggregate.Complete,
isDistinct = false)
case expressions.StddevPop(child) =>
aggregate.AggregateExpression2(
aggregateFunction = aggregate.StddevPop(child),
@ -139,12 +133,6 @@ object Utils {
mode = aggregate.Complete,
isDistinct = false)
case expressions.Variance(child) =>
aggregate.AggregateExpression2(
aggregateFunction = aggregate.Variance(child),
mode = aggregate.Complete,
isDistinct = false)
case expressions.VariancePop(child) =>
aggregate.AggregateExpression2(
aggregateFunction = aggregate.VariancePop(child),

View file

@ -785,13 +785,6 @@ abstract class StddevAgg1(child: Expression) extends UnaryExpression with Partia
}
// Compute the sample standard deviation of a column
case class Stddev(child: Expression) extends StddevAgg1(child) {
override def toString: String = s"STDDEV($child)"
override def isSample: Boolean = true
}
// Compute the population standard deviation of a column
case class StddevPop(child: Expression) extends StddevAgg1(child) {
@ -807,20 +800,21 @@ case class StddevSamp(child: Expression) extends StddevAgg1(child) {
}
case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1 {
def this() = this(null)
def this() = this(null)
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = false
override def dataType: DataType = ArrayType(DoubleType)
override def toString: String = s"computePartialStddev($child)"
override def newInstance(): ComputePartialStdFunction =
new ComputePartialStdFunction(child, this)
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = false
override def dataType: DataType = ArrayType(DoubleType)
override def toString: String = s"computePartialStddev($child)"
override def newInstance(): ComputePartialStdFunction =
new ComputePartialStdFunction(child, this)
}
case class ComputePartialStdFunction (
expr: Expression,
base: AggregateExpression1
) extends AggregateFunction1 {
) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization
private val computeType = DoubleType
@ -1048,25 +1042,6 @@ case class Skewness(child: Expression) extends UnaryExpression with AggregateExp
override def toString: String = s"SKEWNESS($child)"
}
// placeholder
case class Variance(child: Expression) extends UnaryExpression with AggregateExpression1 {
override def newInstance(): AggregateFunction1 = {
throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
"please set spark.sql.useAggregate2 = true")
}
override def nullable: Boolean = false
override def dataType: DoubleType.type = DoubleType
override def foldable: Boolean = false
override def prettyName: String = "variance"
override def toString: String = s"VARIANCE($child)"
}
// placeholder
case class VariancePop(child: Expression) extends UnaryExpression with AggregateExpression1 {

View file

@ -1383,7 +1383,7 @@ class DataFrame private[sql](
val statistics = List[(String, Expression => Expression)](
"count" -> Count,
"mean" -> Average,
"stddev" -> Stddev,
"stddev" -> StddevSamp,
"min" -> Min,
"max" -> Max)

View file

@ -96,10 +96,10 @@ class GroupedData protected[sql](
case "avg" | "average" | "mean" => Average
case "max" => Max
case "min" => Min
case "stddev" | "std" => Stddev
case "stddev" | "std" => StddevSamp
case "stddev_pop" => StddevPop
case "stddev_samp" => StddevSamp
case "variance" => Variance
case "variance" => VarianceSamp
case "var_pop" => VariancePop
case "var_samp" => VarianceSamp
case "sum" => Sum

View file

@ -329,13 +329,12 @@ object functions {
def skewness(e: Column): Column = Skewness(e.expr)
/**
* Aggregate function: returns the unbiased sample standard deviation of
* the expression in a group.
* Aggregate function: alias for [[stddev_samp]].
*
* @group agg_funcs
* @since 1.6.0
*/
def stddev(e: Column): Column = Stddev(e.expr)
def stddev(e: Column): Column = StddevSamp(e.expr)
/**
* Aggregate function: returns the unbiased sample standard deviation of
@ -388,12 +387,12 @@ object functions {
def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName))
/**
* Aggregate function: returns the population variance of the values in a group.
* Aggregate function: alias for [[var_samp]].
*
* @group agg_funcs
* @since 1.6.0
*/
def variance(e: Column): Column = Variance(e.expr)
def variance(e: Column): Column = VarianceSamp(e.expr)
/**
* Aggregate function: returns the unbiased variance of the values in a group.

View file

@ -226,23 +226,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
val absTol = 1e-8
val sparkVariance = testData2.agg(variance('a))
val expectedVariance = Row(4.0 / 6.0)
checkAggregatesWithTol(sparkVariance, expectedVariance, absTol)
checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol)
val sparkVariancePop = testData2.agg(var_pop('a))
checkAggregatesWithTol(sparkVariancePop, expectedVariance, absTol)
checkAggregatesWithTol(sparkVariancePop, Row(4.0 / 6.0), absTol)
val sparkVarianceSamp = testData2.agg(var_samp('a))
val expectedVarianceSamp = Row(4.0 / 5.0)
checkAggregatesWithTol(sparkVarianceSamp, expectedVarianceSamp, absTol)
checkAggregatesWithTol(sparkVarianceSamp, Row(4.0 / 5.0), absTol)
val sparkSkewness = testData2.agg(skewness('a))
val expectedSkewness = Row(0.0)
checkAggregatesWithTol(sparkSkewness, expectedSkewness, absTol)
checkAggregatesWithTol(sparkSkewness, Row(0.0), absTol)
val sparkKurtosis = testData2.agg(kurtosis('a))
val expectedKurtosis = Row(-1.5)
checkAggregatesWithTol(sparkKurtosis, expectedKurtosis, absTol)
checkAggregatesWithTol(sparkKurtosis, Row(-1.5), absTol)
}
test("zero moments") {
@ -251,7 +246,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
checkAnswer(
emptyTableData.agg(variance('a)),
Row(0.0))
Row(Double.NaN))
checkAnswer(
emptyTableData.agg(var_samp('a)),

View file

@ -536,7 +536,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(
sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," +
"AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"),
Row(0, -1.5, 1, 3, 2, 2.0 / 3.0, 1, 6, 3)
Row(0, -1.5, 1, 3, 2, 1.0, 1, 6, 3)
)
}
@ -757,7 +757,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("variance") {
val absTol = 1e-8
val sparkAnswer = sql("SELECT VARIANCE(a) FROM testData2")
val expectedAnswer = Row(4.0 / 6.0)
val expectedAnswer = Row(0.8)
checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
}
@ -784,16 +784,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("stddev agg") {
checkAnswer(
sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"),
sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"),
(1 to 3).map(i => Row(i, math.sqrt(1.0 / 2.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 2.0))))
}
test("variance agg") {
val absTol = 1e-8
val sparkAnswer = sql("SELECT a, variance(b), var_samp(b), var_pop(b)" +
"FROM testData2 GROUP BY a")
val expectedAnswer = (1 to 3).map(i => Row(i, 1.0 / 4.0, 1.0 / 2.0, 1.0 / 4.0))
checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
checkAggregatesWithTol(
sql("SELECT a, variance(b), var_samp(b), var_pop(b) FROM testData2 GROUP BY a"),
(1 to 3).map(i => Row(i, 1.0 / 2.0, 1.0 / 2.0, 1.0 / 4.0)),
absTol)
}
test("skewness and kurtosis agg") {