mirror of https://github.com/apache/lucene.git
LUCENE-9905: Move HNSW build parameters to codec (#166)
Previously, the max connections and beam width parameters could be configured as field type attributes. This PR moves them to be parameters on Lucene90HnswVectorFormat, to avoid exposing details of the vector format implementation in the API.
This commit is contained in:
parent
dbb4c265d5
commit
05ae738fc9
|
@ -23,6 +23,7 @@ import org.apache.lucene.codecs.VectorReader;
|
|||
import org.apache.lucene.codecs.VectorWriter;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
|
||||
/**
|
||||
* Lucene 9.0 vector format, which encodes numeric vector values and an optional associated graph
|
||||
|
@ -76,14 +77,37 @@ public final class Lucene90HnswVectorFormat extends VectorFormat {
|
|||
static final int VERSION_START = 0;
|
||||
static final int VERSION_CURRENT = VERSION_START;
|
||||
|
||||
/** Sole constructor */
|
||||
public static final int DEFAULT_MAX_CONN = 16;
|
||||
public static final int DEFAULT_BEAM_WIDTH = 16;
|
||||
|
||||
/**
|
||||
* Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
|
||||
* {@link Lucene90HnswVectorFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details.
|
||||
*/
|
||||
private final int maxConn;
|
||||
|
||||
/**
|
||||
* The number of candidate neighbors to track while searching the graph for each newly inserted
|
||||
* node. Defaults to to {@link Lucene90HnswVectorFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph}
|
||||
* for details.
|
||||
*/
|
||||
private final int beamWidth;
|
||||
|
||||
public Lucene90HnswVectorFormat() {
|
||||
super("Lucene90HnswVectorFormat");
|
||||
this.maxConn = DEFAULT_MAX_CONN;
|
||||
this.beamWidth = DEFAULT_BEAM_WIDTH;
|
||||
}
|
||||
|
||||
public Lucene90HnswVectorFormat(int maxConn, int beamWidth) {
|
||||
super("Lucene90HnswVectorFormat");
|
||||
this.maxConn = maxConn;
|
||||
this.beamWidth = beamWidth;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||
return new Lucene90HnswVectorWriter(state);
|
||||
return new Lucene90HnswVectorWriter(state, maxConn, beamWidth);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -45,9 +45,14 @@ public final class Lucene90HnswVectorWriter extends VectorWriter {
|
|||
private final SegmentWriteState segmentWriteState;
|
||||
private final IndexOutput meta, vectorData, vectorIndex;
|
||||
|
||||
private final int maxConn;
|
||||
private final int beamWidth;
|
||||
private boolean finished;
|
||||
|
||||
Lucene90HnswVectorWriter(SegmentWriteState state) throws IOException {
|
||||
Lucene90HnswVectorWriter(SegmentWriteState state, int maxConn, int beamWidth) throws IOException {
|
||||
this.maxConn = maxConn;
|
||||
this.beamWidth = beamWidth;
|
||||
|
||||
assert state.fieldInfos.hasVectorValues();
|
||||
segmentWriteState = state;
|
||||
|
||||
|
@ -129,8 +134,8 @@ public final class Lucene90HnswVectorWriter extends VectorWriter {
|
|||
vectorIndexOffset,
|
||||
offsets,
|
||||
count,
|
||||
fieldInfo.getAttribute(HnswGraphBuilder.HNSW_MAX_CONN_ATTRIBUTE_KEY),
|
||||
fieldInfo.getAttribute(HnswGraphBuilder.HNSW_BEAM_WIDTH_ATTRIBUTE_KEY));
|
||||
maxConn,
|
||||
beamWidth);
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
"Indexing an HNSW graph requires a random access vector values, got " + vectors);
|
||||
|
@ -196,36 +201,9 @@ public final class Lucene90HnswVectorWriter extends VectorWriter {
|
|||
long graphDataOffset,
|
||||
long[] offsets,
|
||||
int count,
|
||||
String maxConnStr,
|
||||
String beamWidthStr)
|
||||
int maxConn,
|
||||
int beamWidth)
|
||||
throws IOException {
|
||||
int maxConn, beamWidth;
|
||||
if (maxConnStr == null) {
|
||||
maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN;
|
||||
} else {
|
||||
try {
|
||||
maxConn = Integer.parseInt(maxConnStr);
|
||||
} catch (
|
||||
@SuppressWarnings("unused")
|
||||
NumberFormatException e) {
|
||||
throw new NumberFormatException(
|
||||
"Received non integer value for max-connections parameter of HnswGraphBuilder, value: "
|
||||
+ maxConnStr);
|
||||
}
|
||||
}
|
||||
if (beamWidthStr == null) {
|
||||
beamWidth = HnswGraphBuilder.DEFAULT_BEAM_WIDTH;
|
||||
} else {
|
||||
try {
|
||||
beamWidth = Integer.parseInt(beamWidthStr);
|
||||
} catch (
|
||||
@SuppressWarnings("unused")
|
||||
NumberFormatException e) {
|
||||
throw new NumberFormatException(
|
||||
"Received non integer value for beam-width parameter of HnswGraphBuilder, value: "
|
||||
+ beamWidthStr);
|
||||
}
|
||||
}
|
||||
HnswGraphBuilder hnswGraphBuilder =
|
||||
new HnswGraphBuilder(vectorValues, maxConn, beamWidth, HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
package org.apache.lucene.document;
|
||||
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||
|
||||
/**
|
||||
* A field that contains a single floating-point numeric vector (or none) for each document. Vectors
|
||||
|
@ -57,34 +56,16 @@ public class VectorField extends Field {
|
|||
}
|
||||
|
||||
/**
|
||||
* Public method to create HNSW field type with the given max-connections and beam-width
|
||||
* parameters that would be used by HnswGraphBuilder while constructing HNSW graph.
|
||||
* A convenience method for creating a vector field type.
|
||||
*
|
||||
* @param dimension dimension of vectors
|
||||
* @param similarityFunction a function defining vector proximity.
|
||||
* @param maxConn max-connections at each HNSW graph node
|
||||
* @param beamWidth size of list to be used while constructing HNSW graph
|
||||
* @throws IllegalArgumentException if any parameter is null, or has dimension > 1024.
|
||||
*/
|
||||
public static FieldType createHnswType(
|
||||
int dimension,
|
||||
VectorValues.SimilarityFunction similarityFunction,
|
||||
int maxConn,
|
||||
int beamWidth) {
|
||||
if (dimension == 0) {
|
||||
throw new IllegalArgumentException("cannot index an empty vector");
|
||||
}
|
||||
if (dimension > VectorValues.MAX_DIMENSIONS) {
|
||||
throw new IllegalArgumentException(
|
||||
"cannot index vectors with dimension greater than " + VectorValues.MAX_DIMENSIONS);
|
||||
}
|
||||
if (similarityFunction == null || similarityFunction == VectorValues.SimilarityFunction.NONE) {
|
||||
throw new IllegalArgumentException("similarity function must not be: " + similarityFunction);
|
||||
}
|
||||
public static FieldType createFieldType(
|
||||
int dimension, VectorValues.SimilarityFunction similarityFunction) {
|
||||
FieldType type = new FieldType();
|
||||
type.setVectorDimensionsAndSimilarityFunction(dimension, similarityFunction);
|
||||
type.putAttribute(HnswGraphBuilder.HNSW_MAX_CONN_ATTRIBUTE_KEY, String.valueOf(maxConn));
|
||||
type.putAttribute(HnswGraphBuilder.HNSW_BEAM_WIDTH_ATTRIBUTE_KEY, String.valueOf(beamWidth));
|
||||
type.freeze();
|
||||
return type;
|
||||
}
|
||||
|
|
|
@ -38,18 +38,6 @@ public final class HnswGraphBuilder {
|
|||
// expose for testing.
|
||||
public static long randSeed = DEFAULT_RAND_SEED;
|
||||
|
||||
/* These "default" hyper-parameter settings are exposed (and non-final) to enable performance
|
||||
* testing since the indexing API doesn't provide any control over them.
|
||||
*/
|
||||
|
||||
// default max connections per node
|
||||
public static final int DEFAULT_MAX_CONN = 16;
|
||||
public static String HNSW_MAX_CONN_ATTRIBUTE_KEY = "max_connections";
|
||||
|
||||
// default candidate list size
|
||||
public static final int DEFAULT_BEAM_WIDTH = 16;
|
||||
public static String HNSW_BEAM_WIDTH_ATTRIBUTE_KEY = "beam_width";
|
||||
|
||||
private final int maxConn;
|
||||
private final int beamWidth;
|
||||
private final NeighborArray scratch;
|
||||
|
@ -66,11 +54,6 @@ public final class HnswGraphBuilder {
|
|||
// colliding
|
||||
private RandomAccessVectorValues buildVectors;
|
||||
|
||||
/** Construct the builder with default configurations */
|
||||
public HnswGraphBuilder(RandomAccessVectorValuesProducer vectors) {
|
||||
this(vectors, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, randSeed);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense
|
||||
* ordinals, using the given hyperparameter settings, and returns the resulting graph.
|
||||
|
|
|
@ -21,7 +21,6 @@ import org.apache.lucene.index.BaseVectorFormatTestCase;
|
|||
import org.apache.lucene.util.TestUtil;
|
||||
|
||||
public class TestLucene90HnswVectorFormat extends BaseVectorFormatTestCase {
|
||||
|
||||
@Override
|
||||
protected Codec getCodec() {
|
||||
return TestUtil.getDefaultCodec();
|
||||
|
|
|
@ -27,6 +27,9 @@ import java.util.LinkedList;
|
|||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.VectorFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90Codec;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorReader;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldVectorFormat;
|
||||
import org.apache.lucene.document.Document;
|
||||
|
@ -52,8 +55,9 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
|
||||
private static final String KNN_GRAPH_FIELD = "vector";
|
||||
|
||||
private static int maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN;
|
||||
private static int maxConn = Lucene90HnswVectorFormat.DEFAULT_MAX_CONN;
|
||||
|
||||
private Codec codec;
|
||||
private SimilarityFunction similarityFunction;
|
||||
|
||||
@Before
|
||||
|
@ -62,20 +66,29 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
if (random().nextBoolean()) {
|
||||
maxConn = random().nextInt(256) + 3;
|
||||
}
|
||||
|
||||
codec =
|
||||
new Lucene90Codec() {
|
||||
@Override
|
||||
public VectorFormat getVectorFormatForField(String field) {
|
||||
return new Lucene90HnswVectorFormat(
|
||||
maxConn, Lucene90HnswVectorFormat.DEFAULT_BEAM_WIDTH);
|
||||
}
|
||||
};
|
||||
|
||||
int similarity = random().nextInt(SimilarityFunction.values().length - 1) + 1;
|
||||
similarityFunction = SimilarityFunction.values()[similarity];
|
||||
}
|
||||
|
||||
@After
|
||||
public void cleanup() {
|
||||
maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN;
|
||||
maxConn = Lucene90HnswVectorFormat.DEFAULT_MAX_CONN;
|
||||
}
|
||||
|
||||
/** Basic test of creating documents in a graph */
|
||||
public void testBasic() throws Exception {
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter iw =
|
||||
new IndexWriter(dir, newIndexWriterConfig(null).setCodec(Codec.forName("Lucene90")))) {
|
||||
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(codec))) {
|
||||
int numDoc = atLeast(10);
|
||||
int dimension = atLeast(3);
|
||||
float[][] values = new float[numDoc][];
|
||||
|
@ -94,8 +107,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
|
||||
public void testSingleDocument() throws Exception {
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter iw =
|
||||
new IndexWriter(dir, newIndexWriterConfig(null).setCodec(Codec.forName("Lucene90")))) {
|
||||
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(codec))) {
|
||||
float[][] values = new float[][] {new float[] {0, 1, 2}};
|
||||
add(iw, 0, values[0]);
|
||||
assertConsistentGraph(iw, values);
|
||||
|
@ -107,8 +119,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
/** Verify that the graph properties are preserved when merging */
|
||||
public void testMerge() throws Exception {
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter iw =
|
||||
new IndexWriter(dir, newIndexWriterConfig(null).setCodec(Codec.forName("Lucene90")))) {
|
||||
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(codec))) {
|
||||
int numDoc = atLeast(100);
|
||||
int dimension = atLeast(10);
|
||||
float[][] values = randomVectors(numDoc, dimension);
|
||||
|
@ -160,7 +171,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
try (Directory dir = newDirectory()) {
|
||||
IndexWriterConfig iwc = newIndexWriterConfig();
|
||||
iwc.setMergePolicy(new LogDocMergePolicy()); // for predictable segment ordering when merging
|
||||
iwc.setCodec(Codec.forName("Lucene90")); // don't use SimpleTextCodec
|
||||
iwc.setCodec(codec); // don't use SimpleTextCodec
|
||||
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
add(iw, i, values[i]);
|
||||
|
@ -218,7 +229,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
// We can't use dot product here since the vectors are laid out on a grid, not a sphere.
|
||||
similarityFunction = SimilarityFunction.EUCLIDEAN;
|
||||
IndexWriterConfig config = newIndexWriterConfig();
|
||||
config.setCodec(Codec.forName("Lucene90")); // test is not compatible with simpletext
|
||||
config.setCodec(codec); // test is not compatible with simpletext
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter iw = new IndexWriter(dir, config)) {
|
||||
// Add a document for every cartesian point in an NxN square so we can
|
||||
|
@ -447,9 +458,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
throws IOException {
|
||||
Document doc = new Document();
|
||||
if (vector != null) {
|
||||
FieldType fieldType =
|
||||
VectorField.createHnswType(
|
||||
vector.length, similarityFunction, maxConn, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
|
||||
FieldType fieldType = VectorField.createFieldType(vector.length, similarityFunction);
|
||||
doc.add(new VectorField(KNN_GRAPH_FIELD, vector, fieldType));
|
||||
}
|
||||
String idString = Integer.toString(id);
|
||||
|
|
|
@ -35,6 +35,9 @@ import java.nio.file.Paths;
|
|||
import java.util.HashSet;
|
||||
import java.util.Locale;
|
||||
import java.util.Set;
|
||||
import org.apache.lucene.codecs.VectorFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90Codec;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorReader;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.FieldType;
|
||||
|
@ -566,13 +569,19 @@ public class KnnGraphTester {
|
|||
|
||||
private int createIndex(Path docsPath, Path indexPath) throws IOException {
|
||||
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.CREATE);
|
||||
iwc.setCodec(
|
||||
new Lucene90Codec() {
|
||||
@Override
|
||||
public VectorFormat getVectorFormatForField(String field) {
|
||||
return new Lucene90HnswVectorFormat(maxConn, beamWidth);
|
||||
}
|
||||
});
|
||||
// iwc.setMergePolicy(NoMergePolicy.INSTANCE);
|
||||
iwc.setRAMBufferSizeMB(1994d);
|
||||
// iwc.setMaxBufferedDocs(10000);
|
||||
|
||||
FieldType fieldType =
|
||||
VectorField.createHnswType(
|
||||
dim, VectorValues.SimilarityFunction.DOT_PRODUCT, maxConn, beamWidth);
|
||||
VectorField.createFieldType(dim, VectorValues.SimilarityFunction.DOT_PRODUCT);
|
||||
if (quiet == false) {
|
||||
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
|
||||
System.out.println("creating index in " + indexPath);
|
||||
|
|
|
@ -24,7 +24,9 @@ import java.util.Arrays;
|
|||
import java.util.HashSet;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.VectorFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90Codec;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorReader;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldVectorFormat;
|
||||
import org.apache.lucene.document.Document;
|
||||
|
@ -49,23 +51,34 @@ import org.apache.lucene.util.VectorUtil;
|
|||
/** Tests HNSW KNN graphs */
|
||||
public class TestHnsw extends LuceneTestCase {
|
||||
|
||||
// test writing out and reading in a graph gives the same graph
|
||||
// test writing out and reading in a graph gives the expected graph
|
||||
public void testReadWrite() throws IOException {
|
||||
int dim = random().nextInt(100) + 1;
|
||||
int nDoc = random().nextInt(100) + 1;
|
||||
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random());
|
||||
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
|
||||
|
||||
int maxConn = random().nextInt(10) + 5;
|
||||
int beamWidth = random().nextInt(10) + 5;
|
||||
long seed = random().nextLong();
|
||||
HnswGraphBuilder.randSeed = seed;
|
||||
HnswGraphBuilder builder = new HnswGraphBuilder(vectors);
|
||||
HnswGraphBuilder builder = new HnswGraphBuilder(vectors, maxConn, beamWidth, seed);
|
||||
HnswGraph hnsw = builder.build(vectors);
|
||||
|
||||
// Recreate the graph while indexing with the same random seed and write it out
|
||||
HnswGraphBuilder.randSeed = seed;
|
||||
try (Directory dir = newDirectory()) {
|
||||
int nVec = 0, indexedDoc = 0;
|
||||
// Don't merge randomly, create a single segment because we rely on the docid ordering for
|
||||
// this test
|
||||
IndexWriterConfig iwc = new IndexWriterConfig().setCodec(Codec.forName("Lucene90"));
|
||||
IndexWriterConfig iwc =
|
||||
new IndexWriterConfig()
|
||||
.setCodec(
|
||||
new Lucene90Codec() {
|
||||
@Override
|
||||
public VectorFormat getVectorFormatForField(String field) {
|
||||
return new Lucene90HnswVectorFormat(maxConn, beamWidth);
|
||||
}
|
||||
});
|
||||
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
while (v2.nextDoc() != NO_MORE_DOCS) {
|
||||
while (indexedDoc < v2.docID()) {
|
||||
|
|
|
@ -85,28 +85,6 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
|
|||
expectThrows(IllegalArgumentException.class, () -> field.setVectorValue(null));
|
||||
}
|
||||
|
||||
public void testFieldCreateFieldType() {
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> VectorField.createHnswType(0, VectorValues.SimilarityFunction.EUCLIDEAN, 16, 16));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
VectorField.createHnswType(
|
||||
VectorValues.MAX_DIMENSIONS + 1,
|
||||
VectorValues.SimilarityFunction.EUCLIDEAN,
|
||||
16,
|
||||
16));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> VectorField.createHnswType(VectorValues.MAX_DIMENSIONS + 1, null, 16, 16));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
VectorField.createHnswType(
|
||||
VectorValues.MAX_DIMENSIONS + 1, VectorValues.SimilarityFunction.NONE, 16, 16));
|
||||
}
|
||||
|
||||
// Illegal schema change tests:
|
||||
public void testIllegalDimChangeTwoDocs() throws Exception {
|
||||
// illegal change in the same segment
|
||||
|
|
Loading…
Reference in New Issue