LUCENE-10577: enable quantization of HNSW vectors to 8 bits (#1054)

* LUCENE-10577: enable supplying, storing, and comparing HNSW vectors with 8 bit precision
This commit is contained in:
Michael Sokolov 2022-08-10 17:09:07 -04:00 committed by GitHub
parent 59a0917e25
commit a693fe819b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
75 changed files with 2230 additions and 488 deletions

View File

@ -30,6 +30,7 @@ import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataOutput;
@ -214,6 +215,7 @@ public final class Lucene60FieldInfosFormat extends FieldInfosFormat {
pointIndexDimensionCount,
pointNumBytes,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
isSoftDeletesField);
} catch (IllegalStateException e) {

View File

@ -32,7 +32,6 @@ import org.apache.lucene.codecs.StoredFieldsFormat;
import org.apache.lucene.codecs.TermVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
import org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat;
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;

View File

@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene90;
package org.apache.lucene.backward_codecs.lucene90;
import java.io.IOException;
import java.util.Collections;
@ -29,6 +29,7 @@ import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataOutput;
@ -191,6 +192,7 @@ public final class Lucene90FieldInfosFormat extends FieldInfosFormat {
pointIndexDimensionCount,
pointNumBytes,
vectorDimension,
VectorEncoding.FLOAT32,
vectorDistFunc,
isSoftDeletesField);
infos[i].checkConsistency();

View File

@ -18,6 +18,7 @@
package org.apache.lucene.backward_codecs.lucene90;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
import java.io.IOException;
import java.nio.ByteBuffer;
@ -36,6 +37,7 @@ import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
@ -277,6 +279,21 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
// The field does not exist or does not index vectors
return EMPTY_TOPDOCS;
}
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
VectorValues vectorValues = getVectorValues(field);
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
}
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);

View File

@ -17,6 +17,7 @@
package org.apache.lucene.backward_codecs.lucene91;
import java.util.Objects;
import org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.CompoundFormat;
import org.apache.lucene.codecs.DocValuesFormat;
@ -32,7 +33,6 @@ import org.apache.lucene.codecs.StoredFieldsFormat;
import org.apache.lucene.codecs.TermVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
import org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat;
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;

View File

@ -25,6 +25,7 @@ import java.util.Objects;
import java.util.SplittableRandom;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;
@ -55,7 +56,7 @@ public final class Lucene91HnswGraphBuilder {
private final RandomAccessVectorValues vectorValues;
private final SplittableRandom random;
private final Lucene91BoundsChecker bound;
private final HnswGraphSearcher graphSearcher;
private final HnswGraphSearcher<float[]> graphSearcher;
final Lucene91OnHeapHnswGraph hnsw;
@ -101,7 +102,8 @@ public final class Lucene91HnswGraphBuilder {
int levelOfFirstNode = getRandomGraphLevel(ml, random);
this.hnsw = new Lucene91OnHeapHnswGraph(maxConn, levelOfFirstNode);
this.graphSearcher =
new HnswGraphSearcher(
new HnswGraphSearcher<>(
VectorEncoding.FLOAT32,
similarityFunction,
new NeighborQueue(beamWidth, true),
new FixedBitSet(vectorValues.size()));

View File

@ -18,6 +18,7 @@
package org.apache.lucene.backward_codecs.lucene91;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
import java.io.IOException;
import java.nio.ByteBuffer;
@ -34,8 +35,10 @@ import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
@ -244,6 +247,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
target,
k,
vectorValues,
VectorEncoding.FLOAT32,
fieldEntry.similarityFunction,
getGraph(fieldEntry),
getAcceptOrds(acceptDocs, fieldEntry),
@ -265,6 +269,21 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
// The field does not exist or does not index vectors
return EMPTY_TOPDOCS;
}
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
VectorValues vectorValues = getVectorValues(field);
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
}
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);

View File

@ -144,8 +144,8 @@
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90SegmentInfoFormat Segment info}. This
* contains metadata about a segment, such as the number of documents, what files it uses, and
* information about how the segment is sorted
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Field names}. This
* contains metadata about the set of named fields used in the index.
* <li>{@link org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat Field names}.
* This contains metadata about the set of named fields used in the index.
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}.
* This contains, for each document, a list of attribute-value pairs, where the attributes are
* field names. These are used to store auxiliary information about the document, such as its
@ -240,7 +240,7 @@
* systems that frequently run out of file handles.</td>
* </tr>
* <tr>
* <td>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Fields}</td>
* <td>{@link org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat Fields}</td>
* <td>.fnm</td>
* <td>Stores information about the fields</td>
* </tr>

View File

@ -17,6 +17,7 @@
package org.apache.lucene.backward_codecs.lucene92;
import java.util.Objects;
import org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.CompoundFormat;
import org.apache.lucene.codecs.DocValuesFormat;
@ -32,7 +33,6 @@ import org.apache.lucene.codecs.StoredFieldsFormat;
import org.apache.lucene.codecs.TermVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
import org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat;
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;

View File

@ -18,6 +18,7 @@
package org.apache.lucene.backward_codecs.lucene92;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
import java.io.IOException;
import java.util.Arrays;
@ -30,8 +31,10 @@ import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
@ -237,6 +240,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
target,
k,
vectorValues,
VectorEncoding.FLOAT32,
fieldEntry.similarityFunction,
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs),
@ -258,6 +262,21 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
// The field does not exist or does not index vectors
return EMPTY_TOPDOCS;
}
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
VectorValues vectorValues = getVectorValues(field);
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
}
/** Get knn graph values; used for testing */
public HnswGraph getGraph(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field);

View File

@ -144,8 +144,8 @@
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90SegmentInfoFormat Segment info}. This
* contains metadata about a segment, such as the number of documents, what files it uses, and
* information about how the segment is sorted
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Field names}. This
* contains metadata about the set of named fields used in the index.
* <li>{@link org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat Field names}.
* This contains metadata about the set of named fields used in the index.
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}.
* This contains, for each document, a list of attribute-value pairs, where the attributes are
* field names. These are used to store auxiliary information about the document, such as its
@ -240,7 +240,7 @@
* systems that frequently run out of file handles.</td>
* </tr>
* <tr>
* <td>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Fields}</td>
* <td>{@link org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat Fields}</td>
* <td>.fnm</td>
* <td>Stores information about the fields</td>
* </tr>

View File

@ -31,6 +31,7 @@ import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
@ -148,7 +149,8 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
OnHeapHnswGraph graph =
offHeapVectors.size() == 0
? null
: writeGraph(offHeapVectors, fieldInfo.getVectorSimilarityFunction());
: writeGraph(
offHeapVectors, VectorEncoding.FLOAT32, fieldInfo.getVectorSimilarityFunction());
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
writeMeta(
fieldInfo,
@ -266,13 +268,20 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
}
private OnHeapHnswGraph writeGraph(
RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction)
RandomAccessVectorValuesProducer vectorValues,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction)
throws IOException {
// build graph
HnswGraphBuilder hnswGraphBuilder =
new HnswGraphBuilder(
vectorValues, similarityFunction, M, beamWidth, HnswGraphBuilder.randSeed);
HnswGraphBuilder<?> hnswGraphBuilder =
HnswGraphBuilder.create(
vectorValues,
vectorEncoding,
similarityFunction,
M,
beamWidth,
HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());

View File

@ -28,6 +28,7 @@ import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.Directory;
@ -68,7 +69,8 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
static final BytesRef INDEX_DIM_COUNT = new BytesRef(" index dimensional count ");
static final BytesRef DIM_NUM_BYTES = new BytesRef(" dimensional num bytes ");
static final BytesRef VECTOR_NUM_DIMS = new BytesRef(" vector number of dimensions ");
static final BytesRef VECTOR_SEARCH_STRATEGY = new BytesRef(" vector search strategy ");
static final BytesRef VECTOR_ENCODING = new BytesRef(" vector encoding ");
static final BytesRef VECTOR_SIMILARITY = new BytesRef(" vector similarity ");
static final BytesRef SOFT_DELETES = new BytesRef(" soft-deletes ");
@Override
@ -156,8 +158,13 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
int vectorNumDimensions = Integer.parseInt(readString(VECTOR_NUM_DIMS.length, scratch));
SimpleTextUtil.readLine(input, scratch);
assert StringHelper.startsWith(scratch.get(), VECTOR_SEARCH_STRATEGY);
String scoreFunction = readString(VECTOR_SEARCH_STRATEGY.length, scratch);
assert StringHelper.startsWith(scratch.get(), VECTOR_ENCODING);
String encoding = readString(VECTOR_ENCODING.length, scratch);
VectorEncoding vectorEncoding = vectorEncoding(encoding);
SimpleTextUtil.readLine(input, scratch);
assert StringHelper.startsWith(scratch.get(), VECTOR_SIMILARITY);
String scoreFunction = readString(VECTOR_SIMILARITY.length, scratch);
VectorSimilarityFunction vectorDistFunc = distanceFunction(scoreFunction);
SimpleTextUtil.readLine(input, scratch);
@ -179,6 +186,7 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
indexDimensionalCount,
dimensionalNumBytes,
vectorNumDimensions,
vectorEncoding,
vectorDistFunc,
isSoftDeletesField);
}
@ -201,6 +209,10 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
return DocValuesType.valueOf(dvType);
}
public VectorEncoding vectorEncoding(String vectorEncoding) {
return VectorEncoding.valueOf(vectorEncoding);
}
public VectorSimilarityFunction distanceFunction(String scoreFunction) {
return VectorSimilarityFunction.valueOf(scoreFunction);
}
@ -297,7 +309,11 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
SimpleTextUtil.write(out, Integer.toString(fi.getVectorDimension()), scratch);
SimpleTextUtil.writeNewline(out);
SimpleTextUtil.write(out, VECTOR_SEARCH_STRATEGY);
SimpleTextUtil.write(out, VECTOR_ENCODING);
SimpleTextUtil.write(out, fi.getVectorEncoding().name(), scratch);
SimpleTextUtil.writeNewline(out);
SimpleTextUtil.write(out, VECTOR_SIMILARITY);
SimpleTextUtil.write(out, fi.getVectorSimilarityFunction().name(), scratch);
SimpleTextUtil.writeNewline(out);

View File

@ -42,6 +42,7 @@ import org.apache.lucene.store.BufferedChecksumIndexInput;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
@ -181,6 +182,13 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs);
}
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
int numDocs = (int) acceptDocs.cost();
return search(field, target, k, BitSet.of(acceptDocs, numDocs), Integer.MAX_VALUE);
}
@Override
public void checkIntegrity() throws IOException {
IndexInput clone = dataIn.clone();

View File

@ -23,6 +23,7 @@ import org.apache.lucene.codecs.lucene90.tests.MockTermStateFactory;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ByteBuffersDataOutput;
import org.apache.lucene.store.ByteBuffersIndexOutput;
@ -116,6 +117,7 @@ public class TestBlockWriter extends LuceneTestCase {
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
true);
}

View File

@ -41,6 +41,7 @@ import org.apache.lucene.index.ImpactsEnum;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ByteBuffersDirectory;
import org.apache.lucene.store.DataInput;
@ -203,6 +204,7 @@ public class TestSTBlockReader extends LuceneTestCase {
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
false);
}

View File

@ -20,8 +20,12 @@ package org.apache.lucene.codecs;
import java.io.IOException;
import org.apache.lucene.util.Accountable;
/** Vectors' writer for a field */
public abstract class KnnFieldVectorsWriter implements Accountable {
/**
* Vectors' writer for a field
*
* @param <T> an array type; the type of vectors to be written
*/
public abstract class KnnFieldVectorsWriter<T> implements Accountable {
/** Sole constructor */
protected KnnFieldVectorsWriter() {}
@ -30,5 +34,13 @@ public abstract class KnnFieldVectorsWriter implements Accountable {
* Add new docID with its vector value to the given field for indexing. Doc IDs must be added in
* increasing order.
*/
public abstract void addValue(int docID, float[] vectorValue) throws IOException;
public abstract void addValue(int docID, Object vectorValue) throws IOException;
/**
* Used to copy values being indexed to internal storage.
*
* @param vectorValue an array containing the vector value to add
* @return a copy of the value; a new array
*/
public abstract T copyValue(T vectorValue);
}

View File

@ -18,9 +18,11 @@
package org.apache.lucene.codecs;
import java.io.IOException;
import org.apache.lucene.codecs.lucene94.Lucene94HnswVectorsFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.util.Bits;
@ -76,6 +78,15 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
/** Returns a {@link KnnVectorsReader} to read the vectors from the index. */
public abstract KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException;
/**
* Returns the current KnnVectorsFormat version number. Indexes written using the format will be
* "stamped" with this version.
*/
public int currentVersion() {
// return the version supported by older codecs that did not override this method
return Lucene94HnswVectorsFormat.VERSION_START;
}
/**
* EMPTY throws an exception when written. It acts as a sentinel indicating a Codec that does not
* support vectors.
@ -104,6 +115,12 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
return TopDocsCollector.EMPTY_TOPDOCS;
}
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
return TopDocsCollector.EMPTY_TOPDOCS;
}
@Override
public void close() {}

View File

@ -20,12 +20,16 @@ package org.apache.lucene.codecs;
import java.io.Closeable;
import java.io.IOException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/** Reads vectors from an index. */
public abstract class KnnVectorsReader implements Closeable, Accountable {
@ -75,11 +79,39 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
* if they are all allowed to match.
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
* @return the k nearest neighbor documents, along with their (similarity-specific) scores.
*/
public abstract TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
/**
* Return the k nearest neighbor documents as determined by comparison of their vector values for
* this field, to the given vector, by the field's similarity function. The score of each document
* is derived from the vector similarity in a way that ensures scores are positive and that a
* larger score corresponds to a higher ranking.
*
* <p>The search is exact, guaranteeing the true k closest neighbors will be returned. Typically
* this requires an exhaustive scan of the entire index. It is intended to be used when the number
* of potential matches is limited.
*
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor, in
* order of their similarity to the query vector (decreasing scores). The {@link TotalHits}
* contains the number of documents visited during the search. If the search stopped early because
* it hit {@code visitedLimit}, it is indicated through the relation {@code
* TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
*
* <p>The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link
* FieldInfo}. The return value is never {@code null}.
*
* @param field the vector field to search
* @param target the vector-valued query
* @param k the number of docs to return
* @param acceptDocs {@link DocIdSetIterator} that represents the allowed documents to match.
* @return the k nearest neighbor documents, along with their (similarity-specific) scores.
*/
public abstract TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException;
/**
* Returns an instance optimized for merging. This instance may only be consumed in the thread
* that called {@link #getMergeInstance()}.
@ -89,4 +121,67 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
public KnnVectorsReader getMergeInstance() {
return this;
}
/** {@link #searchExhaustively} */
protected static TopDocs exhaustiveSearch(
VectorValues vectorValues,
DocIdSetIterator acceptDocs,
VectorSimilarityFunction similarityFunction,
float[] target,
int k)
throws IOException {
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc = queue.top();
int doc;
while ((doc = acceptDocs.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
int vectorDoc = vectorValues.advance(doc);
assert vectorDoc == doc;
float score = similarityFunction.compare(vectorValues.vectorValue(), target);
if (score >= topDoc.score) {
topDoc.score = score;
topDoc.doc = doc;
topDoc = queue.updateTop();
}
}
return topDocsFromHitQueue(queue, acceptDocs.cost());
}
/** {@link #searchExhaustively} */
protected static TopDocs exhaustiveSearch(
VectorValues vectorValues,
DocIdSetIterator acceptDocs,
VectorSimilarityFunction similarityFunction,
BytesRef target,
int k)
throws IOException {
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc = queue.top();
int doc;
while ((doc = acceptDocs.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
int vectorDoc = vectorValues.advance(doc);
assert vectorDoc == doc;
float score = similarityFunction.compare(vectorValues.binaryValue(), target);
if (score >= topDoc.score) {
topDoc.score = score;
topDoc.doc = doc;
topDoc = queue.updateTop();
}
}
return topDocsFromHitQueue(queue, acceptDocs.cost());
}
private static TopDocs topDocsFromHitQueue(HitQueue queue, long numHits) {
// Remove any remaining sentinel values
while (queue.size() > 0 && queue.top().score < 0) {
queue.pop();
}
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
topScoreDocs[i] = queue.pop();
}
TotalHits totalHits = new TotalHits(numHits, TotalHits.Relation.EQUAL_TO);
return new TopDocs(totalHits, topScoreDocs);
}
}

View File

@ -37,14 +37,15 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
protected KnnVectorsWriter() {}
/** Add new field for indexing */
public abstract KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException;
public abstract KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException;
/** Flush all buffered data on disk * */
public abstract void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException;
/** Write field for merging */
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
KnnFieldVectorsWriter writer = addField(fieldInfo);
@SuppressWarnings("unchecked")
public <T> void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
KnnFieldVectorsWriter<T> writer = (KnnFieldVectorsWriter<T>) addField(fieldInfo);
VectorValues mergedValues = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
for (int doc = mergedValues.nextDoc();
doc != DocIdSetIterator.NO_MORE_DOCS;

View File

@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene94;
import java.io.IOException;
import org.apache.lucene.index.FilterVectorValues;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BytesRef;
/** reads from byte-encoded data */
public class ExpandingVectorValues extends FilterVectorValues {
private final float[] value;
/** @param in the wrapped values */
protected ExpandingVectorValues(VectorValues in) {
super(in);
value = new float[in.dimension()];
}
@Override
public float[] vectorValue() throws IOException {
BytesRef binaryValue = binaryValue();
byte[] bytes = binaryValue.bytes;
for (int i = 0, j = binaryValue.offset; i < value.length; i++, j++) {
value[i] = bytes[j];
}
return value;
}
}

View File

@ -32,7 +32,6 @@ import org.apache.lucene.codecs.StoredFieldsFormat;
import org.apache.lucene.codecs.TermVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
import org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat;
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;
@ -69,7 +68,7 @@ public class Lucene94Codec extends Codec {
}
private final TermVectorsFormat vectorsFormat = new Lucene90TermVectorsFormat();
private final FieldInfosFormat fieldInfosFormat = new Lucene90FieldInfosFormat();
private final FieldInfosFormat fieldInfosFormat = new Lucene94FieldInfosFormat();
private final SegmentInfoFormat segmentInfosFormat = new Lucene90SegmentInfoFormat();
private final LiveDocsFormat liveDocsFormat = new Lucene90LiveDocsFormat();
private final CompoundFormat compoundFormat = new Lucene90CompoundFormat();
@ -100,6 +99,11 @@ public class Lucene94Codec extends Codec {
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return Lucene94Codec.this.getKnnVectorsFormatForField(field);
}
@Override
public int currentVersion() {
return Lucene94HnswVectorsFormat.VERSION_CURRENT;
}
};
private final StoredFieldsFormat storedFieldsFormat;

View File

@ -0,0 +1,385 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene94;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.FieldInfosFormat;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataOutput;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
/**
* Lucene 9.0 Field Infos format.
*
* <p>Field names are stored in the field info file, with suffix <code>.fnm</code>.
*
* <p>FieldInfos (.fnm) --&gt; Header,FieldsCount, &lt;FieldName,FieldNumber,
* FieldBits,DocValuesBits,DocValuesGen,Attributes,DimensionCount,DimensionNumBytes&gt;
* <sup>FieldsCount</sup>,Footer
*
* <p>Data types:
*
* <ul>
* <li>Header --&gt; {@link CodecUtil#checkIndexHeader IndexHeader}
* <li>FieldsCount --&gt; {@link DataOutput#writeVInt VInt}
* <li>FieldName --&gt; {@link DataOutput#writeString String}
* <li>FieldBits, IndexOptions, DocValuesBits --&gt; {@link DataOutput#writeByte Byte}
* <li>FieldNumber, DimensionCount, DimensionNumBytes --&gt; {@link DataOutput#writeInt VInt}
* <li>Attributes --&gt; {@link DataOutput#writeMapOfStrings Map&lt;String,String&gt;}
* <li>DocValuesGen --&gt; {@link DataOutput#writeLong(long) Int64}
* <li>Footer --&gt; {@link CodecUtil#writeFooter CodecFooter}
* </ul>
*
* Field Descriptions:
*
* <ul>
* <li>FieldsCount: the number of fields in this file.
* <li>FieldName: name of the field as a UTF-8 String.
* <li>FieldNumber: the field's number. Note that unlike previous versions of Lucene, the fields
* are not numbered implicitly by their order in the file, instead explicitly.
* <li>FieldBits: a byte containing field options.
* <ul>
* <li>The low order bit (0x1) is one for fields that have term vectors stored, and zero for
* fields without term vectors.
* <li>If the second lowest order-bit is set (0x2), norms are omitted for the indexed field.
* <li>If the third lowest-order bit is set (0x4), payloads are stored for the indexed
* field.
* </ul>
* <li>IndexOptions: a byte containing index options.
* <ul>
* <li>0: not indexed
* <li>1: indexed as DOCS_ONLY
* <li>2: indexed as DOCS_AND_FREQS
* <li>3: indexed as DOCS_AND_FREQS_AND_POSITIONS
* <li>4: indexed as DOCS_AND_FREQS_AND_POSITIONS_AND_OFFSETS
* </ul>
* <li>DocValuesBits: a byte containing per-document value types. The type recorded as two
* four-bit integers, with the high-order bits representing <code>norms</code> options, and
* the low-order bits representing {@code DocValues} options. Each four-bit integer can be
* decoded as such:
* <ul>
* <li>0: no DocValues for this field.
* <li>1: NumericDocValues. ({@link DocValuesType#NUMERIC})
* <li>2: BinaryDocValues. ({@code DocValuesType#BINARY})
* <li>3: SortedDocValues. ({@code DocValuesType#SORTED})
* </ul>
* <li>DocValuesGen is the generation count of the field's DocValues. If this is -1, there are no
* DocValues updates to that field. Anything above zero means there are updates stored by
* {@link DocValuesFormat}.
* <li>Attributes: a key-value map of codec-private attributes.
* <li>PointDimensionCount, PointNumBytes: these are non-zero only if the field is indexed as
* points, e.g. using {@link org.apache.lucene.document.LongPoint}
* <li>VectorDimension: it is non-zero if the field is indexed as vectors.
* <li>VectorEncoding: a byte containing the encoding of vector values:
* <ul>
* <li>0: BYTE. Samples are stored as signed bytes
* <li>1: FLOAT32. Samples are stored in IEEE 32-bit floating point format.
* </ul>
* <li>VectorSimilarityFunction: a byte containing distance function used for similarity
* calculation.
* <ul>
* <li>0: EUCLIDEAN distance. ({@link VectorSimilarityFunction#EUCLIDEAN})
* <li>1: DOT_PRODUCT similarity. ({@link VectorSimilarityFunction#DOT_PRODUCT})
* <li>2: COSINE similarity. ({@link VectorSimilarityFunction#COSINE})
* </ul>
* </ul>
*
* @lucene.experimental
*/
public final class Lucene94FieldInfosFormat extends FieldInfosFormat {
/** Sole constructor. */
public Lucene94FieldInfosFormat() {}
@Override
public FieldInfos read(
Directory directory, SegmentInfo segmentInfo, String segmentSuffix, IOContext context)
throws IOException {
final String fileName =
IndexFileNames.segmentFileName(segmentInfo.name, segmentSuffix, EXTENSION);
try (ChecksumIndexInput input = directory.openChecksumInput(fileName, context)) {
Throwable priorE = null;
FieldInfo[] infos = null;
try {
CodecUtil.checkIndexHeader(
input,
Lucene94FieldInfosFormat.CODEC_NAME,
Lucene94FieldInfosFormat.FORMAT_START,
Lucene94FieldInfosFormat.FORMAT_CURRENT,
segmentInfo.getId(),
segmentSuffix);
final int size = input.readVInt(); // read in the size
infos = new FieldInfo[size];
// previous field's attribute map, we share when possible:
Map<String, String> lastAttributes = Collections.emptyMap();
for (int i = 0; i < size; i++) {
String name = input.readString();
final int fieldNumber = input.readVInt();
if (fieldNumber < 0) {
throw new CorruptIndexException(
"invalid field number for field: " + name + ", fieldNumber=" + fieldNumber, input);
}
byte bits = input.readByte();
boolean storeTermVector = (bits & STORE_TERMVECTOR) != 0;
boolean omitNorms = (bits & OMIT_NORMS) != 0;
boolean storePayloads = (bits & STORE_PAYLOADS) != 0;
boolean isSoftDeletesField = (bits & SOFT_DELETES_FIELD) != 0;
final IndexOptions indexOptions = getIndexOptions(input, input.readByte());
// DV Types are packed in one byte
final DocValuesType docValuesType = getDocValuesType(input, input.readByte());
final long dvGen = input.readLong();
Map<String, String> attributes = input.readMapOfStrings();
// just use the last field's map if its the same
if (attributes.equals(lastAttributes)) {
attributes = lastAttributes;
}
lastAttributes = attributes;
int pointDataDimensionCount = input.readVInt();
int pointNumBytes;
int pointIndexDimensionCount = pointDataDimensionCount;
if (pointDataDimensionCount != 0) {
pointIndexDimensionCount = input.readVInt();
pointNumBytes = input.readVInt();
} else {
pointNumBytes = 0;
}
final int vectorDimension = input.readVInt();
final VectorEncoding vectorEncoding = getVectorEncoding(input, input.readByte());
final VectorSimilarityFunction vectorDistFunc = getDistFunc(input, input.readByte());
try {
infos[i] =
new FieldInfo(
name,
fieldNumber,
storeTermVector,
omitNorms,
storePayloads,
indexOptions,
docValuesType,
dvGen,
attributes,
pointDataDimensionCount,
pointIndexDimensionCount,
pointNumBytes,
vectorDimension,
vectorEncoding,
vectorDistFunc,
isSoftDeletesField);
infos[i].checkConsistency();
} catch (IllegalStateException e) {
throw new CorruptIndexException(
"invalid fieldinfo for field: " + name + ", fieldNumber=" + fieldNumber, input, e);
}
}
} catch (Throwable exception) {
priorE = exception;
} finally {
CodecUtil.checkFooter(input, priorE);
}
return new FieldInfos(infos);
}
}
static {
// We "mirror" DocValues enum values with the constants below; let's try to ensure if we add a
// new DocValuesType while this format is
// still used for writing, we remember to fix this encoding:
assert DocValuesType.values().length == 6;
}
private static byte docValuesByte(DocValuesType type) {
switch (type) {
case NONE:
return 0;
case NUMERIC:
return 1;
case BINARY:
return 2;
case SORTED:
return 3;
case SORTED_SET:
return 4;
case SORTED_NUMERIC:
return 5;
default:
// BUG
throw new AssertionError("unhandled DocValuesType: " + type);
}
}
private static DocValuesType getDocValuesType(IndexInput input, byte b) throws IOException {
switch (b) {
case 0:
return DocValuesType.NONE;
case 1:
return DocValuesType.NUMERIC;
case 2:
return DocValuesType.BINARY;
case 3:
return DocValuesType.SORTED;
case 4:
return DocValuesType.SORTED_SET;
case 5:
return DocValuesType.SORTED_NUMERIC;
default:
throw new CorruptIndexException("invalid docvalues byte: " + b, input);
}
}
private static VectorEncoding getVectorEncoding(IndexInput input, byte b) throws IOException {
if (b < 0 || b >= VectorEncoding.values().length) {
throw new CorruptIndexException("invalid vector encoding: " + b, input);
}
return VectorEncoding.values()[b];
}
private static VectorSimilarityFunction getDistFunc(IndexInput input, byte b) throws IOException {
if (b < 0 || b >= VectorSimilarityFunction.values().length) {
throw new CorruptIndexException("invalid distance function: " + b, input);
}
return VectorSimilarityFunction.values()[b];
}
static {
// We "mirror" IndexOptions enum values with the constants below; let's try to ensure if we add
// a new IndexOption while this format is
// still used for writing, we remember to fix this encoding:
assert IndexOptions.values().length == 5;
}
private static byte indexOptionsByte(IndexOptions indexOptions) {
switch (indexOptions) {
case NONE:
return 0;
case DOCS:
return 1;
case DOCS_AND_FREQS:
return 2;
case DOCS_AND_FREQS_AND_POSITIONS:
return 3;
case DOCS_AND_FREQS_AND_POSITIONS_AND_OFFSETS:
return 4;
default:
// BUG:
throw new AssertionError("unhandled IndexOptions: " + indexOptions);
}
}
private static IndexOptions getIndexOptions(IndexInput input, byte b) throws IOException {
switch (b) {
case 0:
return IndexOptions.NONE;
case 1:
return IndexOptions.DOCS;
case 2:
return IndexOptions.DOCS_AND_FREQS;
case 3:
return IndexOptions.DOCS_AND_FREQS_AND_POSITIONS;
case 4:
return IndexOptions.DOCS_AND_FREQS_AND_POSITIONS_AND_OFFSETS;
default:
// BUG
throw new CorruptIndexException("invalid IndexOptions byte: " + b, input);
}
}
@Override
public void write(
Directory directory,
SegmentInfo segmentInfo,
String segmentSuffix,
FieldInfos infos,
IOContext context)
throws IOException {
final String fileName =
IndexFileNames.segmentFileName(segmentInfo.name, segmentSuffix, EXTENSION);
try (IndexOutput output = directory.createOutput(fileName, context)) {
CodecUtil.writeIndexHeader(
output,
Lucene94FieldInfosFormat.CODEC_NAME,
Lucene94FieldInfosFormat.FORMAT_CURRENT,
segmentInfo.getId(),
segmentSuffix);
output.writeVInt(infos.size());
for (FieldInfo fi : infos) {
fi.checkConsistency();
output.writeString(fi.name);
output.writeVInt(fi.number);
byte bits = 0x0;
if (fi.hasVectors()) bits |= STORE_TERMVECTOR;
if (fi.omitsNorms()) bits |= OMIT_NORMS;
if (fi.hasPayloads()) bits |= STORE_PAYLOADS;
if (fi.isSoftDeletesField()) bits |= SOFT_DELETES_FIELD;
output.writeByte(bits);
output.writeByte(indexOptionsByte(fi.getIndexOptions()));
// pack the DV type and hasNorms in one byte
output.writeByte(docValuesByte(fi.getDocValuesType()));
output.writeLong(fi.getDocValuesGen());
output.writeMapOfStrings(fi.attributes());
output.writeVInt(fi.getPointDimensionCount());
if (fi.getPointDimensionCount() != 0) {
output.writeVInt(fi.getPointIndexDimensionCount());
output.writeVInt(fi.getPointNumBytes());
}
output.writeVInt(fi.getVectorDimension());
output.writeByte((byte) fi.getVectorEncoding().ordinal());
output.writeByte((byte) fi.getVectorSimilarityFunction().ordinal());
}
CodecUtil.writeFooter(output);
}
}
/** Extension of field infos */
static final String EXTENSION = "fnm";
// Codec header
static final String CODEC_NAME = "Lucene90FieldInfos";
static final int FORMAT_START = 0;
static final int FORMAT_CURRENT = FORMAT_START;
// Field flags
static final byte STORE_TERMVECTOR = 0x1;
static final byte OMIT_NORMS = 0x2;
static final byte STORE_PAYLOADS = 0x4;
static final byte SOFT_DELETES_FIELD = 0x8;
}

View File

@ -38,8 +38,9 @@ import org.apache.lucene.util.hnsw.HnswGraph;
* <p>For each field:
*
* <ul>
* <li>Floating-point vector data ordered by field, document ordinal, and vector dimension. The
* floats are stored in little-endian byte order
* <li>Vector data ordered by field, document ordinal, and vector dimension. When the
* vectorEncoding is BYTE, each sample is stored as a single byte. When it is FLOAT32, each
* sample is stored as an IEEE float in little-endian byte order.
* <li>DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)},
* note that only in sparse case
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
@ -89,7 +90,7 @@ import org.apache.lucene.util.hnsw.HnswGraph;
* <ul>
* <li><b>[int]</b> the number of nodes on this level
* <li><b>array[int]</b> for levels greater than 0 list of nodes on this level, stored as
* the the level 0th nodes ordinals.
* the level 0th nodes' ordinals.
* </ul>
* </ul>
*
@ -104,8 +105,8 @@ public final class Lucene94HnswVectorsFormat extends KnnVectorsFormat {
static final String VECTOR_DATA_EXTENSION = "vec";
static final String VECTOR_INDEX_EXTENSION = "vex";
static final int VERSION_START = 0;
static final int VERSION_CURRENT = VERSION_START;
public static final int VERSION_START = 0;
public static final int VERSION_CURRENT = 1;
/** Default number of maximum connections per node */
public static final int DEFAULT_MAX_CONN = 16;
@ -156,6 +157,11 @@ public final class Lucene94HnswVectorsFormat extends KnnVectorsFormat {
return new Lucene94HnswVectorsReader(state);
}
@Override
public int currentVersion() {
return VERSION_CURRENT;
}
@Override
public String toString() {
return "Lucene94HnswVectorsFormat(name=Lucene94HnswVectorsFormat, maxConn="

View File

@ -18,6 +18,8 @@
package org.apache.lucene.codecs.lucene94;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
import static org.apache.lucene.util.VectorUtil.toBytesRef;
import java.io.IOException;
import java.util.Arrays;
@ -30,8 +32,10 @@ import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
@ -169,16 +173,23 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
+ fieldEntry.dimension);
}
long numBytes = (long) fieldEntry.size() * dimension * Float.BYTES;
int byteSize =
switch (info.getVectorEncoding()) {
case BYTE -> Byte.BYTES;
case FLOAT32 -> Float.BYTES;
};
int numBytes = fieldEntry.size * dimension * byteSize;
if (numBytes != fieldEntry.vectorDataLength) {
throw new IllegalStateException(
"Vector data length "
+ fieldEntry.vectorDataLength
+ " not matching size="
+ fieldEntry.size()
+ fieldEntry.size
+ " * dim="
+ dimension
+ " * 4 = "
+ " * byteSize="
+ byteSize
+ " = "
+ numBytes);
}
}
@ -193,9 +204,18 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
return VectorSimilarityFunction.values()[similarityFunctionId];
}
private VectorEncoding readVectorEncoding(DataInput input) throws IOException {
int encodingId = input.readInt();
if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
}
return VectorEncoding.values()[encodingId];
}
private FieldEntry readField(IndexInput input) throws IOException {
VectorEncoding vectorEncoding = readVectorEncoding(input);
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
return new FieldEntry(input, similarityFunction);
return new FieldEntry(input, vectorEncoding, similarityFunction);
}
@Override
@ -216,7 +236,12 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
@Override
public VectorValues getVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
return OffHeapVectorValues.load(fieldEntry, vectorData);
VectorValues values = OffHeapVectorValues.load(fieldEntry, vectorData);
if (fieldEntry.vectorEncoding == VectorEncoding.BYTE) {
return new ExpandingVectorValues(values);
} else {
return values;
}
}
@Override
@ -237,6 +262,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
target,
k,
vectorValues,
fieldEntry.vectorEncoding,
fieldEntry.similarityFunction,
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs),
@ -258,6 +284,25 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
// The field does not exist or does not index vectors
return EMPTY_TOPDOCS;
}
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
VectorValues vectorValues = getVectorValues(field);
return switch (fieldEntry.vectorEncoding) {
case BYTE -> exhaustiveSearch(
vectorValues, acceptDocs, similarityFunction, toBytesRef(target), k);
case FLOAT32 -> exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
};
}
/** Get knn graph values; used for testing */
public HnswGraph getGraph(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field);
@ -286,6 +331,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
static class FieldEntry {
final VectorSimilarityFunction similarityFunction;
final VectorEncoding vectorEncoding;
final long vectorDataOffset;
final long vectorDataLength;
final long vectorIndexOffset;
@ -315,8 +361,13 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
final DirectMonotonicReader.Meta meta;
final long addressesLength;
FieldEntry(IndexInput input, VectorSimilarityFunction similarityFunction) throws IOException {
FieldEntry(
IndexInput input,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction)
throws IOException {
this.similarityFunction = similarityFunction;
this.vectorEncoding = vectorEncoding;
vectorDataOffset = input.readVLong();
vectorDataLength = input.readVLong();
vectorIndexOffset = input.readVLong();

View File

@ -65,7 +65,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
private final int M;
private final int beamWidth;
private final List<FieldWriter> fields = new ArrayList<>();
private final List<FieldWriter<?>> fields = new ArrayList<>();
private boolean finished;
Lucene94HnswVectorsWriter(SegmentWriteState state, int M, int beamWidth) throws IOException {
@ -121,15 +121,16 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
}
@Override
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
FieldWriter newField = new FieldWriter(fieldInfo, M, beamWidth, segmentWriteState.infoStream);
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
FieldWriter<?> newField =
FieldWriter.create(fieldInfo, M, beamWidth, segmentWriteState.infoStream);
fields.add(newField);
return newField;
}
@Override
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
for (FieldWriter field : fields) {
for (FieldWriter<?> field : fields) {
if (sortMap == null) {
writeField(field, maxDoc);
} else {
@ -159,22 +160,20 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
@Override
public long ramBytesUsed() {
long total = 0;
for (FieldWriter field : fields) {
for (FieldWriter<?> field : fields) {
total += field.ramBytesUsed();
}
return total;
}
private void writeField(FieldWriter fieldData, int maxDoc) throws IOException {
private void writeField(FieldWriter<?> fieldData, int maxDoc) throws IOException {
// write vector values
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
final BytesRef binaryValue = new BytesRef(buffer.array());
for (float[] vector : fieldData.vectors) {
buffer.asFloatBuffer().put(vector);
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
switch (fieldData.fieldInfo.getVectorEncoding()) {
case BYTE -> writeByteVectors(fieldData);
case FLOAT32 -> writeFloat32Vectors(fieldData);
}
;
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
// write graph
@ -194,7 +193,24 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
graph);
}
private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap)
private void writeFloat32Vectors(FieldWriter<?> fieldData) throws IOException {
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
final BytesRef binaryValue = new BytesRef(buffer.array());
for (Object v : fieldData.vectors) {
buffer.asFloatBuffer().put((float[]) v);
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
}
}
private void writeByteVectors(FieldWriter<?> fieldData) throws IOException {
for (Object v : fieldData.vectors) {
BytesRef vector = (BytesRef) v;
vectorData.writeBytes(vector.bytes, vector.offset, vector.length);
}
}
private void writeSortingField(FieldWriter<?> fieldData, int maxDoc, Sorter.DocMap sortMap)
throws IOException {
final int[] docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
@ -221,15 +237,11 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
}
// write vector values
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
final BytesRef binaryValue = new BytesRef(buffer.array());
for (int ordinal : ordMap) {
float[] vector = fieldData.vectors.get(ordinal);
buffer.asFloatBuffer().put(vector);
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
}
long vectorDataOffset =
switch (fieldData.fieldInfo.getVectorEncoding()) {
case BYTE -> writeSortedByteVectors(fieldData, ordMap);
case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap);
};
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
// write graph
@ -249,6 +261,29 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
mockGraph);
}
private long writeSortedFloat32Vectors(FieldWriter<?> fieldData, int[] ordMap)
throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
final BytesRef binaryValue = new BytesRef(buffer.array());
for (int ordinal : ordMap) {
float[] vector = (float[]) fieldData.vectors.get(ordinal);
buffer.asFloatBuffer().put(vector);
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
}
return vectorDataOffset;
}
private long writeSortedByteVectors(FieldWriter<?> fieldData, int[] ordMap) throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
for (int ordinal : ordMap) {
byte[] vector = (byte[]) fieldData.vectors.get(ordinal);
vectorData.writeBytes(vector, 0, vector.length);
}
return vectorDataOffset;
}
// reconstruct graph substituting old ordinals with new ordinals
private HnswGraph reconstructAndWriteGraph(
OnHeapHnswGraph graph, int[] newToOldMap, int[] oldToNewMap) throws IOException {
@ -354,7 +389,8 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
boolean success = false;
try {
// write the vector data to a temporary file
DocsWithFieldSet docsWithField = writeVectorData(tempVectorData, vectors);
DocsWithFieldSet docsWithField =
writeVectorData(tempVectorData, vectors, fieldInfo.getVectorEncoding().byteSize);
CodecUtil.writeFooter(tempVectorData);
IOUtils.close(tempVectorData);
@ -365,21 +401,22 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength());
CodecUtil.retrieveChecksum(vectorDataInput);
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
long vectorIndexOffset = vectorIndex.getFilePointer();
// build the graph using the temporary vector data
// we use Lucene94HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
// doesn't need to know docIds
// TODO: separate random access vector values from DocIdSetIterator?
int byteSize = vectors.dimension() * fieldInfo.getVectorEncoding().byteSize;
OffHeapVectorValues offHeapVectors =
new OffHeapVectorValues.DenseOffHeapVectorValues(
vectors.dimension(), docsWithField.cardinality(), vectorDataInput);
vectors.dimension(), docsWithField.cardinality(), vectorDataInput, byteSize);
OnHeapHnswGraph graph = null;
if (offHeapVectors.size() != 0) {
// build graph
HnswGraphBuilder hnswGraphBuilder =
new HnswGraphBuilder(
HnswGraphBuilder<?> hnswGraphBuilder =
HnswGraphBuilder.create(
offHeapVectors,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
@ -451,6 +488,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
HnswGraph graph)
throws IOException {
meta.writeInt(field.number);
meta.writeInt(field.getVectorEncoding().ordinal());
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
meta.writeVLong(vectorDataOffset);
meta.writeVLong(vectorDataLength);
@ -520,13 +558,13 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
/**
* Writes the vector values to the output and returns a set of documents that contains vectors.
*/
private static DocsWithFieldSet writeVectorData(IndexOutput output, VectorValues vectors)
throws IOException {
private static DocsWithFieldSet writeVectorData(
IndexOutput output, VectorValues vectors, int scalarSize) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
// write vector
BytesRef binaryValue = vectors.binaryValue();
assert binaryValue.length == vectors.dimension() * Float.BYTES;
assert binaryValue.length == vectors.dimension() * scalarSize;
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
docsWithField.add(docV);
}
@ -538,54 +576,69 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
IOUtils.close(meta, vectorData, vectorIndex);
}
private static class FieldWriter extends KnnFieldVectorsWriter {
private abstract static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
private final FieldInfo fieldInfo;
private final int dim;
private final DocsWithFieldSet docsWithField;
private final List<float[]> vectors;
private final RAVectorValues raVectorValues;
private final HnswGraphBuilder hnswGraphBuilder;
private final List<T> vectors;
private final RAVectorValues<T> raVectorValues;
private final HnswGraphBuilder<T> hnswGraphBuilder;
private int lastDocID = -1;
private int node = 0;
static FieldWriter<?> create(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
throws IOException {
int dim = fieldInfo.getVectorDimension();
return switch (fieldInfo.getVectorEncoding()) {
case BYTE -> new FieldWriter<BytesRef>(fieldInfo, M, beamWidth, infoStream) {
@Override
public BytesRef copyValue(BytesRef value) {
return new BytesRef(ArrayUtil.copyOfSubArray(value.bytes, value.offset, dim));
}
};
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream) {
@Override
public float[] copyValue(float[] value) {
return ArrayUtil.copyOfSubArray(value, 0, dim);
}
};
};
}
@SuppressWarnings("unchecked")
FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
throws IOException {
this.fieldInfo = fieldInfo;
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>();
raVectorValues = new RAVectorValues(vectors, dim);
raVectorValues = new RAVectorValues<>(vectors, dim);
hnswGraphBuilder =
new HnswGraphBuilder(
() -> raVectorValues,
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
(HnswGraphBuilder<T>)
HnswGraphBuilder.create(
() -> raVectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(infoStream);
}
@Override
public void addValue(int docID, float[] vectorValue) throws IOException {
@SuppressWarnings("unchecked")
public void addValue(int docID, Object value) throws IOException {
if (docID == lastDocID) {
throw new IllegalArgumentException(
"VectorValuesField \""
+ fieldInfo.name
+ "\" appears more than once in this document (only one value is allowed per field)");
}
if (vectorValue.length != dim) {
throw new IllegalArgumentException(
"Attempt to index a vector of dimension "
+ vectorValue.length
+ " but \""
+ fieldInfo.name
+ "\" has dimension "
+ dim);
}
T vectorValue = (T) value;
assert docID > lastDocID;
docsWithField.add(docID);
vectors.add(ArrayUtil.copyOfSubArray(vectorValue, 0, vectorValue.length));
vectors.add(copyValue(vectorValue));
if (node > 0) {
// start at node 1! node 0 is added implicitly, in the constructor
hnswGraphBuilder.addGraphNode(node, vectorValue);
@ -608,16 +661,16 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
return docsWithField.ramBytesUsed()
+ vectors.size()
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
+ vectors.size() * vectors.get(0).length * Float.BYTES
+ vectors.size() * fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize
+ hnswGraphBuilder.getGraph().ramBytesUsed();
}
}
private static class RAVectorValues implements RandomAccessVectorValues {
private final List<float[]> vectors;
private static class RAVectorValues<T> implements RandomAccessVectorValues {
private final List<T> vectors;
private final int dim;
RAVectorValues(List<float[]> vectors, int dim) {
RAVectorValues(List<T> vectors, int dim) {
this.vectors = vectors;
this.dim = dim;
}
@ -634,12 +687,12 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
@Override
public float[] vectorValue(int targetOrd) throws IOException {
return vectors.get(targetOrd);
return (float[]) vectors.get(targetOrd);
}
@Override
public BytesRef binaryValue(int targetOrd) throws IOException {
throw new UnsupportedOperationException();
return (BytesRef) vectors.get(targetOrd);
}
}
}

View File

@ -41,11 +41,11 @@ abstract class OffHeapVectorValues extends VectorValues
protected final int byteSize;
protected final float[] value;
OffHeapVectorValues(int dimension, int size, IndexInput slice) {
OffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
this.dimension = dimension;
this.size = size;
this.slice = slice;
byteSize = Float.BYTES * dimension;
this.byteSize = byteSize;
byteBuffer = ByteBuffer.allocate(byteSize);
value = new float[dimension];
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
@ -93,10 +93,16 @@ abstract class OffHeapVectorValues extends VectorValues
}
IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
int byteSize =
switch (fieldEntry.vectorEncoding) {
case BYTE -> fieldEntry.dimension;
case FLOAT32 -> fieldEntry.dimension * Float.BYTES;
};
if (fieldEntry.docsWithFieldOffset == -1) {
return new DenseOffHeapVectorValues(fieldEntry.dimension, fieldEntry.size, bytesSlice);
return new DenseOffHeapVectorValues(
fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize);
} else {
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice);
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice, byteSize);
}
}
@ -106,8 +112,8 @@ abstract class OffHeapVectorValues extends VectorValues
private int doc = -1;
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice) {
super(dimension, size, slice);
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
super(dimension, size, slice, byteSize);
}
@Override
@ -145,7 +151,7 @@ abstract class OffHeapVectorValues extends VectorValues
@Override
public RandomAccessVectorValues randomAccess() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone());
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
}
@Override
@ -167,10 +173,13 @@ abstract class OffHeapVectorValues extends VectorValues
private final Lucene94HnswVectorsReader.FieldEntry fieldEntry;
public SparseOffHeapVectorValues(
Lucene94HnswVectorsReader.FieldEntry fieldEntry, IndexInput dataIn, IndexInput slice)
Lucene94HnswVectorsReader.FieldEntry fieldEntry,
IndexInput dataIn,
IndexInput slice,
int byteSize)
throws IOException {
super(fieldEntry.dimension, fieldEntry.size, slice);
super(fieldEntry.dimension, fieldEntry.size, slice, byteSize);
this.fieldEntry = fieldEntry;
final RandomAccessInput addressesData =
dataIn.randomAccessSlice(fieldEntry.addressesOffset, fieldEntry.addressesLength);
@ -218,7 +227,7 @@ abstract class OffHeapVectorValues extends VectorValues
@Override
public RandomAccessVectorValues randomAccess() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone());
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
}
@Override
@ -248,7 +257,7 @@ abstract class OffHeapVectorValues extends VectorValues
private static class EmptyOffHeapVectorValues extends OffHeapVectorValues {
public EmptyOffHeapVectorValues(int dimension) {
super(dimension, 0, null);
super(dimension, 0, null, 0);
}
private int doc = -1;

View File

@ -144,7 +144,7 @@
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90SegmentInfoFormat Segment info}. This
* contains metadata about a segment, such as the number of documents, what files it uses, and
* information about how the segment is sorted
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Field names}. This
* <li>{@link org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat Field names}. This
* contains metadata about the set of named fields used in the index.
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}.
* This contains, for each document, a list of attribute-value pairs, where the attributes are
@ -240,7 +240,7 @@
* systems that frequently run out of file handles.</td>
* </tr>
* <tr>
* <td>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Fields}</td>
* <td>{@link org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat Fields}</td>
* <td>.fnm</td>
* <td>Stores information about the fields</td>
* </tr>

View File

@ -33,6 +33,7 @@ import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
@ -101,7 +102,7 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
}
@Override
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
KnnVectorsWriter writer = getInstance(fieldInfo);
return writer.addField(fieldInfo);
}
@ -267,6 +268,17 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
}
}
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
KnnVectorsReader knnVectorsReader = fields.get(field);
if (knnVectorsReader == null) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
} else {
return knnVectorsReader.searchExhaustively(field, target, k, acceptDocs);
}
}
@Override
public void close() throws IOException {
IOUtils.close(fields.values());

View File

@ -24,6 +24,7 @@ import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexableFieldType;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
@ -44,6 +45,7 @@ public class FieldType implements IndexableFieldType {
private int indexDimensionCount;
private int dimensionNumBytes;
private int vectorDimension;
private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32;
private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
private Map<String, String> attributes;
@ -62,6 +64,7 @@ public class FieldType implements IndexableFieldType {
this.indexDimensionCount = ref.pointIndexDimensionCount();
this.dimensionNumBytes = ref.pointNumBytes();
this.vectorDimension = ref.vectorDimension();
this.vectorEncoding = ref.vectorEncoding();
this.vectorSimilarityFunction = ref.vectorSimilarityFunction();
if (ref.getAttributes() != null) {
this.attributes = new HashMap<>(ref.getAttributes());
@ -371,8 +374,8 @@ public class FieldType implements IndexableFieldType {
}
/** Enable vector indexing, with the specified number of dimensions and distance function. */
public void setVectorDimensionsAndSimilarityFunction(
int numDimensions, VectorSimilarityFunction distFunc) {
public void setVectorAttributes(
int numDimensions, VectorEncoding encoding, VectorSimilarityFunction similarity) {
checkIfFrozen();
if (numDimensions <= 0) {
throw new IllegalArgumentException("vector numDimensions must be > 0; got " + numDimensions);
@ -385,7 +388,8 @@ public class FieldType implements IndexableFieldType {
+ numDimensions);
}
this.vectorDimension = numDimensions;
this.vectorSimilarityFunction = Objects.requireNonNull(distFunc);
this.vectorSimilarityFunction = Objects.requireNonNull(similarity);
this.vectorEncoding = Objects.requireNonNull(encoding);
}
@Override
@ -393,6 +397,11 @@ public class FieldType implements IndexableFieldType {
return vectorDimension;
}
@Override
public VectorEncoding vectorEncoding() {
return vectorEncoding;
}
@Override
public VectorSimilarityFunction vectorSimilarityFunction() {
return vectorSimilarityFunction;

View File

@ -17,8 +17,10 @@
package org.apache.lucene.document;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;
/**
@ -39,7 +41,18 @@ public class KnnVectorField extends Field {
if (v == null) {
throw new IllegalArgumentException("vector value must not be null");
}
int dimension = v.length;
return createType(v.length, VectorEncoding.FLOAT32, similarityFunction);
}
private static FieldType createType(BytesRef v, VectorSimilarityFunction similarityFunction) {
if (v == null) {
throw new IllegalArgumentException("vector value must not be null");
}
return createType(v.length, VectorEncoding.BYTE, similarityFunction);
}
private static FieldType createType(
int dimension, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) {
if (dimension == 0) {
throw new IllegalArgumentException("cannot index an empty vector");
}
@ -51,13 +64,13 @@ public class KnnVectorField extends Field {
throw new IllegalArgumentException("similarity function must not be null");
}
FieldType type = new FieldType();
type.setVectorDimensionsAndSimilarityFunction(dimension, similarityFunction);
type.setVectorAttributes(dimension, vectorEncoding, similarityFunction);
type.freeze();
return type;
}
/**
* A convenience method for creating a vector field type.
* A convenience method for creating a vector field type with the default FLOAT32 encoding.
*
* @param dimension dimension of vectors
* @param similarityFunction a function defining vector proximity.
@ -65,8 +78,21 @@ public class KnnVectorField extends Field {
*/
public static FieldType createFieldType(
int dimension, VectorSimilarityFunction similarityFunction) {
return createFieldType(dimension, VectorEncoding.FLOAT32, similarityFunction);
}
/**
* A convenience method for creating a vector field type.
*
* @param dimension dimension of vectors
* @param vectorEncoding the encoding of the scalar values
* @param similarityFunction a function defining vector proximity.
* @throws IllegalArgumentException if any parameter is null, or has dimension &gt; 1024.
*/
public static FieldType createFieldType(
int dimension, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) {
FieldType type = new FieldType();
type.setVectorDimensionsAndSimilarityFunction(dimension, similarityFunction);
type.setVectorAttributes(dimension, vectorEncoding, similarityFunction);
type.freeze();
return type;
}
@ -74,8 +100,8 @@ public class KnnVectorField extends Field {
/**
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
* no value. Vectors of a single field share the same dimension and similarity function. Note that
* some strategies (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to be
* unit-length, which can be enforced using {@link VectorUtil#l2normalize(float[])}.
* some vector similarities (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to
* be unit-length, which can be enforced using {@link VectorUtil#l2normalize(float[])}.
*
* @param name field name
* @param vector value
@ -88,6 +114,23 @@ public class KnnVectorField extends Field {
fieldsData = vector;
}
/**
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
* no value. Vectors of a single field share the same dimension and similarity function. Note that
* some vector similarities (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to
* be constant-length.
*
* @param name field name
* @param vector value
* @param similarityFunction a function defining vector proximity.
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
* dimension &gt; 1024.
*/
public KnnVectorField(String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
fieldsData = vector;
}
/**
* Creates a numeric vector field with the default EUCLIDEAN_HNSW (L2) similarity. Fields are
* single-valued: each document has either one value or no value. Vectors of a single field share
@ -117,6 +160,21 @@ public class KnnVectorField extends Field {
fieldsData = vector;
}
/**
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
* no value. Vectors of a single field share the same dimension and similarity function.
*
* @param name field name
* @param vector value
* @param fieldType field type
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
* dimension &gt; 1024.
*/
public KnnVectorField(String name, BytesRef vector, FieldType fieldType) {
super(name, fieldType);
fieldsData = vector;
}
/** Return the vector value of this field */
public float[] vectorValue() {
return (float[]) fieldsData;

View File

@ -45,7 +45,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
protected BufferingKnnVectorsWriter() {}
@Override
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
public KnnFieldVectorsWriter<float[]> addField(FieldInfo fieldInfo) throws IOException {
FieldWriter newField = new FieldWriter(fieldInfo);
fields.add(newField);
return newField;
@ -88,6 +88,12 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException();
}
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
throw new UnsupportedOperationException();
}
};
writeField(fieldData.fieldInfo, knnVectorsReader, maxDoc);
@ -122,6 +128,12 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
throw new UnsupportedOperationException();
}
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
throw new UnsupportedOperationException();
}
@Override
public VectorValues getVectorValues(String field) throws IOException {
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
@ -137,7 +149,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
protected abstract void writeField(
FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc) throws IOException;
private static class FieldWriter extends KnnFieldVectorsWriter {
private static class FieldWriter extends KnnFieldVectorsWriter<float[]> {
private final FieldInfo fieldInfo;
private final int dim;
private final DocsWithFieldSet docsWithField;
@ -153,35 +165,45 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
}
@Override
public void addValue(int docID, float[] vectorValue) {
public void addValue(int docID, Object value) {
if (docID == lastDocID) {
throw new IllegalArgumentException(
"VectorValuesField \""
+ fieldInfo.name
+ "\" appears more than once in this document (only one value is allowed per field)");
}
if (vectorValue.length != dim) {
throw new IllegalArgumentException(
"Attempt to index a vector of dimension "
+ vectorValue.length
+ " but \""
+ fieldInfo.name
+ "\" has dimension "
+ dim);
}
assert docID > lastDocID;
float[] vectorValue =
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32 -> (float[]) value;
case BYTE -> bytesToFloats((BytesRef) value);
};
docsWithField.add(docID);
vectors.add(ArrayUtil.copyOfSubArray(vectorValue, 0, vectorValue.length));
vectors.add(copyValue(vectorValue));
lastDocID = docID;
}
private float[] bytesToFloats(BytesRef b) {
// This is used only by SimpleTextKnnVectorsWriter
float[] floats = new float[dim];
for (int i = 0; i < dim; i++) {
floats[i] = b.bytes[i + b.offset];
}
return floats;
}
@Override
public float[] copyValue(float[] vectorValue) {
return ArrayUtil.copyOfSubArray(vectorValue, 0, dim);
}
@Override
public long ramBytesUsed() {
if (vectors.size() == 0) return 0;
return docsWithField.ramBytesUsed()
+ vectors.size()
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
+ vectors.size() * vectors.get(0).length * Float.BYTES;
+ vectors.size() * dim * Float.BYTES;
}
}

View File

@ -25,6 +25,7 @@ import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PointsReader;
import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
@ -235,6 +236,19 @@ public abstract class CodecReader extends LeafReader {
return getVectorReader().search(field, target, k, acceptDocs, visitedLimit);
}
@Override
public final TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
ensureOpen();
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// Field does not exist or does not index vectors
return null;
}
return getVectorReader().searchExhaustively(field, target, k, acceptDocs);
}
@Override
protected void doClose() throws IOException {}

View File

@ -18,6 +18,7 @@
package org.apache.lucene.index;
import java.io.IOException;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
@ -58,6 +59,12 @@ abstract class DocValuesLeafReader extends LeafReader {
throw new UnsupportedOperationException();
}
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public final void checkIntegrity() throws IOException {
throw new UnsupportedOperationException();

View File

@ -56,6 +56,7 @@ public final class FieldInfo {
// if it is a positive value, it means this field indexes vectors
private final int vectorDimension;
private final VectorEncoding vectorEncoding;
private final VectorSimilarityFunction vectorSimilarityFunction;
// whether this field is used as the soft-deletes field
@ -80,6 +81,7 @@ public final class FieldInfo {
int pointIndexDimensionCount,
int pointNumBytes,
int vectorDimension,
VectorEncoding vectorEncoding,
VectorSimilarityFunction vectorSimilarityFunction,
boolean softDeletesField) {
this.name = Objects.requireNonNull(name);
@ -105,6 +107,7 @@ public final class FieldInfo {
this.pointIndexDimensionCount = pointIndexDimensionCount;
this.pointNumBytes = pointNumBytes;
this.vectorDimension = vectorDimension;
this.vectorEncoding = vectorEncoding;
this.vectorSimilarityFunction = vectorSimilarityFunction;
this.softDeletesField = softDeletesField;
this.checkConsistency();
@ -229,8 +232,10 @@ public final class FieldInfo {
verifySameVectorOptions(
fieldName,
this.vectorDimension,
this.vectorEncoding,
this.vectorSimilarityFunction,
o.vectorDimension,
o.vectorEncoding,
o.vectorSimilarityFunction);
}
@ -347,19 +352,25 @@ public final class FieldInfo {
static void verifySameVectorOptions(
String fieldName,
int vd1,
VectorEncoding ve1,
VectorSimilarityFunction vsf1,
int vd2,
VectorEncoding ve2,
VectorSimilarityFunction vsf2) {
if (vd1 != vd2 || vsf1 != vsf2) {
if (vd1 != vd2 || vsf1 != vsf2 || ve1 != ve2) {
throw new IllegalArgumentException(
"cannot change field \""
+ fieldName
+ "\" from vector dimension="
+ vd1
+ ", vector encoding="
+ ve1
+ ", vector similarity function="
+ vsf1
+ " to inconsistent vector dimension="
+ vd2
+ ", vector encoding="
+ ve2
+ ", vector similarity function="
+ vsf2);
}
@ -470,6 +481,11 @@ public final class FieldInfo {
return vectorDimension;
}
/** Returns the number of dimensions of the vector value */
public VectorEncoding getVectorEncoding() {
return vectorEncoding;
}
/** Returns {@link VectorSimilarityFunction} for the field */
public VectorSimilarityFunction getVectorSimilarityFunction() {
return vectorSimilarityFunction;

View File

@ -308,10 +308,15 @@ public class FieldInfos implements Iterable<FieldInfo> {
static final class FieldVectorProperties {
final int numDimensions;
final VectorEncoding vectorEncoding;
final VectorSimilarityFunction similarityFunction;
FieldVectorProperties(int numDimensions, VectorSimilarityFunction similarityFunction) {
FieldVectorProperties(
int numDimensions,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction) {
this.numDimensions = numDimensions;
this.vectorEncoding = vectorEncoding;
this.similarityFunction = similarityFunction;
}
}
@ -401,7 +406,8 @@ public class FieldInfos implements Iterable<FieldInfo> {
fi.getPointNumBytes()));
vectorProps.put(
fieldName,
new FieldVectorProperties(fi.getVectorDimension(), fi.getVectorSimilarityFunction()));
new FieldVectorProperties(
fi.getVectorDimension(), fi.getVectorEncoding(), fi.getVectorSimilarityFunction()));
}
return fieldNumber.intValue();
}
@ -459,8 +465,10 @@ public class FieldInfos implements Iterable<FieldInfo> {
verifySameVectorOptions(
fieldName,
props.numDimensions,
props.vectorEncoding,
props.similarityFunction,
fi.getVectorDimension(),
fi.getVectorEncoding(),
fi.getVectorSimilarityFunction());
}
@ -503,6 +511,7 @@ public class FieldInfos implements Iterable<FieldInfo> {
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
(softDeletesFieldName != null && softDeletesFieldName.equals(fieldName)));
addOrGet(fi);
@ -584,6 +593,7 @@ public class FieldInfos implements Iterable<FieldInfo> {
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
isSoftDeletesField);
}
@ -698,6 +708,7 @@ public class FieldInfos implements Iterable<FieldInfo> {
fi.getPointIndexDimensionCount(),
fi.getPointNumBytes(),
fi.getVectorDimension(),
fi.getVectorEncoding(),
fi.getVectorSimilarityFunction(),
fi.isSoftDeletesField());
byName.put(fiNew.getName(), fiNew);

View File

@ -18,6 +18,7 @@ package org.apache.lucene.index;
import java.io.IOException;
import java.util.Iterator;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.AttributeSource;
import org.apache.lucene.util.Bits;
@ -357,6 +358,12 @@ public abstract class FilterLeafReader extends LeafReader {
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
}
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
return in.searchNearestVectorsExhaustively(field, target, k, acceptDocs);
}
@Override
public Fields getTermVectors(int docID) throws IOException {
ensureOpen();

View File

@ -101,6 +101,9 @@ public interface IndexableFieldType {
/** The number of dimensions of the field's vector value */
int vectorDimension();
/** The {@link VectorEncoding} of the field's vector value */
VectorEncoding vectorEncoding();
/** The {@link VectorSimilarityFunction} of the field's vector value */
VectorSimilarityFunction vectorSimilarityFunction();

View File

@ -628,6 +628,7 @@ final class IndexingChain implements Accountable {
s.pointIndexDimensionCount,
s.pointNumBytes,
s.vectorDimension,
s.vectorEncoding,
s.vectorSimilarityFunction,
pf.fieldName.equals(fieldInfos.getSoftDeletesFieldName())));
pf.setFieldInfo(fi);
@ -712,7 +713,11 @@ final class IndexingChain implements Accountable {
pf.pointValuesWriter.addPackedValue(docID, field.binaryValue());
}
if (fieldType.vectorDimension() != 0) {
pf.knnFieldVectorsWriter.addValue(docID, ((KnnVectorField) field).vectorValue());
switch (fieldType.vectorEncoding()) {
case BYTE -> pf.knnFieldVectorsWriter.addValue(docID, field.binaryValue());
case FLOAT32 -> pf.knnFieldVectorsWriter.addValue(
docID, ((KnnVectorField) field).vectorValue());
}
}
return indexedField;
}
@ -776,7 +781,10 @@ final class IndexingChain implements Accountable {
fieldType.pointNumBytes());
}
if (fieldType.vectorDimension() != 0) {
schema.setVectors(fieldType.vectorSimilarityFunction(), fieldType.vectorDimension());
schema.setVectors(
fieldType.vectorEncoding(),
fieldType.vectorSimilarityFunction(),
fieldType.vectorDimension());
}
if (fieldType.getAttributes() != null && fieldType.getAttributes().isEmpty() == false) {
schema.updateAttributes(fieldType.getAttributes());
@ -988,7 +996,7 @@ final class IndexingChain implements Accountable {
PointValuesWriter pointValuesWriter;
// Non-null if this field had vectors in this segment
KnnFieldVectorsWriter knnFieldVectorsWriter;
KnnFieldVectorsWriter<?> knnFieldVectorsWriter;
/** We use this to know when a PerField is seen for the first time in the current document. */
long fieldGen = -1;
@ -1281,6 +1289,7 @@ final class IndexingChain implements Accountable {
private int pointIndexDimensionCount = 0;
private int pointNumBytes = 0;
private int vectorDimension = 0;
private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32;
private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
private static String errMsg =
@ -1361,11 +1370,14 @@ final class IndexingChain implements Accountable {
}
}
void setVectors(VectorSimilarityFunction similarityFunction, int dimension) {
void setVectors(
VectorEncoding encoding, VectorSimilarityFunction similarityFunction, int dimension) {
if (vectorDimension == 0) {
this.vectorDimension = dimension;
this.vectorEncoding = encoding;
this.vectorSimilarityFunction = similarityFunction;
this.vectorDimension = dimension;
} else {
assertSame("vector encoding", vectorEncoding, encoding);
assertSame("vector similarity function", vectorSimilarityFunction, similarityFunction);
assertSame("vector dimension", vectorDimension, dimension);
}
@ -1381,6 +1393,7 @@ final class IndexingChain implements Accountable {
pointIndexDimensionCount = 0;
pointNumBytes = 0;
vectorDimension = 0;
vectorEncoding = VectorEncoding.FLOAT32;
vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
}
@ -1391,6 +1404,7 @@ final class IndexingChain implements Accountable {
assertSame("doc values type", fi.getDocValuesType(), docValuesType);
assertSame(
"vector similarity function", fi.getVectorSimilarityFunction(), vectorSimilarityFunction);
assertSame("vector encoding", fi.getVectorEncoding(), vectorEncoding);
assertSame("vector dimension", fi.getVectorDimension(), vectorDimension);
assertSame("point dimension", fi.getPointDimensionCount(), pointDimensionCount);
assertSame(

View File

@ -17,6 +17,7 @@
package org.apache.lucene.index;
import java.io.IOException;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
@ -235,6 +236,30 @@ public abstract class LeafReader extends IndexReader {
public abstract TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
/**
* Return the k nearest neighbor documents as determined by comparison of their vector values for
* this field, to the given vector, by the field's similarity function. The score of each document
* is derived from the vector similarity in a way that ensures scores are positive and that a
* larger score corresponds to a higher ranking.
*
* <p>The search is exact, meaning the results are guaranteed to be the true k closest neighbors.
* This typically requires an exhaustive scan of all candidate documents.
*
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor,
* sorted in order of their similarity to the query vector (decreasing scores). The {@link
* TotalHits} contains the number of documents visited during the search.
*
* @param field the vector field to search
* @param target the vector-valued query
* @param k the number of docs to return
* @param acceptDocs {@link DocIdSetIterator} that represents the allowed documents to match, or
* {@code null} if they are all allowed to match.
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
* @lucene.experimental
*/
public abstract TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException;
/**
* Get the {@link FieldInfos} describing all fields in this reader.
*

View File

@ -26,6 +26,7 @@ import java.util.Objects;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
@ -403,6 +404,16 @@ public class ParallelLeafReader extends LeafReader {
: reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit);
}
@Override
public TopDocs searchNearestVectorsExhaustively(
String fieldName, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
ensureOpen();
LeafReader reader = fieldToReader.get(fieldName);
return reader == null
? null
: reader.searchNearestVectorsExhaustively(fieldName, target, k, acceptDocs);
}
@Override
public void checkIntegrity() throws IOException {
ensureOpen();

View File

@ -722,6 +722,7 @@ final class ReadersAndUpdates {
fi.getPointIndexDimensionCount(),
fi.getPointNumBytes(),
fi.getVectorDimension(),
fi.getVectorEncoding(),
fi.getVectorSimilarityFunction(),
fi.isSoftDeletesField());
}

View File

@ -27,6 +27,7 @@ import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PointsReader;
import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
@ -172,6 +173,12 @@ public final class SlowCodecReaderWrapper {
return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
}
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
return reader.searchNearestVectorsExhaustively(field, target, k, acceptDocs);
}
@Override
public void checkIntegrity() {
// We already checkIntegrity the entire reader up front

View File

@ -31,6 +31,7 @@ import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PointsReader;
import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
@ -389,6 +390,12 @@ public final class SortingCodecReader extends FilterCodecReader {
throw new UnsupportedOperationException();
}
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
throw new UnsupportedOperationException();
}
@Override
public void close() throws IOException {
delegate.close();

View File

@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.index;
/** The numeric datatype of the vector values. */
public enum VectorEncoding {
/**
* Encodes vector using 8 bits of precision per sample. Use only with DOT_PRODUCT similarity.
* NOTE: this can enable significant storage savings and faster searches, at the cost of some
* possible loss of precision. In order to use it, all vectors must be of the same norm, as
* measured by the sum of the squares of the scalar values, and those values must be in the range
* [-128, 127]. This applies to both document and query vectors. Using nonconforming vectors can
* result in errors or poor search results.
*/
BYTE(1),
/** Encodes vector using 32 bits of precision per sample in IEEE floating point format. */
FLOAT32(4);
/**
* The number of bytes required to encode a scalar in this format. A vector will require dimension
* * byteSize.
*/
public final int byteSize;
VectorEncoding(int byteSize) {
this.byteSize = byteSize;
}
}

View File

@ -18,6 +18,8 @@ package org.apache.lucene.index;
import static org.apache.lucene.util.VectorUtil.*;
import org.apache.lucene.util.BytesRef;
/**
* Vector similarity function; used in search to return top K most similar vectors to a target
* vector. This is a label describing the method used during indexing and searching of the vectors
@ -31,6 +33,11 @@ public enum VectorSimilarityFunction {
public float compare(float[] v1, float[] v2) {
return 1 / (1 + squareDistance(v1, v2));
}
@Override
public float compare(BytesRef v1, BytesRef v2) {
return 1 / (1 + squareDistance(v1, v2));
}
},
/**
@ -44,6 +51,11 @@ public enum VectorSimilarityFunction {
public float compare(float[] v1, float[] v2) {
return (1 + dotProduct(v1, v2)) / 2;
}
@Override
public float compare(BytesRef v1, BytesRef v2) {
return dotProductScore(v1, v2);
}
},
/**
@ -57,6 +69,11 @@ public enum VectorSimilarityFunction {
public float compare(float[] v1, float[] v2) {
return (1 + cosine(v1, v2)) / 2;
}
@Override
public float compare(BytesRef v1, BytesRef v2) {
return (1 + cosine(v1, v2)) / 2;
}
};
/**
@ -68,4 +85,15 @@ public enum VectorSimilarityFunction {
* @return the value of the similarity function applied to the two vectors
*/
public abstract float compare(float[] v1, float[] v2);
/**
* Calculates a similarity score between the two vectors with a specified function. Higher
* similarity scores correspond to closer vectors. The offsets and lengths of the BytesRefs
* determine the vector data that is compared. Each (signed) byte represents a vector dimension.
*
* @param v1 a vector
* @param v2 another vector, of the same dimension
* @return the value of the similarity function applied to the two vectors
*/
public abstract float compare(BytesRef v1, BytesRef v2);
}

View File

@ -65,7 +65,7 @@ class VectorValuesConsumer {
}
}
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
initKnnVectorsWriter(fieldInfo.name);
return writer.addField(fieldInfo);
}

View File

@ -24,11 +24,8 @@ import java.util.Comparator;
import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
@ -133,22 +130,21 @@ public class KnnVectorQuery extends Query {
return NO_RESULTS;
}
BitSet bitSet = createBitSet(scorer.iterator(), liveDocs, maxDoc);
BitSetIterator filterIterator = new BitSetIterator(bitSet, bitSet.cardinality());
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
if (filterIterator.cost() <= k) {
if (acceptDocs.cardinality() <= k) {
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
// must always visit at least k documents
return exactSearch(ctx, filterIterator);
return exactSearch(ctx, new BitSetIterator(acceptDocs, acceptDocs.cardinality()));
}
// Perform the approximate kNN search
TopDocs results = approximateSearch(ctx, bitSet, (int) filterIterator.cost());
TopDocs results = approximateSearch(ctx, acceptDocs, acceptDocs.cardinality());
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
return results;
} else {
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
return exactSearch(ctx, filterIterator);
return exactSearch(ctx, new BitSetIterator(acceptDocs, acceptDocs.cardinality()));
}
}
@ -178,45 +174,9 @@ public class KnnVectorQuery extends Query {
}
// We allow this to be overridden so that tests can check what search strategy is used
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptDocs)
throws IOException {
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return NO_RESULTS;
}
VectorSimilarityFunction similarityFunction = fi.getVectorSimilarityFunction();
VectorValues vectorValues = context.reader().getVectorValues(field);
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc = queue.top();
int doc;
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
int vectorDoc = vectorValues.advance(doc);
assert vectorDoc == doc;
float[] vector = vectorValues.vectorValue();
float score = similarityFunction.compare(vector, target);
if (score >= topDoc.score) {
topDoc.score = score;
topDoc.doc = doc;
topDoc = queue.updateTop();
}
}
// Remove any remaining sentinel values
while (queue.size() > 0 && queue.top().score < 0) {
queue.pop();
}
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
topScoreDocs[i] = queue.pop();
}
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
return new TopDocs(totalHits, topScoreDocs);
return context.reader().searchNearestVectorsExhaustively(field, target, k, acceptDocs);
}
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {

View File

@ -121,6 +121,24 @@ public final class VectorUtil {
return (float) (sum / Math.sqrt(norm1 * norm2));
}
/** Returns the cosine similarity between the two vectors. */
public static float cosine(BytesRef a, BytesRef b) {
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
int sum = 0;
int norm1 = 0;
int norm2 = 0;
int aOffset = a.offset, bOffset = b.offset;
for (int i = 0; i < a.length; i++) {
byte elem1 = a.bytes[aOffset++];
byte elem2 = b.bytes[bOffset++];
sum += elem1 * elem2;
norm1 += elem1 * elem1;
norm2 += elem2 * elem2;
}
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
}
/**
* Returns the sum of squared differences of the two vectors.
*
@ -135,7 +153,7 @@ public final class VectorUtil {
int dim = v1.length;
int i;
for (i = 0; i + 8 <= dim; i += 8) {
squareSum += squareDistanceUnrolled8(v1, v2, i);
squareSum += squareDistanceUnrolled(v1, v2, i);
}
for (; i < dim; i++) {
float diff = v1[i] - v2[i];
@ -144,7 +162,7 @@ public final class VectorUtil {
return squareSum;
}
private static float squareDistanceUnrolled8(float[] v1, float[] v2, int index) {
private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) {
float diff0 = v1[index + 0] - v2[index + 0];
float diff1 = v1[index + 1] - v2[index + 1];
float diff2 = v1[index + 2] - v2[index + 2];
@ -163,6 +181,18 @@ public final class VectorUtil {
+ diff7 * diff7;
}
/** Returns the sum of squared differences of the two vectors. */
public static float squareDistance(BytesRef a, BytesRef b) {
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
int squareSum = 0;
int aOffset = a.offset, bOffset = b.offset;
for (int i = 0; i < a.length; i++) {
int diff = a.bytes[aOffset++] - b.bytes[bOffset++];
squareSum += diff * diff;
}
return squareSum;
}
/**
* Modifies the argument to be unit length, dividing by its l2-norm. IllegalArgumentException is
* thrown for zero vectors.
@ -213,4 +243,48 @@ public final class VectorUtil {
u[i] += v[i];
}
}
/**
* Dot product computed over signed bytes.
*
* @param a bytes containing a vector
* @param b bytes containing another vector, of the same dimension
* @return the value of the dot product of the two vectors
*/
public static float dotProduct(BytesRef a, BytesRef b) {
assert a.length == b.length;
int total = 0;
int aOffset = a.offset, bOffset = b.offset;
for (int i = 0; i < a.length; i++) {
total += a.bytes[aOffset++] * b.bytes[bOffset++];
}
return total;
}
/**
* Dot product score computed over signed bytes, scaled to be in [0, 1].
*
* @param a bytes containing a vector
* @param b bytes containing another vector, of the same dimension
* @return the value of the similarity function applied to the two vectors
*/
public static float dotProductScore(BytesRef a, BytesRef b) {
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
return (1 + dotProduct(a, b)) / (float) (a.length * (1 << 15));
}
/**
* Convert a floating point vector to an array of bytes using casting; the vector values should be
* in [-128,127]
*
* @param vector a vector
* @return a new BytesRef containing the vector's values cast to byte.
*/
public static BytesRef toBytesRef(float[] vector) {
BytesRef b = new BytesRef(new byte[vector.length]);
for (int i = 0; i < vector.length; i++) {
b.bytes[i] = (byte) vector[i];
}
return b;
}
}

View File

@ -25,15 +25,19 @@ import java.util.Objects;
import java.util.SplittableRandom;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;
/**
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
* hyperparameters.
*
* @param <T> the type of vector
*/
public final class HnswGraphBuilder {
public final class HnswGraphBuilder<T> {
/** Default random seed for level generation * */
private static final long DEFAULT_RAND_SEED = 42;
@ -49,9 +53,10 @@ public final class HnswGraphBuilder {
private final NeighborArray scratch;
private final VectorSimilarityFunction similarityFunction;
private final VectorEncoding vectorEncoding;
private final RandomAccessVectorValues vectorValues;
private final SplittableRandom random;
private final HnswGraphSearcher graphSearcher;
private final HnswGraphSearcher<T> graphSearcher;
final OnHeapHnswGraph hnsw;
@ -59,7 +64,18 @@ public final class HnswGraphBuilder {
// we need two sources of vectors in order to perform diversity check comparisons without
// colliding
private RandomAccessVectorValues buildVectors;
private final RandomAccessVectorValues buildVectors;
public static HnswGraphBuilder<?> create(
RandomAccessVectorValuesProducer vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
int M,
int beamWidth,
long seed)
throws IOException {
return new HnswGraphBuilder<>(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
}
/**
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense
@ -73,8 +89,9 @@ public final class HnswGraphBuilder {
* @param seed the seed for a random number generator used during graph construction. Provide this
* to ensure repeatable construction.
*/
public HnswGraphBuilder(
private HnswGraphBuilder(
RandomAccessVectorValuesProducer vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
int M,
int beamWidth,
@ -82,6 +99,7 @@ public final class HnswGraphBuilder {
throws IOException {
vectorValues = vectors.randomAccess();
buildVectors = vectors.randomAccess();
this.vectorEncoding = Objects.requireNonNull(vectorEncoding);
this.similarityFunction = Objects.requireNonNull(similarityFunction);
if (M <= 0) {
throw new IllegalArgumentException("maxConn must be positive");
@ -97,7 +115,8 @@ public final class HnswGraphBuilder {
int levelOfFirstNode = getRandomGraphLevel(ml, random);
this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode);
this.graphSearcher =
new HnswGraphSearcher(
new HnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(beamWidth, true),
new FixedBitSet(vectorValues.size()));
@ -110,7 +129,7 @@ public final class HnswGraphBuilder {
* enables efficient retrieval without extra data copying, while avoiding collision of the
* returned values.
*
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independent
* accessor for the vectors
*/
public OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
@ -121,15 +140,19 @@ public final class HnswGraphBuilder {
if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(HNSW_COMPONENT, "build graph from " + vectors.size() + " vectors");
}
addVectors(vectors);
return hnsw;
}
private void addVectors(RandomAccessVectorValues vectors) throws IOException {
long start = System.nanoTime(), t = start;
// start at node 1! node 0 is added implicitly, in the constructor
for (int node = 1; node < vectors.size(); node++) {
addGraphNode(node, vectors.vectorValue(node));
addGraphNode(node, vectors);
if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
t = printGraphBuildStatus(node, start, t);
}
}
return hnsw;
}
/** Set info-stream to output debugging information * */
@ -142,7 +165,7 @@ public final class HnswGraphBuilder {
}
/** Inserts a doc with vector value to the graph */
public void addGraphNode(int node, float[] value) throws IOException {
public void addGraphNode(int node, T value) throws IOException {
NeighborQueue candidates;
final int nodeLevel = getRandomGraphLevel(ml, random);
int curMaxLevel = hnsw.numLevels() - 1;
@ -167,6 +190,18 @@ public final class HnswGraphBuilder {
}
}
public void addGraphNode(int node, RandomAccessVectorValues values) throws IOException {
addGraphNode(node, getValue(node, values));
}
@SuppressWarnings("unchecked")
private T getValue(int node, RandomAccessVectorValues values) throws IOException {
return switch (vectorEncoding) {
case BYTE -> (T) values.binaryValue(node);
case FLOAT32 -> (T) values.vectorValue(node);
};
}
private long printGraphBuildStatus(int node, long start, long t) {
long now = System.nanoTime();
infoStream.message(
@ -215,7 +250,7 @@ public final class HnswGraphBuilder {
int cNode = candidates.node[i];
float cScore = candidates.score[i];
assert cNode < hnsw.size();
if (diversityCheck(vectorValues.vectorValue(cNode), cScore, neighbors, buildVectors)) {
if (diversityCheck(cNode, cScore, neighbors)) {
neighbors.add(cNode, cScore);
}
}
@ -237,19 +272,38 @@ public final class HnswGraphBuilder {
* @param score the score of the new candidate and node n, to be compared with scores of the
* candidate and n's neighbors
* @param neighbors the neighbors selected so far
* @param vectorValues source of values used for making comparisons between candidate and existing
* neighbors
* @return whether the candidate is diverse given the existing neighbors
*/
private boolean diversityCheck(
float[] candidate,
float score,
NeighborArray neighbors,
RandomAccessVectorValues vectorValues)
private boolean diversityCheck(int candidate, float score, NeighborArray neighbors)
throws IOException {
return isDiverse(candidate, neighbors, score);
}
private boolean isDiverse(int candidate, NeighborArray neighbors, float score)
throws IOException {
return switch (vectorEncoding) {
case BYTE -> isDiverse(vectorValues.binaryValue(candidate), neighbors, score);
case FLOAT32 -> isDiverse(vectorValues.vectorValue(candidate), neighbors, score);
};
}
private boolean isDiverse(float[] candidate, NeighborArray neighbors, float score)
throws IOException {
for (int i = 0; i < neighbors.size(); i++) {
float neighborSimilarity =
similarityFunction.compare(candidate, vectorValues.vectorValue(neighbors.node[i]));
similarityFunction.compare(candidate, buildVectors.vectorValue(neighbors.node[i]));
if (neighborSimilarity >= score) {
return false;
}
}
return true;
}
private boolean isDiverse(BytesRef candidate, NeighborArray neighbors, float score)
throws IOException {
for (int i = 0; i < neighbors.size(); i++) {
float neighborSimilarity =
similarityFunction.compare(candidate, buildVectors.binaryValue(neighbors.node[i]));
if (neighborSimilarity >= score) {
return false;
}
@ -262,24 +316,52 @@ public final class HnswGraphBuilder {
* neighbours
*/
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
float minAcceptedSimilarity;
for (int i = neighbors.size() - 1; i > 0; i--) {
int cNode = neighbors.node[i];
float[] cVector = vectorValues.vectorValue(cNode);
minAcceptedSimilarity = neighbors.score[i];
// check the candidate against its better-scoring neighbors
for (int j = i - 1; j >= 0; j--) {
float neighborSimilarity =
similarityFunction.compare(cVector, buildVectors.vectorValue(neighbors.node[j]));
// node i is too similar to node j given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return i;
}
if (isWorstNonDiverse(i, neighbors, neighbors.score[i])) {
return i;
}
}
return neighbors.size() - 1;
}
private boolean isWorstNonDiverse(
int candidate, NeighborArray neighbors, float minAcceptedSimilarity) throws IOException {
return switch (vectorEncoding) {
case BYTE -> isWorstNonDiverse(
candidate, vectorValues.binaryValue(candidate), neighbors, minAcceptedSimilarity);
case FLOAT32 -> isWorstNonDiverse(
candidate, vectorValues.vectorValue(candidate), neighbors, minAcceptedSimilarity);
};
}
private boolean isWorstNonDiverse(
int candidateIndex, float[] candidate, NeighborArray neighbors, float minAcceptedSimilarity)
throws IOException {
for (int i = candidateIndex - 1; i > -0; i--) {
float neighborSimilarity =
similarityFunction.compare(candidate, buildVectors.vectorValue(neighbors.node[i]));
// node i is too similar to node j given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return false;
}
}
return true;
}
private boolean isWorstNonDiverse(
int candidateIndex, BytesRef candidate, NeighborArray neighbors, float minAcceptedSimilarity)
throws IOException {
for (int i = candidateIndex - 1; i > -0; i--) {
float neighborSimilarity =
similarityFunction.compare(candidate, buildVectors.binaryValue(neighbors.node[i]));
// node i is too similar to node j given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return false;
}
}
return true;
}
private static int getRandomGraphLevel(double ml, SplittableRandom random) {
double randDouble;
do {

View File

@ -18,21 +18,28 @@
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.VectorUtil.toBytesRef;
import java.io.IOException;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.SparseFixedBitSet;
/**
* Searches an HNSW graph to find nearest neighbors to a query vector. For more background on the
* search algorithm, see {@link HnswGraph}.
*
* @param <T> the type of query vector
*/
public final class HnswGraphSearcher {
public class HnswGraphSearcher<T> {
private final VectorSimilarityFunction similarityFunction;
private final VectorEncoding vectorEncoding;
/**
* Scratch data structures that are used in each {@link #searchLevel} call. These can be expensive
* to allocate, so they're cleared and reused across calls.
@ -49,7 +56,11 @@ public final class HnswGraphSearcher {
* @param visited bit set that will track nodes that have already been visited
*/
public HnswGraphSearcher(
VectorSimilarityFunction similarityFunction, NeighborQueue candidates, BitSet visited) {
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
NeighborQueue candidates,
BitSet visited) {
this.vectorEncoding = vectorEncoding;
this.similarityFunction = similarityFunction;
this.candidates = candidates;
this.visited = visited;
@ -73,13 +84,68 @@ public final class HnswGraphSearcher {
float[] query,
int topK,
RandomAccessVectorValues vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
HnswGraph graph,
Bits acceptOrds,
int visitedLimit)
throws IOException {
HnswGraphSearcher graphSearcher =
new HnswGraphSearcher(
if (query.length != vectors.dimension()) {
throw new IllegalArgumentException(
"vector query dimension: "
+ query.length
+ " differs from field dimension: "
+ vectors.dimension());
}
if (vectorEncoding == VectorEncoding.BYTE) {
return search(
toBytesRef(query),
topK,
vectors,
vectorEncoding,
similarityFunction,
graph,
acceptOrds,
visitedLimit);
}
HnswGraphSearcher<float[]> graphSearcher =
new HnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
NeighborQueue results;
int[] eps = new int[] {graph.entryNode()};
int numVisited = 0;
for (int level = graph.numLevels() - 1; level >= 1; level--) {
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
numVisited += results.visitedCount();
visitedLimit -= results.visitedCount();
if (results.incomplete()) {
results.setVisitedCount(numVisited);
return results;
}
eps[0] = results.pop();
}
results =
graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);
results.setVisitedCount(results.visitedCount() + numVisited);
return results;
}
private static NeighborQueue search(
BytesRef query,
int topK,
RandomAccessVectorValues vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
HnswGraph graph,
Bits acceptOrds,
int visitedLimit)
throws IOException {
HnswGraphSearcher<BytesRef> graphSearcher =
new HnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
@ -119,7 +185,8 @@ public final class HnswGraphSearcher {
* @return a priority queue holding the closest neighbors found
*/
public NeighborQueue searchLevel(
float[] query,
// Note: this is only public because Lucene91HnswGraphBuilder needs it
T query,
int topK,
int level,
final int[] eps,
@ -130,7 +197,7 @@ public final class HnswGraphSearcher {
}
private NeighborQueue searchLevel(
float[] query,
T query,
int topK,
int level,
final int[] eps,
@ -150,7 +217,7 @@ public final class HnswGraphSearcher {
results.markIncomplete();
break;
}
float score = similarityFunction.compare(query, vectors.vectorValue(ep));
float score = compare(query, vectors, ep);
numVisited++;
candidates.add(ep, score);
if (acceptOrds == null || acceptOrds.get(ep)) {
@ -185,7 +252,7 @@ public final class HnswGraphSearcher {
results.markIncomplete();
break;
}
float friendSimilarity = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
float friendSimilarity = compare(query, vectors, friendOrd);
numVisited++;
if (friendSimilarity >= minAcceptedSimilarity) {
candidates.add(friendOrd, friendSimilarity);
@ -204,6 +271,14 @@ public final class HnswGraphSearcher {
return results;
}
private float compare(T query, RandomAccessVectorValues vectors, int ord) throws IOException {
if (vectorEncoding == VectorEncoding.BYTE) {
return similarityFunction.compare((BytesRef) query, vectors.binaryValue(ord));
} else {
return similarityFunction.compare((float[]) query, vectors.vectorValue(ord));
}
}
private void prepareScratchState(int capacity) {
candidates.clear();
if (visited.length() < capacity) {

View File

@ -178,7 +178,7 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
return new KnnVectorsWriter() {
@Override
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
fieldsWritten.add(fieldInfo.name);
return writer.addField(fieldInfo);
}

View File

@ -112,6 +112,7 @@ public class TestCodecs extends LuceneTestCase {
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
false));
}

View File

@ -260,6 +260,7 @@ public class TestFieldInfos extends LuceneTestCase {
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
false));
}
@ -279,6 +280,7 @@ public class TestFieldInfos extends LuceneTestCase {
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
false));
assertEquals("Field numbers 0 through 9 were allocated", 10, idx);
@ -300,6 +302,7 @@ public class TestFieldInfos extends LuceneTestCase {
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
false));
assertEquals("Field numbers should reset after clear()", 0, idx);

View File

@ -64,6 +64,7 @@ public class TestFieldsReader extends LuceneTestCase {
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
field.name().equals(softDeletesFieldName)));
}

View File

@ -113,6 +113,11 @@ public class TestIndexableField extends LuceneTestCase {
return 0;
}
@Override
public VectorEncoding vectorEncoding() {
return VectorEncoding.FLOAT32;
}
@Override
public VectorSimilarityFunction vectorSimilarityFunction() {
return VectorSimilarityFunction.EUCLIDEAN;

View File

@ -67,6 +67,8 @@ public class TestKnnGraph extends LuceneTestCase {
private static int M = Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN;
private Codec codec;
private Codec float32Codec;
private VectorEncoding vectorEncoding;
private VectorSimilarityFunction similarityFunction;
@Before
@ -86,6 +88,31 @@ public class TestKnnGraph extends LuceneTestCase {
int similarity = random().nextInt(VectorSimilarityFunction.values().length - 1) + 1;
similarityFunction = VectorSimilarityFunction.values()[similarity];
vectorEncoding = randomVectorEncoding();
codec =
new Lucene94Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene94HnswVectorsFormat(M, Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
}
};
if (vectorEncoding == VectorEncoding.FLOAT32) {
float32Codec = codec;
} else {
float32Codec =
new Lucene94Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene94HnswVectorsFormat(M, Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
}
};
}
}
private VectorEncoding randomVectorEncoding() {
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
}
@After
@ -102,10 +129,7 @@ public class TestKnnGraph extends LuceneTestCase {
float[][] values = new float[numDoc][];
for (int i = 0; i < numDoc; i++) {
if (random().nextBoolean()) {
values[i] = new float[dimension];
for (int j = 0; j < dimension; j++) {
values[i][j] = random().nextFloat();
}
values[i] = randomVector(dimension);
}
add(iw, i, values[i]);
}
@ -117,6 +141,14 @@ public class TestKnnGraph extends LuceneTestCase {
try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(codec))) {
float[][] values = new float[][] {new float[] {0, 1, 2}};
if (similarityFunction == VectorSimilarityFunction.DOT_PRODUCT) {
VectorUtil.l2normalize(values[0]);
}
if (vectorEncoding == VectorEncoding.BYTE) {
for (int i = 0; i < 3; i++) {
values[0][i] = (float) Math.floor(values[0][i] * 127);
}
}
add(iw, 0, values[0]);
assertConsistentGraph(iw, values);
iw.commit();
@ -133,11 +165,7 @@ public class TestKnnGraph extends LuceneTestCase {
float[][] values = randomVectors(numDoc, dimension);
for (int i = 0; i < numDoc; i++) {
if (random().nextBoolean()) {
values[i] = new float[dimension];
for (int j = 0; j < dimension; j++) {
values[i][j] = random().nextFloat();
}
VectorUtil.l2normalize(values[i]);
values[i] = randomVector(dimension);
}
add(iw, i, values[i]);
if (random().nextInt(10) == 3) {
@ -249,16 +277,26 @@ public class TestKnnGraph extends LuceneTestCase {
float[][] values = new float[numDoc][];
for (int i = 0; i < numDoc; i++) {
if (random().nextBoolean()) {
values[i] = new float[dimension];
for (int j = 0; j < dimension; j++) {
values[i][j] = random().nextFloat();
}
VectorUtil.l2normalize(values[i]);
values[i] = randomVector(dimension);
}
}
return values;
}
private float[] randomVector(int dimension) {
float[] value = new float[dimension];
for (int j = 0; j < dimension; j++) {
value[j] = random().nextFloat();
}
VectorUtil.l2normalize(value);
if (vectorEncoding == VectorEncoding.BYTE) {
for (int j = 0; j < dimension; j++) {
value[j] = (byte) (value[j] * 127);
}
}
return value;
}
int[][][] copyGraph(HnswGraph graphValues) throws IOException {
int[][][] graph = new int[graphValues.numLevels()][][];
int size = graphValues.size();
@ -285,7 +323,7 @@ public class TestKnnGraph extends LuceneTestCase {
// We can't use dot product here since the vectors are laid out on a grid, not a sphere.
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
IndexWriterConfig config = newIndexWriterConfig();
config.setCodec(codec); // test is not compatible with simpletext
config.setCodec(float32Codec);
try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, config)) {
indexData(iw);
@ -341,7 +379,7 @@ public class TestKnnGraph extends LuceneTestCase {
public void testMultiThreadedSearch() throws Exception {
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
IndexWriterConfig config = newIndexWriterConfig();
config.setCodec(codec);
config.setCodec(float32Codec);
Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, config);
indexData(iw);
@ -468,7 +506,7 @@ public class TestKnnGraph extends LuceneTestCase {
"vector did not match for doc " + i + ", id=" + id + ": " + Arrays.toString(scratch),
values[id],
scratch,
0f);
0);
numDocsWithVectors++;
}
// if IndexDisi.doc == NO_MORE_DOCS, we should not call IndexDisi.nextDoc()

View File

@ -196,6 +196,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
true);
List<Integer> docsDeleted = Arrays.asList(1, 3, 7, 8, DocIdSetIterator.NO_MORE_DOCS);
@ -233,6 +234,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
true);
for (DocValuesFieldUpdates update : updates) {
@ -295,6 +297,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
true);
List<Integer> docsDeleted = Arrays.asList(1, DocIdSetIterator.NO_MORE_DOCS);
@ -362,6 +365,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
true);
List<DocValuesFieldUpdates> updates =
@ -398,6 +402,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
true);
updates = Arrays.asList(singleUpdate(Arrays.asList(1, DocIdSetIterator.NO_MORE_DOCS), 3, true));

View File

@ -25,6 +25,7 @@ import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.document.Document;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
@ -117,6 +118,12 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
return null;
}
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
return null;
}
@Override
protected void doClose() {}

View File

@ -19,6 +19,7 @@ package org.apache.lucene.search;
import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently;
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.TestVectorUtil.randomVector;
@ -40,7 +41,7 @@ import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
@ -48,6 +49,7 @@ import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.VectorUtil;
@ -174,7 +176,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0}, 10);
IllegalArgumentException e =
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
assertEquals("vector dimensions differ: 1!=2", e.getMessage());
assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
}
}
@ -239,43 +241,38 @@ public class TestKnnVectorQuery extends LuceneTestCase {
}
public void testScoreEuclidean() throws IOException {
try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 0; j < 5; j++) {
Document doc = new Document();
doc.add(
new KnnVectorField("field", new float[] {j, j}, VectorSimilarityFunction.EUCLIDEAN));
w.addDocument(doc);
}
}
try (IndexReader reader = DirectoryReader.open(d)) {
assertEquals(1, reader.leaves().size());
IndexSearcher searcher = new IndexSearcher(reader);
KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
Query rewritten = query.rewrite(reader);
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
Scorer scorer = weight.scorer(reader.leaves().get(0));
float[][] vectors = new float[5][];
for (int j = 0; j < 5; j++) {
vectors[j] = new float[] {j, j};
}
try (Directory d = getIndexStore("field", vectors);
IndexReader reader = DirectoryReader.open(d)) {
assertEquals(1, reader.leaves().size());
IndexSearcher searcher = new IndexSearcher(reader);
KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
Query rewritten = query.rewrite(reader);
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
Scorer scorer = weight.scorer(reader.leaves().get(0));
// prior to advancing, score is 0
assertEquals(-1, scorer.docID());
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
// prior to advancing, score is 0
assertEquals(-1, scorer.docID());
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
// test getMaxScore
assertEquals(0, scorer.getMaxScore(-1), 0);
assertEquals(0, scorer.getMaxScore(0), 0);
// This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
// test getMaxScore
assertEquals(0, scorer.getMaxScore(-1), 0);
assertEquals(0, scorer.getMaxScore(0), 0);
// This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
DocIdSetIterator it = scorer.iterator();
assertEquals(3, it.cost());
assertEquals(1, it.nextDoc());
assertEquals(1 / 6f, scorer.score(), 0);
assertEquals(3, it.advance(3));
assertEquals(1 / 2f, scorer.score(), 0);
assertEquals(NO_MORE_DOCS, it.advance(4));
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
}
DocIdSetIterator it = scorer.iterator();
assertEquals(3, it.cost());
assertEquals(1, it.nextDoc());
assertEquals(1 / 6f, scorer.score(), 0);
assertEquals(3, it.advance(3));
assertEquals(1 / 2f, scorer.score(), 0);
assertEquals(NO_MORE_DOCS, it.advance(4));
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
}
}
@ -764,9 +761,18 @@ public class TestKnnVectorQuery extends LuceneTestCase {
private Directory getIndexStore(String field, float[]... contents) throws IOException {
Directory indexStore = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
VectorEncoding encoding = randomVectorEncoding();
for (int i = 0; i < contents.length; ++i) {
Document doc = new Document();
doc.add(new KnnVectorField(field, contents[i]));
if (encoding == VectorEncoding.BYTE) {
BytesRef v = new BytesRef(new byte[contents[i].length]);
for (int j = 0; j < v.length; j++) {
v.bytes[j] = (byte) contents[i][j];
}
doc.add(new KnnVectorField(field, v, EUCLIDEAN));
} else {
doc.add(new KnnVectorField(field, contents[i]));
}
doc.add(new StringField("id", "id" + i, Field.Store.YES));
writer.addDocument(doc);
}
@ -908,4 +914,8 @@ public class TestKnnVectorQuery extends LuceneTestCase {
return 31 * classHash() + docs.hashCode();
}
}
private VectorEncoding randomVectorEncoding() {
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
}
}

View File

@ -50,6 +50,7 @@ import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.SortField.Type;
import org.apache.lucene.store.Directory;
@ -1126,6 +1127,7 @@ public class TestSortOptimization extends LuceneTestCase {
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.DOT_PRODUCT,
fi.isSoftDeletesField());
newInfos[i] = noIndexFI;

View File

@ -18,6 +18,7 @@ package org.apache.lucene.util;
import java.util.Random;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
public class TestVectorUtil extends LuceneTestCase {
@ -130,6 +131,23 @@ public class TestVectorUtil extends LuceneTestCase {
return u;
}
private static BytesRef negative(BytesRef v) {
BytesRef u = new BytesRef(new byte[v.length]);
for (int i = 0; i < v.length; i++) {
// what is (byte) -(-128)? 127?
u.bytes[i] = (byte) -v.bytes[i];
}
return u;
}
private static float l2(BytesRef v) {
float l2 = 0;
for (int i = v.offset; i < v.offset + v.length; i++) {
l2 += v.bytes[i] * v.bytes[i];
}
return l2;
}
private static float[] randomVector() {
return randomVector(random().nextInt(100) + 1);
}
@ -142,4 +160,88 @@ public class TestVectorUtil extends LuceneTestCase {
}
return v;
}
private static BytesRef randomVectorBytes() {
BytesRef v = TestUtil.randomBinaryTerm(random(), TestUtil.nextInt(random(), 1, 100));
// clip at -127 to avoid overflow
for (int i = v.offset; i < v.offset + v.length; i++) {
if (v.bytes[i] == -128) {
v.bytes[i] = -127;
}
}
return v;
}
public void testBasicDotProductBytes() {
BytesRef a = new BytesRef(new byte[] {1, 2, 3});
BytesRef b = new BytesRef(new byte[] {-10, 0, 5});
assertEquals(5, VectorUtil.dotProduct(a, b), 0);
assertEquals(5 / (3f * (1 << 15)), VectorUtil.dotProductScore(a, b), DELTA);
}
public void testSelfDotProductBytes() {
// the dot product of a vector with itself is equal to the sum of the squares of its components
BytesRef v = randomVectorBytes();
assertEquals(l2(v), VectorUtil.dotProduct(v, v), DELTA);
}
public void testOrthogonalDotProductBytes() {
// the dot product of two perpendicular vectors is 0
byte[] v = new byte[4];
v[0] = (byte) random().nextInt(100);
v[1] = (byte) random().nextInt(100);
v[2] = v[1];
v[3] = (byte) -v[0];
// also test computing using BytesRef with nonzero offset
assertEquals(0, VectorUtil.dotProduct(new BytesRef(v, 0, 2), new BytesRef(v, 2, 2)), DELTA);
}
public void testSelfSquareDistanceBytes() {
// the l2 distance of a vector with itself is zero
BytesRef v = randomVectorBytes();
assertEquals(0, VectorUtil.squareDistance(v, v), DELTA);
}
public void testBasicSquareDistanceBytes() {
assertEquals(
12,
VectorUtil.squareDistance(
new BytesRef(new byte[] {1, 2, 3}), new BytesRef(new byte[] {-1, 0, 5})),
0);
}
public void testRandomSquareDistanceBytes() {
// the square distance of a vector with its inverse is equal to four times the sum of squares of
// its components
BytesRef v = randomVectorBytes();
BytesRef u = negative(v);
assertEquals(4 * l2(v), VectorUtil.squareDistance(u, v), DELTA);
}
public void testBasicCosineBytes() {
assertEquals(
0.11952f,
VectorUtil.cosine(new BytesRef(new byte[] {1, 2, 3}), new BytesRef(new byte[] {-10, 0, 5})),
DELTA);
}
public void testSelfCosineBytes() {
// the dot product of a vector with itself is always equal to 1
BytesRef v = randomVectorBytes();
// ensure the vector is non-zero so that cosine is defined
v.bytes[0] = (byte) (random().nextInt(126) + 1);
assertEquals(1.0f, VectorUtil.cosine(v, v), DELTA);
}
public void testOrthogonalCosineBytes() {
// the cosine of two perpendicular vectors is 0
float[] v = new float[2];
v[0] = random().nextInt(100);
// ensure the vector is non-zero so that cosine is defined
v[1] = random().nextInt(1, 100);
float[] u = new float[2];
u[0] = v[1];
u[1] = -v[0];
assertEquals(0, VectorUtil.cosine(u, v), DELTA);
}
}

View File

@ -56,6 +56,7 @@ import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
@ -101,6 +102,7 @@ public class KnnGraphTester {
private int beamWidth;
private int maxConn;
private VectorSimilarityFunction similarityFunction;
private VectorEncoding vectorEncoding;
private FixedBitSet matchDocs;
private float selectivity;
private boolean prefilter;
@ -113,6 +115,7 @@ public class KnnGraphTester {
topK = 100;
fanout = topK;
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
vectorEncoding = VectorEncoding.FLOAT32;
selectivity = 1f;
prefilter = false;
}
@ -195,12 +198,30 @@ public class KnnGraphTester {
case "-docs":
docVectorsPath = Paths.get(args[++iarg]);
break;
case "-encoding":
String encoding = args[++iarg];
switch (encoding) {
case "byte":
vectorEncoding = VectorEncoding.BYTE;
break;
case "float32":
vectorEncoding = VectorEncoding.FLOAT32;
break;
default:
throw new IllegalArgumentException("-encoding can be 'byte' or 'float32' only");
}
break;
case "-metric":
String metric = args[++iarg];
if (metric.equals("euclidean")) {
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
} else if (metric.equals("angular") == false) {
throw new IllegalArgumentException("-metric can be 'angular' or 'euclidean' only");
switch (metric) {
case "euclidean":
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
break;
case "angular":
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
break;
default:
throw new IllegalArgumentException("-metric can be 'angular' or 'euclidean' only");
}
break;
case "-forceMerge":
@ -229,7 +250,7 @@ public class KnnGraphTester {
if (operation == null && reindex == false) {
usage();
}
if (prefilter == true && selectivity == 1f) {
if (prefilter && selectivity == 1f) {
throw new IllegalArgumentException("-prefilter requires filterSelectivity between 0 and 1");
}
indexPath = Paths.get(formatIndexPath(docVectorsPath));
@ -248,7 +269,9 @@ public class KnnGraphTester {
if (docVectorsPath == null) {
throw new IllegalArgumentException("missing -docs arg");
}
matchDocs = generateRandomBitSet(numDocs, selectivity);
if (selectivity < 1) {
matchDocs = generateRandomBitSet(numDocs, selectivity);
}
if (outputPath != null) {
testSearch(indexPath, queryPath, outputPath, null);
} else {
@ -285,14 +308,17 @@ public class KnnGraphTester {
}
}
@SuppressWarnings("unchecked")
private void dumpGraph(Path docsPath) throws IOException {
try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) {
RandomAccessVectorValues values = vectors.randomAccess();
HnswGraphBuilder builder =
new HnswGraphBuilder(vectors, similarityFunction, maxConn, beamWidth, 0);
HnswGraphBuilder<float[]> builder =
(HnswGraphBuilder<float[]>)
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, maxConn, beamWidth, 0);
// start at node 1
for (int i = 1; i < numDocs; i++) {
builder.addGraphNode(i, values.vectorValue(i));
builder.addGraphNode(i, values);
System.out.println("\nITERATION " + i);
dumpGraph(builder.hnsw);
}
@ -375,13 +401,8 @@ public class KnnGraphTester {
throws IOException {
TopDocs[] results = new TopDocs[numIters];
long elapsed, totalCpuTime, totalVisited = 0;
try (FileChannel q = FileChannel.open(queryPath)) {
int bufferSize = numIters * dim * Float.BYTES;
FloatBuffer targets =
q.map(FileChannel.MapMode.READ_ONLY, 0, bufferSize)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
float[] target = new float[dim];
try (FileChannel input = FileChannel.open(queryPath)) {
VectorReader targetReader = VectorReader.create(input, dim, vectorEncoding, numIters);
if (quiet == false) {
System.out.println("running " + numIters + " targets; topK=" + topK + ", fanout=" + fanout);
}
@ -392,21 +413,21 @@ public class KnnGraphTester {
DirectoryReader reader = DirectoryReader.open(dir)) {
IndexSearcher searcher = new IndexSearcher(reader);
numDocs = reader.maxDoc();
Query bitSetQuery = new BitSetQuery(matchDocs);
Query bitSetQuery = prefilter ? new BitSetQuery(matchDocs) : null;
for (int i = 0; i < numIters; i++) {
// warm up
targets.get(target);
float[] target = targetReader.next();
if (prefilter) {
doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
} else {
doKnnVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
}
}
targets.position(0);
targetReader.reset();
start = System.nanoTime();
cpuTimeStartNs = bean.getCurrentThreadCpuTime();
for (int i = 0; i < numIters; i++) {
targets.get(target);
float[] target = targetReader.next();
if (prefilter) {
results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
} else {
@ -414,10 +435,12 @@ public class KnnGraphTester {
doKnnVectorQuery(
searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
results[i].scoreDocs =
Arrays.stream(results[i].scoreDocs)
.filter(scoreDoc -> matchDocs == null || matchDocs.get(scoreDoc.doc))
.toArray(ScoreDoc[]::new);
if (matchDocs != null) {
results[i].scoreDocs =
Arrays.stream(results[i].scoreDocs)
.filter(scoreDoc -> matchDocs.get(scoreDoc.doc))
.toArray(ScoreDoc[]::new);
}
}
}
totalCpuTime = (bean.getCurrentThreadCpuTime() - cpuTimeStartNs) / 1_000_000;
@ -425,7 +448,14 @@ public class KnnGraphTester {
for (int i = 0; i < numIters; i++) {
totalVisited += results[i].totalHits.value;
for (ScoreDoc doc : results[i].scoreDocs) {
doc.doc = Integer.parseInt(reader.document(doc.doc).get("id"));
if (doc.doc != NO_MORE_DOCS) {
// there is a bug somewhere that can result in doc=NO_MORE_DOCS! I think it happens
// in some degenerate case (like input query has NaN in it?) that causes no results to
// be returned from HNSW search?
doc.doc = Integer.parseInt(reader.document(doc.doc).get("id"));
} else {
System.out.println("NO_MORE_DOCS!");
}
}
}
}
@ -477,6 +507,78 @@ public class KnnGraphTester {
}
}
private abstract static class VectorReader {
final float[] target;
final ByteBuffer bytes;
static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding, int n)
throws IOException {
int bufferSize = n * dim * vectorEncoding.byteSize;
return switch (vectorEncoding) {
case BYTE -> new VectorReaderByte(input, dim, bufferSize);
case FLOAT32 -> new VectorReaderFloat32(input, dim, bufferSize);
};
}
VectorReader(FileChannel input, int dim, int bufferSize) throws IOException {
bytes =
input.map(FileChannel.MapMode.READ_ONLY, 0, bufferSize).order(ByteOrder.LITTLE_ENDIAN);
target = new float[dim];
}
void reset() {
bytes.position(0);
}
abstract float[] next();
}
private static class VectorReaderFloat32 extends VectorReader {
private final FloatBuffer floats;
VectorReaderFloat32(FileChannel input, int dim, int bufferSize) throws IOException {
super(input, dim, bufferSize);
floats = bytes.asFloatBuffer();
}
@Override
void reset() {
super.reset();
floats.position(0);
}
@Override
float[] next() {
floats.get(target);
return target;
}
}
private static class VectorReaderByte extends VectorReader {
private byte[] scratch;
private BytesRef bytesRef;
VectorReaderByte(FileChannel input, int dim, int bufferSize) throws IOException {
super(input, dim, bufferSize);
scratch = new byte[dim];
bytesRef = new BytesRef(scratch);
}
@Override
float[] next() {
bytes.get(scratch);
for (int i = 0; i < scratch.length; i++) {
target[i] = scratch[i];
}
return target;
}
BytesRef nextBytes() {
bytes.get(scratch);
return bytesRef;
}
}
private static TopDocs doKnnVectorQuery(
IndexSearcher searcher, String field, float[] vector, int k, int fanout, Query filter)
throws IOException {
@ -529,7 +631,9 @@ public class KnnGraphTester {
if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath) && selectivity == 1f) {
return readNN(nnPath);
} else {
int[][] nn = computeNN(docPath, queryPath);
// TODO: enable computing NN from high precision vectors when
// checking low-precision recall
int[][] nn = computeNN(docPath, queryPath, vectorEncoding);
if (selectivity == 1f) {
writeNN(nn, nnPath);
}
@ -589,52 +693,37 @@ public class KnnGraphTester {
return bitSet;
}
private int[][] computeNN(Path docPath, Path queryPath) throws IOException {
private int[][] computeNN(Path docPath, Path queryPath, VectorEncoding encoding)
throws IOException {
int[][] result = new int[numIters][];
if (quiet == false) {
System.out.println("computing true nearest neighbors of " + numIters + " target vectors");
}
try (FileChannel in = FileChannel.open(docPath);
FileChannel qIn = FileChannel.open(queryPath)) {
FloatBuffer queries =
qIn.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
float[] vector = new float[dim];
float[] query = new float[dim];
VectorReader docReader = VectorReader.create(in, dim, encoding, numDocs);
VectorReader queryReader = VectorReader.create(qIn, dim, encoding, numIters);
for (int i = 0; i < numIters; i++) {
queries.get(query);
long totalBytes = (long) numDocs * dim * Float.BYTES;
final int maxBlockSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
int offset = 0;
int j = 0;
// System.out.println("totalBytes=" + totalBytes);
while (j < numDocs) {
int blockSize = (int) Math.min(totalBytes - offset, maxBlockSize);
FloatBuffer vectors =
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
offset += blockSize;
NeighborQueue queue = new NeighborQueue(topK, false);
for (; j < numDocs && vectors.hasRemaining(); j++) {
vectors.get(vector);
float d = similarityFunction.compare(query, vector);
if (matchDocs == null || matchDocs.get(j)) {
queue.insertWithOverflow(j, d);
}
}
result[i] = new int[topK];
for (int k = topK - 1; k >= 0; k--) {
result[i][k] = queue.topNode();
queue.pop();
// System.out.print(" " + n);
}
if (quiet == false && (i + 1) % 10 == 0) {
System.out.print(" " + (i + 1));
System.out.flush();
float[] query = queryReader.next();
NeighborQueue queue = new NeighborQueue(topK, false);
for (int j = 0; j < numDocs; j++) {
float[] doc = docReader.next();
float d = similarityFunction.compare(query, doc);
if (matchDocs == null || matchDocs.get(j)) {
queue.insertWithOverflow(j, d);
}
}
docReader.reset();
result[i] = new int[topK];
for (int k = topK - 1; k >= 0; k--) {
result[i][k] = queue.topNode();
queue.pop();
// System.out.print(" " + n);
}
if (quiet == false && (i + 1) % 10 == 0) {
System.out.print(" " + (i + 1));
System.out.flush();
}
}
}
return result;
@ -651,37 +740,29 @@ public class KnnGraphTester {
});
// iwc.setMergePolicy(NoMergePolicy.INSTANCE);
iwc.setRAMBufferSizeMB(1994d);
iwc.setUseCompoundFile(false);
// iwc.setMaxBufferedDocs(10000);
FieldType fieldType = KnnVectorField.createFieldType(dim, similarityFunction);
FieldType fieldType = KnnVectorField.createFieldType(dim, vectorEncoding, similarityFunction);
if (quiet == false) {
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
System.out.println("creating index in " + indexPath);
}
long start = System.nanoTime();
long totalBytes = (long) numDocs * dim * Float.BYTES, offset = 0;
final int maxBlockSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
try (FSDirectory dir = FSDirectory.open(indexPath);
IndexWriter iw = new IndexWriter(dir, iwc)) {
float[] vector = new float[dim];
try (FileChannel in = FileChannel.open(docsPath)) {
int i = 0;
while (i < numDocs) {
int blockSize = (int) Math.min(totalBytes - offset, maxBlockSize);
FloatBuffer vectors =
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
offset += blockSize;
for (; vectors.hasRemaining() && i < numDocs; i++) {
vectors.get(vector);
Document doc = new Document();
// System.out.println("vector=" + vector[0] + "," + vector[1] + "...");
doc.add(new KnnVectorField(KNN_FIELD, vector, fieldType));
doc.add(new StoredField(ID_FIELD, i));
iw.addDocument(doc);
VectorReader vectorReader = VectorReader.create(in, dim, vectorEncoding, numDocs);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
switch (vectorEncoding) {
case BYTE -> doc.add(
new KnnVectorField(
KNN_FIELD, ((VectorReaderByte) vectorReader).nextBytes(), fieldType));
case FLOAT32 -> doc.add(new KnnVectorField(KNN_FIELD, vectorReader.next(), fieldType));
}
doc.add(new StoredField(ID_FIELD, i));
iw.addDocument(doc);
}
if (quiet == false) {
System.out.println("Done indexing " + numDocs + " documents; now flush");

View File

@ -31,6 +31,7 @@ class MockVectorValues extends VectorValues
protected final float[][] denseValues;
protected final float[][] values;
private final int numVectors;
private final BytesRef binaryValue;
private int pos = -1;
@ -47,6 +48,9 @@ class MockVectorValues extends VectorValues
}
numVectors = count;
scratch = new float[dimension];
// used by tests that build a graph from bytes rather than floats
binaryValue = new BytesRef(dimension);
binaryValue.length = dimension;
}
public MockVectorValues copy() {
@ -89,7 +93,11 @@ class MockVectorValues extends VectorValues
@Override
public BytesRef binaryValue(int targetOrd) {
return null;
float[] value = vectorValue(targetOrd);
for (int i = 0; i < value.length; i++) {
binaryValue.bytes[i] = (byte) value[i];
}
return binaryValue;
}
private boolean seek(int target) {

View File

@ -18,6 +18,7 @@
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.VectorUtil.toBytesRef;
import java.io.IOException;
import java.util.ArrayList;
@ -43,6 +44,7 @@ import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.IndexSearcher;
@ -60,25 +62,38 @@ import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.junit.Before;
/** Tests HNSW KNN graphs */
public class TestHnswGraph extends LuceneTestCase {
VectorSimilarityFunction similarityFunction;
VectorEncoding vectorEncoding;
@Before
public void setup() {
similarityFunction =
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
if (similarityFunction == VectorSimilarityFunction.DOT_PRODUCT) {
vectorEncoding =
VectorEncoding.values()[random().nextInt(VectorEncoding.values().length - 1) + 1];
} else {
vectorEncoding = VectorEncoding.FLOAT32;
}
}
// test writing out and reading in a graph gives the expected graph
public void testReadWrite() throws IOException {
int dim = random().nextInt(100) + 1;
int nDoc = random().nextInt(100) + 1;
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random());
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
int M = random().nextInt(10) + 5;
int M = random().nextInt(4) + 2;
int beamWidth = random().nextInt(10) + 5;
long seed = random().nextLong();
VectorSimilarityFunction similarityFunction =
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
HnswGraphBuilder builder =
new HnswGraphBuilder(vectors, similarityFunction, M, beamWidth, seed);
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, vectorEncoding, random());
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
HnswGraph hnsw = builder.build(vectors);
// Recreate the graph while indexing with the same random seed and write it out
@ -131,6 +146,10 @@ public class TestHnswGraph extends LuceneTestCase {
}
}
private VectorEncoding randomVectorEncoding() {
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
}
// test that sorted index returns the same search results are unsorted
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
int dim = random().nextInt(10) + 3;
@ -250,24 +269,27 @@ public class TestHnswGraph extends LuceneTestCase {
// oriented in the right directions
public void testAknnDiverse() throws IOException {
int nDoc = 100;
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder =
new HnswGraphBuilder(
vectors, VectorSimilarityFunction.DOT_PRODUCT, 10, 100, random().nextInt());
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 10, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors);
// run some searches
NeighborQueue nn =
HnswGraphSearcher.search(
new float[] {1, 0},
getTargetVector(),
10,
vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT,
vectorEncoding,
similarityFunction,
hnsw,
null,
Integer.MAX_VALUE);
int[] nodes = nn.nodes();
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
int sum = 0;
for (int node : nodes) {
sum += node;
@ -289,23 +311,26 @@ public class TestHnswGraph extends LuceneTestCase {
public void testSearchWithAcceptOrds() throws IOException {
int nDoc = 100;
CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder =
new HnswGraphBuilder(
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
vectorEncoding = randomVectorEncoding();
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors);
// the first 10 docs must not be deleted to ensure the expected recall
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
NeighborQueue nn =
HnswGraphSearcher.search(
new float[] {1, 0},
getTargetVector(),
10,
vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT,
vectorEncoding,
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
int[] nodes = nn.nodes();
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
int sum = 0;
for (int node : nodes) {
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
@ -319,9 +344,11 @@ public class TestHnswGraph extends LuceneTestCase {
public void testSearchWithSelectiveAcceptOrds() throws IOException {
int nDoc = 100;
CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder =
new HnswGraphBuilder(
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors);
// Only mark a few vectors as accepted
BitSet acceptOrds = new FixedBitSet(vectors.size);
@ -333,10 +360,11 @@ public class TestHnswGraph extends LuceneTestCase {
int numAccepted = acceptOrds.cardinality();
NeighborQueue nn =
HnswGraphSearcher.search(
new float[] {1, 0},
getTargetVector(),
numAccepted,
vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT,
vectorEncoding,
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
@ -347,12 +375,17 @@ public class TestHnswGraph extends LuceneTestCase {
}
}
private float[] getTargetVector() {
return new float[] {1, 0};
}
public void testSearchWithSkewedAcceptOrds() throws IOException {
int nDoc = 1000;
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder =
new HnswGraphBuilder(
vectors, VectorSimilarityFunction.EUCLIDEAN, 16, 100, random().nextInt());
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, VectorEncoding.FLOAT32, similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors);
// Skip over half of the documents that are closest to the query vector
@ -362,15 +395,16 @@ public class TestHnswGraph extends LuceneTestCase {
}
NeighborQueue nn =
HnswGraphSearcher.search(
new float[] {1, 0},
getTargetVector(),
10,
vectors.randomAccess(),
VectorSimilarityFunction.EUCLIDEAN,
VectorEncoding.FLOAT32,
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
int[] nodes = nn.nodes();
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
int sum = 0;
for (int node : nodes) {
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
@ -383,20 +417,23 @@ public class TestHnswGraph extends LuceneTestCase {
public void testVisitedLimit() throws IOException {
int nDoc = 500;
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder =
new HnswGraphBuilder(
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors);
int topK = 50;
int visitedLimit = topK + random().nextInt(5);
NeighborQueue nn =
HnswGraphSearcher.search(
new float[] {1, 0},
getTargetVector(),
topK,
vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT,
vectorEncoding,
similarityFunction,
hnsw,
createRandomAcceptOrds(0, vectors.size),
visitedLimit);
@ -406,54 +443,68 @@ public class TestHnswGraph extends LuceneTestCase {
}
public void testHnswGraphBuilderInvalid() {
expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, null, 0, 0, 0));
expectThrows(
NullPointerException.class, () -> HnswGraphBuilder.create(null, null, null, 0, 0, 0));
// M must be > 0
expectThrows(
IllegalArgumentException.class,
() ->
new HnswGraphBuilder(
HnswGraphBuilder.create(
new RandomVectorValues(1, 1, random()),
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
0,
10,
0));
// beamWidth must be > 0
expectThrows(
IllegalArgumentException.class,
() ->
new HnswGraphBuilder(
HnswGraphBuilder.create(
new RandomVectorValues(1, 1, random()),
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
10,
0,
0));
}
@SuppressWarnings("unchecked")
public void testDiversity() throws IOException {
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
// Some carefully checked test cases with simple 2d vectors on the unit circle:
MockVectorValues vectors =
new MockVectorValues(
new float[][] {
unitVector2d(0.5),
unitVector2d(0.75),
unitVector2d(0.2),
unitVector2d(0.9),
unitVector2d(0.8),
unitVector2d(0.77),
});
float[][] values = {
unitVector2d(0.5),
unitVector2d(0.75),
unitVector2d(0.2),
unitVector2d(0.9),
unitVector2d(0.8),
unitVector2d(0.77),
};
if (vectorEncoding == VectorEncoding.BYTE) {
for (float[] v : values) {
for (int i = 0; i < v.length; i++) {
v[i] *= 127;
}
}
}
MockVectorValues vectors = new MockVectorValues(values);
// First add nodes until everybody gets a full neighbor list
HnswGraphBuilder builder =
new HnswGraphBuilder(
vectors, VectorSimilarityFunction.DOT_PRODUCT, 2, 10, random().nextInt());
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 2, 10, random().nextInt());
// node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0));
builder.addGraphNode(1, vectors.vectorValue(1));
builder.addGraphNode(2, vectors.vectorValue(2));
builder.addGraphNode(1, vectors);
builder.addGraphNode(2, vectors);
// now every node has tried to attach every other node as a neighbor, but
// some were excluded based on diversity check.
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
assertLevel0Neighbors(builder.hnsw, 1, 0);
assertLevel0Neighbors(builder.hnsw, 2, 0);
builder.addGraphNode(3, vectors.vectorValue(3));
builder.addGraphNode(3, vectors);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
// we added 3 here
assertLevel0Neighbors(builder.hnsw, 1, 0, 3);
@ -461,7 +512,7 @@ public class TestHnswGraph extends LuceneTestCase {
assertLevel0Neighbors(builder.hnsw, 3, 1);
// supplant an existing neighbor
builder.addGraphNode(4, vectors.vectorValue(4));
builder.addGraphNode(4, vectors);
// 4 is the same distance from 0 that 2 is; we leave the existing node in place
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4);
@ -470,7 +521,7 @@ public class TestHnswGraph extends LuceneTestCase {
assertLevel0Neighbors(builder.hnsw, 3, 1, 4);
assertLevel0Neighbors(builder.hnsw, 4, 1, 3);
builder.addGraphNode(5, vectors.vectorValue(5));
builder.addGraphNode(5, vectors);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4, 5);
assertLevel0Neighbors(builder.hnsw, 2, 0);
@ -494,29 +545,46 @@ public class TestHnswGraph extends LuceneTestCase {
public void testRandom() throws IOException {
int size = atLeast(100);
int dim = atLeast(10);
RandomVectorValues vectors = new RandomVectorValues(size, dim, random());
VectorSimilarityFunction similarityFunction =
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
RandomVectorValues vectors = new RandomVectorValues(size, dim, vectorEncoding, random());
int topK = 5;
HnswGraphBuilder builder =
new HnswGraphBuilder(vectors, similarityFunction, 10, 30, random().nextLong());
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 10, 30, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors);
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
int totalMatches = 0;
for (int i = 0; i < 100; i++) {
float[] query = randomVector(random(), dim);
NeighborQueue actual =
NeighborQueue actual;
float[] query;
BytesRef bQuery = null;
if (vectorEncoding == VectorEncoding.BYTE) {
query = randomVector8(random(), dim);
bQuery = toBytesRef(query);
} else {
query = randomVector(random(), dim);
}
actual =
HnswGraphSearcher.search(
query, 100, vectors, similarityFunction, hnsw, acceptOrds, Integer.MAX_VALUE);
query,
100,
vectors,
vectorEncoding,
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
while (actual.size() > topK) {
actual.pop();
}
NeighborQueue expected = new NeighborQueue(topK, false);
for (int j = 0; j < size; j++) {
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j)));
if (vectorEncoding == VectorEncoding.BYTE) {
expected.add(j, similarityFunction.compare(bQuery, vectors.binaryValue(j)));
} else {
expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j)));
}
if (expected.size() > topK) {
expected.pop();
}
@ -553,12 +621,14 @@ public class TestHnswGraph extends LuceneTestCase {
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
private final int size;
private final float[] value;
private final BytesRef binaryValue;
int doc = -1;
CircularVectorValues(int size) {
this.size = size;
value = new float[2];
binaryValue = new BytesRef(new byte[2]);
}
public CircularVectorValues copy() {
@ -617,7 +687,11 @@ public class TestHnswGraph extends LuceneTestCase {
@Override
public BytesRef binaryValue(int ord) {
return null;
float[] vectorValue = vectorValue(ord);
for (int i = 0; i < vectorValue.length; i++) {
binaryValue.bytes[i] = (byte) (vectorValue[i] * 127);
}
return binaryValue;
}
}
@ -648,8 +722,9 @@ public class TestHnswGraph extends LuceneTestCase {
if (uDoc == NO_MORE_DOCS) {
break;
}
float delta = vectorEncoding == VectorEncoding.BYTE ? 1 : 1e-4f;
assertArrayEquals(
"vectors do not match for doc=" + uDoc, u.vectorValue(), v.vectorValue(), 1e-4f);
"vectors do not match for doc=" + uDoc, u.vectorValue(), v.vectorValue(), delta);
}
}
@ -657,7 +732,11 @@ public class TestHnswGraph extends LuceneTestCase {
static class RandomVectorValues extends MockVectorValues {
RandomVectorValues(int size, int dimension, Random random) {
super(createRandomVectors(size, dimension, random));
super(createRandomVectors(size, dimension, null, random));
}
RandomVectorValues(int size, int dimension, VectorEncoding vectorEncoding, Random random) {
super(createRandomVectors(size, dimension, vectorEncoding, random));
}
RandomVectorValues(RandomVectorValues other) {
@ -669,11 +748,21 @@ public class TestHnswGraph extends LuceneTestCase {
return new RandomVectorValues(this);
}
private static float[][] createRandomVectors(int size, int dimension, Random random) {
private static float[][] createRandomVectors(
int size, int dimension, VectorEncoding vectorEncoding, Random random) {
float[][] vectors = new float[size][];
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
vectors[offset] = randomVector(random, dimension);
}
if (vectorEncoding == VectorEncoding.BYTE) {
for (float[] vector : vectors) {
if (vector != null) {
for (int i = 0; i < vector.length; i++) {
vector[i] = (byte) (127 * vector[i]);
}
}
}
}
return vectors;
}
}
@ -701,8 +790,19 @@ public class TestHnswGraph extends LuceneTestCase {
float[] vec = new float[dim];
for (int i = 0; i < dim; i++) {
vec[i] = random.nextFloat();
if (random.nextBoolean()) {
vec[i] = -vec[i];
}
}
VectorUtil.l2normalize(vec);
return vec;
}
private static float[] randomVector8(Random random, int dim) {
float[] fvec = randomVector(random, dim);
for (int i = 0; i < dim; i++) {
fvec[i] *= 127;
}
return fvec;
}
}

View File

@ -34,8 +34,10 @@ import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.StoredFieldVisitor;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.Version;
@ -97,6 +99,7 @@ public class TermVectorLeafReader extends LeafReader {
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
false);
fieldInfos = new FieldInfos(new FieldInfo[] {fieldInfo});
@ -166,6 +169,12 @@ public class TermVectorLeafReader extends LeafReader {
return null;
}
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
return null;
}
@Override
public void checkIntegrity() throws IOException {}

View File

@ -37,6 +37,7 @@ import org.apache.lucene.document.FieldType;
import org.apache.lucene.index.*;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
@ -514,6 +515,7 @@ public class MemoryIndex {
fieldType.pointIndexDimensionCount(),
fieldType.pointNumBytes(),
fieldType.vectorDimension(),
fieldType.vectorEncoding(),
fieldType.vectorSimilarityFunction(),
false);
}
@ -546,6 +548,7 @@ public class MemoryIndex {
info.fieldInfo.getPointIndexDimensionCount(),
info.fieldInfo.getPointNumBytes(),
info.fieldInfo.getVectorDimension(),
info.fieldInfo.getVectorEncoding(),
info.fieldInfo.getVectorSimilarityFunction(),
info.fieldInfo.isSoftDeletesField());
} else if (existingDocValuesType != docValuesType) {
@ -1371,6 +1374,12 @@ public class MemoryIndex {
return null;
}
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
return null;
}
@Override
public void checkIntegrity() throws IOException {
// no-op

View File

@ -29,6 +29,7 @@ import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.Bits;
@ -61,7 +62,7 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
}
@Override
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
return delegate.addField(fieldInfo);
}
@ -131,6 +132,18 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
return hits;
}
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
FieldInfo fi = fis.fieldInfo(field);
assert fi != null && fi.getVectorDimension() > 0;
assert acceptDocs != null;
TopDocs hits = delegate.searchExhaustively(field, target, k, acceptDocs);
assert hits != null;
assert hits.scoreDocs.length <= k;
return hits;
}
@Override
public void close() throws IOException {
delegate.close();

View File

@ -36,6 +36,7 @@ import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexableFieldType;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.internal.tests.IndexPackageAccess;
@ -305,6 +306,7 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
fieldType.pointIndexDimensionCount(),
fieldType.pointNumBytes(),
fieldType.vectorDimension(),
fieldType.vectorEncoding(),
fieldType.vectorSimilarityFunction(),
field.equals(softDeletesField));
addAttributes(fi);
@ -353,7 +355,8 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
int dimension = 1 + r.nextInt(VectorValues.MAX_DIMENSIONS);
VectorSimilarityFunction similarityFunction =
RandomPicks.randomFrom(r, VectorSimilarityFunction.values());
type.setVectorDimensionsAndSimilarityFunction(dimension, similarityFunction);
VectorEncoding encoding = RandomPicks.randomFrom(r, VectorEncoding.values());
type.setVectorAttributes(dimension, encoding, similarityFunction);
}
return type;
@ -422,6 +425,7 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
false);
}

View File

@ -360,6 +360,7 @@ abstract class BaseIndexFileFormatTestCase extends LuceneTestCase {
proto.getPointIndexDimensionCount(),
proto.getPointNumBytes(),
proto.getVectorDimension(),
proto.getVectorEncoding(),
proto.getVectorSimilarityFunction(),
proto.isSoftDeletesField());

View File

@ -41,6 +41,7 @@ import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.Sort;
@ -51,8 +52,10 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.VectorUtil;
import org.junit.Before;
/**
* Base class aiming at testing {@link KnnVectorsFormat vectors formats}. To test a new format, all
@ -63,9 +66,21 @@ import org.apache.lucene.util.VectorUtil;
*/
public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTestCase {
private VectorEncoding vectorEncoding;
private VectorSimilarityFunction similarityFunction;
@Before
public void init() {
vectorEncoding = randomVectorEncoding();
similarityFunction = randomSimilarity();
}
@Override
protected void addRandomFields(Document doc) {
doc.add(new KnnVectorField("v2", randomVector(30), VectorSimilarityFunction.EUCLIDEAN));
switch (vectorEncoding) {
case BYTE -> doc.add(new KnnVectorField("v2", randomVector8(30), similarityFunction));
case FLOAT32 -> doc.add(new KnnVectorField("v2", randomVector(30), similarityFunction));
}
}
public void testFieldConstructor() {
@ -133,8 +148,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2));
String errMsg =
"cannot change field \"f\" from vector dimension=4, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=3, vector similarity function=DOT_PRODUCT";
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=3, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT";
assertEquals(errMsg, expected.getMessage());
}
}
@ -170,8 +185,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2));
String errMsg =
"cannot change field \"f\" from vector dimension=4, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=4, vector similarity function=EUCLIDEAN";
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN";
assertEquals(errMsg, expected.getMessage());
}
}
@ -190,8 +205,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2));
assertEquals(
"cannot change field \"f\" from vector dimension=4, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=1, vector similarity function=DOT_PRODUCT",
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=1, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
expected.getMessage());
}
}
@ -211,8 +226,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2));
assertEquals(
"cannot change field \"f\" from vector dimension=4, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=4, vector similarity function=EUCLIDEAN",
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN",
expected.getMessage());
}
}
@ -311,8 +326,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
expectThrows(
IllegalArgumentException.class, () -> w2.addIndexes(new Directory[] {dir}));
assertEquals(
"cannot change field \"f\" from vector dimension=5, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT",
"cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
expected.getMessage());
}
}
@ -333,8 +348,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w2.addIndexes(dir));
assertEquals(
"cannot change field \"f\" from vector dimension=4, vector similarity function=EUCLIDEAN "
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT",
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN "
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
expected.getMessage());
}
}
@ -358,8 +373,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException.class,
() -> w2.addIndexes(new CodecReader[] {(CodecReader) getOnlyLeafReader(r)}));
assertEquals(
"cannot change field \"f\" from vector dimension=5, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT",
"cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
expected.getMessage());
}
}
@ -384,8 +399,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException.class,
() -> w2.addIndexes(new CodecReader[] {(CodecReader) getOnlyLeafReader(r)}));
assertEquals(
"cannot change field \"f\" from vector dimension=4, vector similarity function=EUCLIDEAN "
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT",
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN "
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
expected.getMessage());
}
}
@ -408,8 +423,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> TestUtil.addIndexesSlowly(w2, r));
assertEquals(
"cannot change field \"f\" from vector dimension=5, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT",
"cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
expected.getMessage());
}
}
@ -432,8 +447,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> TestUtil.addIndexesSlowly(w2, r));
assertEquals(
"cannot change field \"f\" from vector dimension=4, vector similarity function=EUCLIDEAN "
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT",
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN "
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
expected.getMessage());
}
}
@ -596,12 +611,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
int[] fieldDocCounts = new int[numFields];
double[] fieldTotals = new double[numFields];
int[] fieldDims = new int[numFields];
VectorSimilarityFunction[] fieldSearchStrategies = new VectorSimilarityFunction[numFields];
VectorSimilarityFunction[] fieldSimilarityFunctions = new VectorSimilarityFunction[numFields];
VectorEncoding[] fieldVectorEncodings = new VectorEncoding[numFields];
for (int i = 0; i < numFields; i++) {
fieldDims[i] = random().nextInt(20) + 1;
fieldSearchStrategies[i] =
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length)];
fieldSimilarityFunctions[i] = randomSimilarity();
fieldVectorEncodings[i] = randomVectorEncoding();
}
try (Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) {
@ -610,15 +625,23 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
for (int field = 0; field < numFields; field++) {
String fieldName = "int" + field;
if (random().nextInt(100) == 17) {
float[] v = randomVector(fieldDims[field]);
doc.add(new KnnVectorField(fieldName, v, fieldSearchStrategies[field]));
switch (fieldVectorEncodings[field]) {
case BYTE -> {
BytesRef b = randomVector8(fieldDims[field]);
doc.add(new KnnVectorField(fieldName, b, fieldSimilarityFunctions[field]));
fieldTotals[field] += b.bytes[b.offset];
}
case FLOAT32 -> {
float[] v = randomVector(fieldDims[field]);
doc.add(new KnnVectorField(fieldName, v, fieldSimilarityFunctions[field]));
fieldTotals[field] += v[0];
}
}
fieldDocCounts[field]++;
fieldTotals[field] += v[0];
}
}
w.addDocument(doc);
}
try (IndexReader r = w.getReader()) {
for (int field = 0; field < numFields; field++) {
int docCount = 0;
@ -634,12 +657,29 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
}
}
assertEquals(fieldDocCounts[field], docCount);
assertEquals(fieldTotals[field], checksum, 1e-5);
// Account for quantization done when indexing fields w/BYTE encoding
double delta = fieldVectorEncodings[field] == VectorEncoding.BYTE ? numDocs * 0.01 : 1e-5;
assertEquals(fieldTotals[field], checksum, delta);
}
}
}
}
private VectorSimilarityFunction randomSimilarity() {
return VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length)];
}
private VectorEncoding randomVectorEncoding() {
Codec codec = getCodec();
if (codec.knnVectorsFormat().currentVersion()
>= Codec.forName("Lucene94").knnVectorsFormat().currentVersion()) {
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
} else {
return VectorEncoding.FLOAT32;
}
}
public void testIndexedValueNotAliased() throws Exception {
// We copy indexed values (as for BinaryDocValues) so the input float[] can be reused across
// calls to IndexWriter.addDocument.
@ -742,7 +782,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
assertEquals(3, vectorValues3.dimension());
assertEquals(1, vectorValues3.size());
vectorValues3.nextDoc();
assertEquals(1f, vectorValues3.vectorValue()[0], 0);
assertEquals(1f, vectorValues3.vectorValue()[0], 0.1);
assertEquals(NO_MORE_DOCS, vectorValues3.nextDoc());
}
}
@ -775,9 +815,9 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
if (random().nextBoolean() && values[i] != null) {
// sometimes use a shared scratch array
System.arraycopy(values[i], 0, scratch, 0, scratch.length);
add(iw, fieldName, i, scratch, VectorSimilarityFunction.EUCLIDEAN);
add(iw, fieldName, i, scratch, similarityFunction);
} else {
add(iw, fieldName, i, values[i], VectorSimilarityFunction.EUCLIDEAN);
add(iw, fieldName, i, values[i], similarityFunction);
}
if (random().nextInt(10) == 2) {
// sometimes delete a random document
@ -898,7 +938,6 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
int numDoc = atLeast(100);
int dimension = atLeast(10);
float[][] id2value = new float[numDoc][];
int[] id2ord = new int[numDoc];
for (int i = 0; i < numDoc; i++) {
int id = random().nextInt(numDoc);
float[] value;
@ -909,7 +948,6 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
value = null;
}
id2value[id] = value;
id2ord[id] = i;
add(iw, fieldName, id, value, VectorSimilarityFunction.EUCLIDEAN);
}
try (IndexReader reader = DirectoryReader.open(iw)) {
@ -1007,6 +1045,15 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
return v;
}
private BytesRef randomVector8(int dim) {
float[] v = randomVector(dim);
byte[] b = new byte[dim];
for (int i = 0; i < dim; i++) {
b[i] = (byte) (v[i] * 127);
}
return new BytesRef(b);
}
public void testCheckIndexIncludesVectors() throws Exception {
try (Directory dir = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
@ -1041,6 +1088,14 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
assertEquals(3, VectorSimilarityFunction.values().length);
}
public void testVectorEncodingOrdinals() {
// make sure we don't accidentally mess up vector encoding identifiers by re-ordering their
// enumerators
assertEquals(0, VectorEncoding.BYTE.ordinal());
assertEquals(1, VectorEncoding.FLOAT32.ordinal());
assertEquals(2, VectorEncoding.values().length);
}
public void testAdvance() throws Exception {
try (Directory dir = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
@ -1091,10 +1146,6 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
public void testVectorValuesReportCorrectDocs() throws Exception {
final int numDocs = atLeast(1000);
final int dim = random().nextInt(20) + 1;
final VectorSimilarityFunction similarityFunction =
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length)];
double fieldValuesCheckSum = 0;
int fieldDocCount = 0;
long fieldSumDocIDs = 0;
@ -1106,9 +1157,18 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
int docID = random().nextInt(numDocs);
doc.add(new StoredField("id", docID));
if (random().nextInt(4) == 3) {
float[] vector = randomVector(dim);
doc.add(new KnnVectorField("knn_vector", vector, similarityFunction));
fieldValuesCheckSum += vector[0];
switch (vectorEncoding) {
case BYTE -> {
BytesRef b = randomVector8(dim);
fieldValuesCheckSum += b.bytes[b.offset];
doc.add(new KnnVectorField("knn_vector", b, similarityFunction));
}
case FLOAT32 -> {
float[] v = randomVector(dim);
fieldValuesCheckSum += v[0];
doc.add(new KnnVectorField("knn_vector", v, similarityFunction));
}
}
fieldDocCount++;
fieldSumDocIDs += docID;
}
@ -1134,7 +1194,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
}
}
}
assertEquals(fieldValuesCheckSum, checksum, 1e-3);
assertEquals(
fieldValuesCheckSum,
checksum,
vectorEncoding == VectorEncoding.BYTE ? numDocs * 0.2 : 1e-5);
assertEquals(fieldDocCount, docCount);
assertEquals(fieldSumDocIDs, sumDocIds);
}

View File

@ -40,6 +40,7 @@ import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.StoredFieldVisitor;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
@ -228,6 +229,12 @@ class MergeReaderWrapper extends LeafReader {
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
}
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
return in.searchNearestVectorsExhaustively(field, target, k, acceptDocs);
}
@Override
public int numDocs() {
return in.numDocs();

View File

@ -88,6 +88,7 @@ public class MismatchedLeafReader extends FilterLeafReader {
oldInfo.getPointIndexDimensionCount(), // index dimension count
oldInfo.getPointNumBytes(), // dimension numBytes
oldInfo.getVectorDimension(), // number of dimensions of the field's vector
oldInfo.getVectorEncoding(), // numeric type of vector samples
// distance function for calculating similarity of the field's vector
oldInfo.getVectorSimilarityFunction(),
oldInfo.isSoftDeletesField()); // used as soft-deletes field

View File

@ -62,6 +62,7 @@ import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.TermState;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.tests.IndexPackageAccess;
import org.apache.lucene.internal.tests.TestSecrets;
@ -163,6 +164,7 @@ public class RandomPostingsTester {
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
false);
fieldUpto++;
@ -734,6 +736,7 @@ public class RandomPostingsTester {
0,
0,
0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN,
false);
}

View File

@ -233,6 +233,12 @@ public class QueryUtils {
return null;
}
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
return null;
}
@Override
public FieldInfos getFieldInfos() {
return FieldInfos.EMPTY;