73 lines
1.9 KiB
Scala
73 lines
1.9 KiB
Scala
package org.mimirdb.pip.lib
|
|
|
|
import org.scalatest.flatspec.AnyFlatSpec
|
|
import TestData._
|
|
|
|
class KDTreeTests extends AnyFlatSpec {
|
|
|
|
"The KD Tree" should "be correct" in {
|
|
val tree = new KDTree[Double, Int](DIMENSIONS)
|
|
|
|
tree.insertAll(TEST_POINTS.zipWithIndex)
|
|
|
|
// Roughly balanced (100 nodes)
|
|
assert(tree.depth < 10)
|
|
|
|
// Nearest point to a point in the tree should be itself
|
|
assert(tree.nearest(TEST_POINTS(0), Measure)._2 == 0)
|
|
|
|
// Nearest points to random points should be the
|
|
// actual minimum
|
|
for( (q, qid) <- TEST_QUERIES.zipWithIndex){
|
|
|
|
// Compute the point-to-point distance between the
|
|
// query and every other point, and then recover
|
|
// the idx of the closest point.
|
|
val actualClosestId =
|
|
TEST_POINTS.map {
|
|
Measure.pointToPoint(_, q)
|
|
}.zipWithIndex.minBy( _._1 )._2
|
|
|
|
// Use the tree to find the nearest point
|
|
val nearestId =
|
|
tree.nearest(q, Measure)
|
|
|
|
assert(nearestId._2 == actualClosestId,
|
|
s"Query $qid not closest"
|
|
)
|
|
}
|
|
|
|
val AMOUNT_TO_KEEP = 20
|
|
val keep = TEST_POINTS.slice(0, AMOUNT_TO_KEEP)
|
|
val drop = TEST_POINTS.slice(AMOUNT_TO_KEEP, TEST_POINTS.size)
|
|
|
|
for( i <- drop )
|
|
{
|
|
tree.remove(i)
|
|
}
|
|
|
|
// println(tree.toString)
|
|
|
|
// Nearest points to random points should still be the
|
|
// actual minimum
|
|
for( (q, qid) <- TEST_QUERIES.zipWithIndex){
|
|
|
|
// Compute the point-to-point distance between the
|
|
// query and every other point, and then recover
|
|
// the idx of the closest point.
|
|
val actualClosestId =
|
|
keep.map {
|
|
Measure.pointToPoint(_, q)
|
|
}.zipWithIndex.minBy( _._1 )._2
|
|
|
|
// Use the tree to find the nearest point
|
|
val nearestId =
|
|
tree.nearest(q, Measure)
|
|
|
|
assert(nearestId._2 == actualClosestId,
|
|
s"Query $qid not closest"
|
|
)
|
|
}
|
|
|
|
}
|
|
} |