mimir-pip/lib/src/org/mimirdb/pip/distribution/distributionFamily.scala

136 lines
3.4 KiB
Scala

package org.mimirdb.pip.distribution
import java.io.ObjectOutputStream
import scala.util.Random
import java.io.ObjectInputStream
import org.apache.spark.sql.types.{ DataType, DoubleType }
import scala.collection.mutable
import java.util.UUID
/**
* A random variable distribution.
*/
trait DistributionFamily
{
/**
* The underlying datatype generated by this distribution
*/
val baseType: DataType
/**
* Encode the params for an instance of this distribution
*/
def serialize(out: ObjectOutputStream, params: Any): Unit
/**
* Decode the params for an instance of this distribution
*/
def deserialize(in: ObjectInputStream): Any
/**
* Draw a sample from this distribution. You <b>must</b> generate random numbers based on the
* provided seed and uuid
*/
def sample(params: Any, random: scala.util.Random): Any
/**
* Draw samples until a criterion is met. Abort if no sample is reached in N steps
*/
def rejectionSample(maxSamples: Int, params: Any, random: scala.util.Random)
(criterion: Any => Boolean): Option[Any] =
{
for(i <- 0 until maxSamples){
val v = sample(params, random)
if(criterion(v)){ return Some(v) }
}
return None
}
/**
* Generate a summary of this distribution based on params
*/
def describe(params: Any): String
/**
* A unique label for this distribution
*/
def label = this.getClass.getSimpleName.toLowerCase
}
/**
* A [Distribution] that specifically samples numbers
*/
trait NumericalDistributionFamily extends DistributionFamily
{
val baseType = DoubleType
/**
* Compute the CDF
*/
def approximateCDF(value: Double, params: Any, samples: Int): Double =
this match {
case c:CDFSupported => c.cdf(value, params)
case _ =>
{
val rand = new scala.util.Random()
(0 until samples).count { _ =>
sample(params, rand).asInstanceOf[Double] <= value
}.toDouble / samples
}
}
def approximateCDFIsFast(params: Any): Boolean = this.isInstanceOf[CDFSupported]
def min(params: Any): Double
def max(params: Any): Double
}
/**
* An add-on to NumericalDistributionFamily that indicates an exact CDF can be computed
*/
trait CDFSupported
{
val baseType: DataType
assert(baseType == DoubleType, "Non-numerical distributions can not support CDFs")
def cdf(value: Double, params: Any): Double
}
/**
* An add-on to NumericalDistributionFamily that indicates an exact Inverse CDF can be computed
*/
trait ICDFSupported
{
val baseType: DataType
assert(baseType == DoubleType, "Non-numerical distributions can not support ICDFs")
def icdf(value: Double, params: Any): Double
}
/**
* Companion object for distributions: Keeps a registry of all known distributions
*/
object DistributionFamily
{
val registered = mutable.Map[String, DistributionFamily]()
def all: Iterable[DistributionFamily] = registered.values
def register(distribution: DistributionFamily): Unit =
registered.put(distribution.label, distribution)
def apply(distribution: String): DistributionFamily =
registered.get(distribution.toLowerCase)
.getOrElse {
throw new IllegalArgumentException(s"Invalid distribution family '$distribution': Available: ${registered.keys.mkString(", ")}")
}
/// Pre-defined distributions
register(Gaussian)
register(NumericalMixture)
register(Clamp)
register(Discretized)
register(Uniform)
}