RDD sample() and takeSample() prototypes for PySpark
This commit is contained in:
parent
a9db1b7b6e
commit
a511c5379e
|
@ -21,6 +21,7 @@ from collections import defaultdict
|
|||
from itertools import chain, ifilter, imap, product
|
||||
import operator
|
||||
import os
|
||||
import sys
|
||||
import shlex
|
||||
from subprocess import Popen, PIPE
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
@ -32,6 +33,7 @@ from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
|
|||
from pyspark.join import python_join, python_left_outer_join, \
|
||||
python_right_outer_join, python_cogroup
|
||||
from pyspark.statcounter import StatCounter
|
||||
from pyspark.rddsampler import RDDSampler
|
||||
|
||||
from py4j.java_collections import ListConverter, MapConverter
|
||||
|
||||
|
@ -165,14 +167,60 @@ class RDD(object):
|
|||
.reduceByKey(lambda x, _: x) \
|
||||
.map(lambda (x, _): x)
|
||||
|
||||
# TODO: sampling needs to be re-implemented due to Batch
|
||||
#def sample(self, withReplacement, fraction, seed):
|
||||
# jrdd = self._jrdd.sample(withReplacement, fraction, seed)
|
||||
# return RDD(jrdd, self.ctx)
|
||||
def sample(self, withReplacement, fraction, seed):
|
||||
"""
|
||||
Return a sampled subset of this RDD (relies on numpy and falls back
|
||||
on default random generator if numpy is unavailable).
|
||||
|
||||
#def takeSample(self, withReplacement, num, seed):
|
||||
# vals = self._jrdd.takeSample(withReplacement, num, seed)
|
||||
# return [load_pickle(bytes(x)) for x in vals]
|
||||
>>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
|
||||
[2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
|
||||
"""
|
||||
return self.mapPartitionsWithSplit(RDDSampler(withReplacement, fraction, seed).func, True)
|
||||
|
||||
# this is ported from scala/spark/RDD.scala
|
||||
def takeSample(self, withReplacement, num, seed):
|
||||
"""
|
||||
Return a fixed-size sampled subset of this RDD (currently requires numpy).
|
||||
|
||||
>>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP
|
||||
[4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
|
||||
"""
|
||||
|
||||
fraction = 0.0
|
||||
total = 0
|
||||
multiplier = 3.0
|
||||
initialCount = self.count()
|
||||
maxSelected = 0
|
||||
|
||||
if (num < 0):
|
||||
raise ValueError
|
||||
|
||||
if initialCount > sys.maxint - 1:
|
||||
maxSelected = sys.maxint - 1
|
||||
else:
|
||||
maxSelected = initialCount
|
||||
|
||||
if num > initialCount and not withReplacement:
|
||||
total = maxSelected
|
||||
fraction = multiplier * (maxSelected + 1) / initialCount
|
||||
else:
|
||||
fraction = multiplier * (num + 1) / initialCount
|
||||
total = num
|
||||
|
||||
samples = self.sample(withReplacement, fraction, seed).collect()
|
||||
|
||||
# If the first sample didn't turn out large enough, keep trying to take samples;
|
||||
# this shouldn't happen often because we use a big multiplier for their initial size.
|
||||
# See: scala/spark/RDD.scala
|
||||
while len(samples) < total:
|
||||
if seed > sys.maxint - 2:
|
||||
seed = -1
|
||||
seed += 1
|
||||
samples = self.sample(withReplacement, fraction, seed).collect()
|
||||
|
||||
sampler = RDDSampler(withReplacement, fraction, seed+1)
|
||||
sampler.shuffle(samples)
|
||||
return samples[0:total]
|
||||
|
||||
def union(self, other):
|
||||
"""
|
||||
|
|
112
python/pyspark/rddsampler.py
Normal file
112
python/pyspark/rddsampler.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import sys
|
||||
import random
|
||||
|
||||
class RDDSampler(object):
|
||||
def __init__(self, withReplacement, fraction, seed):
|
||||
try:
|
||||
import numpy
|
||||
self._use_numpy = True
|
||||
except ImportError:
|
||||
print >> sys.stderr, "NumPy does not appear to be installed. Falling back to default random generator for sampling."
|
||||
self._use_numpy = False
|
||||
|
||||
self._seed = seed
|
||||
self._withReplacement = withReplacement
|
||||
self._fraction = fraction
|
||||
self._random = None
|
||||
self._split = None
|
||||
self._rand_initialized = False
|
||||
|
||||
def initRandomGenerator(self, split):
|
||||
if self._use_numpy:
|
||||
import numpy
|
||||
self._random = numpy.random.RandomState(self._seed)
|
||||
for _ in range(0, split):
|
||||
# discard the next few values in the sequence to have a
|
||||
# different seed for the different splits
|
||||
self._random.randint(sys.maxint)
|
||||
else:
|
||||
import random
|
||||
random.seed(self._seed)
|
||||
for _ in range(0, split):
|
||||
# discard the next few values in the sequence to have a
|
||||
# different seed for the different splits
|
||||
random.randint(0, sys.maxint)
|
||||
self._split = split
|
||||
self._rand_initialized = True
|
||||
|
||||
def getUniformSample(self, split):
|
||||
if not self._rand_initialized or split != self._split:
|
||||
self.initRandomGenerator(split)
|
||||
|
||||
if self._use_numpy:
|
||||
return self._random.random_sample()
|
||||
else:
|
||||
return random.uniform(0.0, 1.0)
|
||||
|
||||
def getPoissonSample(self, split, mean):
|
||||
if not self._rand_initialized or split != self._split:
|
||||
self.initRandomGenerator(split)
|
||||
|
||||
if self._use_numpy:
|
||||
return self._random.poisson(mean)
|
||||
else:
|
||||
# here we simulate drawing numbers n_i ~ Poisson(lambda = 1/mean) by
|
||||
# drawing a sequence of numbers delta_j ~ Exp(mean)
|
||||
num_arrivals = 1
|
||||
cur_time = 0.0
|
||||
|
||||
cur_time += random.expovariate(mean)
|
||||
|
||||
if cur_time > 1.0:
|
||||
return 0
|
||||
|
||||
while(cur_time <= 1.0):
|
||||
cur_time += random.expovariate(mean)
|
||||
num_arrivals += 1
|
||||
|
||||
return (num_arrivals - 1)
|
||||
|
||||
def shuffle(self, vals):
|
||||
if self._random == None or split != self._split:
|
||||
self.initRandomGenerator(0) # this should only ever called on the master so
|
||||
# the split does not matter
|
||||
|
||||
if self._use_numpy:
|
||||
self._random.shuffle(vals)
|
||||
else:
|
||||
random.shuffle(vals, self._random)
|
||||
|
||||
def func(self, split, iterator):
|
||||
if self._withReplacement:
|
||||
for obj in iterator:
|
||||
# For large datasets, the expected number of occurrences of each element in a sample with
|
||||
# replacement is Poisson(frac). We use that to get a count for each element.
|
||||
count = self.getPoissonSample(split, mean = self._fraction)
|
||||
for _ in range(0, count):
|
||||
yield obj
|
||||
else:
|
||||
for obj in iterator:
|
||||
if self.getUniformSample(split) <= self._fraction:
|
||||
yield obj
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in a new issue