Initial commit of cleaned up version
This commit is contained in:
commit
67cae8f6b4
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
.metals
|
||||||
|
/out
|
30
build.sc
Normal file
30
build.sc
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
import $ivy.`com.lihaoyi::mill-contrib-bloop:$MILL_VERSION`
|
||||||
|
import mill._
|
||||||
|
import mill.main._
|
||||||
|
import mill.scalalib._
|
||||||
|
import coursier.maven.{ MavenRepository }
|
||||||
|
|
||||||
|
/*************************************************
|
||||||
|
*** The Vizier Backend
|
||||||
|
*************************************************/
|
||||||
|
object mimir_histogram extends RootModule with ScalaModule {
|
||||||
|
val VERSION = "0.0.1-SNAPSHOT"
|
||||||
|
|
||||||
|
def scalaVersion = "2.12.15"
|
||||||
|
|
||||||
|
|
||||||
|
def ivyDeps = Agg(
|
||||||
|
ivy"org.apache.spark::spark-sql:3.3.1",
|
||||||
|
ivy"org.apache.spark::spark-core:3.3.1",
|
||||||
|
ivy"org.apache.spark::spark-mllib:3.3.1",
|
||||||
|
ivy"org.apache.commons:commons-math3:3.6.1"
|
||||||
|
)
|
||||||
|
|
||||||
|
object test extends ScalaTests with TestModule {
|
||||||
|
def ivyDeps = Agg(ivy"org.scalactic::scalactic:3.2.17",
|
||||||
|
ivy"org.scalatest::scalatest:3.2.17"
|
||||||
|
)
|
||||||
|
def testFramework = "org.scalatest.tools.Framework"
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
251
src/org/mimirdb/histogram/Histogram.scala
Normal file
251
src/org/mimirdb/histogram/Histogram.scala
Normal file
|
@ -0,0 +1,251 @@
|
||||||
|
package org.mimirdb.histogram
|
||||||
|
|
||||||
|
import org.apache.spark.sql.types.{ DataType, UserDefinedType, ArrayType, DoubleType }
|
||||||
|
import org.apache.spark.sql.catalyst.util.ArrayData
|
||||||
|
import org.mimirdb.histogram.util.{ StringUtils, SeqUtils }
|
||||||
|
|
||||||
|
case class HistogramBin(low: Double, high: Double, p: Double)
|
||||||
|
|
||||||
|
case class Histogram(bins: Array[HistogramBin])
|
||||||
|
{
|
||||||
|
def check(): Histogram =
|
||||||
|
{
|
||||||
|
assert(!bins.isEmpty)
|
||||||
|
assert(
|
||||||
|
Math.abs(bins.map { _.p }.sum - 1.0) < Histogram.ACCURACY,
|
||||||
|
s"Unexpected bin boundaries: ${bins.map { _.p }.sum} = ${bins.map { _.p }.mkString(" + ")}"
|
||||||
|
)
|
||||||
|
var curr = bins.head.high
|
||||||
|
for(x <- bins.tail)
|
||||||
|
{
|
||||||
|
assert(x.low < x.high)
|
||||||
|
assert(x.low == curr)
|
||||||
|
curr = x.high
|
||||||
|
}
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
def sample(random: scala.util.Random): Double =
|
||||||
|
{
|
||||||
|
var x = random.nextDouble()
|
||||||
|
var currentBins = bins
|
||||||
|
while(x > currentBins.head.p && currentBins.size > 1){
|
||||||
|
x -= currentBins.head.p
|
||||||
|
currentBins = currentBins.tail
|
||||||
|
}
|
||||||
|
x /= currentBins.head.p
|
||||||
|
return x * (currentBins.head.high - currentBins.head.low) + currentBins.head.low
|
||||||
|
}
|
||||||
|
|
||||||
|
def cdf(value: Double, params: Any, leadingEdge: Boolean = false): Double =
|
||||||
|
{
|
||||||
|
bins.map { bin =>
|
||||||
|
if(value >= bin.high){ bin.p }
|
||||||
|
else if(value >= bin.low){
|
||||||
|
bin.p * ((value - bin.low)/(bin.high - bin.low))
|
||||||
|
}
|
||||||
|
else { 0 }
|
||||||
|
}.sum
|
||||||
|
}
|
||||||
|
|
||||||
|
def klDivergence(target: Histogram): Double =
|
||||||
|
{
|
||||||
|
assert(sameBinsAs(target))
|
||||||
|
bins .zip(target.bins)
|
||||||
|
.map { case (t:HistogramBin, b:HistogramBin) =>
|
||||||
|
if(t.p > 0){
|
||||||
|
if(b.p > 0){
|
||||||
|
t.p * Math.log(t.p / b.p)
|
||||||
|
} else {
|
||||||
|
Double.PositiveInfinity
|
||||||
|
}
|
||||||
|
} else { 0 }
|
||||||
|
}.sum
|
||||||
|
}
|
||||||
|
|
||||||
|
def entropy: Double =
|
||||||
|
bins.map { bin =>
|
||||||
|
if(bin.p > 0){ - Math.log(bin.p) * bin.p }
|
||||||
|
else { 0 }
|
||||||
|
}.sum
|
||||||
|
|
||||||
|
def min: Double = bins.head.low
|
||||||
|
def max: Double = bins.last.high
|
||||||
|
|
||||||
|
def binBoundaries: Array[Double] =
|
||||||
|
(bins.head.low +: bins.map { _.high }).toArray
|
||||||
|
|
||||||
|
def sameBinsAs(other: Histogram): Boolean =
|
||||||
|
{
|
||||||
|
(bins.size == other.bins.size) &&
|
||||||
|
bins.zip(other.bins).forall {
|
||||||
|
case (a, b) => a.low == b.low && a.high == b.high
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def toString(): String =
|
||||||
|
{
|
||||||
|
if(bins.isEmpty) { "[empty histogram]" }
|
||||||
|
else {
|
||||||
|
s"[${
|
||||||
|
bins.head.low
|
||||||
|
}, ${
|
||||||
|
bins.last.high
|
||||||
|
}] -> [${StringUtils.consoleHistogram(bins.map { _.p })}]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
object Histogram
|
||||||
|
{
|
||||||
|
val ACCURACY = 0.0001
|
||||||
|
|
||||||
|
def average(histograms: Seq[Histogram]): Histogram =
|
||||||
|
{
|
||||||
|
for(p <- histograms.sliding(2)) {
|
||||||
|
assert( p(0) sameBinsAs p(1) )
|
||||||
|
}
|
||||||
|
Histogram(
|
||||||
|
SeqUtils.pivot(histograms.map { _.bins.toSeq }).map { vs =>
|
||||||
|
HistogramBin(
|
||||||
|
low = vs(0).low,
|
||||||
|
high = vs(0).high,
|
||||||
|
p = vs.map { _.p }.sum / vs.size
|
||||||
|
)
|
||||||
|
}.toArray
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
class HistogramType extends UserDefinedType[Histogram] with Serializable
|
||||||
|
{
|
||||||
|
|
||||||
|
override def equals(other: Any): Boolean = other match {
|
||||||
|
case _: UserDefinedType[_] => other.isInstanceOf[HistogramType]
|
||||||
|
case _ => false
|
||||||
|
}
|
||||||
|
def sqlType: DataType = ArrayType(DoubleType)
|
||||||
|
def userClass = classOf[Histogram]
|
||||||
|
def deserialize(datum: Any): Histogram =
|
||||||
|
{
|
||||||
|
datum match {
|
||||||
|
case h: Histogram => h
|
||||||
|
case a: Array[Double] =>
|
||||||
|
assert(a.size % 2 == 1)
|
||||||
|
Histogram(
|
||||||
|
(0 until (a.size / 2)).map { i =>
|
||||||
|
HistogramBin(a(i*2), a(i*2+2), a(i*2+1))
|
||||||
|
}.toArray
|
||||||
|
)
|
||||||
|
case a: ArrayData =>
|
||||||
|
deserialize(a.toDoubleArray)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
def serialize(datum: Histogram): ArrayData =
|
||||||
|
{
|
||||||
|
ArrayData.toArrayData(
|
||||||
|
(
|
||||||
|
datum.bins(0).low +:
|
||||||
|
datum.bins.flatMap { bin =>
|
||||||
|
Seq(bin.p, bin.high)
|
||||||
|
}
|
||||||
|
).toArray
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// def apply(base: UnivariateDistribution, bins: Array[Double], samples: Int): Params =
|
||||||
|
// {
|
||||||
|
// assert(bins.size >= 2)
|
||||||
|
// val baseFamily = base.family.asInstanceOf[NumericalDistributionFamily]
|
||||||
|
|
||||||
|
// val params:Params =
|
||||||
|
// if(baseFamily.approximateCDFIsFast(base.params)){
|
||||||
|
// val startCDF = baseFamily.approximateCDF(bins.head, base.params, 1000, leadingEdge = true)
|
||||||
|
// val endCDF = baseFamily.approximateCDF(bins.last, base.params, 1000, leadingEdge = false)
|
||||||
|
// val adjustCDF = endCDF - startCDF
|
||||||
|
// var lastCDF = startCDF
|
||||||
|
// var lastBin = bins.head
|
||||||
|
// assert(adjustCDF > 0, s"Error histogramming $base Using CDF [${bins.head} - ${bins.last}]: $startCDF; $endCDF; $adjustCDF")
|
||||||
|
// // println(s"Fast Path: $startCDF")
|
||||||
|
// bins.tail.map { binHigh =>
|
||||||
|
// val binLow = lastBin
|
||||||
|
// var cdf = baseFamily.approximateCDF(binHigh, base.params, 1000, leadingEdge = false)
|
||||||
|
// val result = Bin(binLow, binHigh, (cdf - lastCDF) / adjustCDF)
|
||||||
|
// lastCDF = baseFamily.approximateCDF(binHigh, base.params, 1000, leadingEdge = true)
|
||||||
|
// lastBin = binHigh
|
||||||
|
// result
|
||||||
|
// }:Params
|
||||||
|
// } else {
|
||||||
|
// // println(s"For $base, sampling histogram")
|
||||||
|
// val counts = Array.fill(bins.size-1)(0)
|
||||||
|
// var missed = 0
|
||||||
|
// for(i <- 0 until samples)
|
||||||
|
// {
|
||||||
|
// val sample = base.family.sample(base, scala.util.Random).asInstanceOf[Double]
|
||||||
|
// val bin = bins.search(sample)
|
||||||
|
// // println(s"Sample: $sample")
|
||||||
|
// if(bin.insertionPoint == 0 || bin.insertionPoint > bins.size){
|
||||||
|
// // println(s" MISSED")
|
||||||
|
// missed += 1
|
||||||
|
// } else {
|
||||||
|
// counts(bin.insertionPoint - 1) += 1
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// counts.zipWithIndex.map { case (count, bin) =>
|
||||||
|
// val binLow = bins(bin)
|
||||||
|
// val binHigh = bins(bin+1)
|
||||||
|
// val cdf = count.toDouble / (samples - missed)
|
||||||
|
// Bin(binLow, binHigh, cdf)
|
||||||
|
// }:Params
|
||||||
|
// }
|
||||||
|
// // println(bins.mkString(", "))
|
||||||
|
// check(params)
|
||||||
|
|
||||||
|
// return params
|
||||||
|
// }
|
||||||
|
// def apply(bins: Seq[Double], values: Seq[Double]): Params =
|
||||||
|
// {
|
||||||
|
// assert(bins.size == values.size+1)
|
||||||
|
// val total = values.sum
|
||||||
|
// assert(total > 0)
|
||||||
|
// values.indices.map { i =>
|
||||||
|
// Bin(bins(i), bins(i+1), values(i))
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// case class Constructor(args: Seq[Expression])
|
||||||
|
// extends UnivariateDistributionConstructor
|
||||||
|
// {
|
||||||
|
// def family = Discretized
|
||||||
|
// def params(values: Seq[Any]) =
|
||||||
|
// Discretized(
|
||||||
|
// base = UnivariateDistribution.decode(values(0)),
|
||||||
|
// bins =
|
||||||
|
// values(1).asInstanceOf[Double].until(
|
||||||
|
// values(2).asInstanceOf[Double],
|
||||||
|
// values(3).asInstanceOf[Double]
|
||||||
|
// ).toArray,
|
||||||
|
// samples = values(4).asInstanceOf[Int]
|
||||||
|
// )
|
||||||
|
|
||||||
|
// def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
|
||||||
|
// copy(args = newChildren)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// case class FromBinsConstructor(args: Seq[Expression])
|
||||||
|
// extends UnivariateDistributionConstructor
|
||||||
|
// {
|
||||||
|
// def family = Discretized
|
||||||
|
// def params(values: Seq[Any]) =
|
||||||
|
// Discretized(
|
||||||
|
// bins = values(0).asInstanceOf[ArrayData].toDoubleArray().toSeq,
|
||||||
|
// values = values(1).asInstanceOf[ArrayData].toDoubleArray().toSeq,
|
||||||
|
// )
|
||||||
|
|
||||||
|
// def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
|
||||||
|
// copy(args = newChildren)
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
object HistogramType extends HistogramType
|
19
src/org/mimirdb/histogram/Plugin.scala
Normal file
19
src/org/mimirdb/histogram/Plugin.scala
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
package org.mimirdb.histogram
|
||||||
|
|
||||||
|
import org.apache.spark.sql.types.UDTRegistration
|
||||||
|
import org.apache.spark.sql.SparkSession
|
||||||
|
|
||||||
|
object Plugin
|
||||||
|
{
|
||||||
|
def init(spark: SparkSession): Unit =
|
||||||
|
{
|
||||||
|
UDTRegistration.register(
|
||||||
|
classOf[Histogram].getName(),
|
||||||
|
classOf[HistogramType].getName()
|
||||||
|
)
|
||||||
|
spark.udf.register("HIST_FromArray",
|
||||||
|
udf.Constructors.fromArrayUdf)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
13
src/org/mimirdb/histogram/udf/Constructors.scala
Normal file
13
src/org/mimirdb/histogram/udf/Constructors.scala
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
package org.mimirdb.histogram.udf
|
||||||
|
|
||||||
|
import org.apache.spark.sql.functions.udf
|
||||||
|
import org.mimirdb.histogram._
|
||||||
|
|
||||||
|
object Constructors
|
||||||
|
{
|
||||||
|
def fromArray(elems: Array[Double]): Histogram =
|
||||||
|
{
|
||||||
|
HistogramType.deserialize(elems).check()
|
||||||
|
}
|
||||||
|
val fromArrayUdf = udf(fromArray(_))
|
||||||
|
}
|
12
src/org/mimirdb/histogram/util/SeqUtils.scala
Normal file
12
src/org/mimirdb/histogram/util/SeqUtils.scala
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
package org.mimirdb.histogram.util
|
||||||
|
|
||||||
|
object SeqUtils
|
||||||
|
{
|
||||||
|
def pivot[A](elems: Seq[Seq[A]]): Seq[Seq[A]] =
|
||||||
|
{
|
||||||
|
if(elems.isEmpty){ return Seq.empty }
|
||||||
|
elems.tail.foldRight(elems.head.map { _ :: Nil }){ (vs, lists) =>
|
||||||
|
vs.zip(lists).map { case (v, list) => v :: list }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
16
src/org/mimirdb/histogram/util/StringUtils.scala
Normal file
16
src/org/mimirdb/histogram/util/StringUtils.scala
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
package org.mimirdb.histogram.util
|
||||||
|
|
||||||
|
object StringUtils
|
||||||
|
{
|
||||||
|
def consoleHistogram(bins: Seq[Double], min: Double = 0.0, max: Double = 1.0): String =
|
||||||
|
{
|
||||||
|
bins.map { b =>
|
||||||
|
val p = Math.min(10, (Math.max(0, Math.ceil((b - min) / (max - min) * 10).toInt)))
|
||||||
|
p match {
|
||||||
|
case 0 => " "
|
||||||
|
case x => Character.toChars(x + 0x2580).mkString
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.mkString
|
||||||
|
}
|
||||||
|
}
|
6
src/resources/vizier-plugin.json
Normal file
6
src/resources/vizier-plugin.json
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
{
|
||||||
|
"schema_version" : 1,
|
||||||
|
"plugin_class" : "org.mimirdb.pip.Histogram.VizierPlugin$",
|
||||||
|
"name" : "Mimir-Pip",
|
||||||
|
"description" : "Tools for working with histograms."
|
||||||
|
}
|
22
test/src/org/mimirdb/histogram/TestHistograms.scala
Normal file
22
test/src/org/mimirdb/histogram/TestHistograms.scala
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
package org.mimirdb.histogram
|
||||||
|
|
||||||
|
import org.scalatest.flatspec.AnyFlatSpec
|
||||||
|
import org.apache.spark.sql.functions._
|
||||||
|
import org.mimirdb.histogram.udf.Constructors._
|
||||||
|
|
||||||
|
class TestHistograms extends AnyFlatSpec {
|
||||||
|
import TestResources.spark
|
||||||
|
|
||||||
|
"histograms" should "be constructable and introspectable" in {
|
||||||
|
val df = spark.range(0, 10)
|
||||||
|
df.select(fromArrayUdf(
|
||||||
|
array(df("id").cast("double"),
|
||||||
|
lit(1.0),
|
||||||
|
(df("id") + 1).cast("double")
|
||||||
|
))
|
||||||
|
as "test"
|
||||||
|
)
|
||||||
|
.show()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
15
test/src/org/mimirdb/histogram/TestResources.scala
Normal file
15
test/src/org/mimirdb/histogram/TestResources.scala
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
package org.mimirdb.histogram
|
||||||
|
|
||||||
|
import org.apache.spark.sql.SparkSession
|
||||||
|
|
||||||
|
/* Spark Session used across tests */
|
||||||
|
object TestResources {
|
||||||
|
val spark = SparkSession.builder
|
||||||
|
.appName("pip")
|
||||||
|
.master("local[*]")
|
||||||
|
.getOrCreate()
|
||||||
|
|
||||||
|
spark.sparkContext.setLogLevel("WARN")
|
||||||
|
|
||||||
|
Plugin.init(spark)
|
||||||
|
}
|
Loading…
Reference in a new issue