[SPARK-2983] [PySpark] improve performance of sortByKey()

1. skip partitionBy() when numOfPartition is 1
2. use bisect_left (O(lg(N))) instread of loop (O(N)) in
rangePartitioner

Author: Davies Liu <davies.liu@gmail.com>

Closes #1898 from davies/sort and squashes the following commits:

0a9608b [Davies Liu] Merge branch 'master' into sort
1cf9565 [Davies Liu] improve performance of sortByKey()
This commit is contained in:
Davies Liu 2014-08-13 14:57:12 -07:00 committed by Matei Zaharia
parent c974a716e1
commit 434bea1c00

View file

@ -30,6 +30,7 @@ from tempfile import NamedTemporaryFile
from threading import Thread
import warnings
import heapq
import bisect
from random import Random
from math import sqrt, log
@ -574,6 +575,8 @@ class RDD(object):
# noqa
>>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
>>> sc.parallelize(tmp).sortByKey(True, 1).collect()
[('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
>>> sc.parallelize(tmp).sortByKey(True, 2).collect()
[('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
>>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)]
@ -584,42 +587,40 @@ class RDD(object):
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
bounds = list()
if numPartitions == 1:
if self.getNumPartitions() > 1:
self = self.coalesce(1)
def sort(iterator):
return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
return self.mapPartitions(sort)
# first compute the boundary of each part via sampling: we want to partition
# the key-space into bins such that the bins have roughly the same
# number of (key, value) pairs falling into them
if numPartitions > 1:
rddSize = self.count()
# constant from Spark's RangePartitioner
maxSampleSize = numPartitions * 20.0
fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
rddSize = self.count()
maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner
fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect()
samples = sorted(samples, reverse=(not ascending), key=keyfunc)
samples = self.sample(False, fraction, 1).map(
lambda (k, v): k).collect()
samples = sorted(samples, reverse=(not ascending), key=keyfunc)
# we have numPartitions many parts but one of the them has
# an implicit boundary
for i in range(0, numPartitions - 1):
index = (len(samples) - 1) * (i + 1) / numPartitions
bounds.append(samples[index])
# we have numPartitions many parts but one of the them has
# an implicit boundary
bounds = [samples[len(samples) * (i + 1) / numPartitions]
for i in range(0, numPartitions - 1)]
def rangePartitionFunc(k):
p = 0
while p < len(bounds) and keyfunc(k) > bounds[p]:
p += 1
p = bisect.bisect_left(bounds, keyfunc(k))
if ascending:
return p
else:
return numPartitions - 1 - p
def mapFunc(iterator):
yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc)
.mapPartitions(mapFunc, preservesPartitioning=True)
.flatMap(lambda x: x, preservesPartitioning=True))
return self.partitionBy(numPartitions, rangePartitionFunc).mapPartitions(mapFunc, True)
def sortBy(self, keyfunc, ascending=True, numPartitions=None):
"""