[SPARK-2983] [PySpark] improve performance of sortByKey()
1. skip partitionBy() when numOfPartition is 1 2. use bisect_left (O(lg(N))) instread of loop (O(N)) in rangePartitioner Author: Davies Liu <davies.liu@gmail.com> Closes #1898 from davies/sort and squashes the following commits: 0a9608b [Davies Liu] Merge branch 'master' into sort 1cf9565 [Davies Liu] improve performance of sortByKey()
This commit is contained in:
parent
c974a716e1
commit
434bea1c00
|
@ -30,6 +30,7 @@ from tempfile import NamedTemporaryFile
|
|||
from threading import Thread
|
||||
import warnings
|
||||
import heapq
|
||||
import bisect
|
||||
from random import Random
|
||||
from math import sqrt, log
|
||||
|
||||
|
@ -574,6 +575,8 @@ class RDD(object):
|
|||
# noqa
|
||||
|
||||
>>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
|
||||
>>> sc.parallelize(tmp).sortByKey(True, 1).collect()
|
||||
[('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
|
||||
>>> sc.parallelize(tmp).sortByKey(True, 2).collect()
|
||||
[('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
|
||||
>>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)]
|
||||
|
@ -584,42 +587,40 @@ class RDD(object):
|
|||
if numPartitions is None:
|
||||
numPartitions = self._defaultReducePartitions()
|
||||
|
||||
bounds = list()
|
||||
if numPartitions == 1:
|
||||
if self.getNumPartitions() > 1:
|
||||
self = self.coalesce(1)
|
||||
|
||||
def sort(iterator):
|
||||
return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
|
||||
|
||||
return self.mapPartitions(sort)
|
||||
|
||||
# first compute the boundary of each part via sampling: we want to partition
|
||||
# the key-space into bins such that the bins have roughly the same
|
||||
# number of (key, value) pairs falling into them
|
||||
if numPartitions > 1:
|
||||
rddSize = self.count()
|
||||
# constant from Spark's RangePartitioner
|
||||
maxSampleSize = numPartitions * 20.0
|
||||
fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
|
||||
rddSize = self.count()
|
||||
maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner
|
||||
fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
|
||||
samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect()
|
||||
samples = sorted(samples, reverse=(not ascending), key=keyfunc)
|
||||
|
||||
samples = self.sample(False, fraction, 1).map(
|
||||
lambda (k, v): k).collect()
|
||||
samples = sorted(samples, reverse=(not ascending), key=keyfunc)
|
||||
|
||||
# we have numPartitions many parts but one of the them has
|
||||
# an implicit boundary
|
||||
for i in range(0, numPartitions - 1):
|
||||
index = (len(samples) - 1) * (i + 1) / numPartitions
|
||||
bounds.append(samples[index])
|
||||
# we have numPartitions many parts but one of the them has
|
||||
# an implicit boundary
|
||||
bounds = [samples[len(samples) * (i + 1) / numPartitions]
|
||||
for i in range(0, numPartitions - 1)]
|
||||
|
||||
def rangePartitionFunc(k):
|
||||
p = 0
|
||||
while p < len(bounds) and keyfunc(k) > bounds[p]:
|
||||
p += 1
|
||||
p = bisect.bisect_left(bounds, keyfunc(k))
|
||||
if ascending:
|
||||
return p
|
||||
else:
|
||||
return numPartitions - 1 - p
|
||||
|
||||
def mapFunc(iterator):
|
||||
yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
|
||||
return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
|
||||
|
||||
return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc)
|
||||
.mapPartitions(mapFunc, preservesPartitioning=True)
|
||||
.flatMap(lambda x: x, preservesPartitioning=True))
|
||||
return self.partitionBy(numPartitions, rangePartitionFunc).mapPartitions(mapFunc, True)
|
||||
|
||||
def sortBy(self, keyfunc, ascending=True, numPartitions=None):
|
||||
"""
|
||||
|
|
Loading…
Reference in a new issue