[SPARK-29959][ML][PYSPARK] Summarizer support more metrics
### What changes were proposed in this pull request? Summarizer support more metrics: sum, std ### Why are the changes needed? Those metrics are widely used, it will be convenient to directly obtain them other than a conversion. in `NaiveBayes`: we want the sum of vectors, mean & weightSum need to computed then multiplied in `StandardScaler`,`AFTSurvivalRegression`,`LinearRegression`,`LinearSVC`,`LogisticRegression`: we need to obtain `variance` and then sqrt it to get std ### Does this PR introduce any user-facing change? yes, new metrics are exposed to end users ### How was this patch tested? added testsuites Closes #26596 from zhengruifeng/summarizer_add_metrics. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: zhengruifeng <ruifengz@foxmail.com>
This commit is contained in:
parent
85cb388ae3
commit
03ac1b799c
|
@ -109,7 +109,8 @@ Refer to the [`ChiSquareTest` Python docs](api/python/index.html#pyspark.ml.stat
|
|||
## Summarizer
|
||||
|
||||
We provide vector column summary statistics for `Dataframe` through `Summarizer`.
|
||||
Available metrics are the column-wise max, min, mean, variance, and number of nonzeros, as well as the total count.
|
||||
Available metrics are the column-wise max, min, mean, sum, variance, std, and number of nonzeros,
|
||||
as well as the total count.
|
||||
|
||||
<div class="codetabs">
|
||||
<div data-lang="scala" markdown="1">
|
||||
|
|
|
@ -170,7 +170,7 @@ class LinearSVC @Since("2.2.0") (
|
|||
regParam, maxIter, fitIntercept, tol, standardization, threshold, aggregationDepth)
|
||||
|
||||
val (summarizer, labelSummarizer) = instances.treeAggregate(
|
||||
(createSummarizerBuffer("mean", "variance", "count"), new MultiClassSummarizer))(
|
||||
(createSummarizerBuffer("mean", "std", "count"), new MultiClassSummarizer))(
|
||||
seqOp = (c: (SummarizerBuffer, MultiClassSummarizer), instance: Instance) =>
|
||||
(c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight)),
|
||||
combOp = (c1: (SummarizerBuffer, MultiClassSummarizer),
|
||||
|
@ -207,7 +207,7 @@ class LinearSVC @Since("2.2.0") (
|
|||
throw new SparkException(msg)
|
||||
}
|
||||
|
||||
val featuresStd = summarizer.variance.toArray.map(math.sqrt)
|
||||
val featuresStd = summarizer.std.toArray
|
||||
val getFeaturesStd = (j: Int) => featuresStd(j)
|
||||
val regParamL2 = $(regParam)
|
||||
val bcFeaturesStd = instances.context.broadcast(featuresStd)
|
||||
|
|
|
@ -501,7 +501,7 @@ class LogisticRegression @Since("1.2.0") (
|
|||
fitIntercept)
|
||||
|
||||
val (summarizer, labelSummarizer) = instances.treeAggregate(
|
||||
(createSummarizerBuffer("mean", "variance", "count"), new MultiClassSummarizer))(
|
||||
(createSummarizerBuffer("mean", "std", "count"), new MultiClassSummarizer))(
|
||||
seqOp = (c: (SummarizerBuffer, MultiClassSummarizer), instance: Instance) =>
|
||||
(c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight)),
|
||||
combOp = (c1: (SummarizerBuffer, MultiClassSummarizer),
|
||||
|
@ -582,7 +582,7 @@ class LogisticRegression @Since("1.2.0") (
|
|||
}
|
||||
|
||||
val featuresMean = summarizer.mean.toArray
|
||||
val featuresStd = summarizer.variance.toArray.map(math.sqrt)
|
||||
val featuresStd = summarizer.std.toArray
|
||||
|
||||
if (!$(fitIntercept) && (0 until numFeatures).exists { i =>
|
||||
featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) {
|
||||
|
|
|
@ -186,16 +186,12 @@ class NaiveBayes @Since("1.5.0") (
|
|||
}
|
||||
|
||||
// Aggregates term frequencies per label.
|
||||
// TODO: Summarizer directly returns sum vector.
|
||||
val aggregated = dataset.groupBy(col($(labelCol)))
|
||||
.agg(sum(w).as("weightSum"), Summarizer.metrics("mean", "count")
|
||||
.agg(sum(w).as("weightSum"), Summarizer.metrics("sum", "count")
|
||||
.summary(validateUDF(col($(featuresCol))), w).as("summary"))
|
||||
.select($(labelCol), "weightSum", "summary.mean", "summary.count")
|
||||
.select($(labelCol), "weightSum", "summary.sum", "summary.count")
|
||||
.as[(Double, Double, Vector, Long)]
|
||||
.map { case (label, weightSum, mean, count) =>
|
||||
BLAS.scal(weightSum, mean)
|
||||
(label, weightSum, mean, count)
|
||||
}.collect().sortBy(_._1)
|
||||
.collect().sortBy(_._1)
|
||||
|
||||
val numFeatures = aggregated.head._3.size
|
||||
instr.logNumFeatures(numFeatures)
|
||||
|
@ -269,7 +265,6 @@ class NaiveBayes @Since("1.5.0") (
|
|||
}
|
||||
|
||||
// Aggregates mean vector and square-sum vector per label.
|
||||
// TODO: Summarizer directly returns square-sum vector.
|
||||
val aggregated = dataset.groupBy(col($(labelCol)))
|
||||
.agg(sum(w).as("weightSum"), Summarizer.metrics("mean", "normL2")
|
||||
.summary(col($(featuresCol)), w).as("summary"))
|
||||
|
|
|
@ -108,13 +108,11 @@ class StandardScaler @Since("1.4.0") (
|
|||
override def fit(dataset: Dataset[_]): StandardScalerModel = {
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
|
||||
val Row(mean: Vector, variance: Vector) = dataset
|
||||
.select(Summarizer.metrics("mean", "variance").summary(col($(inputCol))).as("summary"))
|
||||
.select("summary.mean", "summary.variance")
|
||||
val Row(mean: Vector, std: Vector) = dataset
|
||||
.select(Summarizer.metrics("mean", "std").summary(col($(inputCol))).as("summary"))
|
||||
.select("summary.mean", "summary.std")
|
||||
.first()
|
||||
|
||||
val std = Vectors.dense(variance.toArray.map(math.sqrt))
|
||||
|
||||
copyValues(new StandardScalerModel(uid, std.compressed, mean.compressed).setParent(this))
|
||||
}
|
||||
|
||||
|
|
|
@ -215,13 +215,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
|
|||
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
|
||||
val featuresSummarizer = instances.treeAggregate(
|
||||
createSummarizerBuffer("mean", "variance", "count"))(
|
||||
createSummarizerBuffer("mean", "std", "count"))(
|
||||
seqOp = (c: SummarizerBuffer, v: AFTPoint) => c.add(v.features),
|
||||
combOp = (c1: SummarizerBuffer, c2: SummarizerBuffer) => c1.merge(c2),
|
||||
depth = $(aggregationDepth)
|
||||
)
|
||||
|
||||
val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
|
||||
val featuresStd = featuresSummarizer.std.toArray
|
||||
val numFeatures = featuresStd.size
|
||||
|
||||
instr.logPipelineStage(this)
|
||||
|
|
|
@ -358,8 +358,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
|||
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
|
||||
val (featuresSummarizer, ySummarizer) = instances.treeAggregate(
|
||||
(createSummarizerBuffer("mean", "variance"),
|
||||
createSummarizerBuffer("mean", "variance", "count")))(
|
||||
(createSummarizerBuffer("mean", "std"),
|
||||
createSummarizerBuffer("mean", "std", "count")))(
|
||||
seqOp = (c: (SummarizerBuffer, SummarizerBuffer), instance: Instance) =>
|
||||
(c._1.add(instance.features, instance.weight),
|
||||
c._2.add(Vectors.dense(instance.label), instance.weight)),
|
||||
|
@ -370,7 +370,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
|||
)
|
||||
|
||||
val yMean = ySummarizer.mean(0)
|
||||
val rawYStd = math.sqrt(ySummarizer.variance(0))
|
||||
val rawYStd = ySummarizer.std(0)
|
||||
|
||||
instr.logNumExamples(ySummarizer.count)
|
||||
instr.logNamedValue(Instrumentation.loggerTags.meanOfLabels, yMean)
|
||||
|
@ -421,7 +421,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
|||
// setting yStd=abs(yMean) ensures that y is not scaled anymore in l-bfgs algorithm.
|
||||
val yStd = if (rawYStd > 0) rawYStd else math.abs(yMean)
|
||||
val featuresMean = featuresSummarizer.mean.toArray
|
||||
val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
|
||||
val featuresStd = featuresSummarizer.std.toArray
|
||||
val bcFeaturesMean = instances.context.broadcast(featuresMean)
|
||||
val bcFeaturesStd = instances.context.broadcast(featuresStd)
|
||||
|
||||
|
|
|
@ -89,7 +89,9 @@ object Summarizer extends Logging {
|
|||
*
|
||||
* The following metrics are accepted (case sensitive):
|
||||
* - mean: a vector that contains the coefficient-wise mean.
|
||||
* - sum: a vector that contains the coefficient-wise sum.
|
||||
* - variance: a vector tha contains the coefficient-wise variance.
|
||||
* - std: a vector tha contains the coefficient-wise standard deviation.
|
||||
* - count: the count of all vectors seen.
|
||||
* - numNonzeros: a vector with the number of non-zeros for each coefficients
|
||||
* - max: the maximum for each coefficient.
|
||||
|
@ -106,7 +108,7 @@ object Summarizer extends Logging {
|
|||
@Since("2.3.0")
|
||||
@scala.annotation.varargs
|
||||
def metrics(metrics: String*): SummaryBuilder = {
|
||||
require(metrics.size >= 1, "Should include at least one metric")
|
||||
require(metrics.nonEmpty, "Should include at least one metric")
|
||||
val (typedMetrics, computeMetrics) = getRelevantMetrics(metrics)
|
||||
new SummaryBuilderImpl(typedMetrics, computeMetrics)
|
||||
}
|
||||
|
@ -119,6 +121,14 @@ object Summarizer extends Logging {
|
|||
@Since("2.3.0")
|
||||
def mean(col: Column): Column = mean(col, lit(1.0))
|
||||
|
||||
@Since("3.0.0")
|
||||
def sum(col: Column, weightCol: Column): Column = {
|
||||
getSingleMetric(col, weightCol, "sum")
|
||||
}
|
||||
|
||||
@Since("3.0.0")
|
||||
def sum(col: Column): Column = sum(col, lit(1.0))
|
||||
|
||||
@Since("2.3.0")
|
||||
def variance(col: Column, weightCol: Column): Column = {
|
||||
getSingleMetric(col, weightCol, "variance")
|
||||
|
@ -127,6 +137,14 @@ object Summarizer extends Logging {
|
|||
@Since("2.3.0")
|
||||
def variance(col: Column): Column = variance(col, lit(1.0))
|
||||
|
||||
@Since("3.0.0")
|
||||
def std(col: Column, weightCol: Column): Column = {
|
||||
getSingleMetric(col, weightCol, "std")
|
||||
}
|
||||
|
||||
@Since("3.0.0")
|
||||
def std(col: Column): Column = std(col, lit(1.0))
|
||||
|
||||
@Since("2.3.0")
|
||||
def count(col: Column, weightCol: Column): Column = {
|
||||
getSingleMetric(col, weightCol, "count")
|
||||
|
@ -245,7 +263,9 @@ private[ml] object SummaryBuilderImpl extends Logging {
|
|||
*/
|
||||
private val allMetrics: Seq[(String, Metric, DataType, Seq[ComputeMetric])] = Seq(
|
||||
("mean", Mean, vectorUDT, Seq(ComputeMean, ComputeWeightSum)),
|
||||
("sum", Sum, vectorUDT, Seq(ComputeMean, ComputeWeightSum)),
|
||||
("variance", Variance, vectorUDT, Seq(ComputeWeightSum, ComputeMean, ComputeM2n)),
|
||||
("std", Std, vectorUDT, Seq(ComputeWeightSum, ComputeMean, ComputeM2n)),
|
||||
("count", Count, LongType, Seq()),
|
||||
("numNonZeros", NumNonZeros, vectorUDT, Seq(ComputeNNZ)),
|
||||
("max", Max, vectorUDT, Seq(ComputeMax, ComputeNNZ)),
|
||||
|
@ -259,7 +279,9 @@ private[ml] object SummaryBuilderImpl extends Logging {
|
|||
*/
|
||||
sealed trait Metric extends Serializable
|
||||
private[stat] case object Mean extends Metric
|
||||
private[stat] case object Sum extends Metric
|
||||
private[stat] case object Variance extends Metric
|
||||
private[stat] case object Std extends Metric
|
||||
private[stat] case object Count extends Metric
|
||||
private[stat] case object NumNonZeros extends Metric
|
||||
private[stat] case object Max extends Metric
|
||||
|
@ -295,14 +317,15 @@ private[ml] object SummaryBuilderImpl extends Logging {
|
|||
private var totalCnt: Long = 0
|
||||
private var totalWeightSum: Double = 0.0
|
||||
private var weightSquareSum: Double = 0.0
|
||||
private var weightSum: Array[Double] = null
|
||||
private var currWeightSum: Array[Double] = null
|
||||
private var nnz: Array[Long] = null
|
||||
private var currMax: Array[Double] = null
|
||||
private var currMin: Array[Double] = null
|
||||
|
||||
def this() {
|
||||
this(
|
||||
Seq(Mean, Variance, Count, NumNonZeros, Max, Min, NormL2, NormL1),
|
||||
Seq(Mean, Sum, Variance, Std, Count, NumNonZeros,
|
||||
Max, Min, NormL2, NormL1),
|
||||
Seq(ComputeMean, ComputeM2n, ComputeM2, ComputeL1,
|
||||
ComputeWeightSum, ComputeNNZ, ComputeMax, ComputeMin)
|
||||
)
|
||||
|
@ -323,7 +346,9 @@ private[ml] object SummaryBuilderImpl extends Logging {
|
|||
if (requestedCompMetrics.contains(ComputeM2n)) { currM2n = Array.ofDim[Double](n) }
|
||||
if (requestedCompMetrics.contains(ComputeM2)) { currM2 = Array.ofDim[Double](n) }
|
||||
if (requestedCompMetrics.contains(ComputeL1)) { currL1 = Array.ofDim[Double](n) }
|
||||
if (requestedCompMetrics.contains(ComputeWeightSum)) { weightSum = Array.ofDim[Double](n) }
|
||||
if (requestedCompMetrics.contains(ComputeWeightSum)) {
|
||||
currWeightSum = Array.ofDim[Double](n)
|
||||
}
|
||||
if (requestedCompMetrics.contains(ComputeNNZ)) { nnz = Array.ofDim[Long](n) }
|
||||
if (requestedCompMetrics.contains(ComputeMax)) {
|
||||
currMax = Array.fill[Double](n)(Double.MinValue)
|
||||
|
@ -340,7 +365,7 @@ private[ml] object SummaryBuilderImpl extends Logging {
|
|||
val localCurrM2n = currM2n
|
||||
val localCurrM2 = currM2
|
||||
val localCurrL1 = currL1
|
||||
val localWeightSum = weightSum
|
||||
val localCurrWeightSum = currWeightSum
|
||||
val localNumNonzeros = nnz
|
||||
val localCurrMax = currMax
|
||||
val localCurrMin = currMin
|
||||
|
@ -353,17 +378,18 @@ private[ml] object SummaryBuilderImpl extends Logging {
|
|||
localCurrMin(index) = value
|
||||
}
|
||||
|
||||
if (localWeightSum != null) {
|
||||
if (localCurrWeightSum != null) {
|
||||
if (localCurrMean != null) {
|
||||
val prevMean = localCurrMean(index)
|
||||
val diff = value - prevMean
|
||||
localCurrMean(index) = prevMean + weight * diff / (localWeightSum(index) + weight)
|
||||
localCurrMean(index) = prevMean +
|
||||
weight * diff / (localCurrWeightSum(index) + weight)
|
||||
|
||||
if (localCurrM2n != null) {
|
||||
localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff
|
||||
}
|
||||
}
|
||||
localWeightSum(index) += weight
|
||||
localCurrWeightSum(index) += weight
|
||||
}
|
||||
|
||||
if (localCurrM2 != null) {
|
||||
|
@ -402,9 +428,9 @@ private[ml] object SummaryBuilderImpl extends Logging {
|
|||
weightSquareSum += other.weightSquareSum
|
||||
var i = 0
|
||||
while (i < n) {
|
||||
if (weightSum != null) {
|
||||
val thisWeightSum = weightSum(i)
|
||||
val otherWeightSum = other.weightSum(i)
|
||||
if (currWeightSum != null) {
|
||||
val thisWeightSum = currWeightSum(i)
|
||||
val otherWeightSum = other.currWeightSum(i)
|
||||
val totalWeightSum = thisWeightSum + otherWeightSum
|
||||
|
||||
if (totalWeightSum != 0.0) {
|
||||
|
@ -420,7 +446,7 @@ private[ml] object SummaryBuilderImpl extends Logging {
|
|||
}
|
||||
}
|
||||
}
|
||||
weightSum(i) = totalWeightSum
|
||||
currWeightSum(i) = totalWeightSum
|
||||
}
|
||||
|
||||
// merge m2 together
|
||||
|
@ -442,7 +468,7 @@ private[ml] object SummaryBuilderImpl extends Logging {
|
|||
this.totalCnt = other.totalCnt
|
||||
this.totalWeightSum = other.totalWeightSum
|
||||
this.weightSquareSum = other.weightSquareSum
|
||||
if (other.weightSum != null) { this.weightSum = other.weightSum.clone() }
|
||||
if (other.currWeightSum != null) { this.currWeightSum = other.currWeightSum.clone() }
|
||||
if (other.nnz != null) { this.nnz = other.nnz.clone() }
|
||||
if (other.currMax != null) { this.currMax = other.currMax.clone() }
|
||||
if (other.currMin != null) { this.currMin = other.currMin.clone() }
|
||||
|
@ -460,12 +486,28 @@ private[ml] object SummaryBuilderImpl extends Logging {
|
|||
val realMean = Array.ofDim[Double](n)
|
||||
var i = 0
|
||||
while (i < n) {
|
||||
realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum)
|
||||
realMean(i) = currMean(i) * (currWeightSum(i) / totalWeightSum)
|
||||
i += 1
|
||||
}
|
||||
Vectors.dense(realMean)
|
||||
}
|
||||
|
||||
/**
|
||||
* Sum of each dimension.
|
||||
*/
|
||||
def sum: Vector = {
|
||||
require(requestedMetrics.contains(Sum))
|
||||
require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
|
||||
|
||||
val realSum = Array.ofDim[Double](n)
|
||||
var i = 0
|
||||
while (i < n) {
|
||||
realSum(i) = currMean(i) * currWeightSum(i)
|
||||
i += 1
|
||||
}
|
||||
Vectors.dense(realSum)
|
||||
}
|
||||
|
||||
/**
|
||||
* Unbiased estimate of sample variance of each dimension.
|
||||
*/
|
||||
|
@ -473,8 +515,23 @@ private[ml] object SummaryBuilderImpl extends Logging {
|
|||
require(requestedMetrics.contains(Variance))
|
||||
require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
|
||||
|
||||
val realVariance = Array.ofDim[Double](n)
|
||||
val realVariance = computeVariance
|
||||
Vectors.dense(realVariance)
|
||||
}
|
||||
|
||||
/**
|
||||
* Unbiased estimate of standard deviation of each dimension.
|
||||
*/
|
||||
def std: Vector = {
|
||||
require(requestedMetrics.contains(Std))
|
||||
require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
|
||||
|
||||
val realVariance = computeVariance
|
||||
Vectors.dense(realVariance.map(math.sqrt))
|
||||
}
|
||||
|
||||
private def computeVariance: Array[Double] = {
|
||||
val realVariance = Array.ofDim[Double](n)
|
||||
val denominator = totalWeightSum - (weightSquareSum / totalWeightSum)
|
||||
|
||||
// Sample variance is computed, if the denominator is less than 0, the variance is just 0.
|
||||
|
@ -484,12 +541,12 @@ private[ml] object SummaryBuilderImpl extends Logging {
|
|||
val len = currM2n.length
|
||||
while (i < len) {
|
||||
// We prevent variance from negative value caused by numerical error.
|
||||
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
|
||||
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
|
||||
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * currWeightSum(i) *
|
||||
(totalWeightSum - currWeightSum(i)) / totalWeightSum) / denominator, 0.0)
|
||||
i += 1
|
||||
}
|
||||
}
|
||||
Vectors.dense(realVariance)
|
||||
realVariance
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -579,7 +636,9 @@ private[ml] object SummaryBuilderImpl extends Logging {
|
|||
override def eval(state: SummarizerBuffer): Any = {
|
||||
val metrics = requestedMetrics.map {
|
||||
case Mean => vectorUDT.serialize(state.mean)
|
||||
case Sum => vectorUDT.serialize(state.sum)
|
||||
case Variance => vectorUDT.serialize(state.variance)
|
||||
case Std => vectorUDT.serialize(state.std)
|
||||
case Count => state.count
|
||||
case NumNonZeros => vectorUDT.serialize(state.numNonzeros)
|
||||
case Max => vectorUDT.serialize(state.max)
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.ml.stat
|
||||
|
||||
import org.apache.spark.{SparkException, SparkFunSuite}
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.ml.util.TestingUtils._
|
||||
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
|
||||
import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, Statistics}
|
||||
|
@ -83,6 +83,28 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
Row(Row(summarizerWithoutWeight.mean), expWithoutWeight.mean))
|
||||
}
|
||||
|
||||
registerTest(s"$name - sum only") {
|
||||
val (df, c, w) = wrappedInit()
|
||||
val weightSum = summarizer.weightSum
|
||||
val expected1 = summarizer.mean.asML.copy
|
||||
BLAS.scal(weightSum, expected1)
|
||||
val expected2 = exp.mean.copy
|
||||
BLAS.scal(weightSum, expected2)
|
||||
compareRow(df.select(metrics("sum").summary(c, w), sum(c, w)).first(),
|
||||
Row(Row(expected1), expected2))
|
||||
}
|
||||
|
||||
registerTest(s"$name - sum only w/o weight") {
|
||||
val (df, c, _) = wrappedInit()
|
||||
val weightSum = summarizerWithoutWeight.weightSum
|
||||
val expected1 = summarizerWithoutWeight.mean.asML.copy
|
||||
BLAS.scal(weightSum, expected1)
|
||||
val expected2 = expWithoutWeight.mean.copy
|
||||
BLAS.scal(weightSum, expected2)
|
||||
compareRow(df.select(metrics("sum").summary(c), sum(c)).first(),
|
||||
Row(Row(expected1), expected2))
|
||||
}
|
||||
|
||||
registerTest(s"$name - variance only") {
|
||||
val (df, c, w) = wrappedInit()
|
||||
compareRow(df.select(metrics("variance").summary(c, w), variance(c, w)).first(),
|
||||
|
@ -95,6 +117,22 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
Row(Row(summarizerWithoutWeight.variance), expWithoutWeight.variance))
|
||||
}
|
||||
|
||||
registerTest(s"$name - std only") {
|
||||
val (df, c, w) = wrappedInit()
|
||||
val expected1 = Vectors.dense(summarizer.variance.toArray.map(math.sqrt))
|
||||
val expected2 = Vectors.dense(exp.variance.toArray.map(math.sqrt))
|
||||
compareRow(df.select(metrics("std").summary(c, w), std(c, w)).first(),
|
||||
Row(Row(expected1), expected2))
|
||||
}
|
||||
|
||||
registerTest(s"$name - std only w/o weight") {
|
||||
val (df, c, _) = wrappedInit()
|
||||
val expected1 = Vectors.dense(summarizerWithoutWeight.variance.toArray.map(math.sqrt))
|
||||
val expected2 = Vectors.dense(expWithoutWeight.variance.toArray.map(math.sqrt))
|
||||
compareRow(df.select(metrics("std").summary(c), std(c)).first(),
|
||||
Row(Row(expected1), expected2))
|
||||
}
|
||||
|
||||
registerTest(s"$name - count only") {
|
||||
val (df, c, w) = wrappedInit()
|
||||
compareRow(df.select(metrics("count").summary(c, w), count(c, w)).first(),
|
||||
|
@ -192,8 +230,12 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
assert(v1 ~== v2 absTol 1e-4)
|
||||
case (v1: Vector, v2: OldVector) =>
|
||||
assert(v1 ~== v2.asML absTol 1e-4)
|
||||
case (i1: Int, i2: Int) =>
|
||||
assert(i1 === i2)
|
||||
case (l1: Long, l2: Long) =>
|
||||
assert(l1 === l2)
|
||||
case (d1: Double, d2: Double) =>
|
||||
assert(d1 ~== d2 absTol 1e-4)
|
||||
case (r1: Row, r2: Row) =>
|
||||
compareRow(r1, r2)
|
||||
case (x1: Any, x2: Any) =>
|
||||
|
@ -531,6 +573,30 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
|
||||
}
|
||||
|
||||
test("support new metrics: sum, std, numFeatures, sumL2, weightSum") {
|
||||
val summarizer1 = new SummarizerBuffer()
|
||||
.add(Vectors.dense(10.0, -10.0), 1e10)
|
||||
.add(Vectors.dense(0.0, 0.0), 1e-7)
|
||||
|
||||
val summarizer2 = new SummarizerBuffer()
|
||||
summarizer2.add(Vectors.dense(10.0, -10.0), 1e10)
|
||||
for (i <- 1 to 100) {
|
||||
summarizer2.add(Vectors.dense(0.0, 0.0), 1e-7)
|
||||
}
|
||||
|
||||
val summarizer3 = new SummarizerBuffer()
|
||||
for (i <- 1 to 100) {
|
||||
summarizer3.add(Vectors.dense(0.0, 0.0), 1e-7)
|
||||
}
|
||||
summarizer3.add(Vectors.dense(10.0, -10.0), 1e10)
|
||||
|
||||
Seq(summarizer1, summarizer2, summarizer3).foreach { summarizer =>
|
||||
val variance = summarizer.variance
|
||||
val expectedStd = Vectors.dense(variance.toArray.map(math.sqrt))
|
||||
assert(summarizer.std ~== expectedStd relTol 1e-14)
|
||||
}
|
||||
}
|
||||
|
||||
ignore("performance test") {
|
||||
/*
|
||||
Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.12
|
||||
|
|
|
@ -243,6 +243,14 @@ class Summarizer(object):
|
|||
"""
|
||||
return Summarizer._get_single_metric(col, weightCol, "mean")
|
||||
|
||||
@staticmethod
|
||||
@since("3.0.0")
|
||||
def sum(col, weightCol=None):
|
||||
"""
|
||||
return a column of sum summary
|
||||
"""
|
||||
return Summarizer._get_single_metric(col, weightCol, "sum")
|
||||
|
||||
@staticmethod
|
||||
@since("2.4.0")
|
||||
def variance(col, weightCol=None):
|
||||
|
@ -251,6 +259,14 @@ class Summarizer(object):
|
|||
"""
|
||||
return Summarizer._get_single_metric(col, weightCol, "variance")
|
||||
|
||||
@staticmethod
|
||||
@since("3.0.0")
|
||||
def std(col, weightCol=None):
|
||||
"""
|
||||
return a column of std summary
|
||||
"""
|
||||
return Summarizer._get_single_metric(col, weightCol, "std")
|
||||
|
||||
@staticmethod
|
||||
@since("2.4.0")
|
||||
def count(col, weightCol=None):
|
||||
|
@ -323,7 +339,9 @@ class Summarizer(object):
|
|||
|
||||
The following metrics are accepted (case sensitive):
|
||||
- mean: a vector that contains the coefficient-wise mean.
|
||||
- sum: a vector that contains the coefficient-wise sum.
|
||||
- variance: a vector tha contains the coefficient-wise variance.
|
||||
- std: a vector tha contains the coefficient-wise standard deviation.
|
||||
- count: the count of all vectors seen.
|
||||
- numNonzeros: a vector with the number of non-zeros for each coefficients
|
||||
- max: the maximum for each coefficient.
|
||||
|
|
Loading…
Reference in a new issue