Skip to content

Commit 3c16374

Browse files
Jonathan EllisMichael Sokolov
authored andcommitted
Use HashMap (was TreeMap) for OnHeapHnswGraph neighbors
1 parent 1fa2be9 commit 3c16374

File tree

7 files changed

+146
-64
lines changed

7 files changed

+146
-64
lines changed

lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.Arrays;
2626
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
2727
import org.apache.lucene.codecs.CodecUtil;
28+
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter;
2829
import org.apache.lucene.index.ByteVectorValues;
2930
import org.apache.lucene.index.DocsWithFieldSet;
3031
import org.apache.lucene.index.FieldInfo;
@@ -36,7 +37,6 @@
3637
import org.apache.lucene.store.IndexInput;
3738
import org.apache.lucene.store.IndexOutput;
3839
import org.apache.lucene.util.IOUtils;
39-
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
4040
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
4141

4242
/**
@@ -227,11 +227,10 @@ private void writeMeta(
227227
} else {
228228
meta.writeInt(graph.numLevels());
229229
for (int level = 0; level < graph.numLevels(); level++) {
230-
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
231-
meta.writeInt(nodesOnLevel.size()); // number of nodes on a level
230+
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
231+
meta.writeInt(sortedNodes.length); // number of nodes on a level
232232
if (level > 0) {
233-
while (nodesOnLevel.hasNext()) {
234-
int node = nodesOnLevel.nextInt();
233+
for (int node : sortedNodes) {
235234
meta.writeInt(node); // list of nodes on a level
236235
}
237236
}
@@ -257,9 +256,8 @@ private Lucene91OnHeapHnswGraph writeGraph(
257256
// write vectors' neighbours on each level into the vectorIndex file
258257
int countOnLevel0 = graph.size();
259258
for (int level = 0; level < graph.numLevels(); level++) {
260-
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
261-
while (nodesOnLevel.hasNext()) {
262-
int node = nodesOnLevel.nextInt();
259+
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
260+
for (int node : sortedNodes) {
263261
Lucene91NeighborArray neighbors = graph.getNeighbors(level, node);
264262
int size = neighbors.size();
265263
vectorIndex.writeInt(size);

lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
2828
import org.apache.lucene.codecs.CodecUtil;
2929
import org.apache.lucene.codecs.lucene90.IndexedDISI;
30+
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter;
3031
import org.apache.lucene.index.ByteVectorValues;
3132
import org.apache.lucene.index.DocsWithFieldSet;
3233
import org.apache.lucene.index.FieldInfo;
@@ -39,7 +40,6 @@
3940
import org.apache.lucene.store.IndexInput;
4041
import org.apache.lucene.store.IndexOutput;
4142
import org.apache.lucene.util.IOUtils;
42-
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
4343
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
4444
import org.apache.lucene.util.hnsw.NeighborArray;
4545
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
@@ -261,11 +261,10 @@ private void writeMeta(
261261
} else {
262262
meta.writeInt(graph.numLevels());
263263
for (int level = 0; level < graph.numLevels(); level++) {
264-
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
265-
meta.writeInt(nodesOnLevel.size()); // number of nodes on a level
264+
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
265+
meta.writeInt(sortedNodes.length); // number of nodes on a level
266266
if (level > 0) {
267-
while (nodesOnLevel.hasNext()) {
268-
int node = nodesOnLevel.nextInt();
267+
for (int node : sortedNodes) {
269268
meta.writeInt(node); // list of nodes on a level
270269
}
271270
}
@@ -293,9 +292,8 @@ private OnHeapHnswGraph writeGraph(
293292
int countOnLevel0 = graph.size();
294293
for (int level = 0; level < graph.numLevels(); level++) {
295294
int maxConnOnLevel = level == 0 ? (M * 2) : M;
296-
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
297-
while (nodesOnLevel.hasNext()) {
298-
int node = nodesOnLevel.nextInt();
295+
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
296+
for (int node : sortedNodes) {
299297
NeighborArray neighbors = graph.getNeighbors(level, node);
300298
int size = neighbors.size();
301299
vectorIndex.writeInt(size);

lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
3131
import org.apache.lucene.codecs.KnnVectorsWriter;
3232
import org.apache.lucene.codecs.lucene90.IndexedDISI;
33+
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsWriter;
3334
import org.apache.lucene.index.ByteVectorValues;
3435
import org.apache.lucene.index.DocsWithFieldSet;
3536
import org.apache.lucene.index.FieldInfo;
@@ -303,9 +304,8 @@ private HnswGraph reconstructAndWriteGraph(
303304
for (int level = 1; level < graph.numLevels(); level++) {
304305
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
305306
int[] newNodes = new int[nodesOnLevel.size()];
306-
int n = 0;
307-
while (nodesOnLevel.hasNext()) {
308-
newNodes[n++] = oldToNewMap[nodesOnLevel.nextInt()];
307+
for (int n = 0; nodesOnLevel.hasNext(); n++) {
308+
newNodes[n] = oldToNewMap[nodesOnLevel.nextInt()];
309309
}
310310
Arrays.sort(newNodes);
311311
nodesByLevel.add(newNodes);
@@ -481,9 +481,8 @@ private void writeGraph(OnHeapHnswGraph graph) throws IOException {
481481
int countOnLevel0 = graph.size();
482482
for (int level = 0; level < graph.numLevels(); level++) {
483483
int maxConnOnLevel = level == 0 ? (M * 2) : M;
484-
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
485-
while (nodesOnLevel.hasNext()) {
486-
int node = nodesOnLevel.nextInt();
484+
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
485+
for (int node : sortedNodes) {
487486
NeighborArray neighbors = graph.getNeighbors(level, node);
488487
int size = neighbors.size();
489488
vectorIndex.writeInt(size);
@@ -570,11 +569,10 @@ private void writeMeta(
570569
} else {
571570
meta.writeInt(graph.numLevels());
572571
for (int level = 0; level < graph.numLevels(); level++) {
573-
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
574-
meta.writeInt(nodesOnLevel.size()); // number of nodes on a level
572+
int[] sortedNodes = Lucene95HnswVectorsWriter.getSortedNodes(graph.getNodesOnLevel(level));
573+
meta.writeInt(sortedNodes.length); // number of nodes on a level
575574
if (level > 0) {
576-
while (nodesOnLevel.hasNext()) {
577-
int node = nodesOnLevel.nextInt();
575+
for (int node : sortedNodes) {
578576
meta.writeInt(node); // list of nodes on a level
579577
}
580578
}

lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsWriter.java

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,8 @@ private HnswGraph reconstructAndWriteGraph(
315315
for (int level = 1; level < graph.numLevels(); level++) {
316316
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
317317
int[] newNodes = new int[nodesOnLevel.size()];
318-
int n = 0;
319-
while (nodesOnLevel.hasNext()) {
320-
newNodes[n++] = oldToNewMap[nodesOnLevel.nextInt()];
318+
for (int n = 0; nodesOnLevel.hasNext(); n++) {
319+
newNodes[n] = oldToNewMap[nodesOnLevel.nextInt()];
321320
}
322321
Arrays.sort(newNodes);
323322
nodesByLevel.add(newNodes);
@@ -677,11 +676,10 @@ private int[][] writeGraph(OnHeapHnswGraph graph) throws IOException {
677676
int countOnLevel0 = graph.size();
678677
int[][] offsets = new int[graph.numLevels()][];
679678
for (int level = 0; level < graph.numLevels(); level++) {
680-
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
681-
offsets[level] = new int[nodesOnLevel.size()];
679+
int[] sortedNodes = getSortedNodes(graph.getNodesOnLevel(level));
680+
offsets[level] = new int[sortedNodes.length];
682681
int nodeOffsetId = 0;
683-
while (nodesOnLevel.hasNext()) {
684-
int node = nodesOnLevel.nextInt();
682+
for (int node : sortedNodes) {
685683
NeighborArray neighbors = graph.getNeighbors(level, node);
686684
int size = neighbors.size();
687685
// Write size in VInt as the neighbors list is typically small
@@ -706,6 +704,15 @@ private int[][] writeGraph(OnHeapHnswGraph graph) throws IOException {
706704
return offsets;
707705
}
708706

707+
public static int[] getSortedNodes(NodesIterator nodesOnLevel) {
708+
int[] sortedNodes = new int[nodesOnLevel.size()];
709+
for (int n = 0; nodesOnLevel.hasNext(); n++) {
710+
sortedNodes[n] = nodesOnLevel.nextInt();
711+
}
712+
Arrays.sort(sortedNodes);
713+
return sortedNodes;
714+
}
715+
709716
private void writeMeta(
710717
FieldInfo field,
711718
int maxDoc,
@@ -779,6 +786,7 @@ private void writeMeta(
779786
if (level > 0) {
780787
int[] nol = new int[nodesOnLevel.size()];
781788
int numberConsumed = nodesOnLevel.consume(nol);
789+
Arrays.sort(nol);
782790
assert numberConsumed == nodesOnLevel.size();
783791
meta.writeVInt(nol.length); // number of nodes on a level
784792
for (int i = nodesOnLevel.size() - 1; i > 0; --i) {

lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ protected HnswGraph() {}
8181
public abstract int entryNode() throws IOException;
8282

8383
/**
84-
* Get all nodes on a given level as node 0th ordinals
84+
* Get all nodes on a given level as node 0th ordinals. The nodes are NOT guaranteed to be
85+
* presented in any particular order.
8586
*
8687
* @param level level for which to get all nodes
8788
* @return an iterator over nodes where {@code nextInt} returns a next node on the level
@@ -123,7 +124,8 @@ public NodesIterator getNodesOnLevel(int level) {
123124

124125
/**
125126
* Iterator over the graph nodes on a certain level, Iterator also provides the size – the total
126-
* number of nodes to be iterated over.
127+
* number of nodes to be iterated over. The nodes are NOT guaranteed to be presented in any
128+
* particular order.
127129
*/
128130
public abstract static class NodesIterator implements PrimitiveIterator.OfInt {
129131
protected final int size;

lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
2121

2222
import java.util.ArrayList;
23+
import java.util.HashMap;
2324
import java.util.List;
24-
import java.util.TreeMap;
25+
import java.util.Map;
2526
import org.apache.lucene.util.Accountable;
2627
import org.apache.lucene.util.RamUsageEstimator;
2728

@@ -40,12 +41,12 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
4041
// added to HnswBuilder, and the node values are the ordinals of those vectors.
4142
// Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals.
4243
private final List<NeighborArray> graphLevel0;
43-
// Represents levels 1-N. Each level is represented with a TreeMap that maps a levels level 0
44+
// Represents levels 1-N. Each level is represented with a Map that maps a levels level 0
4445
// ordinal to its neighbors on that level. All nodes are in level 0, so we do not need to maintain
4546
// it in this list. However, to avoid changing list indexing, we always will make the first
4647
// element
4748
// null.
48-
private final List<TreeMap<Integer, NeighborArray>> graphUpperLevels;
49+
private final List<Map<Integer, NeighborArray>> graphUpperLevels;
4950
private final int nsize;
5051
private final int nsize0;
5152

@@ -76,7 +77,7 @@ public NeighborArray getNeighbors(int level, int node) {
7677
if (level == 0) {
7778
return graphLevel0.get(node);
7879
}
79-
TreeMap<Integer, NeighborArray> levelMap = graphUpperLevels.get(level);
80+
Map<Integer, NeighborArray> levelMap = graphUpperLevels.get(level);
8081
assert levelMap.containsKey(node);
8182
return levelMap.get(node);
8283
}
@@ -103,7 +104,7 @@ public void addNode(int level, int node) {
103104
// and make this node the graph's new entry point
104105
if (level >= numLevels) {
105106
for (int i = numLevels; i <= level; i++) {
106-
graphUpperLevels.add(new TreeMap<>());
107+
graphUpperLevels.add(new HashMap<>());
107108
}
108109
numLevels = level + 1;
109110
entryNode = node;
@@ -204,4 +205,15 @@ public long ramBytesUsed() {
204205
}
205206
return total;
206207
}
208+
209+
@Override
210+
public String toString() {
211+
return "OnHeapHnswGraph(size="
212+
+ size()
213+
+ ", numLevels="
214+
+ numLevels
215+
+ ", entryNode="
216+
+ entryNode
217+
+ ")";
218+
}
207219
}

0 commit comments

Comments
 (0)