mimir-pip/lib/src/org/mimirdb/pip/distribution/numerical/Clamp.scala

145 lines
4.8 KiB
Scala

package org.mimirdb.pip.distribution.numerical
import scala.util.Random
import java.io.ObjectOutputStream
import java.io.ObjectInputStream
import org.mimirdb.pip.udt.UnivariateDistribution
import org.apache.spark.sql.functions
import org.mimirdb.pip.udt.UnivariateDistributionConstructor
import org.apache.spark.sql.catalyst.expressions.Expression
import org.mimirdb.pip.distribution.DistributionFamily
object Clamp
extends NumericalDistributionFamily
{
def apply(col: UnivariateDistribution, low: Double, high: Double): Params =
Params(
family = col.family.asInstanceOf[NumericalDistributionFamily],
params = col.params,
low = Some(low),
high = Some(high)
)
case class Params(
family: NumericalDistributionFamily,
params: Any,
low: Option[Double],
high: Option[Double]
)
def describe(params: Any): String =
"Clamp at "+
params.asInstanceOf[Params].low.map { "["+_ }.getOrElse { "(∞" } + ", " +
params.asInstanceOf[Params].high.map { _.toString+"]" }.getOrElse { "∞)" } + " of " +
params.asInstanceOf[Params].family.describe(params.asInstanceOf[Params].params)
def min(params: Any): Double =
params.asInstanceOf[Params].low.map { Math.max(_,
params.asInstanceOf[Params].family.min(params.asInstanceOf[Params].params)
)}.getOrElse {
params.asInstanceOf[Params].family.min(params.asInstanceOf[Params].params)
}
def max(params: Any): Double =
params.asInstanceOf[Params].high.map { Math.min(_,
params.asInstanceOf[Params].family.min(params.asInstanceOf[Params].params)
)}.getOrElse {
params.asInstanceOf[Params].family.min(params.asInstanceOf[Params].params)
}
def sample(params: Any, random: Random): Any =
{
val child = params.asInstanceOf[Params]
child match {
case c:(CDFSupported with ICDFSupported) =>
{
val lowBound = child.low.map { c.cdf(_, child.params) }.getOrElse { 0.0 }
val highBound = child.high.map { c.cdf(_, child.params) }.getOrElse { 1.0 }
val samplePt = random.nextDouble() * (highBound - lowBound) + lowBound
c.icdf(samplePt, params)
}
case _ =>
{
child.family.rejectionSample(1000, child.params, random){
case v:Double =>
child.low.map { _ <= v }.getOrElse(true) &&
child.high.map { v <= _ }.getOrElse(true)
}.getOrElse {
throw new RuntimeException(s"Aborting sampling from rare event: ${child.family.describe(child.params)}")
}
}
}
}
def serialize(out: ObjectOutputStream, params: Any): Unit =
{
val child = params.asInstanceOf[Params]
out.writeDouble(child.low.getOrElse { Double.NaN })
out.writeDouble(child.high.getOrElse { Double.NaN })
out.writeUTF(child.family.label)
child.family.serialize(out, child.params)
}
def deserialize(out: ObjectInputStream): Any =
{
val low = out.readDouble() match {
case x if x.isNaN() => None
case x => Some(x)
}
val high = out.readDouble() match {
case x if x.isNaN() => None
case x => Some(x)
}
val family = DistributionFamily(out.readUTF()).asInstanceOf[NumericalDistributionFamily]
val params = family.deserialize(out)
Params(
family = family,
params = params,
low = low,
high = high,
)
}
override def approximateCDF(value: Double, params: Any, samples: Int): Double =
{
val child = params.asInstanceOf[Params]
if(child.family.approximateCDFIsFast(params))
{
val lowBound = child.low.map { child.family.approximateCDF(_, child.params, 1000) }.getOrElse { 0.0 }
val highBound = child.high.map { child.family.approximateCDF(_, child.params, 1000) }.getOrElse { 1.0 }
val actual = child.family.approximateCDF(value, child.params, 1000)
// println(s"CDF of $value @ Clamp Bounds: [${child.low} -> $lowBound, ${child.high} -> $highBound]: ${child.family.describe(child.params)}")
if(actual < lowBound){ return 0.0 }
if(actual > highBound){ return 1.0 }
return (actual - lowBound) / (highBound - lowBound)
} else {
super.approximateCDF(value, params, samples)
}
}
override def approximateCDFIsFast(params: Any): Boolean =
params.asInstanceOf[Params].family.approximateCDFIsFast(
params.asInstanceOf[Params].params
)
case class Constructor(args: Seq[Expression])
extends UnivariateDistributionConstructor
{
def family = Clamp
def params(values: Seq[Any]) =
{
Clamp(
col = UnivariateDistribution.decode(values(0)),
low = values(1).asInstanceOf[Double],
high = values(2).asInstanceOf[Double],
)
}
def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
copy(args = newChildren)
}
}