Merge pull request #649 from ryanlecompte/master
Add top K method to RDD using a bounded priority queue
This commit is contained in:
commit
f961aac8b2
|
@ -36,6 +36,7 @@ import spark.rdd.ZippedPartitionsRDD2
|
||||||
import spark.rdd.ZippedPartitionsRDD3
|
import spark.rdd.ZippedPartitionsRDD3
|
||||||
import spark.rdd.ZippedPartitionsRDD4
|
import spark.rdd.ZippedPartitionsRDD4
|
||||||
import spark.storage.StorageLevel
|
import spark.storage.StorageLevel
|
||||||
|
import spark.util.BoundedPriorityQueue
|
||||||
|
|
||||||
import SparkContext._
|
import SparkContext._
|
||||||
|
|
||||||
|
@ -723,6 +724,24 @@ abstract class RDD[T: ClassManifest](
|
||||||
case _ => throw new UnsupportedOperationException("empty collection")
|
case _ => throw new UnsupportedOperationException("empty collection")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the top K elements from this RDD as defined by
|
||||||
|
* the specified implicit Ordering[T].
|
||||||
|
* @param num the number of top elements to return
|
||||||
|
* @param ord the implicit ordering for T
|
||||||
|
* @return an array of top elements
|
||||||
|
*/
|
||||||
|
def top(num: Int)(implicit ord: Ordering[T]): Array[T] = {
|
||||||
|
mapPartitions { items =>
|
||||||
|
val queue = new BoundedPriorityQueue[T](num)
|
||||||
|
queue ++= items
|
||||||
|
Iterator.single(queue)
|
||||||
|
}.reduce { (queue1, queue2) =>
|
||||||
|
queue1 ++= queue2
|
||||||
|
queue1
|
||||||
|
}.toArray
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Save this RDD as a text file, using string representations of elements.
|
* Save this RDD as a text file, using string representations of elements.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -86,7 +86,6 @@ JavaRDDLike[T, JavaRDD[T]] {
|
||||||
*/
|
*/
|
||||||
def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] =
|
def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] =
|
||||||
wrapRDD(rdd.subtract(other, p))
|
wrapRDD(rdd.subtract(other, p))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
object JavaRDD {
|
object JavaRDD {
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
package spark.api.java
|
package spark.api.java
|
||||||
|
|
||||||
import java.util.{List => JList}
|
import java.util.{List => JList, Comparator}
|
||||||
import scala.Tuple2
|
import scala.Tuple2
|
||||||
import scala.collection.JavaConversions._
|
import scala.collection.JavaConversions._
|
||||||
|
|
||||||
|
@ -359,4 +359,29 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
|
||||||
def toDebugString(): String = {
|
def toDebugString(): String = {
|
||||||
rdd.toDebugString
|
rdd.toDebugString
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the top K elements from this RDD as defined by
|
||||||
|
* the specified Comparator[T].
|
||||||
|
* @param num the number of top elements to return
|
||||||
|
* @param comp the comparator that defines the order
|
||||||
|
* @return an array of top elements
|
||||||
|
*/
|
||||||
|
def top(num: Int, comp: Comparator[T]): JList[T] = {
|
||||||
|
import scala.collection.JavaConversions._
|
||||||
|
val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp))
|
||||||
|
val arr: java.util.Collection[T] = topElems.toSeq
|
||||||
|
new java.util.ArrayList(arr)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the top K elements from this RDD using the
|
||||||
|
* natural ordering for T.
|
||||||
|
* @param num the number of top elements to return
|
||||||
|
* @return an array of top elements
|
||||||
|
*/
|
||||||
|
def top(num: Int): JList[T] = {
|
||||||
|
val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]]
|
||||||
|
top(num, comp)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
45
core/src/main/scala/spark/util/BoundedPriorityQueue.scala
Normal file
45
core/src/main/scala/spark/util/BoundedPriorityQueue.scala
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
package spark.util
|
||||||
|
|
||||||
|
import java.io.Serializable
|
||||||
|
import java.util.{PriorityQueue => JPriorityQueue}
|
||||||
|
import scala.collection.generic.Growable
|
||||||
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bounded priority queue. This class wraps the original PriorityQueue
|
||||||
|
* class and modifies it such that only the top K elements are retained.
|
||||||
|
* The top K elements are defined by an implicit Ordering[A].
|
||||||
|
*/
|
||||||
|
class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A])
|
||||||
|
extends Iterable[A] with Growable[A] with Serializable {
|
||||||
|
|
||||||
|
private val underlying = new JPriorityQueue[A](maxSize, ord)
|
||||||
|
|
||||||
|
override def iterator: Iterator[A] = underlying.iterator.asScala
|
||||||
|
|
||||||
|
override def ++=(xs: TraversableOnce[A]): this.type = {
|
||||||
|
xs.foreach { this += _ }
|
||||||
|
this
|
||||||
|
}
|
||||||
|
|
||||||
|
override def +=(elem: A): this.type = {
|
||||||
|
if (size < maxSize) underlying.offer(elem)
|
||||||
|
else maybeReplaceLowest(elem)
|
||||||
|
this
|
||||||
|
}
|
||||||
|
|
||||||
|
override def +=(elem1: A, elem2: A, elems: A*): this.type = {
|
||||||
|
this += elem1 += elem2 ++= elems
|
||||||
|
}
|
||||||
|
|
||||||
|
override def clear() { underlying.clear() }
|
||||||
|
|
||||||
|
private def maybeReplaceLowest(a: A): Boolean = {
|
||||||
|
val head = underlying.peek()
|
||||||
|
if (head != null && ord.gt(a, head)) {
|
||||||
|
underlying.poll()
|
||||||
|
underlying.offer(a)
|
||||||
|
} else false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -317,4 +317,23 @@ class RDDSuite extends FunSuite with LocalSparkContext {
|
||||||
assert(sample.size === checkSample.size)
|
assert(sample.size === checkSample.size)
|
||||||
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
|
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("top with predefined ordering") {
|
||||||
|
sc = new SparkContext("local", "test")
|
||||||
|
val nums = Array.range(1, 100000)
|
||||||
|
val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)
|
||||||
|
val topK = ints.top(5)
|
||||||
|
assert(topK.size === 5)
|
||||||
|
assert(topK.sorted === nums.sorted.takeRight(5))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("top with custom ordering") {
|
||||||
|
sc = new SparkContext("local", "test")
|
||||||
|
val words = Vector("a", "b", "c", "d")
|
||||||
|
implicit val ord = implicitly[Ordering[String]].reverse
|
||||||
|
val rdd = sc.makeRDD(words, 2)
|
||||||
|
val topK = rdd.top(2)
|
||||||
|
assert(topK.size === 2)
|
||||||
|
assert(topK.sorted === Array("b", "a"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue