149 lines
4.9 KiB
Scala
149 lines
4.9 KiB
Scala
package org.mimirdb.pip.distribution.numerical
|
|
|
|
import scala.util.Random
|
|
import java.io.Serializable
|
|
import java.io.ObjectOutputStream
|
|
import java.io.ObjectInputStream
|
|
import org.apache.commons.math3.special.Erf
|
|
import org.apache.spark.sql.Column
|
|
import org.apache.spark.sql.types.DoubleType
|
|
import org.mimirdb.pip.SampleParams
|
|
import org.mimirdb.pip.udt.UnivariateDistribution
|
|
import org.mimirdb.pip.udt.UnivariateDistributionType
|
|
import org.mimirdb.pip.udt.UnivariateDistributionConstructor
|
|
import org.apache.spark.sql.catalyst.expressions.Expression
|
|
import org.apache.spark.sql.Column
|
|
|
|
/**
|
|
* The Gaussian (normal) distribution
|
|
*
|
|
*/
|
|
object Gaussian
|
|
extends NumericalDistributionFamily
|
|
with CDFSupported
|
|
with ICDFSupported
|
|
{
|
|
def apply(mean: Column, stddev: Column): Column =
|
|
{
|
|
new Column(Constructor(Seq(mean.expr, stddev.expr)))
|
|
}
|
|
|
|
|
|
case class Params(mean: Double, sd: Double)
|
|
|
|
def sample(params: Any, random: scala.util.Random): Double =
|
|
{
|
|
random.nextGaussian() * params.asInstanceOf[Params].sd
|
|
+ params.asInstanceOf[Params].mean
|
|
}
|
|
|
|
def serialize(in: ObjectOutputStream, params: Any): Unit =
|
|
{
|
|
in.writeDouble(params.asInstanceOf[Params].mean)
|
|
in.writeDouble(params.asInstanceOf[Params].sd)
|
|
}
|
|
|
|
def deserialize(in: ObjectInputStream): Params =
|
|
{
|
|
return Params(
|
|
mean = in.readDouble(),
|
|
sd = in.readDouble()
|
|
)
|
|
}
|
|
|
|
def min(params: Any) = Double.NegativeInfinity
|
|
def max(params: Any) = Double.PositiveInfinity
|
|
|
|
def cdf(value: Double, params: Any): Double =
|
|
(
|
|
1 + Erf.erf(
|
|
(value - params.asInstanceOf[Params].mean)
|
|
/ (params.asInstanceOf[Params].sd * Math.sqrt(2))
|
|
)
|
|
) / 2.0
|
|
|
|
def icdf(value: Double, params: Any): Double =
|
|
{
|
|
Erf.erfInv(value * 2 - 1) * params.asInstanceOf[Params].sd * Math.sqrt(2)
|
|
+ params.asInstanceOf[Params].mean
|
|
}
|
|
|
|
case class Constructor(args: Seq[Expression])
|
|
extends UnivariateDistributionConstructor
|
|
{
|
|
def family = Gaussian
|
|
def params(values: Seq[Any]) =
|
|
Params(mean = values(0).asInstanceOf[Double],
|
|
sd = values(1).asInstanceOf[Double])
|
|
|
|
def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
|
|
copy(args = newChildren)
|
|
}
|
|
|
|
|
|
def plus(a: Params, b: Params): Params =
|
|
{
|
|
val new_mean = a.mean + b.mean
|
|
val new_sd = math.sqrt(math.pow(a.sd, 2) + math.pow(b.sd, 2))
|
|
Params(new_mean, new_sd)
|
|
}
|
|
|
|
def minus(a: Params, b: Params): Params = {
|
|
plus(a, Params(b.mean * -1, b.sd * -1))
|
|
}
|
|
|
|
def mult(a: Params, b: Params): Params = {
|
|
/* https://ccrma.stanford.edu/~jos/sasp/Product_Two_Gaussian_PDFs.html
|
|
I am following the formula at the above link for taking the product of two Gaussians.
|
|
I don't know if this is correct based on other websites. For now, the above link is the formula
|
|
being followed.
|
|
*/
|
|
val new_mean = (a.mean * math.pow(b.sd, 2) + b.mean * math.pow(a.sd, 2)) / (math.pow(a.sd, 2) + math.pow(b.sd, 2))
|
|
val new_sd = (math.pow(a.sd, 2) * math.pow(b.sd, 2)) / (math.pow(a.sd, 2) + math.pow(b.sd, 2))
|
|
|
|
Params(new_mean, new_sd)
|
|
}
|
|
|
|
def describe(params: Any): String =
|
|
s"Gauss(mean: ${params.asInstanceOf[Params].mean}, std-dev: ${params.asInstanceOf[Params].sd})"
|
|
|
|
// /*
|
|
// Comparison < prototype.
|
|
|
|
// Original idea:
|
|
// i) Take in a general parameter 'other'
|
|
// ii) match it to its true type
|
|
// iii) return the result of the comparison which could be an error
|
|
// Difficulties:
|
|
// */
|
|
// def lt(self: Params, other: Any): Boolean = other match {
|
|
// case i: Integer => this < i.asInstanceOf[Float]
|
|
// case f: Float => f match {
|
|
// case y if (this.mean + (this.sd * 3) < y) => true //assumption here that 0.3% of the probability space will not matter
|
|
// case _ => false
|
|
// }
|
|
// case d: Double => this < d.asInstanceOf[Float]
|
|
// case g: Gauss => g match {
|
|
// case y if (y.mean >= this.mean && y.sd >= this.sd) => false
|
|
// case n if (n.mean < this.mean && n.sd < this.sd) => true
|
|
// case _ => ???//here is where we make an estimate and give a confidence level
|
|
// }
|
|
// case _ => ???// this should raise an exception of incompatible types; how to do this in spark/scala ecosystem?
|
|
// }
|
|
|
|
// def >=(other: Any): Boolean = !(this < other)
|
|
// def >(other: Any): Boolean = other match {
|
|
// case i: Integer => this > i.asInstanceOf[Float]
|
|
// case f: Float => f match {
|
|
// case y if (this.mean + (3 * this.sd) > y) => true
|
|
// case n if (this.mean + (3 * this.sd) <= n )=> false
|
|
// case _ => ???//here is where we make an estimate and give a confidence level
|
|
// }
|
|
// case d: Double => this > d.asInstanceOf[Float]
|
|
// case _ => ???//incompatible types exception
|
|
// }
|
|
// def <=(other: Any): Boolean = !(this > other)
|
|
|
|
|
|
}
|