-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Expand file tree
/
Copy pathIDSelector.cpp
More file actions
130 lines (113 loc) · 3.48 KB
/
IDSelector.cpp
File metadata and controls
130 lines (113 loc) · 3.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDSelector.h>
namespace faiss {
/***********************************************************************
* IDSelectorRange
***********************************************************************/
IDSelectorRange::IDSelectorRange(
idx_t imin_in,
idx_t imax_in,
bool assume_sorted_in)
: imin(imin_in), imax(imax_in), assume_sorted(assume_sorted_in) {}
bool IDSelectorRange::is_member(idx_t id) const {
return id >= imin && id < imax;
}
void IDSelectorRange::find_sorted_ids_bounds(
size_t list_size,
const idx_t* ids,
size_t* jmin_out,
size_t* jmax_out) const {
FAISS_ASSERT(assume_sorted);
if (list_size == 0 || imax <= ids[0] || imin > ids[list_size - 1]) {
*jmin_out = *jmax_out = 0;
return;
}
// bisection to find imin
if (ids[0] >= imin) {
*jmin_out = 0;
} else {
size_t j0 = 0, j1 = list_size;
while (j1 > j0 + 1) {
size_t jmed = (j0 + j1) / 2;
if (ids[jmed] >= imin) {
j1 = jmed;
} else {
j0 = jmed;
}
}
*jmin_out = j1;
}
// bisection to find imax
if (*jmin_out == list_size || ids[*jmin_out] >= imax) {
*jmax_out = *jmin_out;
} else {
size_t j0 = *jmin_out, j1 = list_size;
while (j1 > j0 + 1) {
size_t jmed = (j0 + j1) / 2;
if (ids[jmed] >= imax) {
j1 = jmed;
} else {
j0 = jmed;
}
}
*jmax_out = j1;
}
}
/***********************************************************************
* IDSelectorArray
***********************************************************************/
IDSelectorArray::IDSelectorArray(size_t n_in, const idx_t* ids_in)
: n(n_in), ids(ids_in) {}
bool IDSelectorArray::is_member(idx_t id) const {
for (size_t i = 0; i < n; i++) {
if (ids[i] == id) {
return true;
}
}
return false;
}
/***********************************************************************
* IDSelectorBatch
***********************************************************************/
IDSelectorBatch::IDSelectorBatch(size_t n, const idx_t* indices) {
nbits = 0;
while (n > (size_t{1} << nbits)) {
nbits++;
}
nbits += 5;
// for n = 1M, nbits = 25 is optimal, see P56659518
mask = ((idx_t)1 << nbits) - 1;
bloom.resize(size_t{1} << (nbits - 3), 0);
for (size_t i = 0; i < n; i++) {
idx_t id = indices[i];
set.insert(id);
id &= mask;
bloom[id >> 3] |= 1 << (id & 7);
}
}
bool IDSelectorBatch::is_member(idx_t i) const {
idx_t im = i & mask;
if (!(bloom[im >> 3] & (1 << (im & 7)))) {
return false;
}
return set.count(i);
}
/***********************************************************************
* IDSelectorBitmap
***********************************************************************/
IDSelectorBitmap::IDSelectorBitmap(size_t n_in, const uint8_t* bitmap_in)
: n(n_in), bitmap(bitmap_in) {}
bool IDSelectorBitmap::is_member(idx_t ii) const {
uint64_t i = ii;
if ((i >> 3) >= n) {
return false;
}
return (bitmap[i >> 3] >> (i & 7)) & 1;
}
} // namespace faiss