117 lines
3.2 KiB
Scala
117 lines
3.2 KiB
Scala
package org.mimirdb.pip.distribution.numerical
|
|
|
|
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
|
|
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
|
|
import org.apache.spark.sql.expressions.Aggregator
|
|
import org.apache.spark.sql.functions.udaf
|
|
import org.mimirdb.pip.udt.UnivariateDistribution
|
|
import java.util.UUID
|
|
import org.mimirdb.pip.SampleParams
|
|
import org.mimirdb.pip.distribution.DistributionFamily
|
|
|
|
object NumericalMixture
|
|
extends NumericalDistributionFamily
|
|
{
|
|
|
|
case class Child(family: NumericalDistributionFamily, params: Any, p: Double)
|
|
type Params = Seq[Child]
|
|
|
|
def sample(params: Any, random: scala.util.Random): Any =
|
|
{
|
|
var bins = params.asInstanceOf[Params]
|
|
var p = random.nextDouble()
|
|
while(bins.size > 1 && p > bins.head.p)
|
|
{
|
|
p -= bins.head.p
|
|
bins = bins.tail
|
|
}
|
|
bins.head
|
|
.family
|
|
.sample(bins.head.params, random)
|
|
}
|
|
|
|
def describe(params: Any): String =
|
|
{
|
|
"Mixture of "+params.asInstanceOf[Params].map { bin =>
|
|
s"${bin.family.describe(bin.params)}->${bin.p}"
|
|
}.mkString("; ")
|
|
}
|
|
|
|
def serialize(out: java.io.ObjectOutputStream, params: Any): Unit =
|
|
{
|
|
val bins = params.asInstanceOf[Params]
|
|
out.writeInt(bins.size)
|
|
for(bin <- bins)
|
|
{
|
|
out.writeDouble(bin.p)
|
|
out.writeUTF(bin.family.label)
|
|
bin.family.serialize(out, bin.params)
|
|
}
|
|
}
|
|
def deserialize(in: java.io.ObjectInputStream): Any =
|
|
{
|
|
val len = in.readInt()
|
|
|
|
(0 until len).map { _ =>
|
|
val p = in.readDouble()
|
|
val dist = DistributionFamily(in.readUTF()).asInstanceOf[NumericalDistributionFamily]
|
|
val params = dist.deserialize(in)
|
|
Child(
|
|
family = dist,
|
|
params = params,
|
|
p = p,
|
|
)
|
|
}
|
|
}
|
|
|
|
def min(params: Any) =
|
|
params.asInstanceOf[Params].map { bin =>
|
|
bin.family.min(bin.params)
|
|
}.min
|
|
def max(params: Any) =
|
|
params.asInstanceOf[Params].map { bin =>
|
|
bin.family.max(bin.params)
|
|
}.max
|
|
|
|
override def approximateCDF(value: Double, params: Any, samples: Int): Double =
|
|
{
|
|
params.asInstanceOf[Params].map { bin =>
|
|
bin.p * bin.family.approximateCDF(value, bin.params, samples)
|
|
}.sum
|
|
}
|
|
|
|
override def approximateCDFIsFast(params: Any): Boolean =
|
|
params.asInstanceOf[Params].forall { c => c.family.approximateCDFIsFast(c.params) }
|
|
|
|
object UniformAggregate extends Aggregator[UnivariateDistribution, List[UnivariateDistribution], UnivariateDistribution]
|
|
{
|
|
type T = List[UnivariateDistribution]
|
|
|
|
def zero: T = List()
|
|
|
|
def reduce(buffer: T, dataPoint: UnivariateDistribution): T =
|
|
dataPoint +: buffer
|
|
|
|
def merge(b1: T, b2: T): T =
|
|
b1 ++ b2
|
|
|
|
def finish(reduction: T): UnivariateDistribution =
|
|
{
|
|
val p = 1.0 / reduction.size
|
|
UnivariateDistribution(
|
|
family = NumericalMixture,
|
|
params =
|
|
reduction.map { dist =>
|
|
Child(dist.family.asInstanceOf[NumericalDistributionFamily], dist.params, p)
|
|
}
|
|
)
|
|
}
|
|
|
|
def bufferEncoder: Encoder[T] = ExpressionEncoder()
|
|
|
|
def outputEncoder: Encoder[UnivariateDistribution] = ExpressionEncoder()
|
|
}
|
|
|
|
val uniform = udaf(UniformAggregate)
|
|
|
|
} |