mirror of https://github.com/apache/lucene.git
LUCENE-9695: don't merge deleted vectors (#2239)
This commit is contained in:
parent
b1cd6b691f
commit
38ec2602ce
|
@ -153,13 +153,13 @@ public abstract class VectorWriter implements Closeable {
|
|||
private final DocIDMerger<VectorValuesSub> docIdMerger;
|
||||
private final int[] ordBase;
|
||||
private final int cost;
|
||||
private final int size;
|
||||
private int size;
|
||||
|
||||
private int docId;
|
||||
private VectorValuesSub current;
|
||||
// For each doc with a vector, record its ord in the segments being merged. This enables random
|
||||
// access into the
|
||||
// unmerged segments using the ords from the merged segment.
|
||||
/* For each doc with a vector, record its ord in the segments being merged. This enables random
|
||||
* access into the unmerged segments using the ords from the merged segment.
|
||||
*/
|
||||
private int[] ordMap;
|
||||
private int ord;
|
||||
|
||||
|
@ -171,6 +171,10 @@ public abstract class VectorWriter implements Closeable {
|
|||
totalCost += sub.values.cost();
|
||||
totalSize += sub.values.size();
|
||||
}
|
||||
/* This size includes deleted docs, but when we iterate over docs here (nextDoc())
|
||||
* we skip deleted docs. So we sneakily update this size once we observe that iteration is complete.
|
||||
* That way by the time we are asked to do random access for graph building, we have a correct size.
|
||||
*/
|
||||
cost = totalCost;
|
||||
size = totalSize;
|
||||
ordMap = new int[size];
|
||||
|
@ -194,6 +198,9 @@ public abstract class VectorWriter implements Closeable {
|
|||
current = docIdMerger.next();
|
||||
if (current == null) {
|
||||
docId = NO_MORE_DOCS;
|
||||
/* update the size to reflect the number of *non-deleted* documents seen so we can support
|
||||
* random access. */
|
||||
size = ord;
|
||||
} else {
|
||||
docId = current.mappedDocID;
|
||||
ordMap[ord++] = ordBase[current.segmentIndex] + current.count - 1;
|
||||
|
|
|
@ -210,6 +210,7 @@ public final class Lucene90VectorWriter extends VectorWriter {
|
|||
for (int i = 0; i < size; i++) {
|
||||
int node = nodes[i];
|
||||
assert node > lastNode : "nodes out of order: " + lastNode + "," + node;
|
||||
assert node < offsets.length : "node too large: " + node + ">=" + offsets.length;
|
||||
graphData.writeVInt(node - lastNode);
|
||||
lastNode = node;
|
||||
}
|
||||
|
|
|
@ -139,6 +139,7 @@ public final class HnswGraph extends KnnGraphValues {
|
|||
graphValues.seek(topCandidateNode);
|
||||
int friendOrd;
|
||||
while ((friendOrd = graphValues.nextNeighbor()) != NO_MORE_DOCS) {
|
||||
assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
|
||||
if (visited.get(friendOrd)) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -38,9 +38,9 @@ public final class HnswGraphBuilder {
|
|||
// expose for testing.
|
||||
public static long randSeed = DEFAULT_RAND_SEED;
|
||||
|
||||
// These "default" hyper-parameter settings are exposed (and non-final) to enable performance
|
||||
// testing
|
||||
// since the indexing API doesn't provide any control over them.
|
||||
/* These "default" hyper-parameter settings are exposed (and non-final) to enable performance
|
||||
* testing since the indexing API doesn't provide any control over them.
|
||||
*/
|
||||
|
||||
// default max connections per node
|
||||
public static int DEFAULT_MAX_CONN = 16;
|
||||
|
@ -116,6 +116,9 @@ public final class HnswGraphBuilder {
|
|||
throw new IllegalArgumentException(
|
||||
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
||||
}
|
||||
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||
infoStream.message(HNSW_COMPONENT, "build graph from " + vectors.size() + " vectors");
|
||||
}
|
||||
long start = System.nanoTime(), t = start;
|
||||
// start at node 1! node 0 is added implicitly, in the constructor
|
||||
for (int node = 1; node < vectors.size(); node++) {
|
||||
|
@ -149,23 +152,27 @@ public final class HnswGraphBuilder {
|
|||
|
||||
int node = hnsw.addNode();
|
||||
|
||||
// connect neighbors to the new node, using a diversity heuristic that chooses successive
|
||||
// nearest neighbors that are closer to the new node than they are to the previously-selected
|
||||
// neighbors
|
||||
addDiverseNeighbors(node, candidates, buildVectors);
|
||||
/* connect neighbors to the new node, using a diversity heuristic that chooses successive
|
||||
* nearest neighbors that are closer to the new node than they are to the previously-selected
|
||||
* neighbors
|
||||
*/
|
||||
addDiverseNeighbors(node, candidates);
|
||||
}
|
||||
|
||||
private void addDiverseNeighbors(
|
||||
int node, NeighborQueue candidates, RandomAccessVectorValues vectors) throws IOException {
|
||||
// For each of the beamWidth nearest candidates (going from best to worst), select it only if it
|
||||
// is closer to target
|
||||
// than it is to any of the already-selected neighbors (ie selected in this method, since the
|
||||
// node is new and has no
|
||||
// prior neighbors).
|
||||
/* TODO: we are not maintaining nodes in strict score order; the forward links
|
||||
* are added in sorted order, but the reverse implicit ones are not. Diversity heuristic should
|
||||
* work better if we keep the neighbor arrays sorted. Possibly we should switch back to a heap?
|
||||
* But first we should just see if sorting makes a significant difference.
|
||||
*/
|
||||
private void addDiverseNeighbors(int node, NeighborQueue candidates) throws IOException {
|
||||
/* For each of the beamWidth nearest candidates (going from best to worst), select it only if it
|
||||
* is closer to target than it is to any of the already-selected neighbors (ie selected in this method,
|
||||
* since the node is new and has no prior neighbors).
|
||||
*/
|
||||
NeighborArray neighbors = hnsw.getNeighbors(node);
|
||||
assert neighbors.size() == 0; // new node
|
||||
popToScratch(candidates);
|
||||
selectDiverse(neighbors, scratch, vectors);
|
||||
selectDiverse(neighbors, scratch);
|
||||
|
||||
// Link the selected nodes to the new node, and the new node to the selected nodes (again
|
||||
// applying diversity heuristic)
|
||||
|
@ -175,21 +182,20 @@ public final class HnswGraphBuilder {
|
|||
NeighborArray nbrNbr = hnsw.getNeighbors(nbr);
|
||||
nbrNbr.add(node, neighbors.score[i]);
|
||||
if (nbrNbr.size() > maxConn) {
|
||||
diversityUpdate(nbrNbr, buildVectors);
|
||||
diversityUpdate(nbrNbr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void selectDiverse(
|
||||
NeighborArray neighbors, NeighborArray candidates, RandomAccessVectorValues vectors)
|
||||
throws IOException {
|
||||
private void selectDiverse(NeighborArray neighbors, NeighborArray candidates) throws IOException {
|
||||
// Select the best maxConn neighbors of the new node, applying the diversity heuristic
|
||||
for (int i = candidates.size() - 1; neighbors.size() < maxConn && i >= 0; i--) {
|
||||
// compare each neighbor (in distance order) against the closer neighbors selected so far,
|
||||
// only adding it if it is closer to the target than to any of the other selected neighbors
|
||||
int cNode = candidates.node[i];
|
||||
float cScore = candidates.score[i];
|
||||
if (diversityCheck(vectors.vectorValue(cNode), cScore, neighbors, buildVectors)) {
|
||||
assert cNode < hnsw.size();
|
||||
if (diversityCheck(vectorValues.vectorValue(cNode), cScore, neighbors, buildVectors)) {
|
||||
neighbors.add(cNode, cScore);
|
||||
}
|
||||
}
|
||||
|
@ -232,10 +238,9 @@ public final class HnswGraphBuilder {
|
|||
return true;
|
||||
}
|
||||
|
||||
private void diversityUpdate(NeighborArray neighbors, RandomAccessVectorValues vectorValues)
|
||||
throws IOException {
|
||||
private void diversityUpdate(NeighborArray neighbors) throws IOException {
|
||||
assert neighbors.size() == maxConn + 1;
|
||||
int replacePoint = findNonDiverse(neighbors, vectorValues);
|
||||
int replacePoint = findNonDiverse(neighbors);
|
||||
if (replacePoint == -1) {
|
||||
// none found; check score against worst existing neighbor
|
||||
bound.set(neighbors.score[0]);
|
||||
|
@ -253,8 +258,7 @@ public final class HnswGraphBuilder {
|
|||
}
|
||||
|
||||
// scan neighbors looking for diversity violations
|
||||
private int findNonDiverse(NeighborArray neighbors, RandomAccessVectorValues vectorValues)
|
||||
throws IOException {
|
||||
private int findNonDiverse(NeighborArray neighbors) throws IOException {
|
||||
for (int i = neighbors.size() - 1; i >= 0; i--) {
|
||||
// check each neighbor against its better-scoring neighbors. If it fails diversity check with
|
||||
// them, drop it
|
||||
|
@ -263,7 +267,7 @@ public final class HnswGraphBuilder {
|
|||
float[] nbrVector = vectorValues.vectorValue(nbrNode);
|
||||
for (int j = maxConn; j > i; j--) {
|
||||
float diversityCheck =
|
||||
searchStrategy.compare(nbrVector, vectorValues.vectorValue(neighbors.node[j]));
|
||||
searchStrategy.compare(nbrVector, buildVectors.vectorValue(neighbors.node[j]));
|
||||
if (bound.check(diversityCheck) == false) {
|
||||
// node j is too similar to node i given its score relative to the base node
|
||||
// replace it with the new node, which is at [maxConn]
|
||||
|
|
|
@ -30,11 +30,15 @@ import org.apache.lucene.codecs.Codec;
|
|||
import org.apache.lucene.codecs.lucene90.Lucene90VectorReader;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.SortedDocValuesField;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.document.VectorField;
|
||||
import org.apache.lucene.index.VectorValues.SearchStrategy;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||
|
@ -48,6 +52,8 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
|
||||
private static int maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN;
|
||||
|
||||
private SearchStrategy searchStrategy;
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
randSeed = random().nextLong();
|
||||
|
@ -55,6 +61,8 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
maxConn = HnswGraphBuilder.DEFAULT_MAX_CONN;
|
||||
HnswGraphBuilder.DEFAULT_MAX_CONN = random().nextInt(256) + 1;
|
||||
}
|
||||
int strategy = random().nextInt(SearchStrategy.values().length - 1) + 1;
|
||||
searchStrategy = SearchStrategy.values()[strategy];
|
||||
}
|
||||
|
||||
@After
|
||||
|
@ -102,7 +110,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
new IndexWriter(dir, newIndexWriterConfig(null).setCodec(Codec.forName("Lucene90")))) {
|
||||
int numDoc = atLeast(100);
|
||||
int dimension = atLeast(10);
|
||||
float[][] values = new float[numDoc][];
|
||||
float[][] values = randomVectors(numDoc, dimension);
|
||||
for (int i = 0; i < numDoc; i++) {
|
||||
if (random().nextBoolean()) {
|
||||
values[i] = new float[dimension];
|
||||
|
@ -113,7 +121,6 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
}
|
||||
add(iw, i, values[i]);
|
||||
if (random().nextInt(10) == 3) {
|
||||
// System.out.println("commit @" + i);
|
||||
iw.commit();
|
||||
}
|
||||
}
|
||||
|
@ -124,23 +131,88 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
private void dumpGraph(KnnGraphValues values, int size) throws IOException {
|
||||
for (int node = 0; node < size; node++) {
|
||||
int n;
|
||||
System.out.print("" + node + ":");
|
||||
values.seek(node);
|
||||
while ((n = values.nextNeighbor()) != NO_MORE_DOCS) {
|
||||
System.out.print(" " + n);
|
||||
}
|
||||
System.out.println();
|
||||
/**
|
||||
* Verify that we get the *same* graph by indexing one segment as we do by indexing two segments
|
||||
* and merging.
|
||||
*/
|
||||
public void testMergeProducesSameGraph() throws Exception {
|
||||
long seed = random().nextLong();
|
||||
int numDoc = atLeast(100);
|
||||
int dimension = atLeast(10);
|
||||
float[][] values = randomVectors(numDoc, dimension);
|
||||
int mergePoint = random().nextInt(numDoc);
|
||||
int[][] mergedGraph = getIndexedGraph(values, mergePoint, seed);
|
||||
int[][] singleSegmentGraph = getIndexedGraph(values, -1, seed);
|
||||
assertGraphEquals(singleSegmentGraph, mergedGraph);
|
||||
}
|
||||
|
||||
private void assertGraphEquals(int[][] expected, int[][] actual) {
|
||||
assertEquals("graph sizes differ", expected.length, actual.length);
|
||||
for (int i = 0; i < expected.length; i++) {
|
||||
assertArrayEquals("difference at ord=" + i, expected[i], actual[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: testSorted
|
||||
// TODO: testDeletions
|
||||
private int[][] getIndexedGraph(float[][] values, int mergePoint, long seed) throws IOException {
|
||||
HnswGraphBuilder.randSeed = seed;
|
||||
int[][] graph;
|
||||
try (Directory dir = newDirectory()) {
|
||||
IndexWriterConfig iwc = newIndexWriterConfig();
|
||||
iwc.setMergePolicy(new LogDocMergePolicy()); // for predictable segment ordering when merging
|
||||
iwc.setCodec(Codec.forName("Lucene90")); // don't use SimpleTextCodec
|
||||
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
add(iw, i, values[i]);
|
||||
if (i == mergePoint) {
|
||||
// flush proactively to create a segment
|
||||
iw.flush();
|
||||
}
|
||||
}
|
||||
iw.forceMerge(1);
|
||||
}
|
||||
try (IndexReader reader = DirectoryReader.open(dir)) {
|
||||
Lucene90VectorReader vectorReader =
|
||||
((Lucene90VectorReader) ((CodecReader) getOnlyLeafReader(reader)).getVectorReader());
|
||||
graph = copyGraph(vectorReader.getGraphValues(KNN_GRAPH_FIELD));
|
||||
}
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
private float[][] randomVectors(int numDoc, int dimension) {
|
||||
float[][] values = new float[numDoc][];
|
||||
for (int i = 0; i < numDoc; i++) {
|
||||
if (random().nextBoolean()) {
|
||||
values[i] = new float[dimension];
|
||||
for (int j = 0; j < dimension; j++) {
|
||||
values[i][j] = random().nextFloat();
|
||||
}
|
||||
VectorUtil.l2normalize(values[i]);
|
||||
}
|
||||
}
|
||||
return values;
|
||||
}
|
||||
|
||||
int[][] copyGraph(KnnGraphValues values) throws IOException {
|
||||
int size = values.size();
|
||||
int[][] graph = new int[size][];
|
||||
int[] scratch = new int[HnswGraphBuilder.DEFAULT_MAX_CONN];
|
||||
for (int node = 0; node < size; node++) {
|
||||
int n, count = 0;
|
||||
values.seek(node);
|
||||
while ((n = values.nextNeighbor()) != NO_MORE_DOCS) {
|
||||
scratch[count++] = n;
|
||||
// graph[node][i++] = n;
|
||||
}
|
||||
graph[node] = ArrayUtil.copyOfSubArray(scratch, 0, count);
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
/** Verify that searching does something reasonable */
|
||||
public void testSearch() throws Exception {
|
||||
// We can't use dot product here since the vectors are laid out on a grid, not a sphere.
|
||||
searchStrategy = SearchStrategy.EUCLIDEAN_HNSW;
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
// Add a document for every cartesian point in an NxN square so we can
|
||||
|
@ -297,10 +369,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
"Graph has " + graphSize + " nodes, but one of them has no neighbors", graphSize > 1);
|
||||
}
|
||||
if (HnswGraphBuilder.DEFAULT_MAX_CONN > graphSize) {
|
||||
// assert that the graph in each leaf is connected and undirected (ie links are
|
||||
// reciprocated)
|
||||
// We cannot assert this when diversity criterion is applied
|
||||
// assertReciprocal(graph);
|
||||
// assert that the graph in each leaf is connected
|
||||
assertConnected(graph);
|
||||
} else {
|
||||
// assert that max-connections was respected
|
||||
|
@ -330,20 +399,6 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
private void assertReciprocal(int[][] graph) {
|
||||
// The graph is undirected: if a -> b then b -> a.
|
||||
for (int i = 0; i < graph.length; i++) {
|
||||
if (graph[i] != null) {
|
||||
for (int j = 0; j < graph[i].length; j++) {
|
||||
int k = graph[i][j];
|
||||
assertNotNull(graph[k]);
|
||||
assertTrue(
|
||||
"" + i + "->" + k + " is not reciprocated", Arrays.binarySearch(graph[k], i) >= 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void assertConnected(int[][] graph) {
|
||||
// every node in the graph is reachable from every other node
|
||||
Set<Integer> visited = new HashSet<>();
|
||||
|
@ -378,13 +433,19 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
}
|
||||
|
||||
private void add(IndexWriter iw, int id, float[] vector) throws IOException {
|
||||
add(iw, id, vector, searchStrategy);
|
||||
}
|
||||
|
||||
private void add(IndexWriter iw, int id, float[] vector, SearchStrategy searchStrategy)
|
||||
throws IOException {
|
||||
Document doc = new Document();
|
||||
if (vector != null) {
|
||||
// TODO: choose random search strategy
|
||||
doc.add(new VectorField(KNN_GRAPH_FIELD, vector, VectorValues.SearchStrategy.EUCLIDEAN_HNSW));
|
||||
doc.add(new VectorField(KNN_GRAPH_FIELD, vector, searchStrategy));
|
||||
}
|
||||
doc.add(new StringField("id", Integer.toString(id), Field.Store.YES));
|
||||
// System.out.println("add " + id + " " + Arrays.toString(vector));
|
||||
iw.addDocument(doc);
|
||||
String idString = Integer.toString(id);
|
||||
doc.add(new StringField("id", idString, Field.Store.YES));
|
||||
doc.add(new SortedDocValuesField("id", new BytesRef(idString)));
|
||||
// XSSystem.out.println("add " + idString + " " + Arrays.toString(vector));
|
||||
iw.updateDocument(new Term("id", idString), doc);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
|||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
|
@ -32,6 +33,7 @@ import org.apache.lucene.search.Sort;
|
|||
import org.apache.lucene.search.SortField;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.store.FSDirectory;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.TestUtil;
|
||||
|
@ -715,9 +717,9 @@ public class TestVectorValues extends LuceneTestCase {
|
|||
if (random().nextBoolean() && values[i] != null) {
|
||||
// sometimes use a shared scratch array
|
||||
System.arraycopy(values[i], 0, scratch, 0, scratch.length);
|
||||
add(iw, fieldName, i, scratch);
|
||||
add(iw, fieldName, i, scratch, SearchStrategy.NONE);
|
||||
} else {
|
||||
add(iw, fieldName, i, values[i]);
|
||||
add(iw, fieldName, i, values[i], SearchStrategy.NONE);
|
||||
}
|
||||
if (random().nextInt(10) == 2) {
|
||||
// sometimes delete a random document
|
||||
|
@ -733,7 +735,7 @@ public class TestVectorValues extends LuceneTestCase {
|
|||
iw.commit();
|
||||
}
|
||||
}
|
||||
iw.forceMerge(1);
|
||||
int numDeletes = 0;
|
||||
try (IndexReader reader = iw.getReader()) {
|
||||
int valueCount = 0, totalSize = 0;
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
|
@ -748,29 +750,108 @@ public class TestVectorValues extends LuceneTestCase {
|
|||
assertEquals(dimension, v.length);
|
||||
String idString = ctx.reader().document(docId).getField("id").stringValue();
|
||||
int id = Integer.parseInt(idString);
|
||||
assertArrayEquals(idString, values[id], v, 0);
|
||||
++valueCount;
|
||||
if (ctx.reader().getLiveDocs() == null || ctx.reader().getLiveDocs().get(docId)) {
|
||||
assertArrayEquals(idString, values[id], v, 0);
|
||||
++valueCount;
|
||||
} else {
|
||||
++numDeletes;
|
||||
assertNull(values[id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
assertEquals(numValues, valueCount);
|
||||
assertEquals(numValues, totalSize);
|
||||
assertEquals(numValues, totalSize - numDeletes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void add(IndexWriter iw, String field, int id, float[] vector) throws IOException {
|
||||
add(iw, field, id, random().nextInt(100), vector);
|
||||
/**
|
||||
* Index random vectors, sometimes skipping documents, sometimes updating a document, sometimes
|
||||
* merging, sometimes sorting the index, using an HNSW search strategy so as to also produce a
|
||||
* graph, and verify that the expected values can be read back consistently.
|
||||
*/
|
||||
public void testRandomWithUpdatesAndGraph() throws Exception {
|
||||
IndexWriterConfig iwc = newIndexWriterConfig();
|
||||
String fieldName = "field";
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
int numDoc = atLeast(100);
|
||||
int dimension = atLeast(10);
|
||||
float[][] values = new float[numDoc][];
|
||||
float[][] id2value = new float[numDoc][];
|
||||
int[] id2ord = new int[numDoc];
|
||||
for (int i = 0; i < numDoc; i++) {
|
||||
int id = random().nextInt(numDoc);
|
||||
float[] value;
|
||||
if (random().nextInt(7) != 3) {
|
||||
// usually index a vector value for a doc
|
||||
value = randomVector(dimension);
|
||||
} else {
|
||||
value = null;
|
||||
}
|
||||
values[i] = value;
|
||||
id2value[id] = value;
|
||||
id2ord[id] = i;
|
||||
add(iw, fieldName, id, value, SearchStrategy.EUCLIDEAN_HNSW);
|
||||
}
|
||||
try (IndexReader reader = iw.getReader()) {
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
Bits liveDocs = ctx.reader().getLiveDocs();
|
||||
VectorValues vectorValues = ctx.reader().getVectorValues(fieldName);
|
||||
if (vectorValues == null) {
|
||||
continue;
|
||||
}
|
||||
int docId;
|
||||
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
|
||||
float[] v = vectorValues.vectorValue();
|
||||
assertEquals(dimension, v.length);
|
||||
String idString = ctx.reader().document(docId).getField("id").stringValue();
|
||||
int id = Integer.parseInt(idString);
|
||||
if (liveDocs == null || liveDocs.get(docId)) {
|
||||
assertArrayEquals(
|
||||
"values differ for id=" + idString + ", docid=" + docId + " leaf=" + ctx.ord,
|
||||
id2value[id],
|
||||
v,
|
||||
0);
|
||||
} else {
|
||||
if (id2value[id] != null) {
|
||||
assertFalse(Arrays.equals(id2value[id], v));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void add(
|
||||
IndexWriter iw, String field, int id, float[] vector, SearchStrategy searchStrategy)
|
||||
throws IOException {
|
||||
add(iw, field, id, random().nextInt(100), vector, searchStrategy);
|
||||
}
|
||||
|
||||
private void add(IndexWriter iw, String field, int id, int sortkey, float[] vector)
|
||||
throws IOException {
|
||||
add(iw, field, id, sortkey, vector, SearchStrategy.NONE);
|
||||
}
|
||||
|
||||
private void add(
|
||||
IndexWriter iw,
|
||||
String field,
|
||||
int id,
|
||||
int sortkey,
|
||||
float[] vector,
|
||||
SearchStrategy searchStrategy)
|
||||
throws IOException {
|
||||
Document doc = new Document();
|
||||
if (vector != null) {
|
||||
doc.add(new VectorField(field, vector));
|
||||
doc.add(new VectorField(field, vector, searchStrategy));
|
||||
}
|
||||
doc.add(new NumericDocValuesField("sortkey", sortkey));
|
||||
doc.add(new StringField("id", Integer.toString(id), Field.Store.YES));
|
||||
iw.addDocument(doc);
|
||||
String idString = Integer.toString(id);
|
||||
doc.add(new StringField("id", idString, Field.Store.YES));
|
||||
Term idTerm = new Term("id", idString);
|
||||
iw.updateDocument(idTerm, doc);
|
||||
}
|
||||
|
||||
private float[] randomVector(int dim) {
|
||||
|
|
|
@ -81,6 +81,7 @@ public class KnnGraphTester {
|
|||
private Path indexPath;
|
||||
private boolean quiet;
|
||||
private boolean reindex;
|
||||
private boolean forceMerge;
|
||||
private int reindexTimeMsec;
|
||||
|
||||
@SuppressForbidden(reason = "uses Random()")
|
||||
|
@ -176,7 +177,7 @@ public class KnnGraphTester {
|
|||
docVectorsPath = Paths.get(args[++iarg]);
|
||||
break;
|
||||
case "-forceMerge":
|
||||
operation = "-forceMerge";
|
||||
forceMerge = true;
|
||||
break;
|
||||
case "-quiet":
|
||||
quiet = true;
|
||||
|
@ -195,6 +196,9 @@ public class KnnGraphTester {
|
|||
throw new IllegalArgumentException("-docs argument is required when indexing");
|
||||
}
|
||||
reindexTimeMsec = createIndex(docVectorsPath, indexPath);
|
||||
if (forceMerge) {
|
||||
forceMerge();
|
||||
}
|
||||
}
|
||||
if (operation != null) {
|
||||
switch (operation) {
|
||||
|
@ -208,9 +212,6 @@ public class KnnGraphTester {
|
|||
testSearch(indexPath, queryPath, null, getNN(docVectorsPath, queryPath));
|
||||
}
|
||||
break;
|
||||
case "-forceMerge":
|
||||
forceMerge();
|
||||
break;
|
||||
case "-dump":
|
||||
dumpGraph(docVectorsPath);
|
||||
break;
|
||||
|
@ -405,19 +406,17 @@ public class KnnGraphTester {
|
|||
}
|
||||
float recall = checkResults(results, nn);
|
||||
totalVisited /= numIters;
|
||||
if (quiet) {
|
||||
System.out.printf(
|
||||
Locale.ROOT,
|
||||
"%5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%d\n",
|
||||
recall,
|
||||
totalCpuTime / (float) numIters,
|
||||
numDocs,
|
||||
fanout,
|
||||
HnswGraphBuilder.DEFAULT_MAX_CONN,
|
||||
HnswGraphBuilder.DEFAULT_BEAM_WIDTH,
|
||||
totalVisited,
|
||||
reindexTimeMsec);
|
||||
}
|
||||
System.out.printf(
|
||||
Locale.ROOT,
|
||||
"%5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%d\n",
|
||||
recall,
|
||||
totalCpuTime / (float) numIters,
|
||||
numDocs,
|
||||
fanout,
|
||||
HnswGraphBuilder.DEFAULT_MAX_CONN,
|
||||
HnswGraphBuilder.DEFAULT_BEAM_WIDTH,
|
||||
totalVisited,
|
||||
reindexTimeMsec);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -444,11 +443,6 @@ public class KnnGraphTester {
|
|||
// System.out.println(Arrays.toString(results[i].scoreDocs));
|
||||
totalMatches += compareNN(nn[i], results[i]);
|
||||
}
|
||||
if (quiet == false) {
|
||||
System.out.println("total matches = " + totalMatches + " out of " + totalResults);
|
||||
System.out.printf(
|
||||
Locale.ROOT, "Average overlap = %.2f%%\n", ((100.0 * totalMatches) / totalResults));
|
||||
}
|
||||
return totalMatches / (float) totalResults;
|
||||
}
|
||||
|
||||
|
@ -578,6 +572,8 @@ public class KnnGraphTester {
|
|||
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.CREATE);
|
||||
// iwc.setMergePolicy(NoMergePolicy.INSTANCE);
|
||||
iwc.setRAMBufferSizeMB(1994d);
|
||||
// iwc.setMaxBufferedDocs(10000);
|
||||
|
||||
if (quiet == false) {
|
||||
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
|
||||
System.out.println("creating index in " + indexPath);
|
||||
|
|
|
@ -58,7 +58,7 @@ public class TestHnsw extends LuceneTestCase {
|
|||
long seed = random().nextLong();
|
||||
HnswGraphBuilder.randSeed = seed;
|
||||
HnswGraphBuilder builder = new HnswGraphBuilder(vectors);
|
||||
HnswGraph hnsw = builder.build(vectors.randomAccess());
|
||||
HnswGraph hnsw = builder.build(vectors);
|
||||
// Recreate the graph while indexing with the same random seed and write it out
|
||||
HnswGraphBuilder.randSeed = seed;
|
||||
try (Directory dir = newDirectory()) {
|
||||
|
@ -104,9 +104,9 @@ public class TestHnsw extends LuceneTestCase {
|
|||
// oriented in the right directions
|
||||
public void testAknnDiverse() throws IOException {
|
||||
int nDoc = 100;
|
||||
RandomAccessVectorValuesProducer vectors = new CircularVectorValues(nDoc);
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
HnswGraphBuilder builder = new HnswGraphBuilder(vectors, 16, 100, random().nextInt());
|
||||
HnswGraph hnsw = builder.build(vectors.randomAccess());
|
||||
HnswGraph hnsw = builder.build(vectors);
|
||||
// run some searches
|
||||
NeighborQueue nn =
|
||||
HnswGraph.search(new float[] {1, 0}, 10, 5, vectors.randomAccess(), hnsw, random());
|
||||
|
|
Loading…
Reference in New Issue