mimir-pip/lib/src/org/mimirdb/pip/distribution/Gaussian.scala

149 lines
4.8 KiB
Scala

package org.mimirdb.pip.distribution
import scala.util.Random
import java.io.Serializable
import java.io.ObjectOutputStream
import java.io.ObjectInputStream
import org.apache.commons.math3.special.Erf
import org.apache.spark.sql.Column
import org.apache.spark.sql.types.DoubleType
import org.mimirdb.pip.SampleParams
import org.mimirdb.pip.udt.UnivariateDistribution
import org.mimirdb.pip.udt.UnivariateDistributionType
import org.mimirdb.pip.udt.UnivariateDistributionConstructor
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.Column
/**
* The Gaussian (normal) distribution
*
*/
object Gaussian
extends NumericalDistributionFamily
with CDFSupported
with ICDFSupported
{
def apply(mean: Column, stddev: Column): Column =
{
new Column(Constructor(Seq(mean.expr, stddev.expr)))
}
case class Params(mean: Double, sd: Double)
def sample(params: Any, random: scala.util.Random): Double =
{
random.nextGaussian() * params.asInstanceOf[Params].sd
+ params.asInstanceOf[Params].mean
}
def serialize(in: ObjectOutputStream, params: Any): Unit =
{
in.writeDouble(params.asInstanceOf[Params].mean)
in.writeDouble(params.asInstanceOf[Params].sd)
}
def deserialize(in: ObjectInputStream): Params =
{
return Params(
mean = in.readDouble(),
sd = in.readDouble()
)
}
def min(params: Any) = Double.NegativeInfinity
def max(params: Any) = Double.PositiveInfinity
def cdf(value: Double, params: Any): Double =
(
1 + Erf.erf(
(value - params.asInstanceOf[Params].mean)
/ (params.asInstanceOf[Params].sd * Math.sqrt(2))
)
) / 2.0
def icdf(value: Double, params: Any): Double =
{
Erf.erfInv(value * 2 - 1) * params.asInstanceOf[Params].sd * Math.sqrt(2)
+ params.asInstanceOf[Params].mean
}
case class Constructor(args: Seq[Expression])
extends UnivariateDistributionConstructor
{
def family = Gaussian
def params(values: Seq[Any]) =
Params(mean = values(0).asInstanceOf[Double],
sd = values(1).asInstanceOf[Double])
def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
copy(args = newChildren)
}
def plus(a: Params, b: Params): Params =
{
val new_mean = a.mean + b.mean
val new_sd = math.sqrt(math.pow(a.sd, 2) + math.pow(b.sd, 2))
Params(new_mean, new_sd)
}
def minus(a: Params, b: Params): Params = {
plus(a, Params(b.mean * -1, b.sd * -1))
}
def mult(a: Params, b: Params): Params = {
/* https://ccrma.stanford.edu/~jos/sasp/Product_Two_Gaussian_PDFs.html
I am following the formula at the above link for taking the product of two Gaussians.
I don't know if this is correct based on other websites. For now, the above link is the formula
being followed.
*/
val new_mean = (a.mean * math.pow(b.sd, 2) + b.mean * math.pow(a.sd, 2)) / (math.pow(a.sd, 2) + math.pow(b.sd, 2))
val new_sd = (math.pow(a.sd, 2) * math.pow(b.sd, 2)) / (math.pow(a.sd, 2) + math.pow(b.sd, 2))
Params(new_mean, new_sd)
}
def describe(params: Any): String =
s"Gauss(mean: ${params.asInstanceOf[Params].mean}, std-dev: ${params.asInstanceOf[Params].sd})"
// /*
// Comparison < prototype.
// Original idea:
// i) Take in a general parameter 'other'
// ii) match it to its true type
// iii) return the result of the comparison which could be an error
// Difficulties:
// */
// def lt(self: Params, other: Any): Boolean = other match {
// case i: Integer => this < i.asInstanceOf[Float]
// case f: Float => f match {
// case y if (this.mean + (this.sd * 3) < y) => true //assumption here that 0.3% of the probability space will not matter
// case _ => false
// }
// case d: Double => this < d.asInstanceOf[Float]
// case g: Gauss => g match {
// case y if (y.mean >= this.mean && y.sd >= this.sd) => false
// case n if (n.mean < this.mean && n.sd < this.sd) => true
// case _ => ???//here is where we make an estimate and give a confidence level
// }
// case _ => ???// this should raise an exception of incompatible types; how to do this in spark/scala ecosystem?
// }
// def >=(other: Any): Boolean = !(this < other)
// def >(other: Any): Boolean = other match {
// case i: Integer => this > i.asInstanceOf[Float]
// case f: Float => f match {
// case y if (this.mean + (3 * this.sd) > y) => true
// case n if (this.mean + (3 * this.sd) <= n )=> false
// case _ => ???//here is where we make an estimate and give a confidence level
// }
// case d: Double => this > d.asInstanceOf[Float]
// case _ => ???//incompatible types exception
// }
// def <=(other: Any): Boolean = !(this > other)
}