|
29 | 29 | from tempfile import NamedTemporaryFile |
30 | 30 | from threading import Thread |
31 | 31 | import warnings |
32 | | -from heapq import heappush, heappop, heappushpop |
| 32 | +import heapq |
| 33 | +import bisect |
33 | 34 |
|
34 | 35 | from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ |
35 | 36 | BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long |
|
41 | 42 |
|
42 | 43 | from py4j.java_collections import ListConverter, MapConverter |
43 | 44 |
|
44 | | - |
45 | 45 | __all__ = ["RDD"] |
46 | 46 |
|
| 47 | + |
47 | 48 | def _extract_concise_traceback(): |
48 | 49 | """ |
49 | 50 | This function returns the traceback info for a callsite, returns a dict |
@@ -91,6 +92,58 @@ def __exit__(self, type, value, tb): |
91 | 92 | if _spark_stack_depth == 0: |
92 | 93 | self._context._jsc.setCallSite(None) |
93 | 94 |
|
| 95 | +class MaxHeapQ(object): |
| 96 | + """ |
| 97 | + An implementation of MaxHeap. |
| 98 | +
|
| 99 | + """ |
| 100 | + |
| 101 | + def __init__(self): |
| 102 | + # we start from q[1], this makes calculating children as trivial as 2 * k |
| 103 | + self.q = [0] |
| 104 | + |
| 105 | + def _swim(self, k): |
| 106 | + while (k > 1) and (self.q[k/2] < self.q[k]): |
| 107 | + self._swap(k, k/2) |
| 108 | + k = k/2 |
| 109 | + |
| 110 | + def _swap(self, i, j): |
| 111 | + t = self.q[i] |
| 112 | + self.q[i] = self.q[j] |
| 113 | + self.q[j] = t |
| 114 | + |
| 115 | + def _sink(self, k): |
| 116 | + N=len(self.q)-1 |
| 117 | + while 2*k <= N: |
| 118 | + j = 2*k |
| 119 | + # Here we test if both children are greater than parent |
| 120 | + # if not swap with larger one. |
| 121 | + if j<N and self.q[j] < self.q[j+1]: |
| 122 | + j = j+1 |
| 123 | + if(self.q[k] > self.q[j]): |
| 124 | + break |
| 125 | + self._swap(k, j) |
| 126 | + k = j |
| 127 | + |
| 128 | + def insert(self, value): |
| 129 | + self.q.append(value) |
| 130 | + self._swim(len(self.q) - 1) |
| 131 | + |
| 132 | + def getQ(self): |
| 133 | + return self.q[1:] |
| 134 | + |
| 135 | + def replaceRoot(self, value): |
| 136 | + if(self.q[1] > value): |
| 137 | + self.q[1] = value |
| 138 | + self._sink(1) |
| 139 | + |
| 140 | + def delMax(self): |
| 141 | + r = self.q[1] |
| 142 | + self.q[1] = self.q[len(self.q) - 1] |
| 143 | + self.q.pop() |
| 144 | + self._sink(1) |
| 145 | + return r |
| 146 | + |
94 | 147 | class RDD(object): |
95 | 148 | """ |
96 | 149 | A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. |
@@ -696,23 +749,51 @@ def top(self, num): |
696 | 749 | Note: It returns the list sorted in descending order. |
697 | 750 | >>> sc.parallelize([10, 4, 2, 12, 3]).top(1) |
698 | 751 | [12] |
699 | | - >>> sc.parallelize([2, 3, 4, 5, 6]).cache().top(2) |
700 | | - [6, 5] |
| 752 | + >>> sc.parallelize([2, 3, 4, 5, 6], 2).cache().top(2) |
| 753 | + [5, 6] |
701 | 754 | """ |
702 | 755 | def topIterator(iterator): |
703 | 756 | q = [] |
704 | 757 | for k in iterator: |
705 | 758 | if len(q) < num: |
706 | | - heappush(q, k) |
| 759 | + heapq.heappush(q, k) |
707 | 760 | else: |
708 | | - heappushpop(q, k) |
| 761 | + heapq.heappushpop(q, k) |
709 | 762 | yield q |
710 | 763 |
|
711 | 764 | def merge(a, b): |
712 | 765 | return next(topIterator(a + b)) |
713 | 766 |
|
714 | 767 | return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True) |
715 | 768 |
|
| 769 | + def takeOrdered(self, num, key=None): |
| 770 | + """ |
| 771 | + Get the N elements from a RDD ordered in ascending order or as specified |
| 772 | + by the optional key function. |
| 773 | +
|
| 774 | + >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6) |
| 775 | + [1, 2, 3, 4, 5, 6] |
| 776 | + >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7], 2).takeOrdered(6, key=lambda x: -x) |
| 777 | + [(-10, 10), (-9, 9), (-7, 7), (-6, 6), (-5, 5), (-4, 4)] |
| 778 | + """ |
| 779 | + |
| 780 | + def topNKeyedElems(iterator, key_=None): |
| 781 | + q = MaxHeapQ() |
| 782 | + for k in iterator: |
| 783 | + if not (key_ == None): |
| 784 | + k = (key_(k), k) |
| 785 | + if (len(q.q) -1) < num: |
| 786 | + q.insert(k) |
| 787 | + else: |
| 788 | + q.replaceRoot(k) |
| 789 | + yield q.getQ() |
| 790 | + |
| 791 | + def merge(a, b): |
| 792 | + return next(topNKeyedElems(a + b)) |
| 793 | + |
| 794 | + return sorted(self.mapPartitions(lambda i: topNKeyedElems(i, key)).reduce(merge)) |
| 795 | + |
| 796 | + |
716 | 797 | def take(self, num): |
717 | 798 | """ |
718 | 799 | Take the first num elements of the RDD. |
|
0 commit comments