85 lines
2.3 KiB
Scala
85 lines
2.3 KiB
Scala
package org.mimirdb.pip.distribution
|
|
|
|
import java.io.ObjectOutputStream
|
|
import scala.util.Random
|
|
import java.io.ObjectInputStream
|
|
import org.apache.spark.sql.types.{ DataType, DoubleType }
|
|
import scala.collection.mutable
|
|
import java.util.UUID
|
|
|
|
/**
|
|
* A random variable distribution.
|
|
*/
|
|
trait DistributionFamily
|
|
{
|
|
/**
|
|
* The underlying datatype generated by this distribution
|
|
*/
|
|
val baseType: DataType
|
|
|
|
/**
|
|
* Encode the params for an instance of this distribution
|
|
*/
|
|
def serialize(out: ObjectOutputStream, params: Any): Unit
|
|
|
|
/**
|
|
* Decode the params for an instance of this distribution
|
|
*/
|
|
def deserialize(in: ObjectInputStream): Any
|
|
|
|
/**
|
|
* Draw a sample from this distribution. You <b>must</b> generate random numbers based on the
|
|
* provided seed and uuid
|
|
*/
|
|
def sample(params: Any, random: scala.util.Random): Any
|
|
|
|
/**
|
|
* Draw samples until a criterion is met. Abort if no sample is reached in N steps
|
|
*/
|
|
def rejectionSample(maxSamples: Int, params: Any, random: scala.util.Random)
|
|
(criterion: Any => Boolean): Option[Any] =
|
|
{
|
|
for(i <- 0 until maxSamples){
|
|
val v = sample(params, random)
|
|
if(criterion(v)){ return Some(v) }
|
|
}
|
|
return None
|
|
}
|
|
|
|
/**
|
|
* Generate a summary of this distribution based on params
|
|
*/
|
|
def describe(params: Any): String
|
|
|
|
/**
|
|
* A unique label for this distribution
|
|
*/
|
|
def label = this.getClass.getSimpleName.toLowerCase
|
|
}
|
|
|
|
/**
|
|
* Companion object for distributions: Keeps a registry of all known distributions
|
|
*/
|
|
object DistributionFamily
|
|
{
|
|
val registered = mutable.Map[String, DistributionFamily]()
|
|
|
|
def all: Iterable[DistributionFamily] = registered.values
|
|
|
|
def register(distribution: DistributionFamily): Unit =
|
|
registered.put(distribution.label, distribution)
|
|
|
|
def apply(distribution: String): DistributionFamily =
|
|
registered.get(distribution.toLowerCase)
|
|
.getOrElse {
|
|
throw new IllegalArgumentException(s"Invalid distribution family '$distribution': Available: ${registered.keys.mkString(", ")}")
|
|
}
|
|
|
|
|
|
/// Pre-defined distributions
|
|
register(numerical.Gaussian)
|
|
register(numerical.NumericalMixture)
|
|
register(numerical.Clamp)
|
|
register(numerical.Discretized)
|
|
register(numerical.Uniform)
|
|
} |