[SPARK-7753] [MLLIB] Update KernelDensity API
Update `KernelDensity` API to make it extensible to different kernels in the future. `bandwidth` is used instead of `standardDeviation`. The static `kernelDensity` method is removed from `Statistics`. The implementation is updated using BLAS, while the algorithm remains the same. sryza srowen
Author: Xiangrui Meng <meng@databricks.com>
Closes #6279 from mengxr/SPARK-7753 and squashes the following commits:
4cdfadc [Xiangrui Meng] add example code in the doc
767fd5a [Xiangrui Meng] update KernelDensity API
(cherry picked from commit 947ea1cf5f
)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
This commit is contained in:
parent
b0e7c66338
commit
64762444e7
|
@ -17,52 +17,101 @@
|
|||
|
||||
package org.apache.spark.mllib.stat
|
||||
|
||||
import com.github.fommil.netlib.BLAS.{getInstance => blas}
|
||||
|
||||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.api.java.JavaRDD
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
private[stat] object KernelDensity {
|
||||
/**
|
||||
* :: Experimental ::
|
||||
* Kernel density estimation. Given a sample from a population, estimate its probability density
|
||||
* function at each of the given evaluation points using kernels. Only Gaussian kernel is supported.
|
||||
*
|
||||
* Scala example:
|
||||
*
|
||||
* {{{
|
||||
* val sample = sc.parallelize(Seq(0.0, 1.0, 4.0, 4.0))
|
||||
* val kd = new KernelDensity()
|
||||
* .setSample(sample)
|
||||
* .setBandwidth(3.0)
|
||||
* val densities = kd.estimate(Array(-1.0, 2.0, 5.0))
|
||||
* }}}
|
||||
*/
|
||||
@Experimental
|
||||
class KernelDensity extends Serializable {
|
||||
|
||||
import KernelDensity._
|
||||
|
||||
/** Bandwidth of the kernel function. */
|
||||
private var bandwidth: Double = 1.0
|
||||
|
||||
/** A sample from a population. */
|
||||
private var sample: RDD[Double] = _
|
||||
|
||||
/**
|
||||
* Given a set of samples from a distribution, estimates its density at the set of given points.
|
||||
* Uses a Gaussian kernel with the given standard deviation.
|
||||
* Sets the bandwidth (standard deviation) of the Gaussian kernel (default: `1.0`).
|
||||
*/
|
||||
def estimate(samples: RDD[Double], standardDeviation: Double,
|
||||
evaluationPoints: Array[Double]): Array[Double] = {
|
||||
if (standardDeviation <= 0.0) {
|
||||
throw new IllegalArgumentException("Standard deviation must be positive")
|
||||
}
|
||||
|
||||
// This gets used in each Gaussian PDF computation, so compute it up front
|
||||
val logStandardDeviationPlusHalfLog2Pi =
|
||||
math.log(standardDeviation) + 0.5 * math.log(2 * math.Pi)
|
||||
|
||||
val (points, count) = samples.aggregate((new Array[Double](evaluationPoints.length), 0))(
|
||||
(x, y) => {
|
||||
var i = 0
|
||||
while (i < evaluationPoints.length) {
|
||||
x._1(i) += normPdf(y, standardDeviation, logStandardDeviationPlusHalfLog2Pi,
|
||||
evaluationPoints(i))
|
||||
i += 1
|
||||
}
|
||||
(x._1, i)
|
||||
},
|
||||
(x, y) => {
|
||||
var i = 0
|
||||
while (i < evaluationPoints.length) {
|
||||
x._1(i) += y._1(i)
|
||||
i += 1
|
||||
}
|
||||
(x._1, x._2 + y._2)
|
||||
})
|
||||
|
||||
var i = 0
|
||||
while (i < points.length) {
|
||||
points(i) /= count
|
||||
i += 1
|
||||
}
|
||||
points
|
||||
def setBandwidth(bandwidth: Double): this.type = {
|
||||
require(bandwidth > 0, s"Bandwidth must be positive, but got $bandwidth.")
|
||||
this.bandwidth = bandwidth
|
||||
this
|
||||
}
|
||||
|
||||
private def normPdf(mean: Double, standardDeviation: Double,
|
||||
logStandardDeviationPlusHalfLog2Pi: Double, x: Double): Double = {
|
||||
/**
|
||||
* Sets the sample to use for density estimation.
|
||||
*/
|
||||
def setSample(sample: RDD[Double]): this.type = {
|
||||
this.sample = sample
|
||||
this
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the sample to use for density estimation (for Java users).
|
||||
*/
|
||||
def setSample(sample: JavaRDD[java.lang.Double]): this.type = {
|
||||
this.sample = sample.rdd.asInstanceOf[RDD[Double]]
|
||||
this
|
||||
}
|
||||
|
||||
/**
|
||||
* Estimates probability density function at the given array of points.
|
||||
*/
|
||||
def estimate(points: Array[Double]): Array[Double] = {
|
||||
val sample = this.sample
|
||||
val bandwidth = this.bandwidth
|
||||
|
||||
require(sample != null, "Must set sample before calling estimate.")
|
||||
|
||||
val n = points.length
|
||||
// This gets used in each Gaussian PDF computation, so compute it up front
|
||||
val logStandardDeviationPlusHalfLog2Pi = math.log(bandwidth) + 0.5 * math.log(2 * math.Pi)
|
||||
val (densities, count) = sample.aggregate((new Array[Double](n), 0L))(
|
||||
(x, y) => {
|
||||
var i = 0
|
||||
while (i < n) {
|
||||
x._1(i) += normPdf(y, bandwidth, logStandardDeviationPlusHalfLog2Pi, points(i))
|
||||
i += 1
|
||||
}
|
||||
(x._1, n)
|
||||
},
|
||||
(x, y) => {
|
||||
blas.daxpy(n, 1.0, y._1, 1, x._1, 1)
|
||||
(x._1, x._2 + y._2)
|
||||
})
|
||||
blas.dscal(n, 1.0 / count, densities, 1)
|
||||
densities
|
||||
}
|
||||
}
|
||||
|
||||
private object KernelDensity {
|
||||
|
||||
/** Evaluates the PDF of a normal distribution. */
|
||||
def normPdf(
|
||||
mean: Double,
|
||||
standardDeviation: Double,
|
||||
logStandardDeviationPlusHalfLog2Pi: Double,
|
||||
x: Double): Double = {
|
||||
val x0 = x - mean
|
||||
val x1 = x0 / standardDeviation
|
||||
val logDensity = -0.5 * x1 * x1 - logStandardDeviationPlusHalfLog2Pi
|
||||
|
|
|
@ -149,18 +149,4 @@ object Statistics {
|
|||
def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = {
|
||||
ChiSqTest.chiSquaredFeatures(data)
|
||||
}
|
||||
|
||||
/**
|
||||
* Given an empirical distribution defined by the input RDD of samples, estimate its density at
|
||||
* each of the given evaluation points using a Gaussian kernel.
|
||||
*
|
||||
* @param samples The samples RDD used to define the empirical distribution.
|
||||
* @param standardDeviation The standard deviation of the kernel Gaussians.
|
||||
* @param evaluationPoints The points at which to estimate densities.
|
||||
* @return An array the same size as evaluationPoints with the density at each point.
|
||||
*/
|
||||
def kernelDensity(samples: RDD[Double], standardDeviation: Double,
|
||||
evaluationPoints: Iterable[Double]): Array[Double] = {
|
||||
KernelDensity.estimate(samples, standardDeviation, evaluationPoints.toArray)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,9 +17,8 @@
|
|||
|
||||
package org.apache.spark.mllib.stat
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.commons.math3.distribution.NormalDistribution
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
|
||||
|
@ -27,7 +26,7 @@ class KernelDensitySuite extends FunSuite with MLlibTestSparkContext {
|
|||
test("kernel density single sample") {
|
||||
val rdd = sc.parallelize(Array(5.0))
|
||||
val evaluationPoints = Array(5.0, 6.0)
|
||||
val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints)
|
||||
val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints)
|
||||
val normal = new NormalDistribution(5.0, 3.0)
|
||||
val acceptableErr = 1e-6
|
||||
assert(densities(0) - normal.density(5.0) < acceptableErr)
|
||||
|
@ -37,7 +36,7 @@ class KernelDensitySuite extends FunSuite with MLlibTestSparkContext {
|
|||
test("kernel density multiple samples") {
|
||||
val rdd = sc.parallelize(Array(5.0, 10.0))
|
||||
val evaluationPoints = Array(5.0, 6.0)
|
||||
val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints)
|
||||
val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints)
|
||||
val normal1 = new NormalDistribution(5.0, 3.0)
|
||||
val normal2 = new NormalDistribution(10.0, 3.0)
|
||||
val acceptableErr = 1e-6
|
||||
|
|
Loading…
Reference in a new issue