Cleaning up function registration; Adding Vizier plugin support.

main
Oliver Kennedy 2024-01-28 22:24:02 -05:00
parent 3c65411e05
commit b2799bbac3
Signed by: okennedy
GPG Key ID: 3E5F9B3ABD3FDB60
8 changed files with 115 additions and 25 deletions

View File

@ -0,0 +1,6 @@
{
"schema_version" : 1,
"plugin_class" : "org.mimirdb.pip.Pip$",
"name" : "Mimir-Pip",
"description" : "Tools for working with probability distributions"
}

View File

@ -4,6 +4,7 @@ package org.mimirdb.pip
import org.apache.spark.sql.types.UDTRegistration
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.udaf
import org.apache.spark.sql.catalyst.expressions.Expression
import distribution.DistributionFamily
@ -12,6 +13,7 @@ import distribution.DistributionFamily
*/
object Pip {
/**
* Initialize Pip with the current spark session
*/
@ -20,10 +22,15 @@ object Pip {
classOf[udt.RandomVariableType].getName())
UDTRegistration.register(classOf[udt.UnivariateDistribution].getName(),
classOf[udt.UnivariateDistributionType].getName())
spark.udf.register("gaussian", distribution.Gaussian.udf)
spark.udf.register("clamp", distribution.Clamp.udf)
spark.udf.register("discretize", distribution.Discretized.udf)
def registerFunction(name: String, fn: (Seq[Expression] => Expression)) =
spark.sessionState
.functionRegistry
.createOrReplaceTempFunction(name, fn, "scala_udf")
registerFunction("gaussian", distribution.Gaussian.Constructor(_))
registerFunction("clamp", distribution.Clamp.Constructor)
registerFunction("discretize", distribution.Discretized.Constructor)
spark.udf.register("entropy", udf.Entropy.udf)
spark.udf.register("kl_divergence", udf.KLDivergence.udf)

View File

@ -5,21 +5,20 @@ import java.io.ObjectOutputStream
import java.io.ObjectInputStream
import org.mimirdb.pip.udt.UnivariateDistribution
import org.apache.spark.sql.functions
import org.mimirdb.pip.udt.UnivariateDistributionConstructor
import org.apache.spark.sql.catalyst.expressions.Expression
object Clamp
extends NumericalDistributionFamily
{
def apply(col: UnivariateDistribution, low: Double, high: Double): UnivariateDistribution =
UnivariateDistribution(this, Params(
def apply(col: UnivariateDistribution, low: Double, high: Double): Params =
Params(
family = col.family.asInstanceOf[NumericalDistributionFamily],
params = col.params,
low = Some(low),
high = Some(high)
))
def udf = functions.udf(apply(_, _, _))
)
case class Params(
family: NumericalDistributionFamily,
@ -124,4 +123,22 @@ object Clamp
params.asInstanceOf[Params].family.approximateCDFIsFast(
params.asInstanceOf[Params].params
)
case class Constructor(args: Seq[Expression])
extends UnivariateDistributionConstructor
{
def family = Clamp
def params(values: Seq[Any]) =
{
Clamp(
col = UnivariateDistribution.decode(values(0)),
low = values(1).asInstanceOf[Double],
high = values(2).asInstanceOf[Double],
)
}
def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
copy(args = newChildren)
}
}

View File

@ -7,6 +7,8 @@ import org.mimirdb.pip.SampleParams
import org.mimirdb.pip.udt.UnivariateDistribution
import org.apache.spark.sql.functions
import scala.collection.Searching._
import org.mimirdb.pip.udt.UnivariateDistributionConstructor
import org.apache.spark.sql.catalyst.expressions.Expression
object Discretized
extends NumericalDistributionFamily
@ -127,7 +129,7 @@ object Discretized
.forall { case (a, b) => a.low == b.low && a.high == b.high }
}
def apply(base: UnivariateDistribution, bins: Array[Double], samples: Int): UnivariateDistribution =
def apply(base: UnivariateDistribution, bins: Array[Double], samples: Int): Params =
{
assert(bins.size >= 2)
val baseFamily = base.family.asInstanceOf[NumericalDistributionFamily]
@ -171,11 +173,25 @@ object Discretized
// println(bins.mkString(", "))
check(params)
UnivariateDistribution(
family = this,
params = params
)
return params
}
def udf = functions.udf(apply(_, _, _))
case class Constructor(args: Seq[Expression])
extends UnivariateDistributionConstructor
{
def family = Discretized
def params(values: Seq[Any]) =
Discretized(
base = UnivariateDistribution.decode(values(0)),
bins =
values(1).asInstanceOf[Double].until(
values(2).asInstanceOf[Double],
values(3).asInstanceOf[Double]
).toArray,
samples = values(4).asInstanceOf[Int]
)
def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
copy(args = newChildren)
}
}

View File

@ -6,8 +6,12 @@ import java.io.ObjectOutputStream
import java.io.ObjectInputStream
import org.apache.commons.math3.special.Erf
import org.apache.spark.sql.Column
import org.apache.spark.sql.types.DoubleType
import org.mimirdb.pip.SampleParams
import org.mimirdb.pip.udt.UnivariateDistribution
import org.mimirdb.pip.udt.UnivariateDistributionType
import org.mimirdb.pip.udt.UnivariateDistributionConstructor
import org.apache.spark.sql.catalyst.expressions.Expression
/**
* The Gaussian (normal) distribution
@ -57,12 +61,18 @@ object Gaussian
+ params.asInstanceOf[Params].mean
}
def apply(mean: Double, sd: Double): UnivariateDistribution =
UnivariateDistribution(this, Params(mean, sd))
case class Constructor(args: Seq[Expression])
extends UnivariateDistributionConstructor
{
def family = Gaussian
def params(values: Seq[Any]) =
Params(mean = values(0).asInstanceOf[Double],
sd = values(1).asInstanceOf[Double])
def udf = org.apache.spark.sql.functions.udf(apply(_:Double, _:Double))
def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
copy(args = newChildren)
}
def apply(mean: Column, sd: Column): Column = udf(mean, sd)
def plus(a: Params, b: Params): Params =
{

View File

@ -19,7 +19,7 @@ object Entropy
val max = family.max(dist.params)
val step = (max - min) / BUCKETS
val params = Discretized(dist, (min.until(max, step)).toArray, 1000).params
val params = Discretized(dist, (min.until(max, step)).toArray, 1000)
Discretized.entropy(params)
}
}

View File

@ -16,13 +16,13 @@ object KLDivergence
Discretized.klDivergence(target.params, base.params)
case (_:NumericalDistributionFamily, Discretized) =>
Discretized.klDivergence(
Discretized(target, Discretized.bins(base.params), 1000).params,
Discretized(target, Discretized.bins(base.params), 1000),
base.params
)
case (Discretized, _:NumericalDistributionFamily) =>
Discretized.klDivergence(
target.params,
Discretized(base, Discretized.bins(target.params), 1000).params,
Discretized(base, Discretized.bins(target.params), 1000),
)
}
}

View File

@ -2,15 +2,29 @@ package org.mimirdb.pip.udt
import org.mimirdb.pip.distribution.DistributionFamily
import org.apache.spark.sql.types.{ DataType, UserDefinedType, BinaryType }
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.InternalRow
import java.util.UUID
@SerialVersionUID(100)
case class UnivariateDistribution(family: DistributionFamily, params: Any)
extends Serializable
{
override def toString(): String =
family.describe(params)
}
class UnivariateDistributionType extends UserDefinedType[UnivariateDistribution]
object UnivariateDistribution
{
def decode(datum: Any): UnivariateDistribution =
datum match {
case ud:UnivariateDistribution => return ud
case ba:Array[Byte] => UnivariateDistributionType.deserialize(ba)
}
}
class UnivariateDistributionType extends UserDefinedType[UnivariateDistribution] with Serializable
{
override def equals(other: Any): Boolean = other match {
@ -30,7 +44,8 @@ class UnivariateDistributionType extends UserDefinedType[UnivariateDistribution]
return byteBuffer.toByteArray()
}
override def deserialize(datum: Any): UnivariateDistribution = {
override def deserialize(datum: Any): UnivariateDistribution =
{
val bis = new java.io.ByteArrayInputStream(datum.asInstanceOf[Array[Byte]])
val in = new java.io.ObjectInputStream(bis)
val dist = DistributionFamily(in.readUTF())
@ -45,3 +60,22 @@ class UnivariateDistributionType extends UserDefinedType[UnivariateDistribution]
}
case object UnivariateDistributionType extends org.mimirdb.pip.udt.UnivariateDistributionType
abstract class UnivariateDistributionConstructor
extends Expression with Serializable with CodegenFallback
{
def args: Seq[Expression]
def family: DistributionFamily
def params(args: Seq[Any]): Any
def dataType = UnivariateDistributionType
def nullable = false
def eval(input: InternalRow): Any =
{
val dist = UnivariateDistribution(family, params(args.map { _.eval(input) }))
UnivariateDistributionType.serialize(dist)
}
def children = args
}