mimir-pip/lib/src/org/mimirdb/pip/distribution/distributionFamily.scala

85 lines
2.3 KiB
Scala

package org.mimirdb.pip.distribution
import java.io.ObjectOutputStream
import scala.util.Random
import java.io.ObjectInputStream
import org.apache.spark.sql.types.{ DataType, DoubleType }
import scala.collection.mutable
import java.util.UUID
/**
* A random variable distribution.
*/
trait DistributionFamily
{
/**
* The underlying datatype generated by this distribution
*/
val baseType: DataType
/**
* Encode the params for an instance of this distribution
*/
def serialize(out: ObjectOutputStream, params: Any): Unit
/**
* Decode the params for an instance of this distribution
*/
def deserialize(in: ObjectInputStream): Any
/**
* Draw a sample from this distribution. You <b>must</b> generate random numbers based on the
* provided seed and uuid
*/
def sample(params: Any, random: scala.util.Random): Any
/**
* Draw samples until a criterion is met. Abort if no sample is reached in N steps
*/
def rejectionSample(maxSamples: Int, params: Any, random: scala.util.Random)
(criterion: Any => Boolean): Option[Any] =
{
for(i <- 0 until maxSamples){
val v = sample(params, random)
if(criterion(v)){ return Some(v) }
}
return None
}
/**
* Generate a summary of this distribution based on params
*/
def describe(params: Any): String
/**
* A unique label for this distribution
*/
def label = this.getClass.getSimpleName.toLowerCase
}
/**
* Companion object for distributions: Keeps a registry of all known distributions
*/
object DistributionFamily
{
val registered = mutable.Map[String, DistributionFamily]()
def all: Iterable[DistributionFamily] = registered.values
def register(distribution: DistributionFamily): Unit =
registered.put(distribution.label, distribution)
def apply(distribution: String): DistributionFamily =
registered.get(distribution.toLowerCase)
.getOrElse {
throw new IllegalArgumentException(s"Invalid distribution family '$distribution': Available: ${registered.keys.mkString(", ")}")
}
/// Pre-defined distributions
register(numerical.Gaussian)
register(numerical.NumericalMixture)
register(numerical.Clamp)
register(numerical.Discretized)
register(numerical.Uniform)
}