Skip to content

Commit 49e6ba7

Browse files
committed
SPARK-1162 added takeOrdered to pyspark
1 parent ada310a commit 49e6ba7

File tree

1 file changed

+87
-6
lines changed

1 file changed

+87
-6
lines changed

python/pyspark/rdd.py

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
from tempfile import NamedTemporaryFile
3030
from threading import Thread
3131
import warnings
32-
from heapq import heappush, heappop, heappushpop
32+
import heapq
33+
import bisect
3334

3435
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
3536
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
@@ -41,9 +42,9 @@
4142

4243
from py4j.java_collections import ListConverter, MapConverter
4344

44-
4545
__all__ = ["RDD"]
4646

47+
4748
def _extract_concise_traceback():
4849
"""
4950
This function returns the traceback info for a callsite, returns a dict
@@ -91,6 +92,58 @@ def __exit__(self, type, value, tb):
9192
if _spark_stack_depth == 0:
9293
self._context._jsc.setCallSite(None)
9394

95+
class MaxHeapQ(object):
96+
"""
97+
An implementation of MaxHeap.
98+
99+
"""
100+
101+
def __init__(self):
102+
# we start from q[1], this makes calculating children as trivial as 2 * k
103+
self.q = [0]
104+
105+
def _swim(self, k):
106+
while (k > 1) and (self.q[k/2] < self.q[k]):
107+
self._swap(k, k/2)
108+
k = k/2
109+
110+
def _swap(self, i, j):
111+
t = self.q[i]
112+
self.q[i] = self.q[j]
113+
self.q[j] = t
114+
115+
def _sink(self, k):
116+
N=len(self.q)-1
117+
while 2*k <= N:
118+
j = 2*k
119+
# Here we test if both children are greater than parent
120+
# if not swap with larger one.
121+
if j<N and self.q[j] < self.q[j+1]:
122+
j = j+1
123+
if(self.q[k] > self.q[j]):
124+
break
125+
self._swap(k, j)
126+
k = j
127+
128+
def insert(self, value):
129+
self.q.append(value)
130+
self._swim(len(self.q) - 1)
131+
132+
def getQ(self):
133+
return self.q[1:]
134+
135+
def replaceRoot(self, value):
136+
if(self.q[1] > value):
137+
self.q[1] = value
138+
self._sink(1)
139+
140+
def delMax(self):
141+
r = self.q[1]
142+
self.q[1] = self.q[len(self.q) - 1]
143+
self.q.pop()
144+
self._sink(1)
145+
return r
146+
94147
class RDD(object):
95148
"""
96149
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
@@ -696,23 +749,51 @@ def top(self, num):
696749
Note: It returns the list sorted in descending order.
697750
>>> sc.parallelize([10, 4, 2, 12, 3]).top(1)
698751
[12]
699-
>>> sc.parallelize([2, 3, 4, 5, 6]).cache().top(2)
700-
[6, 5]
752+
>>> sc.parallelize([2, 3, 4, 5, 6], 2).cache().top(2)
753+
[5, 6]
701754
"""
702755
def topIterator(iterator):
703756
q = []
704757
for k in iterator:
705758
if len(q) < num:
706-
heappush(q, k)
759+
heapq.heappush(q, k)
707760
else:
708-
heappushpop(q, k)
761+
heapq.heappushpop(q, k)
709762
yield q
710763

711764
def merge(a, b):
712765
return next(topIterator(a + b))
713766

714767
return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True)
715768

769+
def takeOrdered(self, num, key=None):
770+
"""
771+
Get the N elements from a RDD ordered in ascending order or as specified
772+
by the optional key function.
773+
774+
>>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6)
775+
[1, 2, 3, 4, 5, 6]
776+
>>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7], 2).takeOrdered(6, key=lambda x: -x)
777+
[(-10, 10), (-9, 9), (-7, 7), (-6, 6), (-5, 5), (-4, 4)]
778+
"""
779+
780+
def topNKeyedElems(iterator, key_=None):
781+
q = MaxHeapQ()
782+
for k in iterator:
783+
if not (key_ == None):
784+
k = (key_(k), k)
785+
if (len(q.q) -1) < num:
786+
q.insert(k)
787+
else:
788+
q.replaceRoot(k)
789+
yield q.getQ()
790+
791+
def merge(a, b):
792+
return next(topNKeyedElems(a + b))
793+
794+
return sorted(self.mapPartitions(lambda i: topNKeyedElems(i, key)).reduce(merge))
795+
796+
716797
def take(self, num):
717798
"""
718799
Take the first num elements of the RDD.

0 commit comments

Comments
 (0)