Initial commit of cleaned up version
This commit is contained in:
commit
67cae8f6b4
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
.metals
|
||||
/out
|
30
build.sc
Normal file
30
build.sc
Normal file
|
@ -0,0 +1,30 @@
|
|||
import $ivy.`com.lihaoyi::mill-contrib-bloop:$MILL_VERSION`
|
||||
import mill._
|
||||
import mill.main._
|
||||
import mill.scalalib._
|
||||
import coursier.maven.{ MavenRepository }
|
||||
|
||||
/*************************************************
|
||||
*** The Vizier Backend
|
||||
*************************************************/
|
||||
object mimir_histogram extends RootModule with ScalaModule {
|
||||
val VERSION = "0.0.1-SNAPSHOT"
|
||||
|
||||
def scalaVersion = "2.12.15"
|
||||
|
||||
|
||||
def ivyDeps = Agg(
|
||||
ivy"org.apache.spark::spark-sql:3.3.1",
|
||||
ivy"org.apache.spark::spark-core:3.3.1",
|
||||
ivy"org.apache.spark::spark-mllib:3.3.1",
|
||||
ivy"org.apache.commons:commons-math3:3.6.1"
|
||||
)
|
||||
|
||||
object test extends ScalaTests with TestModule {
|
||||
def ivyDeps = Agg(ivy"org.scalactic::scalactic:3.2.17",
|
||||
ivy"org.scalatest::scalatest:3.2.17"
|
||||
)
|
||||
def testFramework = "org.scalatest.tools.Framework"
|
||||
}
|
||||
|
||||
}
|
251
src/org/mimirdb/histogram/Histogram.scala
Normal file
251
src/org/mimirdb/histogram/Histogram.scala
Normal file
|
@ -0,0 +1,251 @@
|
|||
package org.mimirdb.histogram
|
||||
|
||||
import org.apache.spark.sql.types.{ DataType, UserDefinedType, ArrayType, DoubleType }
|
||||
import org.apache.spark.sql.catalyst.util.ArrayData
|
||||
import org.mimirdb.histogram.util.{ StringUtils, SeqUtils }
|
||||
|
||||
case class HistogramBin(low: Double, high: Double, p: Double)
|
||||
|
||||
case class Histogram(bins: Array[HistogramBin])
|
||||
{
|
||||
def check(): Histogram =
|
||||
{
|
||||
assert(!bins.isEmpty)
|
||||
assert(
|
||||
Math.abs(bins.map { _.p }.sum - 1.0) < Histogram.ACCURACY,
|
||||
s"Unexpected bin boundaries: ${bins.map { _.p }.sum} = ${bins.map { _.p }.mkString(" + ")}"
|
||||
)
|
||||
var curr = bins.head.high
|
||||
for(x <- bins.tail)
|
||||
{
|
||||
assert(x.low < x.high)
|
||||
assert(x.low == curr)
|
||||
curr = x.high
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
def sample(random: scala.util.Random): Double =
|
||||
{
|
||||
var x = random.nextDouble()
|
||||
var currentBins = bins
|
||||
while(x > currentBins.head.p && currentBins.size > 1){
|
||||
x -= currentBins.head.p
|
||||
currentBins = currentBins.tail
|
||||
}
|
||||
x /= currentBins.head.p
|
||||
return x * (currentBins.head.high - currentBins.head.low) + currentBins.head.low
|
||||
}
|
||||
|
||||
def cdf(value: Double, params: Any, leadingEdge: Boolean = false): Double =
|
||||
{
|
||||
bins.map { bin =>
|
||||
if(value >= bin.high){ bin.p }
|
||||
else if(value >= bin.low){
|
||||
bin.p * ((value - bin.low)/(bin.high - bin.low))
|
||||
}
|
||||
else { 0 }
|
||||
}.sum
|
||||
}
|
||||
|
||||
def klDivergence(target: Histogram): Double =
|
||||
{
|
||||
assert(sameBinsAs(target))
|
||||
bins .zip(target.bins)
|
||||
.map { case (t:HistogramBin, b:HistogramBin) =>
|
||||
if(t.p > 0){
|
||||
if(b.p > 0){
|
||||
t.p * Math.log(t.p / b.p)
|
||||
} else {
|
||||
Double.PositiveInfinity
|
||||
}
|
||||
} else { 0 }
|
||||
}.sum
|
||||
}
|
||||
|
||||
def entropy: Double =
|
||||
bins.map { bin =>
|
||||
if(bin.p > 0){ - Math.log(bin.p) * bin.p }
|
||||
else { 0 }
|
||||
}.sum
|
||||
|
||||
def min: Double = bins.head.low
|
||||
def max: Double = bins.last.high
|
||||
|
||||
def binBoundaries: Array[Double] =
|
||||
(bins.head.low +: bins.map { _.high }).toArray
|
||||
|
||||
def sameBinsAs(other: Histogram): Boolean =
|
||||
{
|
||||
(bins.size == other.bins.size) &&
|
||||
bins.zip(other.bins).forall {
|
||||
case (a, b) => a.low == b.low && a.high == b.high
|
||||
}
|
||||
}
|
||||
|
||||
override def toString(): String =
|
||||
{
|
||||
if(bins.isEmpty) { "[empty histogram]" }
|
||||
else {
|
||||
s"[${
|
||||
bins.head.low
|
||||
}, ${
|
||||
bins.last.high
|
||||
}] -> [${StringUtils.consoleHistogram(bins.map { _.p })}]"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object Histogram
|
||||
{
|
||||
val ACCURACY = 0.0001
|
||||
|
||||
def average(histograms: Seq[Histogram]): Histogram =
|
||||
{
|
||||
for(p <- histograms.sliding(2)) {
|
||||
assert( p(0) sameBinsAs p(1) )
|
||||
}
|
||||
Histogram(
|
||||
SeqUtils.pivot(histograms.map { _.bins.toSeq }).map { vs =>
|
||||
HistogramBin(
|
||||
low = vs(0).low,
|
||||
high = vs(0).high,
|
||||
p = vs.map { _.p }.sum / vs.size
|
||||
)
|
||||
}.toArray
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
class HistogramType extends UserDefinedType[Histogram] with Serializable
|
||||
{
|
||||
|
||||
override def equals(other: Any): Boolean = other match {
|
||||
case _: UserDefinedType[_] => other.isInstanceOf[HistogramType]
|
||||
case _ => false
|
||||
}
|
||||
def sqlType: DataType = ArrayType(DoubleType)
|
||||
def userClass = classOf[Histogram]
|
||||
def deserialize(datum: Any): Histogram =
|
||||
{
|
||||
datum match {
|
||||
case h: Histogram => h
|
||||
case a: Array[Double] =>
|
||||
assert(a.size % 2 == 1)
|
||||
Histogram(
|
||||
(0 until (a.size / 2)).map { i =>
|
||||
HistogramBin(a(i*2), a(i*2+2), a(i*2+1))
|
||||
}.toArray
|
||||
)
|
||||
case a: ArrayData =>
|
||||
deserialize(a.toDoubleArray)
|
||||
}
|
||||
}
|
||||
def serialize(datum: Histogram): ArrayData =
|
||||
{
|
||||
ArrayData.toArrayData(
|
||||
(
|
||||
datum.bins(0).low +:
|
||||
datum.bins.flatMap { bin =>
|
||||
Seq(bin.p, bin.high)
|
||||
}
|
||||
).toArray
|
||||
)
|
||||
}
|
||||
|
||||
// def apply(base: UnivariateDistribution, bins: Array[Double], samples: Int): Params =
|
||||
// {
|
||||
// assert(bins.size >= 2)
|
||||
// val baseFamily = base.family.asInstanceOf[NumericalDistributionFamily]
|
||||
|
||||
// val params:Params =
|
||||
// if(baseFamily.approximateCDFIsFast(base.params)){
|
||||
// val startCDF = baseFamily.approximateCDF(bins.head, base.params, 1000, leadingEdge = true)
|
||||
// val endCDF = baseFamily.approximateCDF(bins.last, base.params, 1000, leadingEdge = false)
|
||||
// val adjustCDF = endCDF - startCDF
|
||||
// var lastCDF = startCDF
|
||||
// var lastBin = bins.head
|
||||
// assert(adjustCDF > 0, s"Error histogramming $base Using CDF [${bins.head} - ${bins.last}]: $startCDF; $endCDF; $adjustCDF")
|
||||
// // println(s"Fast Path: $startCDF")
|
||||
// bins.tail.map { binHigh =>
|
||||
// val binLow = lastBin
|
||||
// var cdf = baseFamily.approximateCDF(binHigh, base.params, 1000, leadingEdge = false)
|
||||
// val result = Bin(binLow, binHigh, (cdf - lastCDF) / adjustCDF)
|
||||
// lastCDF = baseFamily.approximateCDF(binHigh, base.params, 1000, leadingEdge = true)
|
||||
// lastBin = binHigh
|
||||
// result
|
||||
// }:Params
|
||||
// } else {
|
||||
// // println(s"For $base, sampling histogram")
|
||||
// val counts = Array.fill(bins.size-1)(0)
|
||||
// var missed = 0
|
||||
// for(i <- 0 until samples)
|
||||
// {
|
||||
// val sample = base.family.sample(base, scala.util.Random).asInstanceOf[Double]
|
||||
// val bin = bins.search(sample)
|
||||
// // println(s"Sample: $sample")
|
||||
// if(bin.insertionPoint == 0 || bin.insertionPoint > bins.size){
|
||||
// // println(s" MISSED")
|
||||
// missed += 1
|
||||
// } else {
|
||||
// counts(bin.insertionPoint - 1) += 1
|
||||
// }
|
||||
// }
|
||||
// counts.zipWithIndex.map { case (count, bin) =>
|
||||
// val binLow = bins(bin)
|
||||
// val binHigh = bins(bin+1)
|
||||
// val cdf = count.toDouble / (samples - missed)
|
||||
// Bin(binLow, binHigh, cdf)
|
||||
// }:Params
|
||||
// }
|
||||
// // println(bins.mkString(", "))
|
||||
// check(params)
|
||||
|
||||
// return params
|
||||
// }
|
||||
// def apply(bins: Seq[Double], values: Seq[Double]): Params =
|
||||
// {
|
||||
// assert(bins.size == values.size+1)
|
||||
// val total = values.sum
|
||||
// assert(total > 0)
|
||||
// values.indices.map { i =>
|
||||
// Bin(bins(i), bins(i+1), values(i))
|
||||
// }
|
||||
// }
|
||||
|
||||
// case class Constructor(args: Seq[Expression])
|
||||
// extends UnivariateDistributionConstructor
|
||||
// {
|
||||
// def family = Discretized
|
||||
// def params(values: Seq[Any]) =
|
||||
// Discretized(
|
||||
// base = UnivariateDistribution.decode(values(0)),
|
||||
// bins =
|
||||
// values(1).asInstanceOf[Double].until(
|
||||
// values(2).asInstanceOf[Double],
|
||||
// values(3).asInstanceOf[Double]
|
||||
// ).toArray,
|
||||
// samples = values(4).asInstanceOf[Int]
|
||||
// )
|
||||
|
||||
// def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
|
||||
// copy(args = newChildren)
|
||||
// }
|
||||
|
||||
// case class FromBinsConstructor(args: Seq[Expression])
|
||||
// extends UnivariateDistributionConstructor
|
||||
// {
|
||||
// def family = Discretized
|
||||
// def params(values: Seq[Any]) =
|
||||
// Discretized(
|
||||
// bins = values(0).asInstanceOf[ArrayData].toDoubleArray().toSeq,
|
||||
// values = values(1).asInstanceOf[ArrayData].toDoubleArray().toSeq,
|
||||
// )
|
||||
|
||||
// def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
|
||||
// copy(args = newChildren)
|
||||
// }
|
||||
}
|
||||
|
||||
object HistogramType extends HistogramType
|
19
src/org/mimirdb/histogram/Plugin.scala
Normal file
19
src/org/mimirdb/histogram/Plugin.scala
Normal file
|
@ -0,0 +1,19 @@
|
|||
package org.mimirdb.histogram
|
||||
|
||||
import org.apache.spark.sql.types.UDTRegistration
|
||||
import org.apache.spark.sql.SparkSession
|
||||
|
||||
object Plugin
|
||||
{
|
||||
def init(spark: SparkSession): Unit =
|
||||
{
|
||||
UDTRegistration.register(
|
||||
classOf[Histogram].getName(),
|
||||
classOf[HistogramType].getName()
|
||||
)
|
||||
spark.udf.register("HIST_FromArray",
|
||||
udf.Constructors.fromArrayUdf)
|
||||
|
||||
}
|
||||
|
||||
}
|
13
src/org/mimirdb/histogram/udf/Constructors.scala
Normal file
13
src/org/mimirdb/histogram/udf/Constructors.scala
Normal file
|
@ -0,0 +1,13 @@
|
|||
package org.mimirdb.histogram.udf
|
||||
|
||||
import org.apache.spark.sql.functions.udf
|
||||
import org.mimirdb.histogram._
|
||||
|
||||
object Constructors
|
||||
{
|
||||
def fromArray(elems: Array[Double]): Histogram =
|
||||
{
|
||||
HistogramType.deserialize(elems).check()
|
||||
}
|
||||
val fromArrayUdf = udf(fromArray(_))
|
||||
}
|
12
src/org/mimirdb/histogram/util/SeqUtils.scala
Normal file
12
src/org/mimirdb/histogram/util/SeqUtils.scala
Normal file
|
@ -0,0 +1,12 @@
|
|||
package org.mimirdb.histogram.util
|
||||
|
||||
object SeqUtils
|
||||
{
|
||||
def pivot[A](elems: Seq[Seq[A]]): Seq[Seq[A]] =
|
||||
{
|
||||
if(elems.isEmpty){ return Seq.empty }
|
||||
elems.tail.foldRight(elems.head.map { _ :: Nil }){ (vs, lists) =>
|
||||
vs.zip(lists).map { case (v, list) => v :: list }
|
||||
}
|
||||
}
|
||||
}
|
16
src/org/mimirdb/histogram/util/StringUtils.scala
Normal file
16
src/org/mimirdb/histogram/util/StringUtils.scala
Normal file
|
@ -0,0 +1,16 @@
|
|||
package org.mimirdb.histogram.util
|
||||
|
||||
object StringUtils
|
||||
{
|
||||
def consoleHistogram(bins: Seq[Double], min: Double = 0.0, max: Double = 1.0): String =
|
||||
{
|
||||
bins.map { b =>
|
||||
val p = Math.min(10, (Math.max(0, Math.ceil((b - min) / (max - min) * 10).toInt)))
|
||||
p match {
|
||||
case 0 => " "
|
||||
case x => Character.toChars(x + 0x2580).mkString
|
||||
}
|
||||
}
|
||||
.mkString
|
||||
}
|
||||
}
|
6
src/resources/vizier-plugin.json
Normal file
6
src/resources/vizier-plugin.json
Normal file
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"schema_version" : 1,
|
||||
"plugin_class" : "org.mimirdb.pip.Histogram.VizierPlugin$",
|
||||
"name" : "Mimir-Pip",
|
||||
"description" : "Tools for working with histograms."
|
||||
}
|
22
test/src/org/mimirdb/histogram/TestHistograms.scala
Normal file
22
test/src/org/mimirdb/histogram/TestHistograms.scala
Normal file
|
@ -0,0 +1,22 @@
|
|||
package org.mimirdb.histogram
|
||||
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.mimirdb.histogram.udf.Constructors._
|
||||
|
||||
class TestHistograms extends AnyFlatSpec {
|
||||
import TestResources.spark
|
||||
|
||||
"histograms" should "be constructable and introspectable" in {
|
||||
val df = spark.range(0, 10)
|
||||
df.select(fromArrayUdf(
|
||||
array(df("id").cast("double"),
|
||||
lit(1.0),
|
||||
(df("id") + 1).cast("double")
|
||||
))
|
||||
as "test"
|
||||
)
|
||||
.show()
|
||||
}
|
||||
|
||||
}
|
15
test/src/org/mimirdb/histogram/TestResources.scala
Normal file
15
test/src/org/mimirdb/histogram/TestResources.scala
Normal file
|
@ -0,0 +1,15 @@
|
|||
package org.mimirdb.histogram
|
||||
|
||||
import org.apache.spark.sql.SparkSession
|
||||
|
||||
/* Spark Session used across tests */
|
||||
object TestResources {
|
||||
val spark = SparkSession.builder
|
||||
.appName("pip")
|
||||
.master("local[*]")
|
||||
.getOrCreate()
|
||||
|
||||
spark.sparkContext.setLogLevel("WARN")
|
||||
|
||||
Plugin.init(spark)
|
||||
}
|
Loading…
Reference in a new issue