Initial commit of cleaned up version

This commit is contained in:
Oliver Kennedy 2024-07-24 17:05:07 -04:00
commit 67cae8f6b4
Signed by: okennedy
GPG key ID: 3E5F9B3ABD3FDB60
10 changed files with 386 additions and 0 deletions

2
.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
.metals
/out

30
build.sc Normal file
View file

@ -0,0 +1,30 @@
import $ivy.`com.lihaoyi::mill-contrib-bloop:$MILL_VERSION`
import mill._
import mill.main._
import mill.scalalib._
import coursier.maven.{ MavenRepository }
/*************************************************
*** The Vizier Backend
*************************************************/
object mimir_histogram extends RootModule with ScalaModule {
val VERSION = "0.0.1-SNAPSHOT"
def scalaVersion = "2.12.15"
def ivyDeps = Agg(
ivy"org.apache.spark::spark-sql:3.3.1",
ivy"org.apache.spark::spark-core:3.3.1",
ivy"org.apache.spark::spark-mllib:3.3.1",
ivy"org.apache.commons:commons-math3:3.6.1"
)
object test extends ScalaTests with TestModule {
def ivyDeps = Agg(ivy"org.scalactic::scalactic:3.2.17",
ivy"org.scalatest::scalatest:3.2.17"
)
def testFramework = "org.scalatest.tools.Framework"
}
}

View file

@ -0,0 +1,251 @@
package org.mimirdb.histogram
import org.apache.spark.sql.types.{ DataType, UserDefinedType, ArrayType, DoubleType }
import org.apache.spark.sql.catalyst.util.ArrayData
import org.mimirdb.histogram.util.{ StringUtils, SeqUtils }
case class HistogramBin(low: Double, high: Double, p: Double)
case class Histogram(bins: Array[HistogramBin])
{
def check(): Histogram =
{
assert(!bins.isEmpty)
assert(
Math.abs(bins.map { _.p }.sum - 1.0) < Histogram.ACCURACY,
s"Unexpected bin boundaries: ${bins.map { _.p }.sum} = ${bins.map { _.p }.mkString(" + ")}"
)
var curr = bins.head.high
for(x <- bins.tail)
{
assert(x.low < x.high)
assert(x.low == curr)
curr = x.high
}
return this
}
def sample(random: scala.util.Random): Double =
{
var x = random.nextDouble()
var currentBins = bins
while(x > currentBins.head.p && currentBins.size > 1){
x -= currentBins.head.p
currentBins = currentBins.tail
}
x /= currentBins.head.p
return x * (currentBins.head.high - currentBins.head.low) + currentBins.head.low
}
def cdf(value: Double, params: Any, leadingEdge: Boolean = false): Double =
{
bins.map { bin =>
if(value >= bin.high){ bin.p }
else if(value >= bin.low){
bin.p * ((value - bin.low)/(bin.high - bin.low))
}
else { 0 }
}.sum
}
def klDivergence(target: Histogram): Double =
{
assert(sameBinsAs(target))
bins .zip(target.bins)
.map { case (t:HistogramBin, b:HistogramBin) =>
if(t.p > 0){
if(b.p > 0){
t.p * Math.log(t.p / b.p)
} else {
Double.PositiveInfinity
}
} else { 0 }
}.sum
}
def entropy: Double =
bins.map { bin =>
if(bin.p > 0){ - Math.log(bin.p) * bin.p }
else { 0 }
}.sum
def min: Double = bins.head.low
def max: Double = bins.last.high
def binBoundaries: Array[Double] =
(bins.head.low +: bins.map { _.high }).toArray
def sameBinsAs(other: Histogram): Boolean =
{
(bins.size == other.bins.size) &&
bins.zip(other.bins).forall {
case (a, b) => a.low == b.low && a.high == b.high
}
}
override def toString(): String =
{
if(bins.isEmpty) { "[empty histogram]" }
else {
s"[${
bins.head.low
}, ${
bins.last.high
}] -> [${StringUtils.consoleHistogram(bins.map { _.p })}]"
}
}
}
object Histogram
{
val ACCURACY = 0.0001
def average(histograms: Seq[Histogram]): Histogram =
{
for(p <- histograms.sliding(2)) {
assert( p(0) sameBinsAs p(1) )
}
Histogram(
SeqUtils.pivot(histograms.map { _.bins.toSeq }).map { vs =>
HistogramBin(
low = vs(0).low,
high = vs(0).high,
p = vs.map { _.p }.sum / vs.size
)
}.toArray
)
}
}
class HistogramType extends UserDefinedType[Histogram] with Serializable
{
override def equals(other: Any): Boolean = other match {
case _: UserDefinedType[_] => other.isInstanceOf[HistogramType]
case _ => false
}
def sqlType: DataType = ArrayType(DoubleType)
def userClass = classOf[Histogram]
def deserialize(datum: Any): Histogram =
{
datum match {
case h: Histogram => h
case a: Array[Double] =>
assert(a.size % 2 == 1)
Histogram(
(0 until (a.size / 2)).map { i =>
HistogramBin(a(i*2), a(i*2+2), a(i*2+1))
}.toArray
)
case a: ArrayData =>
deserialize(a.toDoubleArray)
}
}
def serialize(datum: Histogram): ArrayData =
{
ArrayData.toArrayData(
(
datum.bins(0).low +:
datum.bins.flatMap { bin =>
Seq(bin.p, bin.high)
}
).toArray
)
}
// def apply(base: UnivariateDistribution, bins: Array[Double], samples: Int): Params =
// {
// assert(bins.size >= 2)
// val baseFamily = base.family.asInstanceOf[NumericalDistributionFamily]
// val params:Params =
// if(baseFamily.approximateCDFIsFast(base.params)){
// val startCDF = baseFamily.approximateCDF(bins.head, base.params, 1000, leadingEdge = true)
// val endCDF = baseFamily.approximateCDF(bins.last, base.params, 1000, leadingEdge = false)
// val adjustCDF = endCDF - startCDF
// var lastCDF = startCDF
// var lastBin = bins.head
// assert(adjustCDF > 0, s"Error histogramming $base Using CDF [${bins.head} - ${bins.last}]: $startCDF; $endCDF; $adjustCDF")
// // println(s"Fast Path: $startCDF")
// bins.tail.map { binHigh =>
// val binLow = lastBin
// var cdf = baseFamily.approximateCDF(binHigh, base.params, 1000, leadingEdge = false)
// val result = Bin(binLow, binHigh, (cdf - lastCDF) / adjustCDF)
// lastCDF = baseFamily.approximateCDF(binHigh, base.params, 1000, leadingEdge = true)
// lastBin = binHigh
// result
// }:Params
// } else {
// // println(s"For $base, sampling histogram")
// val counts = Array.fill(bins.size-1)(0)
// var missed = 0
// for(i <- 0 until samples)
// {
// val sample = base.family.sample(base, scala.util.Random).asInstanceOf[Double]
// val bin = bins.search(sample)
// // println(s"Sample: $sample")
// if(bin.insertionPoint == 0 || bin.insertionPoint > bins.size){
// // println(s" MISSED")
// missed += 1
// } else {
// counts(bin.insertionPoint - 1) += 1
// }
// }
// counts.zipWithIndex.map { case (count, bin) =>
// val binLow = bins(bin)
// val binHigh = bins(bin+1)
// val cdf = count.toDouble / (samples - missed)
// Bin(binLow, binHigh, cdf)
// }:Params
// }
// // println(bins.mkString(", "))
// check(params)
// return params
// }
// def apply(bins: Seq[Double], values: Seq[Double]): Params =
// {
// assert(bins.size == values.size+1)
// val total = values.sum
// assert(total > 0)
// values.indices.map { i =>
// Bin(bins(i), bins(i+1), values(i))
// }
// }
// case class Constructor(args: Seq[Expression])
// extends UnivariateDistributionConstructor
// {
// def family = Discretized
// def params(values: Seq[Any]) =
// Discretized(
// base = UnivariateDistribution.decode(values(0)),
// bins =
// values(1).asInstanceOf[Double].until(
// values(2).asInstanceOf[Double],
// values(3).asInstanceOf[Double]
// ).toArray,
// samples = values(4).asInstanceOf[Int]
// )
// def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
// copy(args = newChildren)
// }
// case class FromBinsConstructor(args: Seq[Expression])
// extends UnivariateDistributionConstructor
// {
// def family = Discretized
// def params(values: Seq[Any]) =
// Discretized(
// bins = values(0).asInstanceOf[ArrayData].toDoubleArray().toSeq,
// values = values(1).asInstanceOf[ArrayData].toDoubleArray().toSeq,
// )
// def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
// copy(args = newChildren)
// }
}
object HistogramType extends HistogramType

View file

@ -0,0 +1,19 @@
package org.mimirdb.histogram
import org.apache.spark.sql.types.UDTRegistration
import org.apache.spark.sql.SparkSession
object Plugin
{
def init(spark: SparkSession): Unit =
{
UDTRegistration.register(
classOf[Histogram].getName(),
classOf[HistogramType].getName()
)
spark.udf.register("HIST_FromArray",
udf.Constructors.fromArrayUdf)
}
}

View file

@ -0,0 +1,13 @@
package org.mimirdb.histogram.udf
import org.apache.spark.sql.functions.udf
import org.mimirdb.histogram._
object Constructors
{
def fromArray(elems: Array[Double]): Histogram =
{
HistogramType.deserialize(elems).check()
}
val fromArrayUdf = udf(fromArray(_))
}

View file

@ -0,0 +1,12 @@
package org.mimirdb.histogram.util
object SeqUtils
{
def pivot[A](elems: Seq[Seq[A]]): Seq[Seq[A]] =
{
if(elems.isEmpty){ return Seq.empty }
elems.tail.foldRight(elems.head.map { _ :: Nil }){ (vs, lists) =>
vs.zip(lists).map { case (v, list) => v :: list }
}
}
}

View file

@ -0,0 +1,16 @@
package org.mimirdb.histogram.util
object StringUtils
{
def consoleHistogram(bins: Seq[Double], min: Double = 0.0, max: Double = 1.0): String =
{
bins.map { b =>
val p = Math.min(10, (Math.max(0, Math.ceil((b - min) / (max - min) * 10).toInt)))
p match {
case 0 => " "
case x => Character.toChars(x + 0x2580).mkString
}
}
.mkString
}
}

View file

@ -0,0 +1,6 @@
{
"schema_version" : 1,
"plugin_class" : "org.mimirdb.pip.Histogram.VizierPlugin$",
"name" : "Mimir-Pip",
"description" : "Tools for working with histograms."
}

View file

@ -0,0 +1,22 @@
package org.mimirdb.histogram
import org.scalatest.flatspec.AnyFlatSpec
import org.apache.spark.sql.functions._
import org.mimirdb.histogram.udf.Constructors._
class TestHistograms extends AnyFlatSpec {
import TestResources.spark
"histograms" should "be constructable and introspectable" in {
val df = spark.range(0, 10)
df.select(fromArrayUdf(
array(df("id").cast("double"),
lit(1.0),
(df("id") + 1).cast("double")
))
as "test"
)
.show()
}
}

View file

@ -0,0 +1,15 @@
package org.mimirdb.histogram
import org.apache.spark.sql.SparkSession
/* Spark Session used across tests */
object TestResources {
val spark = SparkSession.builder
.appName("pip")
.master("local[*]")
.getOrCreate()
spark.sparkContext.setLogLevel("WARN")
Plugin.init(spark)
}