
136 lines
3.4 KiB

package org.mimirdb.pip.distribution
import java.io.ObjectOutputStream
import scala.util.Random
import java.io.ObjectInputStream
import org.apache.spark.sql.types.{ DataType, DoubleType }
import scala.collection.mutable
import java.util.UUID
* A random variable distribution.
trait DistributionFamily
* The underlying datatype generated by this distribution
val baseType: DataType
* Encode the params for an instance of this distribution
def serialize(out: ObjectOutputStream, params: Any): Unit
* Decode the params for an instance of this distribution
def deserialize(in: ObjectInputStream): Any
* Draw a sample from this distribution. You <b>must</b> generate random numbers based on the
* provided seed and uuid
def sample(params: Any, random: scala.util.Random): Any
* Draw samples until a criterion is met. Abort if no sample is reached in N steps
def rejectionSample(maxSamples: Int, params: Any, random: scala.util.Random)
(criterion: Any => Boolean): Option[Any] =
for(i <- 0 until maxSamples){
val v = sample(params, random)
if(criterion(v)){ return Some(v) }
return None
* Generate a summary of this distribution based on params
def describe(params: Any): String
* A unique label for this distribution
def label = this.getClass.getSimpleName.toLowerCase
* A [Distribution] that specifically samples numbers
trait NumericalDistributionFamily extends DistributionFamily
val baseType = DoubleType
* Compute the CDF
def approximateCDF(value: Double, params: Any, samples: Int): Double =
this match {
case c:CDFSupported => c.cdf(value, params)
case _ =>
val rand = new scala.util.Random()
(0 until samples).count { _ =>
sample(params, rand).asInstanceOf[Double] <= value
}.toDouble / samples
def approximateCDFIsFast(params: Any): Boolean = this.isInstanceOf[CDFSupported]
def min(params: Any): Double
def max(params: Any): Double
* An add-on to NumericalDistributionFamily that indicates an exact CDF can be computed
trait CDFSupported
val baseType: DataType
assert(baseType == DoubleType, "Non-numerical distributions can not support CDFs")
def cdf(value: Double, params: Any): Double
* An add-on to NumericalDistributionFamily that indicates an exact Inverse CDF can be computed
trait ICDFSupported
val baseType: DataType
assert(baseType == DoubleType, "Non-numerical distributions can not support ICDFs")
def icdf(value: Double, params: Any): Double
* Companion object for distributions: Keeps a registry of all known distributions
object DistributionFamily
val registered = mutable.Map[String, DistributionFamily]()
def all: Iterable[DistributionFamily] = registered.values
def register(distribution: DistributionFamily): Unit =
registered.put(distribution.label, distribution)
def apply(distribution: String): DistributionFamily =
.getOrElse {
throw new IllegalArgumentException(s"Invalid distribution family '$distribution': Available: ${registered.keys.mkString(", ")}")
/// Pre-defined distributions