mimir-pip/lib/src/org/mimirdb/pip/Pip.scala

46 lines
1.6 KiB
Scala

package org.mimirdb.pip
import org.apache.spark.sql.types.UDTRegistration
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.udaf
import org.apache.spark.sql.catalyst.expressions.Expression
import distribution.DistributionFamily
/**
* Entry points for the Pip plugin
*/
object Pip {
/**
* Initialize Pip with the current spark session
*/
def init(spark: SparkSession): Unit = {
UDTRegistration.register(classOf[udt.RandomVariable].getName(),
classOf[udt.RandomVariableType].getName())
UDTRegistration.register(classOf[udt.UnivariateDistribution].getName(),
classOf[udt.UnivariateDistributionType].getName())
def registerFunction(name: String, fn: (Seq[Expression] => Expression)) =
spark.sessionState
.functionRegistry
.createOrReplaceTempFunction(name, fn, "scala_udf")
registerFunction("gaussian", distribution.numerical.Gaussian.Constructor(_))
registerFunction("uniform", distribution.numerical.Uniform.Constructor(_))
registerFunction("num_const", distribution.numerical.ConstantNumber.Constructor(_))
registerFunction("clamp", distribution.numerical.Clamp.Constructor)
registerFunction("discretize", distribution.numerical.Discretized.Constructor)
spark.udf.register("entropy", udf.Entropy.udf)
spark.udf.register("kl_divergence", udf.KLDivergence.udf)
// Aggregates
spark.udf.register("uniform_mixture", distribution.numerical.NumericalMixture.uniform)
spark.udf.register("histogram", udaf(udf.Histogram))
}
}