31 lines
1017 B
Scala
31 lines
1017 B
Scala
package org.mimirdb.pip.udf
|
|
|
|
import org.mimirdb.pip.udt.UnivariateDistribution
|
|
import org.mimirdb.pip.distribution.numerical.Discretized
|
|
import org.mimirdb.pip.distribution.numerical.NumericalDistributionFamily
|
|
import org.apache.spark.sql.functions
|
|
|
|
object KLDivergence
|
|
{
|
|
val BUCKETS = 1000
|
|
|
|
def apply(target: UnivariateDistribution, base: UnivariateDistribution): Double =
|
|
{
|
|
(target.family, base.family) match {
|
|
case (Discretized, Discretized) if Discretized.sameBins(target.params, base.params) =>
|
|
Discretized.klDivergence(target.params, base.params)
|
|
case (_:NumericalDistributionFamily, Discretized) =>
|
|
Discretized.klDivergence(
|
|
Discretized(target, Discretized.bins(base.params), 1000),
|
|
base.params
|
|
)
|
|
case (Discretized, _:NumericalDistributionFamily) =>
|
|
Discretized.klDivergence(
|
|
target.params,
|
|
Discretized(base, Discretized.bins(target.params), 1000),
|
|
)
|
|
}
|
|
}
|
|
|
|
def udf = functions.udf(apply(_,_))
|
|
} |