mirror of https://github.com/apache/lucene.git
LUCENE-9615: Expose HnswGraphBuilder index-time hyperparameters as FieldType attributes (from Shubham Beniwal))
This commit is contained in:
parent
8f75933f3d
commit
a53e8e7228
|
@ -123,7 +123,9 @@ public final class Lucene90VectorWriter extends VectorWriter {
|
|||
(RandomAccessVectorValuesProducer) vectors,
|
||||
vectorIndexOffset,
|
||||
offsets,
|
||||
count);
|
||||
count,
|
||||
fieldInfo.getAttribute(HnswGraphBuilder.HNSW_MAX_CONN_ATTRIBUTE_KEY),
|
||||
fieldInfo.getAttribute(HnswGraphBuilder.HNSW_BEAM_WIDTH_ATTRIBUTE_KEY));
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
"Indexing an HNSW graph requires a random access vector values, got " + vectors);
|
||||
|
@ -188,9 +190,35 @@ public final class Lucene90VectorWriter extends VectorWriter {
|
|||
RandomAccessVectorValuesProducer vectorValues,
|
||||
long graphDataOffset,
|
||||
long[] offsets,
|
||||
int count)
|
||||
int count,
|
||||
String maxConnStr,
|
||||
String beamWidthStr)
|
||||
throws IOException {
|
||||
HnswGraphBuilder hnswGraphBuilder = new HnswGraphBuilder(vectorValues);
|
||||
int maxConn, beamWidth;
|
||||
if (maxConnStr == null) {
|
||||
maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN;
|
||||
} else {
|
||||
try {
|
||||
maxConn = Integer.parseInt(maxConnStr);
|
||||
} catch (NumberFormatException e) {
|
||||
throw new NumberFormatException(
|
||||
"Received non integer value for max-connections parameter of HnswGraphBuilder, value: "
|
||||
+ maxConnStr);
|
||||
}
|
||||
}
|
||||
if (beamWidthStr == null) {
|
||||
beamWidth = HnswGraphBuilder.DEFAULT_BEAM_WIDTH;
|
||||
} else {
|
||||
try {
|
||||
beamWidth = Integer.parseInt(beamWidthStr);
|
||||
} catch (NumberFormatException e) {
|
||||
throw new NumberFormatException(
|
||||
"Received non integer value for beam-width parameter of HnswGraphBuilder, value: "
|
||||
+ beamWidthStr);
|
||||
}
|
||||
}
|
||||
HnswGraphBuilder hnswGraphBuilder =
|
||||
new HnswGraphBuilder(vectorValues, maxConn, beamWidth, System.currentTimeMillis());
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.lucene.document;
|
||||
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||
|
||||
/**
|
||||
* A field that contains a single floating-point numeric vector (or none) for each document. Vectors
|
||||
|
@ -32,7 +33,7 @@ import org.apache.lucene.index.VectorValues;
|
|||
*/
|
||||
public class VectorField extends Field {
|
||||
|
||||
private static FieldType getType(float[] v, VectorValues.SearchStrategy searchStrategy) {
|
||||
private static FieldType createType(float[] v, VectorValues.SearchStrategy searchStrategy) {
|
||||
if (v == null) {
|
||||
throw new IllegalArgumentException("vector value must not be null");
|
||||
}
|
||||
|
@ -53,6 +54,37 @@ public class VectorField extends Field {
|
|||
return type;
|
||||
}
|
||||
|
||||
/**
|
||||
* Public method to create HNSW field type with the given max-connections and beam-width
|
||||
* parameters that would be used by HnswGraphBuilder while constructing HNSW graph.
|
||||
*
|
||||
* @param dimension dimension of vectors
|
||||
* @param searchStrategy a function defining vector proximity.
|
||||
* @param maxConn max-connections at each HNSW graph node
|
||||
* @param beamWidth size of list to be used while constructing HNSW graph
|
||||
* @throws IllegalArgumentException if any parameter is null, or has dimension > 1024.
|
||||
*/
|
||||
public static FieldType createHnswType(
|
||||
int dimension, VectorValues.SearchStrategy searchStrategy, int maxConn, int beamWidth) {
|
||||
if (dimension == 0) {
|
||||
throw new IllegalArgumentException("cannot index an empty vector");
|
||||
}
|
||||
if (dimension > VectorValues.MAX_DIMENSIONS) {
|
||||
throw new IllegalArgumentException(
|
||||
"cannot index vectors with dimension greater than " + VectorValues.MAX_DIMENSIONS);
|
||||
}
|
||||
if (searchStrategy == null || !searchStrategy.isHnsw()) {
|
||||
throw new IllegalArgumentException(
|
||||
"search strategy must not be null or non HNSW type, received: " + searchStrategy);
|
||||
}
|
||||
FieldType type = new FieldType();
|
||||
type.setVectorDimensionsAndSearchStrategy(dimension, searchStrategy);
|
||||
type.putAttribute(HnswGraphBuilder.HNSW_MAX_CONN_ATTRIBUTE_KEY, String.valueOf(maxConn));
|
||||
type.putAttribute(HnswGraphBuilder.HNSW_BEAM_WIDTH_ATTRIBUTE_KEY, String.valueOf(beamWidth));
|
||||
type.freeze();
|
||||
return type;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
|
||||
* no value. Vectors of a single field share the same dimension and search strategy. Note that
|
||||
|
@ -66,7 +98,7 @@ public class VectorField extends Field {
|
|||
* dimension > 1024.
|
||||
*/
|
||||
public VectorField(String name, float[] vector, VectorValues.SearchStrategy searchStrategy) {
|
||||
super(name, getType(vector, searchStrategy));
|
||||
super(name, createType(vector, searchStrategy));
|
||||
fieldsData = vector;
|
||||
}
|
||||
|
||||
|
@ -84,6 +116,21 @@ public class VectorField extends Field {
|
|||
this(name, vector, VectorValues.SearchStrategy.EUCLIDEAN_HNSW);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
|
||||
* no value. Vectors of a single field share the same dimension and search strategy.
|
||||
*
|
||||
* @param name field name
|
||||
* @param vector value
|
||||
* @param fieldType field type
|
||||
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
|
||||
* dimension > 1024.
|
||||
*/
|
||||
public VectorField(String name, float[] vector, FieldType fieldType) {
|
||||
super(name, fieldType);
|
||||
fieldsData = vector;
|
||||
}
|
||||
|
||||
/** Return the vector value of this field */
|
||||
public float[] vectorValue() {
|
||||
return (float[]) fieldsData;
|
||||
|
|
|
@ -43,10 +43,12 @@ public final class HnswGraphBuilder {
|
|||
*/
|
||||
|
||||
// default max connections per node
|
||||
public static int DEFAULT_MAX_CONN = 16;
|
||||
public static final int DEFAULT_MAX_CONN = 16;
|
||||
public static String HNSW_MAX_CONN_ATTRIBUTE_KEY = "max_connections";
|
||||
|
||||
// default candidate list size
|
||||
public static int DEFAULT_BEAM_WIDTH = 16;
|
||||
public static final int DEFAULT_BEAM_WIDTH = 16;
|
||||
public static String HNSW_BEAM_WIDTH_ATTRIBUTE_KEY = "beam_width";
|
||||
|
||||
private final int maxConn;
|
||||
private final int beamWidth;
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.apache.lucene.codecs.Codec;
|
|||
import org.apache.lucene.codecs.lucene90.Lucene90VectorReader;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.FieldType;
|
||||
import org.apache.lucene.document.SortedDocValuesField;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.document.VectorField;
|
||||
|
@ -58,8 +59,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
public void setup() {
|
||||
randSeed = random().nextLong();
|
||||
if (random().nextBoolean()) {
|
||||
maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN;
|
||||
HnswGraphBuilder.DEFAULT_MAX_CONN = random().nextInt(256) + 2;
|
||||
maxConn = random().nextInt(256) + 2;
|
||||
}
|
||||
int strategy = random().nextInt(SearchStrategy.values().length - 1) + 1;
|
||||
searchStrategy = SearchStrategy.values()[strategy];
|
||||
|
@ -67,7 +67,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
|
||||
@After
|
||||
public void cleanup() {
|
||||
HnswGraphBuilder.DEFAULT_MAX_CONN = maxConn;
|
||||
maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN;
|
||||
}
|
||||
|
||||
/** Basic test of creating documents in a graph */
|
||||
|
@ -196,7 +196,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
int[][] copyGraph(KnnGraphValues values) throws IOException {
|
||||
int size = values.size();
|
||||
int[][] graph = new int[size][];
|
||||
int[] scratch = new int[HnswGraphBuilder.DEFAULT_MAX_CONN];
|
||||
int[] scratch = new int[maxConn];
|
||||
for (int node = 0; node < size; node++) {
|
||||
int n, count = 0;
|
||||
values.seek(node);
|
||||
|
@ -368,12 +368,12 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
assertTrue(
|
||||
"Graph has " + graphSize + " nodes, but one of them has no neighbors", graphSize > 1);
|
||||
}
|
||||
if (HnswGraphBuilder.DEFAULT_MAX_CONN > graphSize) {
|
||||
if (maxConn > graphSize) {
|
||||
// assert that the graph in each leaf is connected
|
||||
assertConnected(graph);
|
||||
} else {
|
||||
// assert that max-connections was respected
|
||||
assertMaxConn(graph, HnswGraphBuilder.DEFAULT_MAX_CONN);
|
||||
assertMaxConn(graph, maxConn);
|
||||
}
|
||||
totalGraphDocs += graphSize;
|
||||
}
|
||||
|
@ -439,7 +439,10 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
throws IOException {
|
||||
Document doc = new Document();
|
||||
if (vector != null) {
|
||||
doc.add(new VectorField(KNN_GRAPH_FIELD, vector, searchStrategy));
|
||||
FieldType fieldType =
|
||||
VectorField.createHnswType(
|
||||
vector.length, searchStrategy, maxConn, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
|
||||
doc.add(new VectorField(KNN_GRAPH_FIELD, vector, fieldType));
|
||||
}
|
||||
String idString = Integer.toString(id);
|
||||
doc.add(new StringField("id", idString, Field.Store.YES));
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.apache.lucene.codecs.Codec;
|
|||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.Field.Store;
|
||||
import org.apache.lucene.document.FieldType;
|
||||
import org.apache.lucene.document.NumericDocValuesField;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.document.VectorField;
|
||||
|
@ -76,11 +77,16 @@ public class TestVectorValues extends LuceneTestCase {
|
|||
public void testFieldConstructorExceptions() {
|
||||
expectThrows(IllegalArgumentException.class, () -> new VectorField(null, new float[1]));
|
||||
expectThrows(IllegalArgumentException.class, () -> new VectorField("f", null));
|
||||
expectThrows(IllegalArgumentException.class, () -> new VectorField("f", new float[1], null));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new VectorField("f", new float[1], (SearchStrategy) null));
|
||||
expectThrows(IllegalArgumentException.class, () -> new VectorField("f", new float[0]));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new VectorField("f", new float[VectorValues.MAX_DIMENSIONS + 1]));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new VectorField("f", new float[VectorValues.MAX_DIMENSIONS + 1], (FieldType) null));
|
||||
}
|
||||
|
||||
public void testFieldSetValue() {
|
||||
|
@ -92,6 +98,25 @@ public class TestVectorValues extends LuceneTestCase {
|
|||
expectThrows(IllegalArgumentException.class, () -> field.setVectorValue(null));
|
||||
}
|
||||
|
||||
public void testFieldCreateFieldType() {
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> VectorField.createHnswType(0, SearchStrategy.EUCLIDEAN_HNSW, 16, 16));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
VectorField.createHnswType(
|
||||
VectorValues.MAX_DIMENSIONS + 1, SearchStrategy.EUCLIDEAN_HNSW, 16, 16));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> VectorField.createHnswType(VectorValues.MAX_DIMENSIONS + 1, null, 16, 16));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
VectorField.createHnswType(
|
||||
VectorValues.MAX_DIMENSIONS + 1, SearchStrategy.NONE, 16, 16));
|
||||
}
|
||||
|
||||
// Illegal schema change tests:
|
||||
|
||||
public void testIllegalDimChangeTwoDocs() throws Exception {
|
||||
|
|
|
@ -37,6 +37,7 @@ import java.util.Locale;
|
|||
import java.util.Set;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90VectorReader;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.FieldType;
|
||||
import org.apache.lucene.document.StoredField;
|
||||
import org.apache.lucene.document.VectorField;
|
||||
import org.apache.lucene.index.CodecReader;
|
||||
|
@ -83,6 +84,8 @@ public class KnnGraphTester {
|
|||
private boolean reindex;
|
||||
private boolean forceMerge;
|
||||
private int reindexTimeMsec;
|
||||
private int beamWidth;
|
||||
private int maxConn;
|
||||
|
||||
@SuppressForbidden(reason = "uses Random()")
|
||||
private KnnGraphTester() {
|
||||
|
@ -132,13 +135,13 @@ public class KnnGraphTester {
|
|||
if (iarg == args.length - 1) {
|
||||
throw new IllegalArgumentException("-beamWidthIndex requires a following number");
|
||||
}
|
||||
HnswGraphBuilder.DEFAULT_BEAM_WIDTH = Integer.parseInt(args[++iarg]);
|
||||
beamWidth = Integer.parseInt(args[++iarg]);
|
||||
break;
|
||||
case "-maxConn":
|
||||
if (iarg == args.length - 1) {
|
||||
throw new IllegalArgumentException("-maxConn requires a following number");
|
||||
}
|
||||
HnswGraphBuilder.DEFAULT_MAX_CONN = Integer.parseInt(args[++iarg]);
|
||||
maxConn = Integer.parseInt(args[++iarg]);
|
||||
break;
|
||||
case "-dim":
|
||||
if (iarg == args.length - 1) {
|
||||
|
@ -223,12 +226,7 @@ public class KnnGraphTester {
|
|||
}
|
||||
|
||||
private String formatIndexPath(Path docsPath) {
|
||||
return docsPath.getFileName()
|
||||
+ "-"
|
||||
+ HnswGraphBuilder.DEFAULT_MAX_CONN
|
||||
+ "-"
|
||||
+ HnswGraphBuilder.DEFAULT_BEAM_WIDTH
|
||||
+ ".index";
|
||||
return docsPath.getFileName() + "-" + maxConn + "-" + beamWidth + ".index";
|
||||
}
|
||||
|
||||
@SuppressForbidden(reason = "Prints stuff")
|
||||
|
@ -250,9 +248,7 @@ public class KnnGraphTester {
|
|||
private void dumpGraph(Path docsPath) throws IOException {
|
||||
try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) {
|
||||
RandomAccessVectorValues values = vectors.randomAccess();
|
||||
HnswGraphBuilder builder =
|
||||
new HnswGraphBuilder(
|
||||
vectors, HnswGraphBuilder.DEFAULT_MAX_CONN, HnswGraphBuilder.DEFAULT_BEAM_WIDTH, 0);
|
||||
HnswGraphBuilder builder = new HnswGraphBuilder(vectors, maxConn, beamWidth, 0);
|
||||
// start at node 1
|
||||
for (int i = 1; i < numDocs; i++) {
|
||||
builder.addGraphNode(values.vectorValue(i));
|
||||
|
@ -413,8 +409,8 @@ public class KnnGraphTester {
|
|||
totalCpuTime / (float) numIters,
|
||||
numDocs,
|
||||
fanout,
|
||||
HnswGraphBuilder.DEFAULT_MAX_CONN,
|
||||
HnswGraphBuilder.DEFAULT_BEAM_WIDTH,
|
||||
maxConn,
|
||||
beamWidth,
|
||||
totalVisited,
|
||||
reindexTimeMsec);
|
||||
}
|
||||
|
@ -574,6 +570,9 @@ public class KnnGraphTester {
|
|||
iwc.setRAMBufferSizeMB(1994d);
|
||||
// iwc.setMaxBufferedDocs(10000);
|
||||
|
||||
FieldType fieldType =
|
||||
VectorField.createHnswType(
|
||||
dim, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW, maxConn, beamWidth);
|
||||
if (quiet == false) {
|
||||
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
|
||||
System.out.println("creating index in " + indexPath);
|
||||
|
@ -598,8 +597,7 @@ public class KnnGraphTester {
|
|||
vectors.get(vector);
|
||||
Document doc = new Document();
|
||||
// System.out.println("vector=" + vector[0] + "," + vector[1] + "...");
|
||||
doc.add(
|
||||
new VectorField(KNN_FIELD, vector, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW));
|
||||
doc.add(new VectorField(KNN_FIELD, vector, fieldType));
|
||||
doc.add(new StoredField(ID_FIELD, i));
|
||||
iw.addDocument(doc);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue