added takeSamples method
takeSamples method takes a specified number of samples from the RDD and outputs it in an array.
This commit is contained in:
parent
3b9d9de583
commit
a3bc012af8
|
@ -91,6 +91,39 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
|
|||
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] =
|
||||
new SampledRDD(this, withReplacement, fraction, seed)
|
||||
|
||||
def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
|
||||
var fraction = 0.0
|
||||
var total = 0
|
||||
var multiplier = 3.0
|
||||
|
||||
if (num > count()) {
|
||||
total = Math.min(count().toInt)
|
||||
fraction = 1.0
|
||||
}
|
||||
else if (num < 0) {
|
||||
throw(new IllegalArgumentException())
|
||||
}
|
||||
else {
|
||||
fraction = Math.min(multiplier*(num+1)/count(), 1.0)
|
||||
total = num.toInt
|
||||
}
|
||||
|
||||
var r = new SampledRDD(this, withReplacement, fraction, seed)
|
||||
|
||||
while (r.count() < total) {
|
||||
r = new SampledRDD(this, withReplacement, fraction, seed)
|
||||
}
|
||||
|
||||
var samples = r.collect()
|
||||
var arr = new Array[T](total)
|
||||
|
||||
for (i <- 0 to total - 1) {
|
||||
arr(i) = samples(i)
|
||||
}
|
||||
|
||||
return arr
|
||||
}
|
||||
|
||||
def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other))
|
||||
|
||||
def ++(other: RDD[T]): RDD[T] = this.union(other)
|
||||
|
|
Loading…
Reference in a new issue