70 lines
2.0 KiB
Scala
70 lines
2.0 KiB
Scala
package org.mimirdb.pip.distribution.boolean
|
|
|
|
import org.mimirdb.pip.distribution.DistributionFamily
|
|
import org.mimirdb.pip.distribution.numerical.NumericalDistributionFamily
|
|
import org.mimirdb.pip.distribution.numerical.CDFSupported
|
|
|
|
object Between
|
|
extends BooleanDistributionFamily
|
|
{
|
|
case class Params(lower: Double, upper: Double, baseDist: String, baseParams: Any)
|
|
{
|
|
def dist =
|
|
DistributionFamily(baseDist).asInstanceOf[NumericalDistributionFamily]
|
|
def apply[A](op: (NumericalDistributionFamily, Any) => A): A =
|
|
{
|
|
op(dist, baseParams)
|
|
}
|
|
}
|
|
|
|
override def approximateProbability(params: Any, samples: Int): Double =
|
|
{
|
|
val config = params.asInstanceOf[Params]
|
|
|
|
config.dist match {
|
|
case dist: CDFSupported =>
|
|
dist.cdf(config.upper, config.baseParams)
|
|
- dist.cdf(config.lower, config.baseParams)
|
|
case dist => super.approximateProbability(params, samples)
|
|
}
|
|
}
|
|
override def approximateProbabilityIsFast(params: Any): Boolean =
|
|
params.asInstanceOf[Params].dist.isInstanceOf[CDFSupported]
|
|
|
|
def describe(params: Any): String =
|
|
{
|
|
val config = params.asInstanceOf[Params]
|
|
s"Between(${config.lower} < ${config { _.describe(_) }} < ${config.upper})"
|
|
}
|
|
|
|
def sample(params: Any, random: scala.util.Random): Boolean =
|
|
{
|
|
val config = params.asInstanceOf[Params]
|
|
val v:Double = config { _.sample(_, random).asInstanceOf[Double] }
|
|
|
|
return (v > config.lower) && (v < config.upper)
|
|
}
|
|
|
|
def deserialize(in: java.io.ObjectInputStream): Any =
|
|
{
|
|
val lower = in.readDouble()
|
|
val upper = in.readDouble()
|
|
val baseDist = in.readUTF()
|
|
val dist = DistributionFamily(baseDist)
|
|
Params(
|
|
lower = lower,
|
|
upper = upper,
|
|
baseDist = baseDist,
|
|
baseParams = dist.deserialize(in)
|
|
)
|
|
}
|
|
|
|
def serialize(out: java.io.ObjectOutputStream, params: Any): Unit =
|
|
{
|
|
val config = params.asInstanceOf[Params]
|
|
out.writeDouble(config.lower)
|
|
out.writeDouble(config.upper)
|
|
out.writeUTF(config.baseDist)
|
|
config { _.serialize(out, _) }
|
|
}
|
|
} |