Add back maxConn & beamWidth HNSW codec ctor (#12728)

follow up to https://github.com/apache/lucene/pull/12582

For user convenience, I added back the two parameter ctor for the HNSW codec.
This commit is contained in:
Benjamin Trent 2023-10-30 09:31:04 -04:00 committed by GitHub
parent 11436a848c
commit 2a8d187a99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 31 additions and 21 deletions

View File

@ -191,6 +191,16 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
*
* @param maxConn the maximum number of connections to a node in the HNSW graph
* @param beamWidth the size of the queue maintained during graph construction.
*/
public Lucene99HnswVectorsFormat(int maxConn, int beamWidth) {
this(maxConn, beamWidth, null);
}
/**
* Constructs a format using the given graph construction parameters and scalar quantization.
*
* @param maxConn the maximum number of connections to a node in the HNSW graph
* @param beamWidth the size of the queue maintained during graph construction.
* @param scalarQuantize the scalar quantization format
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec

View File

@ -234,7 +234,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
MergedQuantizedVectorValues byteVectorValues =
MergedQuantizedVectorValues.mergeQuantizedByteVectorValues(
fieldInfo, mergeState, mergedQuantizationState);
writeQuantizedVectorData(tempQuantizedVectorData, byteVectorValues);
DocsWithFieldSet docsWithField =
writeQuantizedVectorData(tempQuantizedVectorData, byteVectorValues);
CodecUtil.writeFooter(tempQuantizedVectorData);
IOUtils.close(tempQuantizedVectorData);
quantizationDataInput =
@ -254,7 +255,9 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
fieldInfo.getVectorSimilarityFunction(),
mergedQuantizationState,
new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(), byteVectorValues.size(), quantizationDataInput)));
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
quantizationDataInput)));
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(quantizationDataInput);

View File

@ -33,7 +33,7 @@ public class TestLucene99HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
new FilterCodec("foo", Codec.getDefault()) {
@Override
public KnnVectorsFormat knnVectorsFormat() {
return new Lucene99HnswVectorsFormat(10, 20, null);
return new Lucene99HnswVectorsFormat(10, 20);
}
};
String expectedString =
@ -42,13 +42,11 @@ public class TestLucene99HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
}
public void testLimits() {
expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(-1, 20, null));
expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(0, 20, null));
expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, 0, null));
expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, -1, null));
expectThrows(
IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(512 + 1, 20, null));
expectThrows(
IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, 3201, null));
expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(-1, 20));
expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(0, 20));
expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, 0));
expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, -1));
expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(512 + 1, 20));
expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, 3201));
}
}

View File

@ -170,8 +170,8 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
try (Directory directory = newDirectory()) {
IndexWriterConfig iwc = newIndexWriterConfig(new MockAnalyzer(random()));
KnnVectorsFormat format1 =
new KnnVectorsFormatMaxDims32(new Lucene99HnswVectorsFormat(16, 100, null));
KnnVectorsFormat format2 = new Lucene99HnswVectorsFormat(16, 100, null);
new KnnVectorsFormatMaxDims32(new Lucene99HnswVectorsFormat(16, 100));
KnnVectorsFormat format2 = new Lucene99HnswVectorsFormat(16, 100);
iwc.setCodec(
new AssertingCodec() {
@Override

View File

@ -113,8 +113,7 @@ public class TestKnnGraph extends LuceneTestCase {
return new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswVectorsFormat(
M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH, null);
return new Lucene99HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
}
};
}

View File

@ -165,7 +165,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
return new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswVectorsFormat(M, beamWidth, null);
return new Lucene99HnswVectorsFormat(M, beamWidth);
}
};
}
@ -237,7 +237,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
return new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswVectorsFormat(M, beamWidth, null);
return new Lucene99HnswVectorsFormat(M, beamWidth);
}
};
}
@ -298,7 +298,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
return new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswVectorsFormat(M, beamWidth, null);
return new Lucene99HnswVectorsFormat(M, beamWidth);
}
};
}
@ -312,7 +312,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
return new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswVectorsFormat(M, beamWidth, null);
return new Lucene99HnswVectorsFormat(M, beamWidth);
}
};
}

View File

@ -32,12 +32,12 @@ public class ConfigurableMCodec extends FilterCodec {
public ConfigurableMCodec() {
super("ConfigurableMCodec", TestUtil.getDefaultCodec());
knnVectorsFormat = new Lucene99HnswVectorsFormat(128, 100, null);
knnVectorsFormat = new Lucene99HnswVectorsFormat(128, 100);
}
public ConfigurableMCodec(int maxConn) {
super("ConfigurableMCodec", TestUtil.getDefaultCodec());
knnVectorsFormat = new Lucene99HnswVectorsFormat(maxConn, 100, null);
knnVectorsFormat = new Lucene99HnswVectorsFormat(maxConn, 100);
}
@Override