LUCENE-10577: enable quantization of HNSW vectors to 8 bits (#1054)

* LUCENE-10577: enable supplying, storing, and comparing HNSW vectors with 8 bit precision
This commit is contained in:
Michael Sokolov 2022-08-10 17:09:07 -04:00 committed by GitHub
parent 59a0917e25
commit a693fe819b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
75 changed files with 2230 additions and 488 deletions

View File

@ -30,6 +30,7 @@ import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.SegmentInfo; import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataOutput; import org.apache.lucene.store.DataOutput;
@ -214,6 +215,7 @@ public final class Lucene60FieldInfosFormat extends FieldInfosFormat {
pointIndexDimensionCount, pointIndexDimensionCount,
pointNumBytes, pointNumBytes,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
isSoftDeletesField); isSoftDeletesField);
} catch (IllegalStateException e) { } catch (IllegalStateException e) {

View File

@ -32,7 +32,6 @@ import org.apache.lucene.codecs.StoredFieldsFormat;
import org.apache.lucene.codecs.TermVectorsFormat; import org.apache.lucene.codecs.TermVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat; import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
import org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat;
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat; import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat; import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat; import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;

View File

@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package org.apache.lucene.codecs.lucene90; package org.apache.lucene.backward_codecs.lucene90;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
@ -29,6 +29,7 @@ import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.SegmentInfo; import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataOutput; import org.apache.lucene.store.DataOutput;
@ -191,6 +192,7 @@ public final class Lucene90FieldInfosFormat extends FieldInfosFormat {
pointIndexDimensionCount, pointIndexDimensionCount,
pointNumBytes, pointNumBytes,
vectorDimension, vectorDimension,
VectorEncoding.FLOAT32,
vectorDistFunc, vectorDistFunc,
isSoftDeletesField); isSoftDeletesField);
infos[i].checkConsistency(); infos[i].checkConsistency();

View File

@ -18,6 +18,7 @@
package org.apache.lucene.backward_codecs.lucene90; package org.apache.lucene.backward_codecs.lucene90;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -36,6 +37,7 @@ import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
@ -277,6 +279,21 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs); return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
// The field does not exist or does not index vectors
return EMPTY_TOPDOCS;
}
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
VectorValues vectorValues = getVectorValues(field);
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
}
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException { private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
IndexInput bytesSlice = IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength); vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);

View File

@ -17,6 +17,7 @@
package org.apache.lucene.backward_codecs.lucene91; package org.apache.lucene.backward_codecs.lucene91;
import java.util.Objects; import java.util.Objects;
import org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat;
import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.CompoundFormat; import org.apache.lucene.codecs.CompoundFormat;
import org.apache.lucene.codecs.DocValuesFormat; import org.apache.lucene.codecs.DocValuesFormat;
@ -32,7 +33,6 @@ import org.apache.lucene.codecs.StoredFieldsFormat;
import org.apache.lucene.codecs.TermVectorsFormat; import org.apache.lucene.codecs.TermVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat; import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
import org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat;
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat; import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat; import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat; import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;

View File

@ -25,6 +25,7 @@ import java.util.Objects;
import java.util.SplittableRandom; import java.util.SplittableRandom;
import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer; import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.InfoStream;
@ -55,7 +56,7 @@ public final class Lucene91HnswGraphBuilder {
private final RandomAccessVectorValues vectorValues; private final RandomAccessVectorValues vectorValues;
private final SplittableRandom random; private final SplittableRandom random;
private final Lucene91BoundsChecker bound; private final Lucene91BoundsChecker bound;
private final HnswGraphSearcher graphSearcher; private final HnswGraphSearcher<float[]> graphSearcher;
final Lucene91OnHeapHnswGraph hnsw; final Lucene91OnHeapHnswGraph hnsw;
@ -101,7 +102,8 @@ public final class Lucene91HnswGraphBuilder {
int levelOfFirstNode = getRandomGraphLevel(ml, random); int levelOfFirstNode = getRandomGraphLevel(ml, random);
this.hnsw = new Lucene91OnHeapHnswGraph(maxConn, levelOfFirstNode); this.hnsw = new Lucene91OnHeapHnswGraph(maxConn, levelOfFirstNode);
this.graphSearcher = this.graphSearcher =
new HnswGraphSearcher( new HnswGraphSearcher<>(
VectorEncoding.FLOAT32,
similarityFunction, similarityFunction,
new NeighborQueue(beamWidth, true), new NeighborQueue(beamWidth, true),
new FixedBitSet(vectorValues.size())); new FixedBitSet(vectorValues.size()));

View File

@ -18,6 +18,7 @@
package org.apache.lucene.backward_codecs.lucene91; package org.apache.lucene.backward_codecs.lucene91;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -34,8 +35,10 @@ import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer; import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
@ -244,6 +247,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
target, target,
k, k,
vectorValues, vectorValues,
VectorEncoding.FLOAT32,
fieldEntry.similarityFunction, fieldEntry.similarityFunction,
getGraph(fieldEntry), getGraph(fieldEntry),
getAcceptOrds(acceptDocs, fieldEntry), getAcceptOrds(acceptDocs, fieldEntry),
@ -265,6 +269,21 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs); return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
// The field does not exist or does not index vectors
return EMPTY_TOPDOCS;
}
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
VectorValues vectorValues = getVectorValues(field);
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
}
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException { private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
IndexInput bytesSlice = IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength); vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);

View File

@ -144,8 +144,8 @@
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90SegmentInfoFormat Segment info}. This * <li>{@link org.apache.lucene.codecs.lucene90.Lucene90SegmentInfoFormat Segment info}. This
* contains metadata about a segment, such as the number of documents, what files it uses, and * contains metadata about a segment, such as the number of documents, what files it uses, and
* information about how the segment is sorted * information about how the segment is sorted
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Field names}. This * <li>{@link org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat Field names}.
* contains metadata about the set of named fields used in the index. * This contains metadata about the set of named fields used in the index.
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}. * <li>{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}.
* This contains, for each document, a list of attribute-value pairs, where the attributes are * This contains, for each document, a list of attribute-value pairs, where the attributes are
* field names. These are used to store auxiliary information about the document, such as its * field names. These are used to store auxiliary information about the document, such as its
@ -240,7 +240,7 @@
* systems that frequently run out of file handles.</td> * systems that frequently run out of file handles.</td>
* </tr> * </tr>
* <tr> * <tr>
* <td>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Fields}</td> * <td>{@link org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat Fields}</td>
* <td>.fnm</td> * <td>.fnm</td>
* <td>Stores information about the fields</td> * <td>Stores information about the fields</td>
* </tr> * </tr>

View File

@ -17,6 +17,7 @@
package org.apache.lucene.backward_codecs.lucene92; package org.apache.lucene.backward_codecs.lucene92;
import java.util.Objects; import java.util.Objects;
import org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat;
import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.CompoundFormat; import org.apache.lucene.codecs.CompoundFormat;
import org.apache.lucene.codecs.DocValuesFormat; import org.apache.lucene.codecs.DocValuesFormat;
@ -32,7 +33,6 @@ import org.apache.lucene.codecs.StoredFieldsFormat;
import org.apache.lucene.codecs.TermVectorsFormat; import org.apache.lucene.codecs.TermVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat; import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
import org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat;
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat; import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat; import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat; import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;

View File

@ -18,6 +18,7 @@
package org.apache.lucene.backward_codecs.lucene92; package org.apache.lucene.backward_codecs.lucene92;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
@ -30,8 +31,10 @@ import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
@ -237,6 +240,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
target, target,
k, k,
vectorValues, vectorValues,
VectorEncoding.FLOAT32,
fieldEntry.similarityFunction, fieldEntry.similarityFunction,
getGraph(fieldEntry), getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs), vectorValues.getAcceptOrds(acceptDocs),
@ -258,6 +262,21 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs); return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
// The field does not exist or does not index vectors
return EMPTY_TOPDOCS;
}
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
VectorValues vectorValues = getVectorValues(field);
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
}
/** Get knn graph values; used for testing */ /** Get knn graph values; used for testing */
public HnswGraph getGraph(String field) throws IOException { public HnswGraph getGraph(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field); FieldInfo info = fieldInfos.fieldInfo(field);

View File

@ -144,8 +144,8 @@
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90SegmentInfoFormat Segment info}. This * <li>{@link org.apache.lucene.codecs.lucene90.Lucene90SegmentInfoFormat Segment info}. This
* contains metadata about a segment, such as the number of documents, what files it uses, and * contains metadata about a segment, such as the number of documents, what files it uses, and
* information about how the segment is sorted * information about how the segment is sorted
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Field names}. This * <li>{@link org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat Field names}.
* contains metadata about the set of named fields used in the index. * This contains metadata about the set of named fields used in the index.
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}. * <li>{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}.
* This contains, for each document, a list of attribute-value pairs, where the attributes are * This contains, for each document, a list of attribute-value pairs, where the attributes are
* field names. These are used to store auxiliary information about the document, such as its * field names. These are used to store auxiliary information about the document, such as its
@ -240,7 +240,7 @@
* systems that frequently run out of file handles.</td> * systems that frequently run out of file handles.</td>
* </tr> * </tr>
* <tr> * <tr>
* <td>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Fields}</td> * <td>{@link org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat Fields}</td>
* <td>.fnm</td> * <td>.fnm</td>
* <td>Stores information about the fields</td> * <td>Stores information about the fields</td>
* </tr> * </tr>

View File

@ -31,6 +31,7 @@ import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValuesProducer; import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DocIdSetIterator;
@ -148,7 +149,8 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
OnHeapHnswGraph graph = OnHeapHnswGraph graph =
offHeapVectors.size() == 0 offHeapVectors.size() == 0
? null ? null
: writeGraph(offHeapVectors, fieldInfo.getVectorSimilarityFunction()); : writeGraph(
offHeapVectors, VectorEncoding.FLOAT32, fieldInfo.getVectorSimilarityFunction());
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset; long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
writeMeta( writeMeta(
fieldInfo, fieldInfo,
@ -266,13 +268,20 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
} }
private OnHeapHnswGraph writeGraph( private OnHeapHnswGraph writeGraph(
RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction) RandomAccessVectorValuesProducer vectorValues,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction)
throws IOException { throws IOException {
// build graph // build graph
HnswGraphBuilder hnswGraphBuilder = HnswGraphBuilder<?> hnswGraphBuilder =
new HnswGraphBuilder( HnswGraphBuilder.create(
vectorValues, similarityFunction, M, beamWidth, HnswGraphBuilder.randSeed); vectorValues,
vectorEncoding,
similarityFunction,
M,
beamWidth,
HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess()); OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());

View File

@ -28,6 +28,7 @@ import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.SegmentInfo; import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
@ -68,7 +69,8 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
static final BytesRef INDEX_DIM_COUNT = new BytesRef(" index dimensional count "); static final BytesRef INDEX_DIM_COUNT = new BytesRef(" index dimensional count ");
static final BytesRef DIM_NUM_BYTES = new BytesRef(" dimensional num bytes "); static final BytesRef DIM_NUM_BYTES = new BytesRef(" dimensional num bytes ");
static final BytesRef VECTOR_NUM_DIMS = new BytesRef(" vector number of dimensions "); static final BytesRef VECTOR_NUM_DIMS = new BytesRef(" vector number of dimensions ");
static final BytesRef VECTOR_SEARCH_STRATEGY = new BytesRef(" vector search strategy "); static final BytesRef VECTOR_ENCODING = new BytesRef(" vector encoding ");
static final BytesRef VECTOR_SIMILARITY = new BytesRef(" vector similarity ");
static final BytesRef SOFT_DELETES = new BytesRef(" soft-deletes "); static final BytesRef SOFT_DELETES = new BytesRef(" soft-deletes ");
@Override @Override
@ -156,8 +158,13 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
int vectorNumDimensions = Integer.parseInt(readString(VECTOR_NUM_DIMS.length, scratch)); int vectorNumDimensions = Integer.parseInt(readString(VECTOR_NUM_DIMS.length, scratch));
SimpleTextUtil.readLine(input, scratch); SimpleTextUtil.readLine(input, scratch);
assert StringHelper.startsWith(scratch.get(), VECTOR_SEARCH_STRATEGY); assert StringHelper.startsWith(scratch.get(), VECTOR_ENCODING);
String scoreFunction = readString(VECTOR_SEARCH_STRATEGY.length, scratch); String encoding = readString(VECTOR_ENCODING.length, scratch);
VectorEncoding vectorEncoding = vectorEncoding(encoding);
SimpleTextUtil.readLine(input, scratch);
assert StringHelper.startsWith(scratch.get(), VECTOR_SIMILARITY);
String scoreFunction = readString(VECTOR_SIMILARITY.length, scratch);
VectorSimilarityFunction vectorDistFunc = distanceFunction(scoreFunction); VectorSimilarityFunction vectorDistFunc = distanceFunction(scoreFunction);
SimpleTextUtil.readLine(input, scratch); SimpleTextUtil.readLine(input, scratch);
@ -179,6 +186,7 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
indexDimensionalCount, indexDimensionalCount,
dimensionalNumBytes, dimensionalNumBytes,
vectorNumDimensions, vectorNumDimensions,
vectorEncoding,
vectorDistFunc, vectorDistFunc,
isSoftDeletesField); isSoftDeletesField);
} }
@ -201,6 +209,10 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
return DocValuesType.valueOf(dvType); return DocValuesType.valueOf(dvType);
} }
public VectorEncoding vectorEncoding(String vectorEncoding) {
return VectorEncoding.valueOf(vectorEncoding);
}
public VectorSimilarityFunction distanceFunction(String scoreFunction) { public VectorSimilarityFunction distanceFunction(String scoreFunction) {
return VectorSimilarityFunction.valueOf(scoreFunction); return VectorSimilarityFunction.valueOf(scoreFunction);
} }
@ -297,7 +309,11 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
SimpleTextUtil.write(out, Integer.toString(fi.getVectorDimension()), scratch); SimpleTextUtil.write(out, Integer.toString(fi.getVectorDimension()), scratch);
SimpleTextUtil.writeNewline(out); SimpleTextUtil.writeNewline(out);
SimpleTextUtil.write(out, VECTOR_SEARCH_STRATEGY); SimpleTextUtil.write(out, VECTOR_ENCODING);
SimpleTextUtil.write(out, fi.getVectorEncoding().name(), scratch);
SimpleTextUtil.writeNewline(out);
SimpleTextUtil.write(out, VECTOR_SIMILARITY);
SimpleTextUtil.write(out, fi.getVectorSimilarityFunction().name(), scratch); SimpleTextUtil.write(out, fi.getVectorSimilarityFunction().name(), scratch);
SimpleTextUtil.writeNewline(out); SimpleTextUtil.writeNewline(out);

View File

@ -42,6 +42,7 @@ import org.apache.lucene.store.BufferedChecksumIndexInput;
import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder; import org.apache.lucene.util.BytesRefBuilder;
@ -181,6 +182,13 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs); return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs);
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
int numDocs = (int) acceptDocs.cost();
return search(field, target, k, BitSet.of(acceptDocs, numDocs), Integer.MAX_VALUE);
}
@Override @Override
public void checkIntegrity() throws IOException { public void checkIntegrity() throws IOException {
IndexInput clone = dataIn.clone(); IndexInput clone = dataIn.clone();

View File

@ -23,6 +23,7 @@ import org.apache.lucene.codecs.lucene90.tests.MockTermStateFactory;
import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ByteBuffersDataOutput; import org.apache.lucene.store.ByteBuffersDataOutput;
import org.apache.lucene.store.ByteBuffersIndexOutput; import org.apache.lucene.store.ByteBuffersIndexOutput;
@ -116,6 +117,7 @@ public class TestBlockWriter extends LuceneTestCase {
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
true); true);
} }

View File

@ -41,6 +41,7 @@ import org.apache.lucene.index.ImpactsEnum;
import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ByteBuffersDirectory; import org.apache.lucene.store.ByteBuffersDirectory;
import org.apache.lucene.store.DataInput; import org.apache.lucene.store.DataInput;
@ -203,6 +204,7 @@ public class TestSTBlockReader extends LuceneTestCase {
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
false); false);
} }

View File

@ -20,8 +20,12 @@ package org.apache.lucene.codecs;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.util.Accountable; import org.apache.lucene.util.Accountable;
/** Vectors' writer for a field */ /**
public abstract class KnnFieldVectorsWriter implements Accountable { * Vectors' writer for a field
*
* @param <T> an array type; the type of vectors to be written
*/
public abstract class KnnFieldVectorsWriter<T> implements Accountable {
/** Sole constructor */ /** Sole constructor */
protected KnnFieldVectorsWriter() {} protected KnnFieldVectorsWriter() {}
@ -30,5 +34,13 @@ public abstract class KnnFieldVectorsWriter implements Accountable {
* Add new docID with its vector value to the given field for indexing. Doc IDs must be added in * Add new docID with its vector value to the given field for indexing. Doc IDs must be added in
* increasing order. * increasing order.
*/ */
public abstract void addValue(int docID, float[] vectorValue) throws IOException; public abstract void addValue(int docID, Object vectorValue) throws IOException;
/**
* Used to copy values being indexed to internal storage.
*
* @param vectorValue an array containing the vector value to add
* @return a copy of the value; a new array
*/
public abstract T copyValue(T vectorValue);
} }

View File

@ -18,9 +18,11 @@
package org.apache.lucene.codecs; package org.apache.lucene.codecs;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.codecs.lucene94.Lucene94HnswVectorsFormat;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
@ -76,6 +78,15 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
/** Returns a {@link KnnVectorsReader} to read the vectors from the index. */ /** Returns a {@link KnnVectorsReader} to read the vectors from the index. */
public abstract KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException; public abstract KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException;
/**
* Returns the current KnnVectorsFormat version number. Indexes written using the format will be
* "stamped" with this version.
*/
public int currentVersion() {
// return the version supported by older codecs that did not override this method
return Lucene94HnswVectorsFormat.VERSION_START;
}
/** /**
* EMPTY throws an exception when written. It acts as a sentinel indicating a Codec that does not * EMPTY throws an exception when written. It acts as a sentinel indicating a Codec that does not
* support vectors. * support vectors.
@ -104,6 +115,12 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
return TopDocsCollector.EMPTY_TOPDOCS; return TopDocsCollector.EMPTY_TOPDOCS;
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
return TopDocsCollector.EMPTY_TOPDOCS;
}
@Override @Override
public void close() {} public void close() {}

View File

@ -20,12 +20,16 @@ package org.apache.lucene.codecs;
import java.io.Closeable; import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.Accountable; import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/** Reads vectors from an index. */ /** Reads vectors from an index. */
public abstract class KnnVectorsReader implements Closeable, Accountable { public abstract class KnnVectorsReader implements Closeable, Accountable {
@ -75,11 +79,39 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
* if they are all allowed to match. * if they are all allowed to match.
* @param visitedLimit the maximum number of nodes that the search is allowed to visit * @param visitedLimit the maximum number of nodes that the search is allowed to visit
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores. * @return the k nearest neighbor documents, along with their (similarity-specific) scores.
*/ */
public abstract TopDocs search( public abstract TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException; String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
/**
* Return the k nearest neighbor documents as determined by comparison of their vector values for
* this field, to the given vector, by the field's similarity function. The score of each document
* is derived from the vector similarity in a way that ensures scores are positive and that a
* larger score corresponds to a higher ranking.
*
* <p>The search is exact, guaranteeing the true k closest neighbors will be returned. Typically
* this requires an exhaustive scan of the entire index. It is intended to be used when the number
* of potential matches is limited.
*
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor, in
* order of their similarity to the query vector (decreasing scores). The {@link TotalHits}
* contains the number of documents visited during the search. If the search stopped early because
* it hit {@code visitedLimit}, it is indicated through the relation {@code
* TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
*
* <p>The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link
* FieldInfo}. The return value is never {@code null}.
*
* @param field the vector field to search
* @param target the vector-valued query
* @param k the number of docs to return
* @param acceptDocs {@link DocIdSetIterator} that represents the allowed documents to match.
* @return the k nearest neighbor documents, along with their (similarity-specific) scores.
*/
public abstract TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException;
/** /**
* Returns an instance optimized for merging. This instance may only be consumed in the thread * Returns an instance optimized for merging. This instance may only be consumed in the thread
* that called {@link #getMergeInstance()}. * that called {@link #getMergeInstance()}.
@ -89,4 +121,67 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
public KnnVectorsReader getMergeInstance() { public KnnVectorsReader getMergeInstance() {
return this; return this;
} }
/** {@link #searchExhaustively} */
protected static TopDocs exhaustiveSearch(
VectorValues vectorValues,
DocIdSetIterator acceptDocs,
VectorSimilarityFunction similarityFunction,
float[] target,
int k)
throws IOException {
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc = queue.top();
int doc;
while ((doc = acceptDocs.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
int vectorDoc = vectorValues.advance(doc);
assert vectorDoc == doc;
float score = similarityFunction.compare(vectorValues.vectorValue(), target);
if (score >= topDoc.score) {
topDoc.score = score;
topDoc.doc = doc;
topDoc = queue.updateTop();
}
}
return topDocsFromHitQueue(queue, acceptDocs.cost());
}
/** {@link #searchExhaustively} */
protected static TopDocs exhaustiveSearch(
VectorValues vectorValues,
DocIdSetIterator acceptDocs,
VectorSimilarityFunction similarityFunction,
BytesRef target,
int k)
throws IOException {
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc = queue.top();
int doc;
while ((doc = acceptDocs.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
int vectorDoc = vectorValues.advance(doc);
assert vectorDoc == doc;
float score = similarityFunction.compare(vectorValues.binaryValue(), target);
if (score >= topDoc.score) {
topDoc.score = score;
topDoc.doc = doc;
topDoc = queue.updateTop();
}
}
return topDocsFromHitQueue(queue, acceptDocs.cost());
}
private static TopDocs topDocsFromHitQueue(HitQueue queue, long numHits) {
// Remove any remaining sentinel values
while (queue.size() > 0 && queue.top().score < 0) {
queue.pop();
}
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
topScoreDocs[i] = queue.pop();
}
TotalHits totalHits = new TotalHits(numHits, TotalHits.Relation.EQUAL_TO);
return new TopDocs(totalHits, topScoreDocs);
}
} }

View File

@ -37,14 +37,15 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
protected KnnVectorsWriter() {} protected KnnVectorsWriter() {}
/** Add new field for indexing */ /** Add new field for indexing */
public abstract KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException; public abstract KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException;
/** Flush all buffered data on disk * */ /** Flush all buffered data on disk * */
public abstract void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException; public abstract void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException;
/** Write field for merging */ /** Write field for merging */
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { @SuppressWarnings("unchecked")
KnnFieldVectorsWriter writer = addField(fieldInfo); public <T> void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
KnnFieldVectorsWriter<T> writer = (KnnFieldVectorsWriter<T>) addField(fieldInfo);
VectorValues mergedValues = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState); VectorValues mergedValues = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
for (int doc = mergedValues.nextDoc(); for (int doc = mergedValues.nextDoc();
doc != DocIdSetIterator.NO_MORE_DOCS; doc != DocIdSetIterator.NO_MORE_DOCS;

View File

@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene94;
import java.io.IOException;
import org.apache.lucene.index.FilterVectorValues;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BytesRef;
/** reads from byte-encoded data */
public class ExpandingVectorValues extends FilterVectorValues {
private final float[] value;
/** @param in the wrapped values */
protected ExpandingVectorValues(VectorValues in) {
super(in);
value = new float[in.dimension()];
}
@Override
public float[] vectorValue() throws IOException {
BytesRef binaryValue = binaryValue();
byte[] bytes = binaryValue.bytes;
for (int i = 0, j = binaryValue.offset; i < value.length; i++, j++) {
value[i] = bytes[j];
}
return value;
}
}

View File

@ -32,7 +32,6 @@ import org.apache.lucene.codecs.StoredFieldsFormat;
import org.apache.lucene.codecs.TermVectorsFormat; import org.apache.lucene.codecs.TermVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat; import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
import org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat;
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat; import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat; import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat; import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;
@ -69,7 +68,7 @@ public class Lucene94Codec extends Codec {
} }
private final TermVectorsFormat vectorsFormat = new Lucene90TermVectorsFormat(); private final TermVectorsFormat vectorsFormat = new Lucene90TermVectorsFormat();
private final FieldInfosFormat fieldInfosFormat = new Lucene90FieldInfosFormat(); private final FieldInfosFormat fieldInfosFormat = new Lucene94FieldInfosFormat();
private final SegmentInfoFormat segmentInfosFormat = new Lucene90SegmentInfoFormat(); private final SegmentInfoFormat segmentInfosFormat = new Lucene90SegmentInfoFormat();
private final LiveDocsFormat liveDocsFormat = new Lucene90LiveDocsFormat(); private final LiveDocsFormat liveDocsFormat = new Lucene90LiveDocsFormat();
private final CompoundFormat compoundFormat = new Lucene90CompoundFormat(); private final CompoundFormat compoundFormat = new Lucene90CompoundFormat();
@ -100,6 +99,11 @@ public class Lucene94Codec extends Codec {
public KnnVectorsFormat getKnnVectorsFormatForField(String field) { public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return Lucene94Codec.this.getKnnVectorsFormatForField(field); return Lucene94Codec.this.getKnnVectorsFormatForField(field);
} }
@Override
public int currentVersion() {
return Lucene94HnswVectorsFormat.VERSION_CURRENT;
}
}; };
private final StoredFieldsFormat storedFieldsFormat; private final StoredFieldsFormat storedFieldsFormat;

View File

@ -0,0 +1,385 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene94;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.FieldInfosFormat;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataOutput;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
/**
* Lucene 9.0 Field Infos format.
*
* <p>Field names are stored in the field info file, with suffix <code>.fnm</code>.
*
* <p>FieldInfos (.fnm) --&gt; Header,FieldsCount, &lt;FieldName,FieldNumber,
* FieldBits,DocValuesBits,DocValuesGen,Attributes,DimensionCount,DimensionNumBytes&gt;
* <sup>FieldsCount</sup>,Footer
*
* <p>Data types:
*
* <ul>
* <li>Header --&gt; {@link CodecUtil#checkIndexHeader IndexHeader}
* <li>FieldsCount --&gt; {@link DataOutput#writeVInt VInt}
* <li>FieldName --&gt; {@link DataOutput#writeString String}
* <li>FieldBits, IndexOptions, DocValuesBits --&gt; {@link DataOutput#writeByte Byte}
* <li>FieldNumber, DimensionCount, DimensionNumBytes --&gt; {@link DataOutput#writeInt VInt}
* <li>Attributes --&gt; {@link DataOutput#writeMapOfStrings Map&lt;String,String&gt;}
* <li>DocValuesGen --&gt; {@link DataOutput#writeLong(long) Int64}
* <li>Footer --&gt; {@link CodecUtil#writeFooter CodecFooter}
* </ul>
*
* Field Descriptions:
*
* <ul>
* <li>FieldsCount: the number of fields in this file.
* <li>FieldName: name of the field as a UTF-8 String.
* <li>FieldNumber: the field's number. Note that unlike previous versions of Lucene, the fields
* are not numbered implicitly by their order in the file, instead explicitly.
* <li>FieldBits: a byte containing field options.
* <ul>
* <li>The low order bit (0x1) is one for fields that have term vectors stored, and zero for
* fields without term vectors.
* <li>If the second lowest order-bit is set (0x2), norms are omitted for the indexed field.
* <li>If the third lowest-order bit is set (0x4), payloads are stored for the indexed
* field.
* </ul>
* <li>IndexOptions: a byte containing index options.
* <ul>
* <li>0: not indexed
* <li>1: indexed as DOCS_ONLY
* <li>2: indexed as DOCS_AND_FREQS
* <li>3: indexed as DOCS_AND_FREQS_AND_POSITIONS
* <li>4: indexed as DOCS_AND_FREQS_AND_POSITIONS_AND_OFFSETS
* </ul>
* <li>DocValuesBits: a byte containing per-document value types. The type recorded as two
* four-bit integers, with the high-order bits representing <code>norms</code> options, and
* the low-order bits representing {@code DocValues} options. Each four-bit integer can be
* decoded as such:
* <ul>
* <li>0: no DocValues for this field.
* <li>1: NumericDocValues. ({@link DocValuesType#NUMERIC})
* <li>2: BinaryDocValues. ({@code DocValuesType#BINARY})
* <li>3: SortedDocValues. ({@code DocValuesType#SORTED})
* </ul>
* <li>DocValuesGen is the generation count of the field's DocValues. If this is -1, there are no
* DocValues updates to that field. Anything above zero means there are updates stored by
* {@link DocValuesFormat}.
* <li>Attributes: a key-value map of codec-private attributes.
* <li>PointDimensionCount, PointNumBytes: these are non-zero only if the field is indexed as
* points, e.g. using {@link org.apache.lucene.document.LongPoint}
* <li>VectorDimension: it is non-zero if the field is indexed as vectors.
* <li>VectorEncoding: a byte containing the encoding of vector values:
* <ul>
* <li>0: BYTE. Samples are stored as signed bytes
* <li>1: FLOAT32. Samples are stored in IEEE 32-bit floating point format.
* </ul>
* <li>VectorSimilarityFunction: a byte containing distance function used for similarity
* calculation.
* <ul>
* <li>0: EUCLIDEAN distance. ({@link VectorSimilarityFunction#EUCLIDEAN})
* <li>1: DOT_PRODUCT similarity. ({@link VectorSimilarityFunction#DOT_PRODUCT})
* <li>2: COSINE similarity. ({@link VectorSimilarityFunction#COSINE})
* </ul>
* </ul>
*
* @lucene.experimental
*/
public final class Lucene94FieldInfosFormat extends FieldInfosFormat {
/** Sole constructor. */
public Lucene94FieldInfosFormat() {}
@Override
public FieldInfos read(
Directory directory, SegmentInfo segmentInfo, String segmentSuffix, IOContext context)
throws IOException {
final String fileName =
IndexFileNames.segmentFileName(segmentInfo.name, segmentSuffix, EXTENSION);
try (ChecksumIndexInput input = directory.openChecksumInput(fileName, context)) {
Throwable priorE = null;
FieldInfo[] infos = null;
try {
CodecUtil.checkIndexHeader(
input,
Lucene94FieldInfosFormat.CODEC_NAME,
Lucene94FieldInfosFormat.FORMAT_START,
Lucene94FieldInfosFormat.FORMAT_CURRENT,
segmentInfo.getId(),
segmentSuffix);
final int size = input.readVInt(); // read in the size
infos = new FieldInfo[size];
// previous field's attribute map, we share when possible:
Map<String, String> lastAttributes = Collections.emptyMap();
for (int i = 0; i < size; i++) {
String name = input.readString();
final int fieldNumber = input.readVInt();
if (fieldNumber < 0) {
throw new CorruptIndexException(
"invalid field number for field: " + name + ", fieldNumber=" + fieldNumber, input);
}
byte bits = input.readByte();
boolean storeTermVector = (bits & STORE_TERMVECTOR) != 0;
boolean omitNorms = (bits & OMIT_NORMS) != 0;
boolean storePayloads = (bits & STORE_PAYLOADS) != 0;
boolean isSoftDeletesField = (bits & SOFT_DELETES_FIELD) != 0;
final IndexOptions indexOptions = getIndexOptions(input, input.readByte());
// DV Types are packed in one byte
final DocValuesType docValuesType = getDocValuesType(input, input.readByte());
final long dvGen = input.readLong();
Map<String, String> attributes = input.readMapOfStrings();
// just use the last field's map if its the same
if (attributes.equals(lastAttributes)) {
attributes = lastAttributes;
}
lastAttributes = attributes;
int pointDataDimensionCount = input.readVInt();
int pointNumBytes;
int pointIndexDimensionCount = pointDataDimensionCount;
if (pointDataDimensionCount != 0) {
pointIndexDimensionCount = input.readVInt();
pointNumBytes = input.readVInt();
} else {
pointNumBytes = 0;
}
final int vectorDimension = input.readVInt();
final VectorEncoding vectorEncoding = getVectorEncoding(input, input.readByte());
final VectorSimilarityFunction vectorDistFunc = getDistFunc(input, input.readByte());
try {
infos[i] =
new FieldInfo(
name,
fieldNumber,
storeTermVector,
omitNorms,
storePayloads,
indexOptions,
docValuesType,
dvGen,
attributes,
pointDataDimensionCount,
pointIndexDimensionCount,
pointNumBytes,
vectorDimension,
vectorEncoding,
vectorDistFunc,
isSoftDeletesField);
infos[i].checkConsistency();
} catch (IllegalStateException e) {
throw new CorruptIndexException(
"invalid fieldinfo for field: " + name + ", fieldNumber=" + fieldNumber, input, e);
}
}
} catch (Throwable exception) {
priorE = exception;
} finally {
CodecUtil.checkFooter(input, priorE);
}
return new FieldInfos(infos);
}
}
static {
// We "mirror" DocValues enum values with the constants below; let's try to ensure if we add a
// new DocValuesType while this format is
// still used for writing, we remember to fix this encoding:
assert DocValuesType.values().length == 6;
}
private static byte docValuesByte(DocValuesType type) {
switch (type) {
case NONE:
return 0;
case NUMERIC:
return 1;
case BINARY:
return 2;
case SORTED:
return 3;
case SORTED_SET:
return 4;
case SORTED_NUMERIC:
return 5;
default:
// BUG
throw new AssertionError("unhandled DocValuesType: " + type);
}
}
private static DocValuesType getDocValuesType(IndexInput input, byte b) throws IOException {
switch (b) {
case 0:
return DocValuesType.NONE;
case 1:
return DocValuesType.NUMERIC;
case 2:
return DocValuesType.BINARY;
case 3:
return DocValuesType.SORTED;
case 4:
return DocValuesType.SORTED_SET;
case 5:
return DocValuesType.SORTED_NUMERIC;
default:
throw new CorruptIndexException("invalid docvalues byte: " + b, input);
}
}
private static VectorEncoding getVectorEncoding(IndexInput input, byte b) throws IOException {
if (b < 0 || b >= VectorEncoding.values().length) {
throw new CorruptIndexException("invalid vector encoding: " + b, input);
}
return VectorEncoding.values()[b];
}
private static VectorSimilarityFunction getDistFunc(IndexInput input, byte b) throws IOException {
if (b < 0 || b >= VectorSimilarityFunction.values().length) {
throw new CorruptIndexException("invalid distance function: " + b, input);
}
return VectorSimilarityFunction.values()[b];
}
static {
// We "mirror" IndexOptions enum values with the constants below; let's try to ensure if we add
// a new IndexOption while this format is
// still used for writing, we remember to fix this encoding:
assert IndexOptions.values().length == 5;
}
private static byte indexOptionsByte(IndexOptions indexOptions) {
switch (indexOptions) {
case NONE:
return 0;
case DOCS:
return 1;
case DOCS_AND_FREQS:
return 2;
case DOCS_AND_FREQS_AND_POSITIONS:
return 3;
case DOCS_AND_FREQS_AND_POSITIONS_AND_OFFSETS:
return 4;
default:
// BUG:
throw new AssertionError("unhandled IndexOptions: " + indexOptions);
}
}
private static IndexOptions getIndexOptions(IndexInput input, byte b) throws IOException {
switch (b) {
case 0:
return IndexOptions.NONE;
case 1:
return IndexOptions.DOCS;
case 2:
return IndexOptions.DOCS_AND_FREQS;
case 3:
return IndexOptions.DOCS_AND_FREQS_AND_POSITIONS;
case 4:
return IndexOptions.DOCS_AND_FREQS_AND_POSITIONS_AND_OFFSETS;
default:
// BUG
throw new CorruptIndexException("invalid IndexOptions byte: " + b, input);
}
}
@Override
public void write(
Directory directory,
SegmentInfo segmentInfo,
String segmentSuffix,
FieldInfos infos,
IOContext context)
throws IOException {
final String fileName =
IndexFileNames.segmentFileName(segmentInfo.name, segmentSuffix, EXTENSION);
try (IndexOutput output = directory.createOutput(fileName, context)) {
CodecUtil.writeIndexHeader(
output,
Lucene94FieldInfosFormat.CODEC_NAME,
Lucene94FieldInfosFormat.FORMAT_CURRENT,
segmentInfo.getId(),
segmentSuffix);
output.writeVInt(infos.size());
for (FieldInfo fi : infos) {
fi.checkConsistency();
output.writeString(fi.name);
output.writeVInt(fi.number);
byte bits = 0x0;
if (fi.hasVectors()) bits |= STORE_TERMVECTOR;
if (fi.omitsNorms()) bits |= OMIT_NORMS;
if (fi.hasPayloads()) bits |= STORE_PAYLOADS;
if (fi.isSoftDeletesField()) bits |= SOFT_DELETES_FIELD;
output.writeByte(bits);
output.writeByte(indexOptionsByte(fi.getIndexOptions()));
// pack the DV type and hasNorms in one byte
output.writeByte(docValuesByte(fi.getDocValuesType()));
output.writeLong(fi.getDocValuesGen());
output.writeMapOfStrings(fi.attributes());
output.writeVInt(fi.getPointDimensionCount());
if (fi.getPointDimensionCount() != 0) {
output.writeVInt(fi.getPointIndexDimensionCount());
output.writeVInt(fi.getPointNumBytes());
}
output.writeVInt(fi.getVectorDimension());
output.writeByte((byte) fi.getVectorEncoding().ordinal());
output.writeByte((byte) fi.getVectorSimilarityFunction().ordinal());
}
CodecUtil.writeFooter(output);
}
}
/** Extension of field infos */
static final String EXTENSION = "fnm";
// Codec header
static final String CODEC_NAME = "Lucene90FieldInfos";
static final int FORMAT_START = 0;
static final int FORMAT_CURRENT = FORMAT_START;
// Field flags
static final byte STORE_TERMVECTOR = 0x1;
static final byte OMIT_NORMS = 0x2;
static final byte STORE_PAYLOADS = 0x4;
static final byte SOFT_DELETES_FIELD = 0x8;
}

View File

@ -38,8 +38,9 @@ import org.apache.lucene.util.hnsw.HnswGraph;
* <p>For each field: * <p>For each field:
* *
* <ul> * <ul>
* <li>Floating-point vector data ordered by field, document ordinal, and vector dimension. The * <li>Vector data ordered by field, document ordinal, and vector dimension. When the
* floats are stored in little-endian byte order * vectorEncoding is BYTE, each sample is stored as a single byte. When it is FLOAT32, each
* sample is stored as an IEEE float in little-endian byte order.
* <li>DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)}, * <li>DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)},
* note that only in sparse case * note that only in sparse case
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note * <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
@ -89,7 +90,7 @@ import org.apache.lucene.util.hnsw.HnswGraph;
* <ul> * <ul>
* <li><b>[int]</b> the number of nodes on this level * <li><b>[int]</b> the number of nodes on this level
* <li><b>array[int]</b> for levels greater than 0 list of nodes on this level, stored as * <li><b>array[int]</b> for levels greater than 0 list of nodes on this level, stored as
* the the level 0th nodes ordinals. * the level 0th nodes' ordinals.
* </ul> * </ul>
* </ul> * </ul>
* *
@ -104,8 +105,8 @@ public final class Lucene94HnswVectorsFormat extends KnnVectorsFormat {
static final String VECTOR_DATA_EXTENSION = "vec"; static final String VECTOR_DATA_EXTENSION = "vec";
static final String VECTOR_INDEX_EXTENSION = "vex"; static final String VECTOR_INDEX_EXTENSION = "vex";
static final int VERSION_START = 0; public static final int VERSION_START = 0;
static final int VERSION_CURRENT = VERSION_START; public static final int VERSION_CURRENT = 1;
/** Default number of maximum connections per node */ /** Default number of maximum connections per node */
public static final int DEFAULT_MAX_CONN = 16; public static final int DEFAULT_MAX_CONN = 16;
@ -156,6 +157,11 @@ public final class Lucene94HnswVectorsFormat extends KnnVectorsFormat {
return new Lucene94HnswVectorsReader(state); return new Lucene94HnswVectorsReader(state);
} }
@Override
public int currentVersion() {
return VERSION_CURRENT;
}
@Override @Override
public String toString() { public String toString() {
return "Lucene94HnswVectorsFormat(name=Lucene94HnswVectorsFormat, maxConn=" return "Lucene94HnswVectorsFormat(name=Lucene94HnswVectorsFormat, maxConn="

View File

@ -18,6 +18,8 @@
package org.apache.lucene.codecs.lucene94; package org.apache.lucene.codecs.lucene94;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
import static org.apache.lucene.util.VectorUtil.toBytesRef;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
@ -30,8 +32,10 @@ import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
@ -169,16 +173,23 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
+ fieldEntry.dimension); + fieldEntry.dimension);
} }
long numBytes = (long) fieldEntry.size() * dimension * Float.BYTES; int byteSize =
switch (info.getVectorEncoding()) {
case BYTE -> Byte.BYTES;
case FLOAT32 -> Float.BYTES;
};
int numBytes = fieldEntry.size * dimension * byteSize;
if (numBytes != fieldEntry.vectorDataLength) { if (numBytes != fieldEntry.vectorDataLength) {
throw new IllegalStateException( throw new IllegalStateException(
"Vector data length " "Vector data length "
+ fieldEntry.vectorDataLength + fieldEntry.vectorDataLength
+ " not matching size=" + " not matching size="
+ fieldEntry.size() + fieldEntry.size
+ " * dim=" + " * dim="
+ dimension + dimension
+ " * 4 = " + " * byteSize="
+ byteSize
+ " = "
+ numBytes); + numBytes);
} }
} }
@ -193,9 +204,18 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
return VectorSimilarityFunction.values()[similarityFunctionId]; return VectorSimilarityFunction.values()[similarityFunctionId];
} }
private VectorEncoding readVectorEncoding(DataInput input) throws IOException {
int encodingId = input.readInt();
if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
}
return VectorEncoding.values()[encodingId];
}
private FieldEntry readField(IndexInput input) throws IOException { private FieldEntry readField(IndexInput input) throws IOException {
VectorEncoding vectorEncoding = readVectorEncoding(input);
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
return new FieldEntry(input, similarityFunction); return new FieldEntry(input, vectorEncoding, similarityFunction);
} }
@Override @Override
@ -216,7 +236,12 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
@Override @Override
public VectorValues getVectorValues(String field) throws IOException { public VectorValues getVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field); FieldEntry fieldEntry = fields.get(field);
return OffHeapVectorValues.load(fieldEntry, vectorData); VectorValues values = OffHeapVectorValues.load(fieldEntry, vectorData);
if (fieldEntry.vectorEncoding == VectorEncoding.BYTE) {
return new ExpandingVectorValues(values);
} else {
return values;
}
} }
@Override @Override
@ -237,6 +262,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
target, target,
k, k,
vectorValues, vectorValues,
fieldEntry.vectorEncoding,
fieldEntry.similarityFunction, fieldEntry.similarityFunction,
getGraph(fieldEntry), getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs), vectorValues.getAcceptOrds(acceptDocs),
@ -258,6 +284,25 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs); return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
// The field does not exist or does not index vectors
return EMPTY_TOPDOCS;
}
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
VectorValues vectorValues = getVectorValues(field);
return switch (fieldEntry.vectorEncoding) {
case BYTE -> exhaustiveSearch(
vectorValues, acceptDocs, similarityFunction, toBytesRef(target), k);
case FLOAT32 -> exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
};
}
/** Get knn graph values; used for testing */ /** Get knn graph values; used for testing */
public HnswGraph getGraph(String field) throws IOException { public HnswGraph getGraph(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field); FieldInfo info = fieldInfos.fieldInfo(field);
@ -286,6 +331,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
static class FieldEntry { static class FieldEntry {
final VectorSimilarityFunction similarityFunction; final VectorSimilarityFunction similarityFunction;
final VectorEncoding vectorEncoding;
final long vectorDataOffset; final long vectorDataOffset;
final long vectorDataLength; final long vectorDataLength;
final long vectorIndexOffset; final long vectorIndexOffset;
@ -315,8 +361,13 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
final DirectMonotonicReader.Meta meta; final DirectMonotonicReader.Meta meta;
final long addressesLength; final long addressesLength;
FieldEntry(IndexInput input, VectorSimilarityFunction similarityFunction) throws IOException { FieldEntry(
IndexInput input,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction)
throws IOException {
this.similarityFunction = similarityFunction; this.similarityFunction = similarityFunction;
this.vectorEncoding = vectorEncoding;
vectorDataOffset = input.readVLong(); vectorDataOffset = input.readVLong();
vectorDataLength = input.readVLong(); vectorDataLength = input.readVLong();
vectorIndexOffset = input.readVLong(); vectorIndexOffset = input.readVLong();

View File

@ -65,7 +65,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
private final int M; private final int M;
private final int beamWidth; private final int beamWidth;
private final List<FieldWriter> fields = new ArrayList<>(); private final List<FieldWriter<?>> fields = new ArrayList<>();
private boolean finished; private boolean finished;
Lucene94HnswVectorsWriter(SegmentWriteState state, int M, int beamWidth) throws IOException { Lucene94HnswVectorsWriter(SegmentWriteState state, int M, int beamWidth) throws IOException {
@ -121,15 +121,16 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
} }
@Override @Override
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
FieldWriter newField = new FieldWriter(fieldInfo, M, beamWidth, segmentWriteState.infoStream); FieldWriter<?> newField =
FieldWriter.create(fieldInfo, M, beamWidth, segmentWriteState.infoStream);
fields.add(newField); fields.add(newField);
return newField; return newField;
} }
@Override @Override
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
for (FieldWriter field : fields) { for (FieldWriter<?> field : fields) {
if (sortMap == null) { if (sortMap == null) {
writeField(field, maxDoc); writeField(field, maxDoc);
} else { } else {
@ -159,22 +160,20 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
@Override @Override
public long ramBytesUsed() { public long ramBytesUsed() {
long total = 0; long total = 0;
for (FieldWriter field : fields) { for (FieldWriter<?> field : fields) {
total += field.ramBytesUsed(); total += field.ramBytesUsed();
} }
return total; return total;
} }
private void writeField(FieldWriter fieldData, int maxDoc) throws IOException { private void writeField(FieldWriter<?> fieldData, int maxDoc) throws IOException {
// write vector values // write vector values
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
final ByteBuffer buffer = switch (fieldData.fieldInfo.getVectorEncoding()) {
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); case BYTE -> writeByteVectors(fieldData);
final BytesRef binaryValue = new BytesRef(buffer.array()); case FLOAT32 -> writeFloat32Vectors(fieldData);
for (float[] vector : fieldData.vectors) {
buffer.asFloatBuffer().put(vector);
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
} }
;
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
// write graph // write graph
@ -194,7 +193,24 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
graph); graph);
} }
private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap) private void writeFloat32Vectors(FieldWriter<?> fieldData) throws IOException {
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
final BytesRef binaryValue = new BytesRef(buffer.array());
for (Object v : fieldData.vectors) {
buffer.asFloatBuffer().put((float[]) v);
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
}
}
private void writeByteVectors(FieldWriter<?> fieldData) throws IOException {
for (Object v : fieldData.vectors) {
BytesRef vector = (BytesRef) v;
vectorData.writeBytes(vector.bytes, vector.offset, vector.length);
}
}
private void writeSortingField(FieldWriter<?> fieldData, int maxDoc, Sorter.DocMap sortMap)
throws IOException { throws IOException {
final int[] docIdOffsets = new int[sortMap.size()]; final int[] docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document) int offset = 1; // 0 means no vector for this (field, document)
@ -221,15 +237,11 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
} }
// write vector values // write vector values
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); long vectorDataOffset =
final ByteBuffer buffer = switch (fieldData.fieldInfo.getVectorEncoding()) {
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); case BYTE -> writeSortedByteVectors(fieldData, ordMap);
final BytesRef binaryValue = new BytesRef(buffer.array()); case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap);
for (int ordinal : ordMap) { };
float[] vector = fieldData.vectors.get(ordinal);
buffer.asFloatBuffer().put(vector);
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
}
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
// write graph // write graph
@ -249,6 +261,29 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
mockGraph); mockGraph);
} }
private long writeSortedFloat32Vectors(FieldWriter<?> fieldData, int[] ordMap)
throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
final BytesRef binaryValue = new BytesRef(buffer.array());
for (int ordinal : ordMap) {
float[] vector = (float[]) fieldData.vectors.get(ordinal);
buffer.asFloatBuffer().put(vector);
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
}
return vectorDataOffset;
}
private long writeSortedByteVectors(FieldWriter<?> fieldData, int[] ordMap) throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
for (int ordinal : ordMap) {
byte[] vector = (byte[]) fieldData.vectors.get(ordinal);
vectorData.writeBytes(vector, 0, vector.length);
}
return vectorDataOffset;
}
// reconstruct graph substituting old ordinals with new ordinals // reconstruct graph substituting old ordinals with new ordinals
private HnswGraph reconstructAndWriteGraph( private HnswGraph reconstructAndWriteGraph(
OnHeapHnswGraph graph, int[] newToOldMap, int[] oldToNewMap) throws IOException { OnHeapHnswGraph graph, int[] newToOldMap, int[] oldToNewMap) throws IOException {
@ -354,7 +389,8 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
boolean success = false; boolean success = false;
try { try {
// write the vector data to a temporary file // write the vector data to a temporary file
DocsWithFieldSet docsWithField = writeVectorData(tempVectorData, vectors); DocsWithFieldSet docsWithField =
writeVectorData(tempVectorData, vectors, fieldInfo.getVectorEncoding().byteSize);
CodecUtil.writeFooter(tempVectorData); CodecUtil.writeFooter(tempVectorData);
IOUtils.close(tempVectorData); IOUtils.close(tempVectorData);
@ -365,21 +401,22 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength()); vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength());
CodecUtil.retrieveChecksum(vectorDataInput); CodecUtil.retrieveChecksum(vectorDataInput);
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
long vectorIndexOffset = vectorIndex.getFilePointer(); long vectorIndexOffset = vectorIndex.getFilePointer();
// build the graph using the temporary vector data // build the graph using the temporary vector data
// we use Lucene94HnswVectorsReader.DenseOffHeapVectorValues for the graph construction // we use Lucene94HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
// doesn't need to know docIds // doesn't need to know docIds
// TODO: separate random access vector values from DocIdSetIterator? // TODO: separate random access vector values from DocIdSetIterator?
int byteSize = vectors.dimension() * fieldInfo.getVectorEncoding().byteSize;
OffHeapVectorValues offHeapVectors = OffHeapVectorValues offHeapVectors =
new OffHeapVectorValues.DenseOffHeapVectorValues( new OffHeapVectorValues.DenseOffHeapVectorValues(
vectors.dimension(), docsWithField.cardinality(), vectorDataInput); vectors.dimension(), docsWithField.cardinality(), vectorDataInput, byteSize);
OnHeapHnswGraph graph = null; OnHeapHnswGraph graph = null;
if (offHeapVectors.size() != 0) { if (offHeapVectors.size() != 0) {
// build graph // build graph
HnswGraphBuilder hnswGraphBuilder = HnswGraphBuilder<?> hnswGraphBuilder =
new HnswGraphBuilder( HnswGraphBuilder.create(
offHeapVectors, offHeapVectors,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(), fieldInfo.getVectorSimilarityFunction(),
M, M,
beamWidth, beamWidth,
@ -451,6 +488,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
HnswGraph graph) HnswGraph graph)
throws IOException { throws IOException {
meta.writeInt(field.number); meta.writeInt(field.number);
meta.writeInt(field.getVectorEncoding().ordinal());
meta.writeInt(field.getVectorSimilarityFunction().ordinal()); meta.writeInt(field.getVectorSimilarityFunction().ordinal());
meta.writeVLong(vectorDataOffset); meta.writeVLong(vectorDataOffset);
meta.writeVLong(vectorDataLength); meta.writeVLong(vectorDataLength);
@ -520,13 +558,13 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
/** /**
* Writes the vector values to the output and returns a set of documents that contains vectors. * Writes the vector values to the output and returns a set of documents that contains vectors.
*/ */
private static DocsWithFieldSet writeVectorData(IndexOutput output, VectorValues vectors) private static DocsWithFieldSet writeVectorData(
throws IOException { IndexOutput output, VectorValues vectors, int scalarSize) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet(); DocsWithFieldSet docsWithField = new DocsWithFieldSet();
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) { for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
// write vector // write vector
BytesRef binaryValue = vectors.binaryValue(); BytesRef binaryValue = vectors.binaryValue();
assert binaryValue.length == vectors.dimension() * Float.BYTES; assert binaryValue.length == vectors.dimension() * scalarSize;
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length); output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
docsWithField.add(docV); docsWithField.add(docV);
} }
@ -538,54 +576,69 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
IOUtils.close(meta, vectorData, vectorIndex); IOUtils.close(meta, vectorData, vectorIndex);
} }
private static class FieldWriter extends KnnFieldVectorsWriter { private abstract static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
private final FieldInfo fieldInfo; private final FieldInfo fieldInfo;
private final int dim; private final int dim;
private final DocsWithFieldSet docsWithField; private final DocsWithFieldSet docsWithField;
private final List<float[]> vectors; private final List<T> vectors;
private final RAVectorValues raVectorValues; private final RAVectorValues<T> raVectorValues;
private final HnswGraphBuilder hnswGraphBuilder; private final HnswGraphBuilder<T> hnswGraphBuilder;
private int lastDocID = -1; private int lastDocID = -1;
private int node = 0; private int node = 0;
static FieldWriter<?> create(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
throws IOException {
int dim = fieldInfo.getVectorDimension();
return switch (fieldInfo.getVectorEncoding()) {
case BYTE -> new FieldWriter<BytesRef>(fieldInfo, M, beamWidth, infoStream) {
@Override
public BytesRef copyValue(BytesRef value) {
return new BytesRef(ArrayUtil.copyOfSubArray(value.bytes, value.offset, dim));
}
};
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream) {
@Override
public float[] copyValue(float[] value) {
return ArrayUtil.copyOfSubArray(value, 0, dim);
}
};
};
}
@SuppressWarnings("unchecked")
FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream) FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
throws IOException { throws IOException {
this.fieldInfo = fieldInfo; this.fieldInfo = fieldInfo;
this.dim = fieldInfo.getVectorDimension(); this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet(); this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>(); vectors = new ArrayList<>();
raVectorValues = new RAVectorValues(vectors, dim); raVectorValues = new RAVectorValues<>(vectors, dim);
hnswGraphBuilder = hnswGraphBuilder =
new HnswGraphBuilder( (HnswGraphBuilder<T>)
() -> raVectorValues, HnswGraphBuilder.create(
fieldInfo.getVectorSimilarityFunction(), () -> raVectorValues,
M, fieldInfo.getVectorEncoding(),
beamWidth, fieldInfo.getVectorSimilarityFunction(),
HnswGraphBuilder.randSeed); M,
beamWidth,
HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(infoStream); hnswGraphBuilder.setInfoStream(infoStream);
} }
@Override @Override
public void addValue(int docID, float[] vectorValue) throws IOException { @SuppressWarnings("unchecked")
public void addValue(int docID, Object value) throws IOException {
if (docID == lastDocID) { if (docID == lastDocID) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"VectorValuesField \"" "VectorValuesField \""
+ fieldInfo.name + fieldInfo.name
+ "\" appears more than once in this document (only one value is allowed per field)"); + "\" appears more than once in this document (only one value is allowed per field)");
} }
if (vectorValue.length != dim) { T vectorValue = (T) value;
throw new IllegalArgumentException(
"Attempt to index a vector of dimension "
+ vectorValue.length
+ " but \""
+ fieldInfo.name
+ "\" has dimension "
+ dim);
}
assert docID > lastDocID; assert docID > lastDocID;
docsWithField.add(docID); docsWithField.add(docID);
vectors.add(ArrayUtil.copyOfSubArray(vectorValue, 0, vectorValue.length)); vectors.add(copyValue(vectorValue));
if (node > 0) { if (node > 0) {
// start at node 1! node 0 is added implicitly, in the constructor // start at node 1! node 0 is added implicitly, in the constructor
hnswGraphBuilder.addGraphNode(node, vectorValue); hnswGraphBuilder.addGraphNode(node, vectorValue);
@ -608,16 +661,16 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
return docsWithField.ramBytesUsed() return docsWithField.ramBytesUsed()
+ vectors.size() + vectors.size()
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) * (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
+ vectors.size() * vectors.get(0).length * Float.BYTES + vectors.size() * fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize
+ hnswGraphBuilder.getGraph().ramBytesUsed(); + hnswGraphBuilder.getGraph().ramBytesUsed();
} }
} }
private static class RAVectorValues implements RandomAccessVectorValues { private static class RAVectorValues<T> implements RandomAccessVectorValues {
private final List<float[]> vectors; private final List<T> vectors;
private final int dim; private final int dim;
RAVectorValues(List<float[]> vectors, int dim) { RAVectorValues(List<T> vectors, int dim) {
this.vectors = vectors; this.vectors = vectors;
this.dim = dim; this.dim = dim;
} }
@ -634,12 +687,12 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
@Override @Override
public float[] vectorValue(int targetOrd) throws IOException { public float[] vectorValue(int targetOrd) throws IOException {
return vectors.get(targetOrd); return (float[]) vectors.get(targetOrd);
} }
@Override @Override
public BytesRef binaryValue(int targetOrd) throws IOException { public BytesRef binaryValue(int targetOrd) throws IOException {
throw new UnsupportedOperationException(); return (BytesRef) vectors.get(targetOrd);
} }
} }
} }

View File

@ -41,11 +41,11 @@ abstract class OffHeapVectorValues extends VectorValues
protected final int byteSize; protected final int byteSize;
protected final float[] value; protected final float[] value;
OffHeapVectorValues(int dimension, int size, IndexInput slice) { OffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
this.dimension = dimension; this.dimension = dimension;
this.size = size; this.size = size;
this.slice = slice; this.slice = slice;
byteSize = Float.BYTES * dimension; this.byteSize = byteSize;
byteBuffer = ByteBuffer.allocate(byteSize); byteBuffer = ByteBuffer.allocate(byteSize);
value = new float[dimension]; value = new float[dimension];
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize); binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
@ -93,10 +93,16 @@ abstract class OffHeapVectorValues extends VectorValues
} }
IndexInput bytesSlice = IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength); vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
int byteSize =
switch (fieldEntry.vectorEncoding) {
case BYTE -> fieldEntry.dimension;
case FLOAT32 -> fieldEntry.dimension * Float.BYTES;
};
if (fieldEntry.docsWithFieldOffset == -1) { if (fieldEntry.docsWithFieldOffset == -1) {
return new DenseOffHeapVectorValues(fieldEntry.dimension, fieldEntry.size, bytesSlice); return new DenseOffHeapVectorValues(
fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize);
} else { } else {
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice); return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice, byteSize);
} }
} }
@ -106,8 +112,8 @@ abstract class OffHeapVectorValues extends VectorValues
private int doc = -1; private int doc = -1;
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice) { public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
super(dimension, size, slice); super(dimension, size, slice, byteSize);
} }
@Override @Override
@ -145,7 +151,7 @@ abstract class OffHeapVectorValues extends VectorValues
@Override @Override
public RandomAccessVectorValues randomAccess() throws IOException { public RandomAccessVectorValues randomAccess() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone()); return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
} }
@Override @Override
@ -167,10 +173,13 @@ abstract class OffHeapVectorValues extends VectorValues
private final Lucene94HnswVectorsReader.FieldEntry fieldEntry; private final Lucene94HnswVectorsReader.FieldEntry fieldEntry;
public SparseOffHeapVectorValues( public SparseOffHeapVectorValues(
Lucene94HnswVectorsReader.FieldEntry fieldEntry, IndexInput dataIn, IndexInput slice) Lucene94HnswVectorsReader.FieldEntry fieldEntry,
IndexInput dataIn,
IndexInput slice,
int byteSize)
throws IOException { throws IOException {
super(fieldEntry.dimension, fieldEntry.size, slice); super(fieldEntry.dimension, fieldEntry.size, slice, byteSize);
this.fieldEntry = fieldEntry; this.fieldEntry = fieldEntry;
final RandomAccessInput addressesData = final RandomAccessInput addressesData =
dataIn.randomAccessSlice(fieldEntry.addressesOffset, fieldEntry.addressesLength); dataIn.randomAccessSlice(fieldEntry.addressesOffset, fieldEntry.addressesLength);
@ -218,7 +227,7 @@ abstract class OffHeapVectorValues extends VectorValues
@Override @Override
public RandomAccessVectorValues randomAccess() throws IOException { public RandomAccessVectorValues randomAccess() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone()); return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
} }
@Override @Override
@ -248,7 +257,7 @@ abstract class OffHeapVectorValues extends VectorValues
private static class EmptyOffHeapVectorValues extends OffHeapVectorValues { private static class EmptyOffHeapVectorValues extends OffHeapVectorValues {
public EmptyOffHeapVectorValues(int dimension) { public EmptyOffHeapVectorValues(int dimension) {
super(dimension, 0, null); super(dimension, 0, null, 0);
} }
private int doc = -1; private int doc = -1;

View File

@ -144,7 +144,7 @@
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90SegmentInfoFormat Segment info}. This * <li>{@link org.apache.lucene.codecs.lucene90.Lucene90SegmentInfoFormat Segment info}. This
* contains metadata about a segment, such as the number of documents, what files it uses, and * contains metadata about a segment, such as the number of documents, what files it uses, and
* information about how the segment is sorted * information about how the segment is sorted
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Field names}. This * <li>{@link org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat Field names}. This
* contains metadata about the set of named fields used in the index. * contains metadata about the set of named fields used in the index.
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}. * <li>{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}.
* This contains, for each document, a list of attribute-value pairs, where the attributes are * This contains, for each document, a list of attribute-value pairs, where the attributes are
@ -240,7 +240,7 @@
* systems that frequently run out of file handles.</td> * systems that frequently run out of file handles.</td>
* </tr> * </tr>
* <tr> * <tr>
* <td>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Fields}</td> * <td>{@link org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat Fields}</td>
* <td>.fnm</td> * <td>.fnm</td>
* <td>Stores information about the fields</td> * <td>Stores information about the fields</td>
* </tr> * </tr>

View File

@ -33,6 +33,7 @@ import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter; import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
@ -101,7 +102,7 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
} }
@Override @Override
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
KnnVectorsWriter writer = getInstance(fieldInfo); KnnVectorsWriter writer = getInstance(fieldInfo);
return writer.addField(fieldInfo); return writer.addField(fieldInfo);
} }
@ -267,6 +268,17 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
} }
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
KnnVectorsReader knnVectorsReader = fields.get(field);
if (knnVectorsReader == null) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
} else {
return knnVectorsReader.searchExhaustively(field, target, k, acceptDocs);
}
}
@Override @Override
public void close() throws IOException { public void close() throws IOException {
IOUtils.close(fields.values()); IOUtils.close(fields.values());

View File

@ -24,6 +24,7 @@ import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexableFieldType; import org.apache.lucene.index.IndexableFieldType;
import org.apache.lucene.index.PointValues; import org.apache.lucene.index.PointValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
@ -44,6 +45,7 @@ public class FieldType implements IndexableFieldType {
private int indexDimensionCount; private int indexDimensionCount;
private int dimensionNumBytes; private int dimensionNumBytes;
private int vectorDimension; private int vectorDimension;
private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32;
private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN; private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
private Map<String, String> attributes; private Map<String, String> attributes;
@ -62,6 +64,7 @@ public class FieldType implements IndexableFieldType {
this.indexDimensionCount = ref.pointIndexDimensionCount(); this.indexDimensionCount = ref.pointIndexDimensionCount();
this.dimensionNumBytes = ref.pointNumBytes(); this.dimensionNumBytes = ref.pointNumBytes();
this.vectorDimension = ref.vectorDimension(); this.vectorDimension = ref.vectorDimension();
this.vectorEncoding = ref.vectorEncoding();
this.vectorSimilarityFunction = ref.vectorSimilarityFunction(); this.vectorSimilarityFunction = ref.vectorSimilarityFunction();
if (ref.getAttributes() != null) { if (ref.getAttributes() != null) {
this.attributes = new HashMap<>(ref.getAttributes()); this.attributes = new HashMap<>(ref.getAttributes());
@ -371,8 +374,8 @@ public class FieldType implements IndexableFieldType {
} }
/** Enable vector indexing, with the specified number of dimensions and distance function. */ /** Enable vector indexing, with the specified number of dimensions and distance function. */
public void setVectorDimensionsAndSimilarityFunction( public void setVectorAttributes(
int numDimensions, VectorSimilarityFunction distFunc) { int numDimensions, VectorEncoding encoding, VectorSimilarityFunction similarity) {
checkIfFrozen(); checkIfFrozen();
if (numDimensions <= 0) { if (numDimensions <= 0) {
throw new IllegalArgumentException("vector numDimensions must be > 0; got " + numDimensions); throw new IllegalArgumentException("vector numDimensions must be > 0; got " + numDimensions);
@ -385,7 +388,8 @@ public class FieldType implements IndexableFieldType {
+ numDimensions); + numDimensions);
} }
this.vectorDimension = numDimensions; this.vectorDimension = numDimensions;
this.vectorSimilarityFunction = Objects.requireNonNull(distFunc); this.vectorSimilarityFunction = Objects.requireNonNull(similarity);
this.vectorEncoding = Objects.requireNonNull(encoding);
} }
@Override @Override
@ -393,6 +397,11 @@ public class FieldType implements IndexableFieldType {
return vectorDimension; return vectorDimension;
} }
@Override
public VectorEncoding vectorEncoding() {
return vectorEncoding;
}
@Override @Override
public VectorSimilarityFunction vectorSimilarityFunction() { public VectorSimilarityFunction vectorSimilarityFunction() {
return vectorSimilarityFunction; return vectorSimilarityFunction;

View File

@ -17,8 +17,10 @@
package org.apache.lucene.document; package org.apache.lucene.document;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.VectorUtil;
/** /**
@ -39,7 +41,18 @@ public class KnnVectorField extends Field {
if (v == null) { if (v == null) {
throw new IllegalArgumentException("vector value must not be null"); throw new IllegalArgumentException("vector value must not be null");
} }
int dimension = v.length; return createType(v.length, VectorEncoding.FLOAT32, similarityFunction);
}
private static FieldType createType(BytesRef v, VectorSimilarityFunction similarityFunction) {
if (v == null) {
throw new IllegalArgumentException("vector value must not be null");
}
return createType(v.length, VectorEncoding.BYTE, similarityFunction);
}
private static FieldType createType(
int dimension, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) {
if (dimension == 0) { if (dimension == 0) {
throw new IllegalArgumentException("cannot index an empty vector"); throw new IllegalArgumentException("cannot index an empty vector");
} }
@ -51,13 +64,13 @@ public class KnnVectorField extends Field {
throw new IllegalArgumentException("similarity function must not be null"); throw new IllegalArgumentException("similarity function must not be null");
} }
FieldType type = new FieldType(); FieldType type = new FieldType();
type.setVectorDimensionsAndSimilarityFunction(dimension, similarityFunction); type.setVectorAttributes(dimension, vectorEncoding, similarityFunction);
type.freeze(); type.freeze();
return type; return type;
} }
/** /**
* A convenience method for creating a vector field type. * A convenience method for creating a vector field type with the default FLOAT32 encoding.
* *
* @param dimension dimension of vectors * @param dimension dimension of vectors
* @param similarityFunction a function defining vector proximity. * @param similarityFunction a function defining vector proximity.
@ -65,8 +78,21 @@ public class KnnVectorField extends Field {
*/ */
public static FieldType createFieldType( public static FieldType createFieldType(
int dimension, VectorSimilarityFunction similarityFunction) { int dimension, VectorSimilarityFunction similarityFunction) {
return createFieldType(dimension, VectorEncoding.FLOAT32, similarityFunction);
}
/**
* A convenience method for creating a vector field type.
*
* @param dimension dimension of vectors
* @param vectorEncoding the encoding of the scalar values
* @param similarityFunction a function defining vector proximity.
* @throws IllegalArgumentException if any parameter is null, or has dimension &gt; 1024.
*/
public static FieldType createFieldType(
int dimension, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) {
FieldType type = new FieldType(); FieldType type = new FieldType();
type.setVectorDimensionsAndSimilarityFunction(dimension, similarityFunction); type.setVectorAttributes(dimension, vectorEncoding, similarityFunction);
type.freeze(); type.freeze();
return type; return type;
} }
@ -74,8 +100,8 @@ public class KnnVectorField extends Field {
/** /**
* Creates a numeric vector field. Fields are single-valued: each document has either one value or * Creates a numeric vector field. Fields are single-valued: each document has either one value or
* no value. Vectors of a single field share the same dimension and similarity function. Note that * no value. Vectors of a single field share the same dimension and similarity function. Note that
* some strategies (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to be * some vector similarities (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to
* unit-length, which can be enforced using {@link VectorUtil#l2normalize(float[])}. * be unit-length, which can be enforced using {@link VectorUtil#l2normalize(float[])}.
* *
* @param name field name * @param name field name
* @param vector value * @param vector value
@ -88,6 +114,23 @@ public class KnnVectorField extends Field {
fieldsData = vector; fieldsData = vector;
} }
/**
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
* no value. Vectors of a single field share the same dimension and similarity function. Note that
* some vector similarities (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to
* be constant-length.
*
* @param name field name
* @param vector value
* @param similarityFunction a function defining vector proximity.
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
* dimension &gt; 1024.
*/
public KnnVectorField(String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
fieldsData = vector;
}
/** /**
* Creates a numeric vector field with the default EUCLIDEAN_HNSW (L2) similarity. Fields are * Creates a numeric vector field with the default EUCLIDEAN_HNSW (L2) similarity. Fields are
* single-valued: each document has either one value or no value. Vectors of a single field share * single-valued: each document has either one value or no value. Vectors of a single field share
@ -117,6 +160,21 @@ public class KnnVectorField extends Field {
fieldsData = vector; fieldsData = vector;
} }
/**
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
* no value. Vectors of a single field share the same dimension and similarity function.
*
* @param name field name
* @param vector value
* @param fieldType field type
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
* dimension &gt; 1024.
*/
public KnnVectorField(String name, BytesRef vector, FieldType fieldType) {
super(name, fieldType);
fieldsData = vector;
}
/** Return the vector value of this field */ /** Return the vector value of this field */
public float[] vectorValue() { public float[] vectorValue() {
return (float[]) fieldsData; return (float[]) fieldsData;

View File

@ -45,7 +45,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
protected BufferingKnnVectorsWriter() {} protected BufferingKnnVectorsWriter() {}
@Override @Override
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { public KnnFieldVectorsWriter<float[]> addField(FieldInfo fieldInfo) throws IOException {
FieldWriter newField = new FieldWriter(fieldInfo); FieldWriter newField = new FieldWriter(fieldInfo);
fields.add(newField); fields.add(newField);
return newField; return newField;
@ -88,6 +88,12 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) { String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
throw new UnsupportedOperationException();
}
}; };
writeField(fieldData.fieldInfo, knnVectorsReader, maxDoc); writeField(fieldData.fieldInfo, knnVectorsReader, maxDoc);
@ -122,6 +128,12 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
throw new UnsupportedOperationException();
}
@Override @Override
public VectorValues getVectorValues(String field) throws IOException { public VectorValues getVectorValues(String field) throws IOException {
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState); return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
@ -137,7 +149,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
protected abstract void writeField( protected abstract void writeField(
FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc) throws IOException; FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc) throws IOException;
private static class FieldWriter extends KnnFieldVectorsWriter { private static class FieldWriter extends KnnFieldVectorsWriter<float[]> {
private final FieldInfo fieldInfo; private final FieldInfo fieldInfo;
private final int dim; private final int dim;
private final DocsWithFieldSet docsWithField; private final DocsWithFieldSet docsWithField;
@ -153,35 +165,45 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
} }
@Override @Override
public void addValue(int docID, float[] vectorValue) { public void addValue(int docID, Object value) {
if (docID == lastDocID) { if (docID == lastDocID) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"VectorValuesField \"" "VectorValuesField \""
+ fieldInfo.name + fieldInfo.name
+ "\" appears more than once in this document (only one value is allowed per field)"); + "\" appears more than once in this document (only one value is allowed per field)");
} }
if (vectorValue.length != dim) {
throw new IllegalArgumentException(
"Attempt to index a vector of dimension "
+ vectorValue.length
+ " but \""
+ fieldInfo.name
+ "\" has dimension "
+ dim);
}
assert docID > lastDocID; assert docID > lastDocID;
float[] vectorValue =
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32 -> (float[]) value;
case BYTE -> bytesToFloats((BytesRef) value);
};
docsWithField.add(docID); docsWithField.add(docID);
vectors.add(ArrayUtil.copyOfSubArray(vectorValue, 0, vectorValue.length)); vectors.add(copyValue(vectorValue));
lastDocID = docID; lastDocID = docID;
} }
private float[] bytesToFloats(BytesRef b) {
// This is used only by SimpleTextKnnVectorsWriter
float[] floats = new float[dim];
for (int i = 0; i < dim; i++) {
floats[i] = b.bytes[i + b.offset];
}
return floats;
}
@Override
public float[] copyValue(float[] vectorValue) {
return ArrayUtil.copyOfSubArray(vectorValue, 0, dim);
}
@Override @Override
public long ramBytesUsed() { public long ramBytesUsed() {
if (vectors.size() == 0) return 0; if (vectors.size() == 0) return 0;
return docsWithField.ramBytesUsed() return docsWithField.ramBytesUsed()
+ vectors.size() + vectors.size()
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) * (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
+ vectors.size() * vectors.get(0).length * Float.BYTES; + vectors.size() * dim * Float.BYTES;
} }
} }

View File

@ -25,6 +25,7 @@ import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.PointsReader;
import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader; import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
@ -235,6 +236,19 @@ public abstract class CodecReader extends LeafReader {
return getVectorReader().search(field, target, k, acceptDocs, visitedLimit); return getVectorReader().search(field, target, k, acceptDocs, visitedLimit);
} }
@Override
public final TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
ensureOpen();
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// Field does not exist or does not index vectors
return null;
}
return getVectorReader().searchExhaustively(field, target, k, acceptDocs);
}
@Override @Override
protected void doClose() throws IOException {} protected void doClose() throws IOException {}

View File

@ -18,6 +18,7 @@
package org.apache.lucene.index; package org.apache.lucene.index;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
@ -58,6 +59,12 @@ abstract class DocValuesLeafReader extends LeafReader {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
throw new UnsupportedOperationException();
}
@Override @Override
public final void checkIntegrity() throws IOException { public final void checkIntegrity() throws IOException {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();

View File

@ -56,6 +56,7 @@ public final class FieldInfo {
// if it is a positive value, it means this field indexes vectors // if it is a positive value, it means this field indexes vectors
private final int vectorDimension; private final int vectorDimension;
private final VectorEncoding vectorEncoding;
private final VectorSimilarityFunction vectorSimilarityFunction; private final VectorSimilarityFunction vectorSimilarityFunction;
// whether this field is used as the soft-deletes field // whether this field is used as the soft-deletes field
@ -80,6 +81,7 @@ public final class FieldInfo {
int pointIndexDimensionCount, int pointIndexDimensionCount,
int pointNumBytes, int pointNumBytes,
int vectorDimension, int vectorDimension,
VectorEncoding vectorEncoding,
VectorSimilarityFunction vectorSimilarityFunction, VectorSimilarityFunction vectorSimilarityFunction,
boolean softDeletesField) { boolean softDeletesField) {
this.name = Objects.requireNonNull(name); this.name = Objects.requireNonNull(name);
@ -105,6 +107,7 @@ public final class FieldInfo {
this.pointIndexDimensionCount = pointIndexDimensionCount; this.pointIndexDimensionCount = pointIndexDimensionCount;
this.pointNumBytes = pointNumBytes; this.pointNumBytes = pointNumBytes;
this.vectorDimension = vectorDimension; this.vectorDimension = vectorDimension;
this.vectorEncoding = vectorEncoding;
this.vectorSimilarityFunction = vectorSimilarityFunction; this.vectorSimilarityFunction = vectorSimilarityFunction;
this.softDeletesField = softDeletesField; this.softDeletesField = softDeletesField;
this.checkConsistency(); this.checkConsistency();
@ -229,8 +232,10 @@ public final class FieldInfo {
verifySameVectorOptions( verifySameVectorOptions(
fieldName, fieldName,
this.vectorDimension, this.vectorDimension,
this.vectorEncoding,
this.vectorSimilarityFunction, this.vectorSimilarityFunction,
o.vectorDimension, o.vectorDimension,
o.vectorEncoding,
o.vectorSimilarityFunction); o.vectorSimilarityFunction);
} }
@ -347,19 +352,25 @@ public final class FieldInfo {
static void verifySameVectorOptions( static void verifySameVectorOptions(
String fieldName, String fieldName,
int vd1, int vd1,
VectorEncoding ve1,
VectorSimilarityFunction vsf1, VectorSimilarityFunction vsf1,
int vd2, int vd2,
VectorEncoding ve2,
VectorSimilarityFunction vsf2) { VectorSimilarityFunction vsf2) {
if (vd1 != vd2 || vsf1 != vsf2) { if (vd1 != vd2 || vsf1 != vsf2 || ve1 != ve2) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"cannot change field \"" "cannot change field \""
+ fieldName + fieldName
+ "\" from vector dimension=" + "\" from vector dimension="
+ vd1 + vd1
+ ", vector encoding="
+ ve1
+ ", vector similarity function=" + ", vector similarity function="
+ vsf1 + vsf1
+ " to inconsistent vector dimension=" + " to inconsistent vector dimension="
+ vd2 + vd2
+ ", vector encoding="
+ ve2
+ ", vector similarity function=" + ", vector similarity function="
+ vsf2); + vsf2);
} }
@ -470,6 +481,11 @@ public final class FieldInfo {
return vectorDimension; return vectorDimension;
} }
/** Returns the number of dimensions of the vector value */
public VectorEncoding getVectorEncoding() {
return vectorEncoding;
}
/** Returns {@link VectorSimilarityFunction} for the field */ /** Returns {@link VectorSimilarityFunction} for the field */
public VectorSimilarityFunction getVectorSimilarityFunction() { public VectorSimilarityFunction getVectorSimilarityFunction() {
return vectorSimilarityFunction; return vectorSimilarityFunction;

View File

@ -308,10 +308,15 @@ public class FieldInfos implements Iterable<FieldInfo> {
static final class FieldVectorProperties { static final class FieldVectorProperties {
final int numDimensions; final int numDimensions;
final VectorEncoding vectorEncoding;
final VectorSimilarityFunction similarityFunction; final VectorSimilarityFunction similarityFunction;
FieldVectorProperties(int numDimensions, VectorSimilarityFunction similarityFunction) { FieldVectorProperties(
int numDimensions,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction) {
this.numDimensions = numDimensions; this.numDimensions = numDimensions;
this.vectorEncoding = vectorEncoding;
this.similarityFunction = similarityFunction; this.similarityFunction = similarityFunction;
} }
} }
@ -401,7 +406,8 @@ public class FieldInfos implements Iterable<FieldInfo> {
fi.getPointNumBytes())); fi.getPointNumBytes()));
vectorProps.put( vectorProps.put(
fieldName, fieldName,
new FieldVectorProperties(fi.getVectorDimension(), fi.getVectorSimilarityFunction())); new FieldVectorProperties(
fi.getVectorDimension(), fi.getVectorEncoding(), fi.getVectorSimilarityFunction()));
} }
return fieldNumber.intValue(); return fieldNumber.intValue();
} }
@ -459,8 +465,10 @@ public class FieldInfos implements Iterable<FieldInfo> {
verifySameVectorOptions( verifySameVectorOptions(
fieldName, fieldName,
props.numDimensions, props.numDimensions,
props.vectorEncoding,
props.similarityFunction, props.similarityFunction,
fi.getVectorDimension(), fi.getVectorDimension(),
fi.getVectorEncoding(),
fi.getVectorSimilarityFunction()); fi.getVectorSimilarityFunction());
} }
@ -503,6 +511,7 @@ public class FieldInfos implements Iterable<FieldInfo> {
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
(softDeletesFieldName != null && softDeletesFieldName.equals(fieldName))); (softDeletesFieldName != null && softDeletesFieldName.equals(fieldName)));
addOrGet(fi); addOrGet(fi);
@ -584,6 +593,7 @@ public class FieldInfos implements Iterable<FieldInfo> {
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
isSoftDeletesField); isSoftDeletesField);
} }
@ -698,6 +708,7 @@ public class FieldInfos implements Iterable<FieldInfo> {
fi.getPointIndexDimensionCount(), fi.getPointIndexDimensionCount(),
fi.getPointNumBytes(), fi.getPointNumBytes(),
fi.getVectorDimension(), fi.getVectorDimension(),
fi.getVectorEncoding(),
fi.getVectorSimilarityFunction(), fi.getVectorSimilarityFunction(),
fi.isSoftDeletesField()); fi.isSoftDeletesField());
byName.put(fiNew.getName(), fiNew); byName.put(fiNew.getName(), fiNew);

View File

@ -18,6 +18,7 @@ package org.apache.lucene.index;
import java.io.IOException; import java.io.IOException;
import java.util.Iterator; import java.util.Iterator;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.AttributeSource; import org.apache.lucene.util.AttributeSource;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
@ -357,6 +358,12 @@ public abstract class FilterLeafReader extends LeafReader {
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit); return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
} }
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
return in.searchNearestVectorsExhaustively(field, target, k, acceptDocs);
}
@Override @Override
public Fields getTermVectors(int docID) throws IOException { public Fields getTermVectors(int docID) throws IOException {
ensureOpen(); ensureOpen();

View File

@ -101,6 +101,9 @@ public interface IndexableFieldType {
/** The number of dimensions of the field's vector value */ /** The number of dimensions of the field's vector value */
int vectorDimension(); int vectorDimension();
/** The {@link VectorEncoding} of the field's vector value */
VectorEncoding vectorEncoding();
/** The {@link VectorSimilarityFunction} of the field's vector value */ /** The {@link VectorSimilarityFunction} of the field's vector value */
VectorSimilarityFunction vectorSimilarityFunction(); VectorSimilarityFunction vectorSimilarityFunction();

View File

@ -628,6 +628,7 @@ final class IndexingChain implements Accountable {
s.pointIndexDimensionCount, s.pointIndexDimensionCount,
s.pointNumBytes, s.pointNumBytes,
s.vectorDimension, s.vectorDimension,
s.vectorEncoding,
s.vectorSimilarityFunction, s.vectorSimilarityFunction,
pf.fieldName.equals(fieldInfos.getSoftDeletesFieldName()))); pf.fieldName.equals(fieldInfos.getSoftDeletesFieldName())));
pf.setFieldInfo(fi); pf.setFieldInfo(fi);
@ -712,7 +713,11 @@ final class IndexingChain implements Accountable {
pf.pointValuesWriter.addPackedValue(docID, field.binaryValue()); pf.pointValuesWriter.addPackedValue(docID, field.binaryValue());
} }
if (fieldType.vectorDimension() != 0) { if (fieldType.vectorDimension() != 0) {
pf.knnFieldVectorsWriter.addValue(docID, ((KnnVectorField) field).vectorValue()); switch (fieldType.vectorEncoding()) {
case BYTE -> pf.knnFieldVectorsWriter.addValue(docID, field.binaryValue());
case FLOAT32 -> pf.knnFieldVectorsWriter.addValue(
docID, ((KnnVectorField) field).vectorValue());
}
} }
return indexedField; return indexedField;
} }
@ -776,7 +781,10 @@ final class IndexingChain implements Accountable {
fieldType.pointNumBytes()); fieldType.pointNumBytes());
} }
if (fieldType.vectorDimension() != 0) { if (fieldType.vectorDimension() != 0) {
schema.setVectors(fieldType.vectorSimilarityFunction(), fieldType.vectorDimension()); schema.setVectors(
fieldType.vectorEncoding(),
fieldType.vectorSimilarityFunction(),
fieldType.vectorDimension());
} }
if (fieldType.getAttributes() != null && fieldType.getAttributes().isEmpty() == false) { if (fieldType.getAttributes() != null && fieldType.getAttributes().isEmpty() == false) {
schema.updateAttributes(fieldType.getAttributes()); schema.updateAttributes(fieldType.getAttributes());
@ -988,7 +996,7 @@ final class IndexingChain implements Accountable {
PointValuesWriter pointValuesWriter; PointValuesWriter pointValuesWriter;
// Non-null if this field had vectors in this segment // Non-null if this field had vectors in this segment
KnnFieldVectorsWriter knnFieldVectorsWriter; KnnFieldVectorsWriter<?> knnFieldVectorsWriter;
/** We use this to know when a PerField is seen for the first time in the current document. */ /** We use this to know when a PerField is seen for the first time in the current document. */
long fieldGen = -1; long fieldGen = -1;
@ -1281,6 +1289,7 @@ final class IndexingChain implements Accountable {
private int pointIndexDimensionCount = 0; private int pointIndexDimensionCount = 0;
private int pointNumBytes = 0; private int pointNumBytes = 0;
private int vectorDimension = 0; private int vectorDimension = 0;
private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32;
private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN; private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
private static String errMsg = private static String errMsg =
@ -1361,11 +1370,14 @@ final class IndexingChain implements Accountable {
} }
} }
void setVectors(VectorSimilarityFunction similarityFunction, int dimension) { void setVectors(
VectorEncoding encoding, VectorSimilarityFunction similarityFunction, int dimension) {
if (vectorDimension == 0) { if (vectorDimension == 0) {
this.vectorDimension = dimension; this.vectorEncoding = encoding;
this.vectorSimilarityFunction = similarityFunction; this.vectorSimilarityFunction = similarityFunction;
this.vectorDimension = dimension;
} else { } else {
assertSame("vector encoding", vectorEncoding, encoding);
assertSame("vector similarity function", vectorSimilarityFunction, similarityFunction); assertSame("vector similarity function", vectorSimilarityFunction, similarityFunction);
assertSame("vector dimension", vectorDimension, dimension); assertSame("vector dimension", vectorDimension, dimension);
} }
@ -1381,6 +1393,7 @@ final class IndexingChain implements Accountable {
pointIndexDimensionCount = 0; pointIndexDimensionCount = 0;
pointNumBytes = 0; pointNumBytes = 0;
vectorDimension = 0; vectorDimension = 0;
vectorEncoding = VectorEncoding.FLOAT32;
vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN; vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
} }
@ -1391,6 +1404,7 @@ final class IndexingChain implements Accountable {
assertSame("doc values type", fi.getDocValuesType(), docValuesType); assertSame("doc values type", fi.getDocValuesType(), docValuesType);
assertSame( assertSame(
"vector similarity function", fi.getVectorSimilarityFunction(), vectorSimilarityFunction); "vector similarity function", fi.getVectorSimilarityFunction(), vectorSimilarityFunction);
assertSame("vector encoding", fi.getVectorEncoding(), vectorEncoding);
assertSame("vector dimension", fi.getVectorDimension(), vectorDimension); assertSame("vector dimension", fi.getVectorDimension(), vectorDimension);
assertSame("point dimension", fi.getPointDimensionCount(), pointDimensionCount); assertSame("point dimension", fi.getPointDimensionCount(), pointDimensionCount);
assertSame( assertSame(

View File

@ -17,6 +17,7 @@
package org.apache.lucene.index; package org.apache.lucene.index;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
@ -235,6 +236,30 @@ public abstract class LeafReader extends IndexReader {
public abstract TopDocs searchNearestVectors( public abstract TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException; String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
/**
* Return the k nearest neighbor documents as determined by comparison of their vector values for
* this field, to the given vector, by the field's similarity function. The score of each document
* is derived from the vector similarity in a way that ensures scores are positive and that a
* larger score corresponds to a higher ranking.
*
* <p>The search is exact, meaning the results are guaranteed to be the true k closest neighbors.
* This typically requires an exhaustive scan of all candidate documents.
*
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor,
* sorted in order of their similarity to the query vector (decreasing scores). The {@link
* TotalHits} contains the number of documents visited during the search.
*
* @param field the vector field to search
* @param target the vector-valued query
* @param k the number of docs to return
* @param acceptDocs {@link DocIdSetIterator} that represents the allowed documents to match, or
* {@code null} if they are all allowed to match.
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
* @lucene.experimental
*/
public abstract TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException;
/** /**
* Get the {@link FieldInfos} describing all fields in this reader. * Get the {@link FieldInfos} describing all fields in this reader.
* *

View File

@ -26,6 +26,7 @@ import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.SortedMap; import java.util.SortedMap;
import java.util.TreeMap; import java.util.TreeMap;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Sort; import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
@ -403,6 +404,16 @@ public class ParallelLeafReader extends LeafReader {
: reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit); : reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit);
} }
@Override
public TopDocs searchNearestVectorsExhaustively(
String fieldName, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
ensureOpen();
LeafReader reader = fieldToReader.get(fieldName);
return reader == null
? null
: reader.searchNearestVectorsExhaustively(fieldName, target, k, acceptDocs);
}
@Override @Override
public void checkIntegrity() throws IOException { public void checkIntegrity() throws IOException {
ensureOpen(); ensureOpen();

View File

@ -722,6 +722,7 @@ final class ReadersAndUpdates {
fi.getPointIndexDimensionCount(), fi.getPointIndexDimensionCount(),
fi.getPointNumBytes(), fi.getPointNumBytes(),
fi.getVectorDimension(), fi.getVectorDimension(),
fi.getVectorEncoding(),
fi.getVectorSimilarityFunction(), fi.getVectorSimilarityFunction(),
fi.isSoftDeletesField()); fi.isSoftDeletesField());
} }

View File

@ -27,6 +27,7 @@ import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.PointsReader;
import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader; import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
@ -172,6 +173,12 @@ public final class SlowCodecReaderWrapper {
return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit); return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
return reader.searchNearestVectorsExhaustively(field, target, k, acceptDocs);
}
@Override @Override
public void checkIntegrity() { public void checkIntegrity() {
// We already checkIntegrity the entire reader up front // We already checkIntegrity the entire reader up front

View File

@ -31,6 +31,7 @@ import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.PointsReader;
import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader; import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Sort; import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField; import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
@ -389,6 +390,12 @@ public final class SortingCodecReader extends FilterCodecReader {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
throw new UnsupportedOperationException();
}
@Override @Override
public void close() throws IOException { public void close() throws IOException {
delegate.close(); delegate.close();

View File

@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.index;
/** The numeric datatype of the vector values. */
public enum VectorEncoding {
/**
* Encodes vector using 8 bits of precision per sample. Use only with DOT_PRODUCT similarity.
* NOTE: this can enable significant storage savings and faster searches, at the cost of some
* possible loss of precision. In order to use it, all vectors must be of the same norm, as
* measured by the sum of the squares of the scalar values, and those values must be in the range
* [-128, 127]. This applies to both document and query vectors. Using nonconforming vectors can
* result in errors or poor search results.
*/
BYTE(1),
/** Encodes vector using 32 bits of precision per sample in IEEE floating point format. */
FLOAT32(4);
/**
* The number of bytes required to encode a scalar in this format. A vector will require dimension
* * byteSize.
*/
public final int byteSize;
VectorEncoding(int byteSize) {
this.byteSize = byteSize;
}
}

View File

@ -18,6 +18,8 @@ package org.apache.lucene.index;
import static org.apache.lucene.util.VectorUtil.*; import static org.apache.lucene.util.VectorUtil.*;
import org.apache.lucene.util.BytesRef;
/** /**
* Vector similarity function; used in search to return top K most similar vectors to a target * Vector similarity function; used in search to return top K most similar vectors to a target
* vector. This is a label describing the method used during indexing and searching of the vectors * vector. This is a label describing the method used during indexing and searching of the vectors
@ -31,6 +33,11 @@ public enum VectorSimilarityFunction {
public float compare(float[] v1, float[] v2) { public float compare(float[] v1, float[] v2) {
return 1 / (1 + squareDistance(v1, v2)); return 1 / (1 + squareDistance(v1, v2));
} }
@Override
public float compare(BytesRef v1, BytesRef v2) {
return 1 / (1 + squareDistance(v1, v2));
}
}, },
/** /**
@ -44,6 +51,11 @@ public enum VectorSimilarityFunction {
public float compare(float[] v1, float[] v2) { public float compare(float[] v1, float[] v2) {
return (1 + dotProduct(v1, v2)) / 2; return (1 + dotProduct(v1, v2)) / 2;
} }
@Override
public float compare(BytesRef v1, BytesRef v2) {
return dotProductScore(v1, v2);
}
}, },
/** /**
@ -57,6 +69,11 @@ public enum VectorSimilarityFunction {
public float compare(float[] v1, float[] v2) { public float compare(float[] v1, float[] v2) {
return (1 + cosine(v1, v2)) / 2; return (1 + cosine(v1, v2)) / 2;
} }
@Override
public float compare(BytesRef v1, BytesRef v2) {
return (1 + cosine(v1, v2)) / 2;
}
}; };
/** /**
@ -68,4 +85,15 @@ public enum VectorSimilarityFunction {
* @return the value of the similarity function applied to the two vectors * @return the value of the similarity function applied to the two vectors
*/ */
public abstract float compare(float[] v1, float[] v2); public abstract float compare(float[] v1, float[] v2);
/**
* Calculates a similarity score between the two vectors with a specified function. Higher
* similarity scores correspond to closer vectors. The offsets and lengths of the BytesRefs
* determine the vector data that is compared. Each (signed) byte represents a vector dimension.
*
* @param v1 a vector
* @param v2 another vector, of the same dimension
* @return the value of the similarity function applied to the two vectors
*/
public abstract float compare(BytesRef v1, BytesRef v2);
} }

View File

@ -65,7 +65,7 @@ class VectorValuesConsumer {
} }
} }
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
initKnnVectorsWriter(fieldInfo.name); initKnnVectorsWriter(fieldInfo.name);
return writer.addField(fieldInfo); return writer.addField(fieldInfo);
} }

View File

@ -24,11 +24,8 @@ import java.util.Comparator;
import java.util.Objects; import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
@ -133,22 +130,21 @@ public class KnnVectorQuery extends Query {
return NO_RESULTS; return NO_RESULTS;
} }
BitSet bitSet = createBitSet(scorer.iterator(), liveDocs, maxDoc); BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
BitSetIterator filterIterator = new BitSetIterator(bitSet, bitSet.cardinality());
if (filterIterator.cost() <= k) { if (acceptDocs.cardinality() <= k) {
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW // If there are <= k possible matches, short-circuit and perform exact search, since HNSW
// must always visit at least k documents // must always visit at least k documents
return exactSearch(ctx, filterIterator); return exactSearch(ctx, new BitSetIterator(acceptDocs, acceptDocs.cardinality()));
} }
// Perform the approximate kNN search // Perform the approximate kNN search
TopDocs results = approximateSearch(ctx, bitSet, (int) filterIterator.cost()); TopDocs results = approximateSearch(ctx, acceptDocs, acceptDocs.cardinality());
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) { if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
return results; return results;
} else { } else {
// We stopped the kNN search because it visited too many nodes, so fall back to exact search // We stopped the kNN search because it visited too many nodes, so fall back to exact search
return exactSearch(ctx, filterIterator); return exactSearch(ctx, new BitSetIterator(acceptDocs, acceptDocs.cardinality()));
} }
} }
@ -178,45 +174,9 @@ public class KnnVectorQuery extends Query {
} }
// We allow this to be overridden so that tests can check what search strategy is used // We allow this to be overridden so that tests can check what search strategy is used
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptDocs)
throws IOException { throws IOException {
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); return context.reader().searchNearestVectorsExhaustively(field, target, k, acceptDocs);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return NO_RESULTS;
}
VectorSimilarityFunction similarityFunction = fi.getVectorSimilarityFunction();
VectorValues vectorValues = context.reader().getVectorValues(field);
HitQueue queue = new HitQueue(k, true);
ScoreDoc topDoc = queue.top();
int doc;
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
int vectorDoc = vectorValues.advance(doc);
assert vectorDoc == doc;
float[] vector = vectorValues.vectorValue();
float score = similarityFunction.compare(vector, target);
if (score >= topDoc.score) {
topDoc.score = score;
topDoc.doc = doc;
topDoc = queue.updateTop();
}
}
// Remove any remaining sentinel values
while (queue.size() > 0 && queue.top().score < 0) {
queue.pop();
}
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
topScoreDocs[i] = queue.pop();
}
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
return new TopDocs(totalHits, topScoreDocs);
} }
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) { private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {

View File

@ -121,6 +121,24 @@ public final class VectorUtil {
return (float) (sum / Math.sqrt(norm1 * norm2)); return (float) (sum / Math.sqrt(norm1 * norm2));
} }
/** Returns the cosine similarity between the two vectors. */
public static float cosine(BytesRef a, BytesRef b) {
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
int sum = 0;
int norm1 = 0;
int norm2 = 0;
int aOffset = a.offset, bOffset = b.offset;
for (int i = 0; i < a.length; i++) {
byte elem1 = a.bytes[aOffset++];
byte elem2 = b.bytes[bOffset++];
sum += elem1 * elem2;
norm1 += elem1 * elem1;
norm2 += elem2 * elem2;
}
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
}
/** /**
* Returns the sum of squared differences of the two vectors. * Returns the sum of squared differences of the two vectors.
* *
@ -135,7 +153,7 @@ public final class VectorUtil {
int dim = v1.length; int dim = v1.length;
int i; int i;
for (i = 0; i + 8 <= dim; i += 8) { for (i = 0; i + 8 <= dim; i += 8) {
squareSum += squareDistanceUnrolled8(v1, v2, i); squareSum += squareDistanceUnrolled(v1, v2, i);
} }
for (; i < dim; i++) { for (; i < dim; i++) {
float diff = v1[i] - v2[i]; float diff = v1[i] - v2[i];
@ -144,7 +162,7 @@ public final class VectorUtil {
return squareSum; return squareSum;
} }
private static float squareDistanceUnrolled8(float[] v1, float[] v2, int index) { private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) {
float diff0 = v1[index + 0] - v2[index + 0]; float diff0 = v1[index + 0] - v2[index + 0];
float diff1 = v1[index + 1] - v2[index + 1]; float diff1 = v1[index + 1] - v2[index + 1];
float diff2 = v1[index + 2] - v2[index + 2]; float diff2 = v1[index + 2] - v2[index + 2];
@ -163,6 +181,18 @@ public final class VectorUtil {
+ diff7 * diff7; + diff7 * diff7;
} }
/** Returns the sum of squared differences of the two vectors. */
public static float squareDistance(BytesRef a, BytesRef b) {
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
int squareSum = 0;
int aOffset = a.offset, bOffset = b.offset;
for (int i = 0; i < a.length; i++) {
int diff = a.bytes[aOffset++] - b.bytes[bOffset++];
squareSum += diff * diff;
}
return squareSum;
}
/** /**
* Modifies the argument to be unit length, dividing by its l2-norm. IllegalArgumentException is * Modifies the argument to be unit length, dividing by its l2-norm. IllegalArgumentException is
* thrown for zero vectors. * thrown for zero vectors.
@ -213,4 +243,48 @@ public final class VectorUtil {
u[i] += v[i]; u[i] += v[i];
} }
} }
/**
* Dot product computed over signed bytes.
*
* @param a bytes containing a vector
* @param b bytes containing another vector, of the same dimension
* @return the value of the dot product of the two vectors
*/
public static float dotProduct(BytesRef a, BytesRef b) {
assert a.length == b.length;
int total = 0;
int aOffset = a.offset, bOffset = b.offset;
for (int i = 0; i < a.length; i++) {
total += a.bytes[aOffset++] * b.bytes[bOffset++];
}
return total;
}
/**
* Dot product score computed over signed bytes, scaled to be in [0, 1].
*
* @param a bytes containing a vector
* @param b bytes containing another vector, of the same dimension
* @return the value of the similarity function applied to the two vectors
*/
public static float dotProductScore(BytesRef a, BytesRef b) {
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
return (1 + dotProduct(a, b)) / (float) (a.length * (1 << 15));
}
/**
* Convert a floating point vector to an array of bytes using casting; the vector values should be
* in [-128,127]
*
* @param vector a vector
* @return a new BytesRef containing the vector's values cast to byte.
*/
public static BytesRef toBytesRef(float[] vector) {
BytesRef b = new BytesRef(new byte[vector.length]);
for (int i = 0; i < vector.length; i++) {
b.bytes[i] = (byte) vector[i];
}
return b;
}
} }

View File

@ -25,15 +25,19 @@ import java.util.Objects;
import java.util.SplittableRandom; import java.util.SplittableRandom;
import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer; import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.InfoStream;
/** /**
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the * Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
* hyperparameters. * hyperparameters.
*
* @param <T> the type of vector
*/ */
public final class HnswGraphBuilder { public final class HnswGraphBuilder<T> {
/** Default random seed for level generation * */ /** Default random seed for level generation * */
private static final long DEFAULT_RAND_SEED = 42; private static final long DEFAULT_RAND_SEED = 42;
@ -49,9 +53,10 @@ public final class HnswGraphBuilder {
private final NeighborArray scratch; private final NeighborArray scratch;
private final VectorSimilarityFunction similarityFunction; private final VectorSimilarityFunction similarityFunction;
private final VectorEncoding vectorEncoding;
private final RandomAccessVectorValues vectorValues; private final RandomAccessVectorValues vectorValues;
private final SplittableRandom random; private final SplittableRandom random;
private final HnswGraphSearcher graphSearcher; private final HnswGraphSearcher<T> graphSearcher;
final OnHeapHnswGraph hnsw; final OnHeapHnswGraph hnsw;
@ -59,7 +64,18 @@ public final class HnswGraphBuilder {
// we need two sources of vectors in order to perform diversity check comparisons without // we need two sources of vectors in order to perform diversity check comparisons without
// colliding // colliding
private RandomAccessVectorValues buildVectors; private final RandomAccessVectorValues buildVectors;
public static HnswGraphBuilder<?> create(
RandomAccessVectorValuesProducer vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
int M,
int beamWidth,
long seed)
throws IOException {
return new HnswGraphBuilder<>(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
}
/** /**
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense * Reads all the vectors from a VectorValues, builds a graph connecting them by their dense
@ -73,8 +89,9 @@ public final class HnswGraphBuilder {
* @param seed the seed for a random number generator used during graph construction. Provide this * @param seed the seed for a random number generator used during graph construction. Provide this
* to ensure repeatable construction. * to ensure repeatable construction.
*/ */
public HnswGraphBuilder( private HnswGraphBuilder(
RandomAccessVectorValuesProducer vectors, RandomAccessVectorValuesProducer vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction, VectorSimilarityFunction similarityFunction,
int M, int M,
int beamWidth, int beamWidth,
@ -82,6 +99,7 @@ public final class HnswGraphBuilder {
throws IOException { throws IOException {
vectorValues = vectors.randomAccess(); vectorValues = vectors.randomAccess();
buildVectors = vectors.randomAccess(); buildVectors = vectors.randomAccess();
this.vectorEncoding = Objects.requireNonNull(vectorEncoding);
this.similarityFunction = Objects.requireNonNull(similarityFunction); this.similarityFunction = Objects.requireNonNull(similarityFunction);
if (M <= 0) { if (M <= 0) {
throw new IllegalArgumentException("maxConn must be positive"); throw new IllegalArgumentException("maxConn must be positive");
@ -97,7 +115,8 @@ public final class HnswGraphBuilder {
int levelOfFirstNode = getRandomGraphLevel(ml, random); int levelOfFirstNode = getRandomGraphLevel(ml, random);
this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode); this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode);
this.graphSearcher = this.graphSearcher =
new HnswGraphSearcher( new HnswGraphSearcher<>(
vectorEncoding,
similarityFunction, similarityFunction,
new NeighborQueue(beamWidth, true), new NeighborQueue(beamWidth, true),
new FixedBitSet(vectorValues.size())); new FixedBitSet(vectorValues.size()));
@ -110,7 +129,7 @@ public final class HnswGraphBuilder {
* enables efficient retrieval without extra data copying, while avoiding collision of the * enables efficient retrieval without extra data copying, while avoiding collision of the
* returned values. * returned values.
* *
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet * @param vectors the vectors for which to build a nearest neighbors graph. Must be an independent
* accessor for the vectors * accessor for the vectors
*/ */
public OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException { public OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
@ -121,15 +140,19 @@ public final class HnswGraphBuilder {
if (infoStream.isEnabled(HNSW_COMPONENT)) { if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(HNSW_COMPONENT, "build graph from " + vectors.size() + " vectors"); infoStream.message(HNSW_COMPONENT, "build graph from " + vectors.size() + " vectors");
} }
addVectors(vectors);
return hnsw;
}
private void addVectors(RandomAccessVectorValues vectors) throws IOException {
long start = System.nanoTime(), t = start; long start = System.nanoTime(), t = start;
// start at node 1! node 0 is added implicitly, in the constructor // start at node 1! node 0 is added implicitly, in the constructor
for (int node = 1; node < vectors.size(); node++) { for (int node = 1; node < vectors.size(); node++) {
addGraphNode(node, vectors.vectorValue(node)); addGraphNode(node, vectors);
if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) { if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
t = printGraphBuildStatus(node, start, t); t = printGraphBuildStatus(node, start, t);
} }
} }
return hnsw;
} }
/** Set info-stream to output debugging information * */ /** Set info-stream to output debugging information * */
@ -142,7 +165,7 @@ public final class HnswGraphBuilder {
} }
/** Inserts a doc with vector value to the graph */ /** Inserts a doc with vector value to the graph */
public void addGraphNode(int node, float[] value) throws IOException { public void addGraphNode(int node, T value) throws IOException {
NeighborQueue candidates; NeighborQueue candidates;
final int nodeLevel = getRandomGraphLevel(ml, random); final int nodeLevel = getRandomGraphLevel(ml, random);
int curMaxLevel = hnsw.numLevels() - 1; int curMaxLevel = hnsw.numLevels() - 1;
@ -167,6 +190,18 @@ public final class HnswGraphBuilder {
} }
} }
public void addGraphNode(int node, RandomAccessVectorValues values) throws IOException {
addGraphNode(node, getValue(node, values));
}
@SuppressWarnings("unchecked")
private T getValue(int node, RandomAccessVectorValues values) throws IOException {
return switch (vectorEncoding) {
case BYTE -> (T) values.binaryValue(node);
case FLOAT32 -> (T) values.vectorValue(node);
};
}
private long printGraphBuildStatus(int node, long start, long t) { private long printGraphBuildStatus(int node, long start, long t) {
long now = System.nanoTime(); long now = System.nanoTime();
infoStream.message( infoStream.message(
@ -215,7 +250,7 @@ public final class HnswGraphBuilder {
int cNode = candidates.node[i]; int cNode = candidates.node[i];
float cScore = candidates.score[i]; float cScore = candidates.score[i];
assert cNode < hnsw.size(); assert cNode < hnsw.size();
if (diversityCheck(vectorValues.vectorValue(cNode), cScore, neighbors, buildVectors)) { if (diversityCheck(cNode, cScore, neighbors)) {
neighbors.add(cNode, cScore); neighbors.add(cNode, cScore);
} }
} }
@ -237,19 +272,38 @@ public final class HnswGraphBuilder {
* @param score the score of the new candidate and node n, to be compared with scores of the * @param score the score of the new candidate and node n, to be compared with scores of the
* candidate and n's neighbors * candidate and n's neighbors
* @param neighbors the neighbors selected so far * @param neighbors the neighbors selected so far
* @param vectorValues source of values used for making comparisons between candidate and existing
* neighbors
* @return whether the candidate is diverse given the existing neighbors * @return whether the candidate is diverse given the existing neighbors
*/ */
private boolean diversityCheck( private boolean diversityCheck(int candidate, float score, NeighborArray neighbors)
float[] candidate, throws IOException {
float score, return isDiverse(candidate, neighbors, score);
NeighborArray neighbors, }
RandomAccessVectorValues vectorValues)
private boolean isDiverse(int candidate, NeighborArray neighbors, float score)
throws IOException {
return switch (vectorEncoding) {
case BYTE -> isDiverse(vectorValues.binaryValue(candidate), neighbors, score);
case FLOAT32 -> isDiverse(vectorValues.vectorValue(candidate), neighbors, score);
};
}
private boolean isDiverse(float[] candidate, NeighborArray neighbors, float score)
throws IOException { throws IOException {
for (int i = 0; i < neighbors.size(); i++) { for (int i = 0; i < neighbors.size(); i++) {
float neighborSimilarity = float neighborSimilarity =
similarityFunction.compare(candidate, vectorValues.vectorValue(neighbors.node[i])); similarityFunction.compare(candidate, buildVectors.vectorValue(neighbors.node[i]));
if (neighborSimilarity >= score) {
return false;
}
}
return true;
}
private boolean isDiverse(BytesRef candidate, NeighborArray neighbors, float score)
throws IOException {
for (int i = 0; i < neighbors.size(); i++) {
float neighborSimilarity =
similarityFunction.compare(candidate, buildVectors.binaryValue(neighbors.node[i]));
if (neighborSimilarity >= score) { if (neighborSimilarity >= score) {
return false; return false;
} }
@ -262,24 +316,52 @@ public final class HnswGraphBuilder {
* neighbours * neighbours
*/ */
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException { private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
float minAcceptedSimilarity;
for (int i = neighbors.size() - 1; i > 0; i--) { for (int i = neighbors.size() - 1; i > 0; i--) {
int cNode = neighbors.node[i]; if (isWorstNonDiverse(i, neighbors, neighbors.score[i])) {
float[] cVector = vectorValues.vectorValue(cNode); return i;
minAcceptedSimilarity = neighbors.score[i];
// check the candidate against its better-scoring neighbors
for (int j = i - 1; j >= 0; j--) {
float neighborSimilarity =
similarityFunction.compare(cVector, buildVectors.vectorValue(neighbors.node[j]));
// node i is too similar to node j given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return i;
}
} }
} }
return neighbors.size() - 1; return neighbors.size() - 1;
} }
private boolean isWorstNonDiverse(
int candidate, NeighborArray neighbors, float minAcceptedSimilarity) throws IOException {
return switch (vectorEncoding) {
case BYTE -> isWorstNonDiverse(
candidate, vectorValues.binaryValue(candidate), neighbors, minAcceptedSimilarity);
case FLOAT32 -> isWorstNonDiverse(
candidate, vectorValues.vectorValue(candidate), neighbors, minAcceptedSimilarity);
};
}
private boolean isWorstNonDiverse(
int candidateIndex, float[] candidate, NeighborArray neighbors, float minAcceptedSimilarity)
throws IOException {
for (int i = candidateIndex - 1; i > -0; i--) {
float neighborSimilarity =
similarityFunction.compare(candidate, buildVectors.vectorValue(neighbors.node[i]));
// node i is too similar to node j given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return false;
}
}
return true;
}
private boolean isWorstNonDiverse(
int candidateIndex, BytesRef candidate, NeighborArray neighbors, float minAcceptedSimilarity)
throws IOException {
for (int i = candidateIndex - 1; i > -0; i--) {
float neighborSimilarity =
similarityFunction.compare(candidate, buildVectors.binaryValue(neighbors.node[i]));
// node i is too similar to node j given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return false;
}
}
return true;
}
private static int getRandomGraphLevel(double ml, SplittableRandom random) { private static int getRandomGraphLevel(double ml, SplittableRandom random) {
double randDouble; double randDouble;
do { do {

View File

@ -18,21 +18,28 @@
package org.apache.lucene.util.hnsw; package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.VectorUtil.toBytesRef;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.SparseFixedBitSet; import org.apache.lucene.util.SparseFixedBitSet;
/** /**
* Searches an HNSW graph to find nearest neighbors to a query vector. For more background on the * Searches an HNSW graph to find nearest neighbors to a query vector. For more background on the
* search algorithm, see {@link HnswGraph}. * search algorithm, see {@link HnswGraph}.
*
* @param <T> the type of query vector
*/ */
public final class HnswGraphSearcher { public class HnswGraphSearcher<T> {
private final VectorSimilarityFunction similarityFunction; private final VectorSimilarityFunction similarityFunction;
private final VectorEncoding vectorEncoding;
/** /**
* Scratch data structures that are used in each {@link #searchLevel} call. These can be expensive * Scratch data structures that are used in each {@link #searchLevel} call. These can be expensive
* to allocate, so they're cleared and reused across calls. * to allocate, so they're cleared and reused across calls.
@ -49,7 +56,11 @@ public final class HnswGraphSearcher {
* @param visited bit set that will track nodes that have already been visited * @param visited bit set that will track nodes that have already been visited
*/ */
public HnswGraphSearcher( public HnswGraphSearcher(
VectorSimilarityFunction similarityFunction, NeighborQueue candidates, BitSet visited) { VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
NeighborQueue candidates,
BitSet visited) {
this.vectorEncoding = vectorEncoding;
this.similarityFunction = similarityFunction; this.similarityFunction = similarityFunction;
this.candidates = candidates; this.candidates = candidates;
this.visited = visited; this.visited = visited;
@ -73,13 +84,68 @@ public final class HnswGraphSearcher {
float[] query, float[] query,
int topK, int topK,
RandomAccessVectorValues vectors, RandomAccessVectorValues vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction, VectorSimilarityFunction similarityFunction,
HnswGraph graph, HnswGraph graph,
Bits acceptOrds, Bits acceptOrds,
int visitedLimit) int visitedLimit)
throws IOException { throws IOException {
HnswGraphSearcher graphSearcher = if (query.length != vectors.dimension()) {
new HnswGraphSearcher( throw new IllegalArgumentException(
"vector query dimension: "
+ query.length
+ " differs from field dimension: "
+ vectors.dimension());
}
if (vectorEncoding == VectorEncoding.BYTE) {
return search(
toBytesRef(query),
topK,
vectors,
vectorEncoding,
similarityFunction,
graph,
acceptOrds,
visitedLimit);
}
HnswGraphSearcher<float[]> graphSearcher =
new HnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
NeighborQueue results;
int[] eps = new int[] {graph.entryNode()};
int numVisited = 0;
for (int level = graph.numLevels() - 1; level >= 1; level--) {
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
numVisited += results.visitedCount();
visitedLimit -= results.visitedCount();
if (results.incomplete()) {
results.setVisitedCount(numVisited);
return results;
}
eps[0] = results.pop();
}
results =
graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);
results.setVisitedCount(results.visitedCount() + numVisited);
return results;
}
private static NeighborQueue search(
BytesRef query,
int topK,
RandomAccessVectorValues vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
HnswGraph graph,
Bits acceptOrds,
int visitedLimit)
throws IOException {
HnswGraphSearcher<BytesRef> graphSearcher =
new HnswGraphSearcher<>(
vectorEncoding,
similarityFunction, similarityFunction,
new NeighborQueue(topK, true), new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size())); new SparseFixedBitSet(vectors.size()));
@ -119,7 +185,8 @@ public final class HnswGraphSearcher {
* @return a priority queue holding the closest neighbors found * @return a priority queue holding the closest neighbors found
*/ */
public NeighborQueue searchLevel( public NeighborQueue searchLevel(
float[] query, // Note: this is only public because Lucene91HnswGraphBuilder needs it
T query,
int topK, int topK,
int level, int level,
final int[] eps, final int[] eps,
@ -130,7 +197,7 @@ public final class HnswGraphSearcher {
} }
private NeighborQueue searchLevel( private NeighborQueue searchLevel(
float[] query, T query,
int topK, int topK,
int level, int level,
final int[] eps, final int[] eps,
@ -150,7 +217,7 @@ public final class HnswGraphSearcher {
results.markIncomplete(); results.markIncomplete();
break; break;
} }
float score = similarityFunction.compare(query, vectors.vectorValue(ep)); float score = compare(query, vectors, ep);
numVisited++; numVisited++;
candidates.add(ep, score); candidates.add(ep, score);
if (acceptOrds == null || acceptOrds.get(ep)) { if (acceptOrds == null || acceptOrds.get(ep)) {
@ -185,7 +252,7 @@ public final class HnswGraphSearcher {
results.markIncomplete(); results.markIncomplete();
break; break;
} }
float friendSimilarity = similarityFunction.compare(query, vectors.vectorValue(friendOrd)); float friendSimilarity = compare(query, vectors, friendOrd);
numVisited++; numVisited++;
if (friendSimilarity >= minAcceptedSimilarity) { if (friendSimilarity >= minAcceptedSimilarity) {
candidates.add(friendOrd, friendSimilarity); candidates.add(friendOrd, friendSimilarity);
@ -204,6 +271,14 @@ public final class HnswGraphSearcher {
return results; return results;
} }
private float compare(T query, RandomAccessVectorValues vectors, int ord) throws IOException {
if (vectorEncoding == VectorEncoding.BYTE) {
return similarityFunction.compare((BytesRef) query, vectors.binaryValue(ord));
} else {
return similarityFunction.compare((float[]) query, vectors.vectorValue(ord));
}
}
private void prepareScratchState(int capacity) { private void prepareScratchState(int capacity) {
candidates.clear(); candidates.clear();
if (visited.length() < capacity) { if (visited.length() < capacity) {

View File

@ -178,7 +178,7 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
return new KnnVectorsWriter() { return new KnnVectorsWriter() {
@Override @Override
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
fieldsWritten.add(fieldInfo.name); fieldsWritten.add(fieldInfo.name);
return writer.addField(fieldInfo); return writer.addField(fieldInfo);
} }

View File

@ -112,6 +112,7 @@ public class TestCodecs extends LuceneTestCase {
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
false)); false));
} }

View File

@ -260,6 +260,7 @@ public class TestFieldInfos extends LuceneTestCase {
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
false)); false));
} }
@ -279,6 +280,7 @@ public class TestFieldInfos extends LuceneTestCase {
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
false)); false));
assertEquals("Field numbers 0 through 9 were allocated", 10, idx); assertEquals("Field numbers 0 through 9 were allocated", 10, idx);
@ -300,6 +302,7 @@ public class TestFieldInfos extends LuceneTestCase {
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
false)); false));
assertEquals("Field numbers should reset after clear()", 0, idx); assertEquals("Field numbers should reset after clear()", 0, idx);

View File

@ -64,6 +64,7 @@ public class TestFieldsReader extends LuceneTestCase {
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
field.name().equals(softDeletesFieldName))); field.name().equals(softDeletesFieldName)));
} }

View File

@ -113,6 +113,11 @@ public class TestIndexableField extends LuceneTestCase {
return 0; return 0;
} }
@Override
public VectorEncoding vectorEncoding() {
return VectorEncoding.FLOAT32;
}
@Override @Override
public VectorSimilarityFunction vectorSimilarityFunction() { public VectorSimilarityFunction vectorSimilarityFunction() {
return VectorSimilarityFunction.EUCLIDEAN; return VectorSimilarityFunction.EUCLIDEAN;

View File

@ -67,6 +67,8 @@ public class TestKnnGraph extends LuceneTestCase {
private static int M = Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN; private static int M = Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN;
private Codec codec; private Codec codec;
private Codec float32Codec;
private VectorEncoding vectorEncoding;
private VectorSimilarityFunction similarityFunction; private VectorSimilarityFunction similarityFunction;
@Before @Before
@ -86,6 +88,31 @@ public class TestKnnGraph extends LuceneTestCase {
int similarity = random().nextInt(VectorSimilarityFunction.values().length - 1) + 1; int similarity = random().nextInt(VectorSimilarityFunction.values().length - 1) + 1;
similarityFunction = VectorSimilarityFunction.values()[similarity]; similarityFunction = VectorSimilarityFunction.values()[similarity];
vectorEncoding = randomVectorEncoding();
codec =
new Lucene94Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene94HnswVectorsFormat(M, Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
}
};
if (vectorEncoding == VectorEncoding.FLOAT32) {
float32Codec = codec;
} else {
float32Codec =
new Lucene94Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene94HnswVectorsFormat(M, Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
}
};
}
}
private VectorEncoding randomVectorEncoding() {
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
} }
@After @After
@ -102,10 +129,7 @@ public class TestKnnGraph extends LuceneTestCase {
float[][] values = new float[numDoc][]; float[][] values = new float[numDoc][];
for (int i = 0; i < numDoc; i++) { for (int i = 0; i < numDoc; i++) {
if (random().nextBoolean()) { if (random().nextBoolean()) {
values[i] = new float[dimension]; values[i] = randomVector(dimension);
for (int j = 0; j < dimension; j++) {
values[i][j] = random().nextFloat();
}
} }
add(iw, i, values[i]); add(iw, i, values[i]);
} }
@ -117,6 +141,14 @@ public class TestKnnGraph extends LuceneTestCase {
try (Directory dir = newDirectory(); try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(codec))) { IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(codec))) {
float[][] values = new float[][] {new float[] {0, 1, 2}}; float[][] values = new float[][] {new float[] {0, 1, 2}};
if (similarityFunction == VectorSimilarityFunction.DOT_PRODUCT) {
VectorUtil.l2normalize(values[0]);
}
if (vectorEncoding == VectorEncoding.BYTE) {
for (int i = 0; i < 3; i++) {
values[0][i] = (float) Math.floor(values[0][i] * 127);
}
}
add(iw, 0, values[0]); add(iw, 0, values[0]);
assertConsistentGraph(iw, values); assertConsistentGraph(iw, values);
iw.commit(); iw.commit();
@ -133,11 +165,7 @@ public class TestKnnGraph extends LuceneTestCase {
float[][] values = randomVectors(numDoc, dimension); float[][] values = randomVectors(numDoc, dimension);
for (int i = 0; i < numDoc; i++) { for (int i = 0; i < numDoc; i++) {
if (random().nextBoolean()) { if (random().nextBoolean()) {
values[i] = new float[dimension]; values[i] = randomVector(dimension);
for (int j = 0; j < dimension; j++) {
values[i][j] = random().nextFloat();
}
VectorUtil.l2normalize(values[i]);
} }
add(iw, i, values[i]); add(iw, i, values[i]);
if (random().nextInt(10) == 3) { if (random().nextInt(10) == 3) {
@ -249,16 +277,26 @@ public class TestKnnGraph extends LuceneTestCase {
float[][] values = new float[numDoc][]; float[][] values = new float[numDoc][];
for (int i = 0; i < numDoc; i++) { for (int i = 0; i < numDoc; i++) {
if (random().nextBoolean()) { if (random().nextBoolean()) {
values[i] = new float[dimension]; values[i] = randomVector(dimension);
for (int j = 0; j < dimension; j++) {
values[i][j] = random().nextFloat();
}
VectorUtil.l2normalize(values[i]);
} }
} }
return values; return values;
} }
private float[] randomVector(int dimension) {
float[] value = new float[dimension];
for (int j = 0; j < dimension; j++) {
value[j] = random().nextFloat();
}
VectorUtil.l2normalize(value);
if (vectorEncoding == VectorEncoding.BYTE) {
for (int j = 0; j < dimension; j++) {
value[j] = (byte) (value[j] * 127);
}
}
return value;
}
int[][][] copyGraph(HnswGraph graphValues) throws IOException { int[][][] copyGraph(HnswGraph graphValues) throws IOException {
int[][][] graph = new int[graphValues.numLevels()][][]; int[][][] graph = new int[graphValues.numLevels()][][];
int size = graphValues.size(); int size = graphValues.size();
@ -285,7 +323,7 @@ public class TestKnnGraph extends LuceneTestCase {
// We can't use dot product here since the vectors are laid out on a grid, not a sphere. // We can't use dot product here since the vectors are laid out on a grid, not a sphere.
similarityFunction = VectorSimilarityFunction.EUCLIDEAN; similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
IndexWriterConfig config = newIndexWriterConfig(); IndexWriterConfig config = newIndexWriterConfig();
config.setCodec(codec); // test is not compatible with simpletext config.setCodec(float32Codec);
try (Directory dir = newDirectory(); try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, config)) { IndexWriter iw = new IndexWriter(dir, config)) {
indexData(iw); indexData(iw);
@ -341,7 +379,7 @@ public class TestKnnGraph extends LuceneTestCase {
public void testMultiThreadedSearch() throws Exception { public void testMultiThreadedSearch() throws Exception {
similarityFunction = VectorSimilarityFunction.EUCLIDEAN; similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
IndexWriterConfig config = newIndexWriterConfig(); IndexWriterConfig config = newIndexWriterConfig();
config.setCodec(codec); config.setCodec(float32Codec);
Directory dir = newDirectory(); Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, config); IndexWriter iw = new IndexWriter(dir, config);
indexData(iw); indexData(iw);
@ -468,7 +506,7 @@ public class TestKnnGraph extends LuceneTestCase {
"vector did not match for doc " + i + ", id=" + id + ": " + Arrays.toString(scratch), "vector did not match for doc " + i + ", id=" + id + ": " + Arrays.toString(scratch),
values[id], values[id],
scratch, scratch,
0f); 0);
numDocsWithVectors++; numDocsWithVectors++;
} }
// if IndexDisi.doc == NO_MORE_DOCS, we should not call IndexDisi.nextDoc() // if IndexDisi.doc == NO_MORE_DOCS, we should not call IndexDisi.nextDoc()

View File

@ -196,6 +196,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
true); true);
List<Integer> docsDeleted = Arrays.asList(1, 3, 7, 8, DocIdSetIterator.NO_MORE_DOCS); List<Integer> docsDeleted = Arrays.asList(1, 3, 7, 8, DocIdSetIterator.NO_MORE_DOCS);
@ -233,6 +234,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
true); true);
for (DocValuesFieldUpdates update : updates) { for (DocValuesFieldUpdates update : updates) {
@ -295,6 +297,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
true); true);
List<Integer> docsDeleted = Arrays.asList(1, DocIdSetIterator.NO_MORE_DOCS); List<Integer> docsDeleted = Arrays.asList(1, DocIdSetIterator.NO_MORE_DOCS);
@ -362,6 +365,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
true); true);
List<DocValuesFieldUpdates> updates = List<DocValuesFieldUpdates> updates =
@ -398,6 +402,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
true); true);
updates = Arrays.asList(singleUpdate(Arrays.asList(1, DocIdSetIterator.NO_MORE_DOCS), 3, true)); updates = Arrays.asList(singleUpdate(Arrays.asList(1, DocIdSetIterator.NO_MORE_DOCS), 3, true));

View File

@ -25,6 +25,7 @@ import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
@ -117,6 +118,12 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
return null; return null;
} }
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
return null;
}
@Override @Override
protected void doClose() {} protected void doClose() {}

View File

@ -19,6 +19,7 @@ package org.apache.lucene.search;
import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently; import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently;
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.TestVectorUtil.randomVector; import static org.apache.lucene.util.TestVectorUtil.randomVector;
@ -40,7 +41,7 @@ import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.LuceneTestCase;
@ -48,6 +49,7 @@ import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.VectorUtil;
@ -174,7 +176,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0}, 10); KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0}, 10);
IllegalArgumentException e = IllegalArgumentException e =
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10)); expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
assertEquals("vector dimensions differ: 1!=2", e.getMessage()); assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
} }
} }
@ -239,43 +241,38 @@ public class TestKnnVectorQuery extends LuceneTestCase {
} }
public void testScoreEuclidean() throws IOException { public void testScoreEuclidean() throws IOException {
try (Directory d = newDirectory()) { float[][] vectors = new float[5][];
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) { for (int j = 0; j < 5; j++) {
for (int j = 0; j < 5; j++) { vectors[j] = new float[] {j, j};
Document doc = new Document(); }
doc.add( try (Directory d = getIndexStore("field", vectors);
new KnnVectorField("field", new float[] {j, j}, VectorSimilarityFunction.EUCLIDEAN)); IndexReader reader = DirectoryReader.open(d)) {
w.addDocument(doc); assertEquals(1, reader.leaves().size());
} IndexSearcher searcher = new IndexSearcher(reader);
} KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
try (IndexReader reader = DirectoryReader.open(d)) { Query rewritten = query.rewrite(reader);
assertEquals(1, reader.leaves().size()); Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
IndexSearcher searcher = new IndexSearcher(reader); Scorer scorer = weight.scorer(reader.leaves().get(0));
KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
Query rewritten = query.rewrite(reader);
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
Scorer scorer = weight.scorer(reader.leaves().get(0));
// prior to advancing, score is 0 // prior to advancing, score is 0
assertEquals(-1, scorer.docID()); assertEquals(-1, scorer.docID());
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score); expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
// test getMaxScore // test getMaxScore
assertEquals(0, scorer.getMaxScore(-1), 0); assertEquals(0, scorer.getMaxScore(-1), 0);
assertEquals(0, scorer.getMaxScore(0), 0); assertEquals(0, scorer.getMaxScore(0), 0);
// This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5 // This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
assertEquals(1 / 2f, scorer.getMaxScore(2), 0); assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0); assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
DocIdSetIterator it = scorer.iterator(); DocIdSetIterator it = scorer.iterator();
assertEquals(3, it.cost()); assertEquals(3, it.cost());
assertEquals(1, it.nextDoc()); assertEquals(1, it.nextDoc());
assertEquals(1 / 6f, scorer.score(), 0); assertEquals(1 / 6f, scorer.score(), 0);
assertEquals(3, it.advance(3)); assertEquals(3, it.advance(3));
assertEquals(1 / 2f, scorer.score(), 0); assertEquals(1 / 2f, scorer.score(), 0);
assertEquals(NO_MORE_DOCS, it.advance(4)); assertEquals(NO_MORE_DOCS, it.advance(4));
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score); expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
}
} }
} }
@ -764,9 +761,18 @@ public class TestKnnVectorQuery extends LuceneTestCase {
private Directory getIndexStore(String field, float[]... contents) throws IOException { private Directory getIndexStore(String field, float[]... contents) throws IOException {
Directory indexStore = newDirectory(); Directory indexStore = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore); RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
VectorEncoding encoding = randomVectorEncoding();
for (int i = 0; i < contents.length; ++i) { for (int i = 0; i < contents.length; ++i) {
Document doc = new Document(); Document doc = new Document();
doc.add(new KnnVectorField(field, contents[i])); if (encoding == VectorEncoding.BYTE) {
BytesRef v = new BytesRef(new byte[contents[i].length]);
for (int j = 0; j < v.length; j++) {
v.bytes[j] = (byte) contents[i][j];
}
doc.add(new KnnVectorField(field, v, EUCLIDEAN));
} else {
doc.add(new KnnVectorField(field, contents[i]));
}
doc.add(new StringField("id", "id" + i, Field.Store.YES)); doc.add(new StringField("id", "id" + i, Field.Store.YES));
writer.addDocument(doc); writer.addDocument(doc);
} }
@ -908,4 +914,8 @@ public class TestKnnVectorQuery extends LuceneTestCase {
return 31 * classHash() + docs.hashCode(); return 31 * classHash() + docs.hashCode();
} }
} }
private VectorEncoding randomVectorEncoding() {
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
}
} }

View File

@ -50,6 +50,7 @@ import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.PointValues; import org.apache.lucene.index.PointValues;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms; import org.apache.lucene.index.Terms;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.SortField.Type; import org.apache.lucene.search.SortField.Type;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
@ -1126,6 +1127,7 @@ public class TestSortOptimization extends LuceneTestCase {
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.DOT_PRODUCT, VectorSimilarityFunction.DOT_PRODUCT,
fi.isSoftDeletesField()); fi.isSoftDeletesField());
newInfos[i] = noIndexFI; newInfos[i] = noIndexFI;

View File

@ -18,6 +18,7 @@ package org.apache.lucene.util;
import java.util.Random; import java.util.Random;
import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
public class TestVectorUtil extends LuceneTestCase { public class TestVectorUtil extends LuceneTestCase {
@ -130,6 +131,23 @@ public class TestVectorUtil extends LuceneTestCase {
return u; return u;
} }
private static BytesRef negative(BytesRef v) {
BytesRef u = new BytesRef(new byte[v.length]);
for (int i = 0; i < v.length; i++) {
// what is (byte) -(-128)? 127?
u.bytes[i] = (byte) -v.bytes[i];
}
return u;
}
private static float l2(BytesRef v) {
float l2 = 0;
for (int i = v.offset; i < v.offset + v.length; i++) {
l2 += v.bytes[i] * v.bytes[i];
}
return l2;
}
private static float[] randomVector() { private static float[] randomVector() {
return randomVector(random().nextInt(100) + 1); return randomVector(random().nextInt(100) + 1);
} }
@ -142,4 +160,88 @@ public class TestVectorUtil extends LuceneTestCase {
} }
return v; return v;
} }
private static BytesRef randomVectorBytes() {
BytesRef v = TestUtil.randomBinaryTerm(random(), TestUtil.nextInt(random(), 1, 100));
// clip at -127 to avoid overflow
for (int i = v.offset; i < v.offset + v.length; i++) {
if (v.bytes[i] == -128) {
v.bytes[i] = -127;
}
}
return v;
}
public void testBasicDotProductBytes() {
BytesRef a = new BytesRef(new byte[] {1, 2, 3});
BytesRef b = new BytesRef(new byte[] {-10, 0, 5});
assertEquals(5, VectorUtil.dotProduct(a, b), 0);
assertEquals(5 / (3f * (1 << 15)), VectorUtil.dotProductScore(a, b), DELTA);
}
public void testSelfDotProductBytes() {
// the dot product of a vector with itself is equal to the sum of the squares of its components
BytesRef v = randomVectorBytes();
assertEquals(l2(v), VectorUtil.dotProduct(v, v), DELTA);
}
public void testOrthogonalDotProductBytes() {
// the dot product of two perpendicular vectors is 0
byte[] v = new byte[4];
v[0] = (byte) random().nextInt(100);
v[1] = (byte) random().nextInt(100);
v[2] = v[1];
v[3] = (byte) -v[0];
// also test computing using BytesRef with nonzero offset
assertEquals(0, VectorUtil.dotProduct(new BytesRef(v, 0, 2), new BytesRef(v, 2, 2)), DELTA);
}
public void testSelfSquareDistanceBytes() {
// the l2 distance of a vector with itself is zero
BytesRef v = randomVectorBytes();
assertEquals(0, VectorUtil.squareDistance(v, v), DELTA);
}
public void testBasicSquareDistanceBytes() {
assertEquals(
12,
VectorUtil.squareDistance(
new BytesRef(new byte[] {1, 2, 3}), new BytesRef(new byte[] {-1, 0, 5})),
0);
}
public void testRandomSquareDistanceBytes() {
// the square distance of a vector with its inverse is equal to four times the sum of squares of
// its components
BytesRef v = randomVectorBytes();
BytesRef u = negative(v);
assertEquals(4 * l2(v), VectorUtil.squareDistance(u, v), DELTA);
}
public void testBasicCosineBytes() {
assertEquals(
0.11952f,
VectorUtil.cosine(new BytesRef(new byte[] {1, 2, 3}), new BytesRef(new byte[] {-10, 0, 5})),
DELTA);
}
public void testSelfCosineBytes() {
// the dot product of a vector with itself is always equal to 1
BytesRef v = randomVectorBytes();
// ensure the vector is non-zero so that cosine is defined
v.bytes[0] = (byte) (random().nextInt(126) + 1);
assertEquals(1.0f, VectorUtil.cosine(v, v), DELTA);
}
public void testOrthogonalCosineBytes() {
// the cosine of two perpendicular vectors is 0
float[] v = new float[2];
v[0] = random().nextInt(100);
// ensure the vector is non-zero so that cosine is defined
v[1] = random().nextInt(1, 100);
float[] u = new float[2];
u[0] = v[1];
u[1] = -v[0];
assertEquals(0, VectorUtil.cosine(u, v), DELTA);
}
} }

View File

@ -56,6 +56,7 @@ import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer; import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.ConstantScoreScorer; import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight; import org.apache.lucene.search.ConstantScoreWeight;
@ -101,6 +102,7 @@ public class KnnGraphTester {
private int beamWidth; private int beamWidth;
private int maxConn; private int maxConn;
private VectorSimilarityFunction similarityFunction; private VectorSimilarityFunction similarityFunction;
private VectorEncoding vectorEncoding;
private FixedBitSet matchDocs; private FixedBitSet matchDocs;
private float selectivity; private float selectivity;
private boolean prefilter; private boolean prefilter;
@ -113,6 +115,7 @@ public class KnnGraphTester {
topK = 100; topK = 100;
fanout = topK; fanout = topK;
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
vectorEncoding = VectorEncoding.FLOAT32;
selectivity = 1f; selectivity = 1f;
prefilter = false; prefilter = false;
} }
@ -195,12 +198,30 @@ public class KnnGraphTester {
case "-docs": case "-docs":
docVectorsPath = Paths.get(args[++iarg]); docVectorsPath = Paths.get(args[++iarg]);
break; break;
case "-encoding":
String encoding = args[++iarg];
switch (encoding) {
case "byte":
vectorEncoding = VectorEncoding.BYTE;
break;
case "float32":
vectorEncoding = VectorEncoding.FLOAT32;
break;
default:
throw new IllegalArgumentException("-encoding can be 'byte' or 'float32' only");
}
break;
case "-metric": case "-metric":
String metric = args[++iarg]; String metric = args[++iarg];
if (metric.equals("euclidean")) { switch (metric) {
similarityFunction = VectorSimilarityFunction.EUCLIDEAN; case "euclidean":
} else if (metric.equals("angular") == false) { similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
throw new IllegalArgumentException("-metric can be 'angular' or 'euclidean' only"); break;
case "angular":
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
break;
default:
throw new IllegalArgumentException("-metric can be 'angular' or 'euclidean' only");
} }
break; break;
case "-forceMerge": case "-forceMerge":
@ -229,7 +250,7 @@ public class KnnGraphTester {
if (operation == null && reindex == false) { if (operation == null && reindex == false) {
usage(); usage();
} }
if (prefilter == true && selectivity == 1f) { if (prefilter && selectivity == 1f) {
throw new IllegalArgumentException("-prefilter requires filterSelectivity between 0 and 1"); throw new IllegalArgumentException("-prefilter requires filterSelectivity between 0 and 1");
} }
indexPath = Paths.get(formatIndexPath(docVectorsPath)); indexPath = Paths.get(formatIndexPath(docVectorsPath));
@ -248,7 +269,9 @@ public class KnnGraphTester {
if (docVectorsPath == null) { if (docVectorsPath == null) {
throw new IllegalArgumentException("missing -docs arg"); throw new IllegalArgumentException("missing -docs arg");
} }
matchDocs = generateRandomBitSet(numDocs, selectivity); if (selectivity < 1) {
matchDocs = generateRandomBitSet(numDocs, selectivity);
}
if (outputPath != null) { if (outputPath != null) {
testSearch(indexPath, queryPath, outputPath, null); testSearch(indexPath, queryPath, outputPath, null);
} else { } else {
@ -285,14 +308,17 @@ public class KnnGraphTester {
} }
} }
@SuppressWarnings("unchecked")
private void dumpGraph(Path docsPath) throws IOException { private void dumpGraph(Path docsPath) throws IOException {
try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) { try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) {
RandomAccessVectorValues values = vectors.randomAccess(); RandomAccessVectorValues values = vectors.randomAccess();
HnswGraphBuilder builder = HnswGraphBuilder<float[]> builder =
new HnswGraphBuilder(vectors, similarityFunction, maxConn, beamWidth, 0); (HnswGraphBuilder<float[]>)
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, maxConn, beamWidth, 0);
// start at node 1 // start at node 1
for (int i = 1; i < numDocs; i++) { for (int i = 1; i < numDocs; i++) {
builder.addGraphNode(i, values.vectorValue(i)); builder.addGraphNode(i, values);
System.out.println("\nITERATION " + i); System.out.println("\nITERATION " + i);
dumpGraph(builder.hnsw); dumpGraph(builder.hnsw);
} }
@ -375,13 +401,8 @@ public class KnnGraphTester {
throws IOException { throws IOException {
TopDocs[] results = new TopDocs[numIters]; TopDocs[] results = new TopDocs[numIters];
long elapsed, totalCpuTime, totalVisited = 0; long elapsed, totalCpuTime, totalVisited = 0;
try (FileChannel q = FileChannel.open(queryPath)) { try (FileChannel input = FileChannel.open(queryPath)) {
int bufferSize = numIters * dim * Float.BYTES; VectorReader targetReader = VectorReader.create(input, dim, vectorEncoding, numIters);
FloatBuffer targets =
q.map(FileChannel.MapMode.READ_ONLY, 0, bufferSize)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
float[] target = new float[dim];
if (quiet == false) { if (quiet == false) {
System.out.println("running " + numIters + " targets; topK=" + topK + ", fanout=" + fanout); System.out.println("running " + numIters + " targets; topK=" + topK + ", fanout=" + fanout);
} }
@ -392,21 +413,21 @@ public class KnnGraphTester {
DirectoryReader reader = DirectoryReader.open(dir)) { DirectoryReader reader = DirectoryReader.open(dir)) {
IndexSearcher searcher = new IndexSearcher(reader); IndexSearcher searcher = new IndexSearcher(reader);
numDocs = reader.maxDoc(); numDocs = reader.maxDoc();
Query bitSetQuery = new BitSetQuery(matchDocs); Query bitSetQuery = prefilter ? new BitSetQuery(matchDocs) : null;
for (int i = 0; i < numIters; i++) { for (int i = 0; i < numIters; i++) {
// warm up // warm up
targets.get(target); float[] target = targetReader.next();
if (prefilter) { if (prefilter) {
doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery); doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
} else { } else {
doKnnVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null); doKnnVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
} }
} }
targets.position(0); targetReader.reset();
start = System.nanoTime(); start = System.nanoTime();
cpuTimeStartNs = bean.getCurrentThreadCpuTime(); cpuTimeStartNs = bean.getCurrentThreadCpuTime();
for (int i = 0; i < numIters; i++) { for (int i = 0; i < numIters; i++) {
targets.get(target); float[] target = targetReader.next();
if (prefilter) { if (prefilter) {
results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery); results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
} else { } else {
@ -414,10 +435,12 @@ public class KnnGraphTester {
doKnnVectorQuery( doKnnVectorQuery(
searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null); searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
results[i].scoreDocs = if (matchDocs != null) {
Arrays.stream(results[i].scoreDocs) results[i].scoreDocs =
.filter(scoreDoc -> matchDocs == null || matchDocs.get(scoreDoc.doc)) Arrays.stream(results[i].scoreDocs)
.toArray(ScoreDoc[]::new); .filter(scoreDoc -> matchDocs.get(scoreDoc.doc))
.toArray(ScoreDoc[]::new);
}
} }
} }
totalCpuTime = (bean.getCurrentThreadCpuTime() - cpuTimeStartNs) / 1_000_000; totalCpuTime = (bean.getCurrentThreadCpuTime() - cpuTimeStartNs) / 1_000_000;
@ -425,7 +448,14 @@ public class KnnGraphTester {
for (int i = 0; i < numIters; i++) { for (int i = 0; i < numIters; i++) {
totalVisited += results[i].totalHits.value; totalVisited += results[i].totalHits.value;
for (ScoreDoc doc : results[i].scoreDocs) { for (ScoreDoc doc : results[i].scoreDocs) {
doc.doc = Integer.parseInt(reader.document(doc.doc).get("id")); if (doc.doc != NO_MORE_DOCS) {
// there is a bug somewhere that can result in doc=NO_MORE_DOCS! I think it happens
// in some degenerate case (like input query has NaN in it?) that causes no results to
// be returned from HNSW search?
doc.doc = Integer.parseInt(reader.document(doc.doc).get("id"));
} else {
System.out.println("NO_MORE_DOCS!");
}
} }
} }
} }
@ -477,6 +507,78 @@ public class KnnGraphTester {
} }
} }
private abstract static class VectorReader {
final float[] target;
final ByteBuffer bytes;
static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding, int n)
throws IOException {
int bufferSize = n * dim * vectorEncoding.byteSize;
return switch (vectorEncoding) {
case BYTE -> new VectorReaderByte(input, dim, bufferSize);
case FLOAT32 -> new VectorReaderFloat32(input, dim, bufferSize);
};
}
VectorReader(FileChannel input, int dim, int bufferSize) throws IOException {
bytes =
input.map(FileChannel.MapMode.READ_ONLY, 0, bufferSize).order(ByteOrder.LITTLE_ENDIAN);
target = new float[dim];
}
void reset() {
bytes.position(0);
}
abstract float[] next();
}
private static class VectorReaderFloat32 extends VectorReader {
private final FloatBuffer floats;
VectorReaderFloat32(FileChannel input, int dim, int bufferSize) throws IOException {
super(input, dim, bufferSize);
floats = bytes.asFloatBuffer();
}
@Override
void reset() {
super.reset();
floats.position(0);
}
@Override
float[] next() {
floats.get(target);
return target;
}
}
private static class VectorReaderByte extends VectorReader {
private byte[] scratch;
private BytesRef bytesRef;
VectorReaderByte(FileChannel input, int dim, int bufferSize) throws IOException {
super(input, dim, bufferSize);
scratch = new byte[dim];
bytesRef = new BytesRef(scratch);
}
@Override
float[] next() {
bytes.get(scratch);
for (int i = 0; i < scratch.length; i++) {
target[i] = scratch[i];
}
return target;
}
BytesRef nextBytes() {
bytes.get(scratch);
return bytesRef;
}
}
private static TopDocs doKnnVectorQuery( private static TopDocs doKnnVectorQuery(
IndexSearcher searcher, String field, float[] vector, int k, int fanout, Query filter) IndexSearcher searcher, String field, float[] vector, int k, int fanout, Query filter)
throws IOException { throws IOException {
@ -529,7 +631,9 @@ public class KnnGraphTester {
if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath) && selectivity == 1f) { if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath) && selectivity == 1f) {
return readNN(nnPath); return readNN(nnPath);
} else { } else {
int[][] nn = computeNN(docPath, queryPath); // TODO: enable computing NN from high precision vectors when
// checking low-precision recall
int[][] nn = computeNN(docPath, queryPath, vectorEncoding);
if (selectivity == 1f) { if (selectivity == 1f) {
writeNN(nn, nnPath); writeNN(nn, nnPath);
} }
@ -589,52 +693,37 @@ public class KnnGraphTester {
return bitSet; return bitSet;
} }
private int[][] computeNN(Path docPath, Path queryPath) throws IOException { private int[][] computeNN(Path docPath, Path queryPath, VectorEncoding encoding)
throws IOException {
int[][] result = new int[numIters][]; int[][] result = new int[numIters][];
if (quiet == false) { if (quiet == false) {
System.out.println("computing true nearest neighbors of " + numIters + " target vectors"); System.out.println("computing true nearest neighbors of " + numIters + " target vectors");
} }
try (FileChannel in = FileChannel.open(docPath); try (FileChannel in = FileChannel.open(docPath);
FileChannel qIn = FileChannel.open(queryPath)) { FileChannel qIn = FileChannel.open(queryPath)) {
FloatBuffer queries = VectorReader docReader = VectorReader.create(in, dim, encoding, numDocs);
qIn.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES) VectorReader queryReader = VectorReader.create(qIn, dim, encoding, numIters);
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
float[] vector = new float[dim];
float[] query = new float[dim];
for (int i = 0; i < numIters; i++) { for (int i = 0; i < numIters; i++) {
queries.get(query); float[] query = queryReader.next();
long totalBytes = (long) numDocs * dim * Float.BYTES; NeighborQueue queue = new NeighborQueue(topK, false);
final int maxBlockSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES); for (int j = 0; j < numDocs; j++) {
int offset = 0; float[] doc = docReader.next();
int j = 0; float d = similarityFunction.compare(query, doc);
// System.out.println("totalBytes=" + totalBytes); if (matchDocs == null || matchDocs.get(j)) {
while (j < numDocs) { queue.insertWithOverflow(j, d);
int blockSize = (int) Math.min(totalBytes - offset, maxBlockSize);
FloatBuffer vectors =
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
offset += blockSize;
NeighborQueue queue = new NeighborQueue(topK, false);
for (; j < numDocs && vectors.hasRemaining(); j++) {
vectors.get(vector);
float d = similarityFunction.compare(query, vector);
if (matchDocs == null || matchDocs.get(j)) {
queue.insertWithOverflow(j, d);
}
}
result[i] = new int[topK];
for (int k = topK - 1; k >= 0; k--) {
result[i][k] = queue.topNode();
queue.pop();
// System.out.print(" " + n);
}
if (quiet == false && (i + 1) % 10 == 0) {
System.out.print(" " + (i + 1));
System.out.flush();
} }
} }
docReader.reset();
result[i] = new int[topK];
for (int k = topK - 1; k >= 0; k--) {
result[i][k] = queue.topNode();
queue.pop();
// System.out.print(" " + n);
}
if (quiet == false && (i + 1) % 10 == 0) {
System.out.print(" " + (i + 1));
System.out.flush();
}
} }
} }
return result; return result;
@ -651,37 +740,29 @@ public class KnnGraphTester {
}); });
// iwc.setMergePolicy(NoMergePolicy.INSTANCE); // iwc.setMergePolicy(NoMergePolicy.INSTANCE);
iwc.setRAMBufferSizeMB(1994d); iwc.setRAMBufferSizeMB(1994d);
iwc.setUseCompoundFile(false);
// iwc.setMaxBufferedDocs(10000); // iwc.setMaxBufferedDocs(10000);
FieldType fieldType = KnnVectorField.createFieldType(dim, similarityFunction); FieldType fieldType = KnnVectorField.createFieldType(dim, vectorEncoding, similarityFunction);
if (quiet == false) { if (quiet == false) {
iwc.setInfoStream(new PrintStreamInfoStream(System.out)); iwc.setInfoStream(new PrintStreamInfoStream(System.out));
System.out.println("creating index in " + indexPath); System.out.println("creating index in " + indexPath);
} }
long start = System.nanoTime(); long start = System.nanoTime();
long totalBytes = (long) numDocs * dim * Float.BYTES, offset = 0;
final int maxBlockSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
try (FSDirectory dir = FSDirectory.open(indexPath); try (FSDirectory dir = FSDirectory.open(indexPath);
IndexWriter iw = new IndexWriter(dir, iwc)) { IndexWriter iw = new IndexWriter(dir, iwc)) {
float[] vector = new float[dim];
try (FileChannel in = FileChannel.open(docsPath)) { try (FileChannel in = FileChannel.open(docsPath)) {
int i = 0; VectorReader vectorReader = VectorReader.create(in, dim, vectorEncoding, numDocs);
while (i < numDocs) { for (int i = 0; i < numDocs; i++) {
int blockSize = (int) Math.min(totalBytes - offset, maxBlockSize); Document doc = new Document();
FloatBuffer vectors = switch (vectorEncoding) {
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize) case BYTE -> doc.add(
.order(ByteOrder.LITTLE_ENDIAN) new KnnVectorField(
.asFloatBuffer(); KNN_FIELD, ((VectorReaderByte) vectorReader).nextBytes(), fieldType));
offset += blockSize; case FLOAT32 -> doc.add(new KnnVectorField(KNN_FIELD, vectorReader.next(), fieldType));
for (; vectors.hasRemaining() && i < numDocs; i++) {
vectors.get(vector);
Document doc = new Document();
// System.out.println("vector=" + vector[0] + "," + vector[1] + "...");
doc.add(new KnnVectorField(KNN_FIELD, vector, fieldType));
doc.add(new StoredField(ID_FIELD, i));
iw.addDocument(doc);
} }
doc.add(new StoredField(ID_FIELD, i));
iw.addDocument(doc);
} }
if (quiet == false) { if (quiet == false) {
System.out.println("Done indexing " + numDocs + " documents; now flush"); System.out.println("Done indexing " + numDocs + " documents; now flush");

View File

@ -31,6 +31,7 @@ class MockVectorValues extends VectorValues
protected final float[][] denseValues; protected final float[][] denseValues;
protected final float[][] values; protected final float[][] values;
private final int numVectors; private final int numVectors;
private final BytesRef binaryValue;
private int pos = -1; private int pos = -1;
@ -47,6 +48,9 @@ class MockVectorValues extends VectorValues
} }
numVectors = count; numVectors = count;
scratch = new float[dimension]; scratch = new float[dimension];
// used by tests that build a graph from bytes rather than floats
binaryValue = new BytesRef(dimension);
binaryValue.length = dimension;
} }
public MockVectorValues copy() { public MockVectorValues copy() {
@ -89,7 +93,11 @@ class MockVectorValues extends VectorValues
@Override @Override
public BytesRef binaryValue(int targetOrd) { public BytesRef binaryValue(int targetOrd) {
return null; float[] value = vectorValue(targetOrd);
for (int i = 0; i < value.length; i++) {
binaryValue.bytes[i] = (byte) value[i];
}
return binaryValue;
} }
private boolean seek(int target) { private boolean seek(int target) {

View File

@ -18,6 +18,7 @@
package org.apache.lucene.util.hnsw; package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.VectorUtil.toBytesRef;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
@ -43,6 +44,7 @@ import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer; import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
@ -60,25 +62,38 @@ import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.junit.Before;
/** Tests HNSW KNN graphs */ /** Tests HNSW KNN graphs */
public class TestHnswGraph extends LuceneTestCase { public class TestHnswGraph extends LuceneTestCase {
VectorSimilarityFunction similarityFunction;
VectorEncoding vectorEncoding;
@Before
public void setup() {
similarityFunction =
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
if (similarityFunction == VectorSimilarityFunction.DOT_PRODUCT) {
vectorEncoding =
VectorEncoding.values()[random().nextInt(VectorEncoding.values().length - 1) + 1];
} else {
vectorEncoding = VectorEncoding.FLOAT32;
}
}
// test writing out and reading in a graph gives the expected graph // test writing out and reading in a graph gives the expected graph
public void testReadWrite() throws IOException { public void testReadWrite() throws IOException {
int dim = random().nextInt(100) + 1; int dim = random().nextInt(100) + 1;
int nDoc = random().nextInt(100) + 1; int nDoc = random().nextInt(100) + 1;
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random()); int M = random().nextInt(4) + 2;
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
int M = random().nextInt(10) + 5;
int beamWidth = random().nextInt(10) + 5; int beamWidth = random().nextInt(10) + 5;
long seed = random().nextLong(); long seed = random().nextLong();
VectorSimilarityFunction similarityFunction = RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, vectorEncoding, random());
VectorSimilarityFunction.values()[ RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1]; HnswGraphBuilder<?> builder =
HnswGraphBuilder builder = HnswGraphBuilder.create(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
new HnswGraphBuilder(vectors, similarityFunction, M, beamWidth, seed);
HnswGraph hnsw = builder.build(vectors); HnswGraph hnsw = builder.build(vectors);
// Recreate the graph while indexing with the same random seed and write it out // Recreate the graph while indexing with the same random seed and write it out
@ -131,6 +146,10 @@ public class TestHnswGraph extends LuceneTestCase {
} }
} }
private VectorEncoding randomVectorEncoding() {
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
}
// test that sorted index returns the same search results are unsorted // test that sorted index returns the same search results are unsorted
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException { public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
int dim = random().nextInt(10) + 3; int dim = random().nextInt(10) + 3;
@ -250,24 +269,27 @@ public class TestHnswGraph extends LuceneTestCase {
// oriented in the right directions // oriented in the right directions
public void testAknnDiverse() throws IOException { public void testAknnDiverse() throws IOException {
int nDoc = 100; int nDoc = 100;
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
CircularVectorValues vectors = new CircularVectorValues(nDoc); CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder = HnswGraphBuilder<?> builder =
new HnswGraphBuilder( HnswGraphBuilder.create(
vectors, VectorSimilarityFunction.DOT_PRODUCT, 10, 100, random().nextInt()); vectors, vectorEncoding, similarityFunction, 10, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors); OnHeapHnswGraph hnsw = builder.build(vectors);
// run some searches // run some searches
NeighborQueue nn = NeighborQueue nn =
HnswGraphSearcher.search( HnswGraphSearcher.search(
new float[] {1, 0}, getTargetVector(),
10, 10,
vectors.randomAccess(), vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT, vectorEncoding,
similarityFunction,
hnsw, hnsw,
null, null,
Integer.MAX_VALUE); Integer.MAX_VALUE);
int[] nodes = nn.nodes(); int[] nodes = nn.nodes();
assertTrue("Number of found results is not equal to [10].", nodes.length == 10); assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
int sum = 0; int sum = 0;
for (int node : nodes) { for (int node : nodes) {
sum += node; sum += node;
@ -289,23 +311,26 @@ public class TestHnswGraph extends LuceneTestCase {
public void testSearchWithAcceptOrds() throws IOException { public void testSearchWithAcceptOrds() throws IOException {
int nDoc = 100; int nDoc = 100;
CircularVectorValues vectors = new CircularVectorValues(nDoc); CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder = similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
new HnswGraphBuilder( vectorEncoding = randomVectorEncoding();
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt()); HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors); OnHeapHnswGraph hnsw = builder.build(vectors);
// the first 10 docs must not be deleted to ensure the expected recall // the first 10 docs must not be deleted to ensure the expected recall
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size); Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
NeighborQueue nn = NeighborQueue nn =
HnswGraphSearcher.search( HnswGraphSearcher.search(
new float[] {1, 0}, getTargetVector(),
10, 10,
vectors.randomAccess(), vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT, vectorEncoding,
similarityFunction,
hnsw, hnsw,
acceptOrds, acceptOrds,
Integer.MAX_VALUE); Integer.MAX_VALUE);
int[] nodes = nn.nodes(); int[] nodes = nn.nodes();
assertTrue("Number of found results is not equal to [10].", nodes.length == 10); assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
int sum = 0; int sum = 0;
for (int node : nodes) { for (int node : nodes) {
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node)); assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
@ -319,9 +344,11 @@ public class TestHnswGraph extends LuceneTestCase {
public void testSearchWithSelectiveAcceptOrds() throws IOException { public void testSearchWithSelectiveAcceptOrds() throws IOException {
int nDoc = 100; int nDoc = 100;
CircularVectorValues vectors = new CircularVectorValues(nDoc); CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder = vectorEncoding = randomVectorEncoding();
new HnswGraphBuilder( similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt()); HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors); OnHeapHnswGraph hnsw = builder.build(vectors);
// Only mark a few vectors as accepted // Only mark a few vectors as accepted
BitSet acceptOrds = new FixedBitSet(vectors.size); BitSet acceptOrds = new FixedBitSet(vectors.size);
@ -333,10 +360,11 @@ public class TestHnswGraph extends LuceneTestCase {
int numAccepted = acceptOrds.cardinality(); int numAccepted = acceptOrds.cardinality();
NeighborQueue nn = NeighborQueue nn =
HnswGraphSearcher.search( HnswGraphSearcher.search(
new float[] {1, 0}, getTargetVector(),
numAccepted, numAccepted,
vectors.randomAccess(), vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT, vectorEncoding,
similarityFunction,
hnsw, hnsw,
acceptOrds, acceptOrds,
Integer.MAX_VALUE); Integer.MAX_VALUE);
@ -347,12 +375,17 @@ public class TestHnswGraph extends LuceneTestCase {
} }
} }
private float[] getTargetVector() {
return new float[] {1, 0};
}
public void testSearchWithSkewedAcceptOrds() throws IOException { public void testSearchWithSkewedAcceptOrds() throws IOException {
int nDoc = 1000; int nDoc = 1000;
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
CircularVectorValues vectors = new CircularVectorValues(nDoc); CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder = HnswGraphBuilder<?> builder =
new HnswGraphBuilder( HnswGraphBuilder.create(
vectors, VectorSimilarityFunction.EUCLIDEAN, 16, 100, random().nextInt()); vectors, VectorEncoding.FLOAT32, similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors); OnHeapHnswGraph hnsw = builder.build(vectors);
// Skip over half of the documents that are closest to the query vector // Skip over half of the documents that are closest to the query vector
@ -362,15 +395,16 @@ public class TestHnswGraph extends LuceneTestCase {
} }
NeighborQueue nn = NeighborQueue nn =
HnswGraphSearcher.search( HnswGraphSearcher.search(
new float[] {1, 0}, getTargetVector(),
10, 10,
vectors.randomAccess(), vectors.randomAccess(),
VectorSimilarityFunction.EUCLIDEAN, VectorEncoding.FLOAT32,
similarityFunction,
hnsw, hnsw,
acceptOrds, acceptOrds,
Integer.MAX_VALUE); Integer.MAX_VALUE);
int[] nodes = nn.nodes(); int[] nodes = nn.nodes();
assertTrue("Number of found results is not equal to [10].", nodes.length == 10); assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
int sum = 0; int sum = 0;
for (int node : nodes) { for (int node : nodes) {
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node)); assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
@ -383,20 +417,23 @@ public class TestHnswGraph extends LuceneTestCase {
public void testVisitedLimit() throws IOException { public void testVisitedLimit() throws IOException {
int nDoc = 500; int nDoc = 500;
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
CircularVectorValues vectors = new CircularVectorValues(nDoc); CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder = HnswGraphBuilder<?> builder =
new HnswGraphBuilder( HnswGraphBuilder.create(
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt()); vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors); OnHeapHnswGraph hnsw = builder.build(vectors);
int topK = 50; int topK = 50;
int visitedLimit = topK + random().nextInt(5); int visitedLimit = topK + random().nextInt(5);
NeighborQueue nn = NeighborQueue nn =
HnswGraphSearcher.search( HnswGraphSearcher.search(
new float[] {1, 0}, getTargetVector(),
topK, topK,
vectors.randomAccess(), vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT, vectorEncoding,
similarityFunction,
hnsw, hnsw,
createRandomAcceptOrds(0, vectors.size), createRandomAcceptOrds(0, vectors.size),
visitedLimit); visitedLimit);
@ -406,54 +443,68 @@ public class TestHnswGraph extends LuceneTestCase {
} }
public void testHnswGraphBuilderInvalid() { public void testHnswGraphBuilderInvalid() {
expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, null, 0, 0, 0)); expectThrows(
NullPointerException.class, () -> HnswGraphBuilder.create(null, null, null, 0, 0, 0));
// M must be > 0
expectThrows( expectThrows(
IllegalArgumentException.class, IllegalArgumentException.class,
() -> () ->
new HnswGraphBuilder( HnswGraphBuilder.create(
new RandomVectorValues(1, 1, random()), new RandomVectorValues(1, 1, random()),
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
0, 0,
10, 10,
0)); 0));
// beamWidth must be > 0
expectThrows( expectThrows(
IllegalArgumentException.class, IllegalArgumentException.class,
() -> () ->
new HnswGraphBuilder( HnswGraphBuilder.create(
new RandomVectorValues(1, 1, random()), new RandomVectorValues(1, 1, random()),
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
10, 10,
0, 0,
0)); 0));
} }
@SuppressWarnings("unchecked")
public void testDiversity() throws IOException { public void testDiversity() throws IOException {
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
// Some carefully checked test cases with simple 2d vectors on the unit circle: // Some carefully checked test cases with simple 2d vectors on the unit circle:
MockVectorValues vectors = float[][] values = {
new MockVectorValues( unitVector2d(0.5),
new float[][] { unitVector2d(0.75),
unitVector2d(0.5), unitVector2d(0.2),
unitVector2d(0.75), unitVector2d(0.9),
unitVector2d(0.2), unitVector2d(0.8),
unitVector2d(0.9), unitVector2d(0.77),
unitVector2d(0.8), };
unitVector2d(0.77), if (vectorEncoding == VectorEncoding.BYTE) {
}); for (float[] v : values) {
for (int i = 0; i < v.length; i++) {
v[i] *= 127;
}
}
}
MockVectorValues vectors = new MockVectorValues(values);
// First add nodes until everybody gets a full neighbor list // First add nodes until everybody gets a full neighbor list
HnswGraphBuilder builder = HnswGraphBuilder<?> builder =
new HnswGraphBuilder( HnswGraphBuilder.create(
vectors, VectorSimilarityFunction.DOT_PRODUCT, 2, 10, random().nextInt()); vectors, vectorEncoding, similarityFunction, 2, 10, random().nextInt());
// node 0 is added by the builder constructor // node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0)); // builder.addGraphNode(vectors.vectorValue(0));
builder.addGraphNode(1, vectors.vectorValue(1)); builder.addGraphNode(1, vectors);
builder.addGraphNode(2, vectors.vectorValue(2)); builder.addGraphNode(2, vectors);
// now every node has tried to attach every other node as a neighbor, but // now every node has tried to attach every other node as a neighbor, but
// some were excluded based on diversity check. // some were excluded based on diversity check.
assertLevel0Neighbors(builder.hnsw, 0, 1, 2); assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
assertLevel0Neighbors(builder.hnsw, 1, 0); assertLevel0Neighbors(builder.hnsw, 1, 0);
assertLevel0Neighbors(builder.hnsw, 2, 0); assertLevel0Neighbors(builder.hnsw, 2, 0);
builder.addGraphNode(3, vectors.vectorValue(3)); builder.addGraphNode(3, vectors);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2); assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
// we added 3 here // we added 3 here
assertLevel0Neighbors(builder.hnsw, 1, 0, 3); assertLevel0Neighbors(builder.hnsw, 1, 0, 3);
@ -461,7 +512,7 @@ public class TestHnswGraph extends LuceneTestCase {
assertLevel0Neighbors(builder.hnsw, 3, 1); assertLevel0Neighbors(builder.hnsw, 3, 1);
// supplant an existing neighbor // supplant an existing neighbor
builder.addGraphNode(4, vectors.vectorValue(4)); builder.addGraphNode(4, vectors);
// 4 is the same distance from 0 that 2 is; we leave the existing node in place // 4 is the same distance from 0 that 2 is; we leave the existing node in place
assertLevel0Neighbors(builder.hnsw, 0, 1, 2); assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4); assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4);
@ -470,7 +521,7 @@ public class TestHnswGraph extends LuceneTestCase {
assertLevel0Neighbors(builder.hnsw, 3, 1, 4); assertLevel0Neighbors(builder.hnsw, 3, 1, 4);
assertLevel0Neighbors(builder.hnsw, 4, 1, 3); assertLevel0Neighbors(builder.hnsw, 4, 1, 3);
builder.addGraphNode(5, vectors.vectorValue(5)); builder.addGraphNode(5, vectors);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2); assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4, 5); assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4, 5);
assertLevel0Neighbors(builder.hnsw, 2, 0); assertLevel0Neighbors(builder.hnsw, 2, 0);
@ -494,29 +545,46 @@ public class TestHnswGraph extends LuceneTestCase {
public void testRandom() throws IOException { public void testRandom() throws IOException {
int size = atLeast(100); int size = atLeast(100);
int dim = atLeast(10); int dim = atLeast(10);
RandomVectorValues vectors = new RandomVectorValues(size, dim, random()); RandomVectorValues vectors = new RandomVectorValues(size, dim, vectorEncoding, random());
VectorSimilarityFunction similarityFunction =
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
int topK = 5; int topK = 5;
HnswGraphBuilder builder = HnswGraphBuilder<?> builder =
new HnswGraphBuilder(vectors, similarityFunction, 10, 30, random().nextLong()); HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 10, 30, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors); OnHeapHnswGraph hnsw = builder.build(vectors);
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size); Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
int totalMatches = 0; int totalMatches = 0;
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
float[] query = randomVector(random(), dim); NeighborQueue actual;
NeighborQueue actual = float[] query;
BytesRef bQuery = null;
if (vectorEncoding == VectorEncoding.BYTE) {
query = randomVector8(random(), dim);
bQuery = toBytesRef(query);
} else {
query = randomVector(random(), dim);
}
actual =
HnswGraphSearcher.search( HnswGraphSearcher.search(
query, 100, vectors, similarityFunction, hnsw, acceptOrds, Integer.MAX_VALUE); query,
100,
vectors,
vectorEncoding,
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
while (actual.size() > topK) { while (actual.size() > topK) {
actual.pop(); actual.pop();
} }
NeighborQueue expected = new NeighborQueue(topK, false); NeighborQueue expected = new NeighborQueue(topK, false);
for (int j = 0; j < size; j++) { for (int j = 0; j < size; j++) {
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) { if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j))); if (vectorEncoding == VectorEncoding.BYTE) {
expected.add(j, similarityFunction.compare(bQuery, vectors.binaryValue(j)));
} else {
expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j)));
}
if (expected.size() > topK) { if (expected.size() > topK) {
expected.pop(); expected.pop();
} }
@ -553,12 +621,14 @@ public class TestHnswGraph extends LuceneTestCase {
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer { implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
private final int size; private final int size;
private final float[] value; private final float[] value;
private final BytesRef binaryValue;
int doc = -1; int doc = -1;
CircularVectorValues(int size) { CircularVectorValues(int size) {
this.size = size; this.size = size;
value = new float[2]; value = new float[2];
binaryValue = new BytesRef(new byte[2]);
} }
public CircularVectorValues copy() { public CircularVectorValues copy() {
@ -617,7 +687,11 @@ public class TestHnswGraph extends LuceneTestCase {
@Override @Override
public BytesRef binaryValue(int ord) { public BytesRef binaryValue(int ord) {
return null; float[] vectorValue = vectorValue(ord);
for (int i = 0; i < vectorValue.length; i++) {
binaryValue.bytes[i] = (byte) (vectorValue[i] * 127);
}
return binaryValue;
} }
} }
@ -648,8 +722,9 @@ public class TestHnswGraph extends LuceneTestCase {
if (uDoc == NO_MORE_DOCS) { if (uDoc == NO_MORE_DOCS) {
break; break;
} }
float delta = vectorEncoding == VectorEncoding.BYTE ? 1 : 1e-4f;
assertArrayEquals( assertArrayEquals(
"vectors do not match for doc=" + uDoc, u.vectorValue(), v.vectorValue(), 1e-4f); "vectors do not match for doc=" + uDoc, u.vectorValue(), v.vectorValue(), delta);
} }
} }
@ -657,7 +732,11 @@ public class TestHnswGraph extends LuceneTestCase {
static class RandomVectorValues extends MockVectorValues { static class RandomVectorValues extends MockVectorValues {
RandomVectorValues(int size, int dimension, Random random) { RandomVectorValues(int size, int dimension, Random random) {
super(createRandomVectors(size, dimension, random)); super(createRandomVectors(size, dimension, null, random));
}
RandomVectorValues(int size, int dimension, VectorEncoding vectorEncoding, Random random) {
super(createRandomVectors(size, dimension, vectorEncoding, random));
} }
RandomVectorValues(RandomVectorValues other) { RandomVectorValues(RandomVectorValues other) {
@ -669,11 +748,21 @@ public class TestHnswGraph extends LuceneTestCase {
return new RandomVectorValues(this); return new RandomVectorValues(this);
} }
private static float[][] createRandomVectors(int size, int dimension, Random random) { private static float[][] createRandomVectors(
int size, int dimension, VectorEncoding vectorEncoding, Random random) {
float[][] vectors = new float[size][]; float[][] vectors = new float[size][];
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) { for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
vectors[offset] = randomVector(random, dimension); vectors[offset] = randomVector(random, dimension);
} }
if (vectorEncoding == VectorEncoding.BYTE) {
for (float[] vector : vectors) {
if (vector != null) {
for (int i = 0; i < vector.length; i++) {
vector[i] = (byte) (127 * vector[i]);
}
}
}
}
return vectors; return vectors;
} }
} }
@ -701,8 +790,19 @@ public class TestHnswGraph extends LuceneTestCase {
float[] vec = new float[dim]; float[] vec = new float[dim];
for (int i = 0; i < dim; i++) { for (int i = 0; i < dim; i++) {
vec[i] = random.nextFloat(); vec[i] = random.nextFloat();
if (random.nextBoolean()) {
vec[i] = -vec[i];
}
} }
VectorUtil.l2normalize(vec); VectorUtil.l2normalize(vec);
return vec; return vec;
} }
private static float[] randomVector8(Random random, int dim) {
float[] fvec = randomVector(random, dim);
for (int i = 0; i < dim; i++) {
fvec[i] *= 127;
}
return fvec;
}
} }

View File

@ -34,8 +34,10 @@ import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.StoredFieldVisitor; import org.apache.lucene.index.StoredFieldVisitor;
import org.apache.lucene.index.Terms; import org.apache.lucene.index.Terms;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
import org.apache.lucene.util.Version; import org.apache.lucene.util.Version;
@ -97,6 +99,7 @@ public class TermVectorLeafReader extends LeafReader {
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
false); false);
fieldInfos = new FieldInfos(new FieldInfo[] {fieldInfo}); fieldInfos = new FieldInfos(new FieldInfo[] {fieldInfo});
@ -166,6 +169,12 @@ public class TermVectorLeafReader extends LeafReader {
return null; return null;
} }
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
return null;
}
@Override @Override
public void checkIntegrity() throws IOException {} public void checkIntegrity() throws IOException {}

View File

@ -37,6 +37,7 @@ import org.apache.lucene.document.FieldType;
import org.apache.lucene.index.*; import org.apache.lucene.index.*;
import org.apache.lucene.search.Collector; import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable; import org.apache.lucene.search.Scorable;
@ -514,6 +515,7 @@ public class MemoryIndex {
fieldType.pointIndexDimensionCount(), fieldType.pointIndexDimensionCount(),
fieldType.pointNumBytes(), fieldType.pointNumBytes(),
fieldType.vectorDimension(), fieldType.vectorDimension(),
fieldType.vectorEncoding(),
fieldType.vectorSimilarityFunction(), fieldType.vectorSimilarityFunction(),
false); false);
} }
@ -546,6 +548,7 @@ public class MemoryIndex {
info.fieldInfo.getPointIndexDimensionCount(), info.fieldInfo.getPointIndexDimensionCount(),
info.fieldInfo.getPointNumBytes(), info.fieldInfo.getPointNumBytes(),
info.fieldInfo.getVectorDimension(), info.fieldInfo.getVectorDimension(),
info.fieldInfo.getVectorEncoding(),
info.fieldInfo.getVectorSimilarityFunction(), info.fieldInfo.getVectorSimilarityFunction(),
info.fieldInfo.isSoftDeletesField()); info.fieldInfo.isSoftDeletesField());
} else if (existingDocValuesType != docValuesType) { } else if (existingDocValuesType != docValuesType) {
@ -1371,6 +1374,12 @@ public class MemoryIndex {
return null; return null;
} }
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
return null;
}
@Override @Override
public void checkIntegrity() throws IOException { public void checkIntegrity() throws IOException {
// no-op // no-op

View File

@ -29,6 +29,7 @@ import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter; import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
@ -61,7 +62,7 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
} }
@Override @Override
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
return delegate.addField(fieldInfo); return delegate.addField(fieldInfo);
} }
@ -131,6 +132,18 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
return hits; return hits;
} }
@Override
public TopDocs searchExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
FieldInfo fi = fis.fieldInfo(field);
assert fi != null && fi.getVectorDimension() > 0;
assert acceptDocs != null;
TopDocs hits = delegate.searchExhaustively(field, target, k, acceptDocs);
assert hits != null;
assert hits.scoreDocs.length <= k;
return hits;
}
@Override @Override
public void close() throws IOException { public void close() throws IOException {
delegate.close(); delegate.close();

View File

@ -36,6 +36,7 @@ import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexableFieldType; import org.apache.lucene.index.IndexableFieldType;
import org.apache.lucene.index.PointValues; import org.apache.lucene.index.PointValues;
import org.apache.lucene.index.SegmentInfo; import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.internal.tests.IndexPackageAccess; import org.apache.lucene.internal.tests.IndexPackageAccess;
@ -305,6 +306,7 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
fieldType.pointIndexDimensionCount(), fieldType.pointIndexDimensionCount(),
fieldType.pointNumBytes(), fieldType.pointNumBytes(),
fieldType.vectorDimension(), fieldType.vectorDimension(),
fieldType.vectorEncoding(),
fieldType.vectorSimilarityFunction(), fieldType.vectorSimilarityFunction(),
field.equals(softDeletesField)); field.equals(softDeletesField));
addAttributes(fi); addAttributes(fi);
@ -353,7 +355,8 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
int dimension = 1 + r.nextInt(VectorValues.MAX_DIMENSIONS); int dimension = 1 + r.nextInt(VectorValues.MAX_DIMENSIONS);
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction similarityFunction =
RandomPicks.randomFrom(r, VectorSimilarityFunction.values()); RandomPicks.randomFrom(r, VectorSimilarityFunction.values());
type.setVectorDimensionsAndSimilarityFunction(dimension, similarityFunction); VectorEncoding encoding = RandomPicks.randomFrom(r, VectorEncoding.values());
type.setVectorAttributes(dimension, encoding, similarityFunction);
} }
return type; return type;
@ -422,6 +425,7 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
false); false);
} }

View File

@ -360,6 +360,7 @@ abstract class BaseIndexFileFormatTestCase extends LuceneTestCase {
proto.getPointIndexDimensionCount(), proto.getPointIndexDimensionCount(),
proto.getPointNumBytes(), proto.getPointNumBytes(),
proto.getVectorDimension(), proto.getVectorDimension(),
proto.getVectorEncoding(),
proto.getVectorSimilarityFunction(), proto.getVectorSimilarityFunction(),
proto.isSoftDeletesField()); proto.isSoftDeletesField());

View File

@ -41,6 +41,7 @@ import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.Sort; import org.apache.lucene.search.Sort;
@ -51,8 +52,10 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory; import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.VectorUtil;
import org.junit.Before;
/** /**
* Base class aiming at testing {@link KnnVectorsFormat vectors formats}. To test a new format, all * Base class aiming at testing {@link KnnVectorsFormat vectors formats}. To test a new format, all
@ -63,9 +66,21 @@ import org.apache.lucene.util.VectorUtil;
*/ */
public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTestCase { public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTestCase {
private VectorEncoding vectorEncoding;
private VectorSimilarityFunction similarityFunction;
@Before
public void init() {
vectorEncoding = randomVectorEncoding();
similarityFunction = randomSimilarity();
}
@Override @Override
protected void addRandomFields(Document doc) { protected void addRandomFields(Document doc) {
doc.add(new KnnVectorField("v2", randomVector(30), VectorSimilarityFunction.EUCLIDEAN)); switch (vectorEncoding) {
case BYTE -> doc.add(new KnnVectorField("v2", randomVector8(30), similarityFunction));
case FLOAT32 -> doc.add(new KnnVectorField("v2", randomVector(30), similarityFunction));
}
} }
public void testFieldConstructor() { public void testFieldConstructor() {
@ -133,8 +148,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException expected = IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2)); expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2));
String errMsg = String errMsg =
"cannot change field \"f\" from vector dimension=4, vector similarity function=DOT_PRODUCT " "cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=3, vector similarity function=DOT_PRODUCT"; + "to inconsistent vector dimension=3, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT";
assertEquals(errMsg, expected.getMessage()); assertEquals(errMsg, expected.getMessage());
} }
} }
@ -170,8 +185,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException expected = IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2)); expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2));
String errMsg = String errMsg =
"cannot change field \"f\" from vector dimension=4, vector similarity function=DOT_PRODUCT " "cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=4, vector similarity function=EUCLIDEAN"; + "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN";
assertEquals(errMsg, expected.getMessage()); assertEquals(errMsg, expected.getMessage());
} }
} }
@ -190,8 +205,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException expected = IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2)); expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2));
assertEquals( assertEquals(
"cannot change field \"f\" from vector dimension=4, vector similarity function=DOT_PRODUCT " "cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=1, vector similarity function=DOT_PRODUCT", + "to inconsistent vector dimension=1, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
expected.getMessage()); expected.getMessage());
} }
} }
@ -211,8 +226,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException expected = IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2)); expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2));
assertEquals( assertEquals(
"cannot change field \"f\" from vector dimension=4, vector similarity function=DOT_PRODUCT " "cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=4, vector similarity function=EUCLIDEAN", + "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN",
expected.getMessage()); expected.getMessage());
} }
} }
@ -311,8 +326,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
expectThrows( expectThrows(
IllegalArgumentException.class, () -> w2.addIndexes(new Directory[] {dir})); IllegalArgumentException.class, () -> w2.addIndexes(new Directory[] {dir}));
assertEquals( assertEquals(
"cannot change field \"f\" from vector dimension=5, vector similarity function=DOT_PRODUCT " "cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT", + "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
expected.getMessage()); expected.getMessage());
} }
} }
@ -333,8 +348,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException expected = IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w2.addIndexes(dir)); expectThrows(IllegalArgumentException.class, () -> w2.addIndexes(dir));
assertEquals( assertEquals(
"cannot change field \"f\" from vector dimension=4, vector similarity function=EUCLIDEAN " "cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN "
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT", + "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
expected.getMessage()); expected.getMessage());
} }
} }
@ -358,8 +373,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException.class, IllegalArgumentException.class,
() -> w2.addIndexes(new CodecReader[] {(CodecReader) getOnlyLeafReader(r)})); () -> w2.addIndexes(new CodecReader[] {(CodecReader) getOnlyLeafReader(r)}));
assertEquals( assertEquals(
"cannot change field \"f\" from vector dimension=5, vector similarity function=DOT_PRODUCT " "cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT", + "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
expected.getMessage()); expected.getMessage());
} }
} }
@ -384,8 +399,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException.class, IllegalArgumentException.class,
() -> w2.addIndexes(new CodecReader[] {(CodecReader) getOnlyLeafReader(r)})); () -> w2.addIndexes(new CodecReader[] {(CodecReader) getOnlyLeafReader(r)}));
assertEquals( assertEquals(
"cannot change field \"f\" from vector dimension=4, vector similarity function=EUCLIDEAN " "cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN "
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT", + "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
expected.getMessage()); expected.getMessage());
} }
} }
@ -408,8 +423,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException expected = IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> TestUtil.addIndexesSlowly(w2, r)); expectThrows(IllegalArgumentException.class, () -> TestUtil.addIndexesSlowly(w2, r));
assertEquals( assertEquals(
"cannot change field \"f\" from vector dimension=5, vector similarity function=DOT_PRODUCT " "cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT", + "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
expected.getMessage()); expected.getMessage());
} }
} }
@ -432,8 +447,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException expected = IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> TestUtil.addIndexesSlowly(w2, r)); expectThrows(IllegalArgumentException.class, () -> TestUtil.addIndexesSlowly(w2, r));
assertEquals( assertEquals(
"cannot change field \"f\" from vector dimension=4, vector similarity function=EUCLIDEAN " "cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN "
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT", + "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
expected.getMessage()); expected.getMessage());
} }
} }
@ -596,12 +611,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
int[] fieldDocCounts = new int[numFields]; int[] fieldDocCounts = new int[numFields];
double[] fieldTotals = new double[numFields]; double[] fieldTotals = new double[numFields];
int[] fieldDims = new int[numFields]; int[] fieldDims = new int[numFields];
VectorSimilarityFunction[] fieldSearchStrategies = new VectorSimilarityFunction[numFields]; VectorSimilarityFunction[] fieldSimilarityFunctions = new VectorSimilarityFunction[numFields];
VectorEncoding[] fieldVectorEncodings = new VectorEncoding[numFields];
for (int i = 0; i < numFields; i++) { for (int i = 0; i < numFields; i++) {
fieldDims[i] = random().nextInt(20) + 1; fieldDims[i] = random().nextInt(20) + 1;
fieldSearchStrategies[i] = fieldSimilarityFunctions[i] = randomSimilarity();
VectorSimilarityFunction.values()[ fieldVectorEncodings[i] = randomVectorEncoding();
random().nextInt(VectorSimilarityFunction.values().length)];
} }
try (Directory dir = newDirectory(); try (Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) { RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) {
@ -610,15 +625,23 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
for (int field = 0; field < numFields; field++) { for (int field = 0; field < numFields; field++) {
String fieldName = "int" + field; String fieldName = "int" + field;
if (random().nextInt(100) == 17) { if (random().nextInt(100) == 17) {
float[] v = randomVector(fieldDims[field]); switch (fieldVectorEncodings[field]) {
doc.add(new KnnVectorField(fieldName, v, fieldSearchStrategies[field])); case BYTE -> {
BytesRef b = randomVector8(fieldDims[field]);
doc.add(new KnnVectorField(fieldName, b, fieldSimilarityFunctions[field]));
fieldTotals[field] += b.bytes[b.offset];
}
case FLOAT32 -> {
float[] v = randomVector(fieldDims[field]);
doc.add(new KnnVectorField(fieldName, v, fieldSimilarityFunctions[field]));
fieldTotals[field] += v[0];
}
}
fieldDocCounts[field]++; fieldDocCounts[field]++;
fieldTotals[field] += v[0];
} }
} }
w.addDocument(doc); w.addDocument(doc);
} }
try (IndexReader r = w.getReader()) { try (IndexReader r = w.getReader()) {
for (int field = 0; field < numFields; field++) { for (int field = 0; field < numFields; field++) {
int docCount = 0; int docCount = 0;
@ -634,12 +657,29 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
} }
} }
assertEquals(fieldDocCounts[field], docCount); assertEquals(fieldDocCounts[field], docCount);
assertEquals(fieldTotals[field], checksum, 1e-5); // Account for quantization done when indexing fields w/BYTE encoding
double delta = fieldVectorEncodings[field] == VectorEncoding.BYTE ? numDocs * 0.01 : 1e-5;
assertEquals(fieldTotals[field], checksum, delta);
} }
} }
} }
} }
private VectorSimilarityFunction randomSimilarity() {
return VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length)];
}
private VectorEncoding randomVectorEncoding() {
Codec codec = getCodec();
if (codec.knnVectorsFormat().currentVersion()
>= Codec.forName("Lucene94").knnVectorsFormat().currentVersion()) {
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
} else {
return VectorEncoding.FLOAT32;
}
}
public void testIndexedValueNotAliased() throws Exception { public void testIndexedValueNotAliased() throws Exception {
// We copy indexed values (as for BinaryDocValues) so the input float[] can be reused across // We copy indexed values (as for BinaryDocValues) so the input float[] can be reused across
// calls to IndexWriter.addDocument. // calls to IndexWriter.addDocument.
@ -742,7 +782,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
assertEquals(3, vectorValues3.dimension()); assertEquals(3, vectorValues3.dimension());
assertEquals(1, vectorValues3.size()); assertEquals(1, vectorValues3.size());
vectorValues3.nextDoc(); vectorValues3.nextDoc();
assertEquals(1f, vectorValues3.vectorValue()[0], 0); assertEquals(1f, vectorValues3.vectorValue()[0], 0.1);
assertEquals(NO_MORE_DOCS, vectorValues3.nextDoc()); assertEquals(NO_MORE_DOCS, vectorValues3.nextDoc());
} }
} }
@ -775,9 +815,9 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
if (random().nextBoolean() && values[i] != null) { if (random().nextBoolean() && values[i] != null) {
// sometimes use a shared scratch array // sometimes use a shared scratch array
System.arraycopy(values[i], 0, scratch, 0, scratch.length); System.arraycopy(values[i], 0, scratch, 0, scratch.length);
add(iw, fieldName, i, scratch, VectorSimilarityFunction.EUCLIDEAN); add(iw, fieldName, i, scratch, similarityFunction);
} else { } else {
add(iw, fieldName, i, values[i], VectorSimilarityFunction.EUCLIDEAN); add(iw, fieldName, i, values[i], similarityFunction);
} }
if (random().nextInt(10) == 2) { if (random().nextInt(10) == 2) {
// sometimes delete a random document // sometimes delete a random document
@ -898,7 +938,6 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
int numDoc = atLeast(100); int numDoc = atLeast(100);
int dimension = atLeast(10); int dimension = atLeast(10);
float[][] id2value = new float[numDoc][]; float[][] id2value = new float[numDoc][];
int[] id2ord = new int[numDoc];
for (int i = 0; i < numDoc; i++) { for (int i = 0; i < numDoc; i++) {
int id = random().nextInt(numDoc); int id = random().nextInt(numDoc);
float[] value; float[] value;
@ -909,7 +948,6 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
value = null; value = null;
} }
id2value[id] = value; id2value[id] = value;
id2ord[id] = i;
add(iw, fieldName, id, value, VectorSimilarityFunction.EUCLIDEAN); add(iw, fieldName, id, value, VectorSimilarityFunction.EUCLIDEAN);
} }
try (IndexReader reader = DirectoryReader.open(iw)) { try (IndexReader reader = DirectoryReader.open(iw)) {
@ -1007,6 +1045,15 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
return v; return v;
} }
private BytesRef randomVector8(int dim) {
float[] v = randomVector(dim);
byte[] b = new byte[dim];
for (int i = 0; i < dim; i++) {
b[i] = (byte) (v[i] * 127);
}
return new BytesRef(b);
}
public void testCheckIndexIncludesVectors() throws Exception { public void testCheckIndexIncludesVectors() throws Exception {
try (Directory dir = newDirectory()) { try (Directory dir = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
@ -1041,6 +1088,14 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
assertEquals(3, VectorSimilarityFunction.values().length); assertEquals(3, VectorSimilarityFunction.values().length);
} }
public void testVectorEncodingOrdinals() {
// make sure we don't accidentally mess up vector encoding identifiers by re-ordering their
// enumerators
assertEquals(0, VectorEncoding.BYTE.ordinal());
assertEquals(1, VectorEncoding.FLOAT32.ordinal());
assertEquals(2, VectorEncoding.values().length);
}
public void testAdvance() throws Exception { public void testAdvance() throws Exception {
try (Directory dir = newDirectory()) { try (Directory dir = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
@ -1091,10 +1146,6 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
public void testVectorValuesReportCorrectDocs() throws Exception { public void testVectorValuesReportCorrectDocs() throws Exception {
final int numDocs = atLeast(1000); final int numDocs = atLeast(1000);
final int dim = random().nextInt(20) + 1; final int dim = random().nextInt(20) + 1;
final VectorSimilarityFunction similarityFunction =
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length)];
double fieldValuesCheckSum = 0; double fieldValuesCheckSum = 0;
int fieldDocCount = 0; int fieldDocCount = 0;
long fieldSumDocIDs = 0; long fieldSumDocIDs = 0;
@ -1106,9 +1157,18 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
int docID = random().nextInt(numDocs); int docID = random().nextInt(numDocs);
doc.add(new StoredField("id", docID)); doc.add(new StoredField("id", docID));
if (random().nextInt(4) == 3) { if (random().nextInt(4) == 3) {
float[] vector = randomVector(dim); switch (vectorEncoding) {
doc.add(new KnnVectorField("knn_vector", vector, similarityFunction)); case BYTE -> {
fieldValuesCheckSum += vector[0]; BytesRef b = randomVector8(dim);
fieldValuesCheckSum += b.bytes[b.offset];
doc.add(new KnnVectorField("knn_vector", b, similarityFunction));
}
case FLOAT32 -> {
float[] v = randomVector(dim);
fieldValuesCheckSum += v[0];
doc.add(new KnnVectorField("knn_vector", v, similarityFunction));
}
}
fieldDocCount++; fieldDocCount++;
fieldSumDocIDs += docID; fieldSumDocIDs += docID;
} }
@ -1134,7 +1194,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
} }
} }
} }
assertEquals(fieldValuesCheckSum, checksum, 1e-3); assertEquals(
fieldValuesCheckSum,
checksum,
vectorEncoding == VectorEncoding.BYTE ? numDocs * 0.2 : 1e-5);
assertEquals(fieldDocCount, docCount); assertEquals(fieldDocCount, docCount);
assertEquals(fieldSumDocIDs, sumDocIds); assertEquals(fieldSumDocIDs, sumDocIds);
} }

View File

@ -40,6 +40,7 @@ import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.StoredFieldVisitor; import org.apache.lucene.index.StoredFieldVisitor;
import org.apache.lucene.index.Terms; import org.apache.lucene.index.Terms;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
@ -228,6 +229,12 @@ class MergeReaderWrapper extends LeafReader {
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit); return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
} }
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
return in.searchNearestVectorsExhaustively(field, target, k, acceptDocs);
}
@Override @Override
public int numDocs() { public int numDocs() {
return in.numDocs(); return in.numDocs();

View File

@ -88,6 +88,7 @@ public class MismatchedLeafReader extends FilterLeafReader {
oldInfo.getPointIndexDimensionCount(), // index dimension count oldInfo.getPointIndexDimensionCount(), // index dimension count
oldInfo.getPointNumBytes(), // dimension numBytes oldInfo.getPointNumBytes(), // dimension numBytes
oldInfo.getVectorDimension(), // number of dimensions of the field's vector oldInfo.getVectorDimension(), // number of dimensions of the field's vector
oldInfo.getVectorEncoding(), // numeric type of vector samples
// distance function for calculating similarity of the field's vector // distance function for calculating similarity of the field's vector
oldInfo.getVectorSimilarityFunction(), oldInfo.getVectorSimilarityFunction(),
oldInfo.isSoftDeletesField()); // used as soft-deletes field oldInfo.isSoftDeletesField()); // used as soft-deletes field

View File

@ -62,6 +62,7 @@ import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.TermState; import org.apache.lucene.index.TermState;
import org.apache.lucene.index.Terms; import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum; import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.tests.IndexPackageAccess; import org.apache.lucene.internal.tests.IndexPackageAccess;
import org.apache.lucene.internal.tests.TestSecrets; import org.apache.lucene.internal.tests.TestSecrets;
@ -163,6 +164,7 @@ public class RandomPostingsTester {
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
false); false);
fieldUpto++; fieldUpto++;
@ -734,6 +736,7 @@ public class RandomPostingsTester {
0, 0,
0, 0,
0, 0,
VectorEncoding.FLOAT32,
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
false); false);
} }

View File

@ -233,6 +233,12 @@ public class QueryUtils {
return null; return null;
} }
@Override
public TopDocs searchNearestVectorsExhaustively(
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
return null;
}
@Override @Override
public FieldInfos getFieldInfos() { public FieldInfos getFieldInfos() {
return FieldInfos.EMPTY; return FieldInfos.EMPTY;