From 67cae8f6b434d32e95ab51eed46e6711c3767250 Mon Sep 17 00:00:00 2001 From: Oliver Kennedy Date: Wed, 24 Jul 2024 17:05:07 -0400 Subject: [PATCH] Initial commit of cleaned up version --- .gitignore | 2 + build.sc | 30 +++ src/org/mimirdb/histogram/Histogram.scala | 251 ++++++++++++++++++ src/org/mimirdb/histogram/Plugin.scala | 19 ++ .../mimirdb/histogram/udf/Constructors.scala | 13 + src/org/mimirdb/histogram/util/SeqUtils.scala | 12 + .../mimirdb/histogram/util/StringUtils.scala | 16 ++ src/resources/vizier-plugin.json | 6 + .../mimirdb/histogram/TestHistograms.scala | 22 ++ .../org/mimirdb/histogram/TestResources.scala | 15 ++ 10 files changed, 386 insertions(+) create mode 100644 .gitignore create mode 100644 build.sc create mode 100644 src/org/mimirdb/histogram/Histogram.scala create mode 100644 src/org/mimirdb/histogram/Plugin.scala create mode 100644 src/org/mimirdb/histogram/udf/Constructors.scala create mode 100644 src/org/mimirdb/histogram/util/SeqUtils.scala create mode 100644 src/org/mimirdb/histogram/util/StringUtils.scala create mode 100644 src/resources/vizier-plugin.json create mode 100644 test/src/org/mimirdb/histogram/TestHistograms.scala create mode 100644 test/src/org/mimirdb/histogram/TestResources.scala diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c052de3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.metals +/out diff --git a/build.sc b/build.sc new file mode 100644 index 0000000..6f6b356 --- /dev/null +++ b/build.sc @@ -0,0 +1,30 @@ +import $ivy.`com.lihaoyi::mill-contrib-bloop:$MILL_VERSION` +import mill._ +import mill.main._ +import mill.scalalib._ +import coursier.maven.{ MavenRepository } + +/************************************************* + *** The Vizier Backend + *************************************************/ +object mimir_histogram extends RootModule with ScalaModule { + val VERSION = "0.0.1-SNAPSHOT" + + def scalaVersion = "2.12.15" + + + def ivyDeps = Agg( + ivy"org.apache.spark::spark-sql:3.3.1", + ivy"org.apache.spark::spark-core:3.3.1", + ivy"org.apache.spark::spark-mllib:3.3.1", + ivy"org.apache.commons:commons-math3:3.6.1" + ) + + object test extends ScalaTests with TestModule { + def ivyDeps = Agg(ivy"org.scalactic::scalactic:3.2.17", + ivy"org.scalatest::scalatest:3.2.17" + ) + def testFramework = "org.scalatest.tools.Framework" + } + +} diff --git a/src/org/mimirdb/histogram/Histogram.scala b/src/org/mimirdb/histogram/Histogram.scala new file mode 100644 index 0000000..4e61cf9 --- /dev/null +++ b/src/org/mimirdb/histogram/Histogram.scala @@ -0,0 +1,251 @@ +package org.mimirdb.histogram + +import org.apache.spark.sql.types.{ DataType, UserDefinedType, ArrayType, DoubleType } +import org.apache.spark.sql.catalyst.util.ArrayData +import org.mimirdb.histogram.util.{ StringUtils, SeqUtils } + +case class HistogramBin(low: Double, high: Double, p: Double) + +case class Histogram(bins: Array[HistogramBin]) +{ + def check(): Histogram = + { + assert(!bins.isEmpty) + assert( + Math.abs(bins.map { _.p }.sum - 1.0) < Histogram.ACCURACY, + s"Unexpected bin boundaries: ${bins.map { _.p }.sum} = ${bins.map { _.p }.mkString(" + ")}" + ) + var curr = bins.head.high + for(x <- bins.tail) + { + assert(x.low < x.high) + assert(x.low == curr) + curr = x.high + } + return this + } + + def sample(random: scala.util.Random): Double = + { + var x = random.nextDouble() + var currentBins = bins + while(x > currentBins.head.p && currentBins.size > 1){ + x -= currentBins.head.p + currentBins = currentBins.tail + } + x /= currentBins.head.p + return x * (currentBins.head.high - currentBins.head.low) + currentBins.head.low + } + + def cdf(value: Double, params: Any, leadingEdge: Boolean = false): Double = + { + bins.map { bin => + if(value >= bin.high){ bin.p } + else if(value >= bin.low){ + bin.p * ((value - bin.low)/(bin.high - bin.low)) + } + else { 0 } + }.sum + } + + def klDivergence(target: Histogram): Double = + { + assert(sameBinsAs(target)) + bins .zip(target.bins) + .map { case (t:HistogramBin, b:HistogramBin) => + if(t.p > 0){ + if(b.p > 0){ + t.p * Math.log(t.p / b.p) + } else { + Double.PositiveInfinity + } + } else { 0 } + }.sum + } + + def entropy: Double = + bins.map { bin => + if(bin.p > 0){ - Math.log(bin.p) * bin.p } + else { 0 } + }.sum + + def min: Double = bins.head.low + def max: Double = bins.last.high + + def binBoundaries: Array[Double] = + (bins.head.low +: bins.map { _.high }).toArray + + def sameBinsAs(other: Histogram): Boolean = + { + (bins.size == other.bins.size) && + bins.zip(other.bins).forall { + case (a, b) => a.low == b.low && a.high == b.high + } + } + + override def toString(): String = + { + if(bins.isEmpty) { "[empty histogram]" } + else { + s"[${ + bins.head.low + }, ${ + bins.last.high + }] -> [${StringUtils.consoleHistogram(bins.map { _.p })}]" + } + } +} + +object Histogram +{ + val ACCURACY = 0.0001 + + def average(histograms: Seq[Histogram]): Histogram = + { + for(p <- histograms.sliding(2)) { + assert( p(0) sameBinsAs p(1) ) + } + Histogram( + SeqUtils.pivot(histograms.map { _.bins.toSeq }).map { vs => + HistogramBin( + low = vs(0).low, + high = vs(0).high, + p = vs.map { _.p }.sum / vs.size + ) + }.toArray + ) + } + +} + +class HistogramType extends UserDefinedType[Histogram] with Serializable +{ + + override def equals(other: Any): Boolean = other match { + case _: UserDefinedType[_] => other.isInstanceOf[HistogramType] + case _ => false + } + def sqlType: DataType = ArrayType(DoubleType) + def userClass = classOf[Histogram] + def deserialize(datum: Any): Histogram = + { + datum match { + case h: Histogram => h + case a: Array[Double] => + assert(a.size % 2 == 1) + Histogram( + (0 until (a.size / 2)).map { i => + HistogramBin(a(i*2), a(i*2+2), a(i*2+1)) + }.toArray + ) + case a: ArrayData => + deserialize(a.toDoubleArray) + } + } + def serialize(datum: Histogram): ArrayData = + { + ArrayData.toArrayData( + ( + datum.bins(0).low +: + datum.bins.flatMap { bin => + Seq(bin.p, bin.high) + } + ).toArray + ) + } + + // def apply(base: UnivariateDistribution, bins: Array[Double], samples: Int): Params = + // { + // assert(bins.size >= 2) + // val baseFamily = base.family.asInstanceOf[NumericalDistributionFamily] + + // val params:Params = + // if(baseFamily.approximateCDFIsFast(base.params)){ + // val startCDF = baseFamily.approximateCDF(bins.head, base.params, 1000, leadingEdge = true) + // val endCDF = baseFamily.approximateCDF(bins.last, base.params, 1000, leadingEdge = false) + // val adjustCDF = endCDF - startCDF + // var lastCDF = startCDF + // var lastBin = bins.head + // assert(adjustCDF > 0, s"Error histogramming $base Using CDF [${bins.head} - ${bins.last}]: $startCDF; $endCDF; $adjustCDF") + // // println(s"Fast Path: $startCDF") + // bins.tail.map { binHigh => + // val binLow = lastBin + // var cdf = baseFamily.approximateCDF(binHigh, base.params, 1000, leadingEdge = false) + // val result = Bin(binLow, binHigh, (cdf - lastCDF) / adjustCDF) + // lastCDF = baseFamily.approximateCDF(binHigh, base.params, 1000, leadingEdge = true) + // lastBin = binHigh + // result + // }:Params + // } else { + // // println(s"For $base, sampling histogram") + // val counts = Array.fill(bins.size-1)(0) + // var missed = 0 + // for(i <- 0 until samples) + // { + // val sample = base.family.sample(base, scala.util.Random).asInstanceOf[Double] + // val bin = bins.search(sample) + // // println(s"Sample: $sample") + // if(bin.insertionPoint == 0 || bin.insertionPoint > bins.size){ + // // println(s" MISSED") + // missed += 1 + // } else { + // counts(bin.insertionPoint - 1) += 1 + // } + // } + // counts.zipWithIndex.map { case (count, bin) => + // val binLow = bins(bin) + // val binHigh = bins(bin+1) + // val cdf = count.toDouble / (samples - missed) + // Bin(binLow, binHigh, cdf) + // }:Params + // } + // // println(bins.mkString(", ")) + // check(params) + + // return params + // } + // def apply(bins: Seq[Double], values: Seq[Double]): Params = + // { + // assert(bins.size == values.size+1) + // val total = values.sum + // assert(total > 0) + // values.indices.map { i => + // Bin(bins(i), bins(i+1), values(i)) + // } + // } + + // case class Constructor(args: Seq[Expression]) + // extends UnivariateDistributionConstructor + // { + // def family = Discretized + // def params(values: Seq[Any]) = + // Discretized( + // base = UnivariateDistribution.decode(values(0)), + // bins = + // values(1).asInstanceOf[Double].until( + // values(2).asInstanceOf[Double], + // values(3).asInstanceOf[Double] + // ).toArray, + // samples = values(4).asInstanceOf[Int] + // ) + + // def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = + // copy(args = newChildren) + // } + + // case class FromBinsConstructor(args: Seq[Expression]) + // extends UnivariateDistributionConstructor + // { + // def family = Discretized + // def params(values: Seq[Any]) = + // Discretized( + // bins = values(0).asInstanceOf[ArrayData].toDoubleArray().toSeq, + // values = values(1).asInstanceOf[ArrayData].toDoubleArray().toSeq, + // ) + + // def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = + // copy(args = newChildren) + // } +} + +object HistogramType extends HistogramType \ No newline at end of file diff --git a/src/org/mimirdb/histogram/Plugin.scala b/src/org/mimirdb/histogram/Plugin.scala new file mode 100644 index 0000000..f7490a7 --- /dev/null +++ b/src/org/mimirdb/histogram/Plugin.scala @@ -0,0 +1,19 @@ +package org.mimirdb.histogram + +import org.apache.spark.sql.types.UDTRegistration +import org.apache.spark.sql.SparkSession + +object Plugin +{ + def init(spark: SparkSession): Unit = + { + UDTRegistration.register( + classOf[Histogram].getName(), + classOf[HistogramType].getName() + ) + spark.udf.register("HIST_FromArray", + udf.Constructors.fromArrayUdf) + + } + +} \ No newline at end of file diff --git a/src/org/mimirdb/histogram/udf/Constructors.scala b/src/org/mimirdb/histogram/udf/Constructors.scala new file mode 100644 index 0000000..4f494bc --- /dev/null +++ b/src/org/mimirdb/histogram/udf/Constructors.scala @@ -0,0 +1,13 @@ +package org.mimirdb.histogram.udf + +import org.apache.spark.sql.functions.udf +import org.mimirdb.histogram._ + +object Constructors +{ + def fromArray(elems: Array[Double]): Histogram = + { + HistogramType.deserialize(elems).check() + } + val fromArrayUdf = udf(fromArray(_)) +} \ No newline at end of file diff --git a/src/org/mimirdb/histogram/util/SeqUtils.scala b/src/org/mimirdb/histogram/util/SeqUtils.scala new file mode 100644 index 0000000..90ed728 --- /dev/null +++ b/src/org/mimirdb/histogram/util/SeqUtils.scala @@ -0,0 +1,12 @@ +package org.mimirdb.histogram.util + +object SeqUtils +{ + def pivot[A](elems: Seq[Seq[A]]): Seq[Seq[A]] = + { + if(elems.isEmpty){ return Seq.empty } + elems.tail.foldRight(elems.head.map { _ :: Nil }){ (vs, lists) => + vs.zip(lists).map { case (v, list) => v :: list } + } + } +} \ No newline at end of file diff --git a/src/org/mimirdb/histogram/util/StringUtils.scala b/src/org/mimirdb/histogram/util/StringUtils.scala new file mode 100644 index 0000000..e0c49f0 --- /dev/null +++ b/src/org/mimirdb/histogram/util/StringUtils.scala @@ -0,0 +1,16 @@ +package org.mimirdb.histogram.util + +object StringUtils +{ + def consoleHistogram(bins: Seq[Double], min: Double = 0.0, max: Double = 1.0): String = + { + bins.map { b => + val p = Math.min(10, (Math.max(0, Math.ceil((b - min) / (max - min) * 10).toInt))) + p match { + case 0 => " " + case x => Character.toChars(x + 0x2580).mkString + } + } + .mkString + } +} \ No newline at end of file diff --git a/src/resources/vizier-plugin.json b/src/resources/vizier-plugin.json new file mode 100644 index 0000000..ba6eafa --- /dev/null +++ b/src/resources/vizier-plugin.json @@ -0,0 +1,6 @@ +{ + "schema_version" : 1, + "plugin_class" : "org.mimirdb.pip.Histogram.VizierPlugin$", + "name" : "Mimir-Pip", + "description" : "Tools for working with histograms." +} \ No newline at end of file diff --git a/test/src/org/mimirdb/histogram/TestHistograms.scala b/test/src/org/mimirdb/histogram/TestHistograms.scala new file mode 100644 index 0000000..0be8140 --- /dev/null +++ b/test/src/org/mimirdb/histogram/TestHistograms.scala @@ -0,0 +1,22 @@ +package org.mimirdb.histogram + +import org.scalatest.flatspec.AnyFlatSpec +import org.apache.spark.sql.functions._ +import org.mimirdb.histogram.udf.Constructors._ + +class TestHistograms extends AnyFlatSpec { + import TestResources.spark + + "histograms" should "be constructable and introspectable" in { + val df = spark.range(0, 10) + df.select(fromArrayUdf( + array(df("id").cast("double"), + lit(1.0), + (df("id") + 1).cast("double") + )) + as "test" + ) + .show() + } + +} \ No newline at end of file diff --git a/test/src/org/mimirdb/histogram/TestResources.scala b/test/src/org/mimirdb/histogram/TestResources.scala new file mode 100644 index 0000000..e56dc44 --- /dev/null +++ b/test/src/org/mimirdb/histogram/TestResources.scala @@ -0,0 +1,15 @@ +package org.mimirdb.histogram + +import org.apache.spark.sql.SparkSession + +/* Spark Session used across tests */ +object TestResources { + val spark = SparkSession.builder + .appName("pip") + .master("local[*]") + .getOrCreate() + + spark.sparkContext.setLogLevel("WARN") + + Plugin.init(spark) +} \ No newline at end of file