mimir-pip/lib/src/org/mimirdb/pip/udf/KLDivergence.scala

31 lines
1017 B
Scala

package org.mimirdb.pip.udf
import org.mimirdb.pip.udt.UnivariateDistribution
import org.mimirdb.pip.distribution.numerical.Discretized
import org.mimirdb.pip.distribution.numerical.NumericalDistributionFamily
import org.apache.spark.sql.functions
object KLDivergence
{
val BUCKETS = 1000
def apply(target: UnivariateDistribution, base: UnivariateDistribution): Double =
{
(target.family, base.family) match {
case (Discretized, Discretized) if Discretized.sameBins(target.params, base.params) =>
Discretized.klDivergence(target.params, base.params)
case (_:NumericalDistributionFamily, Discretized) =>
Discretized.klDivergence(
Discretized(target, Discretized.bins(base.params), 1000),
base.params
)
case (Discretized, _:NumericalDistributionFamily) =>
Discretized.klDivergence(
target.params,
Discretized(base, Discretized.bins(target.params), 1000),
)
}
}
def udf = functions.udf(apply(_,_))
}