Make RandomAccessVectorValues an implementation detail of HNSW implementations rather than a proper API. (#11964)

`RandomAccessVectorValues` is internally used in our HNSW implementation to
provide random access to vectors, both at index and search time. In order to
better reflect this, this change does the following:
 - `RandomAccessVectorValues` moves to `org.apache.lucene.util.hnsw`.
 - `BufferingKnnVectorsWriter` no longer has a dependency on
   `RandomAccessVectorValues` and moves to `org.apache.lucene.codecs` since
   it's more of a utility class for KNN vector file formats than an index API.
   Maybe we should think of moving it near each file format that uses it
   instead.
 - `SortingCodecReader` no longer has a dependency on
   `RandomAccessVectorValues`.

Closes #10623
This commit is contained in:
Adrien Grand 2022-12-08 08:49:37 +01:00 committed by GitHub
parent 95df7e8109
commit a971120d05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 180 additions and 97 deletions

View File

@ -22,10 +22,10 @@ import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/**
* Builder for HNSW graph. See {@link Lucene90OnHeapHnswGraph} for a gloss on the algorithm and the

View File

@ -31,7 +31,6 @@ import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
@ -47,6 +46,7 @@ import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/**
* Reads vectors from the index segments along with index data structures supporting KNN search.

View File

@ -23,12 +23,12 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.SplittableRandom;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.SparseFixedBitSet;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/**
* An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to

View File

@ -31,7 +31,6 @@ import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
@ -49,6 +48,7 @@ import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/**
* Reads vectors from the index segments along with index data structures supporting KNN search.

View File

@ -20,12 +20,12 @@ package org.apache.lucene.backward_codecs.lucene92;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */

View File

@ -34,7 +34,6 @@ import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorValues;
@ -51,6 +50,7 @@ import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicWriter;
/**

View File

@ -20,12 +20,12 @@ package org.apache.lucene.backward_codecs.lucene94;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */

View File

@ -21,12 +21,11 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.BufferingKnnVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
@ -35,6 +34,7 @@ import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/**
* Writes vector values and knn graphs to index segments.

View File

@ -24,7 +24,6 @@ import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.FixedBitSet;
@ -32,6 +31,7 @@ import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/**
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the

View File

@ -21,13 +21,12 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.BufferingKnnVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
@ -37,6 +36,7 @@ import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/**
* Writes vector values and knn graphs to index segments.

View File

@ -22,14 +22,13 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.BufferingKnnVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
@ -43,6 +42,7 @@ import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicWriter;
/**

View File

@ -28,7 +28,6 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
@ -47,6 +46,7 @@ import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.StringHelper;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/**
* Reads vector values from a simple text format. All vectors are read up front and cached in RAM in

View File

@ -23,8 +23,8 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.codecs.BufferingKnnVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.BufferingKnnVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentWriteState;

View File

@ -15,16 +15,18 @@
* limitations under the License.
*/
package org.apache.lucene.index;
package org.apache.lucene.codecs;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.ArrayUtil;
@ -73,13 +75,13 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
@Override
public VectorValues getVectorValues(String field) throws IOException {
VectorValues vectorValues =
BufferedVectorValues vectorValues =
new BufferedVectorValues(
fieldData.docsWithField,
fieldData.vectors,
fieldData.fieldInfo.getVectorDimension());
return sortMap != null
? new VectorValues.SortingVectorValues(vectorValues, sortMap)
? new SortingVectorValues(vectorValues, sortMap)
: vectorValues;
}
@ -94,6 +96,67 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
}
}
/** Sorting VectorValues that iterate over documents in the order of the provided sortMap */
private static class SortingVectorValues extends VectorValues {
private final BufferedVectorValues randomAccess;
private final int[] docIdOffsets;
private int docId = -1;
SortingVectorValues(BufferedVectorValues delegate, Sorter.DocMap sortMap) throws IOException {
this.randomAccess = delegate.copy();
this.docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
int docID;
while ((docID = delegate.nextDoc()) != NO_MORE_DOCS) {
int newDocID = sortMap.oldToNew(docID);
docIdOffsets[newDocID] = offset++;
}
}
@Override
public int docID() {
return docId;
}
@Override
public int nextDoc() throws IOException {
while (docId < docIdOffsets.length - 1) {
++docId;
if (docIdOffsets[docId] != 0) {
return docId;
}
}
docId = NO_MORE_DOCS;
return docId;
}
@Override
public BytesRef binaryValue() throws IOException {
return randomAccess.binaryValue(docIdOffsets[docId] - 1);
}
@Override
public float[] vectorValue() throws IOException {
return randomAccess.vectorValue(docIdOffsets[docId] - 1);
}
@Override
public int dimension() {
return randomAccess.dimension();
}
@Override
public int size() {
return randomAccess.size();
}
@Override
public int advance(int target) throws IOException {
throw new UnsupportedOperationException();
}
}
@Override
public long ramBytesUsed() {
long total = 0;
@ -197,8 +260,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
}
}
private static class BufferedVectorValues extends VectorValues
implements RandomAccessVectorValues {
private static class BufferedVectorValues extends VectorValues {
final DocsWithFieldSet docsWithField;
@ -225,8 +287,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
docsWithFieldIter = docsWithField.iterator();
}
@Override
public RandomAccessVectorValues copy() {
public BufferedVectorValues copy() {
return new BufferedVectorValues(docsWithField, vectors, dimension);
}
@ -246,7 +307,6 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
return binaryValue;
}
@Override
public BytesRef binaryValue(int targetOrd) {
raBuffer.asFloatBuffer().put(vectors.get(targetOrd));
return raBinaryValue;
@ -257,7 +317,6 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
return vectors.get(ord);
}
@Override
public float[] vectorValue(int targetOrd) {
return vectors.get(targetOrd);
}

View File

@ -41,6 +41,7 @@ import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicWriter;
/**

View File

@ -20,12 +20,12 @@ package org.apache.lucene.codecs.lucene95;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */

View File

@ -20,6 +20,8 @@ package org.apache.lucene.index;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
@ -35,6 +37,7 @@ import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IOSupplier;
import org.apache.lucene.util.packed.PackedInts;
@ -212,6 +215,88 @@ public final class SortingCodecReader extends FilterCodecReader {
}
}
/** Sorting VectorValues that iterate over documents in the order of the provided sortMap */
private static class SortingVectorValues extends VectorValues {
final int size;
final int dimension;
final FixedBitSet docsWithField;
final float[][] vectors;
final ByteBuffer vectorAsBytes;
final BytesRef[] binaryVectors;
private int docId = -1;
SortingVectorValues(VectorValues delegate, Sorter.DocMap sortMap, VectorEncoding encoding)
throws IOException {
this.size = delegate.size();
this.dimension = delegate.dimension();
docsWithField = new FixedBitSet(sortMap.size());
if (encoding == VectorEncoding.BYTE) {
vectors = null;
binaryVectors = new BytesRef[sortMap.size()];
vectorAsBytes = null;
} else {
vectors = new float[sortMap.size()][];
binaryVectors = null;
vectorAsBytes =
ByteBuffer.allocate(delegate.dimension() * encoding.byteSize)
.order(ByteOrder.LITTLE_ENDIAN);
}
for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) {
int newDocID = sortMap.oldToNew(doc);
docsWithField.set(newDocID);
if (encoding == VectorEncoding.BYTE) {
binaryVectors[newDocID] = BytesRef.deepCopyOf(delegate.binaryValue());
} else {
vectors[newDocID] = delegate.vectorValue().clone();
}
}
}
@Override
public int docID() {
return docId;
}
@Override
public int nextDoc() throws IOException {
return advance(docId + 1);
}
@Override
public BytesRef binaryValue() throws IOException {
if (binaryVectors != null) {
return binaryVectors[docId];
} else {
vectorAsBytes.asFloatBuffer().put(vectors[docId]);
return new BytesRef(vectorAsBytes.array());
}
}
@Override
public float[] vectorValue() throws IOException {
return vectors[docId];
}
@Override
public int dimension() {
return dimension;
}
@Override
public int size() {
return size;
}
@Override
public int advance(int target) throws IOException {
if (target >= docsWithField.length()) {
return NO_MORE_DOCS;
}
return docId = docsWithField.nextSetBit(target);
}
}
/**
* Return a sorted view of <code>reader</code> according to the order defined by <code>sort</code>
* . If the reader is already sorted, this method might return the reader as-is.
@ -380,7 +465,9 @@ public final class SortingCodecReader extends FilterCodecReader {
@Override
public VectorValues getVectorValues(String field) throws IOException {
return new VectorValues.SortingVectorValues(delegate.getVectorValues(field), docMap);
FieldInfo fi = in.getFieldInfos().fieldInfo(field);
return new SortingVectorValues(
delegate.getVectorValues(field), docMap, fi.getVectorEncoding());
}
@Override

View File

@ -70,65 +70,4 @@ public abstract class VectorValues extends DocIdSetIterator {
public BytesRef binaryValue() throws IOException {
throw new UnsupportedOperationException();
}
/** Sorting VectorValues that iterate over documents in the order of the provided sortMap */
public static class SortingVectorValues extends VectorValues {
private final RandomAccessVectorValues randomAccess;
private final int[] docIdOffsets;
private int docId = -1;
SortingVectorValues(VectorValues delegate, Sorter.DocMap sortMap) throws IOException {
this.randomAccess = ((RandomAccessVectorValues) delegate).copy();
this.docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
int docID;
while ((docID = delegate.nextDoc()) != NO_MORE_DOCS) {
int newDocID = sortMap.oldToNew(docID);
docIdOffsets[newDocID] = offset++;
}
}
@Override
public int docID() {
return docId;
}
@Override
public int nextDoc() throws IOException {
while (docId < docIdOffsets.length - 1) {
++docId;
if (docIdOffsets[docId] != 0) {
return docId;
}
}
docId = NO_MORE_DOCS;
return docId;
}
@Override
public BytesRef binaryValue() throws IOException {
return randomAccess.binaryValue(docIdOffsets[docId] - 1);
}
@Override
public float[] vectorValue() throws IOException {
return randomAccess.vectorValue(docIdOffsets[docId] - 1);
}
@Override
public int dimension() {
return randomAccess.dimension();
}
@Override
public int size() {
return randomAccess.size();
}
@Override
public int advance(int target) throws IOException {
throw new UnsupportedOperationException();
}
}
}

View File

@ -24,7 +24,6 @@ import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BytesRef;

View File

@ -21,7 +21,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.VectorUtil.toBytesRef;
import java.io.IOException;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BitSet;

View File

@ -15,13 +15,14 @@
* limitations under the License.
*/
package org.apache.lucene.index;
package org.apache.lucene.util.hnsw;
import java.io.IOException;
import org.apache.lucene.util.BytesRef;
/**
* Provides random access to vectors by dense ordinal.
* Provides random access to vectors by dense ordinal. This interface is used by HNSW-based
* implementations of KNN search.
*
* @lucene.experimental
*/

View File

@ -17,7 +17,6 @@
package org.apache.lucene.util.hnsw;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.BytesRef;

View File

@ -45,7 +45,6 @@ import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;