mimir-pip/lib/src/org/mimirdb/pip/lib/KDTree.scala

251 lines
6.5 KiB
Scala

package org.mimirdb.pip.lib
import scala.reflect.ClassTag
class KDTree[I: Ordering, V](dimensions: Int)(implicit iTag: ClassTag[I])
{
type Key = Array[I]
var root: Option[Node] = None
def insert(position: Key, value: V): Unit =
{
assert(position.size == dimensions)
root = Some(root match {
case Some(node) => node.insert(position, value)
case None => Leaf(position, value, 0)
})
}
def insertAll(elements: Iterable[(Key, V)]): Unit =
{
elements.foreach { el => assert(el._1.size == dimensions)}
if(elements.isEmpty) { return }
root = Some(root match {
case Some(node) => node.insertAll(elements.toSeq)
case None => fragmentSeq(elements.toSeq, 0)
})
}
def remove(position: Key): Option[V] =
{
root match {
case None => return None
case Some(r) =>
{
val (newRoot, ret) = r.remove(position)
root = newRoot
return ret
}
}
}
def nearest(position: Key, measure: Distance[I]): (Key, V) =
{
val (found, _) =
root.get.nearest(
position,
measure,
None
)
return (found.position, found.value)
}
def depth(): Int =
root.map { _.depth() }.getOrElse(0)
override def toString(): String =
root.map { _.toString("", "") }.getOrElse("[Empty Tree]")
private def fragmentSeq(elements: Seq[(Key, V)], dim: Int): Node =
{
assert(!elements.isEmpty)
if(elements.size == 1){
return Leaf(elements(0)._1, elements(0)._2, dim)
}
val sorted =
elements.sorted(new Ordering[(Key, V)] {
def compare(a: (Key, V), b: (Key, V)): Int =
Ordering[I].compare(a._1(dim), b._1(dim))
})
val mid = (sorted.size+1) / 2 // +1 to round up.
// Partition into [0, mid), [mid, sorted.size)
Inner(
split = sorted(mid)._1(dim),
dim = dim,
left = fragmentSeq(sorted.slice(0, mid), nextDim(dim)),
right = fragmentSeq(sorted.slice(mid, sorted.size), nextDim(dim))
)
}
@inline def nextDim(dim: Int) =
(dim+1)%dimensions
sealed trait Node
{
def insert(position: Key, value: V): Node
def insertAll(elements: Seq[(Key, V)]): Node
def remove(position: Key): (Option[Node], Option[V])
def nearest(position: Key, measure: Distance[I], best: Option[(Leaf, Double)]): (Leaf, Double)
def depth(): Int
def toString(firstPrefix: String, restPrefix: String): String
}
case class Inner(
split: I,
dim: Int,
var left: Node,
var right: Node
) extends Node
{
def insert(position: Key, value: V): Node =
{
if(Ordering[I].compare(position(dim), split) >= 0){
// position(dim) >= split
right.insert(position, value)
} else {
// position(dim) < split
left.insert(position, value)
}
return this
}
def insertAll(elements: Seq[(Key, V)]): Node =
{
if(elements.size == 1){ insert(elements.head._1, elements.head._2) }
else {
val (forLeft, forRight) =
elements.partition { case (k, v) =>
Ordering[I].compare(k(dim), split) < 0
}
if(forLeft.size > 0){
left = left.insertAll(forLeft)
}
if(forRight.size > 0){
right = right.insertAll(forRight)
}
this
}
}
def remove(position: Key): (Option[Node], Option[V]) =
{
if(Ordering[I].compare(position(dim), split) < 0){
left.remove(position) match {
case (Some(newLeft), ret) =>
{
left = newLeft
return (Some(this), ret)
}
case (None, ret) =>
{
return (Some(right), ret)
}
}
} else {
right.remove(position) match {
case (Some(newRight), ret) =>
{
right = newRight
return (Some(this), ret)
}
case (None, ret) =>
{
return (Some(left), ret)
}
}
}
}
def nearest(position: Key, measure: Distance[I], best: Option[(Leaf, Double)]): (Leaf, Double) =
{
val (near, far) =
if(Ordering[I].compare(position(dim), split) >= 0){
(right, left)
} else {
(left, right)
}
// position is right
val candidate =
near.nearest(position, measure, best)
val splitDistance =
measure.pointToPlane(position, split, dim)
if(candidate._2 > splitDistance){
return far.nearest(position, measure, Some(candidate))
} else {
return candidate
}
}
def depth(): Int =
Math.max(left.depth(), right.depth())+1
def toString(firstPrefix: String, restPrefix: String): String =
firstPrefix + s" Dim[$dim] < $split\n" +
left.toString(restPrefix+" +-", restPrefix+" | ") + "\n" +
right.toString(restPrefix+" +-", restPrefix+" ")
}
case class Leaf(
position: Key,
var value: V,
dim: Int
) extends Node
{
def insert(otherPosition: Key, otherValue: V): Node =
{
Ordering[I].compare(position(dim), otherPosition(dim)) match {
case d if d == 0 =>
{
value = otherValue
return this
}
case d if d >= 0 =>
{
return Inner(
split = position(dim),
dim = dim,
left = Leaf(otherPosition, otherValue, nextDim(dim)),
right = copy(dim = nextDim(dim)),
)
}
case _ =>
{
return Inner(
split = otherPosition(dim),
dim = dim,
left = copy(dim = nextDim(dim)),
right = Leaf(otherPosition, otherValue, nextDim(dim)),
)
}
}
}
def insertAll(elements: Seq[(Key, V)]): Node =
{
fragmentSeq((position, value) +: elements, dim)
}
def remove(position: Key): (Option[Node], Option[V]) =
{
if(position == this.position)
{
return (None, Some(value))
} else {
return (Some(this), None)
}
}
def nearest(position: Key, measure: Distance[I], best: Option[(Leaf, Double)]): (Leaf, Double) =
{
val distance = measure.pointToPoint(position, this.position)
best match {
case Some( best ) if best._2 <= distance => best
case _ => (this, distance)
}
}
def depth(): Int = 1
def toString(firstPrefix: String, restPrefix: String): String =
firstPrefix + s" <${position.mkString(", ")}> -> $value"
}
}