79 lines
2.3 KiB
Scala
79 lines
2.3 KiB
Scala
package org.mimirdb.pip.lib
|
|
|
|
import org.scalatest.flatspec.AnyFlatSpec
|
|
import TestData._
|
|
import org.mimirdb.pip.distribution.Discretized
|
|
|
|
class HierarchicalClusteringTests extends AnyFlatSpec {
|
|
|
|
def testRadii(cluster: HierarchicalClustering.Cluster[Double, Int], parentRadius: Double): Unit =
|
|
{
|
|
cluster match {
|
|
case _:HierarchicalClustering.Singleton[Double, Int] => ()
|
|
case g:HierarchicalClustering.Group[Double, Int] =>
|
|
{
|
|
assert(g.radius <= parentRadius)
|
|
g.children.foreach { testRadii(_, g.radius) }
|
|
}
|
|
}
|
|
}
|
|
|
|
"Naive Clustering" should "be correct" in {
|
|
val clusters =
|
|
Time("Naive Cluster"){
|
|
HierarchicalClustering.naive(
|
|
TEST_POINTS.toIndexedSeq.zipWithIndex,
|
|
Measure
|
|
)
|
|
}
|
|
testRadii(clusters, Double.PositiveInfinity)
|
|
var lastRadius = Double.PositiveInfinity
|
|
for( (c, i) <- clusters.orderedIterator.zipWithIndex)
|
|
{
|
|
assert(c.radius <= lastRadius)
|
|
lastRadius = c.radius
|
|
}
|
|
|
|
// println(
|
|
// Scatterplot(clusters.orderedIterator.collect {
|
|
// case c if c.radius > 0 =>
|
|
// {
|
|
// val centroid =
|
|
// positionToBins(
|
|
// Measure.centroid(c.elements.map { _.position }.toSeq)
|
|
// )
|
|
// val divergence =
|
|
// c.elements.map { e =>
|
|
// val bins = positionToBins(e.position)
|
|
// Discretized.klDivergence(bins, centroid)
|
|
// }.max
|
|
// println(s"${c.radius}, $divergence")
|
|
// (c.radius, divergence)
|
|
// }
|
|
// }.toSeq
|
|
// )
|
|
// )
|
|
// println(
|
|
// Scatterplot(clusters.threshold(0.3).collect {
|
|
// case c if c.radius > 0 =>
|
|
// {
|
|
// val centroid =
|
|
// positionToBins(
|
|
// Measure.centroid(c.elements.map { _.position }.toSeq)
|
|
// )
|
|
// val divergence =
|
|
// c.elements.map { e =>
|
|
// val bins = positionToBins(e.position)
|
|
// Discretized.klDivergence(bins, centroid)
|
|
// }.max
|
|
// println(s"${c.radius}, $divergence")
|
|
// (c.radius, divergence)
|
|
// }
|
|
// }.toSeq
|
|
// )
|
|
// )
|
|
|
|
// println(clusters)
|
|
}
|
|
|
|
} |