35 lines
1010 B
Scala
35 lines
1010 B
Scala
package org.mimirdb.pip.distribution.boolean
|
|
|
|
import org.apache.spark.sql.catalyst.expressions.Expression
|
|
import org.mimirdb.pip.udt.UnivariateDistributionConstructor
|
|
|
|
object Bernoulli
|
|
extends BooleanDistributionFamily
|
|
with ProbabilitySupported
|
|
{
|
|
|
|
def probability(params: Any): Double =
|
|
params.asInstanceOf[Double]
|
|
|
|
def describe(params: Any): String =
|
|
s"Bernoulli($params)"
|
|
|
|
def sample(params: Any, random: scala.util.Random): Boolean =
|
|
random.nextDouble() < params.asInstanceOf[Double]
|
|
|
|
def deserialize(in: java.io.ObjectInputStream): Any =
|
|
in.readDouble()
|
|
|
|
def serialize(out: java.io.ObjectOutputStream, params: Any): Unit =
|
|
out.writeDouble(params.asInstanceOf[Double])
|
|
|
|
case class Constructor(args: Seq[Expression])
|
|
extends UnivariateDistributionConstructor
|
|
{
|
|
def family = Bernoulli
|
|
def params(values: Seq[Any]) = values(0).asInstanceOf[Double]
|
|
|
|
def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
|
|
copy(args = newChildren)
|
|
}
|
|
} |