259 lines
6.7 KiB
Scala
259 lines
6.7 KiB
Scala
package org.mimirdb.pip.lib
|
|
import scala.reflect.ClassTag
|
|
|
|
trait KDDistance[I]
|
|
{
|
|
val min: I
|
|
val max: I
|
|
def pointToPoint(a: Array[I], b: Array[I]): Double
|
|
def pointToPlane(a: Array[I], b: I, dim: Int): Double
|
|
}
|
|
|
|
class KDTree[I: Ordering, V](dimensions: Int)(implicit iTag: ClassTag[I])
|
|
{
|
|
type Key = Array[I]
|
|
var root: Option[Node] = None
|
|
|
|
def insert(position: Key, value: V): Unit =
|
|
{
|
|
assert(position.size == dimensions)
|
|
root = Some(root match {
|
|
case Some(node) => node.insert(position, value)
|
|
case None => Leaf(position, value, 0)
|
|
})
|
|
}
|
|
|
|
def insertAll(elements: Iterable[(Key, V)]): Unit =
|
|
{
|
|
elements.foreach { el => assert(el._1.size == dimensions)}
|
|
|
|
if(elements.isEmpty) { return }
|
|
|
|
root = Some(root match {
|
|
case Some(node) => node.insertAll(elements.toSeq)
|
|
case None => fragmentSeq(elements.toSeq, 0)
|
|
})
|
|
}
|
|
|
|
def remove(position: Key): Option[V] =
|
|
{
|
|
root match {
|
|
case None => return None
|
|
case Some(r) =>
|
|
{
|
|
val (newRoot, ret) = r.remove(position)
|
|
root = newRoot
|
|
return ret
|
|
}
|
|
}
|
|
}
|
|
|
|
def nearest(position: Key, measure: KDDistance[I]): (Key, V) =
|
|
{
|
|
val (found, _) =
|
|
root.get.nearest(
|
|
position,
|
|
measure,
|
|
None
|
|
)
|
|
return (found.position, found.value)
|
|
}
|
|
|
|
def depth(): Int =
|
|
root.map { _.depth() }.getOrElse(0)
|
|
|
|
override def toString(): String =
|
|
root.map { _.toString("", "") }.getOrElse("[Empty Tree]")
|
|
|
|
private def fragmentSeq(elements: Seq[(Key, V)], dim: Int): Node =
|
|
{
|
|
assert(!elements.isEmpty)
|
|
if(elements.size == 1){
|
|
return Leaf(elements(0)._1, elements(0)._2, dim)
|
|
}
|
|
val sorted =
|
|
elements.sorted(new Ordering[(Key, V)] {
|
|
def compare(a: (Key, V), b: (Key, V)): Int =
|
|
Ordering[I].compare(a._1(dim), b._1(dim))
|
|
})
|
|
val mid = (sorted.size+1) / 2 // +1 to round up.
|
|
// Partition into [0, mid), [mid, sorted.size)
|
|
Inner(
|
|
split = sorted(mid)._1(dim),
|
|
dim = dim,
|
|
left = fragmentSeq(sorted.slice(0, mid), nextDim(dim)),
|
|
right = fragmentSeq(sorted.slice(mid, sorted.size), nextDim(dim))
|
|
)
|
|
}
|
|
|
|
@inline def nextDim(dim: Int) =
|
|
(dim+1)%dimensions
|
|
|
|
sealed trait Node
|
|
{
|
|
def insert(position: Key, value: V): Node
|
|
def insertAll(elements: Seq[(Key, V)]): Node
|
|
def remove(position: Key): (Option[Node], Option[V])
|
|
def nearest(position: Key, measure: KDDistance[I], best: Option[(Leaf, Double)]): (Leaf, Double)
|
|
def depth(): Int
|
|
def toString(firstPrefix: String, restPrefix: String): String
|
|
}
|
|
|
|
case class Inner(
|
|
split: I,
|
|
dim: Int,
|
|
var left: Node,
|
|
var right: Node
|
|
) extends Node
|
|
{
|
|
|
|
def insert(position: Key, value: V): Node =
|
|
{
|
|
if(Ordering[I].compare(position(dim), split) >= 0){
|
|
// position(dim) >= split
|
|
right.insert(position, value)
|
|
} else {
|
|
// position(dim) < split
|
|
left.insert(position, value)
|
|
}
|
|
return this
|
|
}
|
|
def insertAll(elements: Seq[(Key, V)]): Node =
|
|
{
|
|
if(elements.size == 1){ insert(elements.head._1, elements.head._2) }
|
|
else {
|
|
val (forLeft, forRight) =
|
|
elements.partition { case (k, v) =>
|
|
Ordering[I].compare(k(dim), split) < 0
|
|
}
|
|
if(forLeft.size > 0){
|
|
left = left.insertAll(forLeft)
|
|
}
|
|
if(forRight.size > 0){
|
|
right = right.insertAll(forRight)
|
|
}
|
|
this
|
|
}
|
|
}
|
|
def remove(position: Key): (Option[Node], Option[V]) =
|
|
{
|
|
if(Ordering[I].compare(position(dim), split) < 0){
|
|
left.remove(position) match {
|
|
case (Some(newLeft), ret) =>
|
|
{
|
|
left = newLeft
|
|
return (Some(this), ret)
|
|
}
|
|
case (None, ret) =>
|
|
{
|
|
return (Some(right), ret)
|
|
}
|
|
}
|
|
} else {
|
|
right.remove(position) match {
|
|
case (Some(newRight), ret) =>
|
|
{
|
|
right = newRight
|
|
return (Some(this), ret)
|
|
}
|
|
case (None, ret) =>
|
|
{
|
|
return (Some(left), ret)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
def nearest(position: Key, measure: KDDistance[I], best: Option[(Leaf, Double)]): (Leaf, Double) =
|
|
{
|
|
val (near, far) =
|
|
if(Ordering[I].compare(position(dim), split) >= 0){
|
|
(right, left)
|
|
} else {
|
|
(left, right)
|
|
}
|
|
|
|
// position is right
|
|
val candidate =
|
|
near.nearest(position, measure, best)
|
|
|
|
val splitDistance =
|
|
measure.pointToPlane(position, split, dim)
|
|
|
|
if(candidate._2 > splitDistance){
|
|
return far.nearest(position, measure, Some(candidate))
|
|
} else {
|
|
return candidate
|
|
}
|
|
}
|
|
def depth(): Int =
|
|
Math.max(left.depth(), right.depth())+1
|
|
def toString(firstPrefix: String, restPrefix: String): String =
|
|
firstPrefix + s" Dim[$dim] < $split\n" +
|
|
left.toString(restPrefix+" +-", restPrefix+" | ") + "\n" +
|
|
right.toString(restPrefix+" +-", restPrefix+" ")
|
|
|
|
}
|
|
|
|
|
|
case class Leaf(
|
|
position: Key,
|
|
var value: V,
|
|
dim: Int
|
|
) extends Node
|
|
{
|
|
def insert(otherPosition: Key, otherValue: V): Node =
|
|
{
|
|
Ordering[I].compare(position(dim), otherPosition(dim)) match {
|
|
case d if d == 0 =>
|
|
{
|
|
value = otherValue
|
|
return this
|
|
}
|
|
case d if d >= 0 =>
|
|
{
|
|
return Inner(
|
|
split = position(dim),
|
|
dim = dim,
|
|
left = Leaf(otherPosition, otherValue, nextDim(dim)),
|
|
right = copy(dim = nextDim(dim)),
|
|
)
|
|
}
|
|
case _ =>
|
|
{
|
|
return Inner(
|
|
split = otherPosition(dim),
|
|
dim = dim,
|
|
left = copy(dim = nextDim(dim)),
|
|
right = Leaf(otherPosition, otherValue, nextDim(dim)),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
def insertAll(elements: Seq[(Key, V)]): Node =
|
|
{
|
|
fragmentSeq((position, value) +: elements, dim)
|
|
}
|
|
def remove(position: Key): (Option[Node], Option[V]) =
|
|
{
|
|
if(position == this.position)
|
|
{
|
|
return (None, Some(value))
|
|
} else {
|
|
return (Some(this), None)
|
|
}
|
|
}
|
|
def nearest(position: Key, measure: KDDistance[I], best: Option[(Leaf, Double)]): (Leaf, Double) =
|
|
{
|
|
val distance = measure.pointToPoint(position, this.position)
|
|
best match {
|
|
case Some( best ) if best._2 <= distance => best
|
|
case _ => (this, distance)
|
|
}
|
|
}
|
|
def depth(): Int = 1
|
|
|
|
def toString(firstPrefix: String, restPrefix: String): String =
|
|
firstPrefix + s" <${position.mkString(", ")}> -> $value"
|
|
}
|
|
}
|