Refactor and javadoc update for KNN vector writer classes (#13548)

This commit is contained in:
Patrick Zhai 2024-07-08 13:04:27 -07:00 committed by GitHub
parent 3304b60c9c
commit ceb4539609
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 74 additions and 77 deletions

View File

@ -258,7 +258,8 @@ New Features
Improvements
---------------------
(No changes)
* GITHUB#13548: Refactor and javadoc update for KNN vector writer classes. (Patrick Zhai)
Optimizations
---------------------

View File

@ -20,14 +20,18 @@ package org.apache.lucene.codecs;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.internal.hppc.IntIntHashMap;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.util.Accountable;
@ -139,6 +143,60 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
}
}
/**
* Given old doc ids and an id mapping, maps old ordinal to new ordinal. Note: this method return
* nothing and output are written to parameters
*
* @param oldDocIds the old or current document ordinals. Must not be null.
* @param sortMap the document sorting map for how to make the new ordinals. Must not be null.
* @param old2NewOrd int[] maps from old ord to new ord
* @param new2OldOrd int[] maps from new ord to old ord
* @param newDocsWithField set of new doc ids which has the value
*/
public static void mapOldOrdToNewOrd(
DocsWithFieldSet oldDocIds,
Sorter.DocMap sortMap,
int[] old2NewOrd,
int[] new2OldOrd,
DocsWithFieldSet newDocsWithField)
throws IOException {
// TODO: a similar function exists in IncrementalHnswGraphMerger#getNewOrdMapping
// maybe we can do a further refactoring
Objects.requireNonNull(oldDocIds);
Objects.requireNonNull(sortMap);
assert (old2NewOrd != null || new2OldOrd != null || newDocsWithField != null);
assert (old2NewOrd == null || old2NewOrd.length == oldDocIds.cardinality());
assert (new2OldOrd == null || new2OldOrd.length == oldDocIds.cardinality());
IntIntHashMap newIdToOldOrd = new IntIntHashMap();
DocIdSetIterator iterator = oldDocIds.iterator();
int[] newDocIds = new int[oldDocIds.cardinality()];
int oldOrd = 0;
for (int oldDocId = iterator.nextDoc();
oldDocId != DocIdSetIterator.NO_MORE_DOCS;
oldDocId = iterator.nextDoc()) {
int newId = sortMap.oldToNew(oldDocId);
newIdToOldOrd.put(newId, oldOrd);
newDocIds[oldOrd] = newId;
oldOrd++;
}
Arrays.sort(newDocIds);
int newOrd = 0;
for (int newDocId : newDocIds) {
int currOldOrd = newIdToOldOrd.get(newDocId);
if (old2NewOrd != null) {
old2NewOrd[currOldOrd] = newOrd;
}
if (new2OldOrd != null) {
new2OldOrd[newOrd] = currOldOrd;
}
if (newDocsWithField != null) {
newDocsWithField.add(newDocId);
}
newOrd++;
}
}
/** View over multiple vector values supporting iterator-style access via DocIdMerger. */
public static final class MergedVectorValues {
private MergedVectorValues() {}

View File

@ -40,7 +40,7 @@ public class OrdToDocDISIReaderConfiguration {
* <p>Within outputMeta the format is as follows:
*
* <ul>
* <li><b>[int8]</b> if equals to -2, empty - no vectory values. If equals to -1, dense all
* <li><b>[int8]</b> if equals to -2, empty - no vector values. If equals to -1, dense all
* documents have values for a field. If equals to 0, sparse some documents missing
* values.
* <li>DocIds were encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput,

View File

@ -56,8 +56,8 @@ import org.apache.lucene.store.IndexOutput;
* <li><b>[vlong]</b> length of this field's vectors, in bytes
* <li><b>[vint]</b> dimension of this field's vectors
* <li><b>[int]</b> the number of documents having values for this field
* <li><b>[int8]</b> if equals to -1, dense all documents have values for a field. If equals to
* 0, sparse some documents missing values.
* <li><b>[int8]</b> if equals to -2, empty - no vector values. If equals to -1, dense all
* documents have values for a field. If equals to 0, sparse some documents missing values.
* <li>DocIds were encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)}
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
* that only in sparse case

View File

@ -44,7 +44,6 @@ import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
@ -191,27 +190,10 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
private void writeSortingField(FieldWriter<?> fieldData, int maxDoc, Sorter.DocMap sortMap)
throws IOException {
final int[] docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
DocIdSetIterator iterator = fieldData.docsWithField.iterator();
for (int docID = iterator.nextDoc();
docID != DocIdSetIterator.NO_MORE_DOCS;
docID = iterator.nextDoc()) {
int newDocID = sortMap.oldToNew(docID);
docIdOffsets[newDocID] = offset++;
}
final int[] ordMap = new int[fieldData.docsWithField.cardinality()]; // new ord to old ord
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
final int[] ordMap = new int[offset - 1]; // new ord to old ord
int ord = 0;
int doc = 0;
for (int docIdOffset : docIdOffsets) {
if (docIdOffset != 0) {
ordMap[ord] = docIdOffset - 1;
newDocsWithField.add(doc);
ord++;
}
doc++;
}
mapOldOrdToNewOrd(fieldData.docsWithField, sortMap, null, ordMap, newDocsWithField);
// write vector values
long vectorDataOffset =

View File

@ -24,14 +24,11 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.MergePolicy;
import org.apache.lucene.index.MergeScheduler;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.hnsw.HnswGraph;
/**
@ -69,11 +66,6 @@ import org.apache.lucene.util.hnsw.HnswGraph;
* <li><b>[vlong]</b> length of this field's index data, in bytes
* <li><b>[vint]</b> dimension of this field's vectors
* <li><b>[int]</b> the number of documents having values for this field
* <li><b>[int8]</b> if equals to -1, dense all documents have values for a field. If equals to
* 0, sparse some documents missing values.
* <li>DocIds were encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)}
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
* that only in sparse case
* <li><b>[vint]</b> the maximum number of connections (neighbours) that each node can have
* <li><b>[vint]</b> number of levels in the graph
* <li>Graph nodes by level. For each level

View File

@ -196,29 +196,10 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
private void writeSortingField(FieldWriter<?> fieldData, Sorter.DocMap sortMap)
throws IOException {
final int[] docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
DocIdSetIterator iterator = fieldData.docsWithField.iterator();
for (int docID = iterator.nextDoc();
docID != DocIdSetIterator.NO_MORE_DOCS;
docID = iterator.nextDoc()) {
int newDocID = sortMap.oldToNew(docID);
docIdOffsets[newDocID] = offset++;
}
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
final int[] ordMap = new int[offset - 1]; // new ord to old ord
final int[] oldOrdMap = new int[offset - 1]; // old ord to new ord
int ord = 0;
int doc = 0;
for (int docIdOffset : docIdOffsets) {
if (docIdOffset != 0) {
ordMap[ord] = docIdOffset - 1;
oldOrdMap[docIdOffset - 1] = ord;
newDocsWithField.add(doc);
ord++;
}
doc++;
}
final int[] ordMap = new int[fieldData.docsWithField.cardinality()]; // new ord to old ord
final int[] oldOrdMap = new int[fieldData.docsWithField.cardinality()]; // old ord to new ord
mapOldOrdToNewOrd(fieldData.docsWithField, sortMap, oldOrdMap, ordMap, null);
// write graph
long vectorIndexOffset = vectorIndex.getFilePointer();
OnHeapHnswGraph graph = fieldData.getGraph();

View File

@ -399,27 +399,10 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap)
throws IOException {
final int[] docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
DocIdSetIterator iterator = fieldData.docsWithField.iterator();
for (int docID = iterator.nextDoc();
docID != DocIdSetIterator.NO_MORE_DOCS;
docID = iterator.nextDoc()) {
int newDocID = sortMap.oldToNew(docID);
docIdOffsets[newDocID] = offset++;
}
final int[] ordMap = new int[fieldData.docsWithField.cardinality()]; // new ord to old ord
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
final int[] ordMap = new int[offset - 1]; // new ord to old ord
int ord = 0;
int doc = 0;
for (int docIdOffset : docIdOffsets) {
if (docIdOffset != 0) {
ordMap[ord] = docIdOffset - 1;
newDocsWithField.add(doc);
ord++;
}
doc++;
}
mapOldOrdToNewOrd(fieldData.docsWithField, sortMap, null, ordMap, newDocsWithField);
// write vector values
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);

View File

@ -356,7 +356,7 @@ final class FieldUpdatesBuffer {
}
}
BytesRef nextTerm() throws IOException {
private BytesRef nextTerm() throws IOException {
if (lookAheadTermIterator != null) {
if (bufferedUpdate.termValue == null) {
lookAheadTermIterator.next();