145 lines
4.8 KiB
Scala
145 lines
4.8 KiB
Scala
package org.mimirdb.pip.distribution.numerical
|
|
|
|
import scala.util.Random
|
|
import java.io.ObjectOutputStream
|
|
import java.io.ObjectInputStream
|
|
import org.mimirdb.pip.udt.UnivariateDistribution
|
|
import org.apache.spark.sql.functions
|
|
import org.mimirdb.pip.udt.UnivariateDistributionConstructor
|
|
import org.apache.spark.sql.catalyst.expressions.Expression
|
|
import org.mimirdb.pip.distribution.DistributionFamily
|
|
|
|
object Clamp
|
|
extends NumericalDistributionFamily
|
|
{
|
|
|
|
def apply(col: UnivariateDistribution, low: Double, high: Double): Params =
|
|
Params(
|
|
family = col.family.asInstanceOf[NumericalDistributionFamily],
|
|
params = col.params,
|
|
low = Some(low),
|
|
high = Some(high)
|
|
)
|
|
|
|
case class Params(
|
|
family: NumericalDistributionFamily,
|
|
params: Any,
|
|
low: Option[Double],
|
|
high: Option[Double]
|
|
)
|
|
|
|
def describe(params: Any): String =
|
|
"Clamp at "+
|
|
params.asInstanceOf[Params].low.map { "["+_ }.getOrElse { "(∞" } + ", " +
|
|
params.asInstanceOf[Params].high.map { _.toString+"]" }.getOrElse { "∞)" } + " of " +
|
|
params.asInstanceOf[Params].family.describe(params.asInstanceOf[Params].params)
|
|
|
|
def min(params: Any): Double =
|
|
params.asInstanceOf[Params].low.map { Math.max(_,
|
|
params.asInstanceOf[Params].family.min(params.asInstanceOf[Params].params)
|
|
)}.getOrElse {
|
|
params.asInstanceOf[Params].family.min(params.asInstanceOf[Params].params)
|
|
}
|
|
|
|
def max(params: Any): Double =
|
|
params.asInstanceOf[Params].high.map { Math.min(_,
|
|
params.asInstanceOf[Params].family.min(params.asInstanceOf[Params].params)
|
|
)}.getOrElse {
|
|
params.asInstanceOf[Params].family.min(params.asInstanceOf[Params].params)
|
|
}
|
|
|
|
def sample(params: Any, random: Random): Any =
|
|
{
|
|
val child = params.asInstanceOf[Params]
|
|
child match {
|
|
case c:(CDFSupported with ICDFSupported) =>
|
|
{
|
|
val lowBound = child.low.map { c.cdf(_, child.params) }.getOrElse { 0.0 }
|
|
val highBound = child.high.map { c.cdf(_, child.params) }.getOrElse { 1.0 }
|
|
|
|
val samplePt = random.nextDouble() * (highBound - lowBound) + lowBound
|
|
|
|
c.icdf(samplePt, params)
|
|
}
|
|
|
|
case _ =>
|
|
{
|
|
child.family.rejectionSample(1000, child.params, random){
|
|
case v:Double =>
|
|
child.low.map { _ <= v }.getOrElse(true) &&
|
|
child.high.map { v <= _ }.getOrElse(true)
|
|
}.getOrElse {
|
|
throw new RuntimeException(s"Aborting sampling from rare event: ${child.family.describe(child.params)}")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
def serialize(out: ObjectOutputStream, params: Any): Unit =
|
|
{
|
|
val child = params.asInstanceOf[Params]
|
|
out.writeDouble(child.low.getOrElse { Double.NaN })
|
|
out.writeDouble(child.high.getOrElse { Double.NaN })
|
|
out.writeUTF(child.family.label)
|
|
child.family.serialize(out, child.params)
|
|
}
|
|
|
|
def deserialize(out: ObjectInputStream): Any =
|
|
{
|
|
val low = out.readDouble() match {
|
|
case x if x.isNaN() => None
|
|
case x => Some(x)
|
|
}
|
|
val high = out.readDouble() match {
|
|
case x if x.isNaN() => None
|
|
case x => Some(x)
|
|
}
|
|
val family = DistributionFamily(out.readUTF()).asInstanceOf[NumericalDistributionFamily]
|
|
val params = family.deserialize(out)
|
|
Params(
|
|
family = family,
|
|
params = params,
|
|
low = low,
|
|
high = high,
|
|
)
|
|
}
|
|
|
|
override def approximateCDF(value: Double, params: Any, samples: Int): Double =
|
|
{
|
|
val child = params.asInstanceOf[Params]
|
|
if(child.family.approximateCDFIsFast(params))
|
|
{
|
|
val lowBound = child.low.map { child.family.approximateCDF(_, child.params, 1000) }.getOrElse { 0.0 }
|
|
val highBound = child.high.map { child.family.approximateCDF(_, child.params, 1000) }.getOrElse { 1.0 }
|
|
val actual = child.family.approximateCDF(value, child.params, 1000)
|
|
// println(s"CDF of $value @ Clamp Bounds: [${child.low} -> $lowBound, ${child.high} -> $highBound]: ${child.family.describe(child.params)}")
|
|
if(actual < lowBound){ return 0.0 }
|
|
if(actual > highBound){ return 1.0 }
|
|
return (actual - lowBound) / (highBound - lowBound)
|
|
} else {
|
|
super.approximateCDF(value, params, samples)
|
|
}
|
|
}
|
|
override def approximateCDFIsFast(params: Any): Boolean =
|
|
params.asInstanceOf[Params].family.approximateCDFIsFast(
|
|
params.asInstanceOf[Params].params
|
|
)
|
|
|
|
case class Constructor(args: Seq[Expression])
|
|
extends UnivariateDistributionConstructor
|
|
{
|
|
def family = Clamp
|
|
def params(values: Seq[Any]) =
|
|
{
|
|
Clamp(
|
|
col = UnivariateDistribution.decode(values(0)),
|
|
low = values(1).asInstanceOf[Double],
|
|
high = values(2).asInstanceOf[Double],
|
|
)
|
|
}
|
|
|
|
def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
|
|
copy(args = newChildren)
|
|
}
|
|
|
|
} |