add a new top K method to RDD using a bounded priority queue
This commit is contained in:
parent
dc4073654b
commit
db5bca08ff
|
@ -35,6 +35,7 @@ import spark.rdd.ZippedPartitionsRDD2
|
|||
import spark.rdd.ZippedPartitionsRDD3
|
||||
import spark.rdd.ZippedPartitionsRDD4
|
||||
import spark.storage.StorageLevel
|
||||
import spark.util.BoundedPriorityQueue
|
||||
|
||||
import SparkContext._
|
||||
|
||||
|
@ -722,6 +723,29 @@ abstract class RDD[T: ClassManifest](
|
|||
case _ => throw new UnsupportedOperationException("empty collection")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the top K elements from this RDD as defined by
|
||||
* the specified implicit Ordering[T].
|
||||
* @param num the number of top elements to return
|
||||
* @param ord the implicit ordering for T
|
||||
* @return an array of top elements
|
||||
*/
|
||||
def top(num: Int)(implicit ord: Ordering[T]): Array[T] = {
|
||||
val topK = mapPartitions { items =>
|
||||
val queue = new BoundedPriorityQueue[T](num)
|
||||
queue ++= items
|
||||
Iterator(queue)
|
||||
}.reduce { (queue1, queue2) =>
|
||||
queue1 ++= queue2
|
||||
queue1
|
||||
}
|
||||
|
||||
val builder = Array.newBuilder[T]
|
||||
builder.sizeHint(topK.size)
|
||||
builder ++= topK
|
||||
builder.result()
|
||||
}
|
||||
|
||||
/**
|
||||
* Save this RDD as a text file, using string representations of elements.
|
||||
*/
|
||||
|
|
48
core/src/main/scala/spark/util/BoundedPriorityQueue.scala
Normal file
48
core/src/main/scala/spark/util/BoundedPriorityQueue.scala
Normal file
|
@ -0,0 +1,48 @@
|
|||
package spark.util
|
||||
|
||||
import java.util.{PriorityQueue => JPriorityQueue}
|
||||
import scala.collection.generic.Growable
|
||||
|
||||
/**
|
||||
* Bounded priority queue. This class modifies the original PriorityQueue's
|
||||
* add/offer methods such that only the top K elements are retained. The top
|
||||
* K elements are defined by an implicit Ordering[A].
|
||||
*/
|
||||
class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A], mf: ClassManifest[A])
|
||||
extends JPriorityQueue[A](maxSize, ord) with Growable[A] {
|
||||
|
||||
override def offer(a: A): Boolean = {
|
||||
if (size < maxSize) super.offer(a)
|
||||
else maybeReplaceLowest(a)
|
||||
}
|
||||
|
||||
override def add(a: A): Boolean = offer(a)
|
||||
|
||||
override def ++=(xs: TraversableOnce[A]): this.type = {
|
||||
xs.foreach(add)
|
||||
this
|
||||
}
|
||||
|
||||
override def +=(elem: A): this.type = {
|
||||
add(elem)
|
||||
this
|
||||
}
|
||||
|
||||
override def +=(elem1: A, elem2: A, elems: A*): this.type = {
|
||||
this += elem1 += elem2 ++= elems
|
||||
}
|
||||
|
||||
private def maybeReplaceLowest(a: A): Boolean = {
|
||||
val head = peek()
|
||||
if (head != null && ord.gt(a, head)) {
|
||||
poll()
|
||||
super.offer(a)
|
||||
} else false
|
||||
}
|
||||
}
|
||||
|
||||
object BoundedPriorityQueue {
|
||||
import scala.collection.JavaConverters._
|
||||
implicit def asIterable[A](queue: BoundedPriorityQueue[A]): Iterable[A] = queue.asScala
|
||||
}
|
||||
|
|
@ -317,4 +317,23 @@ class RDDSuite extends FunSuite with LocalSparkContext {
|
|||
assert(sample.size === checkSample.size)
|
||||
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
|
||||
}
|
||||
|
||||
test("top with predefined ordering") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val nums = Array.range(1, 100000)
|
||||
val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)
|
||||
val topK = ints.top(5)
|
||||
assert(topK.size === 5)
|
||||
assert(topK.sorted === nums.sorted.takeRight(5))
|
||||
}
|
||||
|
||||
test("top with custom ordering") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val words = Vector("a", "b", "c", "d")
|
||||
implicit val ord = implicitly[Ordering[String]].reverse
|
||||
val rdd = sc.makeRDD(words, 2)
|
||||
val topK = rdd.top(2)
|
||||
assert(topK.size === 2)
|
||||
assert(topK.sorted === Array("b", "a"))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue