Compare commits
11 Commits
main
...
spatial-in
Author | SHA1 | Date |
---|---|---|
Oliver Kennedy | 95c1c7bd49 | |
Oliver Kennedy | f278396c78 | |
Oliver Kennedy | 30ad0b2dae | |
Oliver Kennedy | 78c923c5af | |
Oliver Kennedy | e8e371f8fb | |
Oliver Kennedy | 443e34651d | |
Oliver Kennedy | 5479e9578c | |
Oliver Kennedy | a81aab3a68 | |
Oliver Kennedy | d34bc54770 | |
Oliver Kennedy | 1728aea5c8 | |
Oliver Kennedy | 32a996285b |
|
@ -3,5 +3,7 @@
|
|||
/src/out
|
||||
.bloop
|
||||
.vscode
|
||||
vizier.db
|
||||
mill-worker-*
|
||||
mill-runner-*
|
||||
decision_tree.ml
|
||||
|
|
1
build.sc
1
build.sc
|
@ -39,6 +39,7 @@ object mimir_pip extends RootModule with ScalaModule {
|
|||
def ivyDeps = Agg(
|
||||
ivy"org.apache.spark::spark-sql:3.3.1",
|
||||
ivy"org.apache.spark::spark-core:3.3.1",
|
||||
ivy"org.apache.spark::spark-mllib:3.3.1",
|
||||
ivy"org.apache.commons:commons-math3:3.6.1"
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
package org.apache.spark.ml.haxx
|
||||
|
||||
import org.apache.spark.ml.classification.{ DecisionTreeClassifier, DecisionTreeClassificationModel }
|
||||
import org.apache.spark.ml.tree.{
|
||||
Node,
|
||||
InternalNode,
|
||||
LeafNode,
|
||||
Split,
|
||||
CategoricalSplit,
|
||||
ContinuousSplit,
|
||||
}
|
||||
import org.mimirdb.pip.lib.DistSummary
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
|
||||
object ExtractDecisionTree
|
||||
{
|
||||
def apply(model: DecisionTreeClassificationModel, features: Seq[DistSummary.Feature]): DistSummary =
|
||||
{
|
||||
new DistSummary(
|
||||
features,
|
||||
extractNode(model.rootNode)
|
||||
)
|
||||
}
|
||||
|
||||
def extractNode(node: Node): DistSummary.Node =
|
||||
{
|
||||
node match {
|
||||
case n: InternalNode =>
|
||||
DistSummary.InnerNode(
|
||||
n.split,
|
||||
extractNode(n.leftChild),
|
||||
extractNode(n.rightChild),
|
||||
)
|
||||
case l: LeafNode =>
|
||||
DistSummary.LeafNode()
|
||||
}
|
||||
}
|
||||
|
||||
def print(model: DecisionTreeClassificationModel): Unit =
|
||||
{
|
||||
printNode(model.rootNode)
|
||||
}
|
||||
|
||||
def printNode(node: Node, prefix: String = ""): Unit =
|
||||
{
|
||||
node match {
|
||||
case n: InternalNode => printInternalNode(n, prefix)
|
||||
case _ =>
|
||||
println(s"$prefix [${node.getClass}] ${node}")
|
||||
}
|
||||
}
|
||||
def printInternalNode(node: InternalNode, prefix: String = ""): Unit =
|
||||
{
|
||||
println(s"$prefix [${node.getClass}] ${getSplit(node.split)}")
|
||||
printNode(node.leftChild, prefix + " ")
|
||||
printNode(node.rightChild, prefix + " ")
|
||||
}
|
||||
def getSplit(split: Split): String =
|
||||
split match {
|
||||
case c: CategoricalSplit => s"${c.featureIndex} -> [${c.leftCategories.mkString(",")}] vs [${c.rightCategories.mkString(",")}]"
|
||||
case c: ContinuousSplit => s"${c.featureIndex} -> ${c.threshold}"
|
||||
}
|
||||
}
|
|
@ -28,16 +28,25 @@ object Pip {
|
|||
.functionRegistry
|
||||
.createOrReplaceTempFunction(name, fn, "scala_udf")
|
||||
|
||||
registerFunction("gaussian", distribution.Gaussian.Constructor(_))
|
||||
registerFunction("uniform", distribution.Uniform.Constructor(_))
|
||||
registerFunction("num_const", distribution.ConstantNumber.Constructor(_))
|
||||
registerFunction("clamp", distribution.Clamp.Constructor)
|
||||
registerFunction("discretize", distribution.Discretized.Constructor)
|
||||
registerFunction("gaussian", distribution.numerical.Gaussian.Constructor(_))
|
||||
registerFunction("uniform", distribution.numerical.Uniform.Constructor(_))
|
||||
registerFunction("num_const", distribution.numerical.ConstantNumber.Constructor(_))
|
||||
registerFunction("clamp", distribution.numerical.Clamp.Constructor)
|
||||
registerFunction("discretize", distribution.numerical.Discretized.Constructor)
|
||||
registerFunction("dist_between", distribution.boolean.Between.Constructor)
|
||||
spark.udf.register("entropy", udf.Entropy.udf)
|
||||
spark.udf.register("kl_divergence", udf.KLDivergence.udf)
|
||||
spark.udf.register("pip_min", udf.pip_min.udf)
|
||||
spark.udf.register("pip_max", udf.pip_max.udf)
|
||||
spark.udf.register("pip_p_between", udf.pip_p_between.udf)
|
||||
spark.udf.register("pip_histogram", udf.pip_histogram.udf)
|
||||
spark.udf.register("pip_export", udf.Export.udf)
|
||||
spark.udf.register("pip_hc_bottom_up", lib.HierarchicalClustering.bottomUpUdf.udf)
|
||||
spark.udf.register("pip_hc_list_thresholds", lib.HierarchicalClustering.listThresholdsUdf.udf)
|
||||
spark.udf.register("pip_hc_extract_clusters", lib.HierarchicalClustering.extractClustersUdf.udf)
|
||||
|
||||
// Aggregates
|
||||
spark.udf.register("uniform_mixture", distribution.NumericalMixture.uniform)
|
||||
spark.udf.register("uniform_mixture", distribution.numerical.NumericalMixture.uniform)
|
||||
spark.udf.register("histogram", udaf(udf.Histogram))
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
package org.mimirdb.pip.distribution.boolean
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
import org.mimirdb.pip.udt.UnivariateDistributionConstructor
|
||||
|
||||
object Bernoulli
|
||||
extends BooleanDistributionFamily
|
||||
with ProbabilitySupported
|
||||
{
|
||||
|
||||
def probability(params: Any): Double =
|
||||
params.asInstanceOf[Double]
|
||||
|
||||
def describe(params: Any): String =
|
||||
s"Bernoulli($params)"
|
||||
|
||||
def sample(params: Any, random: scala.util.Random): Boolean =
|
||||
random.nextDouble() < params.asInstanceOf[Double]
|
||||
|
||||
def deserialize(in: java.io.ObjectInputStream): Any =
|
||||
in.readDouble()
|
||||
|
||||
def serialize(out: java.io.ObjectOutputStream, params: Any): Unit =
|
||||
out.writeDouble(params.asInstanceOf[Double])
|
||||
|
||||
case class Constructor(args: Seq[Expression])
|
||||
extends UnivariateDistributionConstructor
|
||||
{
|
||||
def family = Bernoulli
|
||||
def params(values: Seq[Any]) = values(0).asInstanceOf[Double]
|
||||
|
||||
def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
|
||||
copy(args = newChildren)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
package org.mimirdb.pip.distribution.boolean
|
||||
|
||||
import org.mimirdb.pip.distribution.DistributionFamily
|
||||
import org.mimirdb.pip.distribution.numerical.NumericalDistributionFamily
|
||||
import org.mimirdb.pip.distribution.numerical.CDFSupported
|
||||
import org.mimirdb.pip.udt.UnivariateDistribution
|
||||
import org.mimirdb.pip.udt.UnivariateDistributionConstructor
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
|
||||
object Between
|
||||
extends BooleanDistributionFamily
|
||||
{
|
||||
case class Params(lower: Double, upper: Double, child: UnivariateDistribution)
|
||||
{
|
||||
def family =
|
||||
child.family.asInstanceOf[NumericalDistributionFamily]
|
||||
def apply[A](op: (NumericalDistributionFamily, Any) => A): A =
|
||||
{
|
||||
op(family, child.params)
|
||||
}
|
||||
}
|
||||
|
||||
override def approximateProbability(params: Any, samples: Int): Double =
|
||||
{
|
||||
val config = params.asInstanceOf[Params]
|
||||
|
||||
config.family match {
|
||||
case dist: CDFSupported =>
|
||||
val upperCDF =
|
||||
dist.cdf(config.upper, config.child.params)
|
||||
val lowerCDF =
|
||||
dist.cdf(config.lower, config.child.params)
|
||||
upperCDF - lowerCDF
|
||||
case dist => super.approximateProbability(params, samples)
|
||||
}
|
||||
}
|
||||
override def approximateProbabilityIsFast(params: Any): Boolean =
|
||||
params.asInstanceOf[Params].family.isInstanceOf[CDFSupported]
|
||||
|
||||
def describe(params: Any): String =
|
||||
{
|
||||
val config = params.asInstanceOf[Params]
|
||||
s"Between(${config.lower} < ${config { _.describe(_) }} < ${config.upper})"
|
||||
}
|
||||
|
||||
def sample(params: Any, random: scala.util.Random): Boolean =
|
||||
{
|
||||
val config = params.asInstanceOf[Params]
|
||||
val v:Double = config { _.sample(_, random).asInstanceOf[Double] }
|
||||
|
||||
return (v > config.lower) && (v < config.upper)
|
||||
}
|
||||
|
||||
def deserialize(in: java.io.ObjectInputStream): Any =
|
||||
{
|
||||
val lower = in.readDouble()
|
||||
val upper = in.readDouble()
|
||||
val dist = DistributionFamily(in.readUTF())
|
||||
Params(
|
||||
lower = lower,
|
||||
upper = upper,
|
||||
child =
|
||||
UnivariateDistribution(
|
||||
family = dist,
|
||||
params = dist.deserialize(in)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
def serialize(out: java.io.ObjectOutputStream, params: Any): Unit =
|
||||
{
|
||||
val config = params.asInstanceOf[Params]
|
||||
out.writeDouble(config.lower)
|
||||
out.writeDouble(config.upper)
|
||||
val family = config.family
|
||||
out.writeUTF(family.label)
|
||||
family.serialize(out, params)
|
||||
}
|
||||
|
||||
case class Constructor(args: Seq[Expression])
|
||||
extends UnivariateDistributionConstructor
|
||||
{
|
||||
def family = Between
|
||||
def params(values: Seq[Any]) =
|
||||
{
|
||||
Params(
|
||||
lower = values(0).asInstanceOf[Double],
|
||||
child = UnivariateDistribution.decode(values(1)),
|
||||
upper = values(2).asInstanceOf[Double],
|
||||
)
|
||||
}
|
||||
|
||||
def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
|
||||
copy(args = newChildren)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
package org.mimirdb.pip.distribution.boolean
|
||||
|
||||
import org.apache.spark.sql.types.{ DataType, BooleanType }
|
||||
import org.mimirdb.pip.distribution.DistributionFamily
|
||||
|
||||
/**
|
||||
* A [Distribution] that specifically samples numbers
|
||||
*/
|
||||
trait BooleanDistributionFamily extends DistributionFamily
|
||||
{
|
||||
val baseType = BooleanType
|
||||
|
||||
def approximateProbability(params: Any, samples: Int): Double =
|
||||
this match {
|
||||
case c:ProbabilitySupported => c.probability(params)
|
||||
case _ =>
|
||||
{
|
||||
val rand = new scala.util.Random()
|
||||
(0 until samples).count { _ =>
|
||||
sample(params, rand).asInstanceOf[Boolean]
|
||||
}.toDouble / samples
|
||||
}
|
||||
}
|
||||
|
||||
def approximateProbabilityIsFast(params: Any): Boolean = this.isInstanceOf[ProbabilitySupported]
|
||||
}
|
||||
|
||||
/**
|
||||
* An add-on to NumericalDistributionFamily that indicates an exact CDF can be computed
|
||||
*/
|
||||
trait ProbabilitySupported
|
||||
{
|
||||
val baseType: DataType
|
||||
|
||||
assert(baseType == BooleanType, "Non-boolean distributions can not support probabilities")
|
||||
|
||||
def probability(params: Any): Double
|
||||
}
|
|
@ -57,57 +57,6 @@ trait DistributionFamily
|
|||
def label = this.getClass.getSimpleName.toLowerCase
|
||||
}
|
||||
|
||||
/**
|
||||
* A [Distribution] that specifically samples numbers
|
||||
*/
|
||||
trait NumericalDistributionFamily extends DistributionFamily
|
||||
{
|
||||
val baseType = DoubleType
|
||||
|
||||
/**
|
||||
* Compute the CDF
|
||||
*/
|
||||
def approximateCDF(value: Double, params: Any, samples: Int): Double =
|
||||
this match {
|
||||
case c:CDFSupported => c.cdf(value, params)
|
||||
case _ =>
|
||||
{
|
||||
val rand = new scala.util.Random()
|
||||
(0 until samples).count { _ =>
|
||||
sample(params, rand).asInstanceOf[Double] <= value
|
||||
}.toDouble / samples
|
||||
}
|
||||
}
|
||||
def approximateCDFIsFast(params: Any): Boolean = this.isInstanceOf[CDFSupported]
|
||||
|
||||
def min(params: Any): Double
|
||||
def max(params: Any): Double
|
||||
}
|
||||
|
||||
/**
|
||||
* An add-on to NumericalDistributionFamily that indicates an exact CDF can be computed
|
||||
*/
|
||||
trait CDFSupported
|
||||
{
|
||||
val baseType: DataType
|
||||
|
||||
assert(baseType == DoubleType, "Non-numerical distributions can not support CDFs")
|
||||
|
||||
def cdf(value: Double, params: Any): Double
|
||||
}
|
||||
|
||||
/**
|
||||
* An add-on to NumericalDistributionFamily that indicates an exact Inverse CDF can be computed
|
||||
*/
|
||||
trait ICDFSupported
|
||||
{
|
||||
val baseType: DataType
|
||||
|
||||
assert(baseType == DoubleType, "Non-numerical distributions can not support ICDFs")
|
||||
|
||||
def icdf(value: Double, params: Any): Double
|
||||
}
|
||||
|
||||
/**
|
||||
* Companion object for distributions: Keeps a registry of all known distributions
|
||||
*/
|
||||
|
@ -128,9 +77,9 @@ object DistributionFamily
|
|||
|
||||
|
||||
/// Pre-defined distributions
|
||||
register(Gaussian)
|
||||
register(NumericalMixture)
|
||||
register(Clamp)
|
||||
register(Discretized)
|
||||
register(Uniform)
|
||||
register(numerical.Gaussian)
|
||||
register(numerical.NumericalMixture)
|
||||
register(numerical.Clamp)
|
||||
register(numerical.Discretized)
|
||||
register(numerical.Uniform)
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package org.mimirdb.pip.distribution
|
||||
package org.mimirdb.pip.distribution.numerical
|
||||
|
||||
import scala.util.Random
|
||||
import java.io.ObjectOutputStream
|
||||
|
@ -7,6 +7,7 @@ import org.mimirdb.pip.udt.UnivariateDistribution
|
|||
import org.apache.spark.sql.functions
|
||||
import org.mimirdb.pip.udt.UnivariateDistributionConstructor
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
import org.mimirdb.pip.distribution.DistributionFamily
|
||||
|
||||
object Clamp
|
||||
extends NumericalDistributionFamily
|
||||
|
@ -103,20 +104,25 @@ object Clamp
|
|||
)
|
||||
}
|
||||
|
||||
override def approximateCDF(value: Double, params: Any, samples: Int): Double =
|
||||
override def approximateCDF(value: Double, params: Any, samples: Int, leadingEdge: Boolean = false): Double =
|
||||
{
|
||||
val child = params.asInstanceOf[Params]
|
||||
if(child.family.approximateCDFIsFast(params))
|
||||
{
|
||||
val lowBound = child.low.map { child.family.approximateCDF(_, child.params, 1000) }.getOrElse { 0.0 }
|
||||
val highBound = child.high.map { child.family.approximateCDF(_, child.params, 1000) }.getOrElse { 1.0 }
|
||||
val actual = child.family.approximateCDF(value, child.params, 1000)
|
||||
val actual = child.family.approximateCDF(value, child.params, 1000, leadingEdge)
|
||||
// println(s"CDF of $value @ Clamp Bounds: [${child.low} -> $lowBound, ${child.high} -> $highBound]: ${child.family.describe(child.params)}")
|
||||
if(actual < lowBound){ return 0.0 }
|
||||
if(actual > highBound){ return 1.0 }
|
||||
if(leadingEdge){
|
||||
if(actual <= lowBound){ return 0.0 }
|
||||
if(actual > highBound){ return 1.0 }
|
||||
} else {
|
||||
if(actual < lowBound){ return 0.0 }
|
||||
if(actual >= highBound){ return 1.0 }
|
||||
}
|
||||
return (actual - lowBound) / (highBound - lowBound)
|
||||
} else {
|
||||
super.approximateCDF(value, params, samples)
|
||||
super.approximateCDF(value, params, samples, leadingEdge)
|
||||
}
|
||||
}
|
||||
override def approximateCDFIsFast(params: Any): Boolean =
|
|
@ -1,4 +1,4 @@
|
|||
package org.mimirdb.pip.distribution
|
||||
package org.mimirdb.pip.distribution.numerical
|
||||
|
||||
import scala.util.Random
|
||||
import java.io.Serializable
|
||||
|
@ -36,11 +36,12 @@ object ConstantNumber
|
|||
def min(params: Any) = params.asInstanceOf[Double]
|
||||
def max(params: Any) = params.asInstanceOf[Double]
|
||||
|
||||
def cdf(value: Double, params: Any): Double =
|
||||
def cdf(value: Double, params: Any, leadingEdge: Boolean = false): Double =
|
||||
{
|
||||
val p = params.asInstanceOf[Double]
|
||||
if(value < p) { 0.0 }
|
||||
else { 1.0 }
|
||||
if(leadingEdge && value <= p) { 0.0 }
|
||||
else if(value < p) { 0.0 }
|
||||
else { 1.0 }
|
||||
}
|
||||
|
||||
def icdf(value: Double, params: Any): Double =
|
|
@ -1,4 +1,4 @@
|
|||
package org.mimirdb.pip.distribution
|
||||
package org.mimirdb.pip.distribution.numerical
|
||||
|
||||
import scala.util.Random
|
||||
import java.io.ObjectOutputStream
|
||||
|
@ -49,7 +49,7 @@ object Discretized
|
|||
return x * (bins.head.high - bins.head.low) + bins.head.low
|
||||
}
|
||||
|
||||
def cdf(value: Double, params: Any): Double =
|
||||
def cdf(value: Double, params: Any, leadingEdge: Boolean = false): Double =
|
||||
{
|
||||
params.asInstanceOf[Params].map { bin =>
|
||||
if(value >= bin.high){ bin.p }
|
||||
|
@ -136,28 +136,32 @@ object Discretized
|
|||
|
||||
val params:Params =
|
||||
if(baseFamily.approximateCDFIsFast(base.params)){
|
||||
val startCDF = baseFamily.approximateCDF(bins.head, base.params, 1000)
|
||||
val endCDF = baseFamily.approximateCDF(bins.last, base.params, 1000)
|
||||
val startCDF = baseFamily.approximateCDF(bins.head, base.params, 1000, leadingEdge = true)
|
||||
val endCDF = baseFamily.approximateCDF(bins.last, base.params, 1000, leadingEdge = false)
|
||||
val adjustCDF = endCDF - startCDF
|
||||
var lastCDF = startCDF
|
||||
var lastBin = bins.head
|
||||
assert(adjustCDF > 0, s"Error histogramming $base Using CDF [${bins.head} - ${bins.last}]: $startCDF; $endCDF; $adjustCDF")
|
||||
// println(s"Fast Path: $startCDF")
|
||||
bins.tail.map { binHigh =>
|
||||
val binLow = lastBin
|
||||
var cdf = baseFamily.approximateCDF(binHigh, base.params, 1000)
|
||||
var cdf = baseFamily.approximateCDF(binHigh, base.params, 1000, leadingEdge = false)
|
||||
val result = Bin(binLow, binHigh, (cdf - lastCDF) / adjustCDF)
|
||||
lastCDF = cdf
|
||||
lastCDF = baseFamily.approximateCDF(binHigh, base.params, 1000, leadingEdge = true)
|
||||
lastBin = binHigh
|
||||
result
|
||||
}:Params
|
||||
} else {
|
||||
// println(s"For $base, sampling histogram")
|
||||
val counts = Array.fill(bins.size-1)(0)
|
||||
var missed = 0
|
||||
for(i <- 0 until samples)
|
||||
{
|
||||
val sample = base.family.sample(base, scala.util.Random).asInstanceOf[Double]
|
||||
val bin = bins.search(sample)
|
||||
// println(s"Sample: $sample")
|
||||
if(bin.insertionPoint == 0 || bin.insertionPoint > bins.size){
|
||||
// println(s" MISSED")
|
||||
missed += 1
|
||||
} else {
|
||||
counts(bin.insertionPoint - 1) += 1
|
|
@ -1,4 +1,4 @@
|
|||
package org.mimirdb.pip.distribution
|
||||
package org.mimirdb.pip.distribution.numerical
|
||||
|
||||
import scala.util.Random
|
||||
import java.io.Serializable
|
||||
|
@ -12,6 +12,7 @@ import org.mimirdb.pip.udt.UnivariateDistribution
|
|||
import org.mimirdb.pip.udt.UnivariateDistributionType
|
||||
import org.mimirdb.pip.udt.UnivariateDistributionConstructor
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
import org.apache.spark.sql.Column
|
||||
|
||||
/**
|
||||
* The Gaussian (normal) distribution
|
||||
|
@ -22,6 +23,12 @@ object Gaussian
|
|||
with CDFSupported
|
||||
with ICDFSupported
|
||||
{
|
||||
def apply(mean: Column, stddev: Column): Column =
|
||||
{
|
||||
new Column(Constructor(Seq(mean.expr, stddev.expr)))
|
||||
}
|
||||
|
||||
|
||||
case class Params(mean: Double, sd: Double)
|
||||
|
||||
def sample(params: Any, random: scala.util.Random): Double =
|
||||
|
@ -47,7 +54,7 @@ object Gaussian
|
|||
def min(params: Any) = Double.NegativeInfinity
|
||||
def max(params: Any) = Double.PositiveInfinity
|
||||
|
||||
def cdf(value: Double, params: Any): Double =
|
||||
def cdf(value: Double, params: Any, leadingEdge: Boolean = false): Double =
|
||||
(
|
||||
1 + Erf.erf(
|
||||
(value - params.asInstanceOf[Params].mean)
|
|
@ -1,4 +1,4 @@
|
|||
package org.mimirdb.pip.distribution
|
||||
package org.mimirdb.pip.distribution.numerical
|
||||
|
||||
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
|
||||
|
@ -7,6 +7,7 @@ import org.apache.spark.sql.functions.udaf
|
|||
import org.mimirdb.pip.udt.UnivariateDistribution
|
||||
import java.util.UUID
|
||||
import org.mimirdb.pip.SampleParams
|
||||
import org.mimirdb.pip.distribution.DistributionFamily
|
||||
|
||||
object NumericalMixture
|
||||
extends NumericalDistributionFamily
|
||||
|
@ -72,10 +73,10 @@ object NumericalMixture
|
|||
bin.family.max(bin.params)
|
||||
}.max
|
||||
|
||||
override def approximateCDF(value: Double, params: Any, samples: Int): Double =
|
||||
override def approximateCDF(value: Double, params: Any, samples: Int, leadingEdge: Boolean = false): Double =
|
||||
{
|
||||
params.asInstanceOf[Params].map { bin =>
|
||||
bin.p * bin.family.approximateCDF(value, bin.params, samples)
|
||||
bin.p * bin.family.approximateCDF(value, bin.params, samples, leadingEdge)
|
||||
}.sum
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package org.mimirdb.pip.distribution
|
||||
package org.mimirdb.pip.distribution.numerical
|
||||
|
||||
import scala.util.Random
|
||||
import java.io.Serializable
|
||||
|
@ -51,12 +51,13 @@ object Uniform
|
|||
def min(params: Any) = params.asInstanceOf[Params].min
|
||||
def max(params: Any) = params.asInstanceOf[Params].max
|
||||
|
||||
def cdf(value: Double, params: Any): Double =
|
||||
def cdf(value: Double, params: Any, leadingEdge: Boolean = false): Double =
|
||||
{
|
||||
val p = params.asInstanceOf[Params]
|
||||
if(value < p.min) { 0.0 }
|
||||
else if(value >= p.max) { 1.0 }
|
||||
else { (value - p.min) / p.width }
|
||||
if(!leadingEdge && value <= p.min) { 0.0 }
|
||||
else if(value < p.min) { 0.0 }
|
||||
else if(value > p.max) { 1.0 }
|
||||
else { (value - p.min) / p.width }
|
||||
}
|
||||
|
||||
def icdf(value: Double, params: Any): Double =
|
|
@ -0,0 +1,59 @@
|
|||
package org.mimirdb.pip.distribution.numerical
|
||||
|
||||
import org.apache.spark.sql.types.{ DataType, DoubleType }
|
||||
import org.mimirdb.pip.distribution.DistributionFamily
|
||||
|
||||
/**
|
||||
* A [Distribution] that specifically samples numbers
|
||||
*/
|
||||
trait NumericalDistributionFamily extends DistributionFamily
|
||||
{
|
||||
val baseType = DoubleType
|
||||
|
||||
/**
|
||||
* Compute the CDF
|
||||
*/
|
||||
def approximateCDF(value: Double, params: Any, samples: Int, leadingEdge: Boolean = false): Double =
|
||||
this match {
|
||||
case c:CDFSupported => c.cdf(value, params)
|
||||
case _ =>
|
||||
{
|
||||
val rand = new scala.util.Random()
|
||||
(0 until samples).count { _ =>
|
||||
if(leadingEdge){
|
||||
sample(params, rand).asInstanceOf[Double] < value
|
||||
} else {
|
||||
sample(params, rand).asInstanceOf[Double] <= value
|
||||
}
|
||||
}.toDouble / samples
|
||||
}
|
||||
}
|
||||
def approximateCDFIsFast(params: Any): Boolean = this.isInstanceOf[CDFSupported]
|
||||
|
||||
def min(params: Any): Double
|
||||
def max(params: Any): Double
|
||||
}
|
||||
|
||||
/**
|
||||
* An add-on to NumericalDistributionFamily that indicates an exact CDF can be computed
|
||||
*/
|
||||
trait CDFSupported
|
||||
{
|
||||
val baseType: DataType
|
||||
|
||||
assert(baseType == DoubleType, "Non-numerical distributions can not support CDFs")
|
||||
|
||||
def cdf(value: Double, params: Any, leadingEdge: Boolean = false): Double
|
||||
}
|
||||
|
||||
/**
|
||||
* An add-on to NumericalDistributionFamily that indicates an exact Inverse CDF can be computed
|
||||
*/
|
||||
trait ICDFSupported
|
||||
{
|
||||
val baseType: DataType
|
||||
|
||||
assert(baseType == DoubleType, "Non-numerical distributions can not support ICDFs")
|
||||
|
||||
def icdf(value: Double, params: Any): Double
|
||||
}
|
|
@ -0,0 +1,94 @@
|
|||
package org.mimirdb.pip.lib
|
||||
import org.apache.spark.ml.tree.{
|
||||
Split,
|
||||
CategoricalSplit,
|
||||
ContinuousSplit,
|
||||
}
|
||||
import scala.collection.mutable
|
||||
import org.mimirdb.pip.udt.UnivariateDistribution
|
||||
|
||||
class DistSummary(dimensions: Seq[DistSummary.Feature], root: DistSummary.Node)
|
||||
{
|
||||
def insert(address: Array[Double], dist: UnivariateDistribution): Unit =
|
||||
{
|
||||
root.insert(address, dist)
|
||||
}
|
||||
|
||||
def regions(): Seq[Array[DistSummary.Feature]] =
|
||||
{
|
||||
val buffer = mutable.Buffer[Array[DistSummary.Feature]]()
|
||||
root.regions(dimensions, buffer)
|
||||
buffer.toSeq
|
||||
}
|
||||
}
|
||||
|
||||
object DistSummary
|
||||
{
|
||||
sealed trait Node
|
||||
{
|
||||
def insert(address: Array[Double], dist: UnivariateDistribution): Unit
|
||||
def regions(dimensions: Seq[DistSummary.Feature], buffer: mutable.Buffer[Array[DistSummary.Feature]]): Unit
|
||||
}
|
||||
|
||||
case class InnerNode(split: Split, left: Node, right: Node) extends Node
|
||||
{
|
||||
def insert(address: Array[Double], dist: UnivariateDistribution): Unit =
|
||||
split match {
|
||||
case c:ContinuousSplit =>
|
||||
if(address(c.featureIndex) <= c.threshold) {
|
||||
left.insert(address, dist)
|
||||
} else {
|
||||
right.insert(address, dist)
|
||||
}
|
||||
case c:CategoricalSplit =>
|
||||
???
|
||||
}
|
||||
def regions(dimensions: Seq[DistSummary.Feature], buffer: mutable.Buffer[Array[DistSummary.Feature]]): Unit =
|
||||
{
|
||||
split match {
|
||||
case c:ContinuousSplit =>
|
||||
val (leftFeat, rightFeat) = dimensions(c.featureIndex)
|
||||
.asInstanceOf[ContinuousFeature]
|
||||
.split(c.threshold)
|
||||
val before = dimensions.take(c.featureIndex)
|
||||
val after = dimensions.drop(c.featureIndex+1)
|
||||
left.regions(
|
||||
before ++ Seq(leftFeat) ++ after,
|
||||
buffer
|
||||
)
|
||||
right.regions(
|
||||
before ++ Seq(rightFeat) ++ after,
|
||||
buffer
|
||||
)
|
||||
case c:CategoricalSplit =>
|
||||
???
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
case class LeafNode(distributions: mutable.ArrayBuffer[UnivariateDistribution] = mutable.ArrayBuffer.empty) extends Node
|
||||
{
|
||||
def insert(address: Array[Double], dist: UnivariateDistribution): Unit =
|
||||
distributions.append(dist)
|
||||
def regions(dimensions: Seq[DistSummary.Feature], buffer: mutable.Buffer[Array[DistSummary.Feature]]): Unit =
|
||||
{
|
||||
buffer.append(dimensions.toArray)
|
||||
}
|
||||
}
|
||||
|
||||
trait Feature
|
||||
|
||||
case class ContinuousFeature(min: Double, max: Double) extends Feature
|
||||
{
|
||||
def split(cutoff: Double): (ContinuousFeature, ContinuousFeature) =
|
||||
{
|
||||
// assert(min <= cutoff, s"Min not below cutoff: $min </= $cutoff")
|
||||
// assert(cutoff <= max, s"Max not above cutoff: $max >/= $cutoff")
|
||||
(
|
||||
new ContinuousFeature(min, Math.max(cutoff, min)),
|
||||
new ContinuousFeature(Math.min(cutoff, max), max),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
package org.mimirdb.pip.lib
|
||||
|
||||
trait Distance[I]
|
||||
{
|
||||
val min: I
|
||||
val max: I
|
||||
def pointToPoint(a: Array[I], b: Array[I]): Double
|
||||
def pointToPlane(a: Array[I], b: I, dim: Int): Double
|
||||
def centroid(elems: Iterable[Array[I]]): Array[I]
|
||||
}
|
||||
|
||||
object Distance
|
||||
{
|
||||
case class Manhattan(dimensions: Int, min: Double, max: Double) extends Distance[Double]
|
||||
{
|
||||
|
||||
def pointToPlane(a: Array[Double], b: Double, dim: Int): Double =
|
||||
{
|
||||
Math.abs(a(dim) - b)
|
||||
}
|
||||
def centroid(a: Iterable[Array[Double]]): Array[Double] =
|
||||
{
|
||||
val ret = Array.ofDim[Double](dimensions)
|
||||
for(pt <- a)
|
||||
{
|
||||
for(i <- 0 until dimensions)
|
||||
{
|
||||
ret(i) += pt(i)
|
||||
}
|
||||
}
|
||||
for(i <- 0 until dimensions)
|
||||
{
|
||||
ret(i) /= a.size
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
def pointToPoint(a: Array[Double], b: Array[Double]): Double =
|
||||
{
|
||||
var tot = 0.0
|
||||
for(i <- 0 until a.size)
|
||||
{
|
||||
tot += Math.abs(a(i) - b(i))
|
||||
}
|
||||
return tot
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,383 @@
|
|||
package org.mimirdb.pip.lib
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.reflect.ClassTag
|
||||
import scala.collection.mutable.PriorityQueue
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.mutable.Stack
|
||||
import org.apache.spark.sql.functions
|
||||
import org.mimirdb.pip.util.ByteArrayUtils
|
||||
|
||||
object HierarchicalClustering
|
||||
{
|
||||
def naive[I, V](elements: IndexedSeq[(Array[I], V)], measure: Distance[I]): Cluster[I, V] =
|
||||
{
|
||||
|
||||
def distance(i: Int, cmp: Iterable[Int]): Double =
|
||||
cmp.map { j =>
|
||||
if(i == j){ 0.0 }
|
||||
else { measure.pointToPoint(elements(i)._1, elements(j)._1) }
|
||||
}.sum
|
||||
|
||||
assert(!elements.isEmpty)
|
||||
if(elements.size == 1){
|
||||
Singleton(elements.head._1, elements.head._2)
|
||||
} else {
|
||||
val all = (0 until elements.size)
|
||||
|
||||
val (furthestIdx, furthestDistance) =
|
||||
(0 until elements.size)
|
||||
.map { i => i -> distance(i, all) }
|
||||
.maxBy { _._2 }
|
||||
|
||||
val left =
|
||||
mutable.Set(all.filterNot { _ == furthestIdx }:_*)
|
||||
val right =
|
||||
mutable.Set[Int](furthestIdx)
|
||||
|
||||
var done = false
|
||||
while(left.size > 1 && !done){
|
||||
val (bestMoveCandidate, bestMoveScore) =
|
||||
left.toSeq
|
||||
.map { i =>
|
||||
val leftScore =
|
||||
(1.0 / (left.size - 1)) * distance(i, left)
|
||||
val rightScore =
|
||||
(1.0 / right.size) * distance(i, right)
|
||||
i -> (leftScore - rightScore)
|
||||
}
|
||||
.maxBy { _._2 }
|
||||
if(bestMoveScore <= 0) { done = true }
|
||||
else {
|
||||
left -= bestMoveCandidate
|
||||
right += bestMoveCandidate
|
||||
}
|
||||
}
|
||||
|
||||
Group(
|
||||
left = naive(left.toIndexedSeq.map { elements(_) }, measure),
|
||||
right = naive(right.toIndexedSeq.map { elements(_) }, measure),
|
||||
radius = furthestDistance,
|
||||
size = elements.size,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
def bottomUp[I: Ordering, V](elements: IndexedSeq[(Array[I], V)], measure: Distance[I])(implicit tag: ClassTag[I]): Cluster[I, V] =
|
||||
{
|
||||
assert(elements.size > 1)
|
||||
val tree = new KDTree[I, Int](elements(0)._1.size)
|
||||
// higher values get dequeued first--max heap
|
||||
val queue = mutable.PriorityQueue()(new Ordering[(Double, Int, Int)]{
|
||||
def compare(a: (Double, Int, Int), b: (Double, Int, Int)): Int =
|
||||
{
|
||||
// invert to get lower distances first
|
||||
-Ordering[Double].compare(a._1, b._1)
|
||||
}
|
||||
})
|
||||
|
||||
val clusters: mutable.ArrayBuffer[Option[(Array[I], Cluster[I, V])]] =
|
||||
mutable.ArrayBuffer(
|
||||
elements.map { e => Some(e._1 -> Singleton(e._1, e._2)) }:_*
|
||||
)
|
||||
|
||||
def nearest(point: Array[I], idx: Int) =
|
||||
tree.nearest(point, measure, ignore = { (_, others, _) => !others.exists { _ != idx } })
|
||||
|
||||
tree.insertAll(elements.map { _._1 }.zipWithIndex)
|
||||
for( (point, idx) <- elements.map { _._1 }.zipWithIndex )
|
||||
{
|
||||
nearest(point, idx) match {
|
||||
case Some( (nearestPoint, nearestIdxs, distance) ) =>
|
||||
// println(s"Nearest to $idx: $nearestIdxs")
|
||||
queue.enqueue( (distance, idx, nearestIdxs.filter { _ != idx }.head) )
|
||||
case None =>
|
||||
assert(false)
|
||||
}
|
||||
}
|
||||
|
||||
var lastCluster: Cluster[I, V] = null
|
||||
|
||||
// Each iteration removes at least two items from the pqueue and adds one
|
||||
// Hard bound the number of iterations defensively
|
||||
for(i <- 0 until elements.size*4)
|
||||
{
|
||||
if(queue.isEmpty) {
|
||||
// Normally, we'd return below, after the tree.isEmpty test. However
|
||||
// it's possible that the queue is padded out with unhandled todos.
|
||||
// If that happens, return the last pushed cluster here instead.
|
||||
assert(lastCluster != null)
|
||||
return lastCluster
|
||||
}
|
||||
val (distance, aIdx, bIdx) = queue.dequeue()
|
||||
|
||||
// println(s"Dequeue: $aIdx, $bIdx")
|
||||
|
||||
(clusters(aIdx), clusters(bIdx)) match {
|
||||
case (Some((aPos, aCluster)), Some((bPos, bCluster))) =>
|
||||
{
|
||||
val centroid = measure.centroid((aCluster.elements ++ bCluster.elements).map { _.position }.toSeq)
|
||||
// println(s"Merging \n $aIdx [${aCluster.elements.mkString(",")}] and \n $bIdx [${bCluster.elements.mkString(",")}]\n (at distance $distance -> [${centroid.mkString(",")}]) into ${clusters.size}")
|
||||
val radius = (aCluster.elements ++ bCluster.elements).map { e =>
|
||||
measure.pointToPoint(centroid, e.position)
|
||||
}.max
|
||||
// println(s"Remove: $aIdx @ $aPos")
|
||||
tree.remove(aPos, aIdx)
|
||||
// println(s"Remove: $bIdx @ $bPos")
|
||||
tree.remove(bPos, bIdx)
|
||||
clusters(aIdx) = None
|
||||
clusters(bIdx) = None
|
||||
|
||||
val newCluster =
|
||||
Group(
|
||||
left = aCluster,
|
||||
right = bCluster,
|
||||
radius = radius,
|
||||
size = aCluster.size + bCluster.size
|
||||
)
|
||||
|
||||
if(tree.isEmpty){
|
||||
return newCluster
|
||||
} else {
|
||||
val newIdx = clusters.size
|
||||
clusters.append( Some(centroid, newCluster) )
|
||||
|
||||
val (nearestPosition, nearestIdx, nearestDistance) =
|
||||
nearest(centroid, newIdx).get
|
||||
|
||||
// println(s"Insert $newIdx -> ${centroid.mkString(",")}")
|
||||
tree.insert(centroid, newIdx)
|
||||
// println(tree)
|
||||
|
||||
queue.enqueue(
|
||||
(nearestDistance, newIdx, nearestIdx.head)
|
||||
)
|
||||
lastCluster = newCluster
|
||||
}
|
||||
}
|
||||
case (Some( (aPos, aCluster) ), None) =>
|
||||
{
|
||||
// b got cleared out, find the next nearest point to a.
|
||||
val (nearestPosition, nearestIdx, nearestDistance) =
|
||||
nearest(aPos, aIdx).get
|
||||
// println(s"Dequeue $aIdx, $bIdx -> Redirect to $nearestIdx @ $nearestDistance")
|
||||
queue.enqueue(
|
||||
(nearestDistance, aIdx, nearestIdx.head)
|
||||
)
|
||||
// println(queue.mkString("\n "))
|
||||
// println(tree)
|
||||
}
|
||||
case _ =>
|
||||
{
|
||||
// println(s"Skipping $aIdx and $bIdx")
|
||||
() // if a or b were removed, we're done with them
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
???
|
||||
}
|
||||
|
||||
|
||||
class ClusterIterator[I, V](root: Cluster[I, V])
|
||||
extends Iterator[Cluster[I, V]]
|
||||
{
|
||||
val queue = PriorityQueue(root)(new Ordering[Cluster[I, V]]{
|
||||
def compare(a: Cluster[I, V], b: Cluster[I, V]) =
|
||||
Ordering[Double].compare(a.radius, b.radius)
|
||||
})
|
||||
|
||||
def hasNext: Boolean = !queue.isEmpty
|
||||
def next(): Cluster[I, V] =
|
||||
{
|
||||
val ret = queue.dequeue()
|
||||
ret.children.foreach { queue.enqueue(_) }
|
||||
return ret
|
||||
}
|
||||
}
|
||||
|
||||
class ClusterElementIterator[I, V](root: Cluster[I, V])
|
||||
extends Iterator[Singleton[I, V]]
|
||||
{
|
||||
val stack = Stack[Group[I, V]]()
|
||||
var nextSingleton: Singleton[I, V] = null
|
||||
pushLeft(root)
|
||||
|
||||
def pushLeft(node: Cluster[I, V]): Unit =
|
||||
{
|
||||
var tmp = node;
|
||||
while(tmp != null){
|
||||
tmp = tmp match {
|
||||
case g: Group[I, V] =>
|
||||
{
|
||||
stack.push(g)
|
||||
g.left
|
||||
}
|
||||
case s: Singleton[I, V] =>
|
||||
{
|
||||
nextSingleton = s
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def hasNext: Boolean = (nextSingleton != null)
|
||||
def next: Singleton[I, V] =
|
||||
{
|
||||
val ret = nextSingleton
|
||||
if(stack.isEmpty){ nextSingleton = null }
|
||||
else { pushLeft(stack.pop.right) }
|
||||
return ret
|
||||
}
|
||||
}
|
||||
|
||||
sealed trait Cluster[I, V] extends Serializable
|
||||
{
|
||||
def radius: Double
|
||||
def render(firstPrefix: String, restPrefix: String): String
|
||||
def size: Int
|
||||
def children: Seq[Cluster[I, V]]
|
||||
def orderedIterator = new ClusterIterator(this)
|
||||
def elements = new ClusterElementIterator[I, V](this)
|
||||
def threshold(cutoff: Double): Iterator[Cluster[I, V]] =
|
||||
{
|
||||
assert(cutoff > 0.0)
|
||||
val queue =
|
||||
PriorityQueue()(new Ordering[Cluster[I, V]] {
|
||||
def compare(a: Cluster[I, V], b: Cluster[I, V]): Int =
|
||||
Ordering[Double].compare(a.radius, b.radius)
|
||||
})
|
||||
queue.enqueue(this)
|
||||
while(!queue.isEmpty && queue.head.radius > cutoff)
|
||||
{
|
||||
queue.dequeue() match {
|
||||
case g:Group[I, V] => {
|
||||
queue.enqueue(g.left)
|
||||
queue.enqueue(g.right)
|
||||
}
|
||||
case _ => assert(false, "Singleton with Radius <= 0.0")
|
||||
}
|
||||
}
|
||||
queue.iterator
|
||||
}
|
||||
override def toString(): String = render("", "")
|
||||
}
|
||||
|
||||
case class Group[I, V](left: Cluster[I,V], right: Cluster[I, V], radius: Double, size: Int) extends Cluster[I, V]
|
||||
{
|
||||
def children = Seq(left, right)
|
||||
def render(firstPrefix: String, restPrefix: String): String =
|
||||
firstPrefix + s"- [$radius]\n" +
|
||||
left.render(restPrefix + " +-", restPrefix + " | ") + "\n" +
|
||||
right.render(restPrefix + " +-", restPrefix + " ")
|
||||
}
|
||||
case class Singleton[I, V](position: Array[I], value: V) extends Cluster[I, V]
|
||||
{
|
||||
def children = Seq()
|
||||
def radius = 0.0
|
||||
def size = 1
|
||||
def render(firstPrefix: String, restPrefix: String): String =
|
||||
firstPrefix + " <" + position.mkString(", ") + s"> -> $value"
|
||||
|
||||
}
|
||||
|
||||
def encode[I,V](cluster: Cluster[I,V]): Array[Byte] =
|
||||
{
|
||||
ByteArrayUtils.encode { buf =>
|
||||
val todos = mutable.Stack[Either[Cluster[I,V], Group[I,V]]](Left(cluster))
|
||||
|
||||
while(!todos.isEmpty)
|
||||
{
|
||||
todos.pop match {
|
||||
case Left(v: Singleton[I,V]) =>
|
||||
buf.writeInt(1)
|
||||
buf.writeObject(v.position)
|
||||
buf.writeObject(v.value)
|
||||
case Left(v: Group[I,V]) =>
|
||||
todos.push(Right(v))
|
||||
todos.push(Left(v.left))
|
||||
todos.push(Left(v.right))
|
||||
case Right(v) =>
|
||||
buf.writeInt(2)
|
||||
buf.writeDouble(v.radius)
|
||||
buf.writeInt(v.size)
|
||||
}
|
||||
}
|
||||
buf.writeInt(3)
|
||||
}
|
||||
}
|
||||
|
||||
def decode[I,V](data: Array[Byte]): Cluster[I,V] =
|
||||
{
|
||||
ByteArrayUtils.decode(data) { buf =>
|
||||
val decoded = mutable.Stack[Cluster[I,V]]()
|
||||
var mode = buf.readInt()
|
||||
|
||||
while(mode != 3)
|
||||
{
|
||||
mode match {
|
||||
case 1 => {
|
||||
val position = buf.readObject().asInstanceOf[Array[I]]
|
||||
val value = buf.readObject().asInstanceOf[V]
|
||||
decoded.push(Singleton(position, value))
|
||||
}
|
||||
case 2 => {
|
||||
val left = decoded.pop()
|
||||
val right = decoded.pop()
|
||||
val radius = buf.readDouble()
|
||||
val size = buf.readInt()
|
||||
decoded.push(Group(left, right, radius, size))
|
||||
}
|
||||
}
|
||||
mode = buf.readInt()
|
||||
}
|
||||
|
||||
decoded.pop()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
object bottomUpUdf {
|
||||
def apply(elements: (Array[(Array[Double], String)]), min: Double, max: Double, method: String): Array[Byte] =
|
||||
encode(
|
||||
HierarchicalClustering.bottomUp(elements,
|
||||
method match {
|
||||
case "manhattan" => Distance.Manhattan(
|
||||
dimensions = elements.head._1.size,
|
||||
min = min,
|
||||
max = max,
|
||||
)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def udf = functions.udf(apply(_, _, _, _))
|
||||
}
|
||||
|
||||
object listThresholdsUdf {
|
||||
def apply(clusters: Array[Byte]): Array[Double] =
|
||||
decode[Array[Double],String](clusters)
|
||||
.orderedIterator
|
||||
.map { x => x.radius }
|
||||
.takeWhile { _ > 0.0 }
|
||||
.toArray
|
||||
|
||||
def udf = functions.udf(apply(_))
|
||||
}
|
||||
|
||||
object extractClustersUdf {
|
||||
def apply(clusters: Array[Byte], threshold: Double): Array[(String, Int)] =
|
||||
decode[Array[Double],String](clusters)
|
||||
.threshold(threshold)
|
||||
.zipWithIndex
|
||||
.flatMap { case (cluster, idx) =>
|
||||
cluster.elements.map { element =>
|
||||
(element.value, idx)
|
||||
}
|
||||
}
|
||||
.toArray
|
||||
|
||||
def udf = functions.udf(apply(_, _))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,422 @@
|
|||
package org.mimirdb.pip.lib
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
class KDTree[I: Ordering, V](dimensions: Int)(implicit iTag: ClassTag[I])
|
||||
{
|
||||
type Key = Array[I]
|
||||
var root: Option[Node] = None
|
||||
private var _size = 0
|
||||
|
||||
def build(op: Builder => Unit): KDTree[I, V] =
|
||||
{
|
||||
val builder = new Builder
|
||||
op(builder)
|
||||
root = Some(builder.stack.pop())
|
||||
return this
|
||||
}
|
||||
|
||||
def insert(position: Key, value: V): Unit =
|
||||
{
|
||||
assert(position.size == dimensions)
|
||||
root = Some(root match {
|
||||
case Some(node) => node.insert(position, value)
|
||||
case None => Leaf(position, Set(value), 0)
|
||||
})
|
||||
_size = _size + 1
|
||||
}
|
||||
|
||||
def insertAll(elements: Iterable[(Key, V)]): Unit =
|
||||
{
|
||||
elements.foreach { el => assert(el._1.size == dimensions)}
|
||||
|
||||
if(elements.isEmpty) { return }
|
||||
|
||||
root = Some(root match {
|
||||
case Some(node) => node.insertAll(elements.toSeq)
|
||||
case None => fragmentSeq(elements.toSeq, 0)
|
||||
})
|
||||
_size = _size + elements.size
|
||||
}
|
||||
|
||||
def remove(position: Key, value: V): Boolean =
|
||||
{
|
||||
root match {
|
||||
case None => return false
|
||||
case Some(r) =>
|
||||
{
|
||||
val (newRoot, ret) = r.remove(position, value)
|
||||
root = newRoot
|
||||
_size = _size - 1
|
||||
return ret
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def isEmpty = root.isEmpty
|
||||
def size = _size
|
||||
|
||||
def nearest(position: Key, measure: Distance[I], ignore: (Array[I], Set[V], Double) => Boolean = {(_, _, _) => false}): Option[(Key, Set[V], Double)] =
|
||||
{
|
||||
root.flatMap { _.nearest(
|
||||
position = position,
|
||||
measure = measure,
|
||||
best = None,
|
||||
ignore = ignore
|
||||
) }.map { case (found, distance) =>
|
||||
( found.position, found.values, distance)
|
||||
}
|
||||
}
|
||||
def get(position: Key): Seq[V] =
|
||||
{
|
||||
root.toSeq.flatMap { _.get(position) }
|
||||
}
|
||||
|
||||
def depth(): Int =
|
||||
root.map { _.depth() }.getOrElse(0)
|
||||
|
||||
override def toString(): String =
|
||||
root.map { _.toString("", "") }.getOrElse("[Empty Tree]")
|
||||
|
||||
private def fragmentSeq(elements: Seq[(Key, V)], dimBase: Int): Node =
|
||||
{
|
||||
assert(!elements.isEmpty)
|
||||
if(elements.size == 1){
|
||||
return Leaf(elements(0)._1, Set(elements(0)._2), dimBase)
|
||||
}
|
||||
for(i <- 0 until dimensions)
|
||||
{
|
||||
val dim = (dimBase + i) % dimensions
|
||||
val sorted =
|
||||
elements.sorted(new Ordering[(Key, V)] {
|
||||
def compare(a: (Key, V), b: (Key, V)): Int =
|
||||
Ordering[I].compare(a._1(dim), b._1(dim))
|
||||
})
|
||||
val mid = (sorted.size+1) / 2 // +1 to round up.
|
||||
|
||||
val midSplit = sorted(mid)._1(dim)
|
||||
|
||||
val (leftElements, rightElements) =
|
||||
sorted.partition { case (position, v) =>
|
||||
Ordering[I].compare(position(dim), midSplit) < 0
|
||||
}
|
||||
|
||||
if(!leftElements.isEmpty && !rightElements.isEmpty){
|
||||
return Inner(
|
||||
split = midSplit,
|
||||
dim = dim,
|
||||
left = fragmentSeq(leftElements, nextDim(dim)),
|
||||
right = fragmentSeq(rightElements, nextDim(dim))
|
||||
)
|
||||
} else {
|
||||
val newSplit: Option[I] =
|
||||
leftElements.reverse.find { case (position, v) =>
|
||||
position(dim) != midSplit
|
||||
}.orElse { rightElements.find { case (position, v) =>
|
||||
position(dim) != midSplit
|
||||
}}.map { _._1(dim) }
|
||||
|
||||
if(newSplit.isDefined){
|
||||
val (leftElements, rightElements) =
|
||||
sorted.partition { case (position, v) =>
|
||||
Ordering[I].compare(position(dim), newSplit.get) < 0
|
||||
}
|
||||
assert(!leftElements.isEmpty)
|
||||
assert(!rightElements.isEmpty)
|
||||
return Inner(
|
||||
split = newSplit.get,
|
||||
dim = dim,
|
||||
left = fragmentSeq(leftElements, nextDim(dim)),
|
||||
right = fragmentSeq(rightElements, nextDim(dim))
|
||||
)
|
||||
}
|
||||
// If we get to this point, all of our element positions are
|
||||
// equal on the given dimension. Fall through the loop to try
|
||||
// the next dimension
|
||||
}
|
||||
}
|
||||
// If we get to this point, all of our element positions are equal
|
||||
// on all dimensions. Aggregate them into a single leaf
|
||||
Leaf(
|
||||
elements(0)._1,
|
||||
elements.map { _._2 }.toSet,
|
||||
dim = dimBase
|
||||
)
|
||||
}
|
||||
|
||||
@inline def nextDim(dim: Int) =
|
||||
(dim+1)%dimensions
|
||||
|
||||
def keysEqual(a: Array[I], b: Array[I]): Boolean =
|
||||
{
|
||||
assert(a.size == b.size)
|
||||
for(i <- 0 until a.size) { if(a(i) != b(i)) { return false } }
|
||||
return true
|
||||
}
|
||||
|
||||
sealed trait Node
|
||||
{
|
||||
def insert(position: Key, value: V): Node
|
||||
def insertAll(elements: Seq[(Key, V)]): Node
|
||||
def remove(position: Key, value: V): (Option[Node], Boolean)
|
||||
def nearest(position: Key, measure: Distance[I], best: Option[(Leaf, Double)], ignore: (Array[I], Set[V], Double) => Boolean): Option[(Leaf, Double)]
|
||||
def get(position: Key): Seq[V]
|
||||
def depth(): Int
|
||||
def toString(firstPrefix: String, restPrefix: String): String
|
||||
}
|
||||
|
||||
case class Inner(
|
||||
split: I,
|
||||
dim: Int,
|
||||
var left: Node,
|
||||
var right: Node
|
||||
) extends Node
|
||||
{
|
||||
|
||||
def insert(position: Key, value: V): Node =
|
||||
{
|
||||
if(Ordering[I].compare(position(dim), split) >= 0){
|
||||
// position(dim) >= split
|
||||
right = right.insert(position, value)
|
||||
} else {
|
||||
// position(dim) < split
|
||||
left = left.insert(position, value)
|
||||
}
|
||||
return this
|
||||
}
|
||||
def insertAll(elements: Seq[(Key, V)]): Node =
|
||||
{
|
||||
if(elements.size == 1){ insert(elements.head._1, elements.head._2) }
|
||||
else {
|
||||
val (forLeft, forRight) =
|
||||
elements.partition { case (k, v) =>
|
||||
Ordering[I].compare(k(dim), split) < 0
|
||||
}
|
||||
if(forLeft.size > 0){
|
||||
left = left.insertAll(forLeft)
|
||||
}
|
||||
if(forRight.size > 0){
|
||||
right = right.insertAll(forRight)
|
||||
}
|
||||
this
|
||||
}
|
||||
}
|
||||
def remove(position: Key, value: V): (Option[Node], Boolean) =
|
||||
{
|
||||
// println(s"Remove $position")
|
||||
if(Ordering[I].compare(position(dim), split) < 0){
|
||||
left.remove(position, value) match {
|
||||
case (Some(newLeft), ret) =>
|
||||
{
|
||||
left = newLeft
|
||||
return (Some(this), ret)
|
||||
}
|
||||
case (None, ret) =>
|
||||
{
|
||||
return (Some(right), ret)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
right.remove(position, value) match {
|
||||
case (Some(newRight), ret) =>
|
||||
{
|
||||
right = newRight
|
||||
return (Some(this), ret)
|
||||
}
|
||||
case (None, ret) =>
|
||||
{
|
||||
return (Some(left), ret)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
def nearest(position: Key, measure: Distance[I], best: Option[(Leaf, Double)], ignore: (Array[I], Set[V], Double) => Boolean): Option[(Leaf, Double)] =
|
||||
{
|
||||
val (near, far) =
|
||||
if(Ordering[I].compare(position(dim), split) >= 0){
|
||||
(right, left)
|
||||
} else {
|
||||
(left, right)
|
||||
}
|
||||
|
||||
// position is right
|
||||
val candidate =
|
||||
near.nearest(
|
||||
position = position,
|
||||
measure = measure,
|
||||
best = best,
|
||||
ignore = ignore,
|
||||
)
|
||||
|
||||
val splitDistance =
|
||||
measure.pointToPlane(position, split, dim)
|
||||
|
||||
if(candidate.isEmpty || candidate.get._2 > splitDistance){
|
||||
return far.nearest(
|
||||
position = position,
|
||||
measure = measure,
|
||||
best = candidate,
|
||||
ignore = ignore,
|
||||
)
|
||||
} else {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
def get(position: Key): Seq[V] =
|
||||
if(Ordering[I].compare(position(dim), split) >= 0){
|
||||
right.get(position)
|
||||
} else {
|
||||
left.get(position)
|
||||
}
|
||||
|
||||
def depth(): Int =
|
||||
Math.max(left.depth(), right.depth())+1
|
||||
def toString(firstPrefix: String, restPrefix: String): String =
|
||||
firstPrefix + s" Dim[$dim] < $split\n" +
|
||||
left.toString(restPrefix+" +-", restPrefix+" | ") + "\n" +
|
||||
right.toString(restPrefix+" +-", restPrefix+" ")
|
||||
|
||||
}
|
||||
|
||||
|
||||
case class Leaf(
|
||||
position: Key,
|
||||
var values: Set[V],
|
||||
dim: Int
|
||||
) extends Node
|
||||
{
|
||||
def insert(otherPosition: Key, otherValue: V): Node =
|
||||
{
|
||||
// println(s"Insert $otherValue")
|
||||
Ordering[I].compare(position(dim), otherPosition(dim)) match {
|
||||
case d if d == 0 =>
|
||||
{
|
||||
// println(s"Same -> $values")
|
||||
val otherComparison =
|
||||
(1 until dimensions).foldLeft(
|
||||
None:Option[(Int, Int)]
|
||||
) {
|
||||
case (f@Some(_), _) => f
|
||||
case (None, i) =>
|
||||
val dimNew = (i+dim)%dimensions
|
||||
val d = Ordering[I].compare(position(dimNew), otherPosition(dimNew))
|
||||
if(d == 0){ None }
|
||||
else { Some(dimNew, d) }
|
||||
}
|
||||
|
||||
otherComparison match {
|
||||
case None => // identical position
|
||||
{
|
||||
// println(" Identical")
|
||||
this.values = this.values + otherValue
|
||||
return this
|
||||
}
|
||||
case Some( (dimNew, d) ) if d >= 0 =>
|
||||
{
|
||||
// println(s" Diff on dim $dimNew from $values (greater)")
|
||||
Inner(
|
||||
split = position(dimNew),
|
||||
dim = dimNew,
|
||||
left = Leaf(otherPosition, Set(otherValue), nextDim(dimNew)),
|
||||
right = copy(dim = nextDim(dimNew)),
|
||||
)
|
||||
}
|
||||
case Some( (dimNew, d) ) =>
|
||||
{
|
||||
// println(s" Diff on dim $dimNew from $values (lesser)")
|
||||
Inner(
|
||||
split = otherPosition(dimNew),
|
||||
dim = dimNew,
|
||||
left = copy(dim = nextDim(dimNew)),
|
||||
right = Leaf(otherPosition, Set(otherValue), nextDim(dimNew)),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
case d if d >= 0 =>
|
||||
{
|
||||
// println("Greater")
|
||||
return Inner(
|
||||
split = position(dim),
|
||||
dim = dim,
|
||||
left = Leaf(otherPosition, Set(otherValue), nextDim(dim)),
|
||||
right = copy(dim = nextDim(dim)),
|
||||
)
|
||||
}
|
||||
case _ =>
|
||||
{
|
||||
// println("Lesser")
|
||||
return Inner(
|
||||
split = otherPosition(dim),
|
||||
dim = dim,
|
||||
left = copy(dim = nextDim(dim)),
|
||||
right = Leaf(otherPosition, Set(otherValue), nextDim(dim)),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
def insertAll(elements: Seq[(Key, V)]): Node =
|
||||
{
|
||||
|
||||
fragmentSeq(values.toSeq.map { (position, _) } ++ elements, dim)
|
||||
}
|
||||
def remove(position: Key, value: V): (Option[Node], Boolean) =
|
||||
{
|
||||
// println(s"Remove $value")
|
||||
if(keysEqual(position, this.position) && this.values.contains(value) )
|
||||
{
|
||||
if(values.size == 1){
|
||||
// println("Replacing self with None")
|
||||
return (None, true)
|
||||
} else {
|
||||
// println(s"Before: ${this.values}")
|
||||
this.values = this.values - value
|
||||
// println(s"After: ${this.values}")
|
||||
return (Some(this), true)
|
||||
}
|
||||
} else {
|
||||
// println(s"Position $position vs ${this.position} or value is different: ${keysEqual(position, this.position)} ; ${this.values.contains(value)}")
|
||||
return (Some(this), false)
|
||||
}
|
||||
}
|
||||
def nearest(position: Key, measure: Distance[I], best: Option[(Leaf, Double)], ignore: (Array[I], Set[V], Double) => Boolean): Option[(Leaf, Double)] =
|
||||
{
|
||||
val distance = measure.pointToPoint(position, this.position)
|
||||
best match {
|
||||
// If this point is *exactly* at position, and we want distinct matches, skip it
|
||||
case _ if ignore(this.position, this.values, distance) => best
|
||||
// If there exists a better point that we've found so far, skip this point
|
||||
case Some( best ) if best._2 <= distance => Some(best)
|
||||
// ... otherwise return this point
|
||||
case _ => Some( (this, distance) )
|
||||
}
|
||||
}
|
||||
def get(position: Key): Seq[V] =
|
||||
{
|
||||
if(keysEqual(this.position, position)){ return this.values.toSeq }
|
||||
else { Seq.empty }
|
||||
}
|
||||
def depth(): Int = 1
|
||||
|
||||
def toString(firstPrefix: String, restPrefix: String): String =
|
||||
firstPrefix + s" <${position.mkString(", ")}> -> ${values.mkString(", ")}"
|
||||
}
|
||||
|
||||
class Builder
|
||||
{
|
||||
val stack = new mutable.Stack[Node]()
|
||||
def inner(split: I, dim: Int): Unit =
|
||||
{
|
||||
val right = stack.pop()
|
||||
val left = stack.pop()
|
||||
stack.push(
|
||||
Inner(split, dim, left, right)
|
||||
)
|
||||
}
|
||||
def leaf(position: Array[I], values: Set[V]): Unit =
|
||||
{
|
||||
stack.push( Leaf(position, values, 0) )
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
package org.mimirdb.pip.lib
|
||||
|
||||
object Time {
|
||||
def apply[V](log: Double => Unit)(f: => V): V =
|
||||
{
|
||||
val start = System.nanoTime()
|
||||
val ret = f
|
||||
val end = System.nanoTime()
|
||||
log( (end-start) / 1000000000.0 )
|
||||
ret
|
||||
}
|
||||
|
||||
def apply[V](label: String)(f: => V): V =
|
||||
apply( t => println(s"[$label] $t s") )(f)
|
||||
|
||||
case class Timer(label: String)
|
||||
{
|
||||
var tot = 0.0
|
||||
def apply[V](f: => V): V = Time.apply( tot += _ ){ f }
|
||||
}
|
||||
}
|
|
@ -1,8 +1,8 @@
|
|||
package org.mimirdb.pip.udf
|
||||
|
||||
import org.mimirdb.pip.udt.UnivariateDistribution
|
||||
import org.mimirdb.pip.distribution.Discretized
|
||||
import org.mimirdb.pip.distribution.NumericalDistributionFamily
|
||||
import org.mimirdb.pip.distribution.numerical.Discretized
|
||||
import org.mimirdb.pip.distribution.numerical.NumericalDistributionFamily
|
||||
import org.apache.spark.sql.functions
|
||||
|
||||
object Entropy
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
package org.mimirdb.pip.udf
|
||||
|
||||
import org.mimirdb.pip.udt.UnivariateDistribution
|
||||
import org.mimirdb.pip.udt.UnivariateDistributionType
|
||||
import org.mimirdb.pip.distribution.numerical.Discretized
|
||||
import org.mimirdb.pip.distribution.numerical.NumericalDistributionFamily
|
||||
import org.apache.spark.sql.functions
|
||||
|
||||
object Export
|
||||
{
|
||||
def apply(dist: UnivariateDistribution): Array[Byte] =
|
||||
{
|
||||
UnivariateDistributionType.serialize(dist)
|
||||
}
|
||||
|
||||
def udf = functions.udf(apply(_))
|
||||
}
|
|
@ -1,13 +1,13 @@
|
|||
package org.mimirdb.pip.udf
|
||||
|
||||
import org.mimirdb.pip.udt.UnivariateDistribution
|
||||
import org.mimirdb.pip.distribution.Discretized
|
||||
import org.mimirdb.pip.distribution.NumericalDistributionFamily
|
||||
import org.mimirdb.pip.distribution.numerical.Discretized
|
||||
import org.mimirdb.pip.distribution.numerical.NumericalDistributionFamily
|
||||
import org.apache.spark.sql.functions
|
||||
|
||||
object KLDivergence
|
||||
{
|
||||
val BUCKETS = 1000
|
||||
val BUCKETS = 100
|
||||
|
||||
def apply(target: UnivariateDistribution, base: UnivariateDistribution): Double =
|
||||
{
|
||||
|
@ -16,13 +16,46 @@ object KLDivergence
|
|||
Discretized.klDivergence(target.params, base.params)
|
||||
case (_:NumericalDistributionFamily, Discretized) =>
|
||||
Discretized.klDivergence(
|
||||
Discretized(target, Discretized.bins(base.params), 1000),
|
||||
Discretized(target, Discretized.bins(base.params), BUCKETS),
|
||||
base.params
|
||||
)
|
||||
case (Discretized, _:NumericalDistributionFamily) =>
|
||||
Discretized.klDivergence(
|
||||
target.params,
|
||||
Discretized(base, Discretized.bins(target.params), 1000),
|
||||
Discretized(base, Discretized.bins(target.params), BUCKETS),
|
||||
)
|
||||
case (targetFam:NumericalDistributionFamily, baseFam:NumericalDistributionFamily) =>
|
||||
val min =
|
||||
Math.max(
|
||||
Double.MinValue,
|
||||
Math.min(
|
||||
targetFam.min(target.params),
|
||||
baseFam.min(base.params),
|
||||
)
|
||||
)
|
||||
val max =
|
||||
{
|
||||
val tmp =
|
||||
Math.min(
|
||||
Double.MaxValue,
|
||||
Math.max(
|
||||
targetFam.max(target.params),
|
||||
baseFam.max(base.params),
|
||||
)
|
||||
)
|
||||
if(min < tmp){ tmp }
|
||||
else { min + 1.0 }
|
||||
}
|
||||
|
||||
// println(s"BINS FROM $min - $max")
|
||||
val bins =
|
||||
(0 to BUCKETS).map { i =>
|
||||
min + ((max - min) / BUCKETS) * i
|
||||
}.toArray
|
||||
|
||||
Discretized.klDivergence(
|
||||
Discretized(target, bins, BUCKETS),
|
||||
Discretized(base, bins, BUCKETS),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
package org.mimirdb.pip.udf
|
||||
|
||||
import org.mimirdb.pip.udt.UnivariateDistribution
|
||||
import org.mimirdb.pip.distribution.numerical.Discretized
|
||||
import org.mimirdb.pip.distribution.numerical.NumericalDistributionFamily
|
||||
import org.mimirdb.pip.distribution.boolean.Between
|
||||
import org.apache.spark.sql.functions
|
||||
|
||||
object pip_min
|
||||
{
|
||||
def apply(target: UnivariateDistribution): Double =
|
||||
{
|
||||
target.family.asInstanceOf[NumericalDistributionFamily].min(target.params)
|
||||
}
|
||||
|
||||
def udf = functions.udf(apply(_))
|
||||
}
|
||||
object pip_max
|
||||
{
|
||||
def apply(target: UnivariateDistribution): Double =
|
||||
{
|
||||
target.family.asInstanceOf[NumericalDistributionFamily].max(target.params)
|
||||
}
|
||||
|
||||
def udf = functions.udf(apply(_))
|
||||
}
|
||||
|
||||
object pip_p_between
|
||||
{
|
||||
def apply(low: Double, target: UnivariateDistribution, high: Double): Double =
|
||||
{
|
||||
Between.approximateProbability(
|
||||
Between.Params(low, high, target),
|
||||
1000
|
||||
)
|
||||
}
|
||||
|
||||
def udf = functions.udf(apply(_, _, _))
|
||||
}
|
||||
|
||||
object pip_histogram
|
||||
{
|
||||
def apply(low: Double, high: Double, buckets: Int, target: UnivariateDistribution): Array[Double] =
|
||||
{
|
||||
val bucketStep = (high - low) / (buckets)
|
||||
(0 until buckets).map { idx =>
|
||||
pip_p_between(
|
||||
bucketStep * (idx + 0.0),
|
||||
target,
|
||||
bucketStep * (idx + 1.0),
|
||||
)
|
||||
}.toArray
|
||||
}
|
||||
|
||||
def udf = functions.udf(apply(_, _, _, _))
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
package org.mimirdb.pip.util
|
||||
|
||||
import java.io.{
|
||||
ByteArrayOutputStream,
|
||||
ByteArrayInputStream,
|
||||
ObjectOutputStream,
|
||||
ObjectInputStream,
|
||||
}
|
||||
|
||||
object ByteArrayUtils
|
||||
{
|
||||
def serialize(o: Any): Array[Byte] =
|
||||
encode { (_: ObjectOutputStream).writeObject(o) }
|
||||
|
||||
def encode( op: ObjectOutputStream => Unit ): Array[Byte] =
|
||||
{
|
||||
val byteBuffer = new ByteArrayOutputStream()
|
||||
val out = new ObjectOutputStream(byteBuffer)
|
||||
op(out)
|
||||
out.flush()
|
||||
return byteBuffer.toByteArray()
|
||||
}
|
||||
|
||||
def deserialize[T <: Serializable](data: Array[Byte]): T =
|
||||
decode(data)((_:ObjectInputStream).readObject().asInstanceOf[T])
|
||||
|
||||
def decode[T](data: Array[Byte])(op: ObjectInputStream => T): T =
|
||||
{
|
||||
val bis = new ByteArrayInputStream(data)
|
||||
val in = new ObjectInputStream(bis)
|
||||
op(in)
|
||||
}
|
||||
|
||||
}
|
|
@ -39,7 +39,7 @@ class CreateGaussObject extends AnyFlatSpec {
|
|||
)
|
||||
|
||||
"A dataframe of Gauss objects" should "have values of type Gauss" in {
|
||||
assert(dfGaussObj.schema("gObj").dataType == RandomVariableType)
|
||||
assert(dfGaussObj.schema("gObj").dataType == udt.RandomVariableType)
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
package org.mimirdb.pip
|
||||
|
||||
import java.io.BufferedWriter
|
||||
import java.io.FileWriter
|
||||
import java.io.File
|
||||
|
||||
object Test
|
||||
{
|
||||
def log[T](label: String)(f: ((String => Unit) => T)): T =
|
||||
{
|
||||
val outFile = new File(s"out/log/$label")
|
||||
outFile.getParentFile().mkdirs()
|
||||
val out = new BufferedWriter(new FileWriter(outFile))
|
||||
val ret = f(m => {
|
||||
val newline = if(m.endsWith("\n")){ "" } else { "\n" }
|
||||
out.write(m+newline)
|
||||
})
|
||||
out.close()
|
||||
return ret
|
||||
}
|
||||
}
|
|
@ -0,0 +1,131 @@
|
|||
package org.mimirdb.pip.lib
|
||||
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
import TestData._
|
||||
import org.mimirdb.pip.distribution.Discretized
|
||||
import org.mimirdb.pip.Test
|
||||
|
||||
class HierarchicalClusteringTests extends AnyFlatSpec {
|
||||
|
||||
def testRadii(cluster: HierarchicalClustering.Cluster[Double, Int], parentRadius: Double): Unit =
|
||||
{
|
||||
cluster match {
|
||||
case _:HierarchicalClustering.Singleton[Double, Int] => ()
|
||||
case g:HierarchicalClustering.Group[Double, Int] =>
|
||||
{
|
||||
// Adopt an approximation for now
|
||||
assert(g.radius <= parentRadius * 1.2)
|
||||
g.children.foreach { testRadii(_, g.radius) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val Strategies =
|
||||
Seq[(String, (IndexedSeq[(Array[Double], Int)], Distance[Double]) => HierarchicalClustering.Cluster[Double, Int])](
|
||||
"Naive" -> { HierarchicalClustering.naive(_, _) },
|
||||
"BottomUp" -> { HierarchicalClustering.bottomUp(_, _) },
|
||||
)
|
||||
|
||||
"Clustering" should "be correct" in {
|
||||
for( (strategy, makeClusters) <- Strategies ){
|
||||
for(Measure <- Seq(ManhattanDistance, EuclideanDistance))
|
||||
{
|
||||
val test = Measure.getClass.getSimpleName.dropRight(9)
|
||||
val clusters =
|
||||
Time(s"$strategy-Cluster-$test"){
|
||||
makeClusters(
|
||||
TEST_POINTS.toIndexedSeq.zipWithIndex,
|
||||
Measure
|
||||
)
|
||||
}
|
||||
testRadii(clusters, Double.PositiveInfinity)
|
||||
var lastRadius = Double.PositiveInfinity
|
||||
|
||||
Test.log(s"$strategy-Cluster-$test-Full"){ log =>
|
||||
log(s"cluster_radius, kl_divergence")
|
||||
for(c <- clusters.orderedIterator) {
|
||||
if(c.radius > 0)
|
||||
{
|
||||
val centroid =
|
||||
positionToBins(
|
||||
Measure.centroid(c.elements.map { _.position }.toSeq)
|
||||
)
|
||||
val divergence =
|
||||
c.elements.map { e =>
|
||||
val bins = positionToBins(e.position)
|
||||
Discretized.klDivergence(bins, centroid)
|
||||
}.max
|
||||
log(s"${c.radius}, $divergence")
|
||||
(c.radius, divergence)
|
||||
}
|
||||
}
|
||||
}
|
||||
Test.log(s"$strategy-Cluster-$test-Cutoff_0.3"){ log =>
|
||||
log(s"cluster_radius, kl_divergence")
|
||||
for(c <- clusters.threshold(0.3)) {
|
||||
if(c.radius > 0)
|
||||
{
|
||||
val centroid =
|
||||
positionToBins(
|
||||
Measure.centroid(c.elements.map { _.position }.toSeq)
|
||||
)
|
||||
val divergence =
|
||||
c.elements.map { e =>
|
||||
val bins = positionToBins(e.position)
|
||||
Discretized.klDivergence(bins, centroid)
|
||||
}.max
|
||||
log(s"${c.radius}, $divergence")
|
||||
(c.radius, divergence)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
it should "run fast" in {
|
||||
val data = TestData.makeData(500).zipWithIndex
|
||||
Time("Hierarchical@500"){
|
||||
HierarchicalClustering.bottomUp(data, EuclideanDistance)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
it should "perform sensibly" in {
|
||||
// BURN IN
|
||||
val data = TestData.makeData(200).zipWithIndex
|
||||
for( (strategy, makeClusters) <- Strategies ){
|
||||
makeClusters(data, EuclideanDistance)
|
||||
}
|
||||
|
||||
val datasets = Seq(
|
||||
100,
|
||||
250,
|
||||
500,
|
||||
1000,
|
||||
2500,
|
||||
5000,
|
||||
10000,
|
||||
// 25000,
|
||||
// 50000
|
||||
).map { s =>
|
||||
s -> TestData.makeData(s).zipWithIndex
|
||||
}
|
||||
|
||||
for( (strategy, makeClusters) <- Strategies ){
|
||||
Test.log(s"$strategy-Time"){ log =>
|
||||
log("size, time_s")
|
||||
for( (size, data) <- datasets )
|
||||
{
|
||||
val skip = (strategy == "Naive" && size >= 1000)
|
||||
if(!skip){
|
||||
Time(t => { log(s"$size, $t"); println(s"$strategy @ $size: $t s") }){
|
||||
makeClusters(data, EuclideanDistance)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
package org.mimirdb.pip.lib
|
||||
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
import TestData._
|
||||
|
||||
class KDTreeTests extends AnyFlatSpec {
|
||||
|
||||
"The KD Tree" should "be correct" in {
|
||||
val tree = new KDTree[Double, Int](DIMENSIONS)
|
||||
|
||||
tree.insertAll(TEST_POINTS.zipWithIndex)
|
||||
|
||||
// Roughly balanced (100 nodes)
|
||||
assert(tree.depth < 10)
|
||||
|
||||
// Nearest point to a point in the tree should be itself
|
||||
assert(tree.nearest(TEST_POINTS(0), ManhattanDistance).get._2 == 0)
|
||||
|
||||
// Nearest points to random points should be the
|
||||
// actual minimum
|
||||
for( (q, qid) <- TEST_QUERIES.zipWithIndex){
|
||||
|
||||
// Compute the point-to-point distance between the
|
||||
// query and every other point, and then recover
|
||||
// the idx of the closest point.
|
||||
val actualClosestId =
|
||||
TEST_POINTS.map {
|
||||
ManhattanDistance.pointToPoint(_, q)
|
||||
}.zipWithIndex.minBy( _._1 )._2
|
||||
|
||||
// Use the tree to find the nearest point
|
||||
val nearestId =
|
||||
tree.nearest(q, ManhattanDistance).get
|
||||
|
||||
assert(nearestId._2 == actualClosestId,
|
||||
s"Query $qid not closest"
|
||||
)
|
||||
}
|
||||
|
||||
val AMOUNT_TO_KEEP = 20
|
||||
val keep = TEST_POINTS.slice(0, AMOUNT_TO_KEEP)
|
||||
val drop = TEST_POINTS.slice(AMOUNT_TO_KEEP, TEST_POINTS.size)
|
||||
|
||||
for( (p, i) <- drop.zipWithIndex )
|
||||
{
|
||||
tree.remove(p, i+AMOUNT_TO_KEEP)
|
||||
}
|
||||
|
||||
// println(tree.toString)
|
||||
|
||||
// Nearest points to random points should still be the
|
||||
// actual minimum
|
||||
for( (q, qid) <- TEST_QUERIES.zipWithIndex){
|
||||
|
||||
// Compute the point-to-point distance between the
|
||||
// query and every other point, and then recover
|
||||
// the idx of the closest point.
|
||||
val actualClosestId =
|
||||
keep.map {
|
||||
ManhattanDistance.pointToPoint(_, q)
|
||||
}.zipWithIndex.minBy( _._1 )._2
|
||||
|
||||
// Use the tree to find the nearest point
|
||||
val nearestId =
|
||||
tree.nearest(q, ManhattanDistance).get
|
||||
|
||||
assert(nearestId._2 == actualClosestId,
|
||||
s"Query $qid not closest"
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
package org.mimirdb.pip.lib
|
||||
|
||||
object Scatterplot
|
||||
{
|
||||
def apply(elems: Iterable[(Double, Double)],
|
||||
width: Int = 80,
|
||||
height: Int = 40
|
||||
): String =
|
||||
{
|
||||
val buffer =
|
||||
Array.fill(height){ Array.fill(width)(" ") }
|
||||
|
||||
val minX = Math.min(elems.map { _._1 }.min, 0)
|
||||
val maxX = Math.max(elems.map { _._1 }.max, 0)
|
||||
val offsetX = maxX - minX
|
||||
|
||||
val minY = Math.min(elems.map { _._2 }.min, 0)
|
||||
val maxY = Math.max(elems.map { _._2 }.max, 0)
|
||||
val offsetY = maxY - minY
|
||||
|
||||
def xToScreen(x: Double): Int =
|
||||
((x - minX) / offsetX * (width-1)).toInt
|
||||
def yToScreen(y: Double): Int =
|
||||
((y - minY) / offsetY * (height-1)).toInt
|
||||
|
||||
val x0 = xToScreen(0.0)
|
||||
for( y <- 0 until height )
|
||||
{
|
||||
buffer(y)(x0) = "|"
|
||||
}
|
||||
val y0 = yToScreen(0.0)
|
||||
for( x <- 0 until width )
|
||||
{
|
||||
buffer(y0)(x) = "-"
|
||||
}
|
||||
buffer(y0)(x0) = "+"
|
||||
|
||||
for( (x, y) <- elems ){
|
||||
buffer(yToScreen(y))(xToScreen(x)) = "*"
|
||||
}
|
||||
|
||||
return buffer.reverse.map { _.mkString }.mkString("\n")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,204 @@
|
|||
package org.mimirdb.pip.lib
|
||||
import org.mimirdb.pip.distribution.Discretized
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
object TestData
|
||||
{
|
||||
|
||||
val DIMENSIONS = 10
|
||||
val BIN_SIZE = 1.0 / DIMENSIONS
|
||||
|
||||
def positionToBins(position: Array[Double]): Seq[Discretized.Bin] =
|
||||
{
|
||||
position.zipWithIndex.map { case (b, i) =>
|
||||
Discretized.Bin(i * BIN_SIZE, (i+1) * BIN_SIZE, b)
|
||||
}.toSeq
|
||||
}
|
||||
|
||||
trait BaseDistance extends Distance[Double]
|
||||
{
|
||||
val min = 0.0
|
||||
val max = 1.0
|
||||
def pointToPlane(a: Array[Double], b: Double, dim: Int): Double =
|
||||
{
|
||||
Math.abs(a(dim) - b)
|
||||
}
|
||||
def centroid(a: Iterable[Array[Double]]): Array[Double] =
|
||||
{
|
||||
val ret = Array.ofDim[Double](DIMENSIONS)
|
||||
for(pt <- a)
|
||||
{
|
||||
for(i <- 0 until DIMENSIONS)
|
||||
{
|
||||
ret(i) += pt(i)
|
||||
}
|
||||
}
|
||||
for(i <- 0 until DIMENSIONS)
|
||||
{
|
||||
ret(i) /= a.size
|
||||
}
|
||||
return ret
|
||||
}
|
||||
}
|
||||
object ManhattanDistance extends BaseDistance
|
||||
{
|
||||
def pointToPoint(a: Array[Double], b: Array[Double]): Double =
|
||||
{
|
||||
var tot = 0.0
|
||||
for(i <- 0 until a.size)
|
||||
{
|
||||
tot += Math.abs(a(i) - b(i))
|
||||
}
|
||||
return tot
|
||||
}
|
||||
}
|
||||
object EuclideanDistance extends BaseDistance
|
||||
{
|
||||
def pointToPoint(a: Array[Double], b: Array[Double]): Double =
|
||||
{
|
||||
var tot = 0.0
|
||||
for(i <- 0 until a.size)
|
||||
{
|
||||
val v = Math.abs(a(i) - b(i))
|
||||
tot += v * v
|
||||
}
|
||||
Math.sqrt(tot)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// from random import random
|
||||
// print("\n".join(
|
||||
// "Array(" + ", ".join(normalized)+"),"
|
||||
// for i in range(100)
|
||||
// for row in [ [random() for j in range(10)] ]
|
||||
// for row_tot in [ sum(row) ]
|
||||
// for normalized in [ [ str(j / row_tot) for j in row ] ]
|
||||
// ))
|
||||
|
||||
val TEST_POINTS = Seq(
|
||||
Array(0.02786369384379577, 0.2135647897466291, 0.11666002519724866, 0.16060600994096885, 0.1881947749642125, 0.06403963097044776, 0.1568800096393837, 0.008981488712747435, 0.03297208945761692, 0.030237487526949314),
|
||||
Array(0.16075003689412376, 0.15606698494069582, 0.16987815773456, 0.026862140241724158, 0.072614915707935, 0.18032312339045223, 0.019560722743463082, 0.012569542805068614, 0.11910516750709621, 0.08226920803488132),
|
||||
Array(0.17968588526797769, 0.034774845334380695, 0.1035055839973817, 0.1767062943179949, 0.10444261644493995, 0.09581453253635883, 0.12469490440947303, 0.019050600283866344, 0.060640468089698196, 0.10068426931792872),
|
||||
Array(0.054469864839301774, 0.15977197188406447, 0.05522890760609085, 0.14660316269034734, 0.12534406923791186, 0.03281310796715435, 0.02623648065320645, 0.12556136154727013, 0.15822772301042895, 0.11574335056422387),
|
||||
Array(0.036597335767986953, 0.06980192128780272, 0.20613728928706976, 0.13349734163305085, 0.08570545470184154, 0.16198326683513503, 0.03416408769409085, 0.011102285434559944, 0.06722279284149091, 0.19378822451697142),
|
||||
Array(0.014696065601550116, 0.08076142961134138, 0.17497911095843408, 0.16999633242830414, 0.03714803751862688, 0.1655275169641564, 0.05210590264424636, 0.030203608779210134, 0.13995137099241034, 0.13463062450172023),
|
||||
Array(0.09544182650487694, 0.12304472947862866, 0.06700532112081119, 0.09335611140516697, 0.05852684038662141, 0.15129577402652522, 0.005101140939778661, 0.13113776195353483, 0.16922048978853085, 0.10587000439552527),
|
||||
Array(0.09841302725336286, 0.004609906162025474, 0.0020254613404004705, 0.099335706548862, 0.19778223934417596, 0.13103099367880117, 0.020242231414964507, 0.09104737013055085, 0.24827107780267862, 0.10724198632417814),
|
||||
Array(0.10053434124679797, 0.09690053663156127, 0.016961064845680825, 0.1284012417040199, 0.07593036730094833, 0.03718316159082512, 0.27539117548986564, 0.07420006124257486, 0.08454724028355125, 0.10995080966417474),
|
||||
Array(0.14553640850413638, 0.022947418483945308, 0.13935357739008655, 0.10806158990354552, 0.03738831867137737, 0.1524217388990238, 0.08170769648381157, 0.11438558584127417, 0.07053737377098514, 0.1276602920518141),
|
||||
Array(0.13905750256508295, 0.06026839869850839, 0.14362824977673022, 0.08286464616062053, 0.043771751931986944, 0.16495463319563128, 0.05412684077454717, 0.001291793672741278, 0.15791801612536574, 0.15211816709878545),
|
||||
Array(0.03450571178106437, 0.14910161747541997, 0.20936934065843624, 0.02070157037201108, 0.1756759095703582, 0.08506914704162599, 0.15409531889595104, 0.004045881843195852, 0.09210520768839635, 0.07533029467354076),
|
||||
Array(0.2057234435786517, 0.10493719048149984, 0.020287338472567362, 0.20254672365888846, 0.028375418482746502, 0.01625435048306782, 0.14894698664688855, 0.20213846581757053, 0.03591455506722132, 0.03487552731089792),
|
||||
Array(0.06596983800464223, 0.05574212028918465, 0.14620125059558675, 0.05085120836064347, 0.15319423683346958, 0.12287030952445653, 0.04263171652560145, 0.09451250962985222, 0.15044704185459362, 0.1175797683819695),
|
||||
Array(0.18555377385749874, 0.07871307216511918, 0.15499599602862116, 0.1358584464428079, 0.060900665316316674, 0.09300938609128262, 0.1317806141718268, 0.08414074629373457, 0.03520282582526993, 0.039844473807522374),
|
||||
Array(0.14833652803655697, 0.05918767983340056, 0.17614058216936196, 0.11664356250084375, 0.02377138706047284, 0.05986212068247505, 0.06472307371490708, 0.18239528252523526, 0.004068441917034033, 0.16487134155971256),
|
||||
Array(0.08235276489612064, 0.10738346146426181, 0.21608377210555152, 0.17469183114182013, 0.105590268615304, 0.10751562144898245, 0.0032152082193963667, 0.0706292192391151, 0.043298004240022135, 0.08923984862942576),
|
||||
Array(0.10826799993848797, 0.03406112185025378, 0.10336881344936165, 0.18260565533775105, 0.14081818689739342, 0.02407061707107449, 0.0959948555290124, 0.12232632842609834, 0.1850108304024592, 0.0034755910981077477),
|
||||
Array(0.05256010963695636, 0.019164389480340625, 0.1127586018343423, 0.13136201930184083, 0.11227176484238988, 0.16928642657767784, 0.2204302839102031, 0.06404268034494189, 0.009416095747722201, 0.10870762832358485),
|
||||
Array(0.15954630727312244, 0.1916376979674094, 0.050378024749961677, 0.027356529819862578, 0.09989300492655918, 0.05340800601232893, 0.04646402819832021, 0.02768228194399227, 0.15448831033308144, 0.1891458087753619),
|
||||
Array(0.019200457769131485, 0.1378188474418964, 0.09384242345414984, 0.07951286357790391, 0.07556775333478064, 0.11200594535143274, 0.16501526319239118, 0.0448964581181867, 0.15480810624984326, 0.11733188151028393),
|
||||
Array(0.05206307179258512, 0.07055277248158913, 0.08428035122336934, 0.02917298975957042, 0.009695474765426227, 0.1611917483375634, 0.0710332933784978, 0.1743145539140939, 0.17579315105180218, 0.1719025932955026),
|
||||
Array(0.08445391973896008, 0.11651733184025334, 0.03511139009625676, 0.027928506021393452, 0.15396570613761867, 0.053631895244169545, 0.015542676077573613, 0.1182971185981754, 0.2272023118283988, 0.16734914441720045),
|
||||
Array(0.13930886603994833, 0.0014655710328042983, 0.1972830924367536, 0.054212838876357715, 0.16334652564534255, 0.04378019741714191, 0.1937464410874603, 0.001520443744550877, 0.08040206888269434, 0.124933954836946),
|
||||
Array(0.001501069391163595, 0.11912933962270829, 0.1139100036368197, 0.1697588821026031, 0.2002010249473232, 0.011009229447535176, 0.07085825240084082, 0.06118818225252381, 0.11645801786380054, 0.13598599833468167),
|
||||
Array(0.11198651380073302, 0.07016838904931964, 0.10501991622186009, 0.17093428610316377, 0.14139783951894258, 0.1359566584991451, 0.024403446712022783, 0.06283315341496183, 0.07431085312299472, 0.10298894355685662),
|
||||
Array(0.02674983175823556, 0.14102838319019134, 0.07471775101084588, 0.100662073794369, 0.1249608784899519, 0.01168085763966831, 0.10895479478392615, 0.14700798084213076, 0.08822437146565835, 0.17601307702502278),
|
||||
Array(0.07282509184093233, 0.16462082249063442, 0.011944597581792534, 0.05869211153850399, 0.17837675679678575, 0.14137448988829124, 0.03625761144310586, 0.20002926104446406, 0.11505874611373491, 0.020820511261754937),
|
||||
Array(0.07312467632109246, 0.07131405646070389, 0.05364984326304017, 0.1275885897183236, 0.15203832211005996, 0.14688485553287495, 0.053414244252040104, 0.11218171860255759, 0.10747214301281464, 0.10233155072649258),
|
||||
Array(0.052264907847167585, 0.13996869391737796, 0.1237850202151384, 0.15548961890806612, 0.039113931956737576, 0.12313879212039727, 0.10013984076208021, 0.02443537013724052, 0.1982285126031362, 0.04343531153265826),
|
||||
Array(0.060913521771762354, 0.10986915923669321, 0.1634626908013808, 0.06175139118263744, 0.04459449533029302, 0.08455861393061977, 0.009479421364473285, 0.05025890017186922, 0.21439085061081703, 0.20072095559945394),
|
||||
Array(0.015217540126475011, 0.1517449337660542, 0.14459441588156344, 0.14173612515367662, 0.04711350556716002, 0.024059517089045322, 0.23635518579554576, 0.1487064632659587, 0.015194521500470206, 0.07527779185405052),
|
||||
Array(0.05745323442281053, 0.17324336310424626, 0.1878853267088442, 0.06831265890243099, 0.09304489444215407, 0.030537555722050774, 0.1862273749084694, 0.07137934514175667, 0.12119313329031117, 0.01072311335692596),
|
||||
Array(0.08860513632762504, 0.04173430420618502, 0.16594185265286898, 0.11766022206961535, 0.16943251635941364, 0.09213380182191239, 0.0835760598725975, 0.05035706965912134, 0.018786463035932434, 0.17177257399472837),
|
||||
Array(0.0895669672527639, 0.032415022388931986, 0.14551951973176522, 0.03414782270970075, 0.10021140103430214, 0.09915006195274664, 0.12865709512186582, 0.15134017300691519, 0.10353937487222276, 0.1154525619287855),
|
||||
Array(0.03442521286270366, 0.06035874299364749, 0.2437951665921442, 0.043501578124571484, 0.0626675613469455, 1.9433901486353943e-06, 0.09810751486613992, 0.13263672018606582, 0.1872474160869221, 0.13725814355071117),
|
||||
Array(0.05745233251260445, 0.056509977834535624, 0.028561488118698523, 0.061159065978089434, 0.23159729489044467, 0.08024399331098168, 0.23839854038499997, 0.07219252165186028, 0.060291594577759254, 0.1135931907400262),
|
||||
Array(0.006545298567172542, 0.012024749274310546, 0.1981070785543689, 0.010983495021697626, 0.17095342603363287, 0.15511216552845877, 0.20061590070833488, 0.16898316302017716, 0.014957858914186684, 0.06171686437765987),
|
||||
Array(0.112452632250053, 0.0873577911566027, 0.09808604910468359, 0.11268407968923635, 0.028850384504150794, 0.22762058873094965, 0.044980420796501885, 0.2328021417794967, 3.159367582590486e-05, 0.05513431831249941),
|
||||
Array(0.1370870974799079, 0.1567295881948363, 0.045537507740712004, 0.1408988907940065, 0.08538191870740057, 0.02688394461781232, 0.10178178483134608, 0.018535253815448512, 0.061416252984871106, 0.22574776083365886),
|
||||
Array(0.0828723419576967, 0.09456812652251853, 0.087599663117369, 0.06940196305570533, 0.15814054917175766, 0.08649935758047966, 0.1511194981586785, 0.030327479239717826, 0.030069969120044247, 0.20940105207603266),
|
||||
Array(0.14762781614886256, 0.04023623467374839, 0.1748064132090026, 0.0026407953142249434, 0.0015271528530661253, 0.15465689623253037, 0.07383003179295365, 0.1908504597043735, 0.1350662560688042, 0.0787579440024339),
|
||||
Array(0.11882225305025361, 0.05836498159445185, 0.007323881155458715, 0.11200456710077376, 0.04877839292261081, 0.17700695269172076, 0.05125787050557799, 0.1245864966070685, 0.15069411272264194, 0.15116049164944198),
|
||||
Array(0.13112175832248862, 0.08319329882632863, 0.01760138535760466, 0.027501955918210818, 0.11123330954725533, 0.04399177715468523, 0.13546971264403335, 0.21036905882472404, 0.20325115371673885, 0.03626658968793039),
|
||||
Array(0.12312411171596355, 0.15851997068360643, 0.07252339856567444, 0.027003412117436752, 0.1365971405823504, 0.13664381594953265, 0.09357162473310766, 0.058408790174600246, 0.11319857048567528, 0.08040916499205246),
|
||||
Array(0.14236614881385598, 0.12973728493303155, 0.09450314652995326, 0.013288030160468837, 0.13727428356328655, 0.03659259516042904, 0.18996513121810632, 0.08930002405246872, 0.08382435962082861, 0.08314899594757105),
|
||||
Array(0.10236780204015723, 0.10035637646153997, 0.16609171402360112, 0.07102726518267725, 0.19238063351329496, 0.03111588057898452, 0.06596072225967456, 0.11939758795170409, 0.12277570183113637, 0.02852631615722993),
|
||||
Array(0.038427451265671826, 0.03552767539437494, 0.09219452547216199, 0.020730190479507946, 0.0525290169775004, 0.1971651865499051, 0.20127608608880615, 0.05659518743569417, 0.2217251820500291, 0.08382949828634836),
|
||||
Array(0.06906678154873087, 0.06424257692768835, 0.16045435036892947, 0.15674198874413098, 0.10205174258571183, 0.1600182411888799, 0.07227763616472796, 0.09726723426667229, 0.10729473698096009, 0.010584711223568294),
|
||||
Array(0.04303408979606997, 0.11506659899040815, 0.10512435670911709, 0.025136402035547175, 0.043604950211019705, 0.1463989451069652, 0.12135301739113916, 0.09917783050679913, 0.11126446973527448, 0.1898393395176599),
|
||||
Array(0.1383421159312698, 0.017020935780186985, 0.11095773109625401, 0.05317066820330746, 0.17993298108381683, 0.10922460166249907, 0.15272433817653416, 0.05730233669034009, 0.07162200938236066, 0.10970228199343099),
|
||||
Array(0.08238869085687614, 0.19828645657813956, 0.1033809052822856, 0.14219789273160144, 0.022054842876676983, 0.13446380986140266, 0.008520389446564379, 0.17372304480237788, 0.013514872280462185, 0.12146909528361316),
|
||||
Array(0.07146308358640838, 0.11554509979386697, 0.12195456606203131, 0.11026445069573283, 0.050695003724245846, 0.08847974634326809, 0.16501175848054325, 0.03635828429806049, 0.06892282119405908, 0.1713051858217837),
|
||||
Array(0.031003911118011148, 0.18453324078441044, 0.04589523340882497, 0.0014457919011631284, 0.19811328040024434, 0.17542175454472703, 0.04407676435028674, 0.15322298978330823, 0.11913313399701377, 0.04715389971201023),
|
||||
Array(0.12136288250452541, 0.12200057525297012, 0.0959592751166102, 0.13479202548737337, 0.04542223158973626, 0.098389670106214, 0.1039037274421444, 0.08616931699326366, 0.14065143310896572, 0.05134886239819683),
|
||||
Array(0.0610841293894362, 0.10188113215530137, 0.1612521333551918, 0.1603116702924079, 0.11483383120307003, 0.14787591994528515, 0.0401354177270283, 0.021346682661379342, 0.15213496497507203, 0.039144118295827764),
|
||||
Array(0.18620934268608338, 0.07256616550566836, 0.05405720101212947, 0.019072243525862603, 0.11437576612350972, 0.06392597855865172, 0.053496498025236684, 0.18402861234038595, 0.12323023175142413, 0.12903796047104776),
|
||||
Array(0.09612630437653323, 0.11218060999035627, 0.08494531942445314, 0.1545252974988426, 0.18043534216852417, 0.13695460159572437, 0.1373109579971919, 0.0306844560238321, 0.03264618464079689, 0.034190926283745204),
|
||||
Array(0.20457414906586738, 0.04001494152709804, 0.0832459217129384, 0.16588363609797907, 0.20215312685412995, 0.046690283788707876, 0.01062427956004096, 0.04897152328932626, 0.19156571500543407, 0.00627642309847815),
|
||||
Array(0.11129899181354237, 0.04683584505447576, 0.1572992495502157, 0.14641700712243083, 0.17155552523447126, 0.03075988820331225, 0.1656223506124421, 0.01558815127355019, 0.005279112927205255, 0.14934387820835415),
|
||||
Array(0.06767382466414455, 0.05375831181296221, 0.03991282557279321, 0.14448155207677496, 0.12640537574685487, 0.08587846976098351, 0.12024219323027723, 0.17571017720793888, 0.0028146701812585095, 0.18312259974601194),
|
||||
Array(0.15197947386288257, 0.06333665700705148, 0.09880659263608725, 0.08889102627448789, 0.1425165332698214, 0.12499207001522576, 0.09140509463456094, 0.16261101106037534, 0.06090731422621665, 0.014554227013290677),
|
||||
Array(0.027192311131099413, 0.17948327140041206, 0.17128646886226406, 0.013561308532162959, 0.15577896801106117, 0.06560073101026488, 0.12191557858569357, 0.020534308512634197, 0.14721014185043346, 0.09743691210397429),
|
||||
Array(0.09589145399363277, 0.161883046219811, 0.15958616151973698, 0.17358072813034178, 0.07913872162732073, 0.10915531019737526, 0.03333300622497551, 0.1196088388997209, 0.043633695057876225, 0.024189038129208553),
|
||||
Array(0.10240296921547659, 0.07979930760561445, 0.07494419877571525, 0.026943945001596473, 0.11299465698027668, 0.1251080756117822, 0.10742320477840746, 0.11585913821208153, 0.12234511368543098, 0.13217939013361843),
|
||||
Array(0.018272481282677513, 0.07881215294990411, 0.07329684760969837, 0.10577098545384651, 0.13834150656071298, 0.2612620102866262, 0.05146690621582695, 0.07516672003663458, 0.07702984811722105, 0.12058054148685161),
|
||||
Array(0.038323843482331955, 0.16082109101115258, 0.15622897692374535, 0.021090597948624654, 0.1631008356609383, 0.0577162642655181, 0.14948296472662748, 0.04081398698490246, 0.16827909007455513, 0.044142348921603874),
|
||||
Array(0.13905243603651857, 0.05937272808206395, 0.1295316158368605, 0.008030893488493477, 0.11269955036810822, 0.08330801795412077, 0.10463908779528827, 0.16710523981841136, 0.12119521465912618, 0.07506521596100874),
|
||||
Array(0.007052506966823561, 0.15005628267684076, 0.10592010799082947, 0.19374549331925403, 0.08510635081973832, 0.06266414145516808, 0.06929792991232463, 0.06720234090262948, 0.15025971652801526, 0.10869512942837657),
|
||||
Array(0.14076756079430117, 0.14219908274155096, 0.04919064264593424, 0.0041672195648900165, 0.12277866590145003, 0.0996550860392083, 0.18192102615499783, 0.1161200994003205, 0.007787134091381239, 0.13541348266596562),
|
||||
Array(0.08035870130087652, 0.041408830577689376, 0.05721717327547464, 0.05313940905601734, 0.14078976126274104, 0.0430548858233374, 0.1875912075375216, 0.12845666782358606, 0.15602065513815644, 0.11196270820459969),
|
||||
Array(0.08464213592351863, 0.1663362794095464, 0.11701432794811453, 0.1132305246217393, 0.04537987062087066, 0.14817422552931084, 0.0801461215400441, 0.04863611505341482, 0.17495773443566762, 0.02148266491777331),
|
||||
Array(0.14387591124606855, 0.0036153959999540007, 0.041261288784413375, 0.164162789329692, 0.09704271492032002, 0.1515295948870589, 0.07687415942230265, 0.0749058835388297, 0.1286797155368104, 0.1180525463345503),
|
||||
Array(0.11678527688359384, 0.05701481306075269, 0.036233802076254644, 0.016007525984071427, 0.19053078281109453, 0.08757587756821966, 0.17679759340566178, 0.08083439825244075, 0.17569698909701914, 0.06252294086089163),
|
||||
Array(0.10575602840156381, 0.08829117000476625, 0.1344773048869035, 0.1215749081064526, 0.055782367675050654, 0.06020409328653541, 0.052553470285285044, 0.1287637478862188, 0.1931919200521899, 0.059404989415033915),
|
||||
Array(0.05106095374343528, 0.07067388123959531, 0.007635580569668982, 0.07749573541523033, 0.0002908062524857475, 0.13335924831853244, 0.2541104723669051, 0.1926156883238885, 0.14400886481888273, 0.0687487689513756),
|
||||
Array(0.1946837775687716, 0.09941344324391299, 0.015689877925444753, 0.08915416153416461, 0.10047568992072241, 0.07878857782918293, 0.21183909989412525, 0.08800928522146421, 0.05409522006270209, 0.06785086679950908),
|
||||
Array(0.02089287310094577, 0.12865061496681493, 0.14867470087523896, 0.1290356586660657, 0.16187900052758378, 0.09629872811391459, 0.13550646865875327, 0.1432009110907077, 0.01789172189110897, 0.017969322108866054),
|
||||
Array(0.12889069548710222, 0.07177805872946336, 0.015574519037764451, 0.16599706072338122, 0.17635827057529121, 0.05360207894659873, 0.025043073723883084, 0.00841039526102416, 0.11703540427845688, 0.23731044323703468),
|
||||
Array(0.07692723997162568, 0.12477992505848887, 0.15233127930472243, 0.11490899123875933, 0.07367056097970794, 0.08470471978223348, 0.10988906546704379, 0.13425224747056186, 0.11930929073472826, 0.009226679992128485),
|
||||
Array(0.1476803839992263, 0.11737829061979849, 0.021037777863082675, 0.019593901851785844, 0.12152028001323316, 0.2600949943989843, 0.048220905571511145, 0.08398824455572904, 0.11187382280204103, 0.06861139832460808),
|
||||
Array(0.11553108094874572, 0.15160135563262345, 0.0765114299190399, 0.16035122231239185, 0.05160218332494809, 0.1351482916622384, 0.05717312513079071, 0.06744072248528292, 0.10781618199896906, 0.07682440658496992),
|
||||
Array(0.0013460797570203123, 0.19398646249663465, 0.16505519347895306, 0.006031889147179139, 0.06335279956378362, 0.05696159938098616, 0.17270295416968828, 0.042022589068066066, 0.20004958681511806, 0.09849084612257067),
|
||||
Array(0.110901595666389, 0.02813994433050754, 0.13287394783050363, 0.016796897890290808, 0.0953827610982354, 0.13966083811301064, 0.14956378010013927, 0.11106465442773401, 0.07778370592297548, 0.13783187462021415),
|
||||
Array(0.08047800613791091, 0.20550193792608948, 0.028433554624275197, 0.06972349687277409, 0.1895224006584368, 0.04394132302671363, 0.05391915581735497, 0.14343917585601257, 0.06736122532122182, 0.11767972375921057),
|
||||
Array(0.08120524567415402, 0.1325019699972708, 0.1403129142885263, 0.1155021290536291, 0.07928062442710866, 0.1029096864052427, 0.11380150370781267, 0.05044786372181201, 0.09677906233820972, 0.08725900038623399),
|
||||
Array(0.07341055713353306, 0.19496109162308006, 0.0991356776684781, 0.11550769181843593, 0.04818315020253669, 0.08361951827201225, 0.16335956188944104, 0.11659543070463235, 0.07375150317914303, 0.031475817508707614),
|
||||
Array(0.00931456554807569, 0.01910089441469917, 0.09513445045078563, 0.16626585205771405, 0.1332822725601691, 0.13857787263650406, 0.15434544153776117, 0.09695036299392755, 0.09812464404279193, 0.0889036437575716),
|
||||
Array(0.17448370568521204, 0.14180587401771316, 0.028245844069491873, 0.06132943246747549, 0.17445513703908866, 0.01594226338873566, 0.03770125570411761, 0.09222827138864834, 0.17949961294141703, 0.09430860329810005),
|
||||
Array(0.0640398951259249, 0.1840760812877578, 0.14655638655976444, 0.12672848715574925, 0.03719907009330565, 0.049839410953140116, 0.05737874189763877, 0.1605711480344574, 0.14377626239622437, 0.02983451649603743),
|
||||
Array(0.01554080270021833, 0.06282138247723013, 0.09780175378296682, 0.014820617689894525, 0.2115203235999544, 0.1882009025723618, 0.18598745074906148, 0.028158752817529553, 0.010473640701422066, 0.18467437290936095),
|
||||
Array(0.15650802915539924, 0.06260347808619696, 0.014681122978988595, 0.02293146608507115, 0.18076543062564712, 0.0794753490280972, 0.005596700126883443, 0.07231201241717639, 0.21854292519888413, 0.18658348629765584),
|
||||
Array(0.14158695686146752, 0.073752243870097, 0.12584239274702985, 0.04541979041504646, 0.08119921746780319, 0.08921688454251944, 0.061174608112587014, 0.12895619869703842, 0.1485642567225797, 0.10428745056383144),
|
||||
Array(0.004493512534186644, 0.22891409798555748, 0.11765683068937702, 0.013775782465550352, 0.13714842098668178, 0.10207697690191013, 0.1754825455443045, 0.11149658214106123, 0.02660961429105109, 0.08234563646031984),
|
||||
Array(0.016162492311754456, 0.08046306590581803, 0.006204266960017437, 0.13953545564278846, 0.022109927432781153, 0.13036969386397787, 0.21797705986478916, 0.12165267548642489, 0.079320128896581, 0.18620523363506764),
|
||||
Array(0.16330552446703467, 0.07729365834616232, 0.08491683478227961, 0.05372076913470597, 0.10603150542124849, 0.14163279994268874, 0.09806560340082228, 0.0665972382478773, 0.09541506488877434, 0.11302100136840619),
|
||||
Array(0.14870213069734303, 0.16705476117167162, 0.050864827435860155, 0.17029113854684527, 0.13547987517179494, 0.017422927951837527, 0.004147161243030273, 0.1483140185384386, 0.14604330222833023, 0.011679857014848218),
|
||||
Array(0.08992489387572739, 0.10518820408279016, 0.07312023884891752, 0.014449732818078313, 0.05075759011625807, 0.08390089421392577, 0.21442543374895132, 0.1127530473157631, 0.01838619941558753, 0.23709376556400089),
|
||||
Array(0.05685013365480455, 0.12323587951459787, 0.1385773832788023, 0.08673781720966815, 0.0472449947192977, 0.1213310020696927, 0.10356254661000977, 0.10036077345290412, 0.11919967363196823, 0.10289979585825443),
|
||||
Array(0.06210250850199482, 0.07394139633169498, 0.11347634992174827, 0.17509649537205516, 0.02603702022320766, 0.06454778194484233, 0.2009116520031418, 0.04134136491143604, 0.14825132903901547, 0.09429410175086352),
|
||||
)
|
||||
|
||||
val TEST_QUERIES = Seq(
|
||||
Array(0.09435371769303323, 0.1233102182440501, 0.035230203246711805, 0.14140749901031638, 0.06014744372531397, 0.0870518687946835, 0.13321966573485866, 0.04078493433809051, 0.12734968037671182, 0.15714476883623013),
|
||||
Array(0.17694310947774938, 0.17566273907454327, 0.04360988797718458, 0.03794594810984841, 0.07369866424191941, 0.061032184207235565, 0.0698539188324209, 0.1553009185894869, 0.041301383797290085, 0.16465124569232142),
|
||||
Array(0.05279511148513846, 0.05601157286831524, 0.14154814872259058, 0.10430107627142228, 0.15002313340757642, 0.06950414061236984, 0.11364641810832642, 0.12601144806948125, 0.0783011297392839, 0.10785782071549557),
|
||||
Array(0.1065115083401876, 0.17946469555540942, 0.03405652390742008, 0.1247966973583643, 0.18874236837731603, 0.12074782714672913, 0.0026503278103620063, 0.0519239521363894, 0.18352257373704428, 0.007583525630777577),
|
||||
Array(0.03667987113643319, 0.09009103953860809, 0.12088138476553231, 0.12088562286925937, 0.1264710071620609, 0.12832491899952636, 0.12412938396670012, 0.14543954572768858, 0.017121037654405494, 0.08997618817978538),
|
||||
Array(0.11157574921223326, 0.1919526862470603, 0.024086150217414424, 0.09159246885001404, 0.05265214855351364, 0.07075989738816084, 0.1399382629458219, 0.19541514307445249, 0.01929324251064926, 0.10273425100067997),
|
||||
Array(0.05969795877369263, 0.20241316322980035, 0.0638817532580585, 0.04689840818476116, 0.06803745210991967, 0.2334322835172291, 0.17343454057411228, 0.12209764056482299, 0.002864067981462762, 0.02724273180614056),
|
||||
Array(0.18839893622558046, 0.10620457231911698, 0.1292661879605831, 0.07176892412623401, 0.12387960260043697, 0.032602352038677444, 0.07875294484104155, 0.10347908925014886, 0.06552736954166116, 0.10012002109651934),
|
||||
Array(0.17025233466774073, 0.11121705371426796, 0.1556308914180297, 0.17664557463788913, 0.011031067141640242, 0.04954155612013968, 0.06614698282241012, 0.03981079415920263, 0.17042516269440885, 0.04929858262427107),
|
||||
Array(0.18831499479266145, 0.04636745966774445, 0.0018217837546960797, 0.1915066571154377, 0.07135828610276607, 0.14086006010078103, 0.051910285700104544, 0.07177581221219843, 0.18672258527455976, 0.0493620752790506),
|
||||
)
|
||||
|
||||
def makeData(size: Int): IndexedSeq[Array[Double]] =
|
||||
{
|
||||
(0 until size).map { _ =>
|
||||
(0 until DIMENSIONS).map { _ =>
|
||||
Random.nextDouble
|
||||
}.toArray
|
||||
}.toArray.toIndexedSeq
|
||||
}
|
||||
}
|
|
@ -3,31 +3,206 @@ package org.mimirdb.pip
|
|||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.log4j.{ Logger, Level }
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.ml.feature.{ StringIndexer, VectorAssembler, SQLTransformer }
|
||||
import scala.util.Random
|
||||
|
||||
/* To read file from URL */
|
||||
import sys.process._
|
||||
import java.net.URL
|
||||
import java.io.File
|
||||
|
||||
/*To check for existence */
|
||||
import java.nio.file.{Paths, Files}
|
||||
|
||||
/* For SedonaContext */
|
||||
import scala.collection.mutable.PriorityQueue
|
||||
import scala.io.Source
|
||||
import java.util.Base64
|
||||
import org.apache.sedona.spark.SedonaContext
|
||||
import org.apache.logging.log4j.core.tools.picocli.CommandLine.Help.Column
|
||||
import org.mimirdb.pip.distribution._
|
||||
import org.mimirdb.pip.lib._
|
||||
import org.mimirdb.pip.udt._
|
||||
import org.mimirdb.pip.udf._
|
||||
import collection.JavaConverters._
|
||||
import org.apache.spark.ml.classification.{ DecisionTreeClassifier, DecisionTreeClassificationModel }
|
||||
import org.apache.spark.ml.{ Pipeline, PipelineModel }
|
||||
import java.io.File
|
||||
import org.apache.spark.ml.haxx.ExtractDecisionTree
|
||||
|
||||
object Main
|
||||
{
|
||||
|
||||
/*I/O source: https://stackoverflow.com/questions/24162478/how-to-download-and-save-a-file-from-the-internet-using-scala#26422540 */
|
||||
def fileDownload(url: String, filename: String): Unit = {
|
||||
new URL(url) #> new File(filename) !!
|
||||
type Row = (UnivariateDistribution, String, Double, Double, Long)
|
||||
|
||||
val BUCKETS = 50
|
||||
val SPEED_MIN = 0
|
||||
val SPEED_MAX = 60
|
||||
val MODEL_FILE = new File("decision_tree.ml")
|
||||
|
||||
// There's a clear knee at 0.001
|
||||
// 195 0.001203387770256109
|
||||
// 196 2.7478019859472624E-15
|
||||
val THRESHOLD = 0.0001
|
||||
|
||||
def loadData(): Seq[Row] =
|
||||
{
|
||||
val base64 = Base64.getDecoder()
|
||||
Source.fromFile("test_data/locations.csv")
|
||||
.getLines
|
||||
.drop(1)
|
||||
.map { line =>
|
||||
val fields = line.split(",")
|
||||
val scene = fields(1)
|
||||
val x = fields(2).toDouble
|
||||
val y = fields(3).toDouble
|
||||
val timestamp = fields(4).toLong
|
||||
val dist = base64.decode(fields(5))
|
||||
(
|
||||
UnivariateDistribution.decode(dist),
|
||||
scene,
|
||||
x,
|
||||
y,
|
||||
timestamp,
|
||||
)
|
||||
}
|
||||
.toSeq
|
||||
}
|
||||
|
||||
def buildHierarchicalClusters(data: Seq[Row]): HierarchicalClustering.Cluster[Double, Row] =
|
||||
{
|
||||
val min = data.map { d => pip_min(d._1) }.min
|
||||
val max = data.map { d => pip_min(d._1) }.max
|
||||
|
||||
assert(min >= SPEED_MIN)
|
||||
assert(max <= SPEED_MAX)
|
||||
|
||||
val bucketStep = (SPEED_MAX - SPEED_MIN) / (BUCKETS)
|
||||
val bucketBounds =
|
||||
(0 until BUCKETS).map { b => (b * bucketStep, (b+1) * bucketStep) }
|
||||
|
||||
val coordinates =
|
||||
data.map { d =>
|
||||
val dist = d._1
|
||||
(
|
||||
bucketBounds.map { case (low, high) =>
|
||||
pip_p_between(low, dist, high)
|
||||
}.toArray,
|
||||
d
|
||||
)
|
||||
}
|
||||
.toIndexedSeq
|
||||
|
||||
HierarchicalClustering.bottomUp(coordinates,
|
||||
new Distance.Manhattan{
|
||||
val min = 0.0
|
||||
val max = 1.0
|
||||
val dimensions = BUCKETS
|
||||
})
|
||||
}
|
||||
|
||||
def exploreHistogramClusters(data: Seq[Row]): Unit =
|
||||
{
|
||||
val queue =
|
||||
PriorityQueue()(new Ordering[HierarchicalClustering.Group[Double, Row]] {
|
||||
def compare(a: HierarchicalClustering.Group[Double, Row],
|
||||
b: HierarchicalClustering.Group[Double, Row]): Int =
|
||||
Ordering[Double].compare(a.radius, b.radius)
|
||||
})
|
||||
|
||||
def enqueue(c: HierarchicalClustering.Cluster[Double, Row]): Unit =
|
||||
c match {
|
||||
case g:HierarchicalClustering.Group[Double, Row] if g.radius > 0.0 => queue.enqueue(g)
|
||||
case _ => ()
|
||||
}
|
||||
|
||||
enqueue(buildHierarchicalClusters(data))
|
||||
|
||||
var count = 0
|
||||
var lastRadius = Double.PositiveInfinity
|
||||
while(!queue.isEmpty){
|
||||
val group = queue.dequeue()
|
||||
count += 1
|
||||
if(group.radius < lastRadius){
|
||||
println(s"$count\t\t\t${group.radius}")
|
||||
lastRadius = group.radius
|
||||
}
|
||||
enqueue(group.left)
|
||||
enqueue(group.right)
|
||||
}
|
||||
}
|
||||
|
||||
def clusterHistograms(data: Seq[Row]): Seq[HierarchicalClustering.Cluster[Double, Row]] =
|
||||
{
|
||||
buildHierarchicalClusters(data)
|
||||
.threshold(THRESHOLD)
|
||||
.toSeq
|
||||
}
|
||||
|
||||
def buildDecisionTree(
|
||||
spark: SparkSession,
|
||||
clusterIds: Seq[HierarchicalClustering.Cluster[Double, Row]]
|
||||
): PipelineModel =
|
||||
{
|
||||
val clusters =
|
||||
spark.createDataFrame(
|
||||
clusterIds.zipWithIndex
|
||||
.flatMap { case (cluster, idx) =>
|
||||
cluster.elements.map { singleton =>
|
||||
Row( idx.toString,
|
||||
singleton.value._1,
|
||||
singleton.value._2,
|
||||
singleton.value._3,
|
||||
singleton.value._4,
|
||||
singleton.value._5,
|
||||
)
|
||||
}
|
||||
}.toList.asJava,
|
||||
StructType(Seq(
|
||||
StructField("dist_cluster", StringType),
|
||||
StructField("dist", UnivariateDistributionType),
|
||||
StructField("scene", StringType),
|
||||
StructField("x", DoubleType),
|
||||
StructField("y", DoubleType),
|
||||
StructField("ts", LongType),
|
||||
))
|
||||
)
|
||||
|
||||
val convertScenesToIndexes =
|
||||
new StringIndexer()
|
||||
.setInputCol("scene")
|
||||
.setOutputCol("scene_index")
|
||||
.fit(clusters)
|
||||
|
||||
val convertDistsToIndexes =
|
||||
new StringIndexer()
|
||||
.setInputCol("dist_cluster")
|
||||
.setOutputCol("dist_cluster_index")
|
||||
.fit(clusters)
|
||||
|
||||
val assembleFeatureVector =
|
||||
new VectorAssembler()
|
||||
.setInputCols(Array(
|
||||
// "scene_index",
|
||||
"x",
|
||||
"y"
|
||||
))
|
||||
.setOutputCol("features")
|
||||
val filterQuery =
|
||||
new SQLTransformer()
|
||||
.setStatement("SELECT * FROM __THIS__ WHERE scene_index = 0.0")
|
||||
|
||||
val dt = new DecisionTreeClassifier()
|
||||
.setLabelCol("dist_cluster_index")
|
||||
.setFeaturesCol("features")
|
||||
.setMaxDepth(20)
|
||||
|
||||
val pipeline = new Pipeline()
|
||||
.setStages(Array(
|
||||
convertScenesToIndexes,
|
||||
convertDistsToIndexes,
|
||||
assembleFeatureVector,
|
||||
filterQuery,
|
||||
dt
|
||||
))
|
||||
|
||||
pipeline.fit(clusters)
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit =
|
||||
{
|
||||
println("Initialize Spark...")
|
||||
val spark = SparkSession.builder
|
||||
.appName("pip")
|
||||
.master("local[*]")
|
||||
|
@ -36,160 +211,61 @@ object Main
|
|||
spark.sparkContext.setLogLevel("WARN")
|
||||
spark.sparkContext.setCheckpointDir("spark-warehouse")
|
||||
|
||||
println("Initialize Pip...")
|
||||
Pip.init(spark)
|
||||
|
||||
/*
|
||||
Reproducing Vizier mars_rover workflow
|
||||
NOTE:
|
||||
this will probably be migrated into its own files later down the road
|
||||
*/
|
||||
println("Load speed distribution data...")
|
||||
val data = loadData()
|
||||
|
||||
import org.apache.spark.sql.DataFrameReader
|
||||
/* It appears hadoop doesn't like urls, so need to
|
||||
i) download the file locally
|
||||
ii) then read it in
|
||||
// Overall goal:
|
||||
// NuScene -> Scene / Timestamp / Speed Dist
|
||||
// Scene
|
||||
val model =
|
||||
if(MODEL_FILE.exists()){
|
||||
PipelineModel.load(MODEL_FILE.toString)
|
||||
} else {
|
||||
|
||||
*/
|
||||
val webData = "https://mars.nasa.gov/mmgis-maps/M20/Layers/json/M20_waypoints.json"
|
||||
val fileData: String = "marsRoverData.json"
|
||||
println("Construct and cluster distribution histograms...")
|
||||
val clusterIds = clusterHistograms(data)
|
||||
|
||||
/* We don't need to download if we already have the file.
|
||||
Source: http://stackoverflow.com/questions/21177107/ddg#21178667
|
||||
*/
|
||||
if (!Files.exists(Paths.get(fileData))) {
|
||||
println("We didn't find the file. Now downloading...")
|
||||
fileDownload(webData, fileData)
|
||||
println("Build Decision Tree...")
|
||||
val model = buildDecisionTree(spark, clusterIds)
|
||||
|
||||
model.write.save(MODEL_FILE.toString)
|
||||
|
||||
model
|
||||
}
|
||||
|
||||
val dt =
|
||||
model.stages.last
|
||||
.asInstanceOf[DecisionTreeClassificationModel]
|
||||
|
||||
val minX = data.map { _._3 }.min
|
||||
val maxX = data.map { _._3 }.max
|
||||
val minY = data.map { _._4 }.min
|
||||
val maxY = data.map { _._4 }.max
|
||||
|
||||
ExtractDecisionTree.print(dt)
|
||||
|
||||
val features =
|
||||
Seq(
|
||||
new DistSummary.ContinuousFeature(minX, maxX),
|
||||
new DistSummary.ContinuousFeature(minY, maxY),
|
||||
)
|
||||
|
||||
val regions =
|
||||
ExtractDecisionTree(dt, features)
|
||||
|
||||
for(row <- data)
|
||||
{
|
||||
val features =
|
||||
regions.insert( Array(row._3, row._4), row._1 )
|
||||
}
|
||||
|
||||
println("We did find the file, now reading.")
|
||||
assert(SedonaContext.create(spark) eq spark)
|
||||
var df = spark.read.option("multiLine", true).json(fileData)
|
||||
|
||||
/* Create temporary Spark view to query */
|
||||
df.createOrReplaceTempView("trips")
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// Extract GeoJSON and properties field from the data
|
||||
////////////////////////////////////////////////////////
|
||||
df = spark.sql("""
|
||||
SELECT features.type,
|
||||
features.properties.*,
|
||||
ST_GeomFromGeoJSON(to_json(features.geometry)) as geo
|
||||
FROM (
|
||||
SELECT explode(features) AS features FROM trips
|
||||
)
|
||||
""").coalesce(1)
|
||||
// sqlDF.printSchema()
|
||||
// df.show(false)
|
||||
df.createOrReplaceTempView("traverse_data")
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// Trip Times
|
||||
////////////////////////////////////////////////////////
|
||||
df = spark.sql("""
|
||||
SELECT *,
|
||||
dist_km - lag(dist_km, 1, 0) OVER (PARTITION BY 1 ORDER BY sol) AS km_traveled,
|
||||
sol - lag(sol, 1, 0) OVER (PARTITION BY 1 ORDER BY sol) AS sols_traveled
|
||||
FROM traverse_data
|
||||
WHERE dist_km > 0
|
||||
""")
|
||||
// df.show(false)
|
||||
df.createOrReplaceTempView("traverse_data")
|
||||
// spark.sql("""
|
||||
// SELECT max(km_traveled * 1000 / sols_traveled) as m_per_sol
|
||||
// FROM traverse_data
|
||||
// WHERE sols_traveled > 0
|
||||
// """).show()
|
||||
// return
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// Trip Distances
|
||||
////////////////////////////////////////////////////////
|
||||
df = spark.sql("""
|
||||
SELECT ST_Point(
|
||||
CAST(lon as decimal(24,20)),
|
||||
CAST(lat as decimal(24,20))
|
||||
) as geometry,
|
||||
clamp(gaussian(cast(km_traveled as double) * 1000 / sols_traveled, 3.0), 0.0, 800) as m_per_sol
|
||||
FROM traverse_data
|
||||
WHERE sols_traveled > 0
|
||||
""")//.checkpoint()
|
||||
// tripDist.printSchema()
|
||||
// df.show(false)
|
||||
df.createOrReplaceTempView("trip_points")
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// Bounding Box
|
||||
////////////////////////////////////////////////////////
|
||||
df = spark.sql("""
|
||||
SELECT min(lat) as min_lat,
|
||||
max(lat) as max_lat,
|
||||
min(lon) as min_lon,
|
||||
max(lon) as max_lon
|
||||
FROM traverse_data
|
||||
""")
|
||||
// df.show(false)
|
||||
df.createOrReplaceTempView("mission_region")
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// Example Histogram Regions
|
||||
////////////////////////////////////////////////////////
|
||||
df = spark.sql("""
|
||||
SELECT id, ST_PolygonFromEnvelope(lon_low, lat_low, lon_high, lat_high) as geometry
|
||||
FROM (
|
||||
SELECT
|
||||
10 * lon_idx + lat_idx AS id,
|
||||
(max_lat - min_lat)/10 * lat_idx + min_lat AS lat_low,
|
||||
(max_lat - min_lat)/10 * (lat_idx+1) + min_lat AS lat_high,
|
||||
(max_lon - min_lon)/10 * lon_idx + min_lon AS lon_low,
|
||||
(max_lon - min_lon)/10 * (lon_idx+1) + min_lon AS lon_high
|
||||
FROM (SELECT id AS lon_idx from range(0,10)) AS lon_split,
|
||||
(SELECT id AS lat_idx from range(0,10)) AS lat_split,
|
||||
mission_region
|
||||
)
|
||||
""")
|
||||
// df.show(false)
|
||||
df.createOrReplaceTempView("bounding_boxes")
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// Per-region distributions
|
||||
////////////////////////////////////////////////////////
|
||||
df = spark.sql("""
|
||||
SELECT box.id,
|
||||
array_agg(m_per_sol) as components,
|
||||
discretize(
|
||||
uniform_mixture(m_per_sol),
|
||||
array(0.0, 40.0, 80.0, 120.0, 160.0, 200.0, 240.0, 280.0, 320.0, 360.0, 400.0),
|
||||
1000
|
||||
) as m_per_sol
|
||||
FROM trip_points point,
|
||||
bounding_boxes box
|
||||
WHERE ST_Contains(box.geometry, point.geometry)
|
||||
GROUP BY box.id
|
||||
""")
|
||||
// df.show(false)
|
||||
df.createOrReplaceTempView("grid_squares")
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// Per-region metrics
|
||||
////////////////////////////////////////////////////////
|
||||
df = spark.sql("""
|
||||
SELECT id,
|
||||
m_per_sol,
|
||||
entropy(m_per_sol) as entropy,
|
||||
array_max(
|
||||
transform(
|
||||
components,
|
||||
x -> kl_divergence(x, m_per_sol)
|
||||
)
|
||||
) as max_kl_div
|
||||
-- components
|
||||
FROM grid_squares
|
||||
-- LIMIT 1
|
||||
""")
|
||||
df.show(false)
|
||||
// println(cluster)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
package org.mimirdb.pip.lib
|
||||
import org.mimirdb.pip.distribution.numerical.Discretized
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
object TestData
|
||||
{
|
||||
|
||||
val DIMENSIONS = 10
|
||||
val BIN_SIZE = 1.0 / DIMENSIONS
|
||||
|
||||
def positionToBins(position: Array[Double]): Seq[Discretized.Bin] =
|
||||
{
|
||||
position.zipWithIndex.map { case (b, i) =>
|
||||
Discretized.Bin(i * BIN_SIZE, (i+1) * BIN_SIZE, b)
|
||||
}.toSeq
|
||||
}
|
||||
|
||||
trait BaseDistance extends Distance[Double]
|
||||
{
|
||||
val min = 0.0
|
||||
val max = 1.0
|
||||
def pointToPlane(a: Array[Double], b: Double, dim: Int): Double =
|
||||
{
|
||||
Math.abs(a(dim) - b)
|
||||
}
|
||||
def centroid(a: Iterable[Array[Double]]): Array[Double] =
|
||||
{
|
||||
val ret = Array.ofDim[Double](DIMENSIONS)
|
||||
for(pt <- a)
|
||||
{
|
||||
for(i <- 0 until DIMENSIONS)
|
||||
{
|
||||
ret(i) += pt(i)
|
||||
}
|
||||
}
|
||||
for(i <- 0 until DIMENSIONS)
|
||||
{
|
||||
ret(i) /= a.size
|
||||
}
|
||||
return ret
|
||||
}
|
||||
}
|
||||
object ManhattanDistance extends BaseDistance
|
||||
{
|
||||
def pointToPoint(a: Array[Double], b: Array[Double]): Double =
|
||||
{
|
||||
var tot = 0.0
|
||||
for(i <- 0 until a.size)
|
||||
{
|
||||
tot += Math.abs(a(i) - b(i))
|
||||
}
|
||||
return tot
|
||||
}
|
||||
}
|
||||
object EuclideanDistance extends BaseDistance
|
||||
{
|
||||
def pointToPoint(a: Array[Double], b: Array[Double]): Double =
|
||||
{
|
||||
var tot = 0.0
|
||||
for(i <- 0 until a.size)
|
||||
{
|
||||
val v = Math.abs(a(i) - b(i))
|
||||
tot += v * v
|
||||
}
|
||||
Math.sqrt(tot)
|
||||
}
|
||||
}
|
||||
|
||||
def makeData(size: Int): IndexedSeq[Array[Double]] =
|
||||
{
|
||||
(0 until size).map { _ =>
|
||||
(0 until DIMENSIONS).map { _ =>
|
||||
Random.nextDouble
|
||||
}.toArray
|
||||
}.toArray.toIndexedSeq
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue