Skip to content

Commit 21ebd51

Browse files
Address Adrien's feedback
Add a field name for getMaxDimensions function
1 parent 3da9111 commit 21ebd51

File tree

7 files changed

+101
-16
lines changed

7 files changed

+101
-16
lines changed

lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,15 @@ public static KnnVectorsFormat forName(String name) {
8080
public abstract KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException;
8181

8282
/**
83-
* Returns the maximum number of vector dimensions supported by this codec.
83+
* Returns the maximum number of vector dimensions supported by this codec for the given field
84+
* name
8485
*
8586
* <p>Codecs should override this method to specify the maximum number of dimensions they support.
8687
*
88+
* @param fieldName the field name
8789
* @return the maximum number of vector dimensions.
8890
*/
89-
public int getMaxDimensions() {
91+
public int getMaxDimensions(String fieldName) {
9092
return DEFAULT_MAX_DIMENSIONS;
9193
}
9294

lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsFormat.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException
189189
}
190190

191191
@Override
192-
public int getMaxDimensions() {
192+
public int getMaxDimensions(String fieldName) {
193193
return MAX_DIMENSIONS;
194194
}
195195

lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException
8080
return new FieldsReader(state);
8181
}
8282

83+
@Override
84+
public int getMaxDimensions(String fieldName) {
85+
return getKnnVectorsFormatForField(fieldName).getMaxDimensions(fieldName);
86+
}
87+
8388
/**
8489
* Returns the numeric vector format that should be used for writing new segments of <code>field
8590
* </code>.

lucene/core/src/java/org/apache/lucene/index/IndexingChain.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ private void initializeFieldInfo(PerField pf) throws IOException {
625625
validateMaxVectorDimension(
626626
pf.fieldName,
627627
s.vectorDimension,
628-
indexWriterConfig.getCodec().knnVectorsFormat().getMaxDimensions());
628+
indexWriterConfig.getCodec().knnVectorsFormat().getMaxDimensions(pf.fieldName));
629629
}
630630
FieldInfo fi =
631631
fieldInfos.add(

lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.apache.lucene.codecs.KnnVectorsFormat;
3030
import org.apache.lucene.codecs.KnnVectorsReader;
3131
import org.apache.lucene.codecs.KnnVectorsWriter;
32+
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
3233
import org.apache.lucene.document.Document;
3334
import org.apache.lucene.document.Field;
3435
import org.apache.lucene.document.KnnFloatVectorField;
@@ -43,6 +44,9 @@
4344
import org.apache.lucene.index.SegmentReadState;
4445
import org.apache.lucene.index.SegmentWriteState;
4546
import org.apache.lucene.index.Sorter;
47+
import org.apache.lucene.search.IndexSearcher;
48+
import org.apache.lucene.search.KnnFloatVectorQuery;
49+
import org.apache.lucene.search.Query;
4650
import org.apache.lucene.search.TopDocs;
4751
import org.apache.lucene.store.Directory;
4852
import org.apache.lucene.tests.analysis.MockAnalyzer;
@@ -162,6 +166,50 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
162166
}
163167
}
164168

169+
public void testMaxDimensionsPerFieldFormat() throws IOException {
170+
try (Directory directory = newDirectory()) {
171+
IndexWriterConfig iwc = newIndexWriterConfig(new MockAnalyzer(random()));
172+
KnnVectorsFormat format1 =
173+
new KnnVectorsFormatMaxDims32(new Lucene95HnswVectorsFormat(16, 100));
174+
KnnVectorsFormat format2 = new Lucene95HnswVectorsFormat(16, 100);
175+
iwc.setCodec(
176+
new AssertingCodec() {
177+
@Override
178+
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
179+
if ("field1".equals(field)) {
180+
return format1;
181+
} else {
182+
return format2;
183+
}
184+
}
185+
});
186+
try (IndexWriter writer = new IndexWriter(directory, iwc)) {
187+
Document doc1 = new Document();
188+
doc1.add(new KnnFloatVectorField("field1", new float[33]));
189+
Exception exc =
190+
expectThrows(IllegalArgumentException.class, () -> writer.addDocument(doc1));
191+
assertTrue(exc.getMessage().contains("vector's dimensions must be <= [32]"));
192+
193+
Document doc2 = new Document();
194+
doc2.add(new KnnFloatVectorField("field1", new float[32]));
195+
doc2.add(new KnnFloatVectorField("field2", new float[33]));
196+
writer.addDocument(doc2);
197+
}
198+
199+
// Check that the vectors were written
200+
try (IndexReader reader = DirectoryReader.open(directory)) {
201+
IndexSearcher searcher = new IndexSearcher(reader);
202+
Query query1 = new KnnFloatVectorQuery("field1", new float[32], 10);
203+
TopDocs topDocs1 = searcher.search(query1, 1);
204+
assertEquals(1, topDocs1.scoreDocs.length);
205+
206+
Query query2 = new KnnFloatVectorQuery("field2", new float[33], 10);
207+
TopDocs topDocs2 = searcher.search(query2, 1);
208+
assertEquals(1, topDocs2.scoreDocs.length);
209+
}
210+
}
211+
}
212+
165213
private static class WriteRecordingKnnVectorsFormat extends KnnVectorsFormat {
166214
private final KnnVectorsFormat delegate;
167215
private final Set<String> fieldsWritten;
@@ -216,4 +264,28 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException
216264
return delegate.fieldsReader(state);
217265
}
218266
}
267+
268+
private static class KnnVectorsFormatMaxDims32 extends KnnVectorsFormat {
269+
private final KnnVectorsFormat delegate;
270+
271+
public KnnVectorsFormatMaxDims32(KnnVectorsFormat delegate) {
272+
super(delegate.getName());
273+
this.delegate = delegate;
274+
}
275+
276+
@Override
277+
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
278+
return delegate.fieldsWriter(state);
279+
}
280+
281+
@Override
282+
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
283+
return delegate.fieldsReader(state);
284+
}
285+
286+
@Override
287+
public int getMaxDimensions(String fieldName) {
288+
return 32;
289+
}
290+
}
219291
}

lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseFieldInfoFormatTestCase.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ public void testRandom() throws Exception {
279279
var builder = INDEX_PACKAGE_ACCESS.newFieldInfosBuilder(softDeletesField);
280280

281281
for (String field : fieldNames) {
282-
IndexableFieldType fieldType = randomFieldType(random());
282+
IndexableFieldType fieldType = randomFieldType(random(), field);
283283
boolean storeTermVectors = false;
284284
boolean storePayloads = false;
285285
boolean omitNorms = false;
@@ -318,11 +318,11 @@ public void testRandom() throws Exception {
318318
dir.close();
319319
}
320320

321-
private int getVectorsMaxDimensions() {
322-
return Codec.getDefault().knnVectorsFormat().getMaxDimensions();
321+
private int getVectorsMaxDimensions(String fieldName) {
322+
return Codec.getDefault().knnVectorsFormat().getMaxDimensions(fieldName);
323323
}
324324

325-
private IndexableFieldType randomFieldType(Random r) {
325+
private IndexableFieldType randomFieldType(Random r, String fieldName) {
326326
FieldType type = new FieldType();
327327

328328
if (r.nextBoolean()) {
@@ -355,7 +355,7 @@ private IndexableFieldType randomFieldType(Random r) {
355355
}
356356

357357
if (r.nextBoolean()) {
358-
int dimension = 1 + r.nextInt(getVectorsMaxDimensions());
358+
int dimension = 1 + r.nextInt(getVectorsMaxDimensions(fieldName));
359359
VectorSimilarityFunction similarityFunction =
360360
RandomPicks.randomFrom(r, VectorSimilarityFunction.values());
361361
VectorEncoding encoding = RandomPicks.randomFrom(r, VectorEncoding.values());

lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ protected void addRandomFields(Document doc) {
8585
}
8686
}
8787

88-
private int getVectorsMaxDimensions() {
89-
return Codec.getDefault().knnVectorsFormat().getMaxDimensions();
88+
private int getVectorsMaxDimensions(String fieldName) {
89+
return Codec.getDefault().knnVectorsFormat().getMaxDimensions(fieldName);
9090
}
9191

9292
public void testFieldConstructor() {
@@ -475,11 +475,13 @@ public void testIllegalDimensionTooLarge() throws Exception {
475475
Document doc = new Document();
476476
doc.add(
477477
new KnnFloatVectorField(
478-
"f", new float[getVectorsMaxDimensions() + 1], VectorSimilarityFunction.DOT_PRODUCT));
478+
"f",
479+
new float[getVectorsMaxDimensions("f") + 1],
480+
VectorSimilarityFunction.DOT_PRODUCT));
479481
Exception exc = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc));
480482
assertTrue(
481483
exc.getMessage()
482-
.contains("vector's dimensions must be <= [" + getVectorsMaxDimensions() + "]"));
484+
.contains("vector's dimensions must be <= [" + getVectorsMaxDimensions("f") + "]"));
483485

484486
Document doc2 = new Document();
485487
doc2.add(new KnnFloatVectorField("f", new float[1], VectorSimilarityFunction.DOT_PRODUCT));
@@ -488,7 +490,9 @@ public void testIllegalDimensionTooLarge() throws Exception {
488490
Document doc3 = new Document();
489491
doc3.add(
490492
new KnnFloatVectorField(
491-
"f", new float[getVectorsMaxDimensions() + 1], VectorSimilarityFunction.DOT_PRODUCT));
493+
"f",
494+
new float[getVectorsMaxDimensions("f") + 1],
495+
VectorSimilarityFunction.DOT_PRODUCT));
492496
exc = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc3));
493497
assertTrue(
494498
exc.getMessage()
@@ -498,11 +502,13 @@ public void testIllegalDimensionTooLarge() throws Exception {
498502
Document doc4 = new Document();
499503
doc4.add(
500504
new KnnFloatVectorField(
501-
"f", new float[getVectorsMaxDimensions() + 1], VectorSimilarityFunction.DOT_PRODUCT));
505+
"f",
506+
new float[getVectorsMaxDimensions("f") + 1],
507+
VectorSimilarityFunction.DOT_PRODUCT));
502508
exc = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc4));
503509
assertTrue(
504510
exc.getMessage()
505-
.contains("vector's dimensions must be <= [" + getVectorsMaxDimensions() + "]"));
511+
.contains("vector's dimensions must be <= [" + getVectorsMaxDimensions("f") + "]"));
506512
}
507513
}
508514

0 commit comments

Comments
 (0)