Skip to content

Commit e8a08e2

Browse files
committed
Code review comments.
1 parent 49e6ba7 commit e8a08e2

File tree

1 file changed

+40
-29
lines changed

1 file changed

+40
-29
lines changed

python/pyspark/rdd.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from threading import Thread
3131
import warnings
3232
import heapq
33-
import bisect
3433

3534
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
3635
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
@@ -95,55 +94,70 @@ def __exit__(self, type, value, tb):
9594
class MaxHeapQ(object):
9695
"""
9796
An implementation of MaxHeap.
98-
97+
>>> import pyspark.rdd
98+
>>> heap = pyspark.rdd.MaxHeapQ(5)
99+
>>> [heap.insert(i) for i in range(10)]
100+
[None, None, None, None, None, None, None, None, None, None]
101+
>>> sorted(heap.getElements())
102+
[0, 1, 2, 3, 4]
103+
>>> heap = pyspark.rdd.MaxHeapQ(5)
104+
>>> [heap.insert(i) for i in range(9, -1, -1)]
105+
[None, None, None, None, None, None, None, None, None, None]
106+
>>> sorted(heap.getElements())
107+
[0, 1, 2, 3, 4]
108+
>>> heap = pyspark.rdd.MaxHeapQ(1)
109+
>>> [heap.insert(i) for i in range(9, -1, -1)]
110+
[None, None, None, None, None, None, None, None, None, None]
111+
>>> heap.getElements()
112+
[0]
99113
"""
100-
101-
def __init__(self):
114+
115+
def __init__(self, maxsize):
102116
# we start from q[1], this makes calculating children as trivial as 2 * k
103117
self.q = [0]
104-
118+
self.maxsize = maxsize
119+
105120
def _swim(self, k):
106121
while (k > 1) and (self.q[k/2] < self.q[k]):
107122
self._swap(k, k/2)
108123
k = k/2
109-
124+
110125
def _swap(self, i, j):
111126
t = self.q[i]
112127
self.q[i] = self.q[j]
113128
self.q[j] = t
114129

115130
def _sink(self, k):
116-
N=len(self.q)-1
117-
while 2*k <= N:
118-
j = 2*k
131+
N = self.size()
132+
while 2 * k <= N:
133+
j = 2 * k
119134
# Here we test if both children are greater than parent
120135
# if not swap with larger one.
121-
if j<N and self.q[j] < self.q[j+1]:
122-
j = j+1
136+
if j < N and self.q[j] < self.q[j + 1]:
137+
j = j + 1
123138
if(self.q[k] > self.q[j]):
124139
break
125140
self._swap(k, j)
126141
k = j
127142

143+
def size(self):
144+
return len(self.q) - 1
145+
128146
def insert(self, value):
129-
self.q.append(value)
130-
self._swim(len(self.q) - 1)
147+
if (self.size()) < self.maxsize:
148+
self.q.append(value)
149+
self._swim(self.size())
150+
else:
151+
self._replaceRoot(value)
131152

132-
def getQ(self):
153+
def getElements(self):
133154
return self.q[1:]
134155

135-
def replaceRoot(self, value):
156+
def _replaceRoot(self, value):
136157
if(self.q[1] > value):
137158
self.q[1] = value
138159
self._sink(1)
139160

140-
def delMax(self):
141-
r = self.q[1]
142-
self.q[1] = self.q[len(self.q) - 1]
143-
self.q.pop()
144-
self._sink(1)
145-
return r
146-
147161
class RDD(object):
148162
"""
149163
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
@@ -778,15 +792,12 @@ def takeOrdered(self, num, key=None):
778792
"""
779793

780794
def topNKeyedElems(iterator, key_=None):
781-
q = MaxHeapQ()
795+
q = MaxHeapQ(num)
782796
for k in iterator:
783-
if not (key_ == None):
797+
if key_ != None:
784798
k = (key_(k), k)
785-
if (len(q.q) -1) < num:
786-
q.insert(k)
787-
else:
788-
q.replaceRoot(k)
789-
yield q.getQ()
799+
q.insert(k)
800+
yield q.getElements()
790801

791802
def merge(a, b):
792803
return next(topNKeyedElems(a + b))

0 commit comments

Comments
 (0)