136 lines
3.4 KiB
Scala
136 lines
3.4 KiB
Scala
package org.mimirdb.pip.distribution
|
|
|
|
import java.io.ObjectOutputStream
|
|
import scala.util.Random
|
|
import java.io.ObjectInputStream
|
|
import org.apache.spark.sql.types.{ DataType, DoubleType }
|
|
import scala.collection.mutable
|
|
import java.util.UUID
|
|
|
|
/**
|
|
* A random variable distribution.
|
|
*/
|
|
trait DistributionFamily
|
|
{
|
|
/**
|
|
* The underlying datatype generated by this distribution
|
|
*/
|
|
val baseType: DataType
|
|
|
|
/**
|
|
* Encode the params for an instance of this distribution
|
|
*/
|
|
def serialize(out: ObjectOutputStream, params: Any): Unit
|
|
|
|
/**
|
|
* Decode the params for an instance of this distribution
|
|
*/
|
|
def deserialize(in: ObjectInputStream): Any
|
|
|
|
/**
|
|
* Draw a sample from this distribution. You <b>must</b> generate random numbers based on the
|
|
* provided seed and uuid
|
|
*/
|
|
def sample(params: Any, random: scala.util.Random): Any
|
|
|
|
/**
|
|
* Draw samples until a criterion is met. Abort if no sample is reached in N steps
|
|
*/
|
|
def rejectionSample(maxSamples: Int, params: Any, random: scala.util.Random)
|
|
(criterion: Any => Boolean): Option[Any] =
|
|
{
|
|
for(i <- 0 until maxSamples){
|
|
val v = sample(params, random)
|
|
if(criterion(v)){ return Some(v) }
|
|
}
|
|
return None
|
|
}
|
|
|
|
/**
|
|
* Generate a summary of this distribution based on params
|
|
*/
|
|
def describe(params: Any): String
|
|
|
|
/**
|
|
* A unique label for this distribution
|
|
*/
|
|
def label = this.getClass.getSimpleName.toLowerCase
|
|
}
|
|
|
|
/**
|
|
* A [Distribution] that specifically samples numbers
|
|
*/
|
|
trait NumericalDistributionFamily extends DistributionFamily
|
|
{
|
|
val baseType = DoubleType
|
|
|
|
/**
|
|
* Compute the CDF
|
|
*/
|
|
def approximateCDF(value: Double, params: Any, samples: Int): Double =
|
|
this match {
|
|
case c:CDFSupported => c.cdf(value, params)
|
|
case _ =>
|
|
{
|
|
val rand = new scala.util.Random()
|
|
(0 until samples).count { _ =>
|
|
sample(params, rand).asInstanceOf[Double] <= value
|
|
}.toDouble / samples
|
|
}
|
|
}
|
|
def approximateCDFIsFast(params: Any): Boolean = this.isInstanceOf[CDFSupported]
|
|
|
|
def min(params: Any): Double
|
|
def max(params: Any): Double
|
|
}
|
|
|
|
/**
|
|
* An add-on to NumericalDistributionFamily that indicates an exact CDF can be computed
|
|
*/
|
|
trait CDFSupported
|
|
{
|
|
val baseType: DataType
|
|
|
|
assert(baseType == DoubleType, "Non-numerical distributions can not support CDFs")
|
|
|
|
def cdf(value: Double, params: Any): Double
|
|
}
|
|
|
|
/**
|
|
* An add-on to NumericalDistributionFamily that indicates an exact Inverse CDF can be computed
|
|
*/
|
|
trait ICDFSupported
|
|
{
|
|
val baseType: DataType
|
|
|
|
assert(baseType == DoubleType, "Non-numerical distributions can not support ICDFs")
|
|
|
|
def icdf(value: Double, params: Any): Double
|
|
}
|
|
|
|
/**
|
|
* Companion object for distributions: Keeps a registry of all known distributions
|
|
*/
|
|
object DistributionFamily
|
|
{
|
|
val registered = mutable.Map[String, DistributionFamily]()
|
|
|
|
def all: Iterable[DistributionFamily] = registered.values
|
|
|
|
def register(distribution: DistributionFamily): Unit =
|
|
registered.put(distribution.label, distribution)
|
|
|
|
def apply(distribution: String): DistributionFamily =
|
|
registered.get(distribution.toLowerCase)
|
|
.getOrElse {
|
|
throw new IllegalArgumentException(s"Invalid distribution family '$distribution': Available: ${registered.keys.mkString(", ")}")
|
|
}
|
|
|
|
|
|
/// Pre-defined distributions
|
|
register(Gaussian)
|
|
register(NumericalMixture)
|
|
register(Clamp)
|
|
register(Discretized)
|
|
register(Uniform)
|
|
} |