|
30 | 30 | from threading import Thread |
31 | 31 | import warnings |
32 | 32 | import heapq |
33 | | -import bisect |
34 | 33 |
|
35 | 34 | from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ |
36 | 35 | BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long |
@@ -95,55 +94,70 @@ def __exit__(self, type, value, tb): |
95 | 94 | class MaxHeapQ(object): |
96 | 95 | """ |
97 | 96 | An implementation of MaxHeap. |
98 | | -
|
| 97 | + >>> import pyspark.rdd |
| 98 | + >>> heap = pyspark.rdd.MaxHeapQ(5) |
| 99 | + >>> [heap.insert(i) for i in range(10)] |
| 100 | + [None, None, None, None, None, None, None, None, None, None] |
| 101 | + >>> sorted(heap.getElements()) |
| 102 | + [0, 1, 2, 3, 4] |
| 103 | + >>> heap = pyspark.rdd.MaxHeapQ(5) |
| 104 | + >>> [heap.insert(i) for i in range(9, -1, -1)] |
| 105 | + [None, None, None, None, None, None, None, None, None, None] |
| 106 | + >>> sorted(heap.getElements()) |
| 107 | + [0, 1, 2, 3, 4] |
| 108 | + >>> heap = pyspark.rdd.MaxHeapQ(1) |
| 109 | + >>> [heap.insert(i) for i in range(9, -1, -1)] |
| 110 | + [None, None, None, None, None, None, None, None, None, None] |
| 111 | + >>> heap.getElements() |
| 112 | + [0] |
99 | 113 | """ |
100 | | - |
101 | | - def __init__(self): |
| 114 | + |
| 115 | + def __init__(self, maxsize): |
102 | 116 | # we start from q[1], this makes calculating children as trivial as 2 * k |
103 | 117 | self.q = [0] |
104 | | - |
| 118 | + self.maxsize = maxsize |
| 119 | + |
105 | 120 | def _swim(self, k): |
106 | 121 | while (k > 1) and (self.q[k/2] < self.q[k]): |
107 | 122 | self._swap(k, k/2) |
108 | 123 | k = k/2 |
109 | | - |
| 124 | + |
110 | 125 | def _swap(self, i, j): |
111 | 126 | t = self.q[i] |
112 | 127 | self.q[i] = self.q[j] |
113 | 128 | self.q[j] = t |
114 | 129 |
|
115 | 130 | def _sink(self, k): |
116 | | - N=len(self.q)-1 |
117 | | - while 2*k <= N: |
118 | | - j = 2*k |
| 131 | + N = self.size() |
| 132 | + while 2 * k <= N: |
| 133 | + j = 2 * k |
119 | 134 | # Here we test if both children are greater than parent |
120 | 135 | # if not swap with larger one. |
121 | | - if j<N and self.q[j] < self.q[j+1]: |
122 | | - j = j+1 |
| 136 | + if j < N and self.q[j] < self.q[j + 1]: |
| 137 | + j = j + 1 |
123 | 138 | if(self.q[k] > self.q[j]): |
124 | 139 | break |
125 | 140 | self._swap(k, j) |
126 | 141 | k = j |
127 | 142 |
|
| 143 | + def size(self): |
| 144 | + return len(self.q) - 1 |
| 145 | + |
128 | 146 | def insert(self, value): |
129 | | - self.q.append(value) |
130 | | - self._swim(len(self.q) - 1) |
| 147 | + if (self.size()) < self.maxsize: |
| 148 | + self.q.append(value) |
| 149 | + self._swim(self.size()) |
| 150 | + else: |
| 151 | + self._replaceRoot(value) |
131 | 152 |
|
132 | | - def getQ(self): |
| 153 | + def getElements(self): |
133 | 154 | return self.q[1:] |
134 | 155 |
|
135 | | - def replaceRoot(self, value): |
| 156 | + def _replaceRoot(self, value): |
136 | 157 | if(self.q[1] > value): |
137 | 158 | self.q[1] = value |
138 | 159 | self._sink(1) |
139 | 160 |
|
140 | | - def delMax(self): |
141 | | - r = self.q[1] |
142 | | - self.q[1] = self.q[len(self.q) - 1] |
143 | | - self.q.pop() |
144 | | - self._sink(1) |
145 | | - return r |
146 | | - |
147 | 161 | class RDD(object): |
148 | 162 | """ |
149 | 163 | A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. |
@@ -778,15 +792,12 @@ def takeOrdered(self, num, key=None): |
778 | 792 | """ |
779 | 793 |
|
780 | 794 | def topNKeyedElems(iterator, key_=None): |
781 | | - q = MaxHeapQ() |
| 795 | + q = MaxHeapQ(num) |
782 | 796 | for k in iterator: |
783 | | - if not (key_ == None): |
| 797 | + if key_ != None: |
784 | 798 | k = (key_(k), k) |
785 | | - if (len(q.q) -1) < num: |
786 | | - q.insert(k) |
787 | | - else: |
788 | | - q.replaceRoot(k) |
789 | | - yield q.getQ() |
| 799 | + q.insert(k) |
| 800 | + yield q.getElements() |
790 | 801 |
|
791 | 802 | def merge(a, b): |
792 | 803 | return next(topNKeyedElems(a + b)) |
|
0 commit comments