Cleaning up function registration; Adding Vizier plugin support.
parent
3c65411e05
commit
b2799bbac3
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"schema_version" : 1,
|
||||
"plugin_class" : "org.mimirdb.pip.Pip$",
|
||||
"name" : "Mimir-Pip",
|
||||
"description" : "Tools for working with probability distributions"
|
||||
}
|
|
@ -4,6 +4,7 @@ package org.mimirdb.pip
|
|||
import org.apache.spark.sql.types.UDTRegistration
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.functions.udaf
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
import distribution.DistributionFamily
|
||||
|
||||
|
||||
|
@ -12,6 +13,7 @@ import distribution.DistributionFamily
|
|||
*/
|
||||
object Pip {
|
||||
|
||||
|
||||
/**
|
||||
* Initialize Pip with the current spark session
|
||||
*/
|
||||
|
@ -20,10 +22,15 @@ object Pip {
|
|||
classOf[udt.RandomVariableType].getName())
|
||||
UDTRegistration.register(classOf[udt.UnivariateDistribution].getName(),
|
||||
classOf[udt.UnivariateDistributionType].getName())
|
||||
|
||||
spark.udf.register("gaussian", distribution.Gaussian.udf)
|
||||
spark.udf.register("clamp", distribution.Clamp.udf)
|
||||
spark.udf.register("discretize", distribution.Discretized.udf)
|
||||
|
||||
def registerFunction(name: String, fn: (Seq[Expression] => Expression)) =
|
||||
spark.sessionState
|
||||
.functionRegistry
|
||||
.createOrReplaceTempFunction(name, fn, "scala_udf")
|
||||
|
||||
registerFunction("gaussian", distribution.Gaussian.Constructor(_))
|
||||
registerFunction("clamp", distribution.Clamp.Constructor)
|
||||
registerFunction("discretize", distribution.Discretized.Constructor)
|
||||
spark.udf.register("entropy", udf.Entropy.udf)
|
||||
spark.udf.register("kl_divergence", udf.KLDivergence.udf)
|
||||
|
||||
|
|
|
@ -5,21 +5,20 @@ import java.io.ObjectOutputStream
|
|||
import java.io.ObjectInputStream
|
||||
import org.mimirdb.pip.udt.UnivariateDistribution
|
||||
import org.apache.spark.sql.functions
|
||||
import org.mimirdb.pip.udt.UnivariateDistributionConstructor
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
|
||||
object Clamp
|
||||
extends NumericalDistributionFamily
|
||||
{
|
||||
|
||||
def apply(col: UnivariateDistribution, low: Double, high: Double): UnivariateDistribution =
|
||||
UnivariateDistribution(this, Params(
|
||||
def apply(col: UnivariateDistribution, low: Double, high: Double): Params =
|
||||
Params(
|
||||
family = col.family.asInstanceOf[NumericalDistributionFamily],
|
||||
params = col.params,
|
||||
low = Some(low),
|
||||
high = Some(high)
|
||||
))
|
||||
|
||||
def udf = functions.udf(apply(_, _, _))
|
||||
|
||||
)
|
||||
|
||||
case class Params(
|
||||
family: NumericalDistributionFamily,
|
||||
|
@ -124,4 +123,22 @@ object Clamp
|
|||
params.asInstanceOf[Params].family.approximateCDFIsFast(
|
||||
params.asInstanceOf[Params].params
|
||||
)
|
||||
|
||||
case class Constructor(args: Seq[Expression])
|
||||
extends UnivariateDistributionConstructor
|
||||
{
|
||||
def family = Clamp
|
||||
def params(values: Seq[Any]) =
|
||||
{
|
||||
Clamp(
|
||||
col = UnivariateDistribution.decode(values(0)),
|
||||
low = values(1).asInstanceOf[Double],
|
||||
high = values(2).asInstanceOf[Double],
|
||||
)
|
||||
}
|
||||
|
||||
def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
|
||||
copy(args = newChildren)
|
||||
}
|
||||
|
||||
}
|
|
@ -7,6 +7,8 @@ import org.mimirdb.pip.SampleParams
|
|||
import org.mimirdb.pip.udt.UnivariateDistribution
|
||||
import org.apache.spark.sql.functions
|
||||
import scala.collection.Searching._
|
||||
import org.mimirdb.pip.udt.UnivariateDistributionConstructor
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
|
||||
object Discretized
|
||||
extends NumericalDistributionFamily
|
||||
|
@ -127,7 +129,7 @@ object Discretized
|
|||
.forall { case (a, b) => a.low == b.low && a.high == b.high }
|
||||
}
|
||||
|
||||
def apply(base: UnivariateDistribution, bins: Array[Double], samples: Int): UnivariateDistribution =
|
||||
def apply(base: UnivariateDistribution, bins: Array[Double], samples: Int): Params =
|
||||
{
|
||||
assert(bins.size >= 2)
|
||||
val baseFamily = base.family.asInstanceOf[NumericalDistributionFamily]
|
||||
|
@ -171,11 +173,25 @@ object Discretized
|
|||
// println(bins.mkString(", "))
|
||||
check(params)
|
||||
|
||||
UnivariateDistribution(
|
||||
family = this,
|
||||
params = params
|
||||
)
|
||||
return params
|
||||
}
|
||||
|
||||
def udf = functions.udf(apply(_, _, _))
|
||||
case class Constructor(args: Seq[Expression])
|
||||
extends UnivariateDistributionConstructor
|
||||
{
|
||||
def family = Discretized
|
||||
def params(values: Seq[Any]) =
|
||||
Discretized(
|
||||
base = UnivariateDistribution.decode(values(0)),
|
||||
bins =
|
||||
values(1).asInstanceOf[Double].until(
|
||||
values(2).asInstanceOf[Double],
|
||||
values(3).asInstanceOf[Double]
|
||||
).toArray,
|
||||
samples = values(4).asInstanceOf[Int]
|
||||
)
|
||||
|
||||
def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
|
||||
copy(args = newChildren)
|
||||
}
|
||||
}
|
|
@ -6,8 +6,12 @@ import java.io.ObjectOutputStream
|
|||
import java.io.ObjectInputStream
|
||||
import org.apache.commons.math3.special.Erf
|
||||
import org.apache.spark.sql.Column
|
||||
import org.apache.spark.sql.types.DoubleType
|
||||
import org.mimirdb.pip.SampleParams
|
||||
import org.mimirdb.pip.udt.UnivariateDistribution
|
||||
import org.mimirdb.pip.udt.UnivariateDistributionType
|
||||
import org.mimirdb.pip.udt.UnivariateDistributionConstructor
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
|
||||
/**
|
||||
* The Gaussian (normal) distribution
|
||||
|
@ -57,12 +61,18 @@ object Gaussian
|
|||
+ params.asInstanceOf[Params].mean
|
||||
}
|
||||
|
||||
def apply(mean: Double, sd: Double): UnivariateDistribution =
|
||||
UnivariateDistribution(this, Params(mean, sd))
|
||||
case class Constructor(args: Seq[Expression])
|
||||
extends UnivariateDistributionConstructor
|
||||
{
|
||||
def family = Gaussian
|
||||
def params(values: Seq[Any]) =
|
||||
Params(mean = values(0).asInstanceOf[Double],
|
||||
sd = values(1).asInstanceOf[Double])
|
||||
|
||||
def udf = org.apache.spark.sql.functions.udf(apply(_:Double, _:Double))
|
||||
def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
|
||||
copy(args = newChildren)
|
||||
}
|
||||
|
||||
def apply(mean: Column, sd: Column): Column = udf(mean, sd)
|
||||
|
||||
def plus(a: Params, b: Params): Params =
|
||||
{
|
||||
|
|
|
@ -19,7 +19,7 @@ object Entropy
|
|||
val max = family.max(dist.params)
|
||||
val step = (max - min) / BUCKETS
|
||||
|
||||
val params = Discretized(dist, (min.until(max, step)).toArray, 1000).params
|
||||
val params = Discretized(dist, (min.until(max, step)).toArray, 1000)
|
||||
Discretized.entropy(params)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,13 +16,13 @@ object KLDivergence
|
|||
Discretized.klDivergence(target.params, base.params)
|
||||
case (_:NumericalDistributionFamily, Discretized) =>
|
||||
Discretized.klDivergence(
|
||||
Discretized(target, Discretized.bins(base.params), 1000).params,
|
||||
Discretized(target, Discretized.bins(base.params), 1000),
|
||||
base.params
|
||||
)
|
||||
case (Discretized, _:NumericalDistributionFamily) =>
|
||||
Discretized.klDivergence(
|
||||
target.params,
|
||||
Discretized(base, Discretized.bins(target.params), 1000).params,
|
||||
Discretized(base, Discretized.bins(target.params), 1000),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,15 +2,29 @@ package org.mimirdb.pip.udt
|
|||
|
||||
import org.mimirdb.pip.distribution.DistributionFamily
|
||||
import org.apache.spark.sql.types.{ DataType, UserDefinedType, BinaryType }
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import java.util.UUID
|
||||
|
||||
@SerialVersionUID(100)
|
||||
case class UnivariateDistribution(family: DistributionFamily, params: Any)
|
||||
extends Serializable
|
||||
{
|
||||
override def toString(): String =
|
||||
family.describe(params)
|
||||
}
|
||||
|
||||
class UnivariateDistributionType extends UserDefinedType[UnivariateDistribution]
|
||||
object UnivariateDistribution
|
||||
{
|
||||
def decode(datum: Any): UnivariateDistribution =
|
||||
datum match {
|
||||
case ud:UnivariateDistribution => return ud
|
||||
case ba:Array[Byte] => UnivariateDistributionType.deserialize(ba)
|
||||
}
|
||||
}
|
||||
|
||||
class UnivariateDistributionType extends UserDefinedType[UnivariateDistribution] with Serializable
|
||||
{
|
||||
|
||||
override def equals(other: Any): Boolean = other match {
|
||||
|
@ -30,7 +44,8 @@ class UnivariateDistributionType extends UserDefinedType[UnivariateDistribution]
|
|||
return byteBuffer.toByteArray()
|
||||
}
|
||||
|
||||
override def deserialize(datum: Any): UnivariateDistribution = {
|
||||
override def deserialize(datum: Any): UnivariateDistribution =
|
||||
{
|
||||
val bis = new java.io.ByteArrayInputStream(datum.asInstanceOf[Array[Byte]])
|
||||
val in = new java.io.ObjectInputStream(bis)
|
||||
val dist = DistributionFamily(in.readUTF())
|
||||
|
@ -45,3 +60,22 @@ class UnivariateDistributionType extends UserDefinedType[UnivariateDistribution]
|
|||
}
|
||||
|
||||
case object UnivariateDistributionType extends org.mimirdb.pip.udt.UnivariateDistributionType
|
||||
|
||||
abstract class UnivariateDistributionConstructor
|
||||
extends Expression with Serializable with CodegenFallback
|
||||
{
|
||||
def args: Seq[Expression]
|
||||
def family: DistributionFamily
|
||||
def params(args: Seq[Any]): Any
|
||||
|
||||
def dataType = UnivariateDistributionType
|
||||
|
||||
def nullable = false
|
||||
def eval(input: InternalRow): Any =
|
||||
{
|
||||
val dist = UnivariateDistribution(family, params(args.map { _.eval(input) }))
|
||||
UnivariateDistributionType.serialize(dist)
|
||||
}
|
||||
|
||||
def children = args
|
||||
}
|
Loading…
Reference in New Issue