LUCENE-9615: Expose HnswGraphBuilder index-time hyperparameters as FieldType attributes (from Shubham Beniwal))

This commit is contained in:
sbeniwal12 2021-02-03 03:56:29 +05:30 committed by GitHub
parent 8f75933f3d
commit a53e8e7228
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 133 additions and 30 deletions

View File

@ -123,7 +123,9 @@ public final class Lucene90VectorWriter extends VectorWriter {
(RandomAccessVectorValuesProducer) vectors,
vectorIndexOffset,
offsets,
count);
count,
fieldInfo.getAttribute(HnswGraphBuilder.HNSW_MAX_CONN_ATTRIBUTE_KEY),
fieldInfo.getAttribute(HnswGraphBuilder.HNSW_BEAM_WIDTH_ATTRIBUTE_KEY));
} else {
throw new IllegalArgumentException(
"Indexing an HNSW graph requires a random access vector values, got " + vectors);
@ -188,9 +190,35 @@ public final class Lucene90VectorWriter extends VectorWriter {
RandomAccessVectorValuesProducer vectorValues,
long graphDataOffset,
long[] offsets,
int count)
int count,
String maxConnStr,
String beamWidthStr)
throws IOException {
HnswGraphBuilder hnswGraphBuilder = new HnswGraphBuilder(vectorValues);
int maxConn, beamWidth;
if (maxConnStr == null) {
maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN;
} else {
try {
maxConn = Integer.parseInt(maxConnStr);
} catch (NumberFormatException e) {
throw new NumberFormatException(
"Received non integer value for max-connections parameter of HnswGraphBuilder, value: "
+ maxConnStr);
}
}
if (beamWidthStr == null) {
beamWidth = HnswGraphBuilder.DEFAULT_BEAM_WIDTH;
} else {
try {
beamWidth = Integer.parseInt(beamWidthStr);
} catch (NumberFormatException e) {
throw new NumberFormatException(
"Received non integer value for beam-width parameter of HnswGraphBuilder, value: "
+ beamWidthStr);
}
}
HnswGraphBuilder hnswGraphBuilder =
new HnswGraphBuilder(vectorValues, maxConn, beamWidth, System.currentTimeMillis());
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());

View File

@ -18,6 +18,7 @@
package org.apache.lucene.document;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
/**
* A field that contains a single floating-point numeric vector (or none) for each document. Vectors
@ -32,7 +33,7 @@ import org.apache.lucene.index.VectorValues;
*/
public class VectorField extends Field {
private static FieldType getType(float[] v, VectorValues.SearchStrategy searchStrategy) {
private static FieldType createType(float[] v, VectorValues.SearchStrategy searchStrategy) {
if (v == null) {
throw new IllegalArgumentException("vector value must not be null");
}
@ -53,6 +54,37 @@ public class VectorField extends Field {
return type;
}
/**
* Public method to create HNSW field type with the given max-connections and beam-width
* parameters that would be used by HnswGraphBuilder while constructing HNSW graph.
*
* @param dimension dimension of vectors
* @param searchStrategy a function defining vector proximity.
* @param maxConn max-connections at each HNSW graph node
* @param beamWidth size of list to be used while constructing HNSW graph
* @throws IllegalArgumentException if any parameter is null, or has dimension > 1024.
*/
public static FieldType createHnswType(
int dimension, VectorValues.SearchStrategy searchStrategy, int maxConn, int beamWidth) {
if (dimension == 0) {
throw new IllegalArgumentException("cannot index an empty vector");
}
if (dimension > VectorValues.MAX_DIMENSIONS) {
throw new IllegalArgumentException(
"cannot index vectors with dimension greater than " + VectorValues.MAX_DIMENSIONS);
}
if (searchStrategy == null || !searchStrategy.isHnsw()) {
throw new IllegalArgumentException(
"search strategy must not be null or non HNSW type, received: " + searchStrategy);
}
FieldType type = new FieldType();
type.setVectorDimensionsAndSearchStrategy(dimension, searchStrategy);
type.putAttribute(HnswGraphBuilder.HNSW_MAX_CONN_ATTRIBUTE_KEY, String.valueOf(maxConn));
type.putAttribute(HnswGraphBuilder.HNSW_BEAM_WIDTH_ATTRIBUTE_KEY, String.valueOf(beamWidth));
type.freeze();
return type;
}
/**
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
* no value. Vectors of a single field share the same dimension and search strategy. Note that
@ -66,7 +98,7 @@ public class VectorField extends Field {
* dimension > 1024.
*/
public VectorField(String name, float[] vector, VectorValues.SearchStrategy searchStrategy) {
super(name, getType(vector, searchStrategy));
super(name, createType(vector, searchStrategy));
fieldsData = vector;
}
@ -84,6 +116,21 @@ public class VectorField extends Field {
this(name, vector, VectorValues.SearchStrategy.EUCLIDEAN_HNSW);
}
/**
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
* no value. Vectors of a single field share the same dimension and search strategy.
*
* @param name field name
* @param vector value
* @param fieldType field type
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
* dimension > 1024.
*/
public VectorField(String name, float[] vector, FieldType fieldType) {
super(name, fieldType);
fieldsData = vector;
}
/** Return the vector value of this field */
public float[] vectorValue() {
return (float[]) fieldsData;

View File

@ -43,10 +43,12 @@ public final class HnswGraphBuilder {
*/
// default max connections per node
public static int DEFAULT_MAX_CONN = 16;
public static final int DEFAULT_MAX_CONN = 16;
public static String HNSW_MAX_CONN_ATTRIBUTE_KEY = "max_connections";
// default candidate list size
public static int DEFAULT_BEAM_WIDTH = 16;
public static final int DEFAULT_BEAM_WIDTH = 16;
public static String HNSW_BEAM_WIDTH_ATTRIBUTE_KEY = "beam_width";
private final int maxConn;
private final int beamWidth;

View File

@ -30,6 +30,7 @@ import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.lucene90.Lucene90VectorReader;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.VectorField;
@ -58,8 +59,7 @@ public class TestKnnGraph extends LuceneTestCase {
public void setup() {
randSeed = random().nextLong();
if (random().nextBoolean()) {
maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN;
HnswGraphBuilder.DEFAULT_MAX_CONN = random().nextInt(256) + 2;
maxConn = random().nextInt(256) + 2;
}
int strategy = random().nextInt(SearchStrategy.values().length - 1) + 1;
searchStrategy = SearchStrategy.values()[strategy];
@ -67,7 +67,7 @@ public class TestKnnGraph extends LuceneTestCase {
@After
public void cleanup() {
HnswGraphBuilder.DEFAULT_MAX_CONN = maxConn;
maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN;
}
/** Basic test of creating documents in a graph */
@ -196,7 +196,7 @@ public class TestKnnGraph extends LuceneTestCase {
int[][] copyGraph(KnnGraphValues values) throws IOException {
int size = values.size();
int[][] graph = new int[size][];
int[] scratch = new int[HnswGraphBuilder.DEFAULT_MAX_CONN];
int[] scratch = new int[maxConn];
for (int node = 0; node < size; node++) {
int n, count = 0;
values.seek(node);
@ -368,12 +368,12 @@ public class TestKnnGraph extends LuceneTestCase {
assertTrue(
"Graph has " + graphSize + " nodes, but one of them has no neighbors", graphSize > 1);
}
if (HnswGraphBuilder.DEFAULT_MAX_CONN > graphSize) {
if (maxConn > graphSize) {
// assert that the graph in each leaf is connected
assertConnected(graph);
} else {
// assert that max-connections was respected
assertMaxConn(graph, HnswGraphBuilder.DEFAULT_MAX_CONN);
assertMaxConn(graph, maxConn);
}
totalGraphDocs += graphSize;
}
@ -439,7 +439,10 @@ public class TestKnnGraph extends LuceneTestCase {
throws IOException {
Document doc = new Document();
if (vector != null) {
doc.add(new VectorField(KNN_GRAPH_FIELD, vector, searchStrategy));
FieldType fieldType =
VectorField.createHnswType(
vector.length, searchStrategy, maxConn, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
doc.add(new VectorField(KNN_GRAPH_FIELD, vector, fieldType));
}
String idString = Integer.toString(id);
doc.add(new StringField("id", idString, Field.Store.YES));

View File

@ -25,6 +25,7 @@ import org.apache.lucene.codecs.Codec;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.VectorField;
@ -76,11 +77,16 @@ public class TestVectorValues extends LuceneTestCase {
public void testFieldConstructorExceptions() {
expectThrows(IllegalArgumentException.class, () -> new VectorField(null, new float[1]));
expectThrows(IllegalArgumentException.class, () -> new VectorField("f", null));
expectThrows(IllegalArgumentException.class, () -> new VectorField("f", new float[1], null));
expectThrows(
IllegalArgumentException.class,
() -> new VectorField("f", new float[1], (SearchStrategy) null));
expectThrows(IllegalArgumentException.class, () -> new VectorField("f", new float[0]));
expectThrows(
IllegalArgumentException.class,
() -> new VectorField("f", new float[VectorValues.MAX_DIMENSIONS + 1]));
expectThrows(
IllegalArgumentException.class,
() -> new VectorField("f", new float[VectorValues.MAX_DIMENSIONS + 1], (FieldType) null));
}
public void testFieldSetValue() {
@ -92,6 +98,25 @@ public class TestVectorValues extends LuceneTestCase {
expectThrows(IllegalArgumentException.class, () -> field.setVectorValue(null));
}
public void testFieldCreateFieldType() {
expectThrows(
IllegalArgumentException.class,
() -> VectorField.createHnswType(0, SearchStrategy.EUCLIDEAN_HNSW, 16, 16));
expectThrows(
IllegalArgumentException.class,
() ->
VectorField.createHnswType(
VectorValues.MAX_DIMENSIONS + 1, SearchStrategy.EUCLIDEAN_HNSW, 16, 16));
expectThrows(
IllegalArgumentException.class,
() -> VectorField.createHnswType(VectorValues.MAX_DIMENSIONS + 1, null, 16, 16));
expectThrows(
IllegalArgumentException.class,
() ->
VectorField.createHnswType(
VectorValues.MAX_DIMENSIONS + 1, SearchStrategy.NONE, 16, 16));
}
// Illegal schema change tests:
public void testIllegalDimChangeTwoDocs() throws Exception {

View File

@ -37,6 +37,7 @@ import java.util.Locale;
import java.util.Set;
import org.apache.lucene.codecs.lucene90.Lucene90VectorReader;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.document.VectorField;
import org.apache.lucene.index.CodecReader;
@ -83,6 +84,8 @@ public class KnnGraphTester {
private boolean reindex;
private boolean forceMerge;
private int reindexTimeMsec;
private int beamWidth;
private int maxConn;
@SuppressForbidden(reason = "uses Random()")
private KnnGraphTester() {
@ -132,13 +135,13 @@ public class KnnGraphTester {
if (iarg == args.length - 1) {
throw new IllegalArgumentException("-beamWidthIndex requires a following number");
}
HnswGraphBuilder.DEFAULT_BEAM_WIDTH = Integer.parseInt(args[++iarg]);
beamWidth = Integer.parseInt(args[++iarg]);
break;
case "-maxConn":
if (iarg == args.length - 1) {
throw new IllegalArgumentException("-maxConn requires a following number");
}
HnswGraphBuilder.DEFAULT_MAX_CONN = Integer.parseInt(args[++iarg]);
maxConn = Integer.parseInt(args[++iarg]);
break;
case "-dim":
if (iarg == args.length - 1) {
@ -223,12 +226,7 @@ public class KnnGraphTester {
}
private String formatIndexPath(Path docsPath) {
return docsPath.getFileName()
+ "-"
+ HnswGraphBuilder.DEFAULT_MAX_CONN
+ "-"
+ HnswGraphBuilder.DEFAULT_BEAM_WIDTH
+ ".index";
return docsPath.getFileName() + "-" + maxConn + "-" + beamWidth + ".index";
}
@SuppressForbidden(reason = "Prints stuff")
@ -250,9 +248,7 @@ public class KnnGraphTester {
private void dumpGraph(Path docsPath) throws IOException {
try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) {
RandomAccessVectorValues values = vectors.randomAccess();
HnswGraphBuilder builder =
new HnswGraphBuilder(
vectors, HnswGraphBuilder.DEFAULT_MAX_CONN, HnswGraphBuilder.DEFAULT_BEAM_WIDTH, 0);
HnswGraphBuilder builder = new HnswGraphBuilder(vectors, maxConn, beamWidth, 0);
// start at node 1
for (int i = 1; i < numDocs; i++) {
builder.addGraphNode(values.vectorValue(i));
@ -413,8 +409,8 @@ public class KnnGraphTester {
totalCpuTime / (float) numIters,
numDocs,
fanout,
HnswGraphBuilder.DEFAULT_MAX_CONN,
HnswGraphBuilder.DEFAULT_BEAM_WIDTH,
maxConn,
beamWidth,
totalVisited,
reindexTimeMsec);
}
@ -574,6 +570,9 @@ public class KnnGraphTester {
iwc.setRAMBufferSizeMB(1994d);
// iwc.setMaxBufferedDocs(10000);
FieldType fieldType =
VectorField.createHnswType(
dim, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW, maxConn, beamWidth);
if (quiet == false) {
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
System.out.println("creating index in " + indexPath);
@ -598,8 +597,7 @@ public class KnnGraphTester {
vectors.get(vector);
Document doc = new Document();
// System.out.println("vector=" + vector[0] + "," + vector[1] + "...");
doc.add(
new VectorField(KNN_FIELD, vector, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW));
doc.add(new VectorField(KNN_FIELD, vector, fieldType));
doc.add(new StoredField(ID_FIELD, i));
iw.addDocument(doc);
}