mimir-pip/lib/src/org/mimirdb/pip/lib/KDTree.scala

359 lines
10 KiB
Scala

package org.mimirdb.pip.lib
import scala.reflect.ClassTag
class KDTree[I: Ordering, V](dimensions: Int)(implicit iTag: ClassTag[I])
{
type Key = Array[I]
var root: Option[Node] = None
private var _size = 0
def insert(position: Key, value: V): Unit =
{
assert(position.size == dimensions)
root = Some(root match {
case Some(node) => node.insert(position, value)
case None => Leaf(position, Set(value), 0)
})
_size = _size + 1
}
def insertAll(elements: Iterable[(Key, V)]): Unit =
{
elements.foreach { el => assert(el._1.size == dimensions)}
if(elements.isEmpty) { return }
root = Some(root match {
case Some(node) => node.insertAll(elements.toSeq)
case None => fragmentSeq(elements.toSeq, 0)
})
_size = _size + elements.size
}
def remove(position: Key, value: V): Boolean =
{
root match {
case None => return false
case Some(r) =>
{
val (newRoot, ret) = r.remove(position, value)
root = newRoot
_size = _size - 1
return ret
}
}
}
def isEmpty = root.isEmpty
def size = _size
def nearest(position: Key, measure: Distance[I], ignore: (Array[I], Set[V], Double) => Boolean = {(_, _, _) => false}): Option[(Key, Set[V], Double)] =
{
root.flatMap { _.nearest(
position = position,
measure = measure,
best = None,
ignore = ignore
) }.map { case (found, distance) =>
( found.position, found.values, distance)
}
}
def depth(): Int =
root.map { _.depth() }.getOrElse(0)
override def toString(): String =
root.map { _.toString("", "") }.getOrElse("[Empty Tree]")
private def fragmentSeq(elements: Seq[(Key, V)], dimBase: Int): Node =
{
assert(!elements.isEmpty)
if(elements.size == 1){
return Leaf(elements(0)._1, Set(elements(0)._2), dimBase)
}
for(i <- 0 until dimensions)
{
val dim = (dimBase + i) % dimensions
val sorted =
elements.sorted(new Ordering[(Key, V)] {
def compare(a: (Key, V), b: (Key, V)): Int =
Ordering[I].compare(a._1(dim), b._1(dim))
})
val mid = (sorted.size+1) / 2 // +1 to round up.
val midSplit = sorted(mid)._1(dim)
val (leftElements, rightElements) =
sorted.partition { case (position, v) =>
Ordering[I].compare(position(dim), midSplit) < 0
}
if(!leftElements.isEmpty && !rightElements.isEmpty){
return Inner(
split = midSplit,
dim = dim,
left = fragmentSeq(leftElements, nextDim(dim)),
right = fragmentSeq(rightElements, nextDim(dim))
)
} else {
val newSplit: Option[I] =
leftElements.reverse.find { case (position, v) =>
position(dim) != midSplit
}.orElse { rightElements.find { case (position, v) =>
position(dim) != midSplit
}}.map { _._1(dim) }
if(newSplit.isDefined){
val (leftElements, rightElements) =
sorted.partition { case (position, v) =>
Ordering[I].compare(position(dim), newSplit.get) < 0
}
assert(!leftElements.isEmpty)
assert(!rightElements.isEmpty)
return Inner(
split = newSplit.get,
dim = dim,
left = fragmentSeq(leftElements, nextDim(dim)),
right = fragmentSeq(rightElements, nextDim(dim))
)
}
// If we get to this point, all of our element positions are
// equal on the given dimension. Fall through the loop to try
// the next dimension
}
}
// If we get to this point, all of our element positions are equal
// on all dimensions. Aggregate them into a single leaf
Leaf(
elements(0)._1,
elements.map { _._2 }.toSet,
dim = dimBase
)
}
@inline def nextDim(dim: Int) =
(dim+1)%dimensions
sealed trait Node
{
def insert(position: Key, value: V): Node
def insertAll(elements: Seq[(Key, V)]): Node
def remove(position: Key, value: V): (Option[Node], Boolean)
def nearest(position: Key, measure: Distance[I], best: Option[(Leaf, Double)], ignore: (Array[I], Set[V], Double) => Boolean): Option[(Leaf, Double)]
def depth(): Int
def toString(firstPrefix: String, restPrefix: String): String
}
case class Inner(
split: I,
dim: Int,
var left: Node,
var right: Node
) extends Node
{
def insert(position: Key, value: V): Node =
{
if(Ordering[I].compare(position(dim), split) >= 0){
// position(dim) >= split
right.insert(position, value)
} else {
// position(dim) < split
left.insert(position, value)
}
return this
}
def insertAll(elements: Seq[(Key, V)]): Node =
{
if(elements.size == 1){ insert(elements.head._1, elements.head._2) }
else {
val (forLeft, forRight) =
elements.partition { case (k, v) =>
Ordering[I].compare(k(dim), split) < 0
}
if(forLeft.size > 0){
left = left.insertAll(forLeft)
}
if(forRight.size > 0){
right = right.insertAll(forRight)
}
this
}
}
def remove(position: Key, value: V): (Option[Node], Boolean) =
{
if(Ordering[I].compare(position(dim), split) < 0){
left.remove(position, value) match {
case (Some(newLeft), ret) =>
{
left = newLeft
return (Some(this), ret)
}
case (None, ret) =>
{
return (Some(right), ret)
}
}
} else {
right.remove(position, value) match {
case (Some(newRight), ret) =>
{
right = newRight
return (Some(this), ret)
}
case (None, ret) =>
{
return (Some(left), ret)
}
}
}
}
def nearest(position: Key, measure: Distance[I], best: Option[(Leaf, Double)], ignore: (Array[I], Set[V], Double) => Boolean): Option[(Leaf, Double)] =
{
val (near, far) =
if(Ordering[I].compare(position(dim), split) >= 0){
(right, left)
} else {
(left, right)
}
// position is right
val candidate =
near.nearest(
position = position,
measure = measure,
best = best,
ignore = ignore,
)
val splitDistance =
measure.pointToPlane(position, split, dim)
if(candidate.isEmpty || candidate.get._2 > splitDistance){
return far.nearest(
position = position,
measure = measure,
best = candidate,
ignore = ignore,
)
} else {
return candidate
}
}
def depth(): Int =
Math.max(left.depth(), right.depth())+1
def toString(firstPrefix: String, restPrefix: String): String =
firstPrefix + s" Dim[$dim] < $split\n" +
left.toString(restPrefix+" +-", restPrefix+" | ") + "\n" +
right.toString(restPrefix+" +-", restPrefix+" ")
}
case class Leaf(
position: Key,
var values: Set[V],
dim: Int
) extends Node
{
def insert(otherPosition: Key, otherValue: V): Node =
{
Ordering[I].compare(position(dim), otherPosition(dim)) match {
case d if d == 0 =>
{
val otherComparison =
(1 until dimensions).foldLeft(
None:Option[(Int, Int)]
) {
case (f@Some(_), _) => f
case (None, i) =>
val dimNew = (i+dim)%dimensions
val d = Ordering[I].compare(position(dimNew), otherPosition(dimNew))
if(d == 0){ None }
else { Some(dimNew, d) }
}
otherComparison match {
case None => // identical position
{
this.values = this.values + otherValue
return this
}
case Some( (dimNew, d) ) if d >= 0 =>
{
Inner(
split = position(dimNew),
dim = dimNew,
left = Leaf(otherPosition, Set(otherValue), nextDim(dimNew)),
right = copy(dim = nextDim(dimNew)),
)
}
case Some( (dimNew, d) ) =>
{
Inner(
split = otherPosition(dimNew),
dim = dimNew,
left = copy(dim = nextDim(dimNew)),
right = Leaf(otherPosition, Set(otherValue), nextDim(dimNew)),
)
}
}
}
case d if d >= 0 =>
{
return Inner(
split = position(dim),
dim = dim,
left = Leaf(otherPosition, Set(otherValue), nextDim(dim)),
right = copy(dim = nextDim(dim)),
)
}
case _ =>
{
return Inner(
split = otherPosition(dim),
dim = dim,
left = copy(dim = nextDim(dim)),
right = Leaf(otherPosition, Set(otherValue), nextDim(dim)),
)
}
}
}
def insertAll(elements: Seq[(Key, V)]): Node =
{
fragmentSeq(values.toSeq.map { (position, _) } ++ elements, dim)
}
def remove(position: Key, value: V): (Option[Node], Boolean) =
{
if(position == this.position && this.values.contains(value) )
{
if(values.size == 1){
return (None, true)
} else {
this.values = this.values - value
return (Some(this), true)
}
} else {
return (Some(this), false)
}
}
def nearest(position: Key, measure: Distance[I], best: Option[(Leaf, Double)], ignore: (Array[I], Set[V], Double) => Boolean): Option[(Leaf, Double)] =
{
val distance = measure.pointToPoint(position, this.position)
best match {
// If this point is *exactly* at position, and we want distinct matches, skip it
case _ if ignore(this.position, this.values, distance) => best
// If there exists a better point that we've found so far, skip this point
case Some( best ) if best._2 <= distance => Some(best)
// ... otherwise return this point
case _ => Some( (this, distance) )
}
}
def depth(): Int = 1
def toString(firstPrefix: String, restPrefix: String): String =
firstPrefix + s" <${position.mkString(", ")}> -> ${values.mkString(", ")}"
}
}