359 lines
10 KiB
Scala
359 lines
10 KiB
Scala
package org.mimirdb.pip.lib
|
|
import scala.reflect.ClassTag
|
|
|
|
class KDTree[I: Ordering, V](dimensions: Int)(implicit iTag: ClassTag[I])
|
|
{
|
|
type Key = Array[I]
|
|
var root: Option[Node] = None
|
|
private var _size = 0
|
|
|
|
def insert(position: Key, value: V): Unit =
|
|
{
|
|
assert(position.size == dimensions)
|
|
root = Some(root match {
|
|
case Some(node) => node.insert(position, value)
|
|
case None => Leaf(position, Set(value), 0)
|
|
})
|
|
_size = _size + 1
|
|
}
|
|
|
|
def insertAll(elements: Iterable[(Key, V)]): Unit =
|
|
{
|
|
elements.foreach { el => assert(el._1.size == dimensions)}
|
|
|
|
if(elements.isEmpty) { return }
|
|
|
|
root = Some(root match {
|
|
case Some(node) => node.insertAll(elements.toSeq)
|
|
case None => fragmentSeq(elements.toSeq, 0)
|
|
})
|
|
_size = _size + elements.size
|
|
}
|
|
|
|
def remove(position: Key, value: V): Boolean =
|
|
{
|
|
root match {
|
|
case None => return false
|
|
case Some(r) =>
|
|
{
|
|
val (newRoot, ret) = r.remove(position, value)
|
|
root = newRoot
|
|
_size = _size - 1
|
|
return ret
|
|
}
|
|
}
|
|
}
|
|
|
|
def isEmpty = root.isEmpty
|
|
def size = _size
|
|
|
|
def nearest(position: Key, measure: Distance[I], ignore: (Array[I], Set[V], Double) => Boolean = {(_, _, _) => false}): Option[(Key, Set[V], Double)] =
|
|
{
|
|
root.flatMap { _.nearest(
|
|
position = position,
|
|
measure = measure,
|
|
best = None,
|
|
ignore = ignore
|
|
) }.map { case (found, distance) =>
|
|
( found.position, found.values, distance)
|
|
}
|
|
}
|
|
|
|
def depth(): Int =
|
|
root.map { _.depth() }.getOrElse(0)
|
|
|
|
override def toString(): String =
|
|
root.map { _.toString("", "") }.getOrElse("[Empty Tree]")
|
|
|
|
private def fragmentSeq(elements: Seq[(Key, V)], dimBase: Int): Node =
|
|
{
|
|
assert(!elements.isEmpty)
|
|
if(elements.size == 1){
|
|
return Leaf(elements(0)._1, Set(elements(0)._2), dimBase)
|
|
}
|
|
for(i <- 0 until dimensions)
|
|
{
|
|
val dim = (dimBase + i) % dimensions
|
|
val sorted =
|
|
elements.sorted(new Ordering[(Key, V)] {
|
|
def compare(a: (Key, V), b: (Key, V)): Int =
|
|
Ordering[I].compare(a._1(dim), b._1(dim))
|
|
})
|
|
val mid = (sorted.size+1) / 2 // +1 to round up.
|
|
|
|
val midSplit = sorted(mid)._1(dim)
|
|
|
|
val (leftElements, rightElements) =
|
|
sorted.partition { case (position, v) =>
|
|
Ordering[I].compare(position(dim), midSplit) < 0
|
|
}
|
|
|
|
if(!leftElements.isEmpty && !rightElements.isEmpty){
|
|
return Inner(
|
|
split = midSplit,
|
|
dim = dim,
|
|
left = fragmentSeq(leftElements, nextDim(dim)),
|
|
right = fragmentSeq(rightElements, nextDim(dim))
|
|
)
|
|
} else {
|
|
val newSplit: Option[I] =
|
|
leftElements.reverse.find { case (position, v) =>
|
|
position(dim) != midSplit
|
|
}.orElse { rightElements.find { case (position, v) =>
|
|
position(dim) != midSplit
|
|
}}.map { _._1(dim) }
|
|
|
|
if(newSplit.isDefined){
|
|
val (leftElements, rightElements) =
|
|
sorted.partition { case (position, v) =>
|
|
Ordering[I].compare(position(dim), newSplit.get) < 0
|
|
}
|
|
assert(!leftElements.isEmpty)
|
|
assert(!rightElements.isEmpty)
|
|
return Inner(
|
|
split = newSplit.get,
|
|
dim = dim,
|
|
left = fragmentSeq(leftElements, nextDim(dim)),
|
|
right = fragmentSeq(rightElements, nextDim(dim))
|
|
)
|
|
}
|
|
// If we get to this point, all of our element positions are
|
|
// equal on the given dimension. Fall through the loop to try
|
|
// the next dimension
|
|
}
|
|
}
|
|
// If we get to this point, all of our element positions are equal
|
|
// on all dimensions. Aggregate them into a single leaf
|
|
Leaf(
|
|
elements(0)._1,
|
|
elements.map { _._2 }.toSet,
|
|
dim = dimBase
|
|
)
|
|
}
|
|
|
|
@inline def nextDim(dim: Int) =
|
|
(dim+1)%dimensions
|
|
|
|
sealed trait Node
|
|
{
|
|
def insert(position: Key, value: V): Node
|
|
def insertAll(elements: Seq[(Key, V)]): Node
|
|
def remove(position: Key, value: V): (Option[Node], Boolean)
|
|
def nearest(position: Key, measure: Distance[I], best: Option[(Leaf, Double)], ignore: (Array[I], Set[V], Double) => Boolean): Option[(Leaf, Double)]
|
|
def depth(): Int
|
|
def toString(firstPrefix: String, restPrefix: String): String
|
|
}
|
|
|
|
case class Inner(
|
|
split: I,
|
|
dim: Int,
|
|
var left: Node,
|
|
var right: Node
|
|
) extends Node
|
|
{
|
|
|
|
def insert(position: Key, value: V): Node =
|
|
{
|
|
if(Ordering[I].compare(position(dim), split) >= 0){
|
|
// position(dim) >= split
|
|
right.insert(position, value)
|
|
} else {
|
|
// position(dim) < split
|
|
left.insert(position, value)
|
|
}
|
|
return this
|
|
}
|
|
def insertAll(elements: Seq[(Key, V)]): Node =
|
|
{
|
|
if(elements.size == 1){ insert(elements.head._1, elements.head._2) }
|
|
else {
|
|
val (forLeft, forRight) =
|
|
elements.partition { case (k, v) =>
|
|
Ordering[I].compare(k(dim), split) < 0
|
|
}
|
|
if(forLeft.size > 0){
|
|
left = left.insertAll(forLeft)
|
|
}
|
|
if(forRight.size > 0){
|
|
right = right.insertAll(forRight)
|
|
}
|
|
this
|
|
}
|
|
}
|
|
def remove(position: Key, value: V): (Option[Node], Boolean) =
|
|
{
|
|
if(Ordering[I].compare(position(dim), split) < 0){
|
|
left.remove(position, value) match {
|
|
case (Some(newLeft), ret) =>
|
|
{
|
|
left = newLeft
|
|
return (Some(this), ret)
|
|
}
|
|
case (None, ret) =>
|
|
{
|
|
return (Some(right), ret)
|
|
}
|
|
}
|
|
} else {
|
|
right.remove(position, value) match {
|
|
case (Some(newRight), ret) =>
|
|
{
|
|
right = newRight
|
|
return (Some(this), ret)
|
|
}
|
|
case (None, ret) =>
|
|
{
|
|
return (Some(left), ret)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
def nearest(position: Key, measure: Distance[I], best: Option[(Leaf, Double)], ignore: (Array[I], Set[V], Double) => Boolean): Option[(Leaf, Double)] =
|
|
{
|
|
val (near, far) =
|
|
if(Ordering[I].compare(position(dim), split) >= 0){
|
|
(right, left)
|
|
} else {
|
|
(left, right)
|
|
}
|
|
|
|
// position is right
|
|
val candidate =
|
|
near.nearest(
|
|
position = position,
|
|
measure = measure,
|
|
best = best,
|
|
ignore = ignore,
|
|
)
|
|
|
|
val splitDistance =
|
|
measure.pointToPlane(position, split, dim)
|
|
|
|
if(candidate.isEmpty || candidate.get._2 > splitDistance){
|
|
return far.nearest(
|
|
position = position,
|
|
measure = measure,
|
|
best = candidate,
|
|
ignore = ignore,
|
|
)
|
|
} else {
|
|
return candidate
|
|
}
|
|
}
|
|
def depth(): Int =
|
|
Math.max(left.depth(), right.depth())+1
|
|
def toString(firstPrefix: String, restPrefix: String): String =
|
|
firstPrefix + s" Dim[$dim] < $split\n" +
|
|
left.toString(restPrefix+" +-", restPrefix+" | ") + "\n" +
|
|
right.toString(restPrefix+" +-", restPrefix+" ")
|
|
|
|
}
|
|
|
|
|
|
case class Leaf(
|
|
position: Key,
|
|
var values: Set[V],
|
|
dim: Int
|
|
) extends Node
|
|
{
|
|
def insert(otherPosition: Key, otherValue: V): Node =
|
|
{
|
|
Ordering[I].compare(position(dim), otherPosition(dim)) match {
|
|
case d if d == 0 =>
|
|
{
|
|
val otherComparison =
|
|
(1 until dimensions).foldLeft(
|
|
None:Option[(Int, Int)]
|
|
) {
|
|
case (f@Some(_), _) => f
|
|
case (None, i) =>
|
|
val dimNew = (i+dim)%dimensions
|
|
val d = Ordering[I].compare(position(dimNew), otherPosition(dimNew))
|
|
if(d == 0){ None }
|
|
else { Some(dimNew, d) }
|
|
}
|
|
|
|
otherComparison match {
|
|
case None => // identical position
|
|
{
|
|
this.values = this.values + otherValue
|
|
return this
|
|
}
|
|
case Some( (dimNew, d) ) if d >= 0 =>
|
|
{
|
|
Inner(
|
|
split = position(dimNew),
|
|
dim = dimNew,
|
|
left = Leaf(otherPosition, Set(otherValue), nextDim(dimNew)),
|
|
right = copy(dim = nextDim(dimNew)),
|
|
)
|
|
}
|
|
case Some( (dimNew, d) ) =>
|
|
{
|
|
Inner(
|
|
split = otherPosition(dimNew),
|
|
dim = dimNew,
|
|
left = copy(dim = nextDim(dimNew)),
|
|
right = Leaf(otherPosition, Set(otherValue), nextDim(dimNew)),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
case d if d >= 0 =>
|
|
{
|
|
return Inner(
|
|
split = position(dim),
|
|
dim = dim,
|
|
left = Leaf(otherPosition, Set(otherValue), nextDim(dim)),
|
|
right = copy(dim = nextDim(dim)),
|
|
)
|
|
}
|
|
case _ =>
|
|
{
|
|
return Inner(
|
|
split = otherPosition(dim),
|
|
dim = dim,
|
|
left = copy(dim = nextDim(dim)),
|
|
right = Leaf(otherPosition, Set(otherValue), nextDim(dim)),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
def insertAll(elements: Seq[(Key, V)]): Node =
|
|
{
|
|
|
|
fragmentSeq(values.toSeq.map { (position, _) } ++ elements, dim)
|
|
}
|
|
def remove(position: Key, value: V): (Option[Node], Boolean) =
|
|
{
|
|
if(position == this.position && this.values.contains(value) )
|
|
{
|
|
if(values.size == 1){
|
|
return (None, true)
|
|
} else {
|
|
this.values = this.values - value
|
|
return (Some(this), true)
|
|
}
|
|
} else {
|
|
return (Some(this), false)
|
|
}
|
|
}
|
|
def nearest(position: Key, measure: Distance[I], best: Option[(Leaf, Double)], ignore: (Array[I], Set[V], Double) => Boolean): Option[(Leaf, Double)] =
|
|
{
|
|
val distance = measure.pointToPoint(position, this.position)
|
|
best match {
|
|
// If this point is *exactly* at position, and we want distinct matches, skip it
|
|
case _ if ignore(this.position, this.values, distance) => best
|
|
// If there exists a better point that we've found so far, skip this point
|
|
case Some( best ) if best._2 <= distance => Some(best)
|
|
// ... otherwise return this point
|
|
case _ => Some( (this, distance) )
|
|
}
|
|
}
|
|
def depth(): Int = 1
|
|
|
|
def toString(firstPrefix: String, restPrefix: String): String =
|
|
firstPrefix + s" <${position.mkString(", ")}> -> ${values.mkString(", ")}"
|
|
}
|
|
}
|