added takeSamples method

takeSamples method takes a specified number of samples from the RDD and
outputs it in an array.
This commit is contained in:
Edison Tung 2011-11-21 16:38:44 -08:00
parent 3b9d9de583
commit a3bc012af8

View file

@ -91,6 +91,39 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] =
new SampledRDD(this, withReplacement, fraction, seed)
def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
var fraction = 0.0
var total = 0
var multiplier = 3.0
if (num > count()) {
total = Math.min(count().toInt)
fraction = 1.0
}
else if (num < 0) {
throw(new IllegalArgumentException())
}
else {
fraction = Math.min(multiplier*(num+1)/count(), 1.0)
total = num.toInt
}
var r = new SampledRDD(this, withReplacement, fraction, seed)
while (r.count() < total) {
r = new SampledRDD(this, withReplacement, fraction, seed)
}
var samples = r.collect()
var arr = new Array[T](total)
for (i <- 0 to total - 1) {
arr(i) = samples(i)
}
return arr
}
def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other))
def ++(other: RDD[T]): RDD[T] = this.union(other)