mimir-pip/lib/src/org/mimirdb/pip/distribution/numerical/NumericalMixture.scala

117 lines
3.2 KiB
Scala

package org.mimirdb.pip.distribution.numerical
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions.udaf
import org.mimirdb.pip.udt.UnivariateDistribution
import java.util.UUID
import org.mimirdb.pip.SampleParams
import org.mimirdb.pip.distribution.DistributionFamily
object NumericalMixture
extends NumericalDistributionFamily
{
case class Child(family: NumericalDistributionFamily, params: Any, p: Double)
type Params = Seq[Child]
def sample(params: Any, random: scala.util.Random): Any =
{
var bins = params.asInstanceOf[Params]
var p = random.nextDouble()
while(bins.size > 1 && p > bins.head.p)
{
p -= bins.head.p
bins = bins.tail
}
bins.head
.family
.sample(bins.head.params, random)
}
def describe(params: Any): String =
{
"Mixture of "+params.asInstanceOf[Params].map { bin =>
s"${bin.family.describe(bin.params)}->${bin.p}"
}.mkString("; ")
}
def serialize(out: java.io.ObjectOutputStream, params: Any): Unit =
{
val bins = params.asInstanceOf[Params]
out.writeInt(bins.size)
for(bin <- bins)
{
out.writeDouble(bin.p)
out.writeUTF(bin.family.label)
bin.family.serialize(out, bin.params)
}
}
def deserialize(in: java.io.ObjectInputStream): Any =
{
val len = in.readInt()
(0 until len).map { _ =>
val p = in.readDouble()
val dist = DistributionFamily(in.readUTF()).asInstanceOf[NumericalDistributionFamily]
val params = dist.deserialize(in)
Child(
family = dist,
params = params,
p = p,
)
}
}
def min(params: Any) =
params.asInstanceOf[Params].map { bin =>
bin.family.min(bin.params)
}.min
def max(params: Any) =
params.asInstanceOf[Params].map { bin =>
bin.family.max(bin.params)
}.max
override def approximateCDF(value: Double, params: Any, samples: Int): Double =
{
params.asInstanceOf[Params].map { bin =>
bin.p * bin.family.approximateCDF(value, bin.params, samples)
}.sum
}
override def approximateCDFIsFast(params: Any): Boolean =
params.asInstanceOf[Params].forall { c => c.family.approximateCDFIsFast(c.params) }
object UniformAggregate extends Aggregator[UnivariateDistribution, List[UnivariateDistribution], UnivariateDistribution]
{
type T = List[UnivariateDistribution]
def zero: T = List()
def reduce(buffer: T, dataPoint: UnivariateDistribution): T =
dataPoint +: buffer
def merge(b1: T, b2: T): T =
b1 ++ b2
def finish(reduction: T): UnivariateDistribution =
{
val p = 1.0 / reduction.size
UnivariateDistribution(
family = NumericalMixture,
params =
reduction.map { dist =>
Child(dist.family.asInstanceOf[NumericalDistributionFamily], dist.params, p)
}
)
}
def bufferEncoder: Encoder[T] = ExpressionEncoder()
def outputEncoder: Encoder[UnivariateDistribution] = ExpressionEncoder()
}
val uniform = udaf(UniformAggregate)
}