46 lines
1.6 KiB
Scala
46 lines
1.6 KiB
Scala
package org.mimirdb.pip
|
|
|
|
|
|
import org.apache.spark.sql.types.UDTRegistration
|
|
import org.apache.spark.sql.SparkSession
|
|
import org.apache.spark.sql.functions.udaf
|
|
import org.apache.spark.sql.catalyst.expressions.Expression
|
|
import distribution.DistributionFamily
|
|
|
|
|
|
/**
|
|
* Entry points for the Pip plugin
|
|
*/
|
|
object Pip {
|
|
|
|
|
|
/**
|
|
* Initialize Pip with the current spark session
|
|
*/
|
|
def init(spark: SparkSession): Unit = {
|
|
UDTRegistration.register(classOf[udt.RandomVariable].getName(),
|
|
classOf[udt.RandomVariableType].getName())
|
|
UDTRegistration.register(classOf[udt.UnivariateDistribution].getName(),
|
|
classOf[udt.UnivariateDistributionType].getName())
|
|
|
|
def registerFunction(name: String, fn: (Seq[Expression] => Expression)) =
|
|
spark.sessionState
|
|
.functionRegistry
|
|
.createOrReplaceTempFunction(name, fn, "scala_udf")
|
|
|
|
registerFunction("gaussian", distribution.numerical.Gaussian.Constructor(_))
|
|
registerFunction("uniform", distribution.numerical.Uniform.Constructor(_))
|
|
registerFunction("num_const", distribution.numerical.ConstantNumber.Constructor(_))
|
|
registerFunction("clamp", distribution.numerical.Clamp.Constructor)
|
|
registerFunction("discretize", distribution.numerical.Discretized.Constructor)
|
|
spark.udf.register("entropy", udf.Entropy.udf)
|
|
spark.udf.register("kl_divergence", udf.KLDivergence.udf)
|
|
|
|
// Aggregates
|
|
spark.udf.register("uniform_mixture", distribution.numerical.NumericalMixture.uniform)
|
|
spark.udf.register("histogram", udaf(udf.Histogram))
|
|
}
|
|
|
|
}
|
|
|